Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial stab at establishing pgxpool #257

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ require (
github.com/brianvoe/gofakeit/v7 v7.1.2
github.com/danielgtaylor/huma/v2 v2.26.0
github.com/dustinkirkland/golang-petname v0.0.0-20240428194347-eebcea082ee0
github.com/exaring/otelpgx v0.7.0
github.com/gertd/go-pluralize v0.2.1
github.com/getkin/kin-openapi v0.128.0
github.com/go-viper/mapstructure/v2 v2.2.1
Expand Down Expand Up @@ -290,14 +291,14 @@ require (
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib v1.29.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.56.0 // indirect
go.opentelemetry.io/otel v1.32.0 // indirect
go.opentelemetry.io/otel v1.32.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.31.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.31.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.29.0 // indirect
go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.29.0 // indirect
go.opentelemetry.io/otel/metric v1.32.0 // indirect
go.opentelemetry.io/otel/sdk v1.31.0 // indirect
go.opentelemetry.io/otel/trace v1.32.0 // indirect
go.opentelemetry.io/otel/trace v1.32.0
go.opentelemetry.io/proto/otlp v1.3.1 // indirect
go.uber.org/goleak v1.3.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1m
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
github.com/envoyproxy/protoc-gen-validate v1.1.0 h1:tntQDh69XqOCOZsDz0lVJQez/2L6Uu2PdjCQwWCJ3bM=
github.com/envoyproxy/protoc-gen-validate v1.1.0/go.mod h1:sXRDRVmzEbkM7CVcM06s9shE/m23dg3wzjl0UWqJ2q4=
github.com/exaring/otelpgx v0.7.0 h1:Wv1x53y6zmmBsEPbWNae6XJAbMNC3KSJmpWRoZxtZr8=
github.com/exaring/otelpgx v0.7.0/go.mod h1:2oRpYkkPBXpvRqQqP0gqkkFPwITRObbpsrA8NT1Fu/I=
github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4=
github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
Expand Down
82 changes: 0 additions & 82 deletions internal/entdb/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ package entdb
import (
"context"
"database/sql"
"fmt"
"os"
"strconv"
"time"

"ariga.io/entcache"
entsql "entgo.io/ent/dialect/sql"
Expand All @@ -15,8 +11,6 @@ import (
"github.com/theopenlane/entx"
"github.com/theopenlane/riverboat/pkg/riverqueue"

"github.com/theopenlane/utils/testutils"

migratedb "github.com/theopenlane/core/db"
ent "github.com/theopenlane/core/internal/ent/generated"
"github.com/theopenlane/core/internal/ent/hooks"
Expand All @@ -25,11 +19,6 @@ import (
_ "github.com/jackc/pgx/v5/stdlib" // add pgx driver
)

const (
// defaultDBTestImage is the default docker image to use for testing
defaultDBTestImage = "docker://postgres:17-alpine"
)

type client struct {
// config is the entdb configuration
config *entx.Config
Expand Down Expand Up @@ -204,74 +193,3 @@ func (c *client) createEntDBClient(db *entsql.Driver) *ent.Client {

return ent.NewClient(cOpts...)
}

// NewTestFixture creates a test container for testing purposes
func NewTestFixture() *testutils.TestFixture {
// Grab the DB environment variable or use the default
testDBURI := os.Getenv("TEST_DB_URL")
testDBContainerExpiry := os.Getenv("TEST_DB_CONTAINER_EXPIRY")

// If the DB URI is not set, use the default docker image
if testDBURI == "" {
testDBURI = defaultDBTestImage
}

if testDBContainerExpiry == "" {
testDBContainerExpiry = "5" // default expiry of 5 minutes
}

expiry, err := strconv.Atoi(testDBContainerExpiry)
if err != nil {
panic(fmt.Sprintf("failed to convert TEST_DB_CONTAINER_EXPIRY to int: %v", err))
}

return testutils.GetTestURI(testutils.WithImage(testDBURI),
testutils.WithExpiryMinutes(expiry),
testutils.WithMaxConn(200)) // nolint:mnd
}

// NewTestClient creates a entdb client that can be used for TEST purposes ONLY
func NewTestClient(ctx context.Context, ctr *testutils.TestFixture, jobOpts []riverqueue.Option, entOpts []ent.Option) (*ent.Client, error) {
dbconf := entx.Config{
Debug: true,
DriverName: ctr.Dialect,
PrimaryDBSource: ctr.URI,
EnableHistory: true, // enable history so the code path is checked during unit tests
CacheTTL: 0 * time.Second, // do not cache results in tests
}

// Create the db client
var db *ent.Client

// Retry the connection to the database to ensure it is up and running
var err error

// run migrations for tests
jobOpts = append(jobOpts, riverqueue.WithRunMigrations(true))

// If a test container is used, retry the connection to the database to ensure it is up and running
if ctr.Pool != nil {
err = ctr.Pool.Retry(func() error {
log.Info().Msg("connecting to database...")

db, err = New(ctx, dbconf, jobOpts, entOpts...)
if err != nil {
log.Info().Err(err).Msg("retrying connection to database...")
}

return err
})
} else {
db, err = New(ctx, dbconf, jobOpts, entOpts...)
}

if err != nil {
return nil, err
}

if err := db.Schema.Create(ctx); err != nil {
return nil, err
}

return db, nil
}
47 changes: 47 additions & 0 deletions internal/entdb/clientpool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package entdb

import (
"context"

"ariga.io/entcache"
"github.com/exaring/otelpgx"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/theopenlane/entx"

ent "github.com/theopenlane/core/internal/ent/generated"
)

// NewDBPool creates a new database pool with the given configuration
func NewDBPool(ctx context.Context, cfg *entx.Config) (context.Context, error) {
poolConfig, err := pgxpool.ParseConfig(cfg.PrimaryDBSource)
if err != nil {
return nil, ErrFailedToParseConnectionString
}

// TODO build config from entx.Config

poolConfig.ConnConfig.Tracer = otelpgx.NewTracer()
poolConfig.MaxConns = 20 // nolint:gomnd

pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, ErrFfailedToConnectToDatabase
}

poolDriver := NewPgxPoolDriver(pool)

cacheDriver := entcache.NewDriver(
poolDriver,
entcache.ContextLevel(),
)

realClient := ent.NewClient(
ent.Driver(cacheDriver),
)

if debugEnabled {
realClient = realClient.Debug()
}

return context.WithValue(ctx, dbKey{}, &dbClient{Client: realClient}), nil
}
105 changes: 105 additions & 0 deletions internal/entdb/ent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package entdb

import (
"context"
"errors"
"fmt"

ent "github.com/theopenlane/core/internal/ent/generated"

// Required PGX driver
_ "github.com/jackc/pgx/v5/stdlib"
)

type (
dbKey struct{}
txKey struct{}
)

type dbClient struct {
Client *ent.Client
}

var debugEnabled = false

// From retrieves a database instance from the context
func From(ctx context.Context) *ent.Client {
tx := ctx.Value(txKey{})
if tx != nil {
return tx.(*ent.Tx).Client()
}

db := ctx.Value(dbKey{})
if db == nil {
return nil
}

return db.(*dbClient).Client
}

// TransferContext transfers a database instance from source to target context
func TransferContext(source context.Context, target context.Context) context.Context {
db := source.Value(dbKey{})
if db == nil {
return target
}

return context.WithValue(target, dbKey{}, db)
}

func Tx(ctx context.Context, f func(newCtx context.Context, tx *ent.Tx) error, onError func() error) error {
db := ctx.Value(dbKey{})
if db == nil {
return ErrDBKeyNotFound
}

client := db.(*dbClient).Client

tx, err := client.Tx(ctx)
if err != nil {
return ErrFailedToStartDatabaseTransaction
}

newCtx := context.WithValue(ctx, txKey{}, tx)

if err = f(newCtx, tx); err != nil {
finalError := err

func() {
defer func() {
if err := recover(); err != nil {
finalError = errors.Join(finalError, fmt.Errorf("panic when rolling back: %w", err.(error)))
}
}()

rollbackErr := tx.Rollback()
if rollbackErr != nil {
finalError = errors.Join(finalError, fmt.Errorf("failed rolling back transaction: %w", rollbackErr))
}

if onError != nil {
onErrorErr := onError()
if onErrorErr != nil {
finalError = errors.Join(finalError, onErrorErr)
}
}
}()

return finalError
}

err = tx.Commit()
if err != nil {
return ErrFailedToCommitDatabaseTransaction
}

return nil
}

func EnableDebug() {
debugEnabled = true
}

func DisableDebug() {
debugEnabled = false
}
28 changes: 28 additions & 0 deletions internal/entdb/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package entdb

import (
"errors"
)

var (
// ErrInvalidTypeArgs is returned when the type of args is invalid
ErrInvalidTypeArgs = errors.New("dialect/sql: invalid type %T. expect []any for args")
// ErrInvalidTypeResult is returned when the type of result is invalid
ErrInvalidTypeResult = errors.New("dialect/sql: invalid type %T. expect *sql.Result")
// ErrInvalidTypeRows is returned when the type of rows is invalid
ErrInvalidTypeRows = errors.New("dialect/sql: invalid type %T. expect *sql.Rows")
// ErrInvalidTypeIsolation is returned when the type of isolation level is invalid
ErrInvalidTypeIsolation = errors.New("unsupported isolation level: %v")
// ErrDBKeyNotFound is returned when the db key is not found in the context
ErrDBKeyNotFound = errors.New("db key not found in context")
// ErrTxKeyNotFound is returned when the tx key is not found in the context
ErrTxKeyNotFound = errors.New("tx key not found in context")
// ErrFailedToParseConnectionString is returned when the connection string is invalid
ErrFailedToParseConnectionString = errors.New("failed to parse connection string")
// ErrFfailedToConnectToDatabase is returned when the connection to the database fails
ErrFfailedToConnectToDatabase = errors.New("failed to connect to database")
// ErrFailedToStartDatabaseTransaction is returned when the database transaction fails to start
ErrFailedToStartDatabaseTransaction = errors.New("failed to start database transaction")
// ErrFailedToCommitDatabaseTransaction is returned when the database transaction fails to commit
ErrFailedToCommitDatabaseTransaction = errors.New("failed to commit database transaction")
)
Loading