mirror of
https://github.com/zitadel/zitadel
synced 2024-11-22 00:39:36 +00:00
fix: uniform oidc errors (#7237)
* fix: uniform oidc errors sanitize oidc error reporting when passing package boundary towards oidc. * add should TriggerBulk in get audiences for auth request * upgrade to oidc 3.10.1 * provisional oidc upgrade to error branch * pin oidc 3.10.2
This commit is contained in:
parent
cdfcdec101
commit
af4e0484d0
2
go.mod
2
go.mod
@ -61,7 +61,7 @@ require (
|
||||
github.com/superseriousbusiness/exifremove v0.0.0-20210330092427-6acd27eac203
|
||||
github.com/ttacon/libphonenumber v1.2.1
|
||||
github.com/zitadel/logging v0.5.0
|
||||
github.com/zitadel/oidc/v3 v3.10.0
|
||||
github.com/zitadel/oidc/v3 v3.10.2
|
||||
github.com/zitadel/passwap v0.5.0
|
||||
github.com/zitadel/saml v0.1.3
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.46.1
|
||||
|
4
go.sum
4
go.sum
@ -782,8 +782,8 @@ github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8=
|
||||
github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||
github.com/zitadel/logging v0.5.0 h1:Kunouvqse/efXy4UDvFw5s3vP+Z4AlHo3y8wF7stXHA=
|
||||
github.com/zitadel/logging v0.5.0/go.mod h1:IzP5fzwFhzzyxHkSmfF8dsyqFsQRJLLcQmwhIBzlGsE=
|
||||
github.com/zitadel/oidc/v3 v3.10.0 h1:qAGlw6FGQEpkWya8tT03P6pU4AHNrZ0Kfyxmwsd4am0=
|
||||
github.com/zitadel/oidc/v3 v3.10.0/go.mod h1:nfjWH8ps4B7T0JGJyLLOIUlhr0Z4becyGKui/sXYpA8=
|
||||
github.com/zitadel/oidc/v3 v3.10.2 h1:nowZrpOBR4tdIlYXE8/l5Nl84QDYwyHpccIE1l2OAd4=
|
||||
github.com/zitadel/oidc/v3 v3.10.2/go.mod h1:nfjWH8ps4B7T0JGJyLLOIUlhr0Z4becyGKui/sXYpA8=
|
||||
github.com/zitadel/passwap v0.5.0 h1:kFMoRyo0GnxtOz7j9+r/CsRwSCjHGRaAKoUe69NwPvs=
|
||||
github.com/zitadel/passwap v0.5.0/go.mod h1:uqY7D3jqdTFcKsW0Q3Pcv5qDMmSHpVTzUZewUKC1KZA=
|
||||
github.com/zitadel/saml v0.1.3 h1:LI4DOCVyyU1qKPkzs3vrGcA5J3H4pH3+CL9zr9ShkpM=
|
||||
|
@ -27,20 +27,22 @@ type accessToken struct {
|
||||
isPAT bool
|
||||
}
|
||||
|
||||
var ErrInvalidTokenFormat = errors.New("invalid token format")
|
||||
|
||||
func (s *Server) verifyAccessToken(ctx context.Context, tkn string) (*accessToken, error) {
|
||||
var tokenID, subject string
|
||||
|
||||
if tokenIDSubject, err := s.Provider().Crypto().Decrypt(tkn); err == nil {
|
||||
split := strings.Split(tokenIDSubject, ":")
|
||||
if len(split) != 2 {
|
||||
return nil, errors.New("invalid token format")
|
||||
return nil, zerrors.ThrowPermissionDenied(ErrInvalidTokenFormat, "OIDC-rei1O", "token is not valid or has expired")
|
||||
}
|
||||
tokenID, subject = split[0], split[1]
|
||||
} else {
|
||||
verifier := op.NewAccessTokenVerifier(op.IssuerFromContext(ctx), s.keySet)
|
||||
claims, err := op.VerifyAccessToken[*oidc.AccessTokenClaims](ctx, tkn, verifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, zerrors.ThrowPermissionDenied(err, "OIDC-Eib8e", "token is not valid or has expired")
|
||||
}
|
||||
tokenID, subject = claims.JWTID, claims.Subject
|
||||
}
|
||||
|
@ -28,7 +28,10 @@ const (
|
||||
|
||||
func (o *OPStorage) CreateAuthRequest(ctx context.Context, req *oidc.AuthRequest, userID string) (_ op.AuthRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
headers, _ := http_utils.HeadersFromCtx(ctx)
|
||||
if loginClient := headers.Get(LoginClientHeader); loginClient != "" {
|
||||
@ -102,7 +105,7 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}})
|
||||
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -112,7 +115,10 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string)
|
||||
|
||||
func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
req, err := o.command.GetCurrentAuthRequest(ctx, id)
|
||||
@ -135,7 +141,10 @@ func (o *OPStorage) AuthRequestByID(ctx context.Context, id string) (_ op.AuthRe
|
||||
|
||||
func (o *OPStorage) AuthRequestByCode(ctx context.Context, code string) (_ op.AuthRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
plainCode, err := o.decryptGrant(code)
|
||||
if err != nil {
|
||||
@ -166,7 +175,10 @@ func (o *OPStorage) decryptGrant(grant string) (string, error) {
|
||||
|
||||
func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
if strings.HasPrefix(id, command.IDPrefixV2) {
|
||||
return o.command.AddAuthRequestCode(ctx, id, code)
|
||||
@ -181,14 +193,20 @@ func (o *OPStorage) SaveAuthCode(ctx context.Context, id, code string) (err erro
|
||||
|
||||
func (o *OPStorage) DeleteAuthRequest(ctx context.Context, id string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
return o.repo.DeleteAuthRequest(ctx, id)
|
||||
}
|
||||
|
||||
func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest) (_ string, _ time.Time, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
var userAgentID, applicationID, userOrgID string
|
||||
switch authReq := req.(type) {
|
||||
@ -221,7 +239,10 @@ func (o *OPStorage) CreateAccessToken(ctx context.Context, req op.TokenRequest)
|
||||
|
||||
func (o *OPStorage) CreateAccessAndRefreshTokens(ctx context.Context, req op.TokenRequest, refreshToken string) (_, _ string, _ time.Time, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
// handle V2 request directly
|
||||
switch tokenReq := req.(type) {
|
||||
@ -279,7 +300,10 @@ func getInfoFromRequest(req op.TokenRequest) (string, string, string, time.Time,
|
||||
|
||||
func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (_ op.RefreshTokenRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
plainToken, err := o.decryptGrant(refreshToken)
|
||||
if err != nil {
|
||||
@ -307,7 +331,10 @@ func (o *OPStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken
|
||||
|
||||
func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
userAgentID, ok := middleware.UserAgentIDFromCtx(ctx)
|
||||
if !ok {
|
||||
logging.Error("no user agent id")
|
||||
@ -331,7 +358,10 @@ func (o *OPStorage) TerminateSession(ctx context.Context, userID, clientID strin
|
||||
|
||||
func (o *OPStorage) TerminateSessionFromRequest(ctx context.Context, endSessionRequest *op.EndSessionRequest) (redirectURI string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
// check for the login client header
|
||||
// and if not provided, terminate the session using the V1 method
|
||||
@ -408,6 +438,12 @@ func (o *OPStorage) revokeTokenV1(ctx context.Context, token, userID, clientID s
|
||||
}
|
||||
|
||||
func (o *OPStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
plainToken, err := o.decryptGrant(token)
|
||||
if err != nil {
|
||||
return "", "", op.ErrInvalidRefreshToken
|
||||
|
@ -51,7 +51,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) {
|
||||
|
||||
// test code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -69,7 +69,7 @@ func TestOPStorage_CreateAccessToken_code(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
|
||||
// exchange with a used code must fail
|
||||
_, err = exchangeTokens(t, clientID, code)
|
||||
_, err = exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@ -140,7 +140,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_code(t *testing.T) {
|
||||
|
||||
// test code exchange (expect refresh token to be returned)
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -165,7 +165,7 @@ func TestOPStorage_CreateAccessAndRefreshTokens_refresh(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -201,7 +201,7 @@ func TestOPStorage_RevokeToken_access_token(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -244,7 +244,7 @@ func TestOPStorage_RevokeToken_access_token_invalid_token_hint_type(t *testing.T
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -281,7 +281,7 @@ func TestOPStorage_RevokeToken_refresh_token(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -324,7 +324,7 @@ func TestOPStorage_RevokeToken_refresh_token_invalid_token_type_hint(t *testing.
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -359,7 +359,7 @@ func TestOPStorage_RevokeToken_invalid_client(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -391,7 +391,7 @@ func TestOPStorage_TerminateSession(t *testing.T) {
|
||||
|
||||
// test code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -428,7 +428,7 @@ func TestOPStorage_TerminateSession_refresh_grant(t *testing.T) {
|
||||
|
||||
// test code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -472,7 +472,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) {
|
||||
|
||||
// test code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -497,7 +497,7 @@ func TestOPStorage_TerminateSession_empty_id_token_hint(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func exchangeTokens(t testing.TB, clientID, code string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
|
||||
func exchangeTokens(t testing.TB, clientID, code, redirectURI string) (*oidc.Tokens[*oidc.IDTokenClaims], error) {
|
||||
provider, err := Tester.CreateRelyingParty(CTX, clientID, redirectURI)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@ -42,7 +42,10 @@ const (
|
||||
|
||||
func (o *OPStorage) GetClientByClientID(ctx context.Context, id string) (_ op.Client, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
client, err := o.query.GetOIDCClientByID(ctx, id, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -59,7 +62,10 @@ func (o *OPStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, userID str
|
||||
|
||||
func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer string) (_ *jose.JSONWebKey, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
publicKeyData, err := o.query.GetAuthNKeyPublicKeyByIDAndIdentifier(ctx, keyID, issuer, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -75,7 +81,12 @@ func (o *OPStorage) GetKeyByIDAndIssuer(ctx context.Context, keyID, issuer strin
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string, scopes []string) ([]string, error) {
|
||||
func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string, scopes []string) (_ []string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
user, err := o.query.GetUserByID(ctx, true, subject)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -85,7 +96,10 @@ func (o *OPStorage) ValidateJWTProfileScopes(ctx context.Context, subject string
|
||||
|
||||
func (o *OPStorage) AuthorizeClientIDSecret(ctx context.Context, id string, secret string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
ctx = authz.SetCtxData(ctx, authz.CtxData{
|
||||
UserID: oidcCtx,
|
||||
OrgID: oidcCtx,
|
||||
@ -102,7 +116,10 @@ func (o *OPStorage) AuthorizeClientIDSecret(ctx context.Context, id string, secr
|
||||
|
||||
func (o *OPStorage) SetUserinfoFromToken(ctx context.Context, userInfo *oidc.UserInfo, tokenID, subject, origin string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
if strings.HasPrefix(tokenID, command.IDPrefixV2) {
|
||||
token, err := o.query.ActiveAccessTokenByToken(ctx, tokenID)
|
||||
@ -129,7 +146,10 @@ func (o *OPStorage) SetUserinfoFromToken(ctx context.Context, userInfo *oidc.Use
|
||||
|
||||
func (o *OPStorage) SetUserinfoFromScopes(ctx context.Context, userInfo *oidc.UserInfo, userID, applicationID string, scopes []string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
if applicationID != "" {
|
||||
app, err := o.query.AppByOIDCClientID(ctx, applicationID)
|
||||
if err != nil {
|
||||
@ -159,7 +179,10 @@ func (o *OPStorage) SetUserinfoFromRequest(ctx context.Context, userinfo *oidc.U
|
||||
|
||||
func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
if strings.HasPrefix(tokenID, command.IDPrefixV2) {
|
||||
token, err := o.query.ActiveAccessTokenByToken(ctx, tokenID)
|
||||
@ -196,7 +219,12 @@ func (o *OPStorage) SetIntrospectionFromToken(ctx context.Context, introspection
|
||||
token.CreationDate, token.Expiration)
|
||||
}
|
||||
|
||||
func (o *OPStorage) ClientCredentialsTokenRequest(ctx context.Context, clientID string, scope []string) (op.TokenRequest, error) {
|
||||
func (o *OPStorage) ClientCredentialsTokenRequest(ctx context.Context, clientID string, scope []string) (_ op.TokenRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
user, err := o.query.GetUserByLoginName(ctx, false, clientID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -545,6 +573,12 @@ func (o *OPStorage) userinfoFlows(ctx context.Context, user *query.User, userGra
|
||||
}
|
||||
|
||||
func (o *OPStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (claims map[string]interface{}, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
roles := make([]string, 0)
|
||||
var allRoles bool
|
||||
for _, scope := range scopes {
|
||||
@ -903,7 +937,10 @@ func userinfoClaims(userInfo *oidc.UserInfo) func(c *actions.FieldConfig) interf
|
||||
|
||||
func (s *Server) VerifyClient(ctx context.Context, r *op.Request[op.ClientCredentials]) (_ op.Client, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
if oidc.GrantType(r.Form.Get("grant_type")) == oidc.GrantTypeClientCredentials {
|
||||
return s.clientCredentialsAuth(ctx, r.Data.ClientID, r.Data.ClientSecret)
|
||||
|
@ -43,7 +43,7 @@ func TestOPStorage_SetUserinfoFromToken(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -152,7 +152,7 @@ func TestServer_Introspect(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, app.GetClientId(), code)
|
||||
tokens, err := exchangeTokens(t, app.GetClientId(), code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
|
@ -109,14 +109,9 @@ func newDeviceAuthorizationState(d *query.DeviceAuth) *op.DeviceAuthorizationSta
|
||||
// As generated user codes are of low entropy, this implementation also takes care or
|
||||
// device authorization request cleanup, when it has been Approved, Denied or Expired.
|
||||
func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, deviceCode string) (state *op.DeviceAuthorizationState, err error) {
|
||||
const logMsg = "get device authorization state"
|
||||
logger := logging.WithFields("device_code", deviceCode)
|
||||
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
logger.WithError(err).Error(logMsg)
|
||||
}
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
@ -124,7 +119,8 @@ func (o *OPStorage) GetDeviceAuthorizatonState(ctx context.Context, clientID, de
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.SetFields(
|
||||
logging.WithFields(
|
||||
"device_code", deviceCode,
|
||||
"expires", deviceAuth.Expires, "scopes", deviceAuth.Scopes,
|
||||
"subject", deviceAuth.Subject, "state", deviceAuth.State,
|
||||
).Debug("device authorization state")
|
||||
|
49
internal/api/oidc/error.go
Normal file
49
internal/api/oidc/error.go
Normal file
@ -0,0 +1,49 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
http_util "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
// oidcError ensures [*oidc.Error] and [op.StatusError] types for err.
|
||||
// It must be used when an error passes the package boundary towards oidc.
|
||||
// When err is already of the correct type is passed as-is.
|
||||
// If the err is a Zitadel error, it is transformed with a proper HTTP status code.
|
||||
// Unknown errors are treated as internal server errors.
|
||||
func oidcError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
sError op.StatusError
|
||||
oError *oidc.Error
|
||||
zError *zerrors.ZitadelError
|
||||
)
|
||||
if errors.As(err, &sError) || errors.As(err, &oError) {
|
||||
return err
|
||||
}
|
||||
|
||||
// here we are encountering an error type that is completely unknown to us.
|
||||
if !errors.As(err, &zError) {
|
||||
err = zerrors.ThrowInternal(err, "OIDC-AhX2u", "Errors.Internal")
|
||||
errors.As(err, &zError)
|
||||
}
|
||||
|
||||
statusCode, _ := http_util.ZitadelErrorToHTTPStatusCode(err)
|
||||
newOidcErr := oidc.ErrServerError
|
||||
if statusCode < 500 {
|
||||
newOidcErr = oidc.ErrInvalidRequest
|
||||
}
|
||||
return op.NewStatusError(
|
||||
newOidcErr().
|
||||
WithParent(err).
|
||||
WithDescription(zError.GetMessage()),
|
||||
statusCode,
|
||||
)
|
||||
}
|
63
internal/api/oidc/error_test.go
Normal file
63
internal/api/oidc/error_test.go
Normal file
@ -0,0 +1,63 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func Test_oidcError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "nil",
|
||||
err: nil,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "status err",
|
||||
err: op.NewStatusError(io.ErrClosedPipe, http.StatusTeapot),
|
||||
wantErr: op.NewStatusError(io.ErrClosedPipe, http.StatusTeapot),
|
||||
},
|
||||
{
|
||||
name: "oidc err",
|
||||
err: oidc.ErrInvalidClient().WithParent(io.ErrClosedPipe),
|
||||
wantErr: oidc.ErrInvalidClient().WithParent(io.ErrClosedPipe),
|
||||
},
|
||||
{
|
||||
name: "unknown err",
|
||||
err: io.ErrClosedPipe,
|
||||
wantErr: op.NewStatusError(
|
||||
oidc.ErrServerError().
|
||||
WithParent(io.ErrClosedPipe).
|
||||
WithDescription("Errors.Internal"),
|
||||
http.StatusInternalServerError,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "zitadel error, invalid request",
|
||||
err: zerrors.ThrowPreconditionFailed(io.ErrClosedPipe, "TEST-123", "oopsie"),
|
||||
wantErr: op.NewStatusError(
|
||||
oidc.ErrInvalidRequest().
|
||||
WithParent(io.ErrClosedPipe).
|
||||
WithDescription("oopsie"),
|
||||
http.StatusBadRequest,
|
||||
),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := oidcError(tt.err)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
@ -18,7 +18,10 @@ import (
|
||||
|
||||
func (s *Server) Introspect(ctx context.Context, r *op.Request[op.IntrospectionRequest]) (resp *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
if s.features.LegacyIntrospection {
|
||||
return s.LegacyServer.Introspect(ctx, r)
|
||||
|
@ -7,10 +7,17 @@ import (
|
||||
"github.com/zitadel/oidc/v3/pkg/op"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func (o *OPStorage) JWTProfileTokenType(ctx context.Context, request op.TokenRequest) (op.AccessTokenType, error) {
|
||||
func (o *OPStorage) JWTProfileTokenType(ctx context.Context, request op.TokenRequest) (_ op.AccessTokenType, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
mapJWTProfileScopesToAudience(ctx, request)
|
||||
user, err := o.query.GetUserByID(ctx, false, request.GetSubject())
|
||||
if err != nil {
|
||||
|
@ -111,7 +111,10 @@ func (k *keySetCache) getKey(ctx context.Context, keyID string) (_ *jose.JSONWeb
|
||||
// VerifySignature implements the oidc.KeySet interface.
|
||||
func (k *keySetCache) VerifySignature(ctx context.Context, jws *jose.JSONWebSignature) (_ []byte, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
if len(jws.Signatures) != 1 {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "OIDC-Gid9s", "Errors.Token.Invalid")
|
||||
|
@ -71,7 +71,7 @@ func Test_ZITADEL_API_missing_audience_scope(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -107,7 +107,7 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
|
||||
@ -134,7 +134,7 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPassword, startTime, changeTime)
|
||||
|
||||
@ -162,7 +162,7 @@ func Test_ZITADEL_API_success(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -196,7 +196,7 @@ func Test_ZITADEL_API_glob_redirects(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, false)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -225,7 +225,7 @@ func Test_ZITADEL_API_inactive_access_token(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
@ -267,7 +267,7 @@ func Test_ZITADEL_API_terminated_session(t *testing.T) {
|
||||
|
||||
// code exchange
|
||||
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
|
||||
tokens, err := exchangeTokens(t, clientID, code)
|
||||
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
|
||||
require.NoError(t, err)
|
||||
assertTokens(t, tokens, true)
|
||||
assertIDTokenClaims(t, tokens.IDTokenClaims, armPasskey, startTime, changeTime)
|
||||
|
@ -110,10 +110,13 @@ func (s *Server) Ready(ctx context.Context, r *op.Request[struct{}]) (_ *op.Resp
|
||||
|
||||
func (s *Server) Discovery(ctx context.Context, r *op.Request[struct{}]) (_ *op.Response, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
defer func() {
|
||||
err = oidcError(err)
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
restrictions, err := s.query.GetInstanceRestrictions(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, op.NewStatusError(oidc.ErrServerError().WithParent(err).WithDescription("internal server error"), http.StatusInternalServerError)
|
||||
}
|
||||
allowedLanguages := restrictions.AllowedLanguages
|
||||
if len(allowedLanguages) == 0 {
|
||||
|
@ -140,7 +140,7 @@ func (repo *AuthRequestRepo) CreateAuthRequest(ctx context.Context, request *dom
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appIDs, err := repo.Query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}})
|
||||
appIDs, err := repo.Query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -476,10 +476,17 @@ func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, wit
|
||||
return apps, err
|
||||
}
|
||||
|
||||
func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries) (ids []string, err error) {
|
||||
func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries, shouldTriggerBulk bool) (ids []string, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if shouldTriggerBulk {
|
||||
_, traceSpan := tracing.NewNamedSpan(ctx, "TriggerAppProjection")
|
||||
ctx, err = projection.AppProjection.Trigger(ctx, handler.WithAwaitRunning())
|
||||
logging.OnError(err).Debug("trigger failed")
|
||||
traceSpan.EndWithError(err)
|
||||
}
|
||||
|
||||
query, scan := prepareClientIDsQuery(ctx, q.client)
|
||||
eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
|
||||
stmt, args, err := queries.toQuery(query).Where(eq).ToSql()
|
||||
|
Loading…
Reference in New Issue
Block a user