Refactor fixture loading for testing (#33024)

To help binary size and testing performance
This commit is contained in:
wxiaoguang 2024-12-30 12:06:57 +08:00 committed by GitHub
parent f4ccbd38dc
commit dafadf0028
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 377 additions and 248 deletions

View file

@ -5,88 +5,29 @@ package unittest
import (
"fmt"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/auth/password/hash"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
"github.com/go-testfixtures/testfixtures/v3"
"xorm.io/xorm"
"xorm.io/xorm/schemas"
)
var fixturesLoader *testfixtures.Loader
type FixturesLoader interface {
Load() error
}
var fixturesLoader FixturesLoader
// GetXORMEngine gets the XORM engine
func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) {
if len(engine) == 1 {
return engine[0]
}
func GetXORMEngine() (x *xorm.Engine) {
return db.GetEngine(db.DefaultContext).(*xorm.Engine)
}
// InitFixtures initialize test fixtures for a test database
func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) {
e := GetXORMEngine(engine...)
var fixtureOptionFiles func(*testfixtures.Loader) error
if opts.Dir != "" {
fixtureOptionFiles = testfixtures.Directory(opts.Dir)
} else {
fixtureOptionFiles = testfixtures.Files(opts.Files...)
}
var dialect string
switch e.Dialect().URI().DBType {
case schemas.POSTGRES:
dialect = "postgres"
case schemas.MYSQL:
dialect = "mysql"
case schemas.MSSQL:
dialect = "mssql"
case schemas.SQLITE:
dialect = "sqlite3"
default:
return fmt.Errorf("unsupported RDBMS for integration tests: %q", e.Dialect().URI().DBType)
}
loaderOptions := []func(loader *testfixtures.Loader) error{
testfixtures.Database(e.DB().DB),
testfixtures.Dialect(dialect),
testfixtures.DangerousSkipTestDatabaseCheck(),
fixtureOptionFiles,
}
if e.Dialect().URI().DBType == schemas.POSTGRES {
loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences())
}
fixturesLoader, err = testfixtures.New(loaderOptions...)
if err != nil {
return err
}
// register the dummy hash algorithm function used in the test fixtures
_ = hash.Register("dummy", hash.NewDummyHasher)
setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy")
return err
}
// LoadFixtures load fixtures for a test database
func LoadFixtures(engine ...*xorm.Engine) error {
e := GetXORMEngine(engine...)
var err error
// (doubt) database transaction conflicts could occur and result in ROLLBACK? just try for a few times.
for i := 0; i < 5; i++ {
if err = fixturesLoader.Load(); err == nil {
break
}
time.Sleep(200 * time.Millisecond)
}
if err != nil {
return fmt.Errorf("LoadFixtures failed after retries: %w", err)
}
// Now if we're running postgres we need to tell it to update the sequences
if e.Dialect().URI().DBType == schemas.POSTGRES {
results, err := e.QueryString(`SELECT 'SELECT SETVAL(' ||
func loadFixtureResetSeqPgsql(e *xorm.Engine) error {
results, err := e.QueryString(`SELECT 'SELECT SETVAL(' ||
quote_literal(quote_ident(PGT.schemaname) || '.' || quote_ident(S.relname)) ||
', COALESCE(MAX(' ||quote_ident(C.attname)|| '), 1) ) FROM ' ||
quote_ident(PGT.schemaname)|| '.'||quote_ident(T.relname)|| ';'
@ -102,19 +43,42 @@ func LoadFixtures(engine ...*xorm.Engine) error {
AND D.refobjsubid = C.attnum
AND T.relname = PGT.tablename
ORDER BY S.relname;`)
if err != nil {
return fmt.Errorf("failed to generate sequence update: %w", err)
}
for _, r := range results {
for _, value := range r {
_, err = e.Exec(value)
if err != nil {
return fmt.Errorf("failed to update sequence: %s, error: %w", value, err)
}
if err != nil {
return fmt.Errorf("failed to generate sequence update: %w", err)
}
for _, r := range results {
for _, value := range r {
_, err = e.Exec(value)
if err != nil {
return fmt.Errorf("failed to update sequence: %s, error: %w", value, err)
}
}
}
_ = hash.Register("dummy", hash.NewDummyHasher)
setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy")
return nil
}
// InitFixtures initialize test fixtures for a test database
func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) {
xormEngine := util.IfZero(util.OptionalArg(engine), GetXORMEngine())
fixturesLoader, err = NewFixturesLoader(xormEngine, opts)
// fixturesLoader = NewFixturesLoaderVendor(xormEngine, opts)
// register the dummy hash algorithm function used in the test fixtures
_ = hash.Register("dummy", hash.NewDummyHasher)
setting.PasswordHashAlgo, _ = hash.SetDefaultPasswordHashAlgorithm("dummy")
return err
}
// LoadFixtures load fixtures for a test database
func LoadFixtures() error {
if err := fixturesLoader.Load(); err != nil {
return err
}
// Now if we're running postgres we need to tell it to update the sequences
if GetXORMEngine().Dialect().URI().DBType == schemas.POSTGRES {
if err := loadFixtureResetSeqPgsql(GetXORMEngine()); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,201 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package unittest
import (
"database/sql"
"encoding/hex"
"fmt"
"os"
"path/filepath"
"slices"
"strings"
"gopkg.in/yaml.v3"
"xorm.io/xorm"
"xorm.io/xorm/schemas"
)
type fixtureItem struct {
tableName string
tableNameQuoted string
sqlInserts []string
sqlInsertArgs [][]any
mssqlHasIdentityColumn bool
}
type fixturesLoaderInternal struct {
db *sql.DB
dbType schemas.DBType
files []string
fixtures map[string]*fixtureItem
quoteObject func(string) string
paramPlaceholder func(idx int) string
}
func (f *fixturesLoaderInternal) mssqlTableHasIdentityColumn(db *sql.DB, tableName string) (bool, error) {
row := db.QueryRow(`SELECT COUNT(*) FROM sys.identity_columns WHERE OBJECT_ID = OBJECT_ID(?)`, tableName)
var count int
if err := row.Scan(&count); err != nil {
return false, err
}
return count > 0, nil
}
func (f *fixturesLoaderInternal) preprocessFixtureRow(row []map[string]any) (err error) {
for _, m := range row {
for k, v := range m {
if s, ok := v.(string); ok {
if strings.HasPrefix(s, "0x") {
if m[k], err = hex.DecodeString(s[2:]); err != nil {
return err
}
}
}
}
}
return nil
}
func (f *fixturesLoaderInternal) prepareFixtureItem(file string) (_ *fixtureItem, err error) {
fixture := &fixtureItem{}
fixture.tableName, _, _ = strings.Cut(filepath.Base(file), ".")
fixture.tableNameQuoted = f.quoteObject(fixture.tableName)
if f.dbType == schemas.MSSQL {
fixture.mssqlHasIdentityColumn, err = f.mssqlTableHasIdentityColumn(f.db, fixture.tableName)
if err != nil {
return nil, err
}
}
data, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("failed to read file %q: %w", file, err)
}
var rows []map[string]any
if err = yaml.Unmarshal(data, &rows); err != nil {
return nil, fmt.Errorf("failed to unmarshal yaml data from %q: %w", file, err)
}
if err = f.preprocessFixtureRow(rows); err != nil {
return nil, fmt.Errorf("failed to preprocess fixture rows from %q: %w", file, err)
}
var sqlBuf []byte
var sqlArguments []any
for _, row := range rows {
sqlBuf = append(sqlBuf, fmt.Sprintf("INSERT INTO %s (", fixture.tableNameQuoted)...)
for k, v := range row {
sqlBuf = append(sqlBuf, f.quoteObject(k)...)
sqlBuf = append(sqlBuf, ","...)
sqlArguments = append(sqlArguments, v)
}
sqlBuf = sqlBuf[:len(sqlBuf)-1]
sqlBuf = append(sqlBuf, ") VALUES ("...)
paramIdx := 1
for range row {
sqlBuf = append(sqlBuf, f.paramPlaceholder(paramIdx)...)
sqlBuf = append(sqlBuf, ',')
paramIdx++
}
sqlBuf[len(sqlBuf)-1] = ')'
fixture.sqlInserts = append(fixture.sqlInserts, string(sqlBuf))
fixture.sqlInsertArgs = append(fixture.sqlInsertArgs, slices.Clone(sqlArguments))
sqlBuf = sqlBuf[:0]
sqlArguments = sqlArguments[:0]
}
return fixture, nil
}
func (f *fixturesLoaderInternal) loadFixtures(tx *sql.Tx, file string) (err error) {
fixture := f.fixtures[file]
if fixture == nil {
if fixture, err = f.prepareFixtureItem(file); err != nil {
return err
}
f.fixtures[file] = fixture
}
_, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", fixture.tableNameQuoted)) // sqlite3 doesn't support truncate
if err != nil {
return err
}
if fixture.mssqlHasIdentityColumn {
_, err = tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", fixture.tableNameQuoted))
if err != nil {
return err
}
defer func() { _, err = tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", fixture.tableNameQuoted)) }()
}
for i := range fixture.sqlInserts {
_, err = tx.Exec(fixture.sqlInserts[i], fixture.sqlInsertArgs[i]...)
}
if err != nil {
return err
}
return nil
}
func (f *fixturesLoaderInternal) Load() error {
tx, err := f.db.Begin()
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for _, file := range f.files {
if err := f.loadFixtures(tx, file); err != nil {
return fmt.Errorf("failed to load fixtures from %s: %w", file, err)
}
}
return tx.Commit()
}
func FixturesFileFullPaths(dir string, files []string) ([]string, error) {
if files != nil && len(files) == 0 {
return nil, nil // load nothing
}
files = slices.Clone(files)
if len(files) == 0 {
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
for _, e := range entries {
files = append(files, e.Name())
}
}
for i, file := range files {
if !filepath.IsAbs(file) {
files[i] = filepath.Join(dir, file)
}
}
return files, nil
}
func NewFixturesLoader(x *xorm.Engine, opts FixturesOptions) (FixturesLoader, error) {
files, err := FixturesFileFullPaths(opts.Dir, opts.Files)
if err != nil {
return nil, fmt.Errorf("failed to get fixtures files: %w", err)
}
f := &fixturesLoaderInternal{db: x.DB().DB, dbType: x.Dialect().URI().DBType, files: files, fixtures: map[string]*fixtureItem{}}
switch f.dbType {
case schemas.SQLITE:
f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
f.paramPlaceholder = func(idx int) string { return "?" }
case schemas.POSTGRES:
f.quoteObject = func(s string) string { return fmt.Sprintf(`"%s"`, s) }
f.paramPlaceholder = func(idx int) string { return fmt.Sprintf(`$%d`, idx) }
case schemas.MYSQL:
f.quoteObject = func(s string) string { return fmt.Sprintf("`%s`", s) }
f.paramPlaceholder = func(idx int) string { return "?" }
case schemas.MSSQL:
f.quoteObject = func(s string) string { return fmt.Sprintf("[%s]", s) }
f.paramPlaceholder = func(idx int) string { return "?" }
}
return f, nil
}

View file

@ -0,0 +1,114 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package unittest_test
import (
"path/filepath"
"testing"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/test"
"github.com/stretchr/testify/require"
"xorm.io/xorm"
)
var NewFixturesLoaderVendor = func(e *xorm.Engine, opts unittest.FixturesOptions) (unittest.FixturesLoader, error) {
return nil, nil
}
/*
// the old code is kept here in case we are still interested in benchmarking the two implementations
func init() {
NewFixturesLoaderVendor = func(e *xorm.Engine, opts unittest.FixturesOptions) (unittest.FixturesLoader, error) {
return NewFixturesLoaderVendorGoTestfixtures(e, opts)
}
}
func NewFixturesLoaderVendorGoTestfixtures(e *xorm.Engine, opts unittest.FixturesOptions) (*testfixtures.Loader, error) {
files, err := unittest.FixturesFileFullPaths(opts.Dir, opts.Files)
if err != nil {
return nil, fmt.Errorf("failed to get fixtures files: %w", err)
}
var dialect string
switch e.Dialect().URI().DBType {
case schemas.POSTGRES:
dialect = "postgres"
case schemas.MYSQL:
dialect = "mysql"
case schemas.MSSQL:
dialect = "mssql"
case schemas.SQLITE:
dialect = "sqlite3"
default:
return nil, fmt.Errorf("unsupported RDBMS for integration tests: %q", e.Dialect().URI().DBType)
}
loaderOptions := []func(loader *testfixtures.Loader) error{
testfixtures.Database(e.DB().DB),
testfixtures.Dialect(dialect),
testfixtures.DangerousSkipTestDatabaseCheck(),
testfixtures.Files(files...),
}
if e.Dialect().URI().DBType == schemas.POSTGRES {
loaderOptions = append(loaderOptions, testfixtures.SkipResetSequences())
}
return testfixtures.New(loaderOptions...)
}
*/
func prepareTestFixturesLoaders(t testing.TB) unittest.FixturesOptions {
_ = user_model.User{}
opts := unittest.FixturesOptions{Dir: filepath.Join(test.SetupGiteaRoot(), "models", "fixtures"), Files: []string{
"user.yml",
}}
require.NoError(t, unittest.CreateTestEngine(opts))
return opts
}
func TestFixturesLoader(t *testing.T) {
opts := prepareTestFixturesLoaders(t)
loaderInternal, err := unittest.NewFixturesLoader(unittest.GetXORMEngine(), opts)
require.NoError(t, err)
loaderVendor, err := NewFixturesLoaderVendor(unittest.GetXORMEngine(), opts)
require.NoError(t, err)
t.Run("Internal", func(t *testing.T) {
require.NoError(t, loaderInternal.Load())
require.NoError(t, loaderInternal.Load())
})
t.Run("Vendor", func(t *testing.T) {
if loaderVendor == nil {
t.Skip()
}
require.NoError(t, loaderVendor.Load())
require.NoError(t, loaderVendor.Load())
})
}
func BenchmarkFixturesLoader(b *testing.B) {
opts := prepareTestFixturesLoaders(b)
require.NoError(b, unittest.CreateTestEngine(opts))
loaderInternal, err := unittest.NewFixturesLoader(unittest.GetXORMEngine(), opts)
require.NoError(b, err)
loaderVendor, err := NewFixturesLoaderVendor(unittest.GetXORMEngine(), opts)
require.NoError(b, err)
// BenchmarkFixturesLoader/Vendor
// BenchmarkFixturesLoader/Vendor-12 1696 719416 ns/op
// BenchmarkFixturesLoader/Internal
// BenchmarkFixturesLoader/Internal-12 1746 670457 ns/op
b.Run("Internal", func(b *testing.B) {
for i := 0; i < b.N; i++ {
require.NoError(b, loaderInternal.Load())
}
})
b.Run("Vendor", func(b *testing.B) {
if loaderVendor == nil {
b.Skip()
}
for i := 0; i < b.N; i++ {
require.NoError(b, loaderVendor.Load())
}
})
}

View file

@ -28,11 +28,7 @@ import (
"xorm.io/xorm/names"
)
// giteaRoot a path to the gitea root
var (
giteaRoot string
fixturesDir string
)
var giteaRoot string
func fatalTestError(fmtStr string, args ...any) {
_, _ = fmt.Fprintf(os.Stderr, fmtStr, args...)
@ -74,39 +70,14 @@ type TestOptions struct {
// MainTest a reusable TestMain(..) function for unit tests that need to use a
// test database. Creates the test database, and sets necessary settings.
func MainTest(m *testing.M, testOpts ...*TestOptions) {
searchDir, _ := os.Getwd()
for searchDir != "" {
if _, err := os.Stat(filepath.Join(searchDir, "go.mod")); err == nil {
break // The "go.mod" should be the one for Gitea repository
}
if dir := filepath.Dir(searchDir); dir == searchDir {
searchDir = "" // reaches the root of filesystem
} else {
searchDir = dir
}
}
if searchDir == "" {
panic("The tests should run in a Gitea repository, there should be a 'go.mod' in the root")
}
giteaRoot = searchDir
func MainTest(m *testing.M, testOptsArg ...*TestOptions) {
testOpts := util.OptionalArg(testOptsArg, &TestOptions{})
giteaRoot = test.SetupGiteaRoot()
setting.CustomPath = filepath.Join(giteaRoot, "custom")
InitSettings()
fixturesDir = filepath.Join(giteaRoot, "models", "fixtures")
var opts FixturesOptions
if len(testOpts) == 0 || len(testOpts[0].FixtureFiles) == 0 {
opts.Dir = fixturesDir
} else {
for _, f := range testOpts[0].FixtureFiles {
if len(f) != 0 {
opts.Files = append(opts.Files, filepath.Join(fixturesDir, f))
}
}
}
if err := CreateTestEngine(opts); err != nil {
fixturesOpts := FixturesOptions{Dir: filepath.Join(giteaRoot, "models", "fixtures"), Files: testOpts.FixtureFiles}
if err := CreateTestEngine(fixturesOpts); err != nil {
fatalTestError("Error creating test engine: %v\n", err)
}
@ -167,16 +138,16 @@ func MainTest(m *testing.M, testOpts ...*TestOptions) {
fatalTestError("git.Init: %v\n", err)
}
if len(testOpts) > 0 && testOpts[0].SetUp != nil {
if err := testOpts[0].SetUp(); err != nil {
if testOpts.SetUp != nil {
if err := testOpts.SetUp(); err != nil {
fatalTestError("set up failed: %v\n", err)
}
}
exitStatus := m.Run()
if len(testOpts) > 0 && testOpts[0].TearDown != nil {
if err := testOpts[0].TearDown(); err != nil {
if testOpts.TearDown != nil {
if err := testOpts.TearDown(); err != nil {
fatalTestError("tear down failed: %v\n", err)
}
}