Skip to content

Commit

Permalink
fix: upsert race condition
Browse files Browse the repository at this point in the history
  • Loading branch information
Muhammad Luthfi Fahlevi committed Feb 8, 2025
1 parent 357f88f commit 9a0deeb
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 51 deletions.
123 changes: 80 additions & 43 deletions internal/store/postgres/asset_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,11 @@ func (r *AssetRepository) GetByID(ctx context.Context, id string) (asset.Asset,
}

func (r *AssetRepository) GetByURN(ctx context.Context, urn string) (asset.Asset, error) {
ast, err := r.getWithPredicate(ctx, sq.Eq{"a.urn": urn})
return r.GetByURNWithTx(ctx, nil, urn)
}

func (r *AssetRepository) GetByURNWithTx(ctx context.Context, tx *sqlx.Tx, urn string) (asset.Asset, error) {
ast, err := r.getWithPredicateWithTx(ctx, tx, sq.Eq{"a.urn": urn})
if errors.Is(err, sql.ErrNoRows) {
return asset.Asset{}, asset.NotFoundError{URN: urn}
}
Expand All @@ -176,6 +180,10 @@ func (r *AssetRepository) GetByURN(ctx context.Context, urn string) (asset.Asset
}

func (r *AssetRepository) getWithPredicate(ctx context.Context, pred sq.Eq) (asset.Asset, error) {
return r.getWithPredicateWithTx(ctx, nil, pred)
}

func (r *AssetRepository) getWithPredicateWithTx(ctx context.Context, tx *sqlx.Tx, pred sq.Eq) (asset.Asset, error) {
query, args, err := r.getAssetSQL().
Where(pred).
Limit(1).
Expand All @@ -186,12 +194,16 @@ func (r *AssetRepository) getWithPredicate(ctx context.Context, pred sq.Eq) (ass
}

var am AssetModel
err = r.client.db.GetContext(ctx, &am, query, args...)
if tx == nil {
err = r.client.db.GetContext(ctx, &am, query, args...)
} else {
err = tx.GetContext(ctx, &am, query, args...)
}
if err != nil {
return asset.Asset{}, err
}

owners, err := r.getOwners(ctx, am.ID)
owners, err := r.getOwnersWithTx(ctx, tx, am.ID)
if err != nil {
return asset.Asset{}, err
}
Expand Down Expand Up @@ -315,7 +327,7 @@ func (r *AssetRepository) getByVersion(
func (r *AssetRepository) Upsert(ctx context.Context, ast *asset.Asset) (string, error) {
var id string
err := r.client.RunWithinTx(ctx, func(tx *sqlx.Tx) error {
fetchedAsset, err := r.GetByURN(ctx, ast.URN)
fetchedAsset, err := r.GetByURNWithTx(ctx, tx, ast.URN)
if errors.As(err, new(asset.NotFoundError)) {
err = nil
}
Expand All @@ -338,7 +350,7 @@ func (r *AssetRepository) Upsert(ctx context.Context, ast *asset.Asset) (string,
return fmt.Errorf("error diffing two assets: %w", err)
}

err = r.update(ctx, fetchedAsset.ID, ast, &fetchedAsset, changelog)
err = r.update(ctx, tx, fetchedAsset.ID, ast, &fetchedAsset, changelog)
if err != nil {
return fmt.Errorf("error updating asset to DB: %w", err)
}
Expand Down Expand Up @@ -598,12 +610,14 @@ func (r *AssetRepository) insert(ctx context.Context, ast *asset.Asset) (string,
return id, nil
}

func (r *AssetRepository) update(ctx context.Context, assetID string, newAsset, oldAsset *asset.Asset, clog diff.Changelog) error {
func (r *AssetRepository) update(ctx context.Context, tx *sqlx.Tx, assetID string, newAsset, oldAsset *asset.Asset, clog diff.Changelog) error {

Check failure on line 613 in internal/store/postgres/asset_repository.go

View workflow job for this annotation

GitHub Actions / golangci

argument-limit: maximum number of arguments per function exceeded; max 5 but got 6 (revive)
if !isValidUUID(assetID) {
return asset.InvalidError{AssetID: assetID}
}

currentTime := time.Now()
// for Upsert API calls, to make currentTime value same for both Postgresql and Elasticsearch,
// the currentTime already filled in UpsertAssetWithoutLineage
if newAsset.RefreshedAt != nil {
currentTime = *newAsset.RefreshedAt
}
Expand All @@ -613,47 +627,42 @@ func (r *AssetRepository) update(ctx context.Context, assetID string, newAsset,
return nil
}

return r.client.RunWithinTx(ctx, func(tx *sqlx.Tx) error {
newAsset.RefreshedAt = &currentTime
return r.updateAsset(ctx, tx, assetID, newAsset)
})
return r.updateAssetRefreshedAt(ctx, tx, assetID, currentTime)
}

return r.client.RunWithinTx(ctx, func(tx *sqlx.Tx) error {
// update assets
newVersion, err := asset.IncreaseMinorVersion(oldAsset.Version)
if err != nil {
return err
}
newAsset.Version = newVersion
newAsset.ID = oldAsset.ID
newAsset.UpdatedAt = currentTime
newAsset.RefreshedAt = &currentTime
// update assets
newVersion, err := asset.IncreaseMinorVersion(oldAsset.Version)
if err != nil {
return err
}
newAsset.Version = newVersion
newAsset.ID = oldAsset.ID
newAsset.UpdatedAt = currentTime
newAsset.RefreshedAt = &currentTime

if err := r.updateAsset(ctx, tx, assetID, newAsset); err != nil {
return err
}
if err := r.updateAsset(ctx, tx, assetID, newAsset); err != nil {
return err
}

// insert versions
if err := r.insertAssetVersion(ctx, tx, newAsset, clog); err != nil {
return err
}
// insert versions
if err := r.insertAssetVersion(ctx, tx, newAsset, clog); err != nil {
return err
}

// managing owners
newAssetOwners, err := r.createOrFetchUsers(ctx, tx, newAsset.Owners)
if err != nil {
return fmt.Errorf("error creating and fetching owners: %w", err)
}
toInserts, toRemoves := r.compareOwners(oldAsset.Owners, newAssetOwners)
if err := r.insertOwners(ctx, tx, assetID, toInserts); err != nil {
return fmt.Errorf("error inserting asset's new owners: %w", err)
}
if err := r.removeOwners(ctx, tx, assetID, toRemoves); err != nil {
return fmt.Errorf("error removing asset's old owners: %w", err)
}
// managing owners
newAssetOwners, err := r.createOrFetchUsers(ctx, tx, newAsset.Owners)
if err != nil {
return fmt.Errorf("error creating and fetching owners: %w", err)
}
toInserts, toRemoves := r.compareOwners(oldAsset.Owners, newAssetOwners)
if err := r.insertOwners(ctx, tx, assetID, toInserts); err != nil {
return fmt.Errorf("error inserting asset's new owners: %w", err)
}
if err := r.removeOwners(ctx, tx, assetID, toRemoves); err != nil {
return fmt.Errorf("error removing asset's old owners: %w", err)
}

return nil
})
return nil
}

func (r *AssetRepository) updateAsset(ctx context.Context, tx *sqlx.Tx, assetID string, newAsset *asset.Asset) error {
Expand Down Expand Up @@ -684,6 +693,23 @@ func (r *AssetRepository) updateAsset(ctx context.Context, tx *sqlx.Tx, assetID
return nil
}

func (r *AssetRepository) updateAssetRefreshedAt(ctx context.Context, tx *sqlx.Tx, assetID string, currentTime time.Time) error {
query, args, err := sq.Update("assets").
Set("refreshed_at", currentTime).
Where(sq.Eq{"id": assetID}).
PlaceholderFormat(sq.Dollar).
ToSql()
if err != nil {
return fmt.Errorf("build query: %w", err)
}

if err := r.execContext(ctx, tx, query, args...); err != nil {
return fmt.Errorf("error running update asset query: %w", err)
}

return nil
}

func (r *AssetRepository) insertAssetVersion(ctx context.Context, execer sqlx.ExecerContext, oldAsset *asset.Asset, clog diff.Changelog) error {
if oldAsset == nil {
return asset.ErrNilAsset
Expand Down Expand Up @@ -715,6 +741,10 @@ func (r *AssetRepository) insertAssetVersion(ctx context.Context, execer sqlx.Ex
}

func (r *AssetRepository) getOwners(ctx context.Context, assetID string) ([]user.User, error) {
return r.getOwnersWithTx(ctx, nil, assetID)
}

func (r *AssetRepository) getOwnersWithTx(ctx context.Context, tx *sqlx.Tx, assetID string) ([]user.User, error) {
if !isValidUUID(assetID) {
return nil, asset.InvalidError{AssetID: assetID}
}
Expand All @@ -730,7 +760,13 @@ func (r *AssetRepository) getOwners(ctx context.Context, assetID string) ([]user
JOIN users u on ao.user_id = u.id
WHERE asset_id = $1`

if err := r.client.db.SelectContext(ctx, &userModels, query, assetID); err != nil {
var err error
if tx == nil {
err = r.client.db.SelectContext(ctx, &userModels, query, assetID)
} else {
err = tx.SelectContext(ctx, &userModels, query, assetID)
}
if err != nil {
return nil, fmt.Errorf("get asset owners: %w", err)
}

Expand Down Expand Up @@ -913,7 +949,8 @@ func (r *AssetRepository) getAssetSQL() sq.SelectBuilder {
u.updated_at as "updated_by.updated_at"
`).
From("assets a").
LeftJoin("users u ON a.updated_by = u.id")
LeftJoin("users u ON a.updated_by = u.id").
Suffix("FOR UPDATE OF a")
}

func (r *AssetRepository) getAssetVersionSQL() sq.SelectBuilder {
Expand Down
5 changes: 2 additions & 3 deletions internal/store/postgres/asset_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,7 @@ func (r *AssetRepositoryTestSuite) TestUpsertRaceCondition() {

localAst := ast
localAst.URL = fmt.Sprintf("https://sample-url-%d.com", index)
_, err := r.repository.Upsert(r.ctx, &localAst)
_, err := r.repository.Upsert(context.Background(), &localAst)

mu.Lock()
results = append(results, err)
Expand All @@ -1309,8 +1309,7 @@ func (r *AssetRepositoryTestSuite) TestUpsertRaceCondition() {
wg.Wait()

// Check for errors
for i, err := range results {
fmt.Println("err", i, ": ", err)
for _, err := range results {
assert.NoError(r.T(), err, "Upsert should not fail under race conditions")
}
})
Expand Down
6 changes: 1 addition & 5 deletions internal/store/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ type Client struct {
}

func (c *Client) RunWithinTx(ctx context.Context, f func(tx *sqlx.Tx) error) error {
return c.RunWithinTxWithOption(ctx, nil, f)
}

func (c *Client) RunWithinTxWithOption(ctx context.Context, opts *sql.TxOptions, f func(tx *sqlx.Tx) error) error {
tx, err := c.db.BeginTxx(ctx, opts)
tx, err := c.db.BeginTxx(ctx, nil)
if err != nil {
return fmt.Errorf("starting transaction: %w", err)
}
Expand Down

0 comments on commit 9a0deeb

Please sign in to comment.