Refactor db package (#35380)

Remove unnecessary code
This commit is contained in:
wxiaoguang
2025-08-30 01:04:06 +08:00
committed by GitHub
parent aef4a3514c
commit 1f50048ac9
11 changed files with 37 additions and 52 deletions

View File

@@ -86,7 +86,7 @@ func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, er
} }
} }
if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil { if _, err := db.Exec(ctx, "UPDATE `session` SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
return nil, err return nil, err
} }

View File

@@ -8,7 +8,6 @@ import (
"testing" "testing"
auth_model "code.gitea.io/gitea/models/auth" auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest" "code.gitea.io/gitea/models/unittest"
"code.gitea.io/gitea/modules/json" "code.gitea.io/gitea/modules/json"
@@ -40,7 +39,7 @@ func (source *TestSource) ToDB() ([]byte, error) {
func TestDumpAuthSource(t *testing.T) { func TestDumpAuthSource(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase()) assert.NoError(t, unittest.PrepareTestDatabase())
authSourceSchema, err := db.TableInfo(new(auth_model.Source)) authSourceSchema, err := unittest.GetXORMEngine().TableInfo(new(auth_model.Source))
assert.NoError(t, err) assert.NoError(t, err)
auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource)) auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource))

View File

@@ -61,17 +61,19 @@ func contextSafetyCheck(e Engine) {
callerNum := runtime.Callers(3, callers) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine callerNum := runtime.Callers(3, callers) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine
for i := range callerNum { for i := range callerNum {
if slices.Contains(contextSafetyDeniedFuncPCs, callers[i]) { if slices.Contains(contextSafetyDeniedFuncPCs, callers[i]) {
panic(errors.New("using database context in an iterator would cause corrupted results")) panic(errors.New("using session context in an iterator would cause corrupted results"))
} }
} }
} }
// GetEngine gets an existing db Engine/Statement or creates a new Session // GetEngine gets an existing db Engine/Statement or creates a new Session
func GetEngine(ctx context.Context) (e Engine) { func GetEngine(ctx context.Context) Engine {
defer func() { contextSafetyCheck(e) }()
if engine, ok := ctx.Value(engineContextKey).(Engine); ok { if engine, ok := ctx.Value(engineContextKey).(Engine); ok {
// if reusing the existing session, need to do "contextSafetyCheck" because the Iterate creates a "autoResetStatement=false" session
contextSafetyCheck(engine)
return engine return engine
} }
// no need to do "contextSafetyCheck" because it's a new Session
return xormEngine.Context(ctx) return xormEngine.Context(ctx)
} }
@@ -301,11 +303,6 @@ func CountByBean(ctx context.Context, bean any) (int64, error) {
return GetEngine(ctx).Count(bean) return GetEngine(ctx).Count(bean)
} }
// TableName returns the table name according a bean object
func TableName(bean any) string {
return xormEngine.TableName(bean)
}
// InTransaction returns true if the engine is in a transaction otherwise return false // InTransaction returns true if the engine is in a transaction otherwise return false
func InTransaction(ctx context.Context) bool { func InTransaction(ctx context.Context) bool {
return getTransactionSession(ctx) != nil return getTransactionSession(ctx) != nil

View File

@@ -100,31 +100,36 @@ func TestContextSafety(t *testing.T) {
assert.NoError(t, db.Insert(t.Context(), &TestModel2{ID: int64(-i)})) assert.NoError(t, db.Insert(t.Context(), &TestModel2{ID: int64(-i)}))
} }
actualCount := 0 t.Run("Show-XORM-Bug", func(t *testing.T) {
// here: db.GetEngine(t.Context()) is a new *Session created from *Engine actualCount := 0
_ = db.WithTx(t.Context(), func(ctx context.Context) error { // here: db.GetEngine(t.Context()) is a new *Session created from *Engine
_ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error { _ = db.WithTx(t.Context(), func(ctx context.Context) error {
// here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false, _ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
// and the internal states (including "cond" and others) are always there and not be reset in this callback. // here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false,
m1 := bean.(*TestModel1) // and the internal states (including "cond" and others) are always there and not be reset in this callback.
assert.EqualValues(t, i+1, m1.ID) m1 := bean.(*TestModel1)
assert.EqualValues(t, i+1, m1.ID)
// here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ... // here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ...
// and it conflicts with the "Iterate"'s internal states. // and it conflicts with the "Iterate"'s internal states.
// has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID}) // has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID})
actualCount++ actualCount++
return nil
})
return nil return nil
}) })
return nil assert.Equal(t, testCount, actualCount)
}) })
assert.Equal(t, testCount, actualCount)
// deny the bad usages t.Run("DenyBadUsage", func(t *testing.T) {
assert.PanicsWithError(t, "using database context in an iterator would cause corrupted results", func() { assert.PanicsWithError(t, "using session context in an iterator would cause corrupted results", func() {
_ = unittest.GetXORMEngine().Iterate(&TestModel1{}, func(i int, bean any) error { _ = db.WithTx(t.Context(), func(ctx context.Context) error {
_ = db.GetEngine(t.Context()) return db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
return nil _ = db.GetEngine(ctx)
return nil
})
})
}) })
}) })
} }

View File

@@ -12,7 +12,6 @@ import (
"strings" "strings"
"xorm.io/xorm" "xorm.io/xorm"
"xorm.io/xorm/schemas"
_ "github.com/go-sql-driver/mysql" // Needed for the MySQL driver _ "github.com/go-sql-driver/mysql" // Needed for the MySQL driver
_ "github.com/lib/pq" // Needed for the Postgresql driver _ "github.com/lib/pq" // Needed for the Postgresql driver
@@ -67,11 +66,6 @@ var (
_ Engine = (*xorm.Session)(nil) _ Engine = (*xorm.Session)(nil)
) )
// TableInfo returns table's information via an object
func TableInfo(v any) (*schemas.Table, error) {
return xormEngine.TableInfo(v)
}
// RegisterModel registers model, if initFuncs provided, it will be invoked after data model sync // RegisterModel registers model, if initFuncs provided, it will be invoked after data model sync
func RegisterModel(bean any, initFunc ...func() error) { func RegisterModel(bean any, initFunc ...func() error) {
registeredModels = append(registeredModels, bean) registeredModels = append(registeredModels, bean)

View File

@@ -70,7 +70,7 @@ func TestPrimaryKeys(t *testing.T) {
} }
for _, bean := range beans { for _, bean := range beans {
table, err := db.TableInfo(bean) table, err := db.GetXORMEngineForTesting().TableInfo(bean)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -19,12 +19,7 @@ type ResourceIndex struct {
MaxIndex int64 `xorm:"index"` MaxIndex int64 `xorm:"index"`
} }
var ( var ErrGetResourceIndexFailed = errors.New("get resource index failed")
// ErrResouceOutdated represents an error when request resource outdated
ErrResouceOutdated = errors.New("resource outdated")
// ErrGetResourceIndexFailed represents an error when resource index retries 3 times
ErrGetResourceIndexFailed = errors.New("get resource index failed")
)
// SyncMaxResourceIndex sync the max index with the resource // SyncMaxResourceIndex sync the max index with the resource
func SyncMaxResourceIndex(ctx context.Context, tableName string, groupID, maxIndex int64) (err error) { func SyncMaxResourceIndex(ctx context.Context, tableName string, groupID, maxIndex int64) (err error) {

View File

@@ -105,11 +105,6 @@ type MinimalOrg = Organization
// GetUserOrgsList returns all organizations the given user has access to // GetUserOrgsList returns all organizations the given user has access to
func GetUserOrgsList(ctx context.Context, user *user_model.User) ([]*MinimalOrg, error) { func GetUserOrgsList(ctx context.Context, user *user_model.User) ([]*MinimalOrg, error) {
schema, err := db.TableInfo(new(user_model.User))
if err != nil {
return nil, err
}
outputCols := []string{ outputCols := []string{
"id", "id",
"name", "name",
@@ -122,7 +117,7 @@ func GetUserOrgsList(ctx context.Context, user *user_model.User) ([]*MinimalOrg,
selectColumns := &strings.Builder{} selectColumns := &strings.Builder{}
for i, col := range outputCols { for i, col := range outputCols {
fmt.Fprintf(selectColumns, "`%s`.%s", schema.Name, col) _, _ = fmt.Fprintf(selectColumns, "`user`.%s", col)
if i < len(outputCols)-1 { if i < len(outputCols)-1 {
selectColumns.WriteString(", ") selectColumns.WriteString(", ")
} }

View File

@@ -45,7 +45,7 @@ func CheckConsistencyFor(t TestingT, beansToCheck ...any) {
} }
func checkForConsistency(t TestingT, bean any) { func checkForConsistency(t TestingT, bean any) {
tb, err := db.TableInfo(bean) tb, err := GetXORMEngine().TableInfo(bean)
assert.NoError(t, err) assert.NoError(t, err)
f := consistencyCheckMap[tb.Name] f := consistencyCheckMap[tb.Name]
require.NotNil(t, f, "unknown bean type: %#v", bean) require.NotNil(t, f, "unknown bean type: %#v", bean)

View File

@@ -218,7 +218,7 @@ func NewFixturesLoader(x *xorm.Engine, opts FixturesOptions) (FixturesLoader, er
xormBeans, _ := db.NamesToBean() xormBeans, _ := db.NamesToBean()
f.xormTableNames = map[string]bool{} f.xormTableNames = map[string]bool{}
for _, bean := range xormBeans { for _, bean := range xormBeans {
f.xormTableNames[db.TableName(bean)] = true f.xormTableNames[x.TableName(bean)] = true
} }
return f, nil return f, nil

View File

@@ -159,7 +159,7 @@ func DumpQueryResult(t require.TestingT, sqlOrBean any, sqlArgs ...any) {
goDB := x.DB().DB goDB := x.DB().DB
sql, ok := sqlOrBean.(string) sql, ok := sqlOrBean.(string)
if !ok { if !ok {
sql = "SELECT * FROM " + db.TableName(sqlOrBean) sql = "SELECT * FROM " + x.TableName(sqlOrBean)
} else if !strings.Contains(sql, " ") { } else if !strings.Contains(sql, " ") {
sql = "SELECT * FROM " + sql sql = "SELECT * FROM " + sql
} }