From af4e0484d063a1b6eef607488692a1daffb91630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Thu, 18 Jan 2024 08:10:49 +0200 Subject: [PATCH] fix: uniform oidc errors (#7237) * fix: uniform oidc errors sanitize oidc error reporting when passing package boundary towards oidc. * add should TriggerBulk in get audiences for auth request * upgrade to oidc 3.10.1 * provisional oidc upgrade to error branch * pin oidc 3.10.2 --- go.mod | 2 +- go.sum | 4 +- internal/api/oidc/access_token.go | 6 +- internal/api/oidc/auth_request.go | 58 +++++++++++++---- .../api/oidc/auth_request_integration_test.go | 26 ++++---- internal/api/oidc/client.go | 55 +++++++++++++--- internal/api/oidc/client_integration_test.go | 4 +- internal/api/oidc/device_auth.go | 10 +-- internal/api/oidc/error.go | 49 +++++++++++++++ internal/api/oidc/error_test.go | 63 +++++++++++++++++++ internal/api/oidc/introspect.go | 5 +- internal/api/oidc/jwt-profile.go | 9 ++- internal/api/oidc/key.go | 5 +- internal/api/oidc/oidc_integration_test.go | 14 ++--- internal/api/oidc/server.go | 7 ++- .../eventsourcing/eventstore/auth_request.go | 2 +- internal/query/app.go | 9 ++- 17 files changed, 267 insertions(+), 61 deletions(-) create mode 100644 internal/api/oidc/error.go create mode 100644 internal/api/oidc/error_test.go diff --git a/go.mod b/go.mod index f54ca36462..7a4a364e25 100644 --- a/go.mod +++ b/go.mod @@ -61,7 +61,7 @@ require ( github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203 github.com/ttacon/libphonenumber v1.2.1 github.com/zitadel/logging v0.5.0 - github.com/zitadel/oidc/v3 v3.10.0 + github.com/zitadel/oidc/v3 v3.10.2 github.com/zitadel/passwap v0.5.0 github.com/zitadel/saml v0.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1 diff --git a/go.sum b/go.sum index e68db76780..925f9cd7df 100644 --- a/go.sum +++ b/go.sum @@ -782,8 +782,8 @@ github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8= github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= github.com/zitadel/logging v0.5.0 h1:Kunouvqse/efXy4UDvFw5s3vP+Z4AlHo3y8wF7stXHA= github.com/zitadel/logging v0.5.0/go.mod h1:IzP5fzwFhzzyxHkSmfF8dsyqFsQRJLLcQmwhIBzlGsE= -github.com/zitadel/oidc/v3 v3.10.0 h1:qAGlw6FGQEpkWya8tT03P6pU4AHNrZ0Kfyxmwsd4am0= -github.com/zitadel/oidc/v3 v3.10.0/go.mod h1:nfjWH8ps4B7T0JGJyLLOIUlhr0Z4becyGKui/sXYpA8= +github.com/zitadel/oidc/v3 v3.10.2 h1:nowZrpOBR4tdIlYXE8/l5Nl84QDYwyHpccIE1l2OAd4= +github.com/zitadel/oidc/v3 v3.10.2/go.mod h1:nfjWH8ps4B7T0JGJyLLOIUlhr0Z4becyGKui/sXYpA8= github.com/zitadel/passwap v0.5.0 h1:kFMoRyo0GnxtOz7j9+r/CsRwSCjHGRaAKoUe69NwPvs= github.com/zitadel/passwap v0.5.0/go.mod h1:uqY7D3jqdTFcKsW0Q3Pcv5qDMmSHpVTzUZewUKC1KZA= github.com/zitadel/saml v0.1.3 h1:LI4DOCVyyU1qKPkzs3vrGcA5J3H4pH3+CL9zr9ShkpM= diff --git a/internal/api/oidc/access_token.go b/internal/api/oidc/access_token.go index 173e507530..0c957ade7a 100644 --- a/internal/api/oidc/access_token.go +++ b/internal/api/oidc/access_token.go @@ -27,20 +27,22 @@ type accessToken struct { isPAT bool } +var ErrInvalidTokenFormat = errors.New("invalid token format") + func (s *Server) verifyAccessToken(ctx context.Context, tkn string) (*accessToken, error) { var tokenID, subject string if tokenIDSubject, err := s.Provider().Crypto().Decrypt(tkn); err == nil { split := strings.Split(tokenIDSubject, ":") if len(split) != 2 { - return nil, errors.New("invalid token format") + return nil, zerrors.ThrowPermissionDenied(ErrInvalidTokenFormat, "OIDC-rei1O", "token is not valid or has expired") } tokenID, subject = split[0], split[1] } else { verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.keySet) claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, tkn, verifier) if err != nil { - return nil, err + return nil, zerrors.ThrowPermissionDenied(err, "OIDC-Eib8e", "token is not valid or has expired") } tokenID, subject = claims.JWTID, claims.Subject } diff --git a/internal/api/oidc/auth_request.go b/internal/api/oidc/auth_request.go index 062b0e3351..fb5bd1080f 100644 --- a/internal/api/oidc/auth_request.go +++ b/internal/api/oidc/auth_request.go @@ -28,7 +28,10 @@ const ( func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() headers, _ := http_utils.HeadersFromCtx(ctx) if loginClient := headers.Get(LoginClientHeader); loginClient != "" { @@ -102,7 +105,7 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string) if err != nil { return nil, err } - appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}) + appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, true) if err != nil { return nil, err } @@ -112,7 +115,10 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string) func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() if strings.HasPrefix(id, command.IDPrefixV2) { req, err := o.command.GetCurrentAuthRequest(ctx, id) @@ -135,7 +141,10 @@ func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRe func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.AuthRequest, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() plainCode, err := o.decryptGrant(code) if err != nil { @@ -166,7 +175,10 @@ func (o *OPStorage) decryptGrant(grant string) (string, error) { func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() if strings.HasPrefix(id, command.IDPrefixV2) { return o.command.AddAuthRequestCode(ctx, id, code) @@ -181,14 +193,20 @@ func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err erro func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() return o.repo.DeleteAuthRequest(ctx, id) } func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() var userAgentID, applicationID, userOrgID string switch authReq := req.(type) { @@ -221,7 +239,10 @@ func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() // handle V2 request directly switch tokenReq := req.(type) { @@ -279,7 +300,10 @@ func getInfoFromRequest(req op.TokenRequest) (string, string, string, time.Time, func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() plainToken, err := o.decryptGrant(refreshToken) if err != nil { @@ -307,7 +331,10 @@ func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID string) (err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() userAgentID, ok := middleware.UserAgentIDFromCtx(ctx) if !ok { logging.Error("no user agent id") @@ -331,7 +358,10 @@ func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID strin func (o *OPStorage) TerminateSessionFromRequest(ctx context.Context, endSessionRequest *op.EndSessionRequest) (redirectURI string, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() // check for the login client header // and if not provided, terminate the session using the V1 method @@ -408,6 +438,12 @@ func (o *OPStorage) revokeTokenV1(ctx context.Context, token, userID, clientID s } func (o *OPStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() + plainToken, err := o.decryptGrant(token) if err != nil { return "", "", op.ErrInvalidRefreshToken diff --git a/internal/api/oidc/auth_request_integration_test.go b/internal/api/oidc/auth_request_integration_test.go index 9f8e77688d..44a4b82e6b 100644 --- a/internal/api/oidc/auth_request_integration_test.go +++ b/internal/api/oidc/auth_request_integration_test.go @@ -51,7 +51,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) { // test code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -69,7 +69,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) { require.Error(t, err) // exchange with a used code must fail - _, err = exchangeTokens(t, clientID, code) + _, err = exchangeTokens(t, clientID, code, redirectURI) require.Error(t, err) } @@ -140,7 +140,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) { // test code exchange (expect refresh token to be returned) code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -165,7 +165,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -201,7 +201,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -244,7 +244,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -281,7 +281,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -324,7 +324,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing. // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -359,7 +359,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -391,7 +391,7 @@ func TestOPStorage_TerminateSession(t *testing.T) { // test code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -428,7 +428,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) { // test code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -472,7 +472,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) { // test code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -497,7 +497,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) { require.Error(t, err) } -func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) { +func exchangeTokens(t testing.TB, clientID, code, redirectURI string) (*oidc.Tokens[*oidc.IDTokenClaims], error) { provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI) require.NoError(t, err) diff --git a/internal/api/oidc/client.go b/internal/api/oidc/client.go index a4a4adb1b8..394ecd834b 100644 --- a/internal/api/oidc/client.go +++ b/internal/api/oidc/client.go @@ -42,7 +42,10 @@ const ( func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Client, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() client, err := o.query.GetOIDCClientByID(ctx, id, false) if err != nil { return nil, err @@ -59,7 +62,10 @@ func (o *OPStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID str func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer string) (_ *jose.JSONWebKey, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() publicKeyData, err := o.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, issuer, false) if err != nil { return nil, err @@ -75,7 +81,12 @@ func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer strin }, nil } -func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string, scopes []string) ([]string, error) { +func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string, scopes []string) (_ []string, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() user, err := o.query.GetUserByID(ctx, true, subject) if err != nil { return nil, err @@ -85,7 +96,10 @@ func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string func (o *OPStorage) AuthorizeClientIDSecret(ctx context.Context, id string, secret string) (err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() ctx = authz.SetCtxData(ctx, authz.CtxData{ UserID: oidcCtx, OrgID: oidcCtx, @@ -102,7 +116,10 @@ func (o *OPStorage) AuthorizeClientIDSecret(ctx context.Context, id string, secr func (o *OPStorage) SetUserinfoFromToken(ctx context.Context, userInfo *oidc.UserInfo, tokenID, subject, origin string) (err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() if strings.HasPrefix(tokenID, command.IDPrefixV2) { token, err := o.query.ActiveAccessTokenByToken(ctx, tokenID) @@ -129,7 +146,10 @@ func (o *OPStorage) SetUserinfoFromToken(ctx context.Context, userInfo *oidc.Use func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.UserInfo, userID, applicationID string, scopes []string) (err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() if applicationID != "" { app, err := o.query.AppByOIDCClientID(ctx, applicationID) if err != nil { @@ -159,7 +179,10 @@ func (o *OPStorage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.U func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) (err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() if strings.HasPrefix(tokenID, command.IDPrefixV2) { token, err := o.query.ActiveAccessTokenByToken(ctx, tokenID) @@ -196,7 +219,12 @@ func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection token.CreationDate, token.Expiration) } -func (o *OPStorage) ClientCredentialsTokenRequest(ctx context.Context, clientID string, scope []string) (op.TokenRequest, error) { +func (o *OPStorage) ClientCredentialsTokenRequest(ctx context.Context, clientID string, scope []string) (_ op.TokenRequest, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() user, err := o.query.GetUserByLoginName(ctx, false, clientID) if err != nil { return nil, err @@ -545,6 +573,12 @@ func (o *OPStorage) userinfoFlows(ctx context.Context, user *query.User, userGra } func (o *OPStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() + roles := make([]string, 0) var allRoles bool for _, scope := range scopes { @@ -903,7 +937,10 @@ func userinfoClaims(userInfo *oidc.UserInfo) func(c *actions.FieldConfig) interf func (s *Server) VerifyClient(ctx context.Context, r *op.Request[op.ClientCredentials]) (_ op.Client, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials { return s.clientCredentialsAuth(ctx, r.Data.ClientID, r.Data.ClientSecret) diff --git a/internal/api/oidc/client_integration_test.go b/internal/api/oidc/client_integration_test.go index a20e388dca..44219c6107 100644 --- a/internal/api/oidc/client_integration_test.go +++ b/internal/api/oidc/client_integration_test.go @@ -43,7 +43,7 @@ func TestOPStorage_SetUserinfoFromToken(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -152,7 +152,7 @@ func TestServer_Introspect(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, app.GetClientId(), code) + tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) diff --git a/internal/api/oidc/device_auth.go b/internal/api/oidc/device_auth.go index 775ec6f9e3..eb4a6e5a85 100644 --- a/internal/api/oidc/device_auth.go +++ b/internal/api/oidc/device_auth.go @@ -109,14 +109,9 @@ func newDeviceAuthorizationState(d *query.DeviceAuth) *op.DeviceAuthorizationSta // As generated user codes are of low entropy, this implementation also takes care or // device authorization request cleanup, when it has been Approved, Denied or Expired. func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (state *op.DeviceAuthorizationState, err error) { - const logMsg = "get device authorization state" - logger := logging.WithFields("device_code", deviceCode) - ctx, span := tracing.NewSpan(ctx) defer func() { - if err != nil { - logger.WithError(err).Error(logMsg) - } + err = oidcError(err) span.EndWithError(err) }() @@ -124,7 +119,8 @@ func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, de if err != nil { return nil, err } - logger.SetFields( + logging.WithFields( + "device_code", deviceCode, "expires", deviceAuth.Expires, "scopes", deviceAuth.Scopes, "subject", deviceAuth.Subject, "state", deviceAuth.State, ).Debug("device authorization state") diff --git a/internal/api/oidc/error.go b/internal/api/oidc/error.go new file mode 100644 index 0000000000..9c7154092f --- /dev/null +++ b/internal/api/oidc/error.go @@ -0,0 +1,49 @@ +package oidc + +import ( + "errors" + + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + + http_util "github.com/zitadel/zitadel/internal/api/http" + "github.com/zitadel/zitadel/internal/zerrors" +) + +// oidcError ensures [*oidc.Error] and [op.StatusError] types for err. +// It must be used when an error passes the package boundary towards oidc. +// When err is already of the correct type is passed as-is. +// If the err is a Zitadel error, it is transformed with a proper HTTP status code. +// Unknown errors are treated as internal server errors. +func oidcError(err error) error { + if err == nil { + return nil + } + + var ( + sError op.StatusError + oError *oidc.Error + zError *zerrors.ZitadelError + ) + if errors.As(err, &sError) || errors.As(err, &oError) { + return err + } + + // here we are encountering an error type that is completely unknown to us. + if !errors.As(err, &zError) { + err = zerrors.ThrowInternal(err, "OIDC-AhX2u", "Errors.Internal") + errors.As(err, &zError) + } + + statusCode, _ := http_util.ZitadelErrorToHTTPStatusCode(err) + newOidcErr := oidc.ErrServerError + if statusCode < 500 { + newOidcErr = oidc.ErrInvalidRequest + } + return op.NewStatusError( + newOidcErr(). + WithParent(err). + WithDescription(zError.GetMessage()), + statusCode, + ) +} diff --git a/internal/api/oidc/error_test.go b/internal/api/oidc/error_test.go new file mode 100644 index 0000000000..fbacbdc3ff --- /dev/null +++ b/internal/api/oidc/error_test.go @@ -0,0 +1,63 @@ +package oidc + +import ( + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/require" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + + "github.com/zitadel/zitadel/internal/zerrors" +) + +func Test_oidcError(t *testing.T) { + tests := []struct { + name string + err error + wantErr error + }{ + { + name: "nil", + err: nil, + wantErr: nil, + }, + { + name: "status err", + err: op.NewStatusError(io.ErrClosedPipe, http.StatusTeapot), + wantErr: op.NewStatusError(io.ErrClosedPipe, http.StatusTeapot), + }, + { + name: "oidc err", + err: oidc.ErrInvalidClient().WithParent(io.ErrClosedPipe), + wantErr: oidc.ErrInvalidClient().WithParent(io.ErrClosedPipe), + }, + { + name: "unknown err", + err: io.ErrClosedPipe, + wantErr: op.NewStatusError( + oidc.ErrServerError(). + WithParent(io.ErrClosedPipe). + WithDescription("Errors.Internal"), + http.StatusInternalServerError, + ), + }, + { + name: "zitadel error, invalid request", + err: zerrors.ThrowPreconditionFailed(io.ErrClosedPipe, "TEST-123", "oopsie"), + wantErr: op.NewStatusError( + oidc.ErrInvalidRequest(). + WithParent(io.ErrClosedPipe). + WithDescription("oopsie"), + http.StatusBadRequest, + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := oidcError(tt.err) + require.ErrorIs(t, err, tt.wantErr) + }) + } +} diff --git a/internal/api/oidc/introspect.go b/internal/api/oidc/introspect.go index 8c73755199..5690888145 100644 --- a/internal/api/oidc/introspect.go +++ b/internal/api/oidc/introspect.go @@ -18,7 +18,10 @@ import ( func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (resp *op.Response, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() if s.features.LegacyIntrospection { return s.LegacyServer.Introspect(ctx, r) diff --git a/internal/api/oidc/jwt-profile.go b/internal/api/oidc/jwt-profile.go index 805936dff3..fe668b5a8a 100644 --- a/internal/api/oidc/jwt-profile.go +++ b/internal/api/oidc/jwt-profile.go @@ -7,10 +7,17 @@ import ( "github.com/zitadel/oidc/v3/pkg/op" "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" ) -func (o *OPStorage) JWTProfileTokenType(ctx context.Context, request op.TokenRequest) (op.AccessTokenType, error) { +func (o *OPStorage) JWTProfileTokenType(ctx context.Context, request op.TokenRequest) (_ op.AccessTokenType, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() + mapJWTProfileScopesToAudience(ctx, request) user, err := o.query.GetUserByID(ctx, false, request.GetSubject()) if err != nil { diff --git a/internal/api/oidc/key.go b/internal/api/oidc/key.go index d0e781a0c2..8f8ca10c34 100644 --- a/internal/api/oidc/key.go +++ b/internal/api/oidc/key.go @@ -111,7 +111,10 @@ func (k *keySetCache) getKey(ctx context.Context, keyID string) (_ *jose.JSONWeb // VerifySignature implements the oidc.KeySet interface. func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() if len(jws.Signatures) != 1 { return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid") diff --git a/internal/api/oidc/oidc_integration_test.go b/internal/api/oidc/oidc_integration_test.go index f956b5c48f..94b3d937b5 100644 --- a/internal/api/oidc/oidc_integration_test.go +++ b/internal/api/oidc/oidc_integration_test.go @@ -71,7 +71,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -107,7 +107,7 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -134,7 +134,7 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertIDTokenClaims(t, tokens.IDTokenClaims, armPassword, startTime, changeTime) @@ -162,7 +162,7 @@ func Test_ZITADEL_API_success(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -196,7 +196,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, false) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -225,7 +225,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) @@ -267,7 +267,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) { // code exchange code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) assertTokens(t, tokens, true) assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime) diff --git a/internal/api/oidc/server.go b/internal/api/oidc/server.go index 8ba186dd7b..f4d604aedb 100644 --- a/internal/api/oidc/server.go +++ b/internal/api/oidc/server.go @@ -110,10 +110,13 @@ func (s *Server) Ready(ctx context.Context, r *op.Request[struct{}]) (_ *op.Resp func (s *Server) Discovery(ctx context.Context, r *op.Request[struct{}]) (_ *op.Response, err error) { ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() + defer func() { + err = oidcError(err) + span.EndWithError(err) + }() restrictions, err := s.query.GetInstanceRestrictions(ctx) if err != nil { - return nil, err + return nil, op.NewStatusError(oidc.ErrServerError().WithParent(err).WithDescription("internal server error"), http.StatusInternalServerError) } allowedLanguages := restrictions.AllowedLanguages if len(allowedLanguages) == 0 { diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index 2d2c60e3cc..1a7eff8052 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -140,7 +140,7 @@ func (repo *AuthRequestRepo) CreateAuthRequest(ctx context.Context, request *dom if err != nil { return nil, err } - appIDs, err := repo.Query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}) + appIDs, err := repo.Query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false) if err != nil { return nil, err } diff --git a/internal/query/app.go b/internal/query/app.go index 92a44742f9..cc7aa1361a 100644 --- a/internal/query/app.go +++ b/internal/query/app.go @@ -476,10 +476,17 @@ func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, wit return apps, err } -func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries) (ids []string, err error) { +func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries, shouldTriggerBulk bool) (ids []string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() + if shouldTriggerBulk { + _, traceSpan := tracing.NewNamedSpan(ctx, "TriggerAppProjection") + ctx, err = projection.AppProjection.Trigger(ctx, handler.WithAwaitRunning()) + logging.OnError(err).Debug("trigger failed") + traceSpan.EndWithError(err) + } + query, scan := prepareClientIDsQuery(ctx, q.client) eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} stmt, args, err := queries.toQuery(query).Where(eq).ToSql()