Skip to content

Commit a920fcf

Browse files
authored
Fix db engine (#32351)
Fix #32349
1 parent d70af38 commit a920fcf

File tree

10 files changed

+172
-74
lines changed

10 files changed

+172
-74
lines changed

Diff for: models/db/context.go

+67-47
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ package db
66
import (
77
"context"
88
"database/sql"
9+
"errors"
10+
"runtime"
11+
"slices"
12+
"sync"
13+
14+
"code.gitea.io/gitea/modules/setting"
915

1016
"xorm.io/builder"
1117
"xorm.io/xorm"
@@ -15,76 +21,90 @@ import (
1521
// will be overwritten by Init with HammerContext
1622
var DefaultContext context.Context
1723

18-
// contextKey is a value for use with context.WithValue.
19-
type contextKey struct {
20-
name string
21-
}
24+
type engineContextKeyType struct{}
2225

23-
// enginedContextKey is a context key. It is used with context.Value() to get the current Engined for the context
24-
var (
25-
enginedContextKey = &contextKey{"engined"}
26-
_ Engined = &Context{}
27-
)
26+
var engineContextKey = engineContextKeyType{}
2827

2928
// Context represents a db context
3029
type Context struct {
3130
context.Context
32-
e Engine
33-
transaction bool
34-
}
35-
36-
func newContext(ctx context.Context, e Engine, transaction bool) *Context {
37-
return &Context{
38-
Context: ctx,
39-
e: e,
40-
transaction: transaction,
41-
}
42-
}
43-
44-
// InTransaction if context is in a transaction
45-
func (ctx *Context) InTransaction() bool {
46-
return ctx.transaction
31+
engine Engine
4732
}
4833

49-
// Engine returns db engine
50-
func (ctx *Context) Engine() Engine {
51-
return ctx.e
34+
func newContext(ctx context.Context, e Engine) *Context {
35+
return &Context{Context: ctx, engine: e}
5236
}
5337

5438
// Value shadows Value for context.Context but allows us to get ourselves and an Engined object
5539
func (ctx *Context) Value(key any) any {
56-
if key == enginedContextKey {
40+
if key == engineContextKey {
5741
return ctx
5842
}
5943
return ctx.Context.Value(key)
6044
}
6145

6246
// WithContext returns this engine tied to this context
6347
func (ctx *Context) WithContext(other context.Context) *Context {
64-
return newContext(ctx, ctx.e.Context(other), ctx.transaction)
48+
return newContext(ctx, ctx.engine.Context(other))
6549
}
6650

67-
// Engined structs provide an Engine
68-
type Engined interface {
69-
Engine() Engine
51+
var (
52+
contextSafetyOnce sync.Once
53+
contextSafetyDeniedFuncPCs []uintptr
54+
)
55+
56+
func contextSafetyCheck(e Engine) {
57+
if setting.IsProd && !setting.IsInTesting {
58+
return
59+
}
60+
if e == nil {
61+
return
62+
}
63+
// Only do this check for non-end-users. If the problem could be fixed in the future, this code could be removed.
64+
contextSafetyOnce.Do(func() {
65+
// try to figure out the bad functions to deny
66+
type m struct{}
67+
_ = e.SQL("SELECT 1").Iterate(&m{}, func(int, any) error {
68+
callers := make([]uintptr, 32)
69+
callerNum := runtime.Callers(1, callers)
70+
for i := 0; i < callerNum; i++ {
71+
if funcName := runtime.FuncForPC(callers[i]).Name(); funcName == "xorm.io/xorm.(*Session).Iterate" {
72+
contextSafetyDeniedFuncPCs = append(contextSafetyDeniedFuncPCs, callers[i])
73+
}
74+
}
75+
return nil
76+
})
77+
if len(contextSafetyDeniedFuncPCs) != 1 {
78+
panic(errors.New("unable to determine the functions to deny"))
79+
}
80+
})
81+
82+
// it should be very fast: xxxx ns/op
83+
callers := make([]uintptr, 32)
84+
callerNum := runtime.Callers(3, callers) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine
85+
for i := 0; i < callerNum; i++ {
86+
if slices.Contains(contextSafetyDeniedFuncPCs, callers[i]) {
87+
panic(errors.New("using database context in an iterator would cause corrupted results"))
88+
}
89+
}
7090
}
7191

72-
// GetEngine will get a db Engine from this context or return an Engine restricted to this context
92+
// GetEngine gets an existing db Engine/Statement or creates a new Session
7393
func GetEngine(ctx context.Context) Engine {
74-
if e := getEngine(ctx); e != nil {
94+
if e := getExistingEngine(ctx); e != nil {
7595
return e
7696
}
7797
return x.Context(ctx)
7898
}
7999

80-
// getEngine will get a db Engine from this context or return nil
81-
func getEngine(ctx context.Context) Engine {
82-
if engined, ok := ctx.(Engined); ok {
83-
return engined.Engine()
100+
// getExistingEngine gets an existing db Engine/Statement from this context or returns nil
101+
func getExistingEngine(ctx context.Context) (e Engine) {
102+
defer func() { contextSafetyCheck(e) }()
103+
if engined, ok := ctx.(*Context); ok {
104+
return engined.engine
84105
}
85-
enginedInterface := ctx.Value(enginedContextKey)
86-
if enginedInterface != nil {
87-
return enginedInterface.(Engined).Engine()
106+
if engined, ok := ctx.Value(engineContextKey).(*Context); ok {
107+
return engined.engine
88108
}
89109
return nil
90110
}
@@ -132,23 +152,23 @@ func (c *halfCommitter) Close() error {
132152
// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
133153
func TxContext(parentCtx context.Context) (*Context, Committer, error) {
134154
if sess, ok := inTransaction(parentCtx); ok {
135-
return newContext(parentCtx, sess, true), &halfCommitter{committer: sess}, nil
155+
return newContext(parentCtx, sess), &halfCommitter{committer: sess}, nil
136156
}
137157

138158
sess := x.NewSession()
139159
if err := sess.Begin(); err != nil {
140-
sess.Close()
160+
_ = sess.Close()
141161
return nil, nil, err
142162
}
143163

144-
return newContext(DefaultContext, sess, true), sess, nil
164+
return newContext(DefaultContext, sess), sess, nil
145165
}
146166

147167
// WithTx represents executing database operations on a transaction, if the transaction exist,
148168
// this function will reuse it otherwise will create a new one and close it when finished.
149169
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
150170
if sess, ok := inTransaction(parentCtx); ok {
151-
err := f(newContext(parentCtx, sess, true))
171+
err := f(newContext(parentCtx, sess))
152172
if err != nil {
153173
// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
154174
_ = sess.Close()
@@ -165,7 +185,7 @@ func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error)
165185
return err
166186
}
167187

168-
if err := f(newContext(parentCtx, sess, true)); err != nil {
188+
if err := f(newContext(parentCtx, sess)); err != nil {
169189
return err
170190
}
171191

@@ -312,7 +332,7 @@ func InTransaction(ctx context.Context) bool {
312332
}
313333

314334
func inTransaction(ctx context.Context) (*xorm.Session, bool) {
315-
e := getEngine(ctx)
335+
e := getExistingEngine(ctx)
316336
if e == nil {
317337
return nil, false
318338
}

Diff for: models/db/context_test.go

+44
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,47 @@ func TestTxContext(t *testing.T) {
8484
}))
8585
}
8686
}
87+
88+
func TestContextSafety(t *testing.T) {
89+
type TestModel1 struct {
90+
ID int64
91+
}
92+
type TestModel2 struct {
93+
ID int64
94+
}
95+
assert.NoError(t, unittest.GetXORMEngine().Sync(&TestModel1{}, &TestModel2{}))
96+
assert.NoError(t, db.TruncateBeans(db.DefaultContext, &TestModel1{}, &TestModel2{}))
97+
testCount := 10
98+
for i := 1; i <= testCount; i++ {
99+
assert.NoError(t, db.Insert(db.DefaultContext, &TestModel1{ID: int64(i)}))
100+
assert.NoError(t, db.Insert(db.DefaultContext, &TestModel2{ID: int64(-i)}))
101+
}
102+
103+
actualCount := 0
104+
// here: db.GetEngine(db.DefaultContext) is a new *Session created from *Engine
105+
_ = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
106+
_ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
107+
// here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false,
108+
// and the internal states (including "cond" and others) are always there and not be reset in this callback.
109+
m1 := bean.(*TestModel1)
110+
assert.EqualValues(t, i+1, m1.ID)
111+
112+
// here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ...
113+
// and it conflicts with the "Iterate"'s internal states.
114+
// has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID})
115+
116+
actualCount++
117+
return nil
118+
})
119+
return nil
120+
})
121+
assert.EqualValues(t, testCount, actualCount)
122+
123+
// deny the bad usages
124+
assert.PanicsWithError(t, "using database context in an iterator would cause corrupted results", func() {
125+
_ = unittest.GetXORMEngine().Iterate(&TestModel1{}, func(i int, bean any) error {
126+
_ = db.GetEngine(db.DefaultContext)
127+
return nil
128+
})
129+
})
130+
}

Diff for: models/db/engine.go

+1-4
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,7 @@ func InitEngine(ctx context.Context) error {
161161
// SetDefaultEngine sets the default engine for db
162162
func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) {
163163
x = eng
164-
DefaultContext = &Context{
165-
Context: ctx,
166-
e: x,
167-
}
164+
DefaultContext = &Context{Context: ctx, engine: x}
168165
}
169166

170167
// UnsetDefaultEngine closes and unsets the default engine

Diff for: models/db/install/db.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
)
1212

1313
func getXORMEngine() *xorm.Engine {
14-
return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine)
14+
return db.GetEngine(db.DefaultContext).(*xorm.Engine)
1515
}
1616

1717
// CheckDatabaseConnection checks the database connection

Diff for: models/db/iterate.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
"xorm.io/builder"
1212
)
1313

14-
// Iterate iterate all the Bean object
14+
// Iterate iterates all the Bean object
1515
func Iterate[Bean any](ctx context.Context, cond builder.Cond, f func(ctx context.Context, bean *Bean) error) error {
1616
var start int
1717
batchSize := setting.Database.IterateBufferSize

Diff for: models/packages/debian/search.go

+16-15
Original file line numberDiff line numberDiff line change
@@ -75,26 +75,27 @@ func ExistPackages(ctx context.Context, opts *PackageSearchOptions) (bool, error
7575
}
7676

7777
// SearchPackages gets the packages matching the search options
78-
func SearchPackages(ctx context.Context, opts *PackageSearchOptions, iter func(*packages.PackageFileDescriptor)) error {
79-
return db.GetEngine(ctx).
78+
func SearchPackages(ctx context.Context, opts *PackageSearchOptions) ([]*packages.PackageFileDescriptor, error) {
79+
var pkgFiles []*packages.PackageFile
80+
err := db.GetEngine(ctx).
8081
Table("package_file").
8182
Select("package_file.*").
8283
Join("INNER", "package_version", "package_version.id = package_file.version_id").
8384
Join("INNER", "package", "package.id = package_version.package_id").
8485
Where(opts.toCond()).
85-
Asc("package.lower_name", "package_version.created_unix").
86-
Iterate(new(packages.PackageFile), func(_ int, bean any) error {
87-
pf := bean.(*packages.PackageFile)
88-
89-
pfd, err := packages.GetPackageFileDescriptor(ctx, pf)
90-
if err != nil {
91-
return err
92-
}
93-
94-
iter(pfd)
95-
96-
return nil
97-
})
86+
Asc("package.lower_name", "package_version.created_unix").Find(&pkgFiles)
87+
if err != nil {
88+
return nil, err
89+
}
90+
pfds := make([]*packages.PackageFileDescriptor, 0, len(pkgFiles))
91+
for _, pf := range pkgFiles {
92+
pfd, err := packages.GetPackageFileDescriptor(ctx, pf)
93+
if err != nil {
94+
return nil, err
95+
}
96+
pfds = append(pfds, pfd)
97+
}
98+
return pfds, nil
9899
}
99100

100101
// GetDistributions gets all available distributions

Diff for: models/unittest/fixtures.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) {
2525
if len(engine) == 1 {
2626
return engine[0]
2727
}
28-
return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine)
28+
return db.GetEngine(db.DefaultContext).(*xorm.Engine)
2929
}
3030

3131
// InitFixtures initialize test fixtures for a test database

Diff for: services/packages/cleanup/cleanup.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import (
2222
rpm_service "code.gitea.io/gitea/services/packages/rpm"
2323
)
2424

25-
// Task method to execute cleanup rules and cleanup expired package data
25+
// CleanupTask executes cleanup rules and cleanup expired package data
2626
func CleanupTask(ctx context.Context, olderThan time.Duration) error {
2727
if err := ExecuteCleanupRules(ctx); err != nil {
2828
return err

Diff for: services/packages/debian/repository.go

+5-4
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,11 @@ func buildPackagesIndices(ctx context.Context, ownerID int64, repoVersion *packa
206206
w := io.MultiWriter(packagesContent, gzw, xzw)
207207

208208
addSeparator := false
209-
if err := debian_model.SearchPackages(ctx, opts, func(pfd *packages_model.PackageFileDescriptor) {
209+
pfds, err := debian_model.SearchPackages(ctx, opts)
210+
if err != nil {
211+
return err
212+
}
213+
for _, pfd := range pfds {
210214
if addSeparator {
211215
fmt.Fprintln(w)
212216
}
@@ -220,10 +224,7 @@ func buildPackagesIndices(ctx context.Context, ownerID int64, repoVersion *packa
220224
fmt.Fprintf(w, "SHA1: %s\n", pfd.Blob.HashSHA1)
221225
fmt.Fprintf(w, "SHA256: %s\n", pfd.Blob.HashSHA256)
222226
fmt.Fprintf(w, "SHA512: %s\n", pfd.Blob.HashSHA512)
223-
}); err != nil {
224-
return err
225227
}
226-
227228
gzw.Close()
228229
xzw.Close()
229230

0 commit comments

Comments
 (0)