From 1f50048ac92a9624ec703500a0291e8f46a0a8c1 Mon Sep 17 00:00:00 2001 From: wxiaoguang Date: Sat, 30 Aug 2025 01:04:06 +0800 Subject: [PATCH] Refactor db package (#35380) Remove unnecessary code --- models/auth/session.go | 2 +- models/auth/source_test.go | 3 +-- models/db/context.go | 13 ++++----- models/db/context_test.go | 43 +++++++++++++++++------------- models/db/engine.go | 6 ----- models/db/engine_test.go | 2 +- models/db/index.go | 7 +---- models/organization/org_list.go | 7 +---- models/unittest/consistency.go | 2 +- models/unittest/fixtures_loader.go | 2 +- models/unittest/unit_tests.go | 2 +- 11 files changed, 37 insertions(+), 52 deletions(-) diff --git a/models/auth/session.go b/models/auth/session.go index 0378d0ec6f0..dbdcde03a0b 100644 --- a/models/auth/session.go +++ b/models/auth/session.go @@ -86,7 +86,7 @@ func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, er } } - if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil { + if _, err := db.Exec(ctx, "UPDATE `session` SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil { return nil, err } diff --git a/models/auth/source_test.go b/models/auth/source_test.go index 285f55a24b0..ebc462c5811 100644 --- a/models/auth/source_test.go +++ b/models/auth/source_test.go @@ -8,7 +8,6 @@ import ( "testing" auth_model "code.gitea.io/gitea/models/auth" - "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" "code.gitea.io/gitea/modules/json" @@ -40,7 +39,7 @@ func (source *TestSource) ToDB() ([]byte, error) { func TestDumpAuthSource(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - authSourceSchema, err := db.TableInfo(new(auth_model.Source)) + authSourceSchema, err := unittest.GetXORMEngine().TableInfo(new(auth_model.Source)) assert.NoError(t, err) auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource)) diff --git a/models/db/context.go b/models/db/context.go index 0938fdeced9..8bb14f1389b 100644 --- a/models/db/context.go +++ b/models/db/context.go @@ -61,17 +61,19 @@ func contextSafetyCheck(e Engine) { callerNum := runtime.Callers(3, callers) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine for i := range callerNum { if slices.Contains(contextSafetyDeniedFuncPCs, callers[i]) { - panic(errors.New("using database context in an iterator would cause corrupted results")) + panic(errors.New("using session context in an iterator would cause corrupted results")) } } } // GetEngine gets an existing db Engine/Statement or creates a new Session -func GetEngine(ctx context.Context) (e Engine) { - defer func() { contextSafetyCheck(e) }() +func GetEngine(ctx context.Context) Engine { if engine, ok := ctx.Value(engineContextKey).(Engine); ok { + // if reusing the existing session, need to do "contextSafetyCheck" because the Iterate creates a "autoResetStatement=false" session + contextSafetyCheck(engine) return engine } + // no need to do "contextSafetyCheck" because it's a new Session return xormEngine.Context(ctx) } @@ -301,11 +303,6 @@ func CountByBean(ctx context.Context, bean any) (int64, error) { return GetEngine(ctx).Count(bean) } -// TableName returns the table name according a bean object -func TableName(bean any) string { - return xormEngine.TableName(bean) -} - // InTransaction returns true if the engine is in a transaction otherwise return false func InTransaction(ctx context.Context) bool { return getTransactionSession(ctx) != nil diff --git a/models/db/context_test.go b/models/db/context_test.go index 1bc84ffbf93..1719a7bfe85 100644 --- a/models/db/context_test.go +++ b/models/db/context_test.go @@ -100,31 +100,36 @@ func TestContextSafety(t *testing.T) { assert.NoError(t, db.Insert(t.Context(), &TestModel2{ID: int64(-i)})) } - actualCount := 0 - // here: db.GetEngine(t.Context()) is a new *Session created from *Engine - _ = db.WithTx(t.Context(), func(ctx context.Context) error { - _ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error { - // here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false, - // and the internal states (including "cond" and others) are always there and not be reset in this callback. - m1 := bean.(*TestModel1) - assert.EqualValues(t, i+1, m1.ID) + t.Run("Show-XORM-Bug", func(t *testing.T) { + actualCount := 0 + // here: db.GetEngine(t.Context()) is a new *Session created from *Engine + _ = db.WithTx(t.Context(), func(ctx context.Context) error { + _ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error { + // here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false, + // and the internal states (including "cond" and others) are always there and not be reset in this callback. + m1 := bean.(*TestModel1) + assert.EqualValues(t, i+1, m1.ID) - // here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ... - // and it conflicts with the "Iterate"'s internal states. - // has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID}) + // here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ... + // and it conflicts with the "Iterate"'s internal states. + // has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID}) - actualCount++ + actualCount++ + return nil + }) return nil }) - return nil + assert.Equal(t, testCount, actualCount) }) - assert.Equal(t, testCount, actualCount) - // deny the bad usages - assert.PanicsWithError(t, "using database context in an iterator would cause corrupted results", func() { - _ = unittest.GetXORMEngine().Iterate(&TestModel1{}, func(i int, bean any) error { - _ = db.GetEngine(t.Context()) - return nil + t.Run("DenyBadUsage", func(t *testing.T) { + assert.PanicsWithError(t, "using session context in an iterator would cause corrupted results", func() { + _ = db.WithTx(t.Context(), func(ctx context.Context) error { + return db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error { + _ = db.GetEngine(ctx) + return nil + }) + }) }) }) } diff --git a/models/db/engine.go b/models/db/engine.go index 4b12925b1cd..b08799210e8 100755 --- a/models/db/engine.go +++ b/models/db/engine.go @@ -12,7 +12,6 @@ import ( "strings" "xorm.io/xorm" - "xorm.io/xorm/schemas" _ "github.com/go-sql-driver/mysql" // Needed for the MySQL driver _ "github.com/lib/pq" // Needed for the Postgresql driver @@ -67,11 +66,6 @@ var ( _ Engine = (*xorm.Session)(nil) ) -// TableInfo returns table's information via an object -func TableInfo(v any) (*schemas.Table, error) { - return xormEngine.TableInfo(v) -} - // RegisterModel registers model, if initFuncs provided, it will be invoked after data model sync func RegisterModel(bean any, initFunc ...func() error) { registeredModels = append(registeredModels, bean) diff --git a/models/db/engine_test.go b/models/db/engine_test.go index 167525a5a8a..1c218df77f3 100644 --- a/models/db/engine_test.go +++ b/models/db/engine_test.go @@ -70,7 +70,7 @@ func TestPrimaryKeys(t *testing.T) { } for _, bean := range beans { - table, err := db.TableInfo(bean) + table, err := db.GetXORMEngineForTesting().TableInfo(bean) if err != nil { t.Fatal(err) } diff --git a/models/db/index.go b/models/db/index.go index 29254b1f07a..7a11645bd45 100644 --- a/models/db/index.go +++ b/models/db/index.go @@ -19,12 +19,7 @@ type ResourceIndex struct { MaxIndex int64 `xorm:"index"` } -var ( - // ErrResouceOutdated represents an error when request resource outdated - ErrResouceOutdated = errors.New("resource outdated") - // ErrGetResourceIndexFailed represents an error when resource index retries 3 times - ErrGetResourceIndexFailed = errors.New("get resource index failed") -) +var ErrGetResourceIndexFailed = errors.New("get resource index failed") // SyncMaxResourceIndex sync the max index with the resource func SyncMaxResourceIndex(ctx context.Context, tableName string, groupID, maxIndex int64) (err error) { diff --git a/models/organization/org_list.go b/models/organization/org_list.go index 81457191fe6..f37961b5f62 100644 --- a/models/organization/org_list.go +++ b/models/organization/org_list.go @@ -105,11 +105,6 @@ type MinimalOrg = Organization // GetUserOrgsList returns all organizations the given user has access to func GetUserOrgsList(ctx context.Context, user *user_model.User) ([]*MinimalOrg, error) { - schema, err := db.TableInfo(new(user_model.User)) - if err != nil { - return nil, err - } - outputCols := []string{ "id", "name", @@ -122,7 +117,7 @@ func GetUserOrgsList(ctx context.Context, user *user_model.User) ([]*MinimalOrg, selectColumns := &strings.Builder{} for i, col := range outputCols { - fmt.Fprintf(selectColumns, "`%s`.%s", schema.Name, col) + _, _ = fmt.Fprintf(selectColumns, "`user`.%s", col) if i < len(outputCols)-1 { selectColumns.WriteString(", ") } diff --git a/models/unittest/consistency.go b/models/unittest/consistency.go index b51cadce0aa..8447bd93ba7 100644 --- a/models/unittest/consistency.go +++ b/models/unittest/consistency.go @@ -45,7 +45,7 @@ func CheckConsistencyFor(t TestingT, beansToCheck ...any) { } func checkForConsistency(t TestingT, bean any) { - tb, err := db.TableInfo(bean) + tb, err := GetXORMEngine().TableInfo(bean) assert.NoError(t, err) f := consistencyCheckMap[tb.Name] require.NotNil(t, f, "unknown bean type: %#v", bean) diff --git a/models/unittest/fixtures_loader.go b/models/unittest/fixtures_loader.go index 0560da83492..d92b0cdb14d 100644 --- a/models/unittest/fixtures_loader.go +++ b/models/unittest/fixtures_loader.go @@ -218,7 +218,7 @@ func NewFixturesLoader(x *xorm.Engine, opts FixturesOptions) (FixturesLoader, er xormBeans, _ := db.NamesToBean() f.xormTableNames = map[string]bool{} for _, bean := range xormBeans { - f.xormTableNames[db.TableName(bean)] = true + f.xormTableNames[x.TableName(bean)] = true } return f, nil diff --git a/models/unittest/unit_tests.go b/models/unittest/unit_tests.go index 49d42d7fe60..c49b26fea45 100644 --- a/models/unittest/unit_tests.go +++ b/models/unittest/unit_tests.go @@ -159,7 +159,7 @@ func DumpQueryResult(t require.TestingT, sqlOrBean any, sqlArgs ...any) { goDB := x.DB().DB sql, ok := sqlOrBean.(string) if !ok { - sql = "SELECT * FROM " + db.TableName(sqlOrBean) + sql = "SELECT * FROM " + x.TableName(sqlOrBean) } else if !strings.Contains(sql, " ") { sql = "SELECT * FROM " + sql }