Move almost all functions' parameter db.Engine to context.Context (#19748)

* Move almost all functions' parameter db.Engine to context.Context
* remove some unnecessary wrap functions
This commit is contained in:
Lunny Xiao 2022-05-20 22:08:52 +08:00 committed by GitHub
parent d81e31ad78
commit fd7d83ace6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
232 changed files with 1463 additions and 2108 deletions

View file

@ -92,13 +92,9 @@ func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
}
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) {
return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID)
}
func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) {
func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@ -107,17 +103,13 @@ func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant
}
// CreateGrant generates a grant for an user
func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) {
return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope)
}
func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) {
func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
grant := &OAuth2Grant{
ApplicationID: app.ID,
UserID: userID,
Scope: scope,
}
_, err := e.Insert(grant)
err := db.Insert(ctx, grant)
if err != nil {
return nil, err
}
@ -125,13 +117,9 @@ func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope strin
}
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) {
return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID)
}
func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) {
func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := e.Where("client_id = ?", clientID).Get(app)
has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
if !has {
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
}
@ -139,13 +127,9 @@ func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Ap
}
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) {
return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id)
}
func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) {
func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := e.ID(id).Get(app)
has, err := db.GetEngine(ctx).ID(id).Get(app)
if err != nil {
return nil, err
}
@ -156,13 +140,9 @@ func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, er
}
// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user
func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) {
return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID)
}
func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) {
func GetOAuth2ApplicationsByUserID(ctx context.Context, userID int64) (apps []*OAuth2Application, err error) {
apps = make([]*OAuth2Application, 0)
err = e.Where("uid = ?", userID).Find(&apps)
err = db.GetEngine(ctx).Where("uid = ?", userID).Find(&apps)
return
}
@ -174,11 +154,7 @@ type CreateOAuth2ApplicationOptions struct {
}
// CreateOAuth2Application inserts a new oauth2 application
func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
return createOAuth2Application(db.GetEngine(db.DefaultContext), opts)
}
func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
clientID := uuid.New().String()
app := &OAuth2Application{
UID: opts.UserID,
@ -186,7 +162,7 @@ func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (
ClientID: clientID,
RedirectURIs: opts.RedirectURIs,
}
if _, err := e.Insert(app); err != nil {
if err := db.Insert(ctx, app); err != nil {
return nil, err
}
return app, nil
@ -207,9 +183,8 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
return nil, err
}
defer committer.Close()
sess := db.GetEngine(ctx)
app, err := getOAuth2ApplicationByID(sess, opts.ID)
app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
if err != nil {
return nil, err
}
@ -220,7 +195,7 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
app.Name = opts.Name
app.RedirectURIs = opts.RedirectURIs
if err = updateOAuth2Application(sess, app); err != nil {
if err = updateOAuth2Application(ctx, app); err != nil {
return nil, err
}
app.ClientSecret = ""
@ -228,14 +203,15 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
return app, committer.Commit()
}
func updateOAuth2Application(e db.Engine, app *OAuth2Application) error {
if _, err := e.ID(app.ID).Update(app); err != nil {
func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
if _, err := db.GetEngine(ctx).ID(app.ID).Update(app); err != nil {
return err
}
return nil
}
func deleteOAuth2Application(sess db.Engine, id, userid int64) error {
func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
sess := db.GetEngine(ctx)
if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil {
return err
} else if deleted == 0 {
@ -269,7 +245,7 @@ func DeleteOAuth2Application(id, userid int64) error {
return err
}
defer committer.Close()
if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil {
if err := deleteOAuth2Application(ctx, id, userid); err != nil {
return err
}
return committer.Commit()
@ -328,21 +304,13 @@ func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect
}
// Invalidate deletes the auth code from the database to invalidate this code
func (code *OAuth2AuthorizationCode) Invalidate() error {
return code.invalidate(db.GetEngine(db.DefaultContext))
}
func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error {
_, err := e.Delete(code)
func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
return err
}
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
return code.validateCodeChallenge(verifier)
}
func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool {
switch code.CodeChallengeMethod {
case "S256":
// base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
@ -360,19 +328,15 @@ func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool
}
// GetOAuth2AuthorizationByCode returns an authorization by its code
func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) {
return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code)
}
func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
auth = new(OAuth2AuthorizationCode)
if has, err := e.Where("code = ?", code).Get(auth); err != nil {
if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
return nil, err
} else if !has {
return nil, nil
}
auth.Grant = new(OAuth2Grant)
if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil {
if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@ -401,11 +365,7 @@ func (grant *OAuth2Grant) TableName() string {
}
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) {
return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod)
}
func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
rBytes, err := util.CryptoRandomBytes(32)
if err != nil {
return &OAuth2AuthorizationCode{}, err
@ -422,23 +382,19 @@ func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
}
if _, err := e.Insert(code); err != nil {
if err := db.Insert(ctx, code); err != nil {
return nil, err
}
return code, nil
}
// IncreaseCounter increases the counter and updates the grant
func (grant *OAuth2Grant) IncreaseCounter() error {
return grant.increaseCount(db.GetEngine(db.DefaultContext))
}
func (grant *OAuth2Grant) increaseCount(e db.Engine) error {
_, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
if err != nil {
return err
}
updatedGrant, err := getOAuth2GrantByID(e, grant.ID)
updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID)
if err != nil {
return err
}
@ -457,13 +413,9 @@ func (grant *OAuth2Grant) ScopeContains(scope string) bool {
}
// SetNonce updates the current nonce value of a grant
func (grant *OAuth2Grant) SetNonce(nonce string) error {
return grant.setNonce(db.GetEngine(db.DefaultContext), nonce)
}
func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
grant.Nonce = nonce
_, err := e.ID(grant.ID).Cols("nonce").Update(grant)
_, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
if err != nil {
return err
}
@ -471,13 +423,9 @@ func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
}
// GetOAuth2GrantByID returns the grant with the given ID
func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) {
return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id)
}
func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := e.ID(id).Get(grant); err != nil {
if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil
@ -486,18 +434,14 @@ func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
}
// GetOAuth2GrantsByUserID lists all grants of a certain user
func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) {
return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid)
}
func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
type joinedOAuth2Grant struct {
Grant *OAuth2Grant `xorm:"extends"`
Application *OAuth2Application `xorm:"extends"`
}
var results *xorm.Rows
var err error
if results, err = e.
if results, err = db.GetEngine(ctx).
Table("oauth2_grant").
Where("user_id = ?", uid).
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
@ -518,12 +462,8 @@ func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
}
// RevokeOAuth2Grant deletes the grant with grantID and userID
func RevokeOAuth2Grant(grantID, userID int64) error {
return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID)
}
func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error {
_, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID})
func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
_, err := db.DeleteByBean(ctx, &OAuth2Grant{ID: grantID, UserID: userID})
return err
}

View file

@ -7,6 +7,7 @@ package auth
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
@ -52,18 +53,18 @@ func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
func TestGetOAuth2ApplicationByClientID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app, err := GetOAuth2ApplicationByClientID("da7da3ba-9a13-4167-856f-3899de0b0138")
app, err := GetOAuth2ApplicationByClientID(db.DefaultContext, "da7da3ba-9a13-4167-856f-3899de0b0138")
assert.NoError(t, err)
assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
app, err = GetOAuth2ApplicationByClientID("invalid client id")
app, err = GetOAuth2ApplicationByClientID(db.DefaultContext, "invalid client id")
assert.Error(t, err)
assert.Nil(t, app)
}
func TestCreateOAuth2Application(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app, err := CreateOAuth2Application(CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
app, err := CreateOAuth2Application(db.DefaultContext, CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
assert.NoError(t, err)
assert.Equal(t, "newapp", app.Name)
assert.Len(t, app.ClientID, 36)
@ -77,11 +78,11 @@ func TestOAuth2Application_TableName(t *testing.T) {
func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
grant, err := app.GetGrantByUserID(1)
grant, err := app.GetGrantByUserID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, int64(1), grant.UserID)
grant, err = app.GetGrantByUserID(34923458)
grant, err = app.GetGrantByUserID(db.DefaultContext, 34923458)
assert.NoError(t, err)
assert.Nil(t, grant)
}
@ -89,7 +90,7 @@ func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
func TestOAuth2Application_CreateGrant(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
grant, err := app.CreateGrant(2, "")
grant, err := app.CreateGrant(db.DefaultContext, 2, "")
assert.NoError(t, err)
assert.NotNil(t, grant)
assert.Equal(t, int64(2), grant.UserID)
@ -101,11 +102,11 @@ func TestOAuth2Application_CreateGrant(t *testing.T) {
func TestGetOAuth2GrantByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant, err := GetOAuth2GrantByID(1)
grant, err := GetOAuth2GrantByID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, int64(1), grant.ID)
grant, err = GetOAuth2GrantByID(34923458)
grant, err = GetOAuth2GrantByID(db.DefaultContext, 34923458)
assert.NoError(t, err)
assert.Nil(t, grant)
}
@ -113,7 +114,7 @@ func TestGetOAuth2GrantByID(t *testing.T) {
func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 1}).(*OAuth2Grant)
assert.NoError(t, grant.IncreaseCounter())
assert.NoError(t, grant.IncreaseCounter(db.DefaultContext))
assert.Equal(t, int64(2), grant.Counter)
unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 2})
}
@ -130,7 +131,7 @@ func TestOAuth2Grant_ScopeContains(t *testing.T) {
func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1}).(*OAuth2Grant)
code, err := grant.GenerateNewAuthorizationCode("https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
code, err := grant.GenerateNewAuthorizationCode(db.DefaultContext, "https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
assert.NoError(t, err)
assert.NotNil(t, code)
assert.True(t, len(code.Code) > 32) // secret length > 32
@ -142,20 +143,20 @@ func TestOAuth2Grant_TableName(t *testing.T) {
func TestGetOAuth2GrantsByUserID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
result, err := GetOAuth2GrantsByUserID(1)
result, err := GetOAuth2GrantsByUserID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, int64(1), result[0].ID)
assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
result, err = GetOAuth2GrantsByUserID(34134)
result, err = GetOAuth2GrantsByUserID(db.DefaultContext, 34134)
assert.NoError(t, err)
assert.Empty(t, result)
}
func TestRevokeOAuth2Grant(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.NoError(t, RevokeOAuth2Grant(1, 1))
assert.NoError(t, RevokeOAuth2Grant(db.DefaultContext, 1, 1))
unittest.AssertNotExistsBean(t, &OAuth2Grant{ID: 1, UserID: 1})
}
@ -163,13 +164,13 @@ func TestRevokeOAuth2Grant(t *testing.T) {
func TestGetOAuth2AuthorizationByCode(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
code, err := GetOAuth2AuthorizationByCode("authcode")
code, err := GetOAuth2AuthorizationByCode(db.DefaultContext, "authcode")
assert.NoError(t, err)
assert.NotNil(t, code)
assert.Equal(t, "authcode", code.Code)
assert.Equal(t, int64(1), code.ID)
code, err = GetOAuth2AuthorizationByCode("does not exist")
code, err = GetOAuth2AuthorizationByCode(db.DefaultContext, "does not exist")
assert.NoError(t, err)
assert.Nil(t, code)
}
@ -224,7 +225,7 @@ func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
code := unittest.AssertExistsAndLoadBean(t, &OAuth2AuthorizationCode{Code: "authcode"}).(*OAuth2AuthorizationCode)
assert.NoError(t, code.Invalidate())
assert.NoError(t, code.Invalidate(db.DefaultContext))
unittest.AssertNotExistsBean(t, &OAuth2AuthorizationCode{Code: "authcode"})
}