From ac985e2dfb43afb41ecf6c47a5bc04ce20042632 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Wed, 24 Apr 2024 10:44:55 +0200 Subject: [PATCH] fix(login): correctly reload policies on auth request (#7839) --- .../eventsourcing/eventstore/auth_request.go | 13 +++++++------ internal/auth_request/repository/cache/cache.go | 14 +++++++++----- internal/domain/auth_request.go | 10 ++++++++++ 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index ca95108205..eea29a167f 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -656,7 +656,7 @@ func (repo *AuthRequestRepo) fillPolicies(ctx context.Context, request *domain.A } } - if request.LoginPolicy == nil || len(request.AllowedExternalIDPs) == 0 { + if request.LoginPolicy == nil || len(request.AllowedExternalIDPs) == 0 || request.PolicyOrgID() != orgID { loginPolicy, idpProviders, err := repo.getLoginPolicyAndIDPProviders(ctx, orgID) if err != nil { return err @@ -666,21 +666,21 @@ func (repo *AuthRequestRepo) fillPolicies(ctx context.Context, request *domain.A request.AllowedExternalIDPs = idpProviders } } - if request.LockoutPolicy == nil { + if request.LockoutPolicy == nil || request.PolicyOrgID() != orgID { lockoutPolicy, err := repo.getLockoutPolicy(ctx, orgID) if err != nil { return err } request.LockoutPolicy = lockoutPolicyToDomain(lockoutPolicy) } - if request.PrivacyPolicy == nil { + if request.PrivacyPolicy == nil || request.PolicyOrgID() != orgID { privacyPolicy, err := repo.GetPrivacyPolicy(ctx, orgID) if err != nil { return err } request.PrivacyPolicy = privacyPolicy } - if request.LabelPolicy == nil { + if request.LabelPolicy == nil || request.PolicyOrgID() != orgID { labelPolicy, err := repo.getLabelPolicy(ctx, request.PrivateLabelingOrgID(orgID)) if err != nil { return err @@ -694,13 +694,14 @@ func (repo *AuthRequestRepo) fillPolicies(ctx context.Context, request *domain.A } request.DefaultTranslations = defaultLoginTranslations } - if len(request.OrgTranslations) == 0 { + if len(request.OrgTranslations) == 0 || request.PolicyOrgID() != orgID { orgLoginTranslations, err := repo.getLoginTexts(ctx, orgID) if err != nil { return err } request.OrgTranslations = orgLoginTranslations } + request.SetPolicyOrgID(orgID) repo.AuthRequests.CacheAuthRequest(ctx, request) return nil } @@ -887,7 +888,7 @@ func (repo *AuthRequestRepo) checkLoginNameInputForResourceOwner(ctx context.Con } func (repo *AuthRequestRepo) checkLoginPolicyWithResourceOwner(ctx context.Context, request *domain.AuthRequest, resourceOwner string) (err error) { - if request.LoginPolicy == nil { + if request.LoginPolicy == nil || request.PolicyOrgID() != resourceOwner { loginPolicy, idps, err := repo.getLoginPolicyAndIDPProviders(ctx, resourceOwner) if err != nil { return err diff --git a/internal/auth_request/repository/cache/cache.go b/internal/auth_request/repository/cache/cache.go index 9919d717de..a59f291d53 100644 --- a/internal/auth_request/repository/cache/cache.go +++ b/internal/auth_request/repository/cache/cache.go @@ -24,16 +24,20 @@ type AuthRequestCache struct { } func Start(dbClient *database.DB, amountOfCachedAuthRequests uint16) *AuthRequestCache { + cache := &AuthRequestCache{ + client: dbClient, + } idCache, err := lru.New[string, *domain.AuthRequest](int(amountOfCachedAuthRequests)) logging.OnError(err).Info("auth request cache disabled") + if err == nil { + cache.idCache = idCache + } codeCache, err := lru.New[string, *domain.AuthRequest](int(amountOfCachedAuthRequests)) logging.OnError(err).Info("auth request cache disabled") - - return &AuthRequestCache{ - client: dbClient, - idCache: idCache, - codeCache: codeCache, + if err == nil { + cache.codeCache = codeCache } + return cache } func (c *AuthRequestCache) Health(ctx context.Context) error { diff --git a/internal/domain/auth_request.go b/internal/domain/auth_request.go index cf406c7625..0a7704d691 100644 --- a/internal/domain/auth_request.go +++ b/internal/domain/auth_request.go @@ -56,6 +56,16 @@ type AuthRequest struct { DefaultTranslations []*CustomText OrgTranslations []*CustomText SAMLRequestID string + // orgID the policies were last loaded with + policyOrgID string +} + +func (a *AuthRequest) SetPolicyOrgID(id string) { + a.policyOrgID = id +} + +func (a *AuthRequest) PolicyOrgID() string { + return a.policyOrgID } type ExternalUser struct {