Skip to content

Commit

Permalink
feat(integrations): propagate current org name to policy providers (c…
Browse files Browse the repository at this point in the history
…hainloop-dev#1830)

Signed-off-by: Jose I. Paris <[email protected]>
  • Loading branch information
jiparis authored Feb 18, 2025
1 parent 8e7f5d5 commit 09a9d0a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 19 deletions.
14 changes: 12 additions & 2 deletions app/controlplane/internal/service/attestation.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,12 @@ func (s *AttestationService) GetPolicy(ctx context.Context, req *cpAPI.Attestati
return nil, errors.Forbidden("forbidden", "token not found")
}

remotePolicy, err := s.workflowContractUseCase.GetPolicy(req.GetProvider(), req.GetPolicyName(), req.GetOrgName(), token.Token)
org, err := requireCurrentOrg(ctx)
if err != nil {
return nil, errors.Forbidden("forbidden", "organization not found")
}

remotePolicy, err := s.workflowContractUseCase.GetPolicy(req.GetProvider(), req.GetPolicyName(), req.GetOrgName(), org.Name, token.Token)
if err != nil {
return nil, handleUseCaseErr(err, s.log)
}
Expand All @@ -421,7 +426,12 @@ func (s *AttestationService) GetPolicyGroup(ctx context.Context, req *cpAPI.Atte
return nil, errors.Forbidden("forbidden", "token not found")
}

remoteGroup, err := s.workflowContractUseCase.GetPolicyGroup(req.GetProvider(), req.GetGroupName(), req.GetOrgName(), token.Token)
org, err := requireCurrentOrg(ctx)
if err != nil {
return nil, errors.Forbidden("forbidden", "organization not found")
}

remoteGroup, err := s.workflowContractUseCase.GetPolicyGroup(req.GetProvider(), req.GetGroupName(), req.GetOrgName(), org.Name, token.Token)
if err != nil {
return nil, handleUseCaseErr(err, s.log)
}
Expand Down
18 changes: 12 additions & 6 deletions app/controlplane/pkg/biz/workflowcontract.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ func (uc *WorkflowContractUseCase) findAndValidatePolicy(att *schemav1.PolicyAtt
return nil, err
}

remotePolicy, err := uc.GetPolicy(pr.Provider, pr.Name, pr.OrgName, token)
remotePolicy, err := uc.GetPolicy(pr.Provider, pr.Name, pr.OrgName, "", token)
if err != nil {
return nil, err
}
Expand All @@ -415,7 +415,7 @@ func (uc *WorkflowContractUseCase) findPolicyGroup(att *schemav1.PolicyGroupAtta
// [chainloop://][provider/]name
if loader.IsProviderScheme(att.GetRef()) {
pr := loader.ProviderParts(att.GetRef())
remoteGroup, err := uc.GetPolicyGroup(pr.Provider, pr.Name, pr.OrgName, token)
remoteGroup, err := uc.GetPolicyGroup(pr.Provider, pr.Name, pr.OrgName, "", token)
if err != nil {
return nil, NewErrValidation(fmt.Errorf("failed to get policy group: %w", err))
}
Expand Down Expand Up @@ -492,13 +492,16 @@ type RemotePolicyGroup struct {
}

// GetPolicy retrieves a policy from a policy provider
func (uc *WorkflowContractUseCase) GetPolicy(providerName, policyName, orgName, token string) (*RemotePolicy, error) {
func (uc *WorkflowContractUseCase) GetPolicy(providerName, policyName, policyOrgName, currentOrgName, token string) (*RemotePolicy, error) {
provider, err := uc.findProvider(providerName)
if err != nil {
return nil, err
}

policy, ref, err := provider.Resolve(policyName, orgName, token)
policy, ref, err := provider.Resolve(policyName, policyOrgName, policies.ProviderAuthOpts{
Token: token,
OrgName: currentOrgName,
})
if err != nil {
if errors.Is(err, policies.ErrNotFound) {
return nil, NewErrNotFound(fmt.Sprintf("policy %q", policyName))
Expand All @@ -510,13 +513,16 @@ func (uc *WorkflowContractUseCase) GetPolicy(providerName, policyName, orgName,
return &RemotePolicy{Policy: policy, ProviderRef: ref}, nil
}

func (uc *WorkflowContractUseCase) GetPolicyGroup(providerName, groupName, orgName, token string) (*RemotePolicyGroup, error) {
func (uc *WorkflowContractUseCase) GetPolicyGroup(providerName, groupName, groupOrgName, currentOrgName, token string) (*RemotePolicyGroup, error) {
provider, err := uc.findProvider(providerName)
if err != nil {
return nil, err
}

group, ref, err := provider.ResolveGroup(groupName, orgName, token)
group, ref, err := provider.ResolveGroup(groupName, groupOrgName, policies.ProviderAuthOpts{
Token: token,
OrgName: currentOrgName,
})
if err != nil {
if errors.Is(err, policies.ErrNotFound) {
return nil, NewErrNotFound(fmt.Sprintf("policy group %q", groupName))
Expand Down
31 changes: 20 additions & 11 deletions app/controlplane/pkg/policies/policyprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ const (
validateAction = "validate"
groupsEndpoint = "groups"

digestParam = "digest"
orgNameParam = "organization_name"
digestParam = "digest"
orgNameParam = "organization_name"
organizationHeader = "Chainloop-Organization"
)

// PolicyProvider represents an external policy provider
Expand Down Expand Up @@ -72,12 +73,17 @@ type PolicyReference struct {
Digest string
}

type ProviderAuthOpts struct {
Token string
OrgName string
}

var ErrNotFound = fmt.Errorf("policy not found")

// Resolve calls the remote provider for retrieving a policy
func (p *PolicyProvider) Resolve(policyName, orgName, token string) (*schemaapi.Policy, *PolicyReference, error) {
if policyName == "" || token == "" {
return nil, nil, fmt.Errorf("both policyname and token are mandatory")
func (p *PolicyProvider) Resolve(policyName, policyOrgName string, authOpts ProviderAuthOpts) (*schemaapi.Policy, *PolicyReference, error) {
if policyName == "" || authOpts.Token == "" {
return nil, nil, fmt.Errorf("both policyname and auth opts are mandatory")
}

// the policy name might include a digest in the form of <name>@sha256:<digest>
Expand All @@ -94,7 +100,7 @@ func (p *PolicyProvider) Resolve(policyName, orgName, token string) (*schemaapi.
}
// we want to override the orgName with the one in the response
// since we might have resolved it implicitly
providerDigest, orgName, err := p.queryProvider(url, digest, orgName, token, &policy)
providerDigest, orgName, err := p.queryProvider(url, digest, policyOrgName, authOpts, &policy)
if err != nil {
return nil, nil, fmt.Errorf("failed to resolve policy: %w", err)
}
Expand Down Expand Up @@ -170,8 +176,8 @@ func (p *PolicyProvider) ValidateAttachment(att *schemaapi.PolicyAttachment, tok
}

// ResolveGroup calls remote provider for retrieving a policy group definition
func (p *PolicyProvider) ResolveGroup(groupName, orgName, token string) (*schemaapi.PolicyGroup, *PolicyReference, error) {
if groupName == "" || token == "" {
func (p *PolicyProvider) ResolveGroup(groupName, groupOrgName string, authOpts ProviderAuthOpts) (*schemaapi.PolicyGroup, *PolicyReference, error) {
if groupName == "" || authOpts.Token == "" {
return nil, nil, fmt.Errorf("both policyname and token are mandatory")
}

Expand All @@ -189,7 +195,7 @@ func (p *PolicyProvider) ResolveGroup(groupName, orgName, token string) (*schema
}
// we want to override the orgName with the one in the response
// since we might have resolved it implicitly
providerDigest, orgName, err := p.queryProvider(url, digest, orgName, token, &group)
providerDigest, orgName, err := p.queryProvider(url, digest, groupOrgName, authOpts, &group)
if err != nil {
return nil, nil, fmt.Errorf("failed to resolve group: %w", err)
}
Expand All @@ -198,7 +204,7 @@ func (p *PolicyProvider) ResolveGroup(groupName, orgName, token string) (*schema
}

// returns digest, orgname, error
func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName, token string, out proto.Message) (string, string, error) {
func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName string, authOpts ProviderAuthOpts, out proto.Message) (string, string, error) {
query := url.Query()
if digest != "" {
query.Set(digestParam, digest)
Expand All @@ -215,7 +221,10 @@ func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName, token stri
return "", "", fmt.Errorf("error creating policy request: %w", err)
}

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authOpts.Token))
if authOpts.OrgName != "" {
req.Header.Set(organizationHeader, authOpts.OrgName)
}

// make the request
resp, err := http.DefaultClient.Do(req)
Expand Down

0 comments on commit 09a9d0a

Please sign in to comment.