Skip to content

Commit

Permalink
lock role record before updating or deleting
Browse files Browse the repository at this point in the history
Signed-off-by: Mike Mason <[email protected]>
  • Loading branch information
mikemrm committed Jan 11, 2024
1 parent 199915c commit 39ec7f9
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 14 deletions.
47 changes: 34 additions & 13 deletions internal/query/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,22 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou
return types.Role{}, err
}

_, err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)

span.RecordError(sErr)
span.SetStatus(codes.Error, sErr.Error())

logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return types.Role{}, err
}

role, err := e.GetRole(dbCtx, roleResource)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return types.Role{}, err
}

Expand Down Expand Up @@ -1003,14 +1017,30 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er

defer span.End()

var (
resActions map[types.Resource][]string
err error
)
dbCtx, err := e.store.BeginContext(ctx)
if err != nil {
return err
}

_, err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)

span.RecordError(sErr)
span.SetStatus(codes.Error, sErr.Error())

logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return err
}

var resActions map[types.Resource][]string

for _, resType := range e.schemaRoleables {
resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return err
}

Expand All @@ -1020,10 +1050,6 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
}
}

if len(resActions) == 0 {
return ErrRoleNotFound
}

roleType := e.namespace + "/role"

var filters []*pb.RelationshipFilter
Expand All @@ -1047,11 +1073,6 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
}
}

dbCtx, err := e.store.BeginContext(ctx)
if err != nil {
return err
}

_, err = e.store.DeleteRole(dbCtx, roleResource.ID)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))
Expand Down
3 changes: 2 additions & 1 deletion internal/query/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"go.infratographer.com/permissions-api/internal/iapl"
"go.infratographer.com/permissions-api/internal/spicedbx"
"go.infratographer.com/permissions-api/internal/storage"
"go.infratographer.com/permissions-api/internal/storage/teststore"
"go.infratographer.com/permissions-api/internal/testingx"
"go.infratographer.com/permissions-api/internal/types"
Expand Down Expand Up @@ -229,7 +230,7 @@ func TestRoleUpdate(t *testing.T) {
Input: gidx.MustNewID(RolePrefix),
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
require.Error(t, res.Err)
assert.ErrorIs(t, res.Err, ErrRoleNotFound)
assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound)
},
},
{
Expand Down
45 changes: 45 additions & 0 deletions internal/storage/roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type RoleService interface {
CreateRole(ctx context.Context, actorID gidx.PrefixedID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (Role, error)
UpdateRole(ctx context.Context, actorID, roleID gidx.PrefixedID, name string) (Role, error)
DeleteRole(ctx context.Context, roleID gidx.PrefixedID) (Role, error)
LockRoleForUpdate(ctx context.Context, roleID gidx.PrefixedID) (Role, error)
}

// Role represents a role in the database.
Expand Down Expand Up @@ -74,6 +75,50 @@ func (e *engine) GetRoleByID(ctx context.Context, id gidx.PrefixedID) (Role, err
return role, nil
}

// LockRoleForUpdate locks the provided role's record to be updated to ensure consistency.
// If no role exists an ErrNoRoleFound error is returned.
func (e *engine) LockRoleForUpdate(ctx context.Context, id gidx.PrefixedID) (Role, error) {
db, err := getContextDBQuery(ctx, e)
if err != nil {
return Role{}, err
}

var role Role

err = db.QueryRowContext(ctx, `
SELECT
id,
name,
resource_id,
created_by,
updated_by,
created_at,
updated_at
FROM roles
WHERE id = $1
FOR UPDATE
`, id.String(),
).Scan(
&role.ID,
&role.Name,
&role.ResourceID,
&role.CreatedBy,
&role.UpdatedBy,
&role.CreatedAt,
&role.UpdatedAt,
)

if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return Role{}, fmt.Errorf("%w: %s", ErrNoRoleFound, id.String())
}

return Role{}, fmt.Errorf("%w: %s", err, id.String())
}

return role, nil
}

// GetResourceRoleByName retrieves a role from the database by the provided resource ID and role name.
// If no role exists an ErrRoleNotFound error is returned.
func (e *engine) GetResourceRoleByName(ctx context.Context, resourceID gidx.PrefixedID, name string) (Role, error) {
Expand Down

0 comments on commit 39ec7f9

Please sign in to comment.