From 71cf48da5250e49ef092fc0490604cf29400f8fd Mon Sep 17 00:00:00 2001 From: Matthew Anderson <42154938+matoszz@users.noreply.github.com> Date: Sun, 1 Dec 2024 08:13:37 -0700 Subject: [PATCH] initial stab at establishing pgxpool --- go.mod | 5 +- go.sum | 2 + internal/entdb/client.go | 82 ---------- internal/entdb/clientpool.go | 47 ++++++ internal/entdb/ent.go | 105 +++++++++++++ internal/entdb/errors.go | 28 ++++ internal/entdb/pgxpool.go | 296 +++++++++++++++++++++++++++++++++++ internal/entdb/testclient.go | 91 +++++++++++ internal/entdb/tx.go | 100 ++++++++++++ 9 files changed, 672 insertions(+), 84 deletions(-) create mode 100644 internal/entdb/clientpool.go create mode 100644 internal/entdb/ent.go create mode 100644 internal/entdb/errors.go create mode 100644 internal/entdb/pgxpool.go create mode 100644 internal/entdb/testclient.go create mode 100644 internal/entdb/tx.go diff --git a/go.mod b/go.mod index b916b51a..be6e5853 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/brianvoe/gofakeit/v7 v7.1.2 github.com/danielgtaylor/huma/v2 v2.26.0 github.com/dustinkirkland/golang-petname v0.0.0-20240428194347-eebcea082ee0 + github.com/exaring/otelpgx v0.7.0 github.com/gertd/go-pluralize v0.2.1 github.com/getkin/kin-openapi v0.128.0 github.com/go-viper/mapstructure/v2 v2.2.1 @@ -290,14 +291,14 @@ require ( go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/contrib v1.29.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect - go.opentelemetry.io/otel v1.32.0 // indirect + go.opentelemetry.io/otel v1.32.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.29.0 // indirect go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.29.0 // indirect go.opentelemetry.io/otel/metric v1.32.0 // indirect go.opentelemetry.io/otel/sdk v1.31.0 // indirect - go.opentelemetry.io/otel/trace v1.32.0 // indirect + go.opentelemetry.io/otel/trace v1.32.0 go.opentelemetry.io/proto/otlp v1.3.1 // indirect go.uber.org/goleak v1.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect diff --git a/go.sum b/go.sum index c81d3e9d..d550a5c0 100644 --- a/go.sum +++ b/go.sum @@ -193,6 +193,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v1.1.0 h1:tntQDh69XqOCOZsDz0lVJQez/2L6Uu2PdjCQwWCJ3bM= github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4= +github.com/exaring/otelpgx v0.7.0 h1:Wv1x53y6zmmBsEPbWNae6XJAbMNC3KSJmpWRoZxtZr8= +github.com/exaring/otelpgx v0.7.0/go.mod h1:2oRpYkkPBXpvRqQqP0gqkkFPwITRObbpsrA8NT1Fu/I= github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= diff --git a/internal/entdb/client.go b/internal/entdb/client.go index 6548221c..5b9a5e5f 100644 --- a/internal/entdb/client.go +++ b/internal/entdb/client.go @@ -3,10 +3,6 @@ package entdb import ( "context" "database/sql" - "fmt" - "os" - "strconv" - "time" "ariga.io/entcache" entsql "entgo.io/ent/dialect/sql" @@ -15,8 +11,6 @@ import ( "github.com/theopenlane/entx" "github.com/theopenlane/riverboat/pkg/riverqueue" - "github.com/theopenlane/utils/testutils" - migratedb "github.com/theopenlane/core/db" ent "github.com/theopenlane/core/internal/ent/generated" "github.com/theopenlane/core/internal/ent/hooks" @@ -25,11 +19,6 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" // add pgx driver ) -const ( - // defaultDBTestImage is the default docker image to use for testing - defaultDBTestImage = "docker://postgres:17-alpine" -) - type client struct { // config is the entdb configuration config *entx.Config @@ -204,74 +193,3 @@ func (c *client) createEntDBClient(db *entsql.Driver) *ent.Client { return ent.NewClient(cOpts...) } - -// NewTestFixture creates a test container for testing purposes -func NewTestFixture() *testutils.TestFixture { - // Grab the DB environment variable or use the default - testDBURI := os.Getenv("TEST_DB_URL") - testDBContainerExpiry := os.Getenv("TEST_DB_CONTAINER_EXPIRY") - - // If the DB URI is not set, use the default docker image - if testDBURI == "" { - testDBURI = defaultDBTestImage - } - - if testDBContainerExpiry == "" { - testDBContainerExpiry = "5" // default expiry of 5 minutes - } - - expiry, err := strconv.Atoi(testDBContainerExpiry) - if err != nil { - panic(fmt.Sprintf("failed to convert TEST_DB_CONTAINER_EXPIRY to int: %v", err)) - } - - return testutils.GetTestURI(testutils.WithImage(testDBURI), - testutils.WithExpiryMinutes(expiry), - testutils.WithMaxConn(200)) // nolint:mnd -} - -// NewTestClient creates a entdb client that can be used for TEST purposes ONLY -func NewTestClient(ctx context.Context, ctr *testutils.TestFixture, jobOpts []riverqueue.Option, entOpts []ent.Option) (*ent.Client, error) { - dbconf := entx.Config{ - Debug: true, - DriverName: ctr.Dialect, - PrimaryDBSource: ctr.URI, - EnableHistory: true, // enable history so the code path is checked during unit tests - CacheTTL: 0 * time.Second, // do not cache results in tests - } - - // Create the db client - var db *ent.Client - - // Retry the connection to the database to ensure it is up and running - var err error - - // run migrations for tests - jobOpts = append(jobOpts, riverqueue.WithRunMigrations(true)) - - // If a test container is used, retry the connection to the database to ensure it is up and running - if ctr.Pool != nil { - err = ctr.Pool.Retry(func() error { - log.Info().Msg("connecting to database...") - - db, err = New(ctx, dbconf, jobOpts, entOpts...) - if err != nil { - log.Info().Err(err).Msg("retrying connection to database...") - } - - return err - }) - } else { - db, err = New(ctx, dbconf, jobOpts, entOpts...) - } - - if err != nil { - return nil, err - } - - if err := db.Schema.Create(ctx); err != nil { - return nil, err - } - - return db, nil -} diff --git a/internal/entdb/clientpool.go b/internal/entdb/clientpool.go new file mode 100644 index 00000000..2d0e292a --- /dev/null +++ b/internal/entdb/clientpool.go @@ -0,0 +1,47 @@ +package entdb + +import ( + "context" + + "ariga.io/entcache" + "github.com/exaring/otelpgx" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/theopenlane/entx" + + ent "github.com/theopenlane/core/internal/ent/generated" +) + +// NewDBPool creates a new database pool with the given configuration +func NewDBPool(ctx context.Context, cfg *entx.Config) (context.Context, error) { + poolConfig, err := pgxpool.ParseConfig(cfg.PrimaryDBSource) + if err != nil { + return nil, ErrFailedToParseConnectionString + } + + // TODO build config from entx.Config + + poolConfig.ConnConfig.Tracer = otelpgx.NewTracer() + poolConfig.MaxConns = 20 // nolint:gomnd + + pool, err := pgxpool.NewWithConfig(ctx, poolConfig) + if err != nil { + return nil, ErrFfailedToConnectToDatabase + } + + poolDriver := NewPgxPoolDriver(pool) + + cacheDriver := entcache.NewDriver( + poolDriver, + entcache.ContextLevel(), + ) + + realClient := ent.NewClient( + ent.Driver(cacheDriver), + ) + + if debugEnabled { + realClient = realClient.Debug() + } + + return context.WithValue(ctx, dbKey{}, &dbClient{Client: realClient}), nil +} diff --git a/internal/entdb/ent.go b/internal/entdb/ent.go new file mode 100644 index 00000000..a08e459c --- /dev/null +++ b/internal/entdb/ent.go @@ -0,0 +1,105 @@ +package entdb + +import ( + "context" + "errors" + "fmt" + + ent "github.com/theopenlane/core/internal/ent/generated" + + // Required PGX driver + _ "github.com/jackc/pgx/v5/stdlib" +) + +type ( + dbKey struct{} + txKey struct{} +) + +type dbClient struct { + Client *ent.Client +} + +var debugEnabled = false + +// From retrieves a database instance from the context +func From(ctx context.Context) *ent.Client { + tx := ctx.Value(txKey{}) + if tx != nil { + return tx.(*ent.Tx).Client() + } + + db := ctx.Value(dbKey{}) + if db == nil { + return nil + } + + return db.(*dbClient).Client +} + +// TransferContext transfers a database instance from source to target context +func TransferContext(source context.Context, target context.Context) context.Context { + db := source.Value(dbKey{}) + if db == nil { + return target + } + + return context.WithValue(target, dbKey{}, db) +} + +func Tx(ctx context.Context, f func(newCtx context.Context, tx *ent.Tx) error, onError func() error) error { + db := ctx.Value(dbKey{}) + if db == nil { + return ErrDBKeyNotFound + } + + client := db.(*dbClient).Client + + tx, err := client.Tx(ctx) + if err != nil { + return ErrFailedToStartDatabaseTransaction + } + + newCtx := context.WithValue(ctx, txKey{}, tx) + + if err = f(newCtx, tx); err != nil { + finalError := err + + func() { + defer func() { + if err := recover(); err != nil { + finalError = errors.Join(finalError, fmt.Errorf("panic when rolling back: %w", err.(error))) + } + }() + + rollbackErr := tx.Rollback() + if rollbackErr != nil { + finalError = errors.Join(finalError, fmt.Errorf("failed rolling back transaction: %w", rollbackErr)) + } + + if onError != nil { + onErrorErr := onError() + if onErrorErr != nil { + finalError = errors.Join(finalError, onErrorErr) + } + } + }() + + return finalError + } + + err = tx.Commit() + if err != nil { + return ErrFailedToCommitDatabaseTransaction + } + + return nil +} + +func EnableDebug() { + debugEnabled = true +} + +func DisableDebug() { + debugEnabled = false +} diff --git a/internal/entdb/errors.go b/internal/entdb/errors.go new file mode 100644 index 00000000..4e2f4060 --- /dev/null +++ b/internal/entdb/errors.go @@ -0,0 +1,28 @@ +package entdb + +import ( + "errors" +) + +var ( + // ErrInvalidTypeArgs is returned when the type of args is invalid + ErrInvalidTypeArgs = errors.New("dialect/sql: invalid type %T. expect []any for args") + // ErrInvalidTypeResult is returned when the type of result is invalid + ErrInvalidTypeResult = errors.New("dialect/sql: invalid type %T. expect *sql.Result") + // ErrInvalidTypeRows is returned when the type of rows is invalid + ErrInvalidTypeRows = errors.New("dialect/sql: invalid type %T. expect *sql.Rows") + // ErrInvalidTypeIsolation is returned when the type of isolation level is invalid + ErrInvalidTypeIsolation = errors.New("unsupported isolation level: %v") + // ErrDBKeyNotFound is returned when the db key is not found in the context + ErrDBKeyNotFound = errors.New("db key not found in context") + // ErrTxKeyNotFound is returned when the tx key is not found in the context + ErrTxKeyNotFound = errors.New("tx key not found in context") + // ErrFailedToParseConnectionString is returned when the connection string is invalid + ErrFailedToParseConnectionString = errors.New("failed to parse connection string") + // ErrFfailedToConnectToDatabase is returned when the connection to the database fails + ErrFfailedToConnectToDatabase = errors.New("failed to connect to database") + // ErrFailedToStartDatabaseTransaction is returned when the database transaction fails to start + ErrFailedToStartDatabaseTransaction = errors.New("failed to start database transaction") + // ErrFailedToCommitDatabaseTransaction is returned when the database transaction fails to commit + ErrFailedToCommitDatabaseTransaction = errors.New("failed to commit database transaction") +) diff --git a/internal/entdb/pgxpool.go b/internal/entdb/pgxpool.go new file mode 100644 index 00000000..7f4d189f --- /dev/null +++ b/internal/entdb/pgxpool.go @@ -0,0 +1,296 @@ +// ported / adopted originally from: https://github.com/ent/ent/discussions/1797#discussioncomment-5111111 + +package entdb + +import ( + "context" + stdsql "database/sql" + + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" +) + +func NewPgxPoolDriver(pool *pgxpool.Pool) dialect.Driver { + return &EntPgxpoolDriver{ + pool: pool, + tracer: otel.Tracer("pgxpool"), + } +} + +type EntPgxpoolDriver struct { + pool *pgxpool.Pool + tracer trace.Tracer +} + +func (e *EntPgxpoolDriver) Exec(ctx context.Context, query string, args, result any) error { + var _ stdsql.Result + + argv, ok := args.([]any) + if !ok { + return ErrInvalidTypeArgs + } + + switch result := result.(type) { + case nil: + if _, err := e.pool.Exec(ctx, query, argv...); err != nil { + return err + } + case *sql.Result: + commandTag, err := e.pool.Exec(ctx, query, argv...) + if err != nil { + return err + } + + *result = execResult{rowsAffected: commandTag.RowsAffected()} + default: + return ErrInvalidTypeResult + } + + return nil +} + +func (e *EntPgxpoolDriver) Query(ctx context.Context, query string, args, v any) error { + vr, ok := v.(*sql.Rows) + if !ok { + return ErrInvalidTypeRows + } + + argv, ok := args.([]any) + if !ok { + return ErrInvalidTypeArgs + } + + pgxRows, err := e.pool.Query(ctx, query, argv...) + if err != nil { + return err + } + + columnScanner := &entPgxRows{pgxRows: pgxRows} + *vr = sql.Rows{ + ColumnScanner: columnScanner, + } + + return nil +} + +func (e *EntPgxpoolDriver) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { + commandTag, err := e.pool.Exec(ctx, query, args...) + if err != nil { + return nil, err + } + + return &execResult{rowsAffected: commandTag.RowsAffected()}, nil +} + +func (e *EntPgxpoolDriver) Tx(ctx context.Context) (dialect.Tx, error) { + return e.BeginTx(ctx, nil) +} + +func (e *EntPgxpoolDriver) BeginTx(ctx context.Context, opts *sql.TxOptions) (dialect.Tx, error) { + ctx, span := e.tracer.Start(ctx, "BeginTx", trace.WithAttributes()) + + defer span.End() + + pgxOpts, err := getPgxTxOptions(opts) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + span.RecordError(err) + + return nil, err + } + + tx, err := e.pool.BeginTx(ctx, *pgxOpts) + if err != nil { + span.SetStatus(codes.Error, err.Error()) + span.RecordError(err) + + return nil, err + } + + return &EntPgxPoolTx{ + tx: tx, + }, nil +} + +func getPgxTxOptions(opts *sql.TxOptions) (*pgx.TxOptions, error) { + var pgxOpts pgx.TxOptions + if opts == nil { + return &pgxOpts, nil + } + + switch opts.Isolation { + case stdsql.LevelDefault: + case stdsql.LevelReadUncommitted: + pgxOpts.IsoLevel = pgx.ReadUncommitted + case stdsql.LevelReadCommitted: + pgxOpts.IsoLevel = pgx.ReadCommitted + case stdsql.LevelRepeatableRead, stdsql.LevelSnapshot: + pgxOpts.IsoLevel = pgx.RepeatableRead + case stdsql.LevelSerializable: + pgxOpts.IsoLevel = pgx.Serializable + default: + return nil, ErrInvalidTypeIsolation + } + + if opts.ReadOnly { + pgxOpts.AccessMode = pgx.ReadOnly + } + + return &pgxOpts, nil +} + +func (e *EntPgxpoolDriver) Close() error { + e.pool.Close() + + return nil +} + +func (e *EntPgxpoolDriver) Dialect() string { + return dialect.Postgres +} + +type EntPgxPoolTx struct { + tx pgx.Tx +} + +func (e *EntPgxPoolTx) Exec(ctx context.Context, query string, args, result any) error { + var _ stdsql.Result + + argv, ok := args.([]any) + if !ok { + return ErrInvalidTypeArgs + } + + switch result := result.(type) { + case nil: + if _, err := e.tx.Exec(ctx, query, argv...); err != nil { + return err + } + case *sql.Result: + commandTag, err := e.tx.Exec(ctx, query, argv...) + if err != nil { + return err + } + + *result = execResult{rowsAffected: commandTag.RowsAffected()} + default: + return ErrInvalidTypeResult + } + + return nil +} + +func (e *EntPgxPoolTx) ExecContext(ctx context.Context, query string, args ...any) (stdsql.Result, error) { + commandTag, err := e.tx.Exec(ctx, query, args...) + if err != nil { + return nil, err + } + + return &execResult{rowsAffected: commandTag.RowsAffected()}, nil +} + +func (e *EntPgxPoolTx) Query(ctx context.Context, query string, args, v any) error { + vr, ok := v.(*sql.Rows) + + if !ok { + return ErrInvalidTypeRows + } + + argv, ok := args.([]any) + if !ok { + return ErrInvalidTypeArgs + } + + pgxRows, err := e.tx.Query(ctx, query, argv...) + if err != nil { + return err + } + + columnScanner := &entPgxRows{pgxRows: pgxRows} + + *vr = sql.Rows{ + ColumnScanner: columnScanner, + } + + return nil +} + +func (e *EntPgxPoolTx) Commit() error { + return e.tx.Commit(context.TODO()) +} + +func (e *EntPgxPoolTx) Rollback() error { + return e.tx.Rollback(context.TODO()) +} + +func (e *EntPgxPoolTx) PGXTransaction() pgx.Tx { + return e.tx +} + +type entPgxRows struct { + pgxRows pgx.Rows +} + +func (e entPgxRows) Close() error { + e.pgxRows.Close() + + return nil +} + +// ColumnTypes returns column information such as column type, length, and nullable +func (e entPgxRows) ColumnTypes() ([]*stdsql.ColumnType, error) { + return []*stdsql.ColumnType{}, nil +} + +// Columns returns the column names +func (e entPgxRows) Columns() ([]string, error) { + fieldDescs := e.pgxRows.FieldDescriptions() + columnNames := make([]string, len(fieldDescs)) + + for i, fd := range fieldDescs { + columnNames[i] = fd.Name + } + + return columnNames, nil +} + +func (e entPgxRows) Err() error { + return e.pgxRows.Err() +} + +func (e entPgxRows) Next() bool { + return e.pgxRows.Next() +} + +// NextResultSet prepares the next result set for reading; it reports whether +// there is further result sets, or false if there is no further result set +// or if there is an error advancing to it +func (e entPgxRows) NextResultSet() bool { + // For now this does not seem like a must have for normal database functionality. + // This seems to be useful if we want to send 2 sql statements in a single query + // and when the results of the first query are exhausted, then check if the NextResultSet + // has values + return e.pgxRows.Next() +} + +func (e entPgxRows) Scan(dest ...any) error { + return e.pgxRows.Scan(dest...) +} + +type execResult struct { + lastInsertID int64 + rowsAffected int64 +} + +func (e execResult) LastInsertId() (int64, error) { + return e.lastInsertID, nil +} + +func (e execResult) RowsAffected() (int64, error) { + return e.rowsAffected, nil +} diff --git a/internal/entdb/testclient.go b/internal/entdb/testclient.go new file mode 100644 index 00000000..10fa52ba --- /dev/null +++ b/internal/entdb/testclient.go @@ -0,0 +1,91 @@ +package entdb + +import ( + "context" + "fmt" + "os" + "strconv" + "time" + + "github.com/rs/zerolog/log" + ent "github.com/theopenlane/core/internal/ent/generated" + "github.com/theopenlane/entx" + "github.com/theopenlane/riverboat/pkg/riverqueue" + "github.com/theopenlane/utils/testutils" +) + +const ( + // defaultDBTestImage is the default docker image to use for testing + defaultDBTestImage = "docker://postgres:17-alpine" +) + +// NewTestFixture creates a test container for testing purposes +func NewTestFixture() *testutils.TestFixture { + // Grab the DB environment variable or use the default + testDBURI := os.Getenv("TEST_DB_URL") + testDBContainerExpiry := os.Getenv("TEST_DB_CONTAINER_EXPIRY") + + // If the DB URI is not set, use the default docker image + if testDBURI == "" { + testDBURI = defaultDBTestImage + } + + if testDBContainerExpiry == "" { + testDBContainerExpiry = "5" // default expiry of 5 minutes + } + + expiry, err := strconv.Atoi(testDBContainerExpiry) + if err != nil { + panic(fmt.Sprintf("failed to convert TEST_DB_CONTAINER_EXPIRY to int: %v", err)) + } + + return testutils.GetTestURI(testutils.WithImage(testDBURI), + testutils.WithExpiryMinutes(expiry), + testutils.WithMaxConn(200)) // nolint:mnd +} + +// NewTestClient creates a entdb client that can be used for TEST purposes ONLY +func NewTestClient(ctx context.Context, ctr *testutils.TestFixture, jobOpts []riverqueue.Option, entOpts []ent.Option) (*ent.Client, error) { + dbconf := entx.Config{ + Debug: true, + DriverName: ctr.Dialect, + PrimaryDBSource: ctr.URI, + EnableHistory: true, // enable history so the code path is checked during unit tests + CacheTTL: 0 * time.Second, // do not cache results in tests + } + + // Create the db client + var db *ent.Client + + // Retry the connection to the database to ensure it is up and running + var err error + + // run migrations for tests + jobOpts = append(jobOpts, riverqueue.WithRunMigrations(true)) + + // If a test container is used, retry the connection to the database to ensure it is up and running + if ctr.Pool != nil { + err = ctr.Pool.Retry(func() error { + log.Info().Msg("connecting to database...") + + db, err = New(ctx, dbconf, jobOpts, entOpts...) + if err != nil { + log.Info().Err(err).Msg("retrying connection to database...") + } + + return err + }) + } else { + db, err = New(ctx, dbconf, jobOpts, entOpts...) + } + + if err != nil { + return nil, err + } + + if err := db.Schema.Create(ctx); err != nil { + return nil, err + } + + return db, nil +} diff --git a/internal/entdb/tx.go b/internal/entdb/tx.go new file mode 100644 index 00000000..8a3cfcfc --- /dev/null +++ b/internal/entdb/tx.go @@ -0,0 +1,100 @@ +package entdb + +import ( + "context" + "fmt" + + "github.com/rs/zerolog/log" + ent "github.com/theopenlane/core/internal/ent/generated" +) + +// WithTxResult wraps the given function with a transaction. +// If the function returns an error, the transaction is rolled back. +func WithTxResult[T any](ctx context.Context, client *ent.Client, fn func(tx *ent.Tx) (T, error)) (T, error) { + var zero T + + // Start a new transaction + log.Ctx(ctx).Info().Msg("Starting transaction") + + tx, err := client.Tx(ctx) + if err != nil { + return zero, err + } + + // Flag to keep track of transaction finalization + transactionCompleted := false + + defer func() { + if transactionCompleted { + return + } + + log.Ctx(ctx).Info().Msg("Transaction not completed, attempting to rollback") + + if v := recover(); v != nil { + // Attempt to rollback on panic + err := tx.Rollback() // Ignore rollback error here as panic takes precedence + log.Ctx(ctx).Info().Msgf("Rollback failed: %v", err) + panic(v) + } + }() + + // Execute the function within the transaction + log.Ctx(ctx).Info().Msg("Executing function within transaction") + + result, err := fn(tx) + if err != nil { + // Rollback transaction on error + log.Ctx(ctx).Info().Msgf("Rolling back transaction on error: %v", err) + + if rerr := tx.Rollback(); rerr != nil { + err = fmt.Errorf("%w: rolling back transaction: %v", err, rerr) + } + + return zero, err + } + + // Commit the transaction + log.Ctx(ctx).Info().Msg("Committing transaction") + + if err := tx.Commit(); err != nil { + log.Ctx(ctx).Info().Msgf("Error committing transaction: %v", err) + + return zero, fmt.Errorf("committing transaction: %w", err) + } + + // Mark the transaction as completed to prevent deferred rollback + log.Ctx(ctx).Info().Msg("Transaction completed successfully") + + transactionCompleted = true + + return result, nil +} + +func WithTx(ctx context.Context, client *ent.Client, fn func(tx *ent.Tx) error) error { + tx, err := client.Tx(ctx) + if err != nil { + return err + } + + defer func() { + if v := recover(); v != nil { + tx.Rollback() + panic(v) + } + }() + + if err := fn(tx); err != nil { + if rerr := tx.Rollback(); rerr != nil { + err = fmt.Errorf("%w: rolling back transaction: %v", err, rerr) + } + + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +}