diff --git a/core/appeal/service.go b/core/appeal/service.go index 50d69aab3..2fe80af0f 100644 --- a/core/appeal/service.go +++ b/core/appeal/service.go @@ -36,6 +36,8 @@ const ( var TimeNow = time.Now +type ContextKeyIsAdditionalAppealCreation struct{} + type repository interface { BulkUpsert([]*domain.Appeal) error Find(*domain.ListAppealsFilter) ([]*domain.Appeal, error) @@ -141,6 +143,8 @@ func (s *Service) Find(ctx context.Context, filters *domain.ListAppealsFilter) ( // Create record func (s *Service) Create(ctx context.Context, appeals []*domain.Appeal) error { + isAdditionalAppealCreation, _ := ctx.Value(ContextKeyIsAdditionalAppealCreation{}).(bool) + resourceIDs := []string{} for _, a := range appeals { resourceIDs = append(resourceIDs, a.ResourceID) @@ -197,9 +201,15 @@ func (s *Service) Create(ctx context.Context, appeals []*domain.Appeal) error { return fmt.Errorf("validating appeal based on provider: %w", err) } - policy, err := getPolicy(appeal, provider, policies) - if err != nil { - return fmt.Errorf("retrieving policy: %w", err) + var policy *domain.Policy + if isAdditionalAppealCreation && appeal.PolicyID != "" && appeal.PolicyVersion != 0 { + policy = policies[appeal.PolicyID][appeal.PolicyVersion] + } else { + var err error + policy, err = getPolicy(appeal, provider, policies) + if err != nil { + return fmt.Errorf("retrieving policy: %w", err) + } } if err := s.addCreatorDetails(appeal, policy); err != nil { @@ -776,6 +786,7 @@ func (s *Service) handleAppealRequirements(ctx context.Context, a *domain.Appeal additionalAppeal.PolicyID = aa.Policy.ID additionalAppeal.PolicyVersion = uint(aa.Policy.Version) } + ctx = context.WithValue(ctx, ContextKeyIsAdditionalAppealCreation{}, true) if err := s.Create(ctx, []*domain.Appeal{additionalAppeal}); err != nil { if errors.Is(err, ErrAppealDuplicate) { s.logger.Debug("skipping creating additional appeal", @@ -806,8 +817,11 @@ func (s *Service) CreateAccess(ctx context.Context, a *domain.Appeal) error { policy = p } - if err := s.handleAppealRequirements(ctx, a, policy); err != nil { - return fmt.Errorf("handling appeal requirements: %w", err) + isAdditionalAppealCreation, _ := ctx.Value(ContextKeyIsAdditionalAppealCreation{}).(bool) + if !isAdditionalAppealCreation { + if err := s.handleAppealRequirements(ctx, a, policy); err != nil { + return fmt.Errorf("handling appeal requirements: %w", err) + } } if err := s.providerService.GrantAccess(ctx, a); err != nil { diff --git a/core/appeal/service_test.go b/core/appeal/service_test.go index ab4d9136a..7752e6fef 100644 --- a/core/appeal/service_test.go +++ b/core/appeal/service_test.go @@ -7,6 +7,7 @@ import ( "time" "github.com/go-playground/validator/v10" + "github.com/google/uuid" "github.com/odpf/guardian/core/appeal" appealmocks "github.com/odpf/guardian/core/appeal/mocks" "github.com/odpf/guardian/core/provider" @@ -768,6 +769,84 @@ func (s *ServiceTestSuite) TestCreate() { s.Equal(expectedResult, appeals) s.Nil(actualError) }) + + s.Run("additional appeal creation", func() { + s.Run("should use the overridding policy", func() { + ctx := context.WithValue(context.Background(), appeal.ContextKeyIsAdditionalAppealCreation{}, true) + + input := &domain.Appeal{ + ResourceID: uuid.New().String(), + AccountID: "user@example.com", + AccountType: domain.DefaultAppealAccountType, + CreatedBy: "user@example.com", + Role: "test-role", + PolicyID: "test-policy", + PolicyVersion: 99, + } + dummyResource := &domain.Resource{ + ID: input.ResourceID, + ProviderType: "test-provider-type", + ProviderURN: "test-provider-urn", + Type: "test-type", + URN: "test-urn", + } + dummyProvider := &domain.Provider{ + Type: dummyResource.ProviderType, + URN: dummyResource.ProviderURN, + Config: &domain.ProviderConfig{ + Type: dummyResource.ProviderType, + URN: dummyResource.ProviderURN, + Resources: []*domain.ResourceConfig{ + { + Type: dummyResource.Type, + Policy: &domain.PolicyConfig{ + ID: "test-dummy-policy", + Version: 1, + }, + Roles: []*domain.Role{ + { + ID: input.Role, + }, + }, + }, + }, + }, + } + dummyPolicy := &domain.Policy{ + ID: "test-dummy-policy", + Version: 1, + } + overriddingPolicy := &domain.Policy{ + ID: input.PolicyID, + Version: input.PolicyVersion, + Steps: []*domain.Step{ + { + Name: "test-approval", + Strategy: "auto", + ApproveIf: "true", + }, + }, + } + + s.mockResourceService.On("Find", mock.Anything, mock.Anything).Return([]*domain.Resource{dummyResource}, nil).Once() + s.mockProviderService.On("Find", mock.Anything).Return([]*domain.Provider{dummyProvider}, nil).Once() + s.mockPolicyService.On("Find", mock.Anything).Return([]*domain.Policy{dummyPolicy, overriddingPolicy}, nil).Once() + s.mockRepository.On("Find", mock.Anything).Return([]*domain.Appeal{}, nil).Once() + s.mockProviderService.On("ValidateAppeal", mock.Anything, mock.Anything, mock.Anything).Return(nil) + s.mockIAMManager.On("ParseConfig", mock.Anything, mock.Anything).Return(nil, nil) + s.mockIAMManager.On("GetClient", mock.Anything, mock.Anything).Return(s.mockIAMClient, nil) + s.mockIAMClient.On("GetUser", input.AccountID).Return(map[string]interface{}{}, nil) + s.mockApprovalService.On("AdvanceApproval", mock.Anything, mock.Anything).Return(nil) + s.mockRepository.On("BulkUpsert", mock.Anything).Return(nil).Once() + s.mockNotifier.On("Notify", mock.Anything).Return(nil).Once() + s.mockAuditLogger.On("Log", mock.Anything, appeal.AuditKeyBulkInsert, mock.Anything).Return(nil).Once() + + err := s.service.Create(ctx, []*domain.Appeal{input}) + + s.NoError(err) + s.Equal("test-approval", input.Approvals[0].Name) + }) + }) } func (s *ServiceTestSuite) TestCreateAppeal__WithExistingAppealAndWithAutoApprovalSteps() {