From 4e3fd305abe87ec08d021770c9111a1054e356a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Fri, 2 Aug 2024 11:38:37 +0300 Subject: [PATCH] fix(crypto): reject decrypted strings with non-UTF8 characters. (#8374) # Which Problems Are Solved We noticed logging where 500: Internal Server errors were returned from the token endpoint, mostly for the `refresh_token` grant. The error was thrown by the database as it received non-UTF8 strings for token IDs Zitadel uses symmetric encryption for opaque tokens, including refresh tokens. Encrypted values are base64 encoded. It appeared to be possible to send garbage base64 to the token endpoint, which will pass decryption and string-splitting. In those cases the resulting ID is not a valid UTF-8 string. Invalid non-UTF8 strings are now rejected during token decryption. # How the Problems Are Solved - `AESCrypto.DecryptString()` checks if the decrypted bytes only contain valid UTF-8 characters before converting them into a string. - `AESCrypto.Decrypt()` is unmodified and still allows decryption on non-UTF8 byte strings. - `FromRefreshToken` now uses `DecryptString` instead of `Decrypt` # Additional Changes - Unit tests added for `FromRefreshToken` and `AESCrypto.DecryptString()`. - Fuzz tests added for `FromRefreshToken` and `AESCrypto.DecryptString()`. This was to pinpoint the problem - Testdata with values that resulted in invalid strings are committed. In the pipeline this results in the Fuzz tests to execute as regular unit-test cases. As we don't use the `-fuzz` flag in the pipeline no further fuzzing is performed. # Additional Context - Closes #7765 - https://go.dev/doc/tutorial/fuzz --- internal/command/oidc_session.go | 2 +- internal/crypto/aes.go | 11 +- internal/crypto/aes_test.go | 109 +++++++++++++-- internal/crypto/crypto.go | 5 + .../8d609af8fa2eb76f | 2 + internal/domain/refresh_token.go | 8 +- internal/domain/refresh_token_test.go | 129 ++++++++++++++++++ .../FuzzFromRefreshToken/576e811604c701eb | 2 + internal/zerrors/invalid_argument.go | 13 +- 9 files changed, 262 insertions(+), 19 deletions(-) create mode 100644 internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f create mode 100644 internal/domain/refresh_token_test.go create mode 100644 internal/domain/testdata/fuzz/FuzzFromRefreshToken/576e811604c701eb diff --git a/internal/command/oidc_session.go b/internal/command/oidc_session.go index fc384d0d30..1fe82198bf 100644 --- a/internal/command/oidc_session.go +++ b/internal/command/oidc_session.go @@ -293,7 +293,7 @@ func (c *Commands) decryptRefreshToken(refreshToken string) (sessionID, refreshT } decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID()) if err != nil { - return "", "", err + return "", "", zerrors.ThrowInvalidArgument(err, "OIDCS-Jei0i", "Errors.User.RefreshToken.Invalid") } return parseRefreshToken(decrypted) } diff --git a/internal/crypto/aes.go b/internal/crypto/aes.go index e943c2ca8e..f57a78fb85 100644 --- a/internal/crypto/aes.go +++ b/internal/crypto/aes.go @@ -6,6 +6,7 @@ import ( "crypto/rand" "encoding/base64" "io" + "unicode/utf8" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -46,15 +47,17 @@ func (a *AESCrypto) Decrypt(value []byte, keyID string) ([]byte, error) { return DecryptAES(value, key) } +// DecryptString decrypts the value using the key identified by keyID. +// When the decrypted value contains non-UTF8 characters an error is returned. func (a *AESCrypto) DecryptString(value []byte, keyID string) (string, error) { - key, err := a.decryptionKey(keyID) + b, err := a.Decrypt(value, keyID) if err != nil { return "", err } - b, err := DecryptAES(value, key) - if err != nil { - return "", err + if !utf8.Valid(b) { + return "", zerrors.ThrowPreconditionFailed(err, "CRYPT-hiCh0", "non-UTF-8 in decrypted string") } + return string(b), nil } diff --git a/internal/crypto/aes_test.go b/internal/crypto/aes_test.go index 5731f320eb..128fd6c4dc 100644 --- a/internal/crypto/aes_test.go +++ b/internal/crypto/aes_test.go @@ -1,18 +1,109 @@ package crypto import ( + "context" + "errors" "testing" + "unicode/utf8" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/internal/zerrors" ) -// TODO: refactor test style -func TestDecrypt_OK(t *testing.T) { - encryptedpw, err := EncryptAESString("ThisIsMySecretPw", "passphrasewhichneedstobe32bytes!") - assert.NoError(t, err) - - decryptedpw, err := DecryptAESString(encryptedpw, "passphrasewhichneedstobe32bytes!") - assert.NoError(t, err) - - assert.Equal(t, "ThisIsMySecretPw", decryptedpw) +type mockKeyStorage struct { + keys Keys +} + +func (s *mockKeyStorage) ReadKeys() (Keys, error) { + return s.keys, nil +} + +func (s *mockKeyStorage) ReadKey(id string) (*Key, error) { + return &Key{ + ID: id, + Value: s.keys[id], + }, nil +} + +func (*mockKeyStorage) CreateKeys(context.Context, ...*Key) error { + return errors.New("mockKeyStorage.CreateKeys not implemented") +} + +func newTestAESCrypto(t testing.TB) *AESCrypto { + keyConfig := &KeyConfig{ + EncryptionKeyID: "keyID", + DecryptionKeyIDs: []string{"keyID"}, + } + keys := Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"} + aesCrypto, err := NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys}) + require.NoError(t, err) + return aesCrypto +} + +func TestAESCrypto_DecryptString(t *testing.T) { + aesCrypto := newTestAESCrypto(t) + const input = "SecretData" + crypted, err := aesCrypto.Encrypt([]byte(input)) + require.NoError(t, err) + + type args struct { + value []byte + keyID string + } + tests := []struct { + name string + args args + want string + wantErr error + }{ + { + name: "unknown key id error", + args: args{ + value: crypted, + keyID: "foo", + }, + wantErr: zerrors.ThrowNotFound(nil, "CRYPT-nkj1s", "unknown key id"), + }, + { + name: "ok", + args: args{ + value: crypted, + keyID: "keyID", + }, + want: input, + }, + } + for _, tt := range tests { + got, err := aesCrypto.DecryptString(tt.args.value, tt.args.keyID) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.want, got) + } +} + +func FuzzAESCrypto_DecryptString(f *testing.F) { + aesCrypto := newTestAESCrypto(f) + tests := []string{ + " ", + "SecretData", + "FooBar", + "HelloWorld", + } + for _, input := range tests { + tc, err := aesCrypto.Encrypt([]byte(input)) + require.NoError(f, err) + f.Add(tc) + } + f.Fuzz(func(t *testing.T, value []byte) { + got, err := aesCrypto.DecryptString(value, "keyID") + if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "CRYPT-23kH1", "cipher text block too short")) { + return + } + if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "CRYPT-hiCh0", "non-UTF-8 in decrypted string")) { + return + } + require.NoError(t, err) + assert.True(t, utf8.ValidString(got), "result is not valid UTF-8") + }) } diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index 2e8e4a71b0..a74f97a054 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -19,6 +19,9 @@ type EncryptionAlgorithm interface { DecryptionKeyIDs() []string Encrypt(value []byte) ([]byte, error) Decrypt(hashed []byte, keyID string) ([]byte, error) + + // DecryptString decrypts the value using the key identified by keyID. + // When the decrypted value contains non-UTF8 characters an error is returned. DecryptString(hashed []byte, keyID string) (string, error) } @@ -72,6 +75,8 @@ func Decrypt(value *CryptoValue, alg EncryptionAlgorithm) ([]byte, error) { return alg.Decrypt(value.Crypted, value.KeyID) } +// DecryptString decrypts the value using the key identified by keyID. +// When the decrypted value contains non-UTF8 characters an error is returned. func DecryptString(value *CryptoValue, alg EncryptionAlgorithm) (string, error) { if err := checkEncryptionAlgorithm(value, alg); err != nil { return "", err diff --git a/internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f b/internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f new file mode 100644 index 0000000000..233de8fb25 --- /dev/null +++ b/internal/crypto/testdata/fuzz/FuzzAESCrypto_DecryptString/8d609af8fa2eb76f @@ -0,0 +1,2 @@ +go test fuzz v1 +[]byte("0010120C001010070") diff --git a/internal/domain/refresh_token.go b/internal/domain/refresh_token.go index 6f2d883df5..25ab32f45b 100644 --- a/internal/domain/refresh_token.go +++ b/internal/domain/refresh_token.go @@ -25,13 +25,13 @@ func FromRefreshToken(refreshToken string, algorithm crypto.EncryptionAlgorithm) if err != nil { return "", "", "", zerrors.ThrowInvalidArgument(err, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid") } - decrypted, err := algorithm.Decrypt(decoded, algorithm.EncryptionKeyID()) + decrypted, err := algorithm.DecryptString(decoded, algorithm.EncryptionKeyID()) if err != nil { - return "", "", "", err + return "", "", "", zerrors.ThrowInvalidArgument(err, "DOMAIN-rie9A", "Errors.User.RefreshToken.Invalid") } - split := strings.Split(string(decrypted), ":") + split := strings.Split(decrypted, ":") if len(split) != 3 { - return "", "", "", zerrors.ThrowInvalidArgument(nil, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid") + return "", "", "", zerrors.ThrowInvalidArgument(nil, "DOMAIN-Se8oh", "Errors.User.RefreshToken.Invalid") } return split[0], split[1], split[2], nil } diff --git a/internal/domain/refresh_token_test.go b/internal/domain/refresh_token_test.go new file mode 100644 index 0000000000..e2719bd238 --- /dev/null +++ b/internal/domain/refresh_token_test.go @@ -0,0 +1,129 @@ +package domain + +import ( + "encoding/base64" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/net/context" + + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type mockKeyStorage struct { + keys crypto.Keys +} + +func (s *mockKeyStorage) ReadKeys() (crypto.Keys, error) { + return s.keys, nil +} + +func (s *mockKeyStorage) ReadKey(id string) (*crypto.Key, error) { + return &crypto.Key{ + ID: id, + Value: s.keys[id], + }, nil +} + +func (*mockKeyStorage) CreateKeys(context.Context, ...*crypto.Key) error { + return errors.New("mockKeyStorage.CreateKeys not implemented") +} + +func TestFromRefreshToken(t *testing.T) { + const ( + userID = "userID" + tokenID = "tokenID" + ) + + keyConfig := &crypto.KeyConfig{ + EncryptionKeyID: "keyID", + DecryptionKeyIDs: []string{"keyID"}, + } + keys := crypto.Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"} + algorithm, err := crypto.NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys}) + require.NoError(t, err) + + refreshToken, err := NewRefreshToken(userID, tokenID, algorithm) + require.NoError(t, err) + + invalidRefreshToken, err := algorithm.Encrypt([]byte(userID + ":" + tokenID)) + require.NoError(t, err) + + type args struct { + refreshToken string + algorithm crypto.EncryptionAlgorithm + } + tests := []struct { + name string + args args + wantUserID string + wantTokenID string + wantToken string + wantErr error + }{ + { + name: "invalid base64", + args: args{"~~~", algorithm}, + wantErr: zerrors.ThrowInvalidArgument(nil, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid"), + }, + { + name: "short cipher text", + args: args{"DEADBEEF", algorithm}, + wantErr: zerrors.ThrowInvalidArgument(err, "DOMAIN-rie9A", "Errors.User.RefreshToken.Invalid"), + }, + { + name: "incorrect amount of segments", + args: args{base64.RawURLEncoding.EncodeToString(invalidRefreshToken), algorithm}, + wantErr: zerrors.ThrowInvalidArgument(nil, "DOMAIN-Se8oh", "Errors.User.RefreshToken.Invalid"), + }, + { + name: "success", + args: args{refreshToken, algorithm}, + wantUserID: userID, + wantTokenID: tokenID, + wantToken: tokenID, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotUserID, gotTokenID, gotToken, err := FromRefreshToken(tt.args.refreshToken, tt.args.algorithm) + require.ErrorIs(t, err, tt.wantErr) + assert.Equal(t, tt.wantUserID, gotUserID) + assert.Equal(t, tt.wantTokenID, gotTokenID) + assert.Equal(t, tt.wantToken, gotToken) + }) + } +} + +// Fuzz test invalid inputs. None of the inputs should result in a success. +func FuzzFromRefreshToken(f *testing.F) { + keyConfig := &crypto.KeyConfig{ + EncryptionKeyID: "keyID", + DecryptionKeyIDs: []string{"keyID"}, + } + keys := crypto.Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"} + algorithm, err := crypto.NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys}) + require.NoError(f, err) + + invalidRefreshToken, err := algorithm.Encrypt([]byte("userID:tokenID")) + require.NoError(f, err) + + tests := []string{ + "~~~", // invalid base64 + "DEADBEEF", // short cipher text + base64.RawURLEncoding.EncodeToString(invalidRefreshToken), // incorrect amount of segments + } + for _, tc := range tests { + f.Add(tc) + } + + f.Fuzz(func(t *testing.T, refreshToken string) { + gotUserID, gotTokenID, gotToken, err := FromRefreshToken(refreshToken, algorithm) + target := zerrors.InvalidArgumentError{ZitadelError: new(zerrors.ZitadelError)} + t.Log(gotUserID, gotTokenID, gotToken) + require.ErrorAs(t, err, &target) + }) +} diff --git a/internal/domain/testdata/fuzz/FuzzFromRefreshToken/576e811604c701eb b/internal/domain/testdata/fuzz/FuzzFromRefreshToken/576e811604c701eb new file mode 100644 index 0000000000..0e9296b076 --- /dev/null +++ b/internal/domain/testdata/fuzz/FuzzFromRefreshToken/576e811604c701eb @@ -0,0 +1,2 @@ +go test fuzz v1 +string("0000050000000000000000000000000") diff --git a/internal/zerrors/invalid_argument.go b/internal/zerrors/invalid_argument.go index b2a33fc860..e97519e660 100644 --- a/internal/zerrors/invalid_argument.go +++ b/internal/zerrors/invalid_argument.go @@ -1,6 +1,8 @@ package zerrors -import "fmt" +import ( + "fmt" +) var ( _ InvalidArgument = (*InvalidArgumentError)(nil) @@ -39,6 +41,15 @@ func (err *InvalidArgumentError) Is(target error) bool { return err.ZitadelError.Is(t.ZitadelError) } +func (err *InvalidArgumentError) As(target any) bool { + targetErr, ok := target.(*InvalidArgumentError) + if !ok { + return false + } + *targetErr = *err + return true +} + func (err *InvalidArgumentError) Unwrap() error { return err.ZitadelError }