diff --git a/internal/query/action_flow.go b/internal/query/action_flow.go index a4716be9a3..303d8584d9 100644 --- a/internal/query/action_flow.go +++ b/internal/query/action_flow.go @@ -63,7 +63,7 @@ type Flow struct { } func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID string) (*Flow, error) { - query, scan := prepareFlowQuery() + query, scan := prepareFlowQuery(flowType) stmt, args, err := query.Where( sq.Eq{ FlowsTriggersColumnFlowType.identifier(): flowType, @@ -188,7 +188,7 @@ func prepareTriggerActionsQuery() (sq.SelectBuilder, func(*sql.Rows) ([]*Action, } } -func prepareFlowQuery() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { +func prepareFlowQuery(flowType domain.FlowType) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { return sq.Select( ActionColumnID.identifier(), ActionColumnCreationDate.identifier(), @@ -211,6 +211,7 @@ func prepareFlowQuery() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { func(rows *sql.Rows) (*Flow, error) { flow := &Flow{ TriggerActions: make(map[domain.TriggerType][]*Action), + Type: flowType, } for rows.Next() { var ( diff --git a/internal/query/action_flow_test.go b/internal/query/action_flow_test.go index d8a70d83a0..7151cefdcc 100644 --- a/internal/query/action_flow_test.go +++ b/internal/query/action_flow_test.go @@ -8,6 +8,8 @@ import ( "regexp" "testing" + sq "github.com/Masterminds/squirrel" + "github.com/zitadel/zitadel/internal/domain" ) @@ -23,8 +25,10 @@ func Test_FlowPrepares(t *testing.T) { object interface{} }{ { - name: "prepareFlowQuery no result", - prepare: prepareFlowQuery, + name: "prepareFlowQuery no result", + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { + return prepareFlowQuery(domain.FlowTypeExternalAuthentication) + }, want: want{ sqlExpectations: mockQueries( regexp.QuoteMeta(`SELECT projections.actions.id,`+ @@ -47,11 +51,16 @@ func Test_FlowPrepares(t *testing.T) { nil, ), }, - object: &Flow{TriggerActions: map[domain.TriggerType][]*Action{}}, + object: &Flow{ + TriggerActions: map[domain.TriggerType][]*Action{}, + Type: domain.FlowTypeExternalAuthentication, + }, }, { - name: "prepareFlowQuery one action", - prepare: prepareFlowQuery, + name: "prepareFlowQuery one action", + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { + return prepareFlowQuery(domain.FlowTypeExternalAuthentication) + }, want: want{ sqlExpectations: mockQueries( regexp.QuoteMeta(`SELECT projections.actions.id,`+ @@ -129,8 +138,10 @@ func Test_FlowPrepares(t *testing.T) { }, }, { - name: "prepareFlowQuery multiple actions", - prepare: prepareFlowQuery, + name: "prepareFlowQuery multiple actions", + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { + return prepareFlowQuery(domain.FlowTypeExternalAuthentication) + }, want: want{ sqlExpectations: mockQueries( regexp.QuoteMeta(`SELECT projections.actions.id,`+ @@ -236,8 +247,10 @@ func Test_FlowPrepares(t *testing.T) { }, }, { - name: "prepareFlowQuery no action", - prepare: prepareFlowQuery, + name: "prepareFlowQuery no action", + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { + return prepareFlowQuery(domain.FlowTypeExternalAuthentication) + }, want: want{ sqlExpectations: mockQueries( regexp.QuoteMeta(`SELECT projections.actions.id,`+ @@ -302,8 +315,10 @@ func Test_FlowPrepares(t *testing.T) { }, }, { - name: "prepareFlowQuery sql err", - prepare: prepareFlowQuery, + name: "prepareFlowQuery sql err", + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { + return prepareFlowQuery(domain.FlowTypeExternalAuthentication) + }, want: want{ sqlExpectations: mockQueryErr( regexp.QuoteMeta(`SELECT projections.actions.id,`+