fix: correctly check user state (#8631)

# Which Problems Are Solved

ZITADEL's user account deactivation mechanism did not work correctly
with service accounts. Deactivated service accounts retained the ability
to request tokens, which could lead to unauthorized access to
applications and resources.

# How the Problems Are Solved

Additionally to checking the user state on the session API and login UI,
the state is checked on all oidc session methods resulting in a new
token or when returning the user information (userinfo, introspection,
id_token / access_token and saml attributes)
This commit is contained in:
Livio Spring 2024-09-17 15:21:49 +02:00 committed by GitHub
parent ca1914e235
commit 5b40af79f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 520 additions and 33 deletions

View File

@ -330,6 +330,9 @@ func (o *OPStorage) setUserinfo(ctx context.Context, userInfo *oidc.UserInfo, us
if err != nil { if err != nil {
return err return err
} }
if user.State != domain.UserStateActive {
return zerrors.ThrowUnauthenticated(nil, "OIDC-S3tha", "Errors.Users.NotActive")
}
var allRoles bool var allRoles bool
roles := make([]string, 0) roles := make([]string, 0)
for _, scope := range scopes { for _, scope := range scopes {

View File

@ -24,6 +24,9 @@ func TestServer_ClientCredentialsExchange(t *testing.T) {
machine, name, clientID, clientSecret, err := Instance.CreateOIDCCredentialsClient(CTX) machine, name, clientID, clientSecret, err := Instance.CreateOIDCCredentialsClient(CTX)
require.NoError(t, err) require.NoError(t, err)
_, _, clientIDInactive, clientSecretInactive, err := Instance.CreateOIDCCredentialsClientInactive(CTX)
require.NoError(t, err)
type claims struct { type claims struct {
name string name string
username string username string
@ -71,6 +74,13 @@ func TestServer_ClientCredentialsExchange(t *testing.T) {
scope: []string{oidc.ScopeOpenID}, scope: []string{oidc.ScopeOpenID},
wantErr: true, wantErr: true,
}, },
{
name: "inactive machine user error",
clientID: clientIDInactive,
clientSecret: clientSecretInactive,
scope: []string{oidc.ScopeOpenID},
wantErr: true,
},
{ {
name: "wrong secret error", name: "wrong secret error",
clientID: clientID, clientID: clientID,

View File

@ -66,7 +66,10 @@ func (s *Server) UserInfo(ctx context.Context, r *op.Request[oidc.UserInfoReques
false, false,
)(ctx, true, domain.TriggerTypePreUserinfoCreation) )(ctx, true, domain.TriggerTypePreUserinfoCreation)
if err != nil { if err != nil {
return nil, err if !zerrors.IsNotFound(err) {
return nil, err
}
return nil, op.NewStatusError(oidc.ErrAccessDenied().WithDescription("no active user").WithParent(err).WithReturnParentToClient(authz.GetFeatures(ctx).DebugOIDCParentError), http.StatusUnauthorized)
} }
return op.NewResponse(userInfo), nil return op.NewResponse(userInfo), nil
} }

View File

@ -131,6 +131,9 @@ func (p *Storage) SetUserinfoWithUserID(ctx context.Context, applicationID strin
if err != nil { if err != nil {
return err return err
} }
if user.State != domain.UserStateActive {
return zerrors.ThrowPreconditionFailed(nil, "SAML-S3gFd", "Errors.User.NotActive")
}
userGrants, err := p.getGrants(ctx, userID, applicationID) userGrants, err := p.getGrants(ctx, userID, applicationID)
if err != nil { if err != nil {
@ -157,6 +160,9 @@ func (p *Storage) SetUserinfoWithLoginName(ctx context.Context, userinfo models.
if err != nil { if err != nil {
return err return err
} }
if user.State != domain.UserStateActive {
return zerrors.ThrowPreconditionFailed(nil, "SAML-FJ262", "Errors.User.NotActive")
}
setUserinfo(user, userinfo, attributes, map[string]*customAttribute{}) setUserinfo(user, userinfo, attributes, map[string]*customAttribute{})
return nil return nil

View File

@ -144,7 +144,7 @@ func (c *Commands) CreateOIDCSessionFromDeviceAuth(ctx context.Context, deviceCo
return nil, DeviceAuthStateError(deviceAuthModel.State) return nil, DeviceAuthStateError(deviceAuthModel.State)
} }
cmd, err := c.newOIDCSessionAddEvents(ctx, deviceAuthModel.UserOrgID) cmd, err := c.newOIDCSessionAddEvents(ctx, deviceAuthModel.UserID, deviceAuthModel.UserOrgID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -126,7 +126,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
pushErr := errors.New("pushErr") pushErr := errors.New("pushErr")
type fields struct { type fields struct {
eventstore *eventstore.Eventstore eventstore func(*testing.T) *eventstore.Eventstore
} }
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -149,7 +149,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
{ {
name: "not found error", name: "not found error",
fields: fields{ fields: fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter(), expectFilter(),
), ),
}, },
@ -169,7 +169,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
{ {
name: "push error", name: "push error",
fields: fields{ fields: fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter(eventFromEventPusherWithInstanceID( expectFilter(eventFromEventPusherWithInstanceID(
"instance1", "instance1",
deviceauth.NewAddedEvent( deviceauth.NewAddedEvent(
@ -211,7 +211,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
{ {
name: "success", name: "success",
fields: fields{ fields: fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter(eventFromEventPusherWithInstanceID( expectFilter(eventFromEventPusherWithInstanceID(
"instance1", "instance1",
deviceauth.NewAddedEvent( deviceauth.NewAddedEvent(
@ -256,7 +256,7 @@ func TestCommands_ApproveDeviceAuth(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &Commands{ c := &Commands{
eventstore: tt.fields.eventstore, eventstore: tt.fields.eventstore(t),
} }
gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.userID, tt.args.userOrgID, tt.args.authMethods, tt.args.authTime, tt.args.preferredLanguage, tt.args.userAgent, tt.args.sessionID) gotDetails, err := c.ApproveDeviceAuth(tt.args.ctx, tt.args.id, tt.args.userID, tt.args.userOrgID, tt.args.authMethods, tt.args.authTime, tt.args.preferredLanguage, tt.args.userAgent, tt.args.sessionID)
require.ErrorIs(t, err, tt.wantErr) require.ErrorIs(t, err, tt.wantErr)
@ -271,7 +271,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
pushErr := errors.New("pushErr") pushErr := errors.New("pushErr")
type fields struct { type fields struct {
eventstore *eventstore.Eventstore eventstore func(*testing.T) *eventstore.Eventstore
} }
type args struct { type args struct {
ctx context.Context ctx context.Context
@ -288,7 +288,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
{ {
name: "not found error", name: "not found error",
fields: fields{ fields: fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter(), expectFilter(),
), ),
}, },
@ -298,7 +298,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
{ {
name: "push error", name: "push error",
fields: fields{ fields: fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter(eventFromEventPusherWithInstanceID( expectFilter(eventFromEventPusherWithInstanceID(
"instance1", "instance1",
deviceauth.NewAddedEvent( deviceauth.NewAddedEvent(
@ -323,7 +323,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
{ {
name: "success/denied", name: "success/denied",
fields: fields{ fields: fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter(eventFromEventPusherWithInstanceID( expectFilter(eventFromEventPusherWithInstanceID(
"instance1", "instance1",
deviceauth.NewAddedEvent( deviceauth.NewAddedEvent(
@ -350,7 +350,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
{ {
name: "success/expired", name: "success/expired",
fields: fields{ fields: fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter(eventFromEventPusherWithInstanceID( expectFilter(eventFromEventPusherWithInstanceID(
"instance1", "instance1",
deviceauth.NewAddedEvent( deviceauth.NewAddedEvent(
@ -378,7 +378,7 @@ func TestCommands_CancelDeviceAuth(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &Commands{ c := &Commands{
eventstore: tt.fields.eventstore, eventstore: tt.fields.eventstore(t),
} }
gotDetails, err := c.CancelDeviceAuth(tt.args.ctx, tt.args.id, tt.args.reason) gotDetails, err := c.CancelDeviceAuth(tt.args.ctx, tt.args.id, tt.args.reason)
require.ErrorIs(t, err, tt.wantErr) require.ErrorIs(t, err, tt.wantErr)
@ -586,6 +586,69 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) {
}, },
wantErr: DeviceAuthStateError(domain.DeviceAuthStateDone), wantErr: DeviceAuthStateError(domain.DeviceAuthStateDone),
}, },
{
name: "user not active",
fields: fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewAddedEvent(
ctx,
deviceauth.NewAggregate("123", "instance1"),
"clientID", "123", "456", time.Now().Add(-time.Minute),
[]string{"openid", "offline_access"},
[]string{"audience"}, false,
),
),
eventFromEventPusherWithInstanceID(
"instance1",
deviceauth.NewApprovedEvent(ctx,
deviceauth.NewAggregate("123", "instance1"),
"userID", "org1",
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
testNow, &language.Afrikaans, &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
"sessionID",
),
),
),
expectFilter(
user.NewHumanAddedEvent(
ctx,
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.English,
domain.GenderUnspecified,
"email",
false,
),
user.NewUserDeactivatedEvent(
ctx,
&user.NewAggregate("userID", "org1").Aggregate,
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx,
"123",
},
wantErr: zerrors.ThrowPreconditionFailed(nil, "OIDCS-kj3g2", "Errors.User.NotActive"),
},
{ {
name: "approved, success", name: "approved, success",
fields: fields{ fields: fields{
@ -617,6 +680,21 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) {
), ),
), ),
), ),
expectFilter(
user.NewHumanAddedEvent(
ctx,
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.English,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime expectFilter(), // token lifetime
expectPush( expectPush(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -699,6 +777,21 @@ func TestCommands_CreateOIDCSessionFromDeviceAuth(t *testing.T) {
), ),
), ),
), ),
expectFilter(
user.NewHumanAddedEvent(
ctx,
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.English,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime expectFilter(), // token lifetime
expectPush( expectPush(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,

View File

@ -80,7 +80,7 @@ func (c *Commands) CreateOIDCSessionFromAuthRequest(ctx context.Context, authReq
return nil, "", err return nil, "", err
} }
cmd, err := c.newOIDCSessionAddEvents(ctx, sessionModel.UserResourceOwner) cmd, err := c.newOIDCSessionAddEvents(ctx, sessionModel.UserID, sessionModel.UserResourceOwner)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
@ -141,7 +141,7 @@ func (c *Commands) CreateOIDCSession(ctx context.Context,
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
cmd, err := c.newOIDCSessionAddEvents(ctx, resourceOwner) cmd, err := c.newOIDCSessionAddEvents(ctx, userID, resourceOwner)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -265,7 +265,14 @@ func (c *Commands) RevokeOIDCSessionToken(ctx context.Context, token, clientID s
return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewAccessTokenRevokedEvent(ctx, writeModel.aggregate)) return c.pushAppendAndReduce(ctx, writeModel, oidcsession.NewAccessTokenRevokedEvent(ctx, writeModel.aggregate))
} }
func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, resourceOwner string, pending ...eventstore.Command) (*OIDCSessionEvents, error) { func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, userID, resourceOwner string, pending ...eventstore.Command) (*OIDCSessionEvents, error) {
userStateModel, err := c.userStateWriteModel(ctx, userID)
if err != nil {
return nil, err
}
if !userStateModel.UserState.IsEnabled() {
return nil, zerrors.ThrowPreconditionFailed(nil, "OIDCS-kj3g2", "Errors.User.NotActive")
}
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx) accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -281,6 +288,7 @@ func (c *Commands) newOIDCSessionAddEvents(ctx context.Context, resourceOwner st
encryptionAlg: c.keyAlgorithm, encryptionAlg: c.keyAlgorithm,
events: pending, events: pending,
oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, resourceOwner), oidcSessionWriteModel: NewOIDCSessionWriteModel(sessionID, resourceOwner),
userStateModel: userStateModel,
accessTokenLifetime: accessTokenLifetime, accessTokenLifetime: accessTokenLifetime,
refreshTokenLifeTime: refreshTokenLifeTime, refreshTokenLifeTime: refreshTokenLifeTime,
refreshTokenIdleLifetime: refreshTokenIdleLifetime, refreshTokenIdleLifetime: refreshTokenIdleLifetime,
@ -321,6 +329,13 @@ func (c *Commands) newOIDCSessionUpdateEvents(ctx context.Context, refreshToken
if err = sessionWriteModel.CheckRefreshToken(refreshTokenID); err != nil { if err = sessionWriteModel.CheckRefreshToken(refreshTokenID); err != nil {
return nil, err return nil, err
} }
userStateWriteModel, err := c.userStateWriteModel(ctx, sessionWriteModel.UserID)
if err != nil {
return nil, err
}
if !userStateWriteModel.UserState.IsEnabled() {
return nil, zerrors.ThrowPreconditionFailed(nil, "OIDCS-J39h2", "Errors.User.NotActive")
}
accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx) accessTokenLifetime, refreshTokenLifeTime, refreshTokenIdleLifetime, err := c.tokenTokenLifetimes(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
@ -342,6 +357,7 @@ type OIDCSessionEvents struct {
encryptionAlg crypto.EncryptionAlgorithm encryptionAlg crypto.EncryptionAlgorithm
events []eventstore.Command events []eventstore.Command
oidcSessionWriteModel *OIDCSessionWriteModel oidcSessionWriteModel *OIDCSessionWriteModel
userStateModel *UserV2WriteModel
accessTokenLifetime time.Duration accessTokenLifetime time.Duration
refreshTokenLifeTime time.Duration refreshTokenLifeTime time.Duration

View File

@ -205,6 +205,103 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"), err: zerrors.ThrowPreconditionFailed(nil, "COMMAND-Flk38", "Errors.Session.NotExisting"),
}, },
}, },
{
"user not active",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusher(
authrequest.NewAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"loginClient",
"clientID",
"redirectURI",
"state",
"nonce",
[]string{"openid", "offline_access"},
[]string{"audience"},
domain.OIDCResponseTypeCode,
domain.OIDCResponseModeQuery,
&domain.OIDCCodeChallenge{
Challenge: "challenge",
Method: domain.CodeChallengeMethodS256,
},
[]domain.Prompt{domain.PromptNone},
[]string{"en", "de"},
gu.Ptr(time.Duration(0)),
gu.Ptr("loginHint"),
gu.Ptr("hintUserID"),
true,
),
),
eventFromEventPusher(
authrequest.NewCodeAddedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
),
eventFromEventPusher(
authrequest.NewSessionLinkedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate,
"sessionID",
"userID",
testNow,
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
),
),
),
expectFilter(
eventFromEventPusher(
session.NewAddedEvent(context.Background(),
&session.NewAggregate("sessionID", "instance1").Aggregate,
&domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
),
),
eventFromEventPusher(
session.NewUserCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
"userID", "org1", testNow, &language.Afrikaans),
),
eventFromEventPusher(
session.NewPasswordCheckedEvent(context.Background(), &session.NewAggregate("sessionID", "instanceID").Aggregate,
testNow),
),
),
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
user.NewUserDeactivatedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
authRequestID: "V2_authRequestID",
complianceCheck: mockAuthRequestComplianceChecker(nil),
needRefreshToken: true,
},
res{
err: zerrors.ThrowPreconditionFailed(nil, "OIDCS-kj3g2", "Errors.User.NotActive"),
},
},
{ {
"add successful", "add successful",
fields{ fields{
@ -266,6 +363,21 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
testNow), testNow),
), ),
), ),
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime expectFilter(), // token lifetime
expectPush( expectPush(
authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate), authrequest.NewCodeExchangedEvent(context.Background(), &authrequest.NewAggregate("V2_authRequestID", "instanceID").Aggregate),
@ -382,6 +494,21 @@ func TestCommands_CreateOIDCSessionFromAuthRequest(t *testing.T) {
testNow), testNow),
), ),
), ),
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime expectFilter(), // token lifetime
expectPush( expectPush(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -521,10 +648,81 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
}, },
wantErr: io.ErrClosedPipe, wantErr: io.ErrClosedPipe,
}, },
{
name: "not active user",
fields: fields{
eventstore: expectEventstore(
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
user.NewUserDeactivatedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args: args{
ctx: context.Background(),
userID: "userID",
resourceOwner: "org1",
clientID: "clientID",
audience: []string{"audience"},
scope: []string{"openid", "offline_access"},
authMethods: []domain.UserAuthMethodType{domain.UserAuthMethodTypePassword},
authTime: testNow,
nonce: "nonce",
preferredLanguage: &language.Afrikaans,
userAgent: &domain.UserAgent{
FingerprintID: gu.Ptr("fp1"),
IP: net.ParseIP("1.2.3.4"),
Description: gu.Ptr("firefox"),
Header: http.Header{"foo": []string{"bar"}},
},
reason: domain.TokenReasonAuthRequest,
actor: &domain.TokenActor{
UserID: "user2",
Issuer: "foo.com",
},
needRefreshToken: false,
},
wantErr: zerrors.ThrowPreconditionFailed(nil, "OIDCS-kj3g2", "Errors.User.NotActive"),
},
{ {
name: "without refresh token", name: "without refresh token",
fields: fields{ fields: fields{
eventstore: expectEventstore( eventstore: expectEventstore(
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime expectFilter(), // token lifetime
expectPush( expectPush(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -606,6 +804,21 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
name: "with refresh token", name: "with refresh token",
fields: fields{ fields: fields{
eventstore: expectEventstore( eventstore: expectEventstore(
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime expectFilter(), // token lifetime
expectPush( expectPush(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -689,6 +902,21 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
name: "with sessionID", name: "with sessionID",
fields: fields{ fields: fields{
eventstore: expectEventstore( eventstore: expectEventstore(
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime expectFilter(), // token lifetime
expectPush( expectPush(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -772,6 +1000,21 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
name: "impersonation not allowed", name: "impersonation not allowed",
fields: fields{ fields: fields{
eventstore: expectEventstore( eventstore: expectEventstore(
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime expectFilter(), // token lifetime
), ),
idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID"), idGenerator: mock.NewIDGeneratorExpectIDs(t, "oidcSessionID"),
@ -813,6 +1056,21 @@ func TestCommands_CreateOIDCSession(t *testing.T) {
name: "impersonation allowed", name: "impersonation allowed",
fields: fields{ fields: fields{
eventstore: expectEventstore( eventstore: expectEventstore(
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime expectFilter(), // token lifetime
expectPush( expectPush(
user.NewUserImpersonatedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "clientID", &domain.TokenActor{ user.NewUserImpersonatedEvent(context.Background(), &user.NewAggregate("userID", "org1").Aggregate, "clientID", &domain.TokenActor{
@ -1067,6 +1325,63 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
err: zerrors.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"), err: zerrors.ThrowPreconditionFailed(nil, "OIDCS-3jt2w", "Errors.OIDCSession.RefreshTokenInvalid"),
}, },
}, },
{
"user not active",
fields{
eventstore: expectEventstore(
expectFilter(
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"userID", "org1", "sessionID", "clientID", []string{"audience"}, []string{"openid", "profile", "offline_access"},
[]domain.UserAuthMethodType{domain.UserAuthMethodTypePassword}, testNow, "nonce", &language.Afrikaans,
&domain.UserAgent{FingerprintID: gu.Ptr("browserFP")},
),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"at_accessTokenID", []string{"openid", "profile", "offline_access"}, time.Hour, domain.TokenReasonAuthRequest, nil),
),
eventFromEventPusherWithCreationDateNow(
oidcsession.NewRefreshTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
"rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour),
),
),
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
user.NewUserDeactivatedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
),
),
),
idGenerator: mock.NewIDGeneratorExpectIDs(t),
defaultAccessTokenLifetime: time.Hour,
defaultRefreshTokenLifetime: 7 * 24 * time.Hour,
defaultRefreshTokenIdleLifetime: 24 * time.Hour,
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
},
args{
ctx: authz.WithInstanceID(context.Background(), "instanceID"),
refreshToken: "VjJfb2lkY1Nlc3Npb25JRC1ydF9yZWZyZXNoVG9rZW5JRDp1c2VySUQ", //V2_oidcSessionID:rt_refreshTokenID:userID
scope: []string{"openid", "offline_access"},
complianceCheck: mockRefreshTokenComplianceChecker(nil),
},
res{
err: zerrors.ThrowPreconditionFailed(nil, "OIDCS-J39h2", "Errors.User.NotActive"),
},
},
{ {
"refresh successful", "refresh successful",
fields{ fields{
@ -1088,6 +1403,21 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
"rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour), "rt_refreshTokenID", 7*24*time.Hour, 24*time.Hour),
), ),
), ),
expectFilter(
user.NewHumanAddedEvent(
context.Background(),
&user.NewAggregate("userID", "org1").Aggregate,
"username",
"firstname",
"lastname",
"nickname",
"displayname",
language.Afrikaans,
domain.GenderUnspecified,
"email",
false,
),
),
expectFilter(), // token lifetime expectFilter(), // token lifetime
expectPush( expectPush(
oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAccessTokenAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -1153,7 +1483,7 @@ func TestCommands_ExchangeOIDCSessionRefreshAndAccessToken(t *testing.T) {
func TestCommands_OIDCSessionByRefreshToken(t *testing.T) { func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
type fields struct { type fields struct {
eventstore *eventstore.Eventstore eventstore func(*testing.T) *eventstore.Eventstore
idGenerator id.Generator idGenerator id.Generator
defaultAccessTokenLifetime time.Duration defaultAccessTokenLifetime time.Duration
defaultRefreshTokenLifetime time.Duration defaultRefreshTokenLifetime time.Duration
@ -1177,7 +1507,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
{ {
"invalid refresh token format error", "invalid refresh token format error",
fields{ fields{
eventstore: eventstoreExpect(t), eventstore: expectEventstore(),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
}, },
args{ args{
@ -1191,7 +1521,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
{ {
"inactive session error", "inactive session error",
fields{ fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter(), expectFilter(),
), ),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
@ -1207,7 +1537,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
{ {
"invalid refresh token error", "invalid refresh token error",
fields{ fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -1235,7 +1565,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
{ {
"expired refresh token error", "expired refresh token error",
fields{ fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -1267,7 +1597,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
{ {
"get successful", "get successful",
fields{ fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter( expectFilter(
eventFromEventPusherWithCreationDateNow( eventFromEventPusherWithCreationDateNow(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -1316,7 +1646,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &Commands{ c := &Commands{
eventstore: tt.fields.eventstore, eventstore: tt.fields.eventstore(t),
idGenerator: tt.fields.idGenerator, idGenerator: tt.fields.idGenerator,
defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime, defaultAccessTokenLifetime: tt.fields.defaultAccessTokenLifetime,
defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime, defaultRefreshTokenLifetime: tt.fields.defaultRefreshTokenLifetime,
@ -1348,7 +1678,7 @@ func TestCommands_OIDCSessionByRefreshToken(t *testing.T) {
func TestCommands_RevokeOIDCSessionToken(t *testing.T) { func TestCommands_RevokeOIDCSessionToken(t *testing.T) {
type fields struct { type fields struct {
eventstore *eventstore.Eventstore eventstore func(*testing.T) *eventstore.Eventstore
keyAlgorithm crypto.EncryptionAlgorithm keyAlgorithm crypto.EncryptionAlgorithm
} }
type args struct { type args struct {
@ -1368,7 +1698,7 @@ func TestCommands_RevokeOIDCSessionToken(t *testing.T) {
{ {
"invalid token", "invalid token",
fields{ fields{
eventstore: eventstoreExpect(t), eventstore: expectEventstore(),
keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)), keyAlgorithm: crypto.CreateMockEncryptionAlg(gomock.NewController(t)),
}, },
args{ args{
@ -1382,7 +1712,7 @@ func TestCommands_RevokeOIDCSessionToken(t *testing.T) {
{ {
"refresh_token inactive", "refresh_token inactive",
fields{ fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -1407,7 +1737,7 @@ func TestCommands_RevokeOIDCSessionToken(t *testing.T) {
{ {
"refresh_token invalid client", "refresh_token invalid client",
fields{ fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -1432,7 +1762,7 @@ func TestCommands_RevokeOIDCSessionToken(t *testing.T) {
{ {
"refresh_token revoked", "refresh_token revoked",
fields{ fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -1468,7 +1798,7 @@ func TestCommands_RevokeOIDCSessionToken(t *testing.T) {
{ {
"access_token inactive session", "access_token inactive session",
fields{ fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -1493,7 +1823,7 @@ func TestCommands_RevokeOIDCSessionToken(t *testing.T) {
{ {
"access_token invalid client", "access_token invalid client",
fields{ fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -1518,7 +1848,7 @@ func TestCommands_RevokeOIDCSessionToken(t *testing.T) {
{ {
"access_token revoked", "access_token revoked",
fields{ fields{
eventstore: eventstoreExpect(t, eventstore: expectEventstore(
expectFilter( expectFilter(
eventFromEventPusher( eventFromEventPusher(
oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate, oidcsession.NewAddedEvent(context.Background(), &oidcsession.NewAggregate("V2_oidcSessionID", "org1").Aggregate,
@ -1555,7 +1885,7 @@ func TestCommands_RevokeOIDCSessionToken(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := &Commands{ c := &Commands{
eventstore: tt.fields.eventstore, eventstore: tt.fields.eventstore(t),
keyAlgorithm: tt.fields.keyAlgorithm, keyAlgorithm: tt.fields.keyAlgorithm,
} }
err := c.RevokeOIDCSessionToken(tt.args.ctx, tt.args.token, tt.args.clientID) err := c.RevokeOIDCSessionToken(tt.args.ctx, tt.args.token, tt.args.clientID)

View File

@ -22,6 +22,7 @@ import (
"github.com/zitadel/zitadel/pkg/grpc/authn" "github.com/zitadel/zitadel/pkg/grpc/authn"
"github.com/zitadel/zitadel/pkg/grpc/management" "github.com/zitadel/zitadel/pkg/grpc/management"
"github.com/zitadel/zitadel/pkg/grpc/user" "github.com/zitadel/zitadel/pkg/grpc/user"
user_v2 "github.com/zitadel/zitadel/pkg/grpc/user/v2"
) )
func (i *Instance) CreateOIDCClient(ctx context.Context, redirectURI, logoutRedirectURI, projectID string, appType app.OIDCAppType, authMethod app.OIDCAuthMethodType, devMode bool, grantTypes ...app.OIDCGrantType) (*management.AddOIDCAppResponse, error) { func (i *Instance) CreateOIDCClient(ctx context.Context, redirectURI, logoutRedirectURI, projectID string, appType app.OIDCAppType, authMethod app.OIDCAuthMethodType, devMode bool, grantTypes ...app.OIDCGrantType) (*management.AddOIDCAppResponse, error) {
@ -355,6 +356,31 @@ func (i *Instance) CreateOIDCCredentialsClient(ctx context.Context) (machine *ma
return machine, name, secret.GetClientId(), secret.GetClientSecret(), nil return machine, name, secret.GetClientId(), secret.GetClientSecret(), nil
} }
func (i *Instance) CreateOIDCCredentialsClientInactive(ctx context.Context) (machine *management.AddMachineUserResponse, name, clientID, clientSecret string, err error) {
name = gofakeit.Username()
machine, err = i.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
Name: name,
UserName: name,
AccessTokenType: user.AccessTokenType_ACCESS_TOKEN_TYPE_JWT,
})
if err != nil {
return nil, "", "", "", err
}
secret, err := i.Client.Mgmt.GenerateMachineSecret(ctx, &management.GenerateMachineSecretRequest{
UserId: machine.GetUserId(),
})
if err != nil {
return nil, "", "", "", err
}
_, err = i.Client.UserV2.DeactivateUser(ctx, &user_v2.DeactivateUserRequest{
UserId: machine.GetUserId(),
})
if err != nil {
return nil, "", "", "", err
}
return machine, name, secret.GetClientId(), secret.GetClientSecret(), nil
}
func (i *Instance) CreateOIDCJWTProfileClient(ctx context.Context) (machine *management.AddMachineUserResponse, name string, keyData []byte, err error) { func (i *Instance) CreateOIDCJWTProfileClient(ctx context.Context) (machine *management.AddMachineUserResponse, name string, keyData []byte, err error) {
name = gofakeit.Username() name = gofakeit.Username()
machine, err = i.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{ machine, err = i.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{

View File

@ -2,7 +2,7 @@ with usr as (
select u.id, u.creation_date, u.change_date, u.sequence, u.state, u.resource_owner, u.username, n.login_name as preferred_login_name select u.id, u.creation_date, u.change_date, u.sequence, u.state, u.resource_owner, u.username, n.login_name as preferred_login_name
from projections.users13 u from projections.users13 u
left join projections.login_names3 n on u.id = n.user_id and u.instance_id = n.instance_id left join projections.login_names3 n on u.id = n.user_id and u.instance_id = n.instance_id
where u.id = $1 where u.id = $1 and u.state = 1 -- only allow active users
and u.instance_id = $2 and u.instance_id = $2
and n.is_primary = true and n.is_primary = true
), ),