fix(crypto): reject decrypted strings with non-UTF8 characters. (#8374)

# Which Problems Are Solved

We noticed logging where 500: Internal Server errors were returned from
the token endpoint, mostly for the `refresh_token` grant. The error was
thrown by the database as it received non-UTF8 strings for token IDs

Zitadel uses symmetric encryption for opaque tokens, including refresh
tokens. Encrypted values are base64 encoded. It appeared to be possible
to send garbage base64 to the token endpoint, which will pass decryption
and string-splitting. In those cases the resulting ID is not a valid
UTF-8 string.

Invalid non-UTF8 strings are now rejected during token decryption.

# How the Problems Are Solved

- `AESCrypto.DecryptString()` checks if the decrypted bytes only contain
valid UTF-8 characters before converting them into a string.
- `AESCrypto.Decrypt()` is unmodified and still allows decryption on
non-UTF8 byte strings.
- `FromRefreshToken` now uses `DecryptString` instead of `Decrypt`

# Additional Changes

- Unit tests added for `FromRefreshToken` and
`AESCrypto.DecryptString()`.
- Fuzz tests added for `FromRefreshToken` and
`AESCrypto.DecryptString()`. This was to pinpoint the problem
- Testdata with values that resulted in invalid strings are committed.
In the pipeline this results in the Fuzz tests to execute as regular
unit-test cases. As we don't use the `-fuzz` flag in the pipeline no
further fuzzing is performed.

# Additional Context

- Closes #7765
- https://go.dev/doc/tutorial/fuzz
This commit is contained in:
Tim Möhlmann 2024-08-02 11:38:37 +03:00 committed by GitHub
parent 3d071fc505
commit 4e3fd305ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 262 additions and 19 deletions

View File

@ -293,7 +293,7 @@ func (c *Commands) decryptRefreshToken(refreshToken string) (sessionID, refreshT
}
decrypted, err := c.keyAlgorithm.DecryptString(decoded, c.keyAlgorithm.EncryptionKeyID())
if err != nil {
return "", "", err
return "", "", zerrors.ThrowInvalidArgument(err, "OIDCS-Jei0i", "Errors.User.RefreshToken.Invalid")
}
return parseRefreshToken(decrypted)
}

View File

@ -6,6 +6,7 @@ import (
"crypto/rand"
"encoding/base64"
"io"
"unicode/utf8"
"github.com/zitadel/zitadel/internal/zerrors"
)
@ -46,15 +47,17 @@ func (a *AESCrypto) Decrypt(value []byte, keyID string) ([]byte, error) {
return DecryptAES(value, key)
}
// DecryptString decrypts the value using the key identified by keyID.
// When the decrypted value contains non-UTF8 characters an error is returned.
func (a *AESCrypto) DecryptString(value []byte, keyID string) (string, error) {
key, err := a.decryptionKey(keyID)
b, err := a.Decrypt(value, keyID)
if err != nil {
return "", err
}
b, err := DecryptAES(value, key)
if err != nil {
return "", err
if !utf8.Valid(b) {
return "", zerrors.ThrowPreconditionFailed(err, "CRYPT-hiCh0", "non-UTF-8 in decrypted string")
}
return string(b), nil
}

View File

@ -1,18 +1,109 @@
package crypto
import (
"context"
"errors"
"testing"
"unicode/utf8"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/zerrors"
)
// TODO: refactor test style
func TestDecrypt_OK(t *testing.T) {
encryptedpw, err := EncryptAESString("ThisIsMySecretPw", "passphrasewhichneedstobe32bytes!")
assert.NoError(t, err)
decryptedpw, err := DecryptAESString(encryptedpw, "passphrasewhichneedstobe32bytes!")
assert.NoError(t, err)
assert.Equal(t, "ThisIsMySecretPw", decryptedpw)
type mockKeyStorage struct {
keys Keys
}
func (s *mockKeyStorage) ReadKeys() (Keys, error) {
return s.keys, nil
}
func (s *mockKeyStorage) ReadKey(id string) (*Key, error) {
return &Key{
ID: id,
Value: s.keys[id],
}, nil
}
func (*mockKeyStorage) CreateKeys(context.Context, ...*Key) error {
return errors.New("mockKeyStorage.CreateKeys not implemented")
}
func newTestAESCrypto(t testing.TB) *AESCrypto {
keyConfig := &KeyConfig{
EncryptionKeyID: "keyID",
DecryptionKeyIDs: []string{"keyID"},
}
keys := Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"}
aesCrypto, err := NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys})
require.NoError(t, err)
return aesCrypto
}
func TestAESCrypto_DecryptString(t *testing.T) {
aesCrypto := newTestAESCrypto(t)
const input = "SecretData"
crypted, err := aesCrypto.Encrypt([]byte(input))
require.NoError(t, err)
type args struct {
value []byte
keyID string
}
tests := []struct {
name string
args args
want string
wantErr error
}{
{
name: "unknown key id error",
args: args{
value: crypted,
keyID: "foo",
},
wantErr: zerrors.ThrowNotFound(nil, "CRYPT-nkj1s", "unknown key id"),
},
{
name: "ok",
args: args{
value: crypted,
keyID: "keyID",
},
want: input,
},
}
for _, tt := range tests {
got, err := aesCrypto.DecryptString(tt.args.value, tt.args.keyID)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.want, got)
}
}
func FuzzAESCrypto_DecryptString(f *testing.F) {
aesCrypto := newTestAESCrypto(f)
tests := []string{
" ",
"SecretData",
"FooBar",
"HelloWorld",
}
for _, input := range tests {
tc, err := aesCrypto.Encrypt([]byte(input))
require.NoError(f, err)
f.Add(tc)
}
f.Fuzz(func(t *testing.T, value []byte) {
got, err := aesCrypto.DecryptString(value, "keyID")
if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "CRYPT-23kH1", "cipher text block too short")) {
return
}
if errors.Is(err, zerrors.ThrowPreconditionFailed(nil, "CRYPT-hiCh0", "non-UTF-8 in decrypted string")) {
return
}
require.NoError(t, err)
assert.True(t, utf8.ValidString(got), "result is not valid UTF-8")
})
}

View File

@ -19,6 +19,9 @@ type EncryptionAlgorithm interface {
DecryptionKeyIDs() []string
Encrypt(value []byte) ([]byte, error)
Decrypt(hashed []byte, keyID string) ([]byte, error)
// DecryptString decrypts the value using the key identified by keyID.
// When the decrypted value contains non-UTF8 characters an error is returned.
DecryptString(hashed []byte, keyID string) (string, error)
}
@ -72,6 +75,8 @@ func Decrypt(value *CryptoValue, alg EncryptionAlgorithm) ([]byte, error) {
return alg.Decrypt(value.Crypted, value.KeyID)
}
// DecryptString decrypts the value using the key identified by keyID.
// When the decrypted value contains non-UTF8 characters an error is returned.
func DecryptString(value *CryptoValue, alg EncryptionAlgorithm) (string, error) {
if err := checkEncryptionAlgorithm(value, alg); err != nil {
return "", err

View File

@ -0,0 +1,2 @@
go test fuzz v1
[]byte("0010120C001010070")

View File

@ -25,13 +25,13 @@ func FromRefreshToken(refreshToken string, algorithm crypto.EncryptionAlgorithm)
if err != nil {
return "", "", "", zerrors.ThrowInvalidArgument(err, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid")
}
decrypted, err := algorithm.Decrypt(decoded, algorithm.EncryptionKeyID())
decrypted, err := algorithm.DecryptString(decoded, algorithm.EncryptionKeyID())
if err != nil {
return "", "", "", err
return "", "", "", zerrors.ThrowInvalidArgument(err, "DOMAIN-rie9A", "Errors.User.RefreshToken.Invalid")
}
split := strings.Split(string(decrypted), ":")
split := strings.Split(decrypted, ":")
if len(split) != 3 {
return "", "", "", zerrors.ThrowInvalidArgument(nil, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid")
return "", "", "", zerrors.ThrowInvalidArgument(nil, "DOMAIN-Se8oh", "Errors.User.RefreshToken.Invalid")
}
return split[0], split[1], split[2], nil
}

View File

@ -0,0 +1,129 @@
package domain
import (
"encoding/base64"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/zerrors"
)
type mockKeyStorage struct {
keys crypto.Keys
}
func (s *mockKeyStorage) ReadKeys() (crypto.Keys, error) {
return s.keys, nil
}
func (s *mockKeyStorage) ReadKey(id string) (*crypto.Key, error) {
return &crypto.Key{
ID: id,
Value: s.keys[id],
}, nil
}
func (*mockKeyStorage) CreateKeys(context.Context, ...*crypto.Key) error {
return errors.New("mockKeyStorage.CreateKeys not implemented")
}
func TestFromRefreshToken(t *testing.T) {
const (
userID = "userID"
tokenID = "tokenID"
)
keyConfig := &crypto.KeyConfig{
EncryptionKeyID: "keyID",
DecryptionKeyIDs: []string{"keyID"},
}
keys := crypto.Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"}
algorithm, err := crypto.NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys})
require.NoError(t, err)
refreshToken, err := NewRefreshToken(userID, tokenID, algorithm)
require.NoError(t, err)
invalidRefreshToken, err := algorithm.Encrypt([]byte(userID + ":" + tokenID))
require.NoError(t, err)
type args struct {
refreshToken string
algorithm crypto.EncryptionAlgorithm
}
tests := []struct {
name string
args args
wantUserID string
wantTokenID string
wantToken string
wantErr error
}{
{
name: "invalid base64",
args: args{"~~~", algorithm},
wantErr: zerrors.ThrowInvalidArgument(nil, "DOMAIN-BGDhn", "Errors.User.RefreshToken.Invalid"),
},
{
name: "short cipher text",
args: args{"DEADBEEF", algorithm},
wantErr: zerrors.ThrowInvalidArgument(err, "DOMAIN-rie9A", "Errors.User.RefreshToken.Invalid"),
},
{
name: "incorrect amount of segments",
args: args{base64.RawURLEncoding.EncodeToString(invalidRefreshToken), algorithm},
wantErr: zerrors.ThrowInvalidArgument(nil, "DOMAIN-Se8oh", "Errors.User.RefreshToken.Invalid"),
},
{
name: "success",
args: args{refreshToken, algorithm},
wantUserID: userID,
wantTokenID: tokenID,
wantToken: tokenID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotUserID, gotTokenID, gotToken, err := FromRefreshToken(tt.args.refreshToken, tt.args.algorithm)
require.ErrorIs(t, err, tt.wantErr)
assert.Equal(t, tt.wantUserID, gotUserID)
assert.Equal(t, tt.wantTokenID, gotTokenID)
assert.Equal(t, tt.wantToken, gotToken)
})
}
}
// Fuzz test invalid inputs. None of the inputs should result in a success.
func FuzzFromRefreshToken(f *testing.F) {
keyConfig := &crypto.KeyConfig{
EncryptionKeyID: "keyID",
DecryptionKeyIDs: []string{"keyID"},
}
keys := crypto.Keys{"keyID": "ThisKeyNeedsToHave32Characters!!"}
algorithm, err := crypto.NewAESCrypto(keyConfig, &mockKeyStorage{keys: keys})
require.NoError(f, err)
invalidRefreshToken, err := algorithm.Encrypt([]byte("userID:tokenID"))
require.NoError(f, err)
tests := []string{
"~~~", // invalid base64
"DEADBEEF", // short cipher text
base64.RawURLEncoding.EncodeToString(invalidRefreshToken), // incorrect amount of segments
}
for _, tc := range tests {
f.Add(tc)
}
f.Fuzz(func(t *testing.T, refreshToken string) {
gotUserID, gotTokenID, gotToken, err := FromRefreshToken(refreshToken, algorithm)
target := zerrors.InvalidArgumentError{ZitadelError: new(zerrors.ZitadelError)}
t.Log(gotUserID, gotTokenID, gotToken)
require.ErrorAs(t, err, &target)
})
}

View File

@ -0,0 +1,2 @@
go test fuzz v1
string("0000050000000000000000000000000")

View File

@ -1,6 +1,8 @@
package zerrors
import "fmt"
import (
"fmt"
)
var (
_ InvalidArgument = (*InvalidArgumentError)(nil)
@ -39,6 +41,15 @@ func (err *InvalidArgumentError) Is(target error) bool {
return err.ZitadelError.Is(t.ZitadelError)
}
func (err *InvalidArgumentError) As(target any) bool {
targetErr, ok := target.(*InvalidArgumentError)
if !ok {
return false
}
*targetErr = *err
return true
}
func (err *InvalidArgumentError) Unwrap() error {
return err.ZitadelError
}