Skip to content

Commit

Permalink
fix(policies): Fix policy reference (chainloop-dev#1449)
Browse files Browse the repository at this point in the history
Signed-off-by: Jose I. Paris <[email protected]>
  • Loading branch information
jiparis authored Oct 29, 2024
1 parent 0748db4 commit 02ed0e0
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 125 deletions.
79 changes: 33 additions & 46 deletions app/controlplane/pkg/policies/policyprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,16 @@ func (p *PolicyProvider) Resolve(policyName, orgName, token string) (*schemaapi.
if err != nil {
return nil, nil, fmt.Errorf("failed to resolve policy: %w", err)
}
ref, err := p.queryProvider(endpoint, digest, orgName, token, &policy)
url, err := url.Parse(endpoint)
if err != nil {
return nil, nil, fmt.Errorf("error parsing policy provider URL: %w", err)
}
providerDigest, err := p.queryProvider(url, digest, orgName, token, &policy)
if err != nil {
return nil, nil, fmt.Errorf("failed to resolve policy: %w", err)
}

return &policy, ref, nil
return &policy, createRef(url, policyName, providerDigest, orgName), nil
}

// ResolveGroup calls remote provider for retrieving a policy group definition
Expand All @@ -83,29 +87,27 @@ func (p *PolicyProvider) ResolveGroup(groupName, orgName, token string) (*schema
}

// the policy name might include a digest in the form of <name>@sha256:<digest>
policyName, digest := policies.ExtractDigest(groupName)
groupName, digest := policies.ExtractDigest(groupName)

var group schemaapi.PolicyGroup
endpoint, err := url.JoinPath(p.url, groupsEndpoint, policyName)
endpoint, err := url.JoinPath(p.url, groupsEndpoint, groupName)
if err != nil {
return nil, nil, fmt.Errorf("failed to resolve group: %w", err)
}
ref, err := p.queryProvider(endpoint, digest, orgName, token, &group)
url, err := url.Parse(endpoint)
if err != nil {
return nil, nil, fmt.Errorf("error parsing policy provider URL: %w", err)
}
providerDigest, err := p.queryProvider(url, digest, orgName, token, &group)
if err != nil {
return nil, nil, fmt.Errorf("failed to resolve group: %w", err)
}

return &group, ref, nil
return &group, createRef(url, groupName, providerDigest, orgName), nil
}

func (p *PolicyProvider) queryProvider(path, digest, orgName, token string, out proto.Message) (*PolicyReference, error) {
// craft the URL
uri, err := url.Parse(path)
if err != nil {
return nil, fmt.Errorf("error parsing policy provider URL: %w", err)
}

query := uri.Query()
func (p *PolicyProvider) queryProvider(url *url.URL, digest, orgName, token string, out proto.Message) (string, error) {
query := url.Query()
if digest != "" {
query.Set(digestParam, digest)
}
Expand All @@ -114,75 +116,60 @@ func (p *PolicyProvider) queryProvider(path, digest, orgName, token string, out
query.Set(orgNameParam, orgName)
}

uri.RawQuery = query.Encode()
url.RawQuery = query.Encode()

req, err := http.NewRequest("GET", uri.String(), nil)
req, err := http.NewRequest("GET", url.String(), nil)
if err != nil {
return nil, fmt.Errorf("error creating policy request: %w", err)
return "", fmt.Errorf("error creating policy request: %w", err)
}

req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

// make the request
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("error executing policy request: %w", err)
return "", fmt.Errorf("error executing policy request: %w", err)
}

if resp.StatusCode != http.StatusOK {
if resp.StatusCode == http.StatusNotFound {
return nil, ErrNotFound
return "", ErrNotFound
}

return nil, fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
return "", fmt.Errorf("expected status code 200 but got %d", resp.StatusCode)
}

resBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading policy response: %w", err)
return "", fmt.Errorf("error reading policy response: %w", err)
}

// unmarshall response
var response ProviderResponse
if err := json.Unmarshal(resBytes, &response); err != nil {
return nil, fmt.Errorf("error unmarshalling policy response: %w", err)
}

ref, err := p.resolveRef(path, response.Digest)
if err != nil {
return nil, fmt.Errorf("error resolving policy reference: %w", err)
return "", fmt.Errorf("error unmarshalling policy response: %w", err)
}

// extract the policy payload from the query response
jsonPolicy, err := json.Marshal(response.Data)
if err != nil {
return nil, fmt.Errorf("error marshalling policy response: %w", err)
return "", fmt.Errorf("error marshalling policy response: %w", err)
}

if err := protojson.Unmarshal(jsonPolicy, out); err != nil {
return nil, fmt.Errorf("error unmarshalling policy response: %w", err)
return "", fmt.Errorf("error unmarshalling policy response: %w", err)
}

return ref, nil
return response.Digest, nil
}

func (p *PolicyProvider) resolveRef(path, digest string) (*PolicyReference, error) {
// Extract hostname from the policy provider URL
uri, err := url.Parse(p.url)
if err != nil {
return nil, fmt.Errorf("error parsing policy provider URL: %w", err)
}

if uri.Host == "" {
return nil, fmt.Errorf("invalid policy provider URL")
}

if path == "" || digest == "" {
return nil, fmt.Errorf("both path and digest are mandatory")
func createRef(policyURL *url.URL, name, digest, orgName string) *PolicyReference {
refURL := fmt.Sprintf("chainloop://%s/%s", policyURL.Host, name)
if orgName != "" {
refURL = fmt.Sprintf("%s?org=%s", refURL, orgName)
}

return &PolicyReference{
URL: fmt.Sprintf("chainloop://%s/%s", uri.Host, path),
URL: refURL,
Digest: digest,
}, nil
}
}
54 changes: 23 additions & 31 deletions app/controlplane/pkg/policies/policyprovider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,45 @@
package policies

import (
"net/url"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestResolveRef(t *testing.T) {
func TestCreateRef(t *testing.T) {
testCases := []struct {
name string
providerURL string
policyName string
digest string
want *PolicyReference
wantErr bool
name string
policyURL string
policyName string
digest string
orgName string
want *PolicyReference
}{
{
name: "valid",
providerURL: "https://p1host.com/foo",
policyName: "my-policy",
digest: "my-digest",
want: &PolicyReference{URL: "chainloop://p1host.com/my-policy", Digest: "my-digest"},
name: "base",
policyURL: "https://p1host.com/foo",
policyName: "my-policy",
digest: "my-digest",
want: &PolicyReference{URL: "chainloop://p1host.com/my-policy", Digest: "my-digest"},
},
{
name: "missing digest",
providerURL: "https://p1host.com/foo",
policyName: "my-policy",
wantErr: true,
},
{
name: "missing schema",
providerURL: "p1host.com/foo",
policyName: "my-policy",
wantErr: true,
name: "with org",
policyURL: "https://p1host.com/foo",
policyName: "my-policy",
digest: "my-digest",
orgName: "my-org",
want: &PolicyReference{URL: "chainloop://p1host.com/my-policy?org=my-org", Digest: "my-digest"},
},
}

for _, tc := range testCases {
t.Run(tc.providerURL, func(t *testing.T) {
provider := &PolicyProvider{url: tc.providerURL}

got, err := provider.resolveRef(tc.policyName, tc.digest)
if tc.wantErr {
require.Error(t, err)
return
}

t.Run(tc.name, func(t *testing.T) {
policyURL, err := url.Parse(tc.policyURL)
require.NoError(t, err)
got := createRef(policyURL, tc.policyName, tc.digest, tc.orgName)

assert.Equal(t, tc.want, got)
})
}
Expand Down
13 changes: 6 additions & 7 deletions pkg/policies/group_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,18 @@ import (

pb "github.com/chainloop-dev/chainloop/app/controlplane/api/controlplane/v1"
v1 "github.com/chainloop-dev/chainloop/app/controlplane/api/workflowcontract/v1"
v12 "github.com/chainloop-dev/chainloop/pkg/attestation/crafter/api/attestation/v1"
crv1 "github.com/google/go-containerregistry/pkg/v1"
)

// GroupLoader defines the interface for policy loaders from contract attachments
type GroupLoader interface {
Load(context.Context, *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *v12.ResourceDescriptor, error)
Load(context.Context, *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *PolicyDescriptor, error)
}

// FileGroupLoader loader loads policies from filesystem and HTTPS references using Cosign's blob package
type FileGroupLoader struct{}

func (l *FileGroupLoader) Load(_ context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *v12.ResourceDescriptor, error) {
func (l *FileGroupLoader) Load(_ context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *PolicyDescriptor, error) {
var (
raw []byte
err error
Expand Down Expand Up @@ -68,7 +67,7 @@ func (l *FileGroupLoader) Load(_ context.Context, attachment *v1.PolicyGroupAtta
// HTTPSGroupLoader loader loads policies from HTTP or HTTPS references
type HTTPSGroupLoader struct{}

func (l *HTTPSGroupLoader) Load(_ context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *v12.ResourceDescriptor, error) {
func (l *HTTPSGroupLoader) Load(_ context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *PolicyDescriptor, error) {
ref, wantDigest := ExtractDigest(attachment.GetRef())

// and do not remove the scheme since we need http(s):// to make the request
Expand Down Expand Up @@ -105,7 +104,7 @@ type ChainloopGroupLoader struct {

type groupWithReference struct {
group *v1.PolicyGroup
reference *v12.ResourceDescriptor
reference *PolicyDescriptor
}

var remoteGroupCache = make(map[string]*groupWithReference)
Expand All @@ -114,7 +113,7 @@ func NewChainloopGroupLoader(client pb.AttestationServiceClient) *ChainloopGroup
return &ChainloopGroupLoader{Client: client}
}

func (c *ChainloopGroupLoader) Load(ctx context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *v12.ResourceDescriptor, error) {
func (c *ChainloopGroupLoader) Load(ctx context.Context, attachment *v1.PolicyGroupAttachment) (*v1.PolicyGroup, *PolicyDescriptor, error) {
ref := attachment.GetRef()

c.cacheMutex.Lock()
Expand Down Expand Up @@ -144,7 +143,7 @@ func (c *ChainloopGroupLoader) Load(ctx context.Context, attachment *v1.PolicyGr
return nil, nil, fmt.Errorf("parsing digest: %w", err)
}

reference := policyReferenceResourceDescriptor(resp.Reference.GetUrl(), h)
reference := policyReferenceResourceDescriptor(providerRef.Name, resp.Reference.GetUrl(), providerRef.OrgName, h)
// cache result
remoteGroupCache[ref] = &groupWithReference{group: resp.GetGroup(), reference: reference}
return resp.GetGroup(), reference, nil
Expand Down
Loading

0 comments on commit 02ed0e0

Please sign in to comment.