Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 135 additions & 25 deletions sdk/bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ import (

// BulkTDF: Reader is TDF Content. Writer writes encrypted data. Error is the error that occurs if decrypting fails.
type BulkTDF struct {
Reader io.ReadSeeker
Writer io.Writer
Error error
Reader io.ReadSeeker
Writer io.Writer
Error error
TriggeredObligations Obligations
}

type BulkDecryptRequest struct {
Expand All @@ -26,6 +27,13 @@ type BulkDecryptRequest struct {
ignoreAllowList bool
}

// BulkDecryptPrepared holds the prepared state for bulk decryption
type BulkDecryptPrepared struct {
PolicyTDF map[string]*BulkTDF
tdfDecryptors map[string]decryptor
allRewrapResp map[string]policyResult
}

// BulkErrors List of Errors that Failed during Bulk Decryption
type BulkErrors []error

Expand Down Expand Up @@ -116,17 +124,9 @@ func (s SDK) createDecryptor(tdf *BulkTDF, req *BulkDecryptRequest) (decryptor,
return nil, fmt.Errorf("unknown tdf type: %s", req.TDFType)
}

// BulkDecrypt Decrypts a list of BulkTDF and if a partial failure of TDFs unable to be decrypted, BulkErrors would be returned.
func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
bulkReq, createError := createBulkRewrapRequest(opts...)
if createError != nil {
return fmt.Errorf("failed to create bulk rewrap request: %w", createError)
}
kasRewrapRequests := make(map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest)
tdfDecryptors := make(map[string]decryptor)
policyTDF := make(map[string]*BulkTDF)

if !bulkReq.ignoreAllowList && len(bulkReq.kasAllowlist) == 0 { //nolint:nestif // if kasAllowlist is not set, we get it from the registry
// setupKasAllowlist configures the KAS allowlist for the bulk request
func (s SDK) setupKasAllowlist(ctx context.Context, bulkReq *BulkDecryptRequest) error {
if !bulkReq.ignoreAllowList && len(bulkReq.kasAllowlist) == 0 { //nolint:nestif // not complex
if s.KeyAccessServerRegistry != nil {
platformEndpoint, err := s.PlatformConfiguration.platformEndpoint()
if err != nil {
Expand All @@ -145,10 +145,18 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
return errors.New("no KAS allowlist provided and no KeyAccessServerRegistry available")
}
}
return nil
}

// prepareDecryptors creates decryptors and rewrap requests for all TDFs
func (s SDK) prepareDecryptors(ctx context.Context, bulkReq *BulkDecryptRequest) (map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, map[string]decryptor, map[string]*BulkTDF) {
kasRewrapRequests := make(map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest)
tdfDecryptors := make(map[string]decryptor)
policyTDF := make(map[string]*BulkTDF)

for i, tdf := range bulkReq.TDFs {
policyID := fmt.Sprintf("policy-%d", i)
decryptor, err := s.createDecryptor(tdf, bulkReq) //nolint:contextcheck // dont want to change signature of LoadTDF
decryptor, err := s.createDecryptor(tdf, bulkReq) //nolint:contextcheck // context is not used in createDecryptor
if err != nil {
tdf.Error = err
continue
Expand All @@ -167,9 +175,15 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
}
}

kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey)
allRewrapResp := make(map[string][]kaoResult)
return kasRewrapRequests, tdfDecryptors, policyTDF
}

// performRewraps executes all rewrap requests with KAS servers
func (s SDK) performRewraps(ctx context.Context, bulkReq *BulkDecryptRequest, kasRewrapRequests map[string][]*kas.UnsignedRewrapRequest_WithPolicyRequest, fulfillableObligations []string) (map[string]policyResult, error) {
kasClient := newKASClient(s.conn.Client, s.conn.Options, s.tokenSource, s.kasSessionKey, fulfillableObligations)
allRewrapResp := make(map[string]policyResult)
var err error

for kasurl, rewrapRequests := range kasRewrapRequests {
if bulkReq.ignoreAllowList {
s.Logger().Warn("kasAllowlist is ignored, kas url is allowed", slog.String("kas_url", kasurl))
Expand All @@ -178,15 +192,21 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
for _, req := range rewrapRequests {
id := req.GetPolicy().GetId()
for _, kao := range req.GetKeyAccessObjects() {
allRewrapResp[id] = append(allRewrapResp[id], kaoResult{
policyRewrapResp, ok := allRewrapResp[id]
if !ok {
policyRewrapResp = policyResult{policyID: id, obligations: []string{}, kaoRes: []kaoResult{}}
}
policyRewrapResp.kaoRes = append(policyRewrapResp.kaoRes, kaoResult{
Error: fmt.Errorf("KasAllowlist: kas url %s is not allowed", kasurl),
KeyAccessObjectID: kao.GetKeyAccessObjectId(),
})
allRewrapResp[id] = policyRewrapResp
}
}
continue
}
var rewrapResp map[string][]kaoResult

var rewrapResp map[string]policyResult
switch bulkReq.TDFType {
case Nano:
rewrapResp, err = kasClient.nanoUnwrap(ctx, rewrapRequests...)
Expand All @@ -195,23 +215,86 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
}

for id, res := range rewrapResp {
allRewrapResp[id] = append(allRewrapResp[id], res...)
// ! It's possible that we already created a policyResult for the policy above for a specific KAS URL.
// ! Meaning for another kas url of the same policy we will end up with an empty list of obligations.
// ! This should be fine since we will error out anyways.
if existingResp, ok := allRewrapResp[id]; !ok {
allRewrapResp[id] = res
} else {
// ! Should not need to append obligations since they should be the same for all TDFs under a policy
existingResp.kaoRes = append(existingResp.kaoRes, res.kaoRes...)
allRewrapResp[id] = existingResp
}
}
}

if err != nil {
return fmt.Errorf("bulk rewrap failed: %w", err)
return nil, fmt.Errorf("bulk rewrap failed: %w", err)
}

return allRewrapResp, nil
}

// PrepareBulkDecrypt does everything except decrypt from the Bulk Decrypt
// ! Currently you cannot specify fulfillable obligations on an individual TDF basis
func (s SDK) PrepareBulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) (*BulkDecryptPrepared, error) {
bulkReq, createError := createBulkRewrapRequest(opts...)
if createError != nil {
return nil, fmt.Errorf("failed to create bulk rewrap request: %w", createError)
}

// Setup KAS allowlist
if err := s.setupKasAllowlist(ctx, bulkReq); err != nil {
return nil, err
}

// Prepare decryptors and rewrap requests
kasRewrapRequests, tdfDecryptors, policyTDF := s.prepareDecryptors(ctx, bulkReq)

// Use the default fulfillable obligations unless a decryptor is available to provide its own
fulfillableObligations := s.fulfillableObligationFQNs
if len(tdfDecryptors) > 0 {
for _, d := range tdfDecryptors {
fulfillableObligations = getFulfillableObligations(d)
break
}
}

// Perform rewraps
allRewrapResp, err := s.performRewraps(ctx, bulkReq, kasRewrapRequests, fulfillableObligations)
if err != nil {
return nil, err
}

var errList []error
for id, tdf := range policyTDF {
kaoRes, ok := allRewrapResp[id]
policyRes, ok := allRewrapResp[id]
if !ok {
tdf.Error = errors.New("rewrap did not create a response for this TDF")
continue
}
tdf.TriggeredObligations = Obligations{FQNs: policyRes.obligations}
}

return &BulkDecryptPrepared{
PolicyTDF: policyTDF,
tdfDecryptors: tdfDecryptors,
allRewrapResp: allRewrapResp,
}, nil
}

// Allow the bulk decryption to occur
func (bp *BulkDecryptPrepared) BulkDecrypt(ctx context.Context) error {
var errList []error
var err error
for id, tdf := range bp.PolicyTDF {
policyRes, ok := bp.allRewrapResp[id]
if !ok {
tdf.Error = errors.New("rewrap did not create a response for this TDF")
errList = append(errList, tdf.Error)
continue
}
decryptor := tdfDecryptors[id]
if _, err = decryptor.Decrypt(ctx, kaoRes); err != nil {
decryptor := bp.tdfDecryptors[id]
if _, err = decryptor.Decrypt(ctx, policyRes.kaoRes); err != nil {
tdf.Error = err
errList = append(errList, tdf.Error)
continue
Expand All @@ -225,9 +308,36 @@ func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
return nil
}

// BulkDecrypt Decrypts a list of BulkTDF and if a partial failure of TDFs unable to be decrypted, BulkErrors would be returned.
func (s SDK) BulkDecrypt(ctx context.Context, opts ...BulkDecryptOption) error {
prepared, err := s.PrepareBulkDecrypt(ctx, opts...)
if err != nil {
return err
}

return prepared.BulkDecrypt(ctx)
}

func (b *BulkDecryptRequest) appendTDFs(tdfs ...*BulkTDF) {
b.TDFs = append(
b.TDFs,
tdfs...,
)
}

func getFulfillableObligations(decryptor decryptor) []string {
if decryptor == nil {
slog.Warn("decryptor is nil, cannot populate obligations")
return make([]string, 0)
}

switch d := decryptor.(type) {
case *tdf3DecryptHandler:
return d.reader.config.fulfillableObligationFQNs
case *NanoTDFDecryptHandler:
return d.config.fulfillableObligationFQNs
default:
slog.Warn("unknown decryptor type, cannot populate obligations", slog.String("type", fmt.Sprintf("%T", d)))
return make([]string, 0)
}
}
27 changes: 27 additions & 0 deletions sdk/granter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const (
specifiedKas = "https://attr.kas.com/"
evenMoreSpecificKas = "https://value.kas.com/"
lessSpecificKas = "https://namespace.kas.com/"
obligationKas = "https://obligation.kas.com/"
fakePem = mockRSAPublicKey1
)

Expand Down Expand Up @@ -75,6 +76,12 @@ var (
mpc, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/c")
mpd, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/d")
mpu, _ = NewAttributeValueFQN("https://virtru.com/attr/mapped/value/unspecified")

// Attributes for testing obligations
OBLIGATION, _ = NewAttributeNameFQN("https://virtru.com/attr/obligation")
obWatermark, _ = NewAttributeValueFQN("https://virtru.com/attr/obligation/value/watermark")
obRedact, _ = NewAttributeValueFQN("https://virtru.com/attr/obligation/value/redact")
obGeo, _ = NewAttributeValueFQN("https://virtru.com/attr/obligation/value/geofence")
)

func spongeCase(s string) string {
Expand Down Expand Up @@ -211,6 +218,14 @@ func mockAttributeFor(fqn AttributeNameFQN) *policy.Attribute {
Rule: policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF,
Fqn: fqn.String(),
}
case OBLIGATION.key:
return &policy.Attribute{
Id: "OBL",
Namespace: &nsOne,
Name: "obligation",
Rule: policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF,
Fqn: fqn.String(),
}
}
return nil
}
Expand Down Expand Up @@ -452,6 +467,18 @@ func mockValueFor(fqn AttributeValueFQN) *policy.Value {
p.Grants = make([]*policy.KeyAccessServer, 1)
p.Grants[0] = mockGrant(evenMoreSpecificKas, "r1")
}
case OBLIGATION.key:
switch strings.ToLower(fqn.Value()) {
case "watermark":
p.KasKeys = make([]*policy.SimpleKasKey, 1)
p.KasKeys[0] = mockSimpleKasKey(obligationKas, "r3")
case "redact":
p.KasKeys = make([]*policy.SimpleKasKey, 1)
p.KasKeys[0] = mockSimpleKasKey(obligationKas, "r3")
case "geofence":
p.KasKeys = make([]*policy.SimpleKasKey, 1)
p.KasKeys[0] = mockSimpleKasKey("https://d.kas/", "e1")
}
}
return &p
}
Expand Down
Loading
Loading