Skip to content

Commit

Permalink
chore: implement best practices for transaction through helper method (
Browse files Browse the repository at this point in the history
…chainloop-dev#1605)

Signed-off-by: Jose I. Paris <[email protected]>
  • Loading branch information
jiparis authored Dec 2, 2024
1 parent 8a07f99 commit 0bde1a6
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 348 deletions.
69 changes: 29 additions & 40 deletions app/controlplane/pkg/data/attestationstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,50 +53,39 @@ func (r *AttestationStateRepo) Initialized(ctx context.Context, runID uuid.UUID)

// baseDigest, when provided will be used to check that it matches the digest of the state currently in the DB
// if the digests do not match, the state has been modified and the caller should retry
func (r *AttestationStateRepo) Save(ctx context.Context, runID uuid.UUID, state []byte, baseDigest string) (err error) {
tx, err := r.data.DB.Tx(ctx)
if err != nil {
return fmt.Errorf("failed to create transaction: %w", err)
}

defer func() {
// Unblock the row if there was an error
if err != nil {
_ = tx.Rollback()
func (r *AttestationStateRepo) Save(ctx context.Context, runID uuid.UUID, state []byte, baseDigest string) error {
return WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
// compared the provided digest with the digest of the state in the DB
// TODO: make digest check mandatory on updates
if baseDigest != "" {
// Get the run but BLOCK IT for update
run, err := tx.WorkflowRun.Query().ForUpdate().Where(workflowrun.ID(runID)).Only(ctx)
if err != nil && !ent.IsNotFound(err) {
return fmt.Errorf("failed to read attestation state: %w", err)
} else if run == nil || run.AttestationState == nil {
return biz.NewErrNotFound("attestation state")
}

// calculate the digest of the current state
storedDigest, err := digest(run.AttestationState)
if err != nil {
return fmt.Errorf("failed to calculate digest: %w", err)
}

if baseDigest != storedDigest {
return biz.NewErrAttestationStateConflict(storedDigest, baseDigest)
}
}
}()

// compared the provided digest with the digest of the state in the DB
// TODO: make digest check mandatory on updates
if baseDigest != "" {
// Get the run but BLOCK IT for update
run, err := tx.WorkflowRun.Query().ForUpdate().Where(workflowrun.ID(runID)).Only(ctx)
// Update it in the DB if the digest matches
err := tx.WorkflowRun.UpdateOneID(runID).SetAttestationState(state).Exec(ctx)
if err != nil && !ent.IsNotFound(err) {
return fmt.Errorf("failed to read attestation state: %w", err)
} else if run == nil || run.AttestationState == nil {
return biz.NewErrNotFound("attestation state")
}

// calculate the digest of the current state
storedDigest, err := digest(run.AttestationState)
if err != nil {
return fmt.Errorf("failed to calculate digest: %w", err)
}

if baseDigest != storedDigest {
return biz.NewErrAttestationStateConflict(storedDigest, baseDigest)
return fmt.Errorf("failed to store attestation state: %w", err)
} else if err != nil {
return biz.NewErrNotFound("workflow run")
}
}

// Update it in the DB if the digest matches
err = tx.WorkflowRun.UpdateOneID(runID).SetAttestationState(state).Exec(ctx)
if err != nil && !ent.IsNotFound(err) {
return fmt.Errorf("failed to store attestation state: %w", err)
} else if err != nil {
return biz.NewErrNotFound("workflow run")
}

return tx.Commit()
return nil
})
}

func (r *AttestationStateRepo) Read(ctx context.Context, runID uuid.UUID) ([]byte, string, error) {
Expand Down
150 changes: 67 additions & 83 deletions app/controlplane/pkg/data/casbackend.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,106 +83,90 @@ func (r *CASBackendRepo) FindFallbackBackend(ctx context.Context, orgID uuid.UUI

// Create creates a new CAS backend in the given organization
// If it's set as default, it will unset the previous default backend
func (r *CASBackendRepo) Create(ctx context.Context, opts *biz.CASBackendCreateOpts) (b *biz.CASBackend, err error) {
tx, err := r.data.DB.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create transaction: %w", err)
}
func (r *CASBackendRepo) Create(ctx context.Context, opts *biz.CASBackendCreateOpts) (*biz.CASBackend, error) {
var (
backend *ent.CASBackend
err error
)
if err := WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
// 1 - unset default backend for all the other backends in the org
if opts.Default {
if err := tx.CASBackend.Update().
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
Where(casbackend.Default(true)).
SetDefault(false).
Exec(ctx); err != nil {
return fmt.Errorf("failed to clear previous default backend: %w", err)
}
}

defer func() {
// Unblock the row if there was an error
// 2 - create the new backend and set it as default if needed
backend, err = tx.CASBackend.Create().
SetName(opts.Name).
SetOrganizationID(opts.OrgID).
SetLocation(opts.Location).
SetDescription(opts.Description).
SetFallback(opts.Fallback).
SetProvider(opts.Provider).
SetDefault(opts.Default).
SetSecretName(opts.SecretName).
SetMaxBlobSizeBytes(opts.MaxBytes).
Save(ctx)
if err != nil {
_ = tx.Rollback()
}
}()

// 1 - unset default backend for all the other backends in the org
if opts.Default {
if err := tx.CASBackend.Update().
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
Where(casbackend.Default(true)).
SetDefault(false).
Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to clear previous default backend: %w", err)
}
}
if ent.IsConstraintError(err) {
return biz.NewErrAlreadyExists(err)
}

// 2 - create the new backend and set it as default if needed
backend, err := tx.CASBackend.Create().
SetName(opts.Name).
SetOrganizationID(opts.OrgID).
SetLocation(opts.Location).
SetDescription(opts.Description).
SetFallback(opts.Fallback).
SetProvider(opts.Provider).
SetDefault(opts.Default).
SetSecretName(opts.SecretName).
SetMaxBlobSizeBytes(opts.MaxBytes).
Save(ctx)
if err != nil {
if ent.IsConstraintError(err) {
return nil, biz.NewErrAlreadyExists(err)
return fmt.Errorf("failed to create backend: %w", err)
}

return nil, fmt.Errorf("failed to create backend: %w", err)
}

// 3 - commit the transaction
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
return nil
}); err != nil {
return nil, err
}

// Return the backend from the DB to have consistent marshalled object
return r.FindByID(ctx, backend.ID)
}

func (r *CASBackendRepo) Update(ctx context.Context, opts *biz.CASBackendUpdateOpts) (b *biz.CASBackend, err error) {
tx, err := r.data.DB.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("failed to create transaction: %w", err)
}

defer func() {
// Unblock the row if there was an error
if err != nil {
_ = tx.Rollback()
func (r *CASBackendRepo) Update(ctx context.Context, opts *biz.CASBackendUpdateOpts) (*biz.CASBackend, error) {
var (
backend *ent.CASBackend
err error
)
if err = WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
// 1 - unset default backend for all the other backends in the org
if opts.Default {
if err := tx.CASBackend.Update().
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
Where(casbackend.Default(true)).
SetDefault(false).
Exec(ctx); err != nil {
return fmt.Errorf("failed to clear previous default backend: %w", err)
}
}
}()

// 1 - unset default backend for all the other backends in the org
if opts.Default {
if err := tx.CASBackend.Update().
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
Where(casbackend.Default(true)).
SetDefault(false).
Exec(ctx); err != nil {
return nil, fmt.Errorf("failed to clear previous default backend: %w", err)
}
}

// 2 - Chain the list of updates
// TODO: allow setting values as empty, currently it's not possible.
// We do it in other models by providing pointers to string + setNillableX methods
updateChain := tx.CASBackend.UpdateOneID(opts.ID).SetDefault(opts.Default)
if opts.Description != "" {
updateChain = updateChain.SetDescription(opts.Description)
}
// 2 - Chain the list of updates
// TODO: allow setting values as empty, currently it's not possible.
// We do it in other models by providing pointers to string + setNillableX methods
updateChain := tx.CASBackend.UpdateOneID(opts.ID).SetDefault(opts.Default)
if opts.Description != "" {
updateChain = updateChain.SetDescription(opts.Description)
}

// If secretName is provided we set it
if opts.SecretName != "" {
updateChain = updateChain.SetSecretName(opts.SecretName)
}
// If secretName is provided we set it
if opts.SecretName != "" {
updateChain = updateChain.SetSecretName(opts.SecretName)
}

backend, err := updateChain.Save(ctx)
if err != nil {
backend, err = updateChain.Save(ctx)
if err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}

// 3 - commit the transaction
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("failed to commit transaction: %w", err)
}

return r.FindByID(ctx, backend.ID)
}

Expand Down
24 changes: 24 additions & 0 deletions app/controlplane/pkg/data/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,27 @@ func toTimePtr(t time.Time) *time.Time {
func orgScopedQuery(client *ent.Client, orgID uuid.UUID) *ent.OrganizationQuery {
return client.Organization.Query().Where(organization.ID(orgID))
}

// WithTx initiates a transaction and wraps the DB function
func WithTx(ctx context.Context, client *ent.Client, fn func(tx *ent.Tx) error) error {
tx, err := client.Tx(ctx)
if err != nil {
return err
}
defer func() {
if v := recover(); v != nil {
_ = tx.Rollback()
panic(v)
}
}()
if err = fn(tx); err != nil {
if rerr := tx.Rollback(); rerr != nil {
err = fmt.Errorf("%w: rolling back transaction: %w", err, rerr)
}
return err
}
if err = tx.Commit(); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
return nil
}
31 changes: 10 additions & 21 deletions app/controlplane/pkg/data/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,29 +112,18 @@ func (r *IntegrationRepo) FindByNameInOrg(ctx context.Context, orgID uuid.UUID,
return entIntegrationToBiz(integration), nil
}

func (r *IntegrationRepo) SoftDelete(ctx context.Context, id uuid.UUID) (err error) {
tx, err := r.data.DB.Tx(ctx)
if err != nil {
return err
}

defer func() {
// Unblock the row if there was an error
if err != nil {
_ = tx.Rollback()
func (r *IntegrationRepo) SoftDelete(ctx context.Context, id uuid.UUID) error {
return WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
// soft-delete attachments associated with this workflow
if err := tx.IntegrationAttachment.Update().Where(integrationattachment.HasIntegrationWith(integration.ID(id))).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
return err
}
}()

// soft-delete attachments associated with this workflow
if err := tx.IntegrationAttachment.Update().Where(integrationattachment.HasIntegrationWith(integration.ID(id))).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
return err
}

if err := tx.Integration.UpdateOneID(id).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
return err
}

return tx.Commit()
if err := tx.Integration.UpdateOneID(id).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
return err
}
return nil
})
}

func entIntegrationToBiz(i *ent.Integration) *biz.Integration {
Expand Down
34 changes: 11 additions & 23 deletions app/controlplane/pkg/data/membership.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,31 +152,19 @@ func (r *MembershipRepo) SetCurrent(ctx context.Context, membershipID uuid.UUID)
return nil, err
}

// For the found user, we must, in a transaction.
tx, err := r.data.DB.Tx(ctx)
if err != nil {
return nil, err
}

defer func() {
// Unblock the row if there was an error
if err != nil {
_ = tx.Rollback()
if err = WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
// 1 - Set all the memberships to current=false
if err = tx.Membership.Update().Where(membership.HasUserWith(user.ID(m.Edges.User.ID))).
SetCurrent(false).Exec(ctx); err != nil {
return err
}
}()

// 1 - Set all the memberships to current=false
if err = tx.Membership.Update().Where(membership.HasUserWith(user.ID(m.Edges.User.ID))).
SetCurrent(false).Exec(ctx); err != nil {
return nil, err
}

// 2 - Set the referenced membership to current=true
if err = tx.Membership.UpdateOneID(membershipID).SetCurrent(true).Exec(ctx); err != nil {
return nil, err
}

if err := tx.Commit(); err != nil {
// 2 - Set the referenced membership to current=true
if err = tx.Membership.UpdateOneID(membershipID).SetCurrent(true).Exec(ctx); err != nil {
return err
}
return nil
}); err != nil {
return nil, err
}

Expand Down
Loading

0 comments on commit 0bde1a6

Please sign in to comment.