diff --git a/sdk/bulk.go b/sdk/bulk.go index 4e1f3f8e3..e0a7405ed 100644 --- a/sdk/bulk.go +++ b/sdk/bulk.go @@ -12,9 +12,10 @@ import ( // BulkTDF: Reader is TDF Content. Writer writes encrypted data. Error is the error that occurs if decrypting fails. type BulkTDF struct { - Reader io.ReadSeeker - Writer io.Writer - Error error + Reader io.ReadSeeker + Writer io.Writer + Error error + TriggeredObligations Obligations } type BulkDecryptRequest struct { @@ -26,6 +27,15 @@ type BulkDecryptRequest struct { ignoreAllowList bool } +// BulkDecryptPrepared holds the prepared state for bulk decryption +// The PolicyTDF is a map of created policy IDs to their corresponding BulkTDF +// The policy IDs are generated during the prepareDecryptors function +type BulkDecryptPrepared struct { + PolicyTDF map[string]*BulkTDF + tdfDecryptors map[string]decryptor + allRewrapResp map[string][]kaoResult +} + // BulkErrors List of Errors that Failed during Bulk Decryption type BulkErrors []error @@ -116,17 +126,9 @@ func (s SDK) createDecryptor(tdf *BulkTDF, req *BulkDecryptRequest) (decryptor, return nil, fmt.Errorf("unknown tdf type: %s", req.TDFType) } -// BulkDecrypt Decrypts a list of BulkTDF and if a partial failure of TDFs unable to be decrypted, BulkErrors would be returned. -func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { - bulkReq, createError := createBulkRewrapRequest(opts...) - if createError != nil { - return fmt.Errorf("failed to create bulk rewrap request: %w", createError) - } - kasRewrapRequests := make(map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest) - tdfDecryptors := make(map[string]decryptor) - policyTDF := make(map[string]*BulkTDF) - - if !bulkReq.ignoreAllowList && len(bulkReq.kasAllowlist) == 0 { //nolint:nestif // if kasAllowlist is not set, we get it from the registry +// setupKasAllowlist configures the KAS allowlist for the bulk request +func (s SDK) setupKasAllowlist(ctx context.Context, bulkReq *BulkDecryptRequest) error { + if !bulkReq.ignoreAllowList && len(bulkReq.kasAllowlist) == 0 { //nolint:nestif // not complex if s.KeyAccessServerRegistry != nil { platformEndpoint, err := s.PlatformConfiguration.platformEndpoint() if err != nil { @@ -145,10 +147,18 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { return errors.New("no KAS allowlist provided and no KeyAccessServerRegistry available") } } + return nil +} + +// prepareDecryptors creates decryptors and rewrap requests for all TDFs +func (s SDK) prepareDecryptors(ctx context.Context, bulkReq *BulkDecryptRequest) (map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, map[string]decryptor, map[string]*BulkTDF) { + kasRewrapRequests := make(map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest) + tdfDecryptors := make(map[string]decryptor) + policyTDF := make(map[string]*BulkTDF) for i, tdf := range bulkReq.TDFs { policyID := fmt.Sprintf("policy-%d", i) - decryptor, err := s.createDecryptor(tdf, bulkReq) //nolint:contextcheck // dont want to change signature of LoadTDF + decryptor, err := s.createDecryptor(tdf, bulkReq) //nolint:contextcheck // context is not used in createDecryptor if err != nil { tdf.Error = err continue @@ -167,9 +177,15 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { } } - kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey) + return kasRewrapRequests, tdfDecryptors, policyTDF +} + +// performRewraps executes all rewrap requests with KAS servers +func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, kasRewrapRequests map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, fulfillableObligations []string) (map[string][]kaoResult, error) { + kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey, fulfillableObligations) allRewrapResp := make(map[string][]kaoResult) var err error + for kasurl, rewrapRequests := range kasRewrapRequests { if bulkReq.ignoreAllowList { s.Logger().Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl)) @@ -186,6 +202,7 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { } continue } + var rewrapResp map[string][]kaoResult switch bulkReq.TDFType { case Nano: @@ -198,19 +215,73 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { allRewrapResp[id] = append(allRewrapResp[id], res...) } } + if err != nil { - return fmt.Errorf("bulk rewrap failed: %w", err) + return nil, fmt.Errorf("bulk rewrap failed: %w", err) + } + + return allRewrapResp, nil +} + +// PrepareBulkDecrypt does everything except decrypt from the Bulk Decrypt +// ! Currently you cannot specify fulfillable obligations on an individual TDF basis +func (s SDK) PrepareBulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) (*BulkDecryptPrepared, error) { + bulkReq, createError := createBulkRewrapRequest(opts...) + if createError != nil { + return nil, fmt.Errorf("failed to create bulk rewrap request: %w", createError) + } + + // Setup KAS allowlist + if err := s.setupKasAllowlist(ctx, bulkReq); err != nil { + return nil, err + } + + // Prepare decryptors and rewrap requests + kasRewrapRequests, tdfDecryptors, policyTDF := s.prepareDecryptors(ctx, bulkReq) + + // Use the default fulfillable obligations unless a decryptor is available to provide its own + fulfillableObligations := s.fulfillableObligationFQNs + if len(tdfDecryptors) > 0 { + for _, d := range tdfDecryptors { + fulfillableObligations = getFulfillableObligations(d, s.logger) + break + } + } + + // Perform rewraps + allRewrapResp, err := s.performRewraps(ctx, bulkReq, kasRewrapRequests, fulfillableObligations) + if err != nil { + return nil, err } - var errList []error for id, tdf := range policyTDF { - kaoRes, ok := allRewrapResp[id] + policyRes, ok := allRewrapResp[id] + if !ok { + tdf.Error = errors.New("rewrap did not create a response for this TDF") + continue + } + tdf.TriggeredObligations = Obligations{FQNs: dedupRequiredObligations(policyRes)} + } + + return &BulkDecryptPrepared{ + PolicyTDF: policyTDF, + tdfDecryptors: tdfDecryptors, + allRewrapResp: allRewrapResp, + }, nil +} + +// Allow the bulk decryption to occur +func (bp *BulkDecryptPrepared) BulkDecrypt(ctx context.Context) error { + var errList []error + var err error + for id, tdf := range bp.PolicyTDF { + kaoRes, ok := bp.allRewrapResp[id] if !ok { tdf.Error = errors.New("rewrap did not create a response for this TDF") errList = append(errList, tdf.Error) continue } - decryptor := tdfDecryptors[id] + decryptor := bp.tdfDecryptors[id] if _, err = decryptor.Decrypt(ctx, kaoRes); err != nil { tdf.Error = err errList = append(errList, tdf.Error) @@ -225,9 +296,36 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { return nil } +// BulkDecrypt Decrypts a list of BulkTDF and if a partial failure of TDFs unable to be decrypted, BulkErrors would be returned. +func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error { + prepared, err := s.PrepareBulkDecrypt(ctx, opts...) + if err != nil { + return err + } + + return prepared.BulkDecrypt(ctx) +} + func (b *BulkDecryptRequest) appendTDFs(tdfs ...*BulkTDF) { b.TDFs = append( b.TDFs, tdfs..., ) } + +func getFulfillableObligations(decryptor decryptor, logger *slog.Logger) []string { + if decryptor == nil { + logger.Warn("decryptor is nil, cannot populate obligations") + return make([]string, 0) + } + + switch d := decryptor.(type) { + case *tdf3DecryptHandler: + return d.reader.config.fulfillableObligationFQNs + case *NanoTDFDecryptHandler: + return d.config.fulfillableObligationFQNs + default: + logger.Warn("unknown decryptor type, cannot populate obligations", slog.String("type", fmt.Sprintf("%T", d))) + return make([]string, 0) + } +} diff --git a/sdk/granter_test.go b/sdk/granter_test.go index 19e9519e6..a47c7d6a5 100644 --- a/sdk/granter_test.go +++ b/sdk/granter_test.go @@ -30,6 +30,7 @@ const ( specifiedKas = "https://attr.kas.com/" evenMoreSpecificKas = "https://value.kas.com/" lessSpecificKas = "https://namespace.kas.com/" + obligationKas = "https://obligation.kas.com/" fakePem = mockRSAPublicKey1 ) @@ -75,6 +76,21 @@ var ( mpc, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/c") mpd, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/d") mpu, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/unspecified") + + // Attributes for testing obligations + + OBLIGATIONATTR, _ = NewAttributeNameFQN("https://virtru.com/attr/obligation_test") + oa1, _ = NewAttributeValueFQN("https://virtru.com/attr/obligation_test/value/value1") + oa2, _ = NewAttributeValueFQN("https://virtru.com/attr/obligation_test/value/value2") + oa3, _ = NewAttributeValueFQN("https://virtru.com/attr/obligation_test/value/value3") + obligationWatermark = "https://virtru.com/obl/obligation_test/value/watermark" + obligationGeofence = "https://virtru.com/obl/obligation_test/value/geofence" + obligationRedact = "https://virtru.com/obl/obligation_test/value/redact" + obligationMap = map[string]string{ + oa1.key: obligationWatermark, + oa2.key: obligationGeofence, + oa3.key: obligationRedact, + } ) func spongeCase(s string) string { @@ -211,6 +227,14 @@ func mockAttributeFor(fqn AttributeNameFQN) *policy.Attribute { Rule: policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF, Fqn: fqn.String(), } + case OBLIGATIONATTR.key: + return &policy.Attribute{ + Id: "OBL", + Namespace: &nsOne, + Name: "obligation", + Rule: policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF, + Fqn: fqn.String(), + } } return nil } @@ -452,6 +476,18 @@ func mockValueFor(fqn AttributeValueFQN) *policy.Value { p.Grants = make([]*policy.KeyAccessServer, 1) p.Grants[0] = mockGrant(evenMoreSpecificKas, "r1") } + case OBLIGATIONATTR.key: + switch strings.ToLower(fqn.Value()) { + case "value1": + p.KasKeys = make([]*policy.SimpleKasKey, 1) + p.KasKeys[0] = mockSimpleKasKey(obligationKas, "r3") + case "value2": + p.KasKeys = make([]*policy.SimpleKasKey, 1) + p.KasKeys[0] = mockSimpleKasKey(obligationKas, "r3") + case "value3": + p.KasKeys = make([]*policy.SimpleKasKey, 1) + p.KasKeys[0] = mockSimpleKasKey("https://d.kas/", "e1") + } } return &p } diff --git a/sdk/kas_client.go b/sdk/kas_client.go index e8e84e7cf..d54fa06fd 100644 --- a/sdk/kas_client.go +++ b/sdk/kas_client.go @@ -3,6 +3,8 @@ package sdk import ( "context" "crypto/sha256" + "encoding/base64" + "encoding/json" "errors" "fmt" "net" @@ -12,6 +14,7 @@ import ( "connectrpc.com/connect" "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/structpb" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" @@ -22,9 +25,11 @@ import ( ) const ( - secondsPerMinute = 60 - statusFail = "fail" - statusPermit = "permit" + secondsPerMinute = 60 + statusFail = "fail" + statusPermit = "permit" + additionalRewrapContextHeader = "X-Rewrap-Additional-Context" + triggeredObligationsHeader = "X-Required-Obligations" ) type KASClient struct { @@ -35,12 +40,14 @@ type KASClient struct { // Set this to enable legacy, non-batch rewrap requests supportSingleRewrapEndpoint bool + fulfillableObligations []string } type kaoResult struct { - SymmetricKey []byte - Error error - KeyAccessObjectID string + SymmetricKey []byte + Error error + KeyAccessObjectID string + RequiredObligations []string } type decryptor interface { @@ -48,13 +55,22 @@ type decryptor interface { Decrypt(ctx context.Context, results []kaoResult) (int, error) } -func newKASClient(httpClient *http.Client, options []connect.ClientOption, accessTokenSource auth.AccessTokenSource, sessionKey ocrypto.KeyPair) *KASClient { +type obligationContext struct { + FulfillableFQNs []string `json:"fulfillableFQNs"` +} + +type additionalRewrapContext struct { + Obligations obligationContext `json:"obligations"` +} + +func newKASClient(httpClient *http.Client, options []connect.ClientOption, accessTokenSource auth.AccessTokenSource, sessionKey ocrypto.KeyPair, fulfillableObligations []string) *KASClient { return &KASClient{ accessTokenSource: accessTokenSource, httpClient: httpClient, connectOptions: options, sessionKey: sessionKey, supportSingleRewrapEndpoint: true, + fulfillableObligations: fulfillableObligations, } } @@ -72,7 +88,11 @@ func (k *KASClient) makeRewrapRequest(ctx context.Context, requests []*kas.Unsig serviceClient := kasconnect.NewAccessServiceClient(k.httpClient, parsedURL, k.connectOptions...) - response, err := serviceClient.Rewrap(ctx, connect.NewRequest(rewrapRequest)) + rewrapReq, err := k.newConnectRewrapRequest(rewrapRequest) + if err != nil { + return nil, fmt.Errorf("error creating rewrap request: %w", err) + } + response, err := serviceClient.Rewrap(ctx, rewrapReq) if err != nil { return upgradeRewrapErrorV1(err, requests) } @@ -82,6 +102,23 @@ func (k *KASClient) makeRewrapRequest(ctx context.Context, requests []*kas.Unsig return response.Msg, nil } +func (k *KASClient) newConnectRewrapRequest(rewrapReq *kas.RewrapRequest) (*connect.Request[kas.RewrapRequest], error) { + req := connect.NewRequest(rewrapReq) + rewrapContext := &additionalRewrapContext{ + Obligations: obligationContext{ + FulfillableFQNs: k.fulfillableObligations, + }, + } + rewrapContextJSON, err := json.Marshal(rewrapContext) + if err != nil { + return nil, fmt.Errorf("error marshaling additional rewrap context: %w", err) + } + + rewrapContextBase64 := base64.StdEncoding.EncodeToString(rewrapContextJSON) + req.Header().Set(additionalRewrapContextHeader, rewrapContextBase64) + return req, nil +} + // convert v1 responses to v2 func upgradeRewrapResponseV1(response *kas.RewrapResponse, requests []*kas.UnsignedRewrapRequest_WithPolicyRequest) { if len(response.GetResponses()) > 0 { @@ -161,10 +198,11 @@ func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRew for _, results := range response.GetResponses() { var kaoKeys []kaoResult for _, kao := range results.GetResults() { + requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata()) if kao.GetStatus() == statusPermit { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO}) } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO}) } } policyResults[results.GetPolicyId()] = kaoKeys @@ -192,16 +230,17 @@ func (k *KASClient) nanoUnwrap(ctx context.Context, requests ...*kas.UnsignedRew for _, results := range response.GetResponses() { var kaoKeys []kaoResult for _, kao := range results.GetResults() { + requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata()) if kao.GetStatus() == statusPermit { wrappedKey := kao.GetKasWrappedKey() key, err := aesGcm.Decrypt(wrappedKey) if err != nil { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO}) } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO}) } } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO}) } } policyResults[results.GetPolicyId()] = kaoKeys @@ -261,15 +300,16 @@ func (k *KASClient) processECResponse(response *kas.RewrapResponse, aesGcm ocryp for _, results := range response.GetResponses() { var kaoKeys []kaoResult for _, kao := range results.GetResults() { + requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata()) if kao.GetStatus() == statusPermit { key, err := aesGcm.Decrypt(kao.GetKasWrappedKey()) if err != nil { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO}) } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO}) } } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO}) } } policyResults[results.GetPolicyId()] = kaoKeys @@ -277,6 +317,35 @@ func (k *KASClient) processECResponse(response *kas.RewrapResponse, aesGcm ocryp return policyResults, nil } +/* +Metadata will be in the following form, per kao: + + { + "metadata": { + "X-Required-Obligations": [] + } + } +*/ +func (k *KASClient) retrieveObligationsFromMetadata(metadata map[string]*structpb.Value) []string { + var requiredObligations []string + + if metadata == nil { + return requiredObligations + } + + triggerOblsValue, ok := metadata[triggeredObligationsHeader] + if !ok { + return requiredObligations + } + + triggerOblsList := triggerOblsValue.GetListValue().GetValues() + for _, v := range triggerOblsList { + requiredObligations = append(requiredObligations, v.GetStringValue()) + } + + return requiredObligations +} + func (k *KASClient) handleRSAKeyResponse(response *kas.RewrapResponse) (map[string][]kaoResult, error) { clientPrivateKey, err := k.sessionKey.PrivateKeyInPemFormat() if err != nil { @@ -296,15 +365,16 @@ func (k *KASClient) processRSAResponse(response *kas.RewrapResponse, asymDecrypt for _, results := range response.GetResponses() { var kaoKeys []kaoResult for _, kao := range results.GetResults() { + requiredObligationsForKAO := k.retrieveObligationsFromMetadata(kao.GetMetadata()) if kao.GetStatus() == statusPermit { key, err := asymDecryption.Decrypt(kao.GetKasWrappedKey()) if err != nil { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: err, RequiredObligations: requiredObligationsForKAO}) } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), SymmetricKey: key, RequiredObligations: requiredObligationsForKAO}) } } else { - kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError())}) + kaoKeys = append(kaoKeys, kaoResult{KeyAccessObjectID: kao.GetKeyAccessObjectId(), Error: errors.New(kao.GetError()), RequiredObligations: requiredObligationsForKAO}) } } policyResults[results.GetPolicyId()] = kaoKeys diff --git a/sdk/kas_client_test.go b/sdk/kas_client_test.go index 001c99dc2..e1f0ecd8c 100644 --- a/sdk/kas_client_test.go +++ b/sdk/kas_client_test.go @@ -1,7 +1,12 @@ package sdk import ( + "bytes" "context" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "io" "net/http" "testing" "time" @@ -18,6 +23,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/structpb" ) type FakeAccessTokenSource struct { @@ -64,7 +70,7 @@ func TestCreatingRequest(t *testing.T) { kasKey, err := ocrypto.NewRSAKeyPair(tdf3KeySize) require.NoError(t, err, "error creating RSA Key") - client := newKASClient(nil, options, tokenSource, &kasKey) + client := newKASClient(nil, options, tokenSource, &kasKey, []string{}) require.NoError(t, err) keyAccess := []*kaspb.UnsignedRewrapRequest_WithPolicyRequest{ @@ -386,3 +392,547 @@ func TestKasKeyCache_Expiration(t *testing.T) { _, exists := cache.c[cacheKey] assert.False(t, exists, "Expired key should be removed from cache") } + +func Test_newConnectRewrapRequest(t *testing.T) { + c := newKASClient(nil, nil, nil, nil, []string{"https://example.com/attr/attr1/value/val1"}) + req, err := c.newConnectRewrapRequest(&kaspb.RewrapRequest{}) + require.NoError(t, err) + actualHeader := req.Header().Get(additionalRewrapContextHeader) + require.NotEmpty(t, actualHeader) + decoded, err := base64.StdEncoding.DecodeString(actualHeader) + require.NoError(t, err) + var rewrapContext additionalRewrapContext + err = json.Unmarshal(decoded, &rewrapContext) + require.NoError(t, err) + require.Len(t, rewrapContext.Obligations.FulfillableFQNs, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val1", rewrapContext.Obligations.FulfillableFQNs[0]) +} + +func Test_retrieveObligationsFromMetadata(t *testing.T) { + c := newKASClient(nil, nil, nil, nil, nil) + metadata := createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val1", + "https://example.com/attr/attr2/value/val2", + }) + + fqns := c.retrieveObligationsFromMetadata(metadata) + require.Len(t, fqns, 2) + require.Equal(t, "https://example.com/attr/attr1/value/val1", fqns[0]) + require.Equal(t, "https://example.com/attr/attr2/value/val2", fqns[1]) +} + +func Test_retrieveObligationsFromMetadata_NoObligations(t *testing.T) { + c := newKASClient(nil, nil, nil, nil, nil) + fqns := c.retrieveObligationsFromMetadata(createMetadataWithObligations(nil)) + require.Empty(t, fqns) +} + +func Test_retrieveObligationsFromMetadata_NotListValue(t *testing.T) { + c := newKASClient(nil, nil, nil, nil, nil) + metadata := make(map[string]*structpb.Value) + metadata[triggeredObligationsHeader] = &structpb.Value{ + Kind: &structpb.Value_BoolValue{ + BoolValue: true, + }, + } + fqns := c.retrieveObligationsFromMetadata(metadata) + require.Empty(t, fqns) +} + +func Test_retrieveObligationsFromMetadata_EmptyList(t *testing.T) { + c := newKASClient(nil, nil, nil, nil, nil) + metadata := make(map[string]*structpb.Value) + metadata[triggeredObligationsHeader] = &structpb.Value{ + Kind: &structpb.Value_ListValue{ + ListValue: &structpb.ListValue{Values: []*structpb.Value{}}, + }, + } + fqns := c.retrieveObligationsFromMetadata(metadata) + require.Empty(t, fqns) +} + +func Test_processRSAResponse(t *testing.T) { + c := newKASClient(nil, nil, nil, nil, nil) + + // Create a mock AsymDecryption + mockPrivateKey, err := ocrypto.NewRSAKeyPair(2048) + require.NoError(t, err) + privateKeyPEM, err := mockPrivateKey.PrivateKeyInPemFormat() + require.NoError(t, err) + mockDecryptor, err := ocrypto.NewAsymDecryption(privateKeyPEM) + require.NoError(t, err) + + // Create a mock AsymEncryption to create the wrapped key + publicKeyPEM, err := mockPrivateKey.PublicKeyInPemFormat() + require.NoError(t, err) + mockEncryptor, err := ocrypto.NewAsymEncryption(publicKeyPEM) + require.NoError(t, err) + + symmetricKey := []byte("supersecretkey") + wrappedKey, err := mockEncryptor.Encrypt(symmetricKey) + require.NoError(t, err) + + response := &kaspb.RewrapResponse{ + Responses: []*kaspb.PolicyRewrapResult{ + { + PolicyId: "policy1", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao1", + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "Access denied", + }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val1", + }), + }, + { + KeyAccessObjectId: "kao2", + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "Access denied", + }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val2", + }), + }, + }, + }, + { + PolicyId: "policy2", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao1", + Status: "permit", + Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{ + KasWrappedKey: wrappedKey, + }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val3", + }), + }, + }, + }, + }, + } + + policyResults, err := c.processRSAResponse(response, mockDecryptor) + require.NoError(t, err) + require.Len(t, policyResults, 2) + + result, ok := policyResults["policy1"] + require.True(t, ok) + require.Len(t, result, 2) + require.Nil(t, result[0].SymmetricKey) + require.Nil(t, result[1].SymmetricKey) + require.Len(t, result[0].RequiredObligations, 1) + require.Len(t, result[1].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val1", result[0].RequiredObligations[0]) + require.Equal(t, "https://example.com/attr/attr1/value/val2", result[1].RequiredObligations[0]) + + result2, ok := policyResults["policy2"] + require.True(t, ok) + require.Len(t, result2, 1) + require.Equal(t, symmetricKey, result2[0].SymmetricKey) + require.Len(t, result2[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val3", result2[0].RequiredObligations[0]) +} + +func Test_processECResponse(t *testing.T) { + c := newKASClient(nil, nil, nil, nil, nil) + + // 1. Set up keys for encryption + kasPublicKey, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1) + require.NoError(t, err) + kasPublicKeyPEM, err := kasPublicKey.PublicKeyInPemFormat() + require.NoError(t, err) + + clientPrivateKey, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1) + require.NoError(t, err) + clientPrivateKeyPEM, err := clientPrivateKey.PrivateKeyInPemFormat() + require.NoError(t, err) + + // 2. Compute shared secret and derive session key (for encryption) + ecdhKey, err := ocrypto.ComputeECDHKey([]byte(clientPrivateKeyPEM), []byte(kasPublicKeyPEM)) + require.NoError(t, err) + + digest := sha256.New() + digest.Write([]byte("TDF")) + salt := digest.Sum(nil) + sessionKey, err := ocrypto.CalculateHKDF(salt, ecdhKey) + require.NoError(t, err) + + // 3. Create AES-GCM cipher for encryption + encryptor, err := ocrypto.NewAESGcm(sessionKey) + require.NoError(t, err) + + symmetricKey2 := []byte("supersecretkey2") + wrappedKey2, err := encryptor.Encrypt(symmetricKey2) + require.NoError(t, err) + + // 5. Create mock response with multiple policies + response := &kaspb.RewrapResponse{ + Responses: []*kaspb.PolicyRewrapResult{ + { + PolicyId: "policy1", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao1", + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "Access denied", + }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val1", + }), + }, + { + KeyAccessObjectId: "kao2", + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "Access denied", + }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val2", + }), + }, + }, + }, + { + PolicyId: "policy2", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao1", + Status: "permit", + Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{ + KasWrappedKey: wrappedKey2, + }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr2/value/val2", + }), + }, + }, + }, + }, + } + + // 6. Create AES-GCM cipher for decryption (using the same session key) + decryptor, err := ocrypto.NewAESGcm(sessionKey) + require.NoError(t, err) + + // 7. Process the response + policyResults, err := c.processECResponse(response, decryptor) + require.NoError(t, err) + require.Len(t, policyResults, 2) + + // 8. Assertions for policy1 + result1, ok := policyResults["policy1"] + require.True(t, ok) + require.Len(t, result1, 2) + require.Nil(t, result1[0].SymmetricKey) + require.Nil(t, result1[1].SymmetricKey) + require.Len(t, result1[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val1", result1[0].RequiredObligations[0]) + require.Len(t, result1[1].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val2", result1[1].RequiredObligations[0]) + + // 9. Assertions for policy2 + result2, ok := policyResults["policy2"] + require.True(t, ok) + require.Len(t, result2, 1) + require.Equal(t, symmetricKey2, result2[0].SymmetricKey) + require.Len(t, result2[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr2/value/val2", result2[0].RequiredObligations[0]) +} + +type mockService interface { + Process(req *http.Request) (*http.Response, error) +} + +type MockKas struct { + t *testing.T + obligations map[string][]string // policyID -> obligations + policyDecisions map[string]string // policyID -> "permit" or "fail" +} + +func (f *MockKas) Process(req *http.Request) (*http.Response, error) { + // 1. KAS generates its own ephemeral keypair for the ECDH exchange. + kasKeypair, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1) + require.NoError(f.t, err) + kasPublicKeyPEM, err := kasKeypair.PublicKeyInPemFormat() + require.NoError(f.t, err) + kasPrivateKeyPEM, err := kasKeypair.PrivateKeyInPemFormat() + require.NoError(f.t, err) + + // 2. Extract the client's public key from the incoming request. + bodyBytes, err := io.ReadAll(req.Body) + require.NoError(f.t, err) + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) // Restore body + + var bodyJSON map[string]interface{} + err = json.Unmarshal(bodyBytes, &bodyJSON) + require.NoError(f.t, err) + signedRequestToken, ok := bodyJSON["signedRequestToken"].(string) + require.True(f.t, ok) + + // We need a public key to verify the token, but for this mock we can parse without verification. + token, err := jwt.ParseString(signedRequestToken, jwt.WithVerify(false)) + require.NoError(f.t, err) + + requestBodyClaim, _ := token.Get("requestBody") + requestBodyJSON, _ := requestBodyClaim.(string) + var unsignedReq kaspb.UnsignedRewrapRequest + err = protojson.Unmarshal([]byte(requestBodyJSON), &unsignedReq) + require.NoError(f.t, err) + clientPublicKeyPEM := unsignedReq.GetClientPublicKey() + + // 3. Compute the shared secret (ECDH) and derive the session key (HKDF). + ecdhKey, err := ocrypto.ComputeECDHKey([]byte(kasPrivateKeyPEM), []byte(clientPublicKeyPEM)) + require.NoError(f.t, err) + sessionKey, err := ocrypto.CalculateHKDF(versionSalt(), ecdhKey) + require.NoError(f.t, err) + + // 4. Encrypt the symmetric key using the derived session key. + encryptor, err := ocrypto.NewAESGcm(sessionKey) + require.NoError(f.t, err) + symmetricKey := []byte("supersecretkey1") + wrappedKey, err := encryptor.Encrypt(symmetricKey) + require.NoError(f.t, err) + + // 5. Construct the KAS rewrap response. + rewrapResponse := &kaspb.RewrapResponse{ + SessionPublicKey: kasPublicKeyPEM, + } + for _, req := range unsignedReq.GetRequests() { + policyID := req.GetPolicy().GetId() + + // Determine if this policy should be permitted or failed + decision := "permit" // default to permit + if f.policyDecisions != nil { + if d, exists := f.policyDecisions[policyID]; exists { + decision = d + } + } + + var kaoResult *kaspb.KeyAccessRewrapResult + var metadata map[string]*structpb.Value + if fqns, exists := f.obligations[policyID]; exists { + metadata = createMetadataWithObligations(fqns) + } + if decision == "permit" { + // For permitted policies: no metadata/obligations + kaoResult = &kaspb.KeyAccessRewrapResult{ + KeyAccessObjectId: req.GetKeyAccessObjects()[0].GetKeyAccessObjectId(), + Status: "permit", + Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{ + KasWrappedKey: wrappedKey, + }, + Metadata: metadata, + } + } else { + kaoResult = &kaspb.KeyAccessRewrapResult{ + KeyAccessObjectId: req.GetKeyAccessObjects()[0].GetKeyAccessObjectId(), + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "denied by policy", + }, + Metadata: metadata, + } + } + + rewrapResponse.Responses = append(rewrapResponse.Responses, &kaspb.PolicyRewrapResult{ + PolicyId: policyID, + Results: []*kaspb.KeyAccessRewrapResult{kaoResult}, + }) + } + + responseBody, err := protojson.Marshal(rewrapResponse) + require.NoError(f.t, err) + + mockHTTPResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(responseBody)), + Header: make(http.Header), + } + mockHTTPResponse.Header.Set("Content-Type", "application/json") + + return mockHTTPResponse, nil +} + +// mockRoundTripper is a mock implementation of http.RoundTripper for testing. +type mockRoundTripper struct { + Response *http.Response + mockService mockService + Err error +} + +// RoundTrip implements the http.RoundTripper interface. +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if m.Err != nil || m.Response != nil { + return m.Response, m.Err + } + return m.mockService.Process(req) +} + +func Test_nanoUnwrap(t *testing.T) { + // 1. Set up the mock HTTP client + mockClient := &http.Client{ + Transport: &mockRoundTripper{mockService: &MockKas{ + t: t, + obligations: map[string][]string{ + "policy1": {"https://example.com/attr/attr1/value/val1"}, + "policy2": {"https://example.com/attr/attr2/value/val2"}, + }, + policyDecisions: map[string]string{ + "policy1": "permit", // policy1 should be permitted + "policy2": "fail", // policy2 should be failed + }, + }}, + } + + // 2. Create the KAS client with the mocked HTTP client + tokenSource := getTokenSource(t) + c := newKASClient(mockClient, []connect.ClientOption{connect.WithProtoJSON()}, tokenSource, nil, nil) + + // 3. Define a dummy request. + dummyKeyAccess := []*kaspb.UnsignedRewrapRequest_WithPolicyRequest{ + { + Policy: &kaspb.UnsignedRewrapRequest_WithPolicy{ + Id: "policy1", + }, + KeyAccessObjects: []*kaspb.UnsignedRewrapRequest_WithKeyAccessObject{ + { + KeyAccessObject: &kaspb.KeyAccess{KasUrl: "https://kas.example.com"}, + }, + }, + }, + { + Policy: &kaspb.UnsignedRewrapRequest_WithPolicy{ + Id: "policy2", + }, + KeyAccessObjects: []*kaspb.UnsignedRewrapRequest_WithKeyAccessObject{ + { + KeyAccessObject: &kaspb.KeyAccess{KasUrl: "https://kas.example.com"}, + }, + }, + }, + } + + // 4. Call nanoUnwrap + policyResults, err := c.nanoUnwrap(t.Context(), dummyKeyAccess...) + require.NoError(t, err) + require.Len(t, policyResults, 2) + + // 5. Assertions + // Policy1 should be permitted - has symmetric key, no error, no obligations + result1, ok := policyResults["policy1"] + require.True(t, ok) + require.Len(t, result1, 1) + require.Equal(t, []byte("supersecretkey1"), result1[0].SymmetricKey) + require.NoError(t, result1[0].Error) + require.Len(t, result1[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val1", result1[0].RequiredObligations[0]) + + // Policy2 should be failed - has error, no symmetric key, has obligations + result2, ok := policyResults["policy2"] + require.True(t, ok) + require.Len(t, result2, 1) + require.Nil(t, result2[0].SymmetricKey, "Failed policies should not have symmetric key") + require.Error(t, result2[0].Error) + require.Contains(t, result2[0].Error.Error(), "denied by policy") + require.Len(t, result2[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr2/value/val2", result2[0].RequiredObligations[0]) +} + +func Test_nanoUnwrap_EmptySPK_WithObligations(t *testing.T) { + // 1. Construct the KAS rewrap response with empty SPK and obligations + rewrapResponse := &kaspb.RewrapResponse{ + SessionPublicKey: "", // Empty Session Public Key + Responses: []*kaspb.PolicyRewrapResult{ + { + PolicyId: "policy1", + Results: []*kaspb.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao1", + Status: "fail", + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "denied by policy", + }, + Metadata: createMetadataWithObligations([]string{ + "https://example.com/attr/attr1/value/val1", + }), + }, + }, + }, + }, + } + + responseBody, err := protojson.Marshal(rewrapResponse) + require.NoError(t, err) + + mockHTTPResponse := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(responseBody)), + Header: make(http.Header), + } + mockHTTPResponse.Header.Set("Content-Type", "application/json") + + // 2. Set up the mock HTTP client to return the crafted response + mockClient := &http.Client{ + Transport: &mockRoundTripper{Response: mockHTTPResponse}, + } + + // 3. Create the KAS client with the mocked HTTP client + tokenSource := getTokenSource(t) + c := newKASClient(mockClient, []connect.ClientOption{connect.WithProtoJSON()}, tokenSource, nil, nil) + + // 4. Define a dummy request that matches the response + dummyKeyAccess := []*kaspb.UnsignedRewrapRequest_WithPolicyRequest{ + { + Policy: &kaspb.UnsignedRewrapRequest_WithPolicy{ + Id: "policy1", + }, + KeyAccessObjects: []*kaspb.UnsignedRewrapRequest_WithKeyAccessObject{ + { + KeyAccessObjectId: "kao1", + KeyAccessObject: &kaspb.KeyAccess{KasUrl: "https://kas.example.com"}, + }, + }, + }, + } + + // 5. Call nanoUnwrap + policyResults, err := c.nanoUnwrap(t.Context(), dummyKeyAccess...) + require.NoError(t, err, "nanoUnwrap should not return a top-level error in this case") + require.Len(t, policyResults, 1) + + // 6. Assertions + result, ok := policyResults["policy1"] + require.True(t, ok) + require.Len(t, result, 1) + + // Assert that the KAO result contains an error + require.Error(t, result[0].Error) + require.Contains(t, result[0].Error.Error(), "denied by policy") + require.Nil(t, result[0].SymmetricKey) + + // Assert that obligations are still present despite the KAO error + require.Len(t, result[0].RequiredObligations, 1) + require.Equal(t, "https://example.com/attr/attr1/value/val1", result[0].RequiredObligations[0]) +} + +func createMetadataWithObligations(obligations []string) map[string]*structpb.Value { + metadata := make(map[string]*structpb.Value) + if len(obligations) == 0 { + return metadata + } + + listValue := &structpb.ListValue{} + for _, fqn := range obligations { + listValue.Values = append(listValue.Values, structpb.NewStringValue(fqn)) + } + + metadata[triggeredObligationsHeader] = structpb.NewListValue(listValue) + return metadata +} diff --git a/sdk/nanotdf.go b/sdk/nanotdf.go index 51af50668..5512d4c1a 100644 --- a/sdk/nanotdf.go +++ b/sdk/nanotdf.go @@ -11,11 +11,14 @@ import ( "fmt" "io" "log/slog" + "net/http" "sync" "time" + "connectrpc.com/connect" "github.com/opentdf/platform/protocol/go/kas" "github.com/opentdf/platform/protocol/go/policy" + "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/lib/ocrypto" ) @@ -872,6 +875,21 @@ type NanoTDFDecryptHandler struct { config *NanoTDFReaderConfig } +type NanoTDFReader struct { + reader io.ReadSeeker + tokenSource auth.AccessTokenSource + httpClient *http.Client + connectOptions []connect.ClientOption + collectionStore *collectionStore + + header NanoTDFHeader + headerBuf []byte + payloadKey []byte + + config *NanoTDFReaderConfig + requiredObligations *Obligations +} + func createNanoTDFDecryptHandler(reader io.ReadSeeker, writer io.Writer, opts ...NanoTDFReaderOption) (*NanoTDFDecryptHandler, error) { nanoTdfReaderConfig, err := newNanoTDFReaderConfig(opts...) if err != nil { @@ -886,107 +904,69 @@ func createNanoTDFDecryptHandler(reader io.ReadSeeker, writer io.Writer, opts .. func (n *NanoTDFDecryptHandler) CreateRewrapRequest(ctx context.Context) (map[string]*kas.UnsignedRewrapRequest_WithPolicyRequest, error) { var err error - var headerSize uint32 - n.header, headerSize, err = NewNanoTDFHeaderFromReader(n.reader) + n.header, n.headerBuf, err = getNanoTDFHeader(n.reader) if err != nil { return nil, err } - _, err = n.reader.Seek(0, io.SeekStart) - if err != nil { - return nil, fmt.Errorf("readSeeker.Seek failed: %w", err) - } - headerBuf := make([]byte, headerSize) - _, err = n.reader.Read(headerBuf) - if err != nil { - return nil, fmt.Errorf("readSeeker.Seek failed: %w", err) - } - kasURL, err := n.header.kasURL.GetURL() - if err != nil { - return nil, err - } - - if n.config.ignoreAllowList { - slog.WarnContext(ctx, "kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasURL)) - } else if !n.config.kasAllowlist.IsAllowed(kasURL) { - return nil, fmt.Errorf("KasAllowlist: kas url %s is not allowed", kasURL) - } - - req := &kas.UnsignedRewrapRequest_WithPolicyRequest{ - KeyAccessObjects: []*kas.UnsignedRewrapRequest_WithKeyAccessObject{ - { - KeyAccessObjectId: "kao-0", - KeyAccessObject: &kas.KeyAccess{KasUrl: kasURL, Header: headerBuf}, - }, - }, - Policy: &kas.UnsignedRewrapRequest_WithPolicy{ - Id: "policy", - }, - Algorithm: "ec:secp256r1", - } - return map[string]*kas.UnsignedRewrapRequest_WithPolicyRequest{kasURL: req}, nil + return createNanoRewrapRequest(ctx, n.config, n.header, n.headerBuf) } func (n *NanoTDFDecryptHandler) Decrypt(ctx context.Context, result []kaoResult) (int, error) { - var err error - if len(result) != 1 { - return 0, errors.New("improper result from kas") - } - - if result[0].Error != nil { - return 0, result[0].Error - } - key := result[0].SymmetricKey - - const ( - kPayloadLoadLengthBufLength = 4 - ) - payloadLengthBuf := make([]byte, kPayloadLoadLengthBufLength) - _, err = n.reader.Read(payloadLengthBuf[1:]) - if err != nil { - return 0, fmt.Errorf(" io.Reader.Read failed :%w", err) - } - - payloadLength := binary.BigEndian.Uint32(payloadLengthBuf) - getLogger().DebugContext(ctx, "decrypt", slog.Uint64("payload_length", uint64(payloadLength))) + return decryptNanoTDF(ctx, n.reader, n.writer, result, &n.header) +} - cipherData := make([]byte, payloadLength) - _, err = n.reader.Read(cipherData) +func (s SDK) LoadNanoTDF(ctx context.Context, reader io.ReadSeeker, opts ...NanoTDFReaderOption) (*NanoTDFReader, error) { + nanoTdfReaderConfig, err := newNanoTDFReaderConfig(opts...) if err != nil { - return 0, fmt.Errorf("readSeeker.Seek failed: %w", err) + return nil, fmt.Errorf("newNanoTDFReaderConfig failed: %w", err) } - aesGcm, err := ocrypto.NewAESGcm(key) - if err != nil { - return 0, fmt.Errorf("ocrypto.NewAESGcm failed:%w", err) + useGlobalFulfillableObligations := len(nanoTdfReaderConfig.fulfillableObligationFQNs) == 0 && len(s.fulfillableObligationFQNs) > 0 + if useGlobalFulfillableObligations { + nanoTdfReaderConfig.fulfillableObligationFQNs = s.fulfillableObligationFQNs } - ivPadded := make([]byte, 0, ocrypto.GcmStandardNonceSize) - noncePadding := make([]byte, kIvPadding) - ivPadded = append(ivPadded, noncePadding...) - iv := cipherData[:kNanoTDFIvSize] - ivPadded = append(ivPadded, iv...) - - tagSize, err := SizeOfAuthTagForCipher(n.header.sigCfg.cipher) + nanoTdfReaderConfig.kasAllowlist, err = getKasAllowList(ctx, nanoTdfReaderConfig.kasAllowlist, s, nanoTdfReaderConfig.ignoreAllowList) if err != nil { - return 0, fmt.Errorf("SizeOfAuthTagForCipher failed:%w", err) + return nil, err } - decryptedData, err := aesGcm.DecryptWithIVAndTagSize(ivPadded, cipherData[kNanoTDFIvSize:], tagSize) + header, headerBuf, err := getNanoTDFHeader(reader) if err != nil { - return 0, err - } + return nil, fmt.Errorf("getNanoTDFHeader: %w", err) + } + + return &NanoTDFReader{ + reader: reader, + tokenSource: s.tokenSource, + httpClient: s.conn.Client, + connectOptions: s.conn.Options, + config: nanoTdfReaderConfig, + collectionStore: s.collectionStore, + header: header, + headerBuf: headerBuf, + }, nil +} - writeLen, err := n.writer.Write(decryptedData) - if err != nil { - return 0, err +// Do all network behavior (Rewrap request) +func (n *NanoTDFReader) Init(ctx context.Context) error { + if n.payloadKey != nil { + return nil } - return writeLen, nil + return n.getNanoRewrapKey(ctx) } -func (n *NanoTDFDecryptHandler) getRawHeader() []byte { - return n.headerBuf +func (n *NanoTDFReader) DecryptNanoTDF(ctx context.Context, writer io.Writer) (int, error) { + if n.payloadKey == nil { + err := n.getNanoRewrapKey(ctx) + if err != nil { + return 0, err + } + } + + return decryptNanoTDF(ctx, n.reader, writer, []kaoResult{{SymmetricKey: n.payloadKey}}, &n.header) } // ReadNanoTDF - read the nano tdf and return the decrypted data from it @@ -996,70 +976,76 @@ func (s SDK) ReadNanoTDF(writer io.Writer, reader io.ReadSeeker, opts ...NanoTDF // ReadNanoTDFContext - allows cancelling the reader func (s SDK) ReadNanoTDFContext(ctx context.Context, writer io.Writer, reader io.ReadSeeker, opts ...NanoTDFReaderOption) (int, error) { - handler, err := createNanoTDFDecryptHandler(reader, writer, opts...) + r, err := s.LoadNanoTDF(ctx, reader, opts...) if err != nil { - return 0, fmt.Errorf("createNanoTDFDecryptHandler failed: %w", err) + return 0, fmt.Errorf("LoadNanoTDF: %w", err) } - if len(handler.config.kasAllowlist) == 0 && !handler.config.ignoreAllowList { //nolint:nestif // handling the case where kasAllowlist is not provided - if s.KeyAccessServerRegistry != nil { - platformEndpoint, err := s.PlatformConfiguration.platformEndpoint() - if err != nil { - return 0, fmt.Errorf("retrieving platformEndpoint failed: %w", err) - } - // retrieve the registered kases if not provided - allowList, err := allowListFromKASRegistry(ctx, s.logger, s.KeyAccessServerRegistry, platformEndpoint) - if err != nil { - return 0, fmt.Errorf("allowListFromKASRegistry failed: %w", err) - } - handler.config.kasAllowlist = allowList - } else { - slog.ErrorContext(ctx, "no KAS allowlist provided and no KeyAccessServerRegistry available") - return 0, errors.New("no KAS allowlist provided and no KeyAccessServerRegistry available") - } - } - - symmetricKey, err := s.getNanoRewrapKey(ctx, handler) + err = r.getNanoRewrapKey(ctx) if err != nil { return 0, fmt.Errorf("getNanoRewrapKey: %w", err) } - return handler.Decrypt(ctx, []kaoResult{{SymmetricKey: symmetricKey}}) + + return r.DecryptNanoTDF(ctx, writer) } -func (s SDK) getNanoRewrapKey(ctx context.Context, decryptor *NanoTDFDecryptHandler) ([]byte, error) { - req, err := decryptor.CreateRewrapRequest(ctx) +/* +* Returns the obligations required for access to the TDF payload, assuming you +* have called Init() or DecryptNanoTDF() to populate obligations. +* +* If obligations are not populated an error is returned. + */ +func (n *NanoTDFReader) Obligations(_ context.Context) (Obligations, error) { + if n.requiredObligations == nil { + return Obligations{}, ErrObligationsNotPopulated + } + + return *n.requiredObligations, nil +} + +func (n *NanoTDFReader) getNanoRewrapKey(ctx context.Context) error { + req, err := createNanoRewrapRequest(ctx, n.config, n.header, n.headerBuf) if err != nil { - return nil, fmt.Errorf("CreateRewrapRequest: %w", err) + return fmt.Errorf("CreateRewrapRequest: %w", err) } - if s.collectionStore != nil { - if key, found := s.collectionStore.get(decryptor.getRawHeader()); found { - return key, nil + if n.collectionStore != nil { + if key, found := n.collectionStore.get(n.headerBuf); found { + n.payloadKey = key + return nil } } - client := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, nil) - kasURL, err := decryptor.header.kasURL.GetURL() + client := newKASClient(n.httpClient, n.connectOptions, n.tokenSource, nil, n.config.fulfillableObligationFQNs) + kasURL, err := n.header.kasURL.GetURL() if err != nil { - return nil, fmt.Errorf("nano header kasUrl: %w", err) + return fmt.Errorf("nano header kasUrl: %w", err) } policyResult, err := client.nanoUnwrap(ctx, req[kasURL]) if err != nil { - return nil, fmt.Errorf("rewrap failed: %w", err) + return fmt.Errorf("rewrap failed: %w", err) } result, ok := policyResult["policy"] if !ok || len(result) != 1 { - return nil, errors.New("policy was not found in rewrap response") + return errors.New("policy was not found in rewrap response") } + + // Populate obligations after policy result is found. + n.requiredObligations = &Obligations{FQNs: result[0].RequiredObligations} + if result[0].Error != nil { - return nil, fmt.Errorf("rewrapError: %w", result[0].Error) + errToReturn := fmt.Errorf("rewrapError: %w", result[0].Error) + return getKasErrorToReturn(result[0].Error, errToReturn) } - if s.collectionStore != nil { - s.collectionStore.store(decryptor.getRawHeader(), result[0].SymmetricKey) + if n.collectionStore != nil { + n.collectionStore.store(n.headerBuf, result[0].SymmetricKey) } - return result[0].SymmetricKey, nil + + n.payloadKey = result[0].SymmetricKey + + return nil } func versionSalt() []byte { @@ -1169,3 +1155,110 @@ func getNanoKasInfoFromBaseKey(s *SDK) (*KASInfo, error) { Algorithm: alg, }, nil } + +func getNanoTDFHeader(reader io.ReadSeeker) (NanoTDFHeader, []byte, error) { + var err error + var headerSize uint32 + var header NanoTDFHeader + header, headerSize, err = NewNanoTDFHeaderFromReader(reader) + if err != nil { + return header, []byte{}, err + } + _, err = reader.Seek(0, io.SeekStart) + if err != nil { + return header, []byte{}, fmt.Errorf("readSeeker.Seek failed: %w", err) + } + + headerBuf := make([]byte, headerSize) + _, err = reader.Read(headerBuf) + if err != nil { + return header, []byte{}, fmt.Errorf("readSeeker.Read failed: %w", err) + } + + return header, headerBuf, nil +} + +func createNanoRewrapRequest(ctx context.Context, config *NanoTDFReaderConfig, header NanoTDFHeader, headerBuf []byte) (map[string]*kas.UnsignedRewrapRequest_WithPolicyRequest, error) { + kasURL, err := header.kasURL.GetURL() + if err != nil { + return nil, err + } + + if config.ignoreAllowList { + slog.WarnContext(ctx, "kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasURL)) + } else if !config.kasAllowlist.IsAllowed(kasURL) { + return nil, fmt.Errorf("KasAllowlist: kas url %s is not allowed", kasURL) + } + + req := &kas.UnsignedRewrapRequest_WithPolicyRequest{ + KeyAccessObjects: []*kas.UnsignedRewrapRequest_WithKeyAccessObject{ + { + KeyAccessObjectId: "kao-0", + KeyAccessObject: &kas.KeyAccess{KasUrl: kasURL, Header: headerBuf}, + }, + }, + Policy: &kas.UnsignedRewrapRequest_WithPolicy{ + Id: "policy", + }, + Algorithm: "ec:secp256r1", + } + return map[string]*kas.UnsignedRewrapRequest_WithPolicyRequest{kasURL: req}, nil +} + +func decryptNanoTDF(ctx context.Context, reader io.ReadSeeker, writer io.Writer, result []kaoResult, header *NanoTDFHeader) (int, error) { + var err error + if len(result) != 1 { + return 0, errors.New("improper result from kas") + } + + if result[0].Error != nil { + return 0, result[0].Error + } + key := result[0].SymmetricKey + + const ( + kPayloadLoadLengthBufLength = 4 + ) + payloadLengthBuf := make([]byte, kPayloadLoadLengthBufLength) + _, err = reader.Read(payloadLengthBuf[1:]) + if err != nil { + return 0, fmt.Errorf(" io.Reader.Read failed :%w", err) + } + + payloadLength := binary.BigEndian.Uint32(payloadLengthBuf) + slog.DebugContext(ctx, "decrypt", slog.Uint64("payload_length", uint64(payloadLength))) + + cipherData := make([]byte, payloadLength) + _, err = reader.Read(cipherData) + if err != nil { + return 0, fmt.Errorf("readSeeker.Seek failed: %w", err) + } + + aesGcm, err := ocrypto.NewAESGcm(key) + if err != nil { + return 0, fmt.Errorf("ocrypto.NewAESGcm failed:%w", err) + } + + ivPadded := make([]byte, 0, ocrypto.GcmStandardNonceSize) + noncePadding := make([]byte, kIvPadding) + ivPadded = append(ivPadded, noncePadding...) + iv := cipherData[:kNanoTDFIvSize] + ivPadded = append(ivPadded, iv...) + + tagSize, err := SizeOfAuthTagForCipher(header.sigCfg.cipher) + if err != nil { + return 0, fmt.Errorf("SizeOfAuthTagForCipher failed:%w", err) + } + + decryptedData, err := aesGcm.DecryptWithIVAndTagSize(ivPadded, cipherData[kNanoTDFIvSize:], tagSize) + if err != nil { + return 0, err + } + + writeLen, err := writer.Write(decryptedData) + if err != nil { + return 0, err + } + + return writeLen, nil +} diff --git a/sdk/nanotdf_config.go b/sdk/nanotdf_config.go index 5efde5663..4e4686ea7 100644 --- a/sdk/nanotdf_config.go +++ b/sdk/nanotdf_config.go @@ -123,8 +123,9 @@ func WithECDSAPolicyBinding() NanoTDFOption { } type NanoTDFReaderConfig struct { - kasAllowlist AllowList - ignoreAllowList bool + kasAllowlist AllowList + ignoreAllowList bool + fulfillableObligationFQNs []string } func newNanoTDFReaderConfig(opt ...NanoTDFReaderOption) (*NanoTDFReaderConfig, error) { @@ -166,3 +167,10 @@ func WithNanoIgnoreAllowlist(ignore bool) NanoTDFReaderOption { return nil } } + +func WithNanoTDFFulfillableObligationFQNs(fqns []string) NanoTDFReaderOption { + return func(c *NanoTDFReaderConfig) error { + c.fulfillableObligationFQNs = fqns + return nil + } +} diff --git a/sdk/nanotdf_test.go b/sdk/nanotdf_test.go index 17f396343..96e0ca368 100644 --- a/sdk/nanotdf_test.go +++ b/sdk/nanotdf_test.go @@ -5,14 +5,23 @@ import ( "context" "crypto/ecdh" "crypto/rand" + "crypto/x509" "encoding/gob" + "encoding/json" + "encoding/pem" "errors" + "fmt" "io" "log/slog" + "net/http" "os" + "strings" "testing" + "connectrpc.com/connect" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/lib/ocrypto" + "github.com/opentdf/platform/protocol/go/kas" "github.com/opentdf/platform/protocol/go/policy" "github.com/opentdf/platform/protocol/go/wellknownconfiguration" "github.com/stretchr/testify/assert" @@ -20,6 +29,7 @@ import ( "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/structpb" ) @@ -27,6 +37,13 @@ const ( nanoFakePem = "pem" ) +// mockTransport is a custom RoundTripper that intercepts HTTP requests +type mockTransport struct { + publicKey string + kid string + kasKeyPair ocrypto.KeyPair // Store the KAS key pair for consistent crypto operations +} + // nanotdfEqual compares two nanoTdf structures for equality. func nanoTDFEqual(a, b *NanoTDFHeader) bool { // Compare kasURL field @@ -400,12 +417,18 @@ func TestDataSet(t *testing.T) { type NanoSuite struct { suite.Suite + mockTransport *mockTransport } func TestNanoTDF(t *testing.T) { suite.Run(t, new(NanoSuite)) } +func (s *NanoSuite) SetupSuite() { + // Create a single mock transport instance for the entire test suite + s.mockTransport = newMockTransport() +} + // mockWellKnownServiceClient is a mock implementation of sdkconnect.WellKnownServiceClient type mockWellKnownServiceClient struct { mockResponse func() (*wellknownconfiguration.GetWellKnownConfigurationResponse, error) @@ -676,3 +699,546 @@ func createMockWellKnownServiceClient(s *suite.Suite, wellKnownConfig map[string }, } } + +// Test suite for NanoTDF Reader functionality +func (s *NanoSuite) Test_NanoTDFReader_LoadNanoTDF() { + // Create a real NanoTDF for testing + sdk, err := s.createTestSDK() + sdk.fulfillableObligationFQNs = []string{"https://example.com/obl/value/obl1"} + s.Require().NoError(err) + nanoTDFData, err := s.createRealNanoTDF(sdk) + s.Require().NoError(err) + reader := bytes.NewReader(nanoTDFData) + + // Test successful load with ignore allowlist + nanoReader, err := sdk.LoadNanoTDF(s.T().Context(), reader, WithNanoIgnoreAllowlist(true)) + s.Require().NoError(err) + s.Require().NotNil(nanoReader) + s.Require().Equal(reader, nanoReader.reader) + s.Require().NotNil(nanoReader.config) + s.Require().True(nanoReader.config.ignoreAllowList) + s.Require().Len(nanoReader.config.fulfillableObligationFQNs, 1) + s.Require().Equal("https://example.com/obl/value/obl1", nanoReader.config.fulfillableObligationFQNs[0]) + + // Test with KAS allowlist + allowedURLs := []string{"https://kas.example.com"} + reader = bytes.NewReader(nanoTDFData) // Reset reader + nanoReader2, err := sdk.LoadNanoTDF(s.T().Context(), reader, WithNanoKasAllowlist(allowedURLs)) + s.Require().NoError(err) + s.Require().NotNil(nanoReader2.config.kasAllowlist) + s.Require().True(nanoReader2.config.kasAllowlist.IsAllowed("https://kas.example.com")) + + // Test with fulfillable obligations + obligations := []string{"obligation1", "obligation2"} + reader = bytes.NewReader(nanoTDFData) // Reset reader + nanoReader3, err := sdk.LoadNanoTDF(s.T().Context(), reader, WithNanoTDFFulfillableObligationFQNs(obligations), WithNanoIgnoreAllowlist(true)) + s.Require().NoError(err) + s.Require().Equal(obligations, nanoReader3.config.fulfillableObligationFQNs) + + // Test with invalid reader (nil) + _, err = sdk.LoadNanoTDF(s.T().Context(), nil) + s.Require().Error(err) +} + +func (s *NanoSuite) Test_NanoTDFReader_Init_WithPayloadKeySet() { + // Create a real NanoTDF for testing + sdk, err := s.createTestSDK() + s.Require().NoError(err) + nanoTDFData, err := s.createRealNanoTDF(sdk) + s.Require().NoError(err) + reader := bytes.NewReader(nanoTDFData) + nanoReader, err := sdk.LoadNanoTDF(s.T().Context(), reader, WithNanoIgnoreAllowlist(true)) + s.Require().NoError(err) + + // Test that calling Init twice doesn't cause issues when payloadKey is set + nanoReader.payloadKey = []byte("mock-key") + err = nanoReader.Init(s.T().Context()) + s.Require().NoError(err) // Should return early since payloadKey is set +} + +func (s *NanoSuite) Test_NanoTDFReader_Init_WithoutPayloadKeySet() { + // Create a real NanoTDF for testing + sdk, err := s.createTestSDK() + s.Require().NoError(err) + nanoTDFData, err := s.createRealNanoTDF(sdk) + s.Require().NoError(err) + reader := bytes.NewReader(nanoTDFData) + + nanoReader, err := sdk.LoadNanoTDF(s.T().Context(), reader, WithNanoIgnoreAllowlist(true)) + s.Require().NoError(err) + + err = nanoReader.Init(s.T().Context()) + s.Require().NoError(err) + s.Require().NotNil(nanoReader.payloadKey) +} + +func (s *NanoSuite) Test_NanoTDFReader_ObligationsSupport() { + // Create a real NanoTDF for testing + sdk, err := s.createTestSDK() + s.Require().NoError(err) + nanoTDFData, err := s.createRealNanoTDF(sdk) + s.Require().NoError(err) + reader := bytes.NewReader(nanoTDFData) + nanoReader, err := sdk.LoadNanoTDF(s.T().Context(), reader, WithNanoIgnoreAllowlist(true)) + s.Require().NoError(err) + s.Require().Nil(nanoReader.requiredObligations) + + // Mock some triggered obligations as would happen during rewrap + mockObligations := Obligations{ + FQNs: []string{"obligation1", "obligation2"}, + } + nanoReader.requiredObligations = &mockObligations + + // Verify obligations are stored + s.Require().NotNil(nanoReader.requiredObligations) + s.Require().Len(nanoReader.requiredObligations.FQNs, 2) + s.Require().Contains(nanoReader.requiredObligations.FQNs, "obligation1") + s.Require().Contains(nanoReader.requiredObligations.FQNs, "obligation2") +} + +func (s *NanoSuite) Test_NanoTDFReader_DecryptNanoTDF() { + // Create a real NanoTDF for testing + sdk, err := s.createTestSDK() + s.Require().NoError(err) + nanoTDFData, err := s.createRealNanoTDF(sdk) + s.Require().NoError(err) + reader := bytes.NewReader(nanoTDFData) + writer := &bytes.Buffer{} + + nanoReader, err := sdk.LoadNanoTDF(s.T().Context(), reader, WithNanoIgnoreAllowlist(true)) + s.Require().NoError(err) + + _, err = nanoReader.DecryptNanoTDF(s.T().Context(), writer) + s.Require().NoError(err) + s.Require().Equal([]byte("Virtru!!!!"), writer.Bytes()) +} + +func (s *NanoSuite) Test_NanoTDFReader_RealWorkflow() { + // Test the complete workflow: Create -> Load -> Parse Header + originalData := []byte("This is test data for NanoTDF encryption!") + + // Step 1: Create a real NanoTDF + input := bytes.NewReader(originalData) + output := &bytes.Buffer{} + + // Create SDK with consistent mock transport + sdk, err := s.createTestSDK() + s.Require().NoError(err) + + config, err := sdk.NewNanoTDFConfig() + s.Require().NoError(err) + + err = config.SetKasURL("https://kas.example.com") + s.Require().NoError(err) + + err = config.SetAttributes([]string{"https://example.com/attr/classification/value/secret"}) + s.Require().NoError(err) + + // The kasPublicKey will be fetched automatically from the mock HTTP client during CreateNanoTDF + + // Create the NanoTDF + tdfSize, err := sdk.CreateNanoTDF(output, input, *config) + s.Require().NoError(err) + s.Require().Positive(tdfSize) + + // Step 2: Load the created NanoTDF + tdfData := output.Bytes() + reader := bytes.NewReader(tdfData) + + nanoReader, err := sdk.LoadNanoTDF(s.T().Context(), reader, WithNanoIgnoreAllowlist(true)) + s.Require().NoError(err) + s.Require().NotNil(nanoReader) + + // Step 3: Validate the header (it should be loaded automatically) + s.Require().NotNil(nanoReader.headerBuf) + s.Require().NotEmpty(nanoReader.headerBuf) + + // Check KAS URL + kasURL, err := nanoReader.header.kasURL.GetURL() + s.Require().NoError(err) + s.Require().Equal("https://kas.example.com", kasURL) + + // Check policy mode and other header fields + s.Require().Equal(PolicyType(2), nanoReader.header.PolicyMode) // Embedded encrypted policy + s.Require().NotNil(nanoReader.header.PolicyBody) + s.Require().NotEmpty(nanoReader.header.PolicyBody) + s.Require().NotNil(nanoReader.header.EphemeralKey) + s.Require().Len(nanoReader.header.EphemeralKey, 33) // secp256r1 compressed key + + _, err = nanoReader.Obligations(s.T().Context()) // Fails bc we don't setup fake authz client here. + s.Require().Error(err) +} + +func (s *NanoSuite) Test_NanoTDF_Obligations() { + sdk, err := s.createTestSDK() + s.Require().NoError(err) + encryptedPolicyTDF, err := s.createRealNanoTDF(sdk) + s.Require().NoError(err) + + // Table-driven test for nano TDF obligations support + testCases := []struct { + name string + fulfillableObligations []string + requiredObligations []string + expectError error + populateObligations []string + }{ + { + name: "Rewrap not called - Error", + expectError: ErrObligationsNotPopulated, + }, + { + name: "Rewrap called - Obligations populated", + expectError: nil, + requiredObligations: []string{"https://example.com/attr/attr1/value/value1"}, + fulfillableObligations: []string{"https://example.com/attr/attr1/value/value1"}, + populateObligations: []string{"https://example.com/attr/attr1/value/value1"}, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + reader := bytes.NewReader(encryptedPolicyTDF) + nanoReader, err := sdk.LoadNanoTDF(s.T().Context(), reader, WithNanoTDFFulfillableObligationFQNs(tc.fulfillableObligations), WithNanoIgnoreAllowlist(true)) + s.Require().NoError(err) + // Check that it has fulfillable obligations set + if len(tc.fulfillableObligations) > 0 { + s.Require().NotNil(nanoReader.config.fulfillableObligationFQNs) + s.Require().Equal(tc.fulfillableObligations, nanoReader.config.fulfillableObligationFQNs) + } else { + s.Require().Empty(nanoReader.config.fulfillableObligationFQNs) + } + + if tc.populateObligations != nil { + nanoReader.requiredObligations = &Obligations{FQNs: tc.populateObligations} + } + + // Initialize the reader (this will parse the header) + obl, err := nanoReader.Obligations(s.T().Context()) + if tc.expectError != nil { + s.Require().Error(err) + s.Require().Empty(obl.FQNs) + s.Require().ErrorIs(err, tc.expectError) + return + } + s.Require().NoError(err) + s.Require().Equal(tc.requiredObligations, obl.FQNs) + + // Call again to verify caching + obl, err = nanoReader.Obligations(s.T().Context()) + s.Require().NoError(err) + s.Require().Equal(tc.requiredObligations, obl.FQNs) + }) + } +} + +// Helper function to create real NanoTDF data for testing +func (s *NanoSuite) createRealNanoTDF(sdk *SDK) ([]byte, error) { + // Read the test file content + input := bytes.NewReader([]byte("Virtru!!!!")) + output := &bytes.Buffer{} + + // Create a NanoTDF config + config, err := sdk.NewNanoTDFConfig() + if err != nil { + return nil, err + } + + // Set a test KAS URL + err = config.SetKasURL("https://kas.example.com") + if err != nil { + return nil, err + } + + // Set test attributes + err = config.SetAttributes([]string{"https://example.com/attr/attr1/value/value1"}) + if err != nil { + return nil, err + } + + err = config.SetPolicyMode(NanoTDFPolicyModeDefault) + if err != nil { + return nil, err + } + + // The kasPublicKey will be fetched automatically from the mock HTTP client during CreateNanoTDF + + // Create the NanoTDF + _, err = sdk.CreateNanoTDF(output, input, *config) + if err != nil { + return nil, err + } + + return output.Bytes(), nil +} + +func (s *NanoSuite) createMockHTTPClient() *http.Client { + return &http.Client{ + Transport: s.mockTransport, + } +} + +// Helper function to create a properly configured SDK for testing +func (s *NanoSuite) createTestSDK() (*SDK, error) { + sdk, err := New("http://localhost:8080", WithPlatformConfiguration(PlatformConfiguration{})) + if err != nil { + return nil, err + } + + sdk.conn.Client = s.createMockHTTPClient() + sdk.conn.Options = []connect.ClientOption{connect.WithProtoJSON()} + sdk.tokenSource = getTokenSource(s.T()) + + return sdk, nil +} + +func newMockTransport() *mockTransport { + // Generate a consistent KAS key pair for the mock + kasKeyPair, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1) + if err != nil { + panic(fmt.Sprintf("Failed to generate KAS key pair: %v", err)) + } + + publicKeyPEM, err := kasKeyPair.PublicKeyInPemFormat() + if err != nil { + panic(fmt.Sprintf("Failed to get public key PEM: %v", err)) + } + + return &mockTransport{ + publicKey: publicKeyPEM, + kid: "e1", + kasKeyPair: kasKeyPair, + } +} + +func (m *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // Check if this is a PublicKey request to KAS + if strings.Contains(req.URL.Path, "/kas.AccessService/PublicKey") { + // Create a mock PublicKeyResponse in the format expected by Connect RPC + response := &kas.PublicKeyResponse{ + PublicKey: m.publicKey, + Kid: m.kid, + } + + // Marshal the response to JSON using Connect protocol format + responseJSON, err := json.Marshal(response) + if err != nil { + return nil, fmt.Errorf("failed to marshal mock response: %w", err) + } + + // Create a mock HTTP response + resp := &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + Body: io.NopCloser(bytes.NewReader(responseJSON)), + ContentLength: int64(len(responseJSON)), + Request: req, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + } + + return resp, nil + } + + // Check if this is a Rewrap request to KAS + if strings.Contains(req.URL.Path, "/kas.AccessService/Rewrap") { + return m.handleRewrapRequest(req) + } + + // For any other requests, return an error + return nil, fmt.Errorf("unexpected request to %s", req.URL.String()) +} + +// handleRewrapRequest handles mock rewrap requests for testing +func (m *mockTransport) handleRewrapRequest(req *http.Request) (*http.Response, error) { + // Read the request body + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("failed to read request body: %w", err) + } + + // Parse the Connect RPC request + var bodyJSON map[string]interface{} + err = json.Unmarshal(bodyBytes, &bodyJSON) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal request body: %w", err) + } + + // Extract the signed request token + signedRequestToken, ok := bodyJSON["signedRequestToken"].(string) + if !ok { + return nil, errors.New("missing signedRequestToken in request") + } + + // Parse the JWT token without verification (for testing) + token, err := jwt.ParseString(signedRequestToken, jwt.WithVerify(false)) + if err != nil { + return nil, fmt.Errorf("failed to parse JWT token: %w", err) + } + + // Extract the request body from the JWT + requestBodyClaim, ok := token.Get("requestBody") + if !ok { + return nil, errors.New("missing requestBody in JWT") + } + + requestBodyJSON, ok := requestBodyClaim.(string) + if !ok { + return nil, errors.New("requestBody is not a string") + } + + // Parse the unsigned rewrap request + var unsignedReq kas.UnsignedRewrapRequest + err = protojson.Unmarshal([]byte(requestBodyJSON), &unsignedReq) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal unsigned request: %w", err) + } + + // Get the client's public key (for the rewrap session key) + clientPublicKeyPEM := unsignedReq.GetClientPublicKey() + + // Extract the NanoTDF header from the KeyAccessObject to get the ephemeral public key + if len(unsignedReq.GetRequests()) == 0 || len(unsignedReq.GetRequests()[0].GetKeyAccessObjects()) == 0 { + return nil, errors.New("no key access objects in request") + } + + headerBuf := unsignedReq.GetRequests()[0].GetKeyAccessObjects()[0].GetKeyAccessObject().GetHeader() + if len(headerBuf) == 0 { + return nil, errors.New("no header in key access object") + } + + // Parse the NanoTDF header to extract the ephemeral public key + headerReader := bytes.NewReader(headerBuf) + nanoHeader, _, err := NewNanoTDFHeaderFromReader(headerReader) + if err != nil { + return nil, fmt.Errorf("failed to parse NanoTDF header: %w", err) + } + + // Get the KAS private key for ECDH computation + kasPrivateKeyForECDH, err := m.kasKeyPair.PrivateKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("failed to get KAS private key: %w", err) + } + + // Convert ephemeral public key to PEM format for ECDH computation + curve, err := nanoHeader.ECCurve() + if err != nil { + return nil, fmt.Errorf("failed to get ECC curve: %w", err) + } + + ephemeralPublicKey, err := ocrypto.UncompressECPubKey(curve, nanoHeader.EphemeralKey) + if err != nil { + return nil, fmt.Errorf("failed to uncompress ephemeral public key: %w", err) + } + + // Convert to PEM format using the same method as the real KAS service + derBytes, err := x509.MarshalPKIXPublicKey(ephemeralPublicKey) + if err != nil { + return nil, fmt.Errorf("failed to marshal ECDSA public key: %w", err) + } + pemBlock := &pem.Block{ + Type: "PUBLIC KEY", + Bytes: derBytes, + } + ephemeralPublicKeyPEM := pem.EncodeToMemory(pemBlock) + + // Compute ECDH shared secret between KAS private key and ephemeral public key + // This recreates the symmetric key that was used during NanoTDF creation + ecdhSharedSecret, err := ocrypto.ComputeECDHKey([]byte(kasPrivateKeyForECDH), ephemeralPublicKeyPEM) + if err != nil { + return nil, fmt.Errorf("failed to compute ECDH shared secret: %w", err) + } + + // Derive the symmetric key using the same process as createNanoTDFSymmetricKey + originalSymmetricKey, err := ocrypto.CalculateHKDF(versionSalt(), ecdhSharedSecret) + if err != nil { + return nil, fmt.Errorf("failed to derive symmetric key: %w", err) + } + + // Now generate a new ephemeral key pair for the rewrap session + rewrapKasKeyPair, err := ocrypto.NewECKeyPair(ocrypto.ECCModeSecp256r1) + if err != nil { + return nil, fmt.Errorf("failed to generate rewrap KAS key pair: %w", err) + } + + rewrapKasPublicKeyPEM, err := rewrapKasKeyPair.PublicKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("failed to get rewrap KAS public key PEM: %w", err) + } + + rewrapKasPrivateKeyPEM, err := rewrapKasKeyPair.PrivateKeyInPemFormat() + if err != nil { + return nil, fmt.Errorf("failed to get rewrap KAS private key PEM: %w", err) + } + + // Compute ECDH shared secret between client's rewrap public key and new KAS ephemeral private key + rewrapEcdhKey, err := ocrypto.ComputeECDHKey([]byte(rewrapKasPrivateKeyPEM), []byte(clientPublicKeyPEM)) + if err != nil { + return nil, fmt.Errorf("failed to compute rewrap ECDH key: %w", err) + } + + // Derive session key using HKDF with version salt + sessionKey, err := ocrypto.CalculateHKDF(versionSalt(), rewrapEcdhKey) + if err != nil { + return nil, fmt.Errorf("failed to calculate rewrap session key: %w", err) + } + + // Create AES-GCM encryptor with session key + encryptor, err := ocrypto.NewAESGcm(sessionKey) + if err != nil { + return nil, fmt.Errorf("failed to create AES-GCM encryptor: %w", err) + } + + // Encrypt the original symmetric key with the rewrap session key + wrappedKey, err := encryptor.Encrypt(originalSymmetricKey) + if err != nil { + return nil, fmt.Errorf("failed to encrypt symmetric key: %w", err) + } + + // Build the rewrap response + rewrapResponse := &kas.RewrapResponse{ + SessionPublicKey: rewrapKasPublicKeyPEM, + Responses: []*kas.PolicyRewrapResult{ + { + PolicyId: "policy", + Results: []*kas.KeyAccessRewrapResult{ + { + KeyAccessObjectId: "kao-0", + Status: "permit", + Result: &kas.KeyAccessRewrapResult_KasWrappedKey{ + KasWrappedKey: wrappedKey, + }, + }, + }, + }, + }, + Metadata: make(map[string]*structpb.Value), + } + + // Marshal the response + responseJSON, err := protojson.Marshal(rewrapResponse) + if err != nil { + return nil, fmt.Errorf("failed to marshal rewrap response: %w", err) + } + + // Create HTTP response + resp := &http.Response{ + Status: http.StatusText(http.StatusOK), + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + }, + Body: io.NopCloser(bytes.NewReader(responseJSON)), + ContentLength: int64(len(responseJSON)), + Request: req, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + } + + return resp, nil +} diff --git a/sdk/options.go b/sdk/options.go index e62b8459a..df36a7a9f 100644 --- a/sdk/options.go +++ b/sdk/options.go @@ -44,6 +44,7 @@ type config struct { entityResolutionConn *ConnectRPCConnection collectionStore *collectionStore shouldValidatePlatformConnectivity bool + fulfillableObligationFQNs []string logger *slog.Logger } @@ -233,6 +234,13 @@ func WithNoKIDInNano() Option { } } +// WithFulfillableObligationFQNs sets the list of obligation FQNs that can +func WithFulfillableObligationFQNs(fqns []string) Option { + return func(c *config) { + c.fulfillableObligationFQNs = fqns + } +} + // WithLogger returns an Option that sets a custom slog.Logger for all SDK logging. func WithLogger(logger *slog.Logger) Option { return func(c *config) { diff --git a/sdk/tdf.go b/sdk/tdf.go index eb9de45ad..ec56966a6 100644 --- a/sdk/tdf.go +++ b/sdk/tdf.go @@ -48,6 +48,7 @@ const ( kAssertionSignature = "assertionSig" kAssertionHash = "assertionHash" hexSemverThreshold = "4.3.0" + readActionName = "read" ) // Loads and reads ZTDF files @@ -64,6 +65,11 @@ type Reader struct { payloadKey []byte kasSessionKey ocrypto.KeyPair config TDFReaderConfig + requiredObligations *Obligations +} + +type Obligations struct { + FQNs []string } type TDFObject struct { @@ -83,6 +89,8 @@ type ecKeyWrappedKeyInfo struct { wrappedKey string } +var ErrObligationsNotPopulated = errors.New("obligations not populated") + func (r *tdf3DecryptHandler) Decrypt(ctx context.Context, results []kaoResult) (int, error) { err := r.reader.buildKey(ctx, results) if err != nil { @@ -767,22 +775,14 @@ func (s SDK) LoadTDF(reader io.ReadSeeker, opts ...TDFReaderOption) (*Reader, er return nil, fmt.Errorf("newAssertionConfig failed: %w", err) } - if len(config.kasAllowlist) == 0 && !config.ignoreAllowList { //nolint:nestif // handle the case where kasAllowlist is empty - if s.KeyAccessServerRegistry != nil { - // retrieve the registered kases if not provided - platformEndpoint, err := s.PlatformConfiguration.platformEndpoint() - if err != nil { - return nil, fmt.Errorf("retrieving platformEndpoint failed: %w", err) - } - allowList, err := allowListFromKASRegistry(context.Background(), s.logger, s.KeyAccessServerRegistry, platformEndpoint) - if err != nil { - return nil, fmt.Errorf("allowListFromKASRegistry failed: %w", err) - } - config.kasAllowlist = allowList - } else { - slog.Error("no KAS allowlist provided and no KeyAccessServerRegistry available") - return nil, errors.New("no KAS allowlist provided and no KeyAccessServerRegistry available") - } + useGlobalFulfillableObligations := len(config.fulfillableObligationFQNs) == 0 && len(s.fulfillableObligationFQNs) > 0 + if useGlobalFulfillableObligations { + config.fulfillableObligationFQNs = s.fulfillableObligationFQNs + } + + config.kasAllowlist, err = getKasAllowList(context.Background(), config.kasAllowlist, s, config.ignoreAllowList) + if err != nil { + return nil, err } manifest, err := tdfReader.Manifest() @@ -1092,6 +1092,20 @@ func (r *Reader) DataAttributes() ([]string, error) { return attributes, nil } +/* +* Returns the obligations required for access to the TDF payload, assuming you +* have called Init() or WriteTo() to populate obligations. +* +* If obligations are not populated an error is returned. + */ +func (r *Reader) Obligations(_ context.Context) (Obligations, error) { + if r.requiredObligations == nil { + return Obligations{}, ErrObligationsNotPopulated + } + + return *r.requiredObligations, nil +} + /* *WARNING:* Using this function is unsafe since KAS will no longer be able to prevent access to the key. @@ -1203,13 +1217,7 @@ func (r *Reader) buildKey(_ context.Context, results []kaoResult) error { if err != nil { errToReturn := fmt.Errorf("kao unwrap failed for split %v: %w", ss, err) - if strings.Contains(err.Error(), codes.InvalidArgument.String()) { - errToReturn = fmt.Errorf("%w: %w", ErrRewrapBadRequest, errToReturn) - } - if strings.Contains(err.Error(), codes.PermissionDenied.String()) { - errToReturn = fmt.Errorf("%w: %w", errRewrapForbidden, errToReturn) - } - skippedSplits[ss] = errToReturn + skippedSplits[ss] = getKasErrorToReturn(err, errToReturn) continue } @@ -1359,7 +1367,7 @@ func (r *Reader) buildKey(_ context.Context, results []kaoResult) error { // Unwraps the payload key, if possible, using the access service func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocognit // Better readability keeping it as is - kasClient := newKASClient(r.httpClient, r.connectOptions, r.tokenSource, r.kasSessionKey) + kasClient := newKASClient(r.httpClient, r.connectOptions, r.tokenSource, r.kasSessionKey, r.config.fulfillableObligationFQNs) var kaoResults []kaoResult reqFail := func(err error, req *kas.UnsignedRewrapRequest_WithPolicyRequest) { @@ -1398,6 +1406,8 @@ func (r *Reader) doPayloadKeyUnwrap(ctx context.Context) error { //nolint:gocogn kaoResults = append(kaoResults, result...) } } + // Deduplicate obligations for all kao results + r.requiredObligations = &Obligations{FQNs: dedupRequiredObligations(kaoResults)} return r.buildKey(ctx, kaoResults) } @@ -1515,3 +1525,55 @@ func createKaoTemplateFromKasInfo(kasInfoArr []KASInfo) []kaoTpl { return kaoTemplate } + +func getKasErrorToReturn(err error, defaultError error) error { + errToReturn := defaultError + if strings.Contains(err.Error(), codes.InvalidArgument.String()) { + errToReturn = errors.Join(ErrRewrapBadRequest, errToReturn) + } else if strings.Contains(err.Error(), codes.PermissionDenied.String()) { + errToReturn = errors.Join(ErrRewrapForbidden, errToReturn) + } + + return errToReturn +} + +func getKasAllowList(ctx context.Context, kasAllowList AllowList, s SDK, ignoreAllowList bool) (AllowList, error) { + allowList := kasAllowList + if len(allowList) == 0 && !ignoreAllowList { + if s.KeyAccessServerRegistry == nil { + slog.Error("no KAS allowlist provided and no KeyAccessServerRegistry available") + return nil, errors.New("no KAS allowlist provided and no KeyAccessServerRegistry available") + } + + // retrieve the registered kases if not provided + platformEndpoint, err := s.PlatformConfiguration.platformEndpoint() + if err != nil { + return nil, fmt.Errorf("retrieving platformEndpoint failed: %w", err) + } + allowList, err = allowListFromKASRegistry(ctx, s.logger, s.KeyAccessServerRegistry, platformEndpoint) + if err != nil { + return nil, fmt.Errorf("allowListFromKASRegistry failed: %w", err) + } + } + + return allowList, nil +} + +func dedupRequiredObligations(kaoResults []kaoResult) []string { + seen := make(map[string]struct{}) + dedupedOblgs := make([]string, 0) + for _, kao := range kaoResults { + for _, oblg := range kao.RequiredObligations { + normalizedOblg := strings.TrimSpace(strings.ToLower(oblg)) + if len(normalizedOblg) == 0 { + continue + } + if _, ok := seen[normalizedOblg]; !ok { + seen[normalizedOblg] = struct{}{} + dedupedOblgs = append(dedupedOblgs, normalizedOblg) + } + } + } + + return dedupedOblgs +} diff --git a/sdk/tdf_config.go b/sdk/tdf_config.go index dc4ff8a05..081cf57e9 100644 --- a/sdk/tdf_config.go +++ b/sdk/tdf_config.go @@ -270,6 +270,7 @@ type TDFReaderConfig struct { kasSessionKey ocrypto.KeyPair kasAllowlist AllowList // KAS URLs that are allowed to be used for reading TDFs ignoreAllowList bool // If true, the kasAllowlist will be ignored, and all KAS URLs will be allowed + fulfillableObligationFQNs []string } type AllowList map[string]bool @@ -420,6 +421,13 @@ func WithIgnoreAllowlist(ignore bool) TDFReaderOption { } } +func WithTDFFulfillableObligationFQNs(fqns []string) TDFReaderOption { + return func(c *TDFReaderConfig) error { + c.fulfillableObligationFQNs = fqns + return nil + } +} + func withSessionKey(k ocrypto.KeyPair) TDFReaderOption { return func(c *TDFReaderConfig) error { c.kasSessionKey = k diff --git a/sdk/tdf_test.go b/sdk/tdf_test.go index cbd117266..d8274eea8 100644 --- a/sdk/tdf_test.go +++ b/sdk/tdf_test.go @@ -4,11 +4,13 @@ import ( "archive/zip" "bytes" "context" + "crypto" "crypto/ecdsa" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" + "encoding/base64" "encoding/hex" "encoding/json" "encoding/pem" @@ -73,6 +75,7 @@ type tdfTest struct { splitPlan []keySplitStep policy []AttributeValueFQN expectedPlanSize int + opts []TDFReaderOption } type baseKeyTest struct { @@ -304,6 +307,24 @@ type TDFSuite struct { fakeWellKnown map[string]interface{} } +type Policy struct { + UUID string `json:"uuid"` + Body KasPolicyBody `json:"body"` +} + +type KasPolicyBody struct { + DataAttributes []Attribute `json:"dataAttributes"` + Dissem []string `json:"dissem"` +} + +type Attribute struct { + URI string `json:"attribute"` // attribute + PublicKey crypto.PublicKey `json:"pubKey"` // pubKey + ProviderURI string `json:"kasUrl"` // kasUrl + SchemaVersion string `json:"tdf_spec_version,omitempty"` + Name string `json:"displayName"` // displayName +} + func (s *TDFSuite) SetupSuite() { // Set up the test environment s.startBackend() @@ -1905,6 +1926,404 @@ func (s *TDFSuite) Test_KeySplits() { } } +func (s *TDFSuite) Test_Obligations_Decrypt() { + for _, test := range []struct { + n string + fileSize int64 + tdfFileSize float64 + checksum string + requiredObligationFQNs []string + opts []TDFOption + fulfillableObligations []string + attrValueFQNs []AttributeValueFQN + expectError bool + }{ + { + n: "two-attributes-same-kas-with-fulfillable-obligations", + fileSize: 5, + tdfFileSize: 1909, + checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", + requiredObligationFQNs: []string{obligationWatermark, obligationGeofence}, + opts: []TDFOption{WithDataAttributes(oa1.key, oa2.key)}, // Both go to obligationKas + fulfillableObligations: []string{obligationWatermark, obligationGeofence}, + attrValueFQNs: []AttributeValueFQN{oa1, oa2}, + expectError: false, + }, + { + n: "two-attributes-same-kas-no-fulfillable-obligations", + fileSize: 5, + tdfFileSize: 1909, + checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", + requiredObligationFQNs: []string{obligationWatermark, obligationGeofence}, + opts: []TDFOption{WithDataAttributes(oa1.key, oa2.key)}, // Both go to obligationKas + fulfillableObligations: []string{}, // No fulfillable obligations + expectError: true, + }, + { + n: "fulfill-one-of-two-attributes-same-kas", + fileSize: 5, + tdfFileSize: 1909, + checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", + requiredObligationFQNs: []string{obligationWatermark, obligationGeofence}, + opts: []TDFOption{WithDataAttributes(oa1.key, oa2.key)}, + fulfillableObligations: []string{obligationWatermark}, + expectError: true, + }, + } { + s.Run(test.n, func() { + // Create a new SDK instance with limited fulfillable obligations + plainTextFileName := test.n + ".txt" + tdfFileName := plainTextFileName + ".tdf" + decryptedTdfFileName := tdfFileName + ".txt" + + defer func() { + // Remove the test files + _ = os.Remove(plainTextFileName) + _ = os.Remove(tdfFileName) + _ = os.Remove(decryptedTdfFileName) + }() + + // test encrypt using the default SDK (which has all fulfillable obligations) + s.testEncrypt(s.sdk, test.opts, plainTextFileName, tdfFileName, tdfTest{ + n: test.n, + fileSize: test.fileSize, + tdfFileSize: test.tdfFileSize, + checksum: test.checksum, + }) + + readSeeker, err := os.Open(tdfFileName) + s.Require().NoError(err) + defer func(readSeeker *os.File) { + err := readSeeker.Close() + s.Require().NoError(err) + }(readSeeker) + + r, err := s.sdk.LoadTDF(readSeeker) + s.Require().NoError(err) + r.config.fulfillableObligationFQNs = test.fulfillableObligations + + if !test.expectError { + // Validate successful decryption + s.testDecryptWithReader(s.sdk, tdfFileName, decryptedTdfFileName, tdfTest{ + n: test.n, + fileSize: test.fileSize, + checksum: test.checksum, + policy: test.attrValueFQNs, + opts: []TDFReaderOption{WithTDFFulfillableObligationFQNs(test.fulfillableObligations)}, + }) + + _, err = r.WriteTo(io.Discard) + s.Require().NoError(err) + } else { + // The decryption should fail due to unmet obligations + _, err = r.WriteTo(io.Discard) + s.Require().Error(err, "Decryption should fail when obligations are not met") + } + + obligations, err := r.Obligations(s.T().Context()) + s.Require().NoError(err) + s.Require().NotNil(obligations, "Obligations should not be nil") + s.Require().Len(obligations.FQNs, len(test.requiredObligationFQNs), "Should have correct number of obligations") + actualObligations := obligations + for _, ob := range test.requiredObligationFQNs { + s.Require().Contains(actualObligations.FQNs, ob, "Actual obligations should contain "+ob) + } + }) + } +} + +func (s *TDFSuite) Test_Obligations() { + originalV2 := s.sdk.AuthorizationV2 + defer func() { + s.sdk.AuthorizationV2 = originalV2 + }() + + // Define test cases covering all code paths in Obligations() + testCases := []struct { + name string + requiredObligations []string + fulfillableObligationFQNs []string + shouldReturnError bool + expectedError error + prepopulatedObligations []string + }{ + { + name: "Rewrap not called - Error", + fulfillableObligationFQNs: []string{obligationWatermark}, + shouldReturnError: true, + expectedError: ErrObligationsNotPopulated, + }, + { + name: "Rewrap called - No Error", + requiredObligations: []string{obligationGeofence}, + fulfillableObligationFQNs: []string{obligationGeofence}, + shouldReturnError: false, + expectedError: nil, + prepopulatedObligations: []string{obligationGeofence}, + }, + } + + for _, tc := range testCases { + s.Run(tc.name, func() { + // Create test files for each test case + plainTextFileName := fmt.Sprintf("obligations_%s.txt", strings.ReplaceAll(tc.name, " ", "_")) + tdfFileName := plainTextFileName + ".tdf" + + defer func() { + _ = os.Remove(plainTextFileName) + _ = os.Remove(tdfFileName) + }() + + // Encrypt the TDF file for testing + opts := []TDFOption{WithKasInformation(s.kases[0].KASInfo), WithDataAttributes(oa1.key, oa2.key, oa3.key)} + s.testEncrypt(s.sdk, opts, plainTextFileName, tdfFileName, tdfTest{ + n: strings.ReplaceAll(tc.name, " ", "_"), + fileSize: 5, + tdfFileSize: 2690, + checksum: "ed968e840d10d2d313a870bc131a4e2c311d7ad09bdf32b3418147221f51a6e2", + }) + + // Load TDF with specified fulfillable obligations + readSeeker, err := os.Open(tdfFileName) + s.Require().NoError(err) + defer func(readSeeker *os.File) { + err := readSeeker.Close() + s.Require().NoError(err) + }(readSeeker) + + var loadOpts []TDFReaderOption + if len(tc.fulfillableObligationFQNs) > 0 { + loadOpts = append(loadOpts, WithTDFFulfillableObligationFQNs(tc.fulfillableObligationFQNs)) + } + + r, err := s.sdk.LoadTDF(readSeeker, loadOpts...) + s.Require().NoError(err) + + // Verify fulfillable obligations were set correctly + if len(tc.fulfillableObligationFQNs) > 0 { + s.Require().Len(r.config.fulfillableObligationFQNs, len(tc.fulfillableObligationFQNs), "Should have correct number of fulfillable obligations") + for _, ob := range tc.fulfillableObligationFQNs { + s.Require().Contains(r.config.fulfillableObligationFQNs, ob, "Should contain fulfillable obligation "+ob) + } + } + + if tc.prepopulatedObligations != nil { + r.requiredObligations = &Obligations{FQNs: tc.prepopulatedObligations} + } + + // First call to Obligations() - this should trigger GetDecision + obligations, err := r.Obligations(s.T().Context()) + + if tc.shouldReturnError { + s.Require().Error(err, "Expected error for test case: %s", tc.name) + if tc.expectedError != nil { + s.Require().ErrorIs(err, tc.expectedError, "Error should be of expected type") + } + return + } + + s.Require().NoError(err, "Should not return error for test case: %s", tc.name) + s.Require().NotNil(obligations, "Obligations should not be nil") + s.Require().Len(obligations.FQNs, len(tc.requiredObligations), "Should have correct number of obligations") + for _, ob := range tc.requiredObligations { + s.Require().Contains(obligations.FQNs, ob, "Actual obligations should contain "+ob) + } + + // Second call to Obligations() - this should use cached result + obligations2, err := r.Obligations(s.T().Context()) + s.Require().NoError(err, "Second call should not return error") + s.Require().NotNil(obligations2, "Second call obligations should not be nil") + s.Require().Equal(obligations, obligations2, "Second call should return same obligations") + }) + } +} + +func TestDedupRequiredObligations(t *testing.T) { + testCases := []struct { + name string + kaoResults []kaoResult + expectedResult []string + }{ + { + name: "empty input", + kaoResults: []kaoResult{}, + expectedResult: []string{}, + }, + { + name: "single kao with no obligations", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{}, + }, + }, + expectedResult: []string{}, + }, + { + name: "single kao with single obligation", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + }, + expectedResult: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + name: "single kao with multiple obligations", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + }, + }, + }, + expectedResult: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + }, + }, + { + name: "multiple kaos with same obligations - should dedupe", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + }, + expectedResult: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + name: "multiple kaos with different obligations", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{"https://demo.com/obl/test/value/geofence"}, + }, + }, + expectedResult: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + }, + }, + { + name: "case insensitive deduplication", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{"https://demo.com/obl/test/value/WATERMARK"}, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + }, + expectedResult: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + name: "whitespace trimming and deduplication", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{" https://demo.com/obl/test/value/watermark "}, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{"https://demo.com/obl/test/value/watermark"}, + }, + }, + expectedResult: []string{"https://demo.com/obl/test/value/watermark"}, + }, + { + name: "complex case - mixed duplicates with case and whitespace variations", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{ + "https://demo.com/obl/test/value/WATERMARK", + "https://demo.com/obl/test/value/geofence", + }, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{ + " https://demo.com/obl/test/value/watermark ", + "https://demo.com/obl/test/value/ENCRYPTION", + }, + }, + { + KeyAccessObjectID: "kao-3", + RequiredObligations: []string{ + "https://demo.com/obl/test/value/geofence", + "https://demo.com/obl/test/value/encryption", + }, + }, + }, + expectedResult: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + "https://demo.com/obl/test/value/encryption", + }, + }, + { + name: "empty string obligations should be normalized", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{ + "", + " ", + "https://demo.com/obl/test/value/watermark", + }, + }, + }, + expectedResult: []string{ + "https://demo.com/obl/test/value/watermark", + }, + }, + { + name: "preserve order of first occurrence", + kaoResults: []kaoResult{ + { + KeyAccessObjectID: "kao-1", + RequiredObligations: []string{ + "https://demo.com/obl/test/value/geofence", + "https://demo.com/obl/test/value/watermark", + }, + }, + { + KeyAccessObjectID: "kao-2", + RequiredObligations: []string{ + "https://demo.com/obl/test/value/watermark", + "https://demo.com/obl/test/value/geofence", + }, + }, + }, + expectedResult: []string{ + "https://demo.com/obl/test/value/geofence", + "https://demo.com/obl/test/value/watermark", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := dedupRequiredObligations(tc.kaoResults) + assert.Equal(t, tc.expectedResult, result, "Deduplication result should match expected") + }) + } +} + func (s *TDFSuite) Test_Autoconfigure() { for index, test := range []tdfTest{ { @@ -2047,7 +2466,7 @@ func (s *TDFSuite) testDecryptWithReader(sdk *SDK, tdfFile, decryptedTdfFileName s.Require().NoError(err) }(readSeeker) - r, err := sdk.LoadTDF(readSeeker) + r, err := sdk.LoadTDF(readSeeker, test.opts...) s.Require().NoError(err) ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(300*time.Minute)) @@ -2145,7 +2564,7 @@ func (s *TDFSuite) startBackend() { {"https://a.kas/", mockRSAPrivateKey1, mockRSAPublicKey1, defaultKID}, {"https://b.kas/", mockRSAPrivateKey2, mockRSAPublicKey2, defaultKID}, {"https://c.kas/", mockRSAPrivateKey3, mockRSAPublicKey3, defaultKID}, - {"https://d.kas/", mockECPrivateKey1, mockECPublicKey1, defaultKID}, + {"https://d.kas/", mockECPrivateKey1, mockECPublicKey1, "e1"}, {"https://e.kas/", mockECPrivateKey2, mockECPublicKey2, defaultKID}, {kasAu, mockRSAPrivateKey1, mockRSAPublicKey1, defaultKID}, {kasCa, mockRSAPrivateKey2, mockRSAPublicKey2, defaultKID}, @@ -2154,6 +2573,7 @@ func (s *TDFSuite) startBackend() { {kasUs, mockRSAPrivateKey1, mockRSAPublicKey1, defaultKID}, {baseKeyURL, mockRSAPrivateKey1, mockRSAPublicKey1, baseKeyKID}, {evenMoreSpecificKas, mockRSAPrivateKey3, mockRSAPublicKey3, "r3"}, + {obligationKas, mockRSAPrivateKey3, mockRSAPublicKey3, "r3"}, } fkar := &FakeKASRegistry{kases: kasesToMake, s: s} @@ -2170,7 +2590,8 @@ func (s *TDFSuite) startBackend() { s: s, privateKey: ki.private, KASInfo: KASInfo{ URL: ki.url, PublicKey: ki.public, KID: ki.kid, Algorithm: "rsa:2048", }, - legakeys: map[string]keyInfo{}, + legakeys: map[string]keyInfo{}, + attrToRequiredObligations: obligationMap, } path, handler := attributesconnect.NewAttributesServiceHandler(fa) mux.Handle(path, handler) @@ -2199,7 +2620,8 @@ func (s *TDFSuite) startBackend() { WithClientCredentials("test", "test", nil), withCustomAccessTokenSource(&ats), WithTokenEndpoint("http://localhost:65432/auth/token"), - WithInsecurePlaintextConn()) + WithInsecurePlaintextConn(), + ) s.Require().NoError(err) s.sdk = sdk } @@ -2277,9 +2699,10 @@ func (f *FakeKASRegistry) ListKeyAccessServers(_ context.Context, _ *connect.Req type FakeKas struct { kasconnect.UnimplementedAccessServiceHandler KASInfo - privateKey string - s *TDFSuite - legakeys map[string]keyInfo + privateKey string + s *TDFSuite + legakeys map[string]keyInfo + attrToRequiredObligations map[string]string } func (f *FakeKas) Rewrap(_ context.Context, in *connect.Request[kaspb.RewrapRequest]) (*connect.Response[kaspb.RewrapResponse], error) { @@ -2299,7 +2722,24 @@ func (f *FakeKas) Rewrap(_ context.Context, in *connect.Request[kaspb.RewrapRequ if !ok { return nil, errors.New("requestBody not a string") } - result := f.getRewrapResponse(requestBodyStr) + + // Extract fulfillable obligations from header + var fulfillableObligations []string + if val := in.Header().Get("X-Rewrap-Additional-Context"); val != "" { + decoded, err := base64.StdEncoding.DecodeString(val) + if err == nil { + var rewrapContext struct { + Obligations struct { + FulfillableFQNs []string `json:"fulfillableFQNs"` + } `json:"obligations"` + } + if json.Unmarshal(decoded, &rewrapContext) == nil { + fulfillableObligations = rewrapContext.Obligations.FulfillableFQNs + } + } + } + + result := f.getRewrapResponse(requestBodyStr, fulfillableObligations) return connect.NewResponse(result), nil } @@ -2308,13 +2748,35 @@ func (f *FakeKas) PublicKey(_ context.Context, _ *connect.Request[kaspb.PublicKe return connect.NewResponse(&kaspb.PublicKeyResponse{PublicKey: f.KASInfo.PublicKey, Kid: f.KID}), nil } -func (f *FakeKas) getRewrapResponse(rewrapRequest string) *kaspb.RewrapResponse { +func (f *FakeKas) getRewrapResponse(rewrapRequest string, fulfillableObligations []string) *kaspb.RewrapResponse { bodyData := kaspb.UnsignedRewrapRequest{} err := protojson.Unmarshal([]byte(rewrapRequest), &bodyData) f.s.Require().NoError(err, "json.Unmarshal failed") resp := &kaspb.RewrapResponse{} for _, req := range bodyData.GetRequests() { + requiredObligations := f.s.checkPolicyObligations(f.attrToRequiredObligations, req) + if f.KASInfo.URL == f.s.kasTestURLLookup[obligationKas] { + // Only return failures for obligation kas URL + if !f.s.checkObligationsFulfillment(requiredObligations, fulfillableObligations) { + // Return a deny response if obligations are not fulfilled + results := &kaspb.PolicyRewrapResult{PolicyId: req.GetPolicy().GetId()} + for _, kaoReq := range req.GetKeyAccessObjects() { + kaoResult := &kaspb.KeyAccessRewrapResult{ + Result: &kaspb.KeyAccessRewrapResult_Error{ + Error: "forbidden", + }, + Status: "deny", + KeyAccessObjectId: kaoReq.GetKeyAccessObjectId(), + Metadata: createMetadataWithObligations(requiredObligations), + } + results.Results = append(results.Results, kaoResult) + } + resp.Responses = append(resp.Responses, results) + continue + } + } + results := &kaspb.PolicyRewrapResult{PolicyId: req.GetPolicy().GetId()} resp.Responses = append(resp.Responses, results) for _, kaoReq := range req.GetKeyAccessObjects() { @@ -2405,13 +2867,50 @@ func (f *FakeKas) getRewrapResponse(rewrapRequest string) *kaspb.RewrapResponse Result: &kaspb.KeyAccessRewrapResult_KasWrappedKey{KasWrappedKey: entityWrappedKey}, Status: "permit", KeyAccessObjectId: kaoReq.GetKeyAccessObjectId(), + Metadata: createMetadataWithObligations(requiredObligations), } results.Results = append(results.Results, kaoResult) } } + return resp } +func (s *TDFSuite) checkPolicyObligations(obligationsMap map[string]string, req *kaspb.UnsignedRewrapRequest_WithPolicyRequest) []string { + var requiredObligations []string + sDecPolicy, policyErr := base64.StdEncoding.DecodeString(req.GetPolicy().GetBody()) + policy := &Policy{} + if policyErr == nil { + policyErr = json.Unmarshal(sDecPolicy, policy) + if policyErr != nil { + return requiredObligations + } + } + for _, attr := range policy.Body.DataAttributes { + if val, found := obligationsMap[attr.URI]; found { + requiredObligations = append(requiredObligations, val) + } + } + return requiredObligations +} + +func (s *TDFSuite) checkObligationsFulfillment(requiredObligations, fulfillableObligations []string) bool { + // Create a set of fulfillable obligations for fast lookup + fulfillableSet := make(map[string]bool) + for _, obligation := range fulfillableObligations { + fulfillableSet[obligation] = true + } + + // Check if all required obligations are in the fulfillable set + for _, required := range requiredObligations { + if !fulfillableSet[required] { + return false + } + } + + return true +} + func (s *TDFSuite) checkIdentical(file, checksum string) bool { f, err := os.Open(file) s.Require().NoError(err, "os.Open failed") @@ -2560,3 +3059,27 @@ func TestIsLessThanSemver(t *testing.T) { }) } } + +func TestGetKasErrorToReturn(t *testing.T) { + defaultError := errors.New("default KAS error") + + t.Run("InvalidArgument error returns ErrRewrapBadRequest", func(t *testing.T) { + inputError := errors.New("rpc error: code = InvalidArgument desc = invalid request") + result := getKasErrorToReturn(inputError, defaultError) + require.ErrorIs(t, result, ErrRewrapBadRequest) + require.ErrorIs(t, result, defaultError) + }) + + t.Run("PermissionDenied error returns ErrRewrapForbidden", func(t *testing.T) { + inputError := errors.New("rpc error: code = PermissionDenied desc = access denied") + result := getKasErrorToReturn(inputError, defaultError) + require.ErrorIs(t, result, ErrRewrapForbidden) + require.ErrorIs(t, result, defaultError) + }) + + t.Run("Other error returns default error unchanged", func(t *testing.T) { + inputError := errors.New("rpc error: code = Internal desc = internal server error") + result := getKasErrorToReturn(inputError, defaultError) + require.Equal(t, defaultError, result) + }) +} diff --git a/sdk/tdferrors.go b/sdk/tdferrors.go index 0ecd659e8..f19331fa8 100644 --- a/sdk/tdferrors.go +++ b/sdk/tdferrors.go @@ -10,7 +10,6 @@ var ( errWriteFailed = errors.New("tdf: io.writer fail to write all bytes") errInvalidKasInfo = errors.New("tdf: kas information is missing") errKasPubKeyMissing = errors.New("tdf: kas public key is missing") - errRewrapForbidden = errors.New("tdf: rewrap request 403") // Exposed tamper detection errors, Catch all possible tamper errors with errors.Is(ErrTampered) ErrTampered = errors.New("tamper detected") @@ -21,6 +20,7 @@ var ( ErrTDFPayloadInvalidOffset = fmt.Errorf("[%w] sdk.Reader.ReadAt: negative offset", ErrTampered) ErrRewrapBadRequest = fmt.Errorf("[%w] tdf: rewrap request 400", ErrTampered) ErrRootSignatureFailure = fmt.Errorf("[%w] tdf: issue verifying root signature", ErrTampered) + ErrRewrapForbidden = errors.New("tdf: rewrap request 403") ) // Custom error struct for Assertion errors