Skip to content

Commit 0bde1a6

Browse files
authored
chore: implement best practices for transaction through helper method (chainloop-dev#1605)
Signed-off-by: Jose I. Paris <[email protected]>
1 parent 8a07f99 commit 0bde1a6

File tree

8 files changed

+284
-348
lines changed

8 files changed

+284
-348
lines changed

app/controlplane/pkg/data/attestationstate.go

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -53,50 +53,39 @@ func (r *AttestationStateRepo) Initialized(ctx context.Context, runID uuid.UUID)
5353

5454
// baseDigest, when provided will be used to check that it matches the digest of the state currently in the DB
5555
// if the digests do not match, the state has been modified and the caller should retry
56-
func (r *AttestationStateRepo) Save(ctx context.Context, runID uuid.UUID, state []byte, baseDigest string) (err error) {
57-
tx, err := r.data.DB.Tx(ctx)
58-
if err != nil {
59-
return fmt.Errorf("failed to create transaction: %w", err)
60-
}
61-
62-
defer func() {
63-
// Unblock the row if there was an error
64-
if err != nil {
65-
_ = tx.Rollback()
56+
func (r *AttestationStateRepo) Save(ctx context.Context, runID uuid.UUID, state []byte, baseDigest string) error {
57+
return WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
58+
// compared the provided digest with the digest of the state in the DB
59+
// TODO: make digest check mandatory on updates
60+
if baseDigest != "" {
61+
// Get the run but BLOCK IT for update
62+
run, err := tx.WorkflowRun.Query().ForUpdate().Where(workflowrun.ID(runID)).Only(ctx)
63+
if err != nil && !ent.IsNotFound(err) {
64+
return fmt.Errorf("failed to read attestation state: %w", err)
65+
} else if run == nil || run.AttestationState == nil {
66+
return biz.NewErrNotFound("attestation state")
67+
}
68+
69+
// calculate the digest of the current state
70+
storedDigest, err := digest(run.AttestationState)
71+
if err != nil {
72+
return fmt.Errorf("failed to calculate digest: %w", err)
73+
}
74+
75+
if baseDigest != storedDigest {
76+
return biz.NewErrAttestationStateConflict(storedDigest, baseDigest)
77+
}
6678
}
67-
}()
6879

69-
// compared the provided digest with the digest of the state in the DB
70-
// TODO: make digest check mandatory on updates
71-
if baseDigest != "" {
72-
// Get the run but BLOCK IT for update
73-
run, err := tx.WorkflowRun.Query().ForUpdate().Where(workflowrun.ID(runID)).Only(ctx)
80+
// Update it in the DB if the digest matches
81+
err := tx.WorkflowRun.UpdateOneID(runID).SetAttestationState(state).Exec(ctx)
7482
if err != nil && !ent.IsNotFound(err) {
75-
return fmt.Errorf("failed to read attestation state: %w", err)
76-
} else if run == nil || run.AttestationState == nil {
77-
return biz.NewErrNotFound("attestation state")
78-
}
79-
80-
// calculate the digest of the current state
81-
storedDigest, err := digest(run.AttestationState)
82-
if err != nil {
83-
return fmt.Errorf("failed to calculate digest: %w", err)
84-
}
85-
86-
if baseDigest != storedDigest {
87-
return biz.NewErrAttestationStateConflict(storedDigest, baseDigest)
83+
return fmt.Errorf("failed to store attestation state: %w", err)
84+
} else if err != nil {
85+
return biz.NewErrNotFound("workflow run")
8886
}
89-
}
90-
91-
// Update it in the DB if the digest matches
92-
err = tx.WorkflowRun.UpdateOneID(runID).SetAttestationState(state).Exec(ctx)
93-
if err != nil && !ent.IsNotFound(err) {
94-
return fmt.Errorf("failed to store attestation state: %w", err)
95-
} else if err != nil {
96-
return biz.NewErrNotFound("workflow run")
97-
}
98-
99-
return tx.Commit()
87+
return nil
88+
})
10089
}
10190

10291
func (r *AttestationStateRepo) Read(ctx context.Context, runID uuid.UUID) ([]byte, string, error) {

app/controlplane/pkg/data/casbackend.go

Lines changed: 67 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -83,106 +83,90 @@ func (r *CASBackendRepo) FindFallbackBackend(ctx context.Context, orgID uuid.UUI
8383

8484
// Create creates a new CAS backend in the given organization
8585
// If it's set as default, it will unset the previous default backend
86-
func (r *CASBackendRepo) Create(ctx context.Context, opts *biz.CASBackendCreateOpts) (b *biz.CASBackend, err error) {
87-
tx, err := r.data.DB.Tx(ctx)
88-
if err != nil {
89-
return nil, fmt.Errorf("failed to create transaction: %w", err)
90-
}
86+
func (r *CASBackendRepo) Create(ctx context.Context, opts *biz.CASBackendCreateOpts) (*biz.CASBackend, error) {
87+
var (
88+
backend *ent.CASBackend
89+
err error
90+
)
91+
if err := WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
92+
// 1 - unset default backend for all the other backends in the org
93+
if opts.Default {
94+
if err := tx.CASBackend.Update().
95+
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
96+
Where(casbackend.Default(true)).
97+
SetDefault(false).
98+
Exec(ctx); err != nil {
99+
return fmt.Errorf("failed to clear previous default backend: %w", err)
100+
}
101+
}
91102

92-
defer func() {
93-
// Unblock the row if there was an error
103+
// 2 - create the new backend and set it as default if needed
104+
backend, err = tx.CASBackend.Create().
105+
SetName(opts.Name).
106+
SetOrganizationID(opts.OrgID).
107+
SetLocation(opts.Location).
108+
SetDescription(opts.Description).
109+
SetFallback(opts.Fallback).
110+
SetProvider(opts.Provider).
111+
SetDefault(opts.Default).
112+
SetSecretName(opts.SecretName).
113+
SetMaxBlobSizeBytes(opts.MaxBytes).
114+
Save(ctx)
94115
if err != nil {
95-
_ = tx.Rollback()
96-
}
97-
}()
98-
99-
// 1 - unset default backend for all the other backends in the org
100-
if opts.Default {
101-
if err := tx.CASBackend.Update().
102-
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
103-
Where(casbackend.Default(true)).
104-
SetDefault(false).
105-
Exec(ctx); err != nil {
106-
return nil, fmt.Errorf("failed to clear previous default backend: %w", err)
107-
}
108-
}
116+
if ent.IsConstraintError(err) {
117+
return biz.NewErrAlreadyExists(err)
118+
}
109119

110-
// 2 - create the new backend and set it as default if needed
111-
backend, err := tx.CASBackend.Create().
112-
SetName(opts.Name).
113-
SetOrganizationID(opts.OrgID).
114-
SetLocation(opts.Location).
115-
SetDescription(opts.Description).
116-
SetFallback(opts.Fallback).
117-
SetProvider(opts.Provider).
118-
SetDefault(opts.Default).
119-
SetSecretName(opts.SecretName).
120-
SetMaxBlobSizeBytes(opts.MaxBytes).
121-
Save(ctx)
122-
if err != nil {
123-
if ent.IsConstraintError(err) {
124-
return nil, biz.NewErrAlreadyExists(err)
120+
return fmt.Errorf("failed to create backend: %w", err)
125121
}
126-
127-
return nil, fmt.Errorf("failed to create backend: %w", err)
128-
}
129-
130-
// 3 - commit the transaction
131-
if err := tx.Commit(); err != nil {
132-
return nil, fmt.Errorf("failed to commit transaction: %w", err)
122+
return nil
123+
}); err != nil {
124+
return nil, err
133125
}
134126

135127
// Return the backend from the DB to have consistent marshalled object
136128
return r.FindByID(ctx, backend.ID)
137129
}
138130

139-
func (r *CASBackendRepo) Update(ctx context.Context, opts *biz.CASBackendUpdateOpts) (b *biz.CASBackend, err error) {
140-
tx, err := r.data.DB.Tx(ctx)
141-
if err != nil {
142-
return nil, fmt.Errorf("failed to create transaction: %w", err)
143-
}
144-
145-
defer func() {
146-
// Unblock the row if there was an error
147-
if err != nil {
148-
_ = tx.Rollback()
131+
func (r *CASBackendRepo) Update(ctx context.Context, opts *biz.CASBackendUpdateOpts) (*biz.CASBackend, error) {
132+
var (
133+
backend *ent.CASBackend
134+
err error
135+
)
136+
if err = WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
137+
// 1 - unset default backend for all the other backends in the org
138+
if opts.Default {
139+
if err := tx.CASBackend.Update().
140+
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
141+
Where(casbackend.Default(true)).
142+
SetDefault(false).
143+
Exec(ctx); err != nil {
144+
return fmt.Errorf("failed to clear previous default backend: %w", err)
145+
}
149146
}
150-
}()
151-
152-
// 1 - unset default backend for all the other backends in the org
153-
if opts.Default {
154-
if err := tx.CASBackend.Update().
155-
Where(casbackend.HasOrganizationWith(organization.ID(opts.OrgID))).
156-
Where(casbackend.Default(true)).
157-
SetDefault(false).
158-
Exec(ctx); err != nil {
159-
return nil, fmt.Errorf("failed to clear previous default backend: %w", err)
160-
}
161-
}
162147

163-
// 2 - Chain the list of updates
164-
// TODO: allow setting values as empty, currently it's not possible.
165-
// We do it in other models by providing pointers to string + setNillableX methods
166-
updateChain := tx.CASBackend.UpdateOneID(opts.ID).SetDefault(opts.Default)
167-
if opts.Description != "" {
168-
updateChain = updateChain.SetDescription(opts.Description)
169-
}
148+
// 2 - Chain the list of updates
149+
// TODO: allow setting values as empty, currently it's not possible.
150+
// We do it in other models by providing pointers to string + setNillableX methods
151+
updateChain := tx.CASBackend.UpdateOneID(opts.ID).SetDefault(opts.Default)
152+
if opts.Description != "" {
153+
updateChain = updateChain.SetDescription(opts.Description)
154+
}
170155

171-
// If secretName is provided we set it
172-
if opts.SecretName != "" {
173-
updateChain = updateChain.SetSecretName(opts.SecretName)
174-
}
156+
// If secretName is provided we set it
157+
if opts.SecretName != "" {
158+
updateChain = updateChain.SetSecretName(opts.SecretName)
159+
}
175160

176-
backend, err := updateChain.Save(ctx)
177-
if err != nil {
161+
backend, err = updateChain.Save(ctx)
162+
if err != nil {
163+
return err
164+
}
165+
return nil
166+
}); err != nil {
178167
return nil, err
179168
}
180169

181-
// 3 - commit the transaction
182-
if err := tx.Commit(); err != nil {
183-
return nil, fmt.Errorf("failed to commit transaction: %w", err)
184-
}
185-
186170
return r.FindByID(ctx, backend.ID)
187171
}
188172

app/controlplane/pkg/data/data.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,27 @@ func toTimePtr(t time.Time) *time.Time {
151151
func orgScopedQuery(client *ent.Client, orgID uuid.UUID) *ent.OrganizationQuery {
152152
return client.Organization.Query().Where(organization.ID(orgID))
153153
}
154+
155+
// WithTx initiates a transaction and wraps the DB function
156+
func WithTx(ctx context.Context, client *ent.Client, fn func(tx *ent.Tx) error) error {
157+
tx, err := client.Tx(ctx)
158+
if err != nil {
159+
return err
160+
}
161+
defer func() {
162+
if v := recover(); v != nil {
163+
_ = tx.Rollback()
164+
panic(v)
165+
}
166+
}()
167+
if err = fn(tx); err != nil {
168+
if rerr := tx.Rollback(); rerr != nil {
169+
err = fmt.Errorf("%w: rolling back transaction: %w", err, rerr)
170+
}
171+
return err
172+
}
173+
if err = tx.Commit(); err != nil {
174+
return fmt.Errorf("committing transaction: %w", err)
175+
}
176+
return nil
177+
}

app/controlplane/pkg/data/integration.go

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -112,29 +112,18 @@ func (r *IntegrationRepo) FindByNameInOrg(ctx context.Context, orgID uuid.UUID,
112112
return entIntegrationToBiz(integration), nil
113113
}
114114

115-
func (r *IntegrationRepo) SoftDelete(ctx context.Context, id uuid.UUID) (err error) {
116-
tx, err := r.data.DB.Tx(ctx)
117-
if err != nil {
118-
return err
119-
}
120-
121-
defer func() {
122-
// Unblock the row if there was an error
123-
if err != nil {
124-
_ = tx.Rollback()
115+
func (r *IntegrationRepo) SoftDelete(ctx context.Context, id uuid.UUID) error {
116+
return WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
117+
// soft-delete attachments associated with this workflow
118+
if err := tx.IntegrationAttachment.Update().Where(integrationattachment.HasIntegrationWith(integration.ID(id))).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
119+
return err
125120
}
126-
}()
127-
128-
// soft-delete attachments associated with this workflow
129-
if err := tx.IntegrationAttachment.Update().Where(integrationattachment.HasIntegrationWith(integration.ID(id))).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
130-
return err
131-
}
132121

133-
if err := tx.Integration.UpdateOneID(id).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
134-
return err
135-
}
136-
137-
return tx.Commit()
122+
if err := tx.Integration.UpdateOneID(id).SetDeletedAt(time.Now()).Exec(ctx); err != nil {
123+
return err
124+
}
125+
return nil
126+
})
138127
}
139128

140129
func entIntegrationToBiz(i *ent.Integration) *biz.Integration {

app/controlplane/pkg/data/membership.go

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -152,31 +152,19 @@ func (r *MembershipRepo) SetCurrent(ctx context.Context, membershipID uuid.UUID)
152152
return nil, err
153153
}
154154

155-
// For the found user, we must, in a transaction.
156-
tx, err := r.data.DB.Tx(ctx)
157-
if err != nil {
158-
return nil, err
159-
}
160-
161-
defer func() {
162-
// Unblock the row if there was an error
163-
if err != nil {
164-
_ = tx.Rollback()
155+
if err = WithTx(ctx, r.data.DB, func(tx *ent.Tx) error {
156+
// 1 - Set all the memberships to current=false
157+
if err = tx.Membership.Update().Where(membership.HasUserWith(user.ID(m.Edges.User.ID))).
158+
SetCurrent(false).Exec(ctx); err != nil {
159+
return err
165160
}
166-
}()
167-
168-
// 1 - Set all the memberships to current=false
169-
if err = tx.Membership.Update().Where(membership.HasUserWith(user.ID(m.Edges.User.ID))).
170-
SetCurrent(false).Exec(ctx); err != nil {
171-
return nil, err
172-
}
173-
174-
// 2 - Set the referenced membership to current=true
175-
if err = tx.Membership.UpdateOneID(membershipID).SetCurrent(true).Exec(ctx); err != nil {
176-
return nil, err
177-
}
178161

179-
if err := tx.Commit(); err != nil {
162+
// 2 - Set the referenced membership to current=true
163+
if err = tx.Membership.UpdateOneID(membershipID).SetCurrent(true).Exec(ctx); err != nil {
164+
return err
165+
}
166+
return nil
167+
}); err != nil {
180168
return nil, err
181169
}
182170

0 commit comments

Comments
 (0)