Skip to content

Commit

Permalink
feat: batch-insert relation tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Nov 18, 2024
1 parent b95a6cf commit 7719205
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 53 deletions.
20 changes: 0 additions & 20 deletions internal/persistence/sql/persister.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ package sql
import (
"context"
"embed"
"reflect"

"github.com/gobuffalo/pop/v6"
"github.com/gofrs/uuid"
"github.com/ory/x/otelx"
"github.com/ory/x/popx"
"github.com/pkg/errors"

Expand Down Expand Up @@ -70,24 +68,6 @@ func (p *Persister) Connection(ctx context.Context) *pop.Connection {
return popx.GetConnection(ctx, p.conn.WithContext(ctx))
}

func (p *Persister) createWithNetwork(ctx context.Context, v interface{}) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.createWithNetwork")
defer otelx.End(span, &err)

rv := reflect.ValueOf(v)

if rv.Kind() != reflect.Ptr && rv.Elem().Kind() != reflect.Struct {
panic("expected to get *struct in create")
}
nID := rv.Elem().FieldByName("NetworkID")
if !nID.IsValid() || !nID.CanSet() {
panic("expected struct to have a 'NetworkID uuid.UUID' field")
}
nID.Set(reflect.ValueOf(p.NetworkID(ctx)))

return p.Connection(ctx).Create(v)
}

func (p *Persister) queryWithNetwork(ctx context.Context) *pop.Query {
return p.Connection(ctx).Where("nid = ?", p.NetworkID(ctx))
}
Expand Down
91 changes: 58 additions & 33 deletions internal/persistence/sql/relationtuples.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func (r *RelationTuple) ToInternal() (*relationtuple.RelationTuple, error) {
return rt, nil
}

func (r *RelationTuple) insertSubject(_ context.Context, s relationtuple.Subject) error {
func (r *RelationTuple) insertSubject(s relationtuple.Subject) error {
switch st := s.(type) {
case *relationtuple.SubjectID:
r.SubjectID = uuid.NullUUID{
Expand All @@ -94,39 +94,12 @@ func (r *RelationTuple) insertSubject(_ context.Context, s relationtuple.Subject
return nil
}

func (r *RelationTuple) FromInternal(ctx context.Context, p *Persister, rt *relationtuple.RelationTuple) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FromInternal")
defer otelx.End(span, &err)

func (r *RelationTuple) FromInternal(rt *relationtuple.RelationTuple) (err error) {
r.Namespace = rt.Namespace
r.Object = rt.Object
r.Relation = rt.Relation

return r.insertSubject(ctx, rt.Subject)
}

func (p *Persister) InsertRelationTuple(ctx context.Context, rel *relationtuple.RelationTuple) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.InsertRelationTuple")
defer otelx.End(span, &err)

if rel.Subject == nil {
return errors.WithStack(ketoapi.ErrNilSubject)
}

rt := &RelationTuple{
ID: uuid.Must(uuid.NewV4()),
CommitTime: time.Now(),
}
if err := rt.FromInternal(ctx, p, rel); err != nil {
return err
}

if err := sqlcon.HandleError(
p.createWithNetwork(ctx, rt),
); err != nil {
return err
}
return nil
return r.insertSubject(rt.Subject)
}

func (p *Persister) whereSubject(_ context.Context, q *pop.Query, sub relationtuple.Subject) error {
Expand Down Expand Up @@ -292,15 +265,63 @@ func (p *Persister) ExistsRelationTuples(ctx context.Context, query *relationtup
return exists, sqlcon.HandleError(err)
}

func buildInsert(commitTime time.Time, nid uuid.UUID, rs []*relationtuple.RelationTuple) (query string, args []any, err error) {
if len(rs) == 0 {
return "", nil, errors.WithStack(ketoapi.ErrMalformedInput)
}

var q strings.Builder
fmt.Fprintf(&q, "INSERT INTO %s (shard_id, nid, namespace, object, relation, subject_id, subject_set_namespace, subject_set_object, subject_set_relation, commit_time) VALUES ", (&RelationTuple{}).TableName())
const placeholders = "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
const separator = ", "
q.Grow(len(rs) * (len(placeholders) + len(separator)))
args = make([]any, 0, 10*len(rs))

for i, r := range rs {
if r.Subject == nil {
return "", nil, errors.WithStack(ketoapi.ErrNilSubject)
}

rt := &RelationTuple{
ID: uuid.Must(uuid.NewV4()),
NetworkID: nid,
CommitTime: commitTime,
}
if err := rt.FromInternal(r); err != nil {
return "", nil, err
}

if i > 0 {
q.WriteString(separator)
}
q.WriteString(placeholders)
args = append(args, rt.ID, rt.NetworkID, rt.Namespace, rt.Object, rt.Relation, rt.SubjectID, rt.SubjectSetNamespace, rt.SubjectSetObject, rt.SubjectSetRelation, rt.CommitTime)
}

query = q.String()
return query, args, nil
}

func (p *Persister) WriteRelationTuples(ctx context.Context, rs ...*relationtuple.RelationTuple) (err error) {
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.WriteRelationTuples")
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.WriteRelationTuples",
trace.WithAttributes(attribute.Int("count", len(rs))))
defer otelx.End(span, &err)

if len(rs) == 0 {
return nil
}

commitTime := time.Now()

return p.Transaction(ctx, func(ctx context.Context) error {
for _, r := range rs {
if err := p.InsertRelationTuple(ctx, r); err != nil {
for chunk := range slices.Chunk(rs, 500) {
q, args, err := buildInsert(commitTime, p.NetworkID(ctx), chunk)
if err != nil {
return err
}
if err := p.Connection(ctx).RawQuery(q, args...).Exec(); err != nil {
return sqlcon.HandleError(err)
}
}
return nil
})
Expand All @@ -310,6 +331,10 @@ func (p *Persister) TransactRelationTuples(ctx context.Context, ins []*relationt
ctx, span := p.d.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.TransactRelationTuples")
defer otelx.End(span, &err)

if len(ins)+len(del) == 0 {
return nil
}

return p.Transaction(ctx, func(ctx context.Context) error {
if err := p.WriteRelationTuples(ctx, ins...); err != nil {
return err
Expand Down

0 comments on commit 7719205

Please sign in to comment.