1
0
Fork 0

Next round of `db.DefaultContext` refactor (#27089)

Part of #27065
This commit is contained in:
JakobDev 2023-09-16 16:39:12 +02:00 committed by GitHub
parent a1b2a11812
commit f91dbbba98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
90 changed files with 434 additions and 464 deletions

View File

@ -41,15 +41,15 @@ func init() {
}
// GetSchedulesMapByIDs returns the schedules by given id slice.
func GetSchedulesMapByIDs(ids []int64) (map[int64]*ActionSchedule, error) {
func GetSchedulesMapByIDs(ctx context.Context, ids []int64) (map[int64]*ActionSchedule, error) {
schedules := make(map[int64]*ActionSchedule, len(ids))
return schedules, db.GetEngine(db.DefaultContext).In("id", ids).Find(&schedules)
return schedules, db.GetEngine(ctx).In("id", ids).Find(&schedules)
}
// GetReposMapByIDs returns the repos by given id slice.
func GetReposMapByIDs(ids []int64) (map[int64]*repo_model.Repository, error) {
func GetReposMapByIDs(ctx context.Context, ids []int64) (map[int64]*repo_model.Repository, error) {
repos := make(map[int64]*repo_model.Repository, len(ids))
return repos, db.GetEngine(db.DefaultContext).In("id", ids).Find(&repos)
return repos, db.GetEngine(ctx).In("id", ids).Find(&repos)
}
var cronParser = cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)

View File

@ -23,9 +23,9 @@ func (specs SpecList) GetScheduleIDs() []int64 {
return ids.Values()
}
func (specs SpecList) LoadSchedules() error {
func (specs SpecList) LoadSchedules(ctx context.Context) error {
scheduleIDs := specs.GetScheduleIDs()
schedules, err := GetSchedulesMapByIDs(scheduleIDs)
schedules, err := GetSchedulesMapByIDs(ctx, scheduleIDs)
if err != nil {
return err
}
@ -34,7 +34,7 @@ func (specs SpecList) LoadSchedules() error {
}
repoIDs := specs.GetRepoIDs()
repos, err := GetReposMapByIDs(repoIDs)
repos, err := GetReposMapByIDs(ctx, repoIDs)
if err != nil {
return err
}
@ -95,7 +95,7 @@ func FindSpecs(ctx context.Context, opts FindSpecOptions) (SpecList, int64, erro
return nil, 0, err
}
if err := specs.LoadSchedules(); err != nil {
if err := specs.LoadSchedules(ctx); err != nil {
return nil, 0, err
}
return specs, total, nil

View File

@ -48,11 +48,7 @@ type TranslatableMessage struct {
}
// LoadRepo loads repository of the task
func (task *Task) LoadRepo() error {
return task.loadRepo(db.DefaultContext)
}
func (task *Task) loadRepo(ctx context.Context) error {
func (task *Task) LoadRepo(ctx context.Context) error {
if task.Repo != nil {
return nil
}
@ -70,13 +66,13 @@ func (task *Task) loadRepo(ctx context.Context) error {
}
// LoadDoer loads do user
func (task *Task) LoadDoer() error {
func (task *Task) LoadDoer(ctx context.Context) error {
if task.Doer != nil {
return nil
}
var doer user_model.User
has, err := db.GetEngine(db.DefaultContext).ID(task.DoerID).Get(&doer)
has, err := db.GetEngine(ctx).ID(task.DoerID).Get(&doer)
if err != nil {
return err
} else if !has {
@ -90,13 +86,13 @@ func (task *Task) LoadDoer() error {
}
// LoadOwner loads owner user
func (task *Task) LoadOwner() error {
func (task *Task) LoadOwner(ctx context.Context) error {
if task.Owner != nil {
return nil
}
var owner user_model.User
has, err := db.GetEngine(db.DefaultContext).ID(task.OwnerID).Get(&owner)
has, err := db.GetEngine(ctx).ID(task.OwnerID).Get(&owner)
if err != nil {
return err
} else if !has {
@ -110,8 +106,8 @@ func (task *Task) LoadOwner() error {
}
// UpdateCols updates some columns
func (task *Task) UpdateCols(cols ...string) error {
_, err := db.GetEngine(db.DefaultContext).ID(task.ID).Cols(cols...).Update(task)
func (task *Task) UpdateCols(ctx context.Context, cols ...string) error {
_, err := db.GetEngine(ctx).ID(task.ID).Cols(cols...).Update(task)
return err
}
@ -169,12 +165,12 @@ func (err ErrTaskDoesNotExist) Unwrap() error {
}
// GetMigratingTask returns the migrating task by repo's id
func GetMigratingTask(repoID int64) (*Task, error) {
func GetMigratingTask(ctx context.Context, repoID int64) (*Task, error) {
task := Task{
RepoID: repoID,
Type: structs.TaskTypeMigrateRepo,
}
has, err := db.GetEngine(db.DefaultContext).Get(&task)
has, err := db.GetEngine(ctx).Get(&task)
if err != nil {
return nil, err
} else if !has {
@ -184,13 +180,13 @@ func GetMigratingTask(repoID int64) (*Task, error) {
}
// GetMigratingTaskByID returns the migrating task by repo's id
func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, error) {
func GetMigratingTaskByID(ctx context.Context, id, doerID int64) (*Task, *migration.MigrateOptions, error) {
task := Task{
ID: id,
DoerID: doerID,
Type: structs.TaskTypeMigrateRepo,
}
has, err := db.GetEngine(db.DefaultContext).Get(&task)
has, err := db.GetEngine(ctx).Get(&task)
if err != nil {
return nil, nil, err
} else if !has {
@ -205,12 +201,12 @@ func GetMigratingTaskByID(id, doerID int64) (*Task, *migration.MigrateOptions, e
}
// CreateTask creates a task on database
func CreateTask(task *Task) error {
return db.Insert(db.DefaultContext, task)
func CreateTask(ctx context.Context, task *Task) error {
return db.Insert(ctx, task)
}
// FinishMigrateTask updates database when migrate task finished
func FinishMigrateTask(task *Task) error {
func FinishMigrateTask(ctx context.Context, task *Task) error {
task.Status = structs.TaskStatusFinished
task.EndTime = timeutil.TimeStampNow()
@ -231,6 +227,6 @@ func FinishMigrateTask(task *Task) error {
}
task.PayloadContent = string(confBytes)
_, err = db.GetEngine(db.DefaultContext).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task)
_, err = db.GetEngine(ctx).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task)
return err
}

View File

@ -4,6 +4,7 @@
package auth
import (
"context"
"fmt"
"code.gitea.io/gitea/models/db"
@ -22,8 +23,8 @@ func init() {
}
// UpdateSession updates the session with provided id
func UpdateSession(key string, data []byte) error {
_, err := db.GetEngine(db.DefaultContext).ID(key).Update(&Session{
func UpdateSession(ctx context.Context, key string, data []byte) error {
_, err := db.GetEngine(ctx).ID(key).Update(&Session{
Data: data,
Expiry: timeutil.TimeStampNow(),
})
@ -31,12 +32,12 @@ func UpdateSession(key string, data []byte) error {
}
// ReadSession reads the data for the provided session
func ReadSession(key string) (*Session, error) {
func ReadSession(ctx context.Context, key string) (*Session, error) {
session := Session{
Key: key,
}
ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
@ -55,24 +56,24 @@ func ReadSession(key string) (*Session, error) {
}
// ExistSession checks if a session exists
func ExistSession(key string) (bool, error) {
func ExistSession(ctx context.Context, key string) (bool, error) {
session := Session{
Key: key,
}
return db.GetEngine(db.DefaultContext).Get(&session)
return db.GetEngine(ctx).Get(&session)
}
// DestroySession destroys a session
func DestroySession(key string) error {
_, err := db.GetEngine(db.DefaultContext).Delete(&Session{
func DestroySession(ctx context.Context, key string) error {
_, err := db.GetEngine(ctx).Delete(&Session{
Key: key,
})
return err
}
// RegenerateSession regenerates a session from the old id
func RegenerateSession(oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext(db.DefaultContext)
func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
@ -114,12 +115,12 @@ func RegenerateSession(oldKey, newKey string) (*Session, error) {
}
// CountSessions returns the number of sessions
func CountSessions() (int64, error) {
return db.GetEngine(db.DefaultContext).Count(&Session{})
func CountSessions(ctx context.Context) (int64, error) {
return db.GetEngine(ctx).Count(&Session{})
}
// CleanupSessions cleans up expired sessions
func CleanupSessions(maxLifetime int64) error {
_, err := db.GetEngine(db.DefaultContext).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
func CleanupSessions(ctx context.Context, maxLifetime int64) error {
_, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
return err
}

View File

@ -67,11 +67,7 @@ func (cred WebAuthnCredential) TableName() string {
}
// UpdateSignCount will update the database value of SignCount
func (cred *WebAuthnCredential) UpdateSignCount() error {
return cred.updateSignCount(db.DefaultContext)
}
func (cred *WebAuthnCredential) updateSignCount(ctx context.Context) error {
func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred)
return err
}
@ -113,30 +109,18 @@ func (list WebAuthnCredentialList) ToCredentials() []webauthn.Credential {
}
// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user
func GetWebAuthnCredentialsByUID(uid int64) (WebAuthnCredentialList, error) {
return getWebAuthnCredentialsByUID(db.DefaultContext, uid)
}
func getWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
creds := make(WebAuthnCredentialList, 0)
return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds)
}
// ExistsWebAuthnCredentialsForUID returns if the given user has credentials
func ExistsWebAuthnCredentialsForUID(uid int64) (bool, error) {
return existsWebAuthnCredentialsByUID(db.DefaultContext, uid)
}
func existsWebAuthnCredentialsByUID(ctx context.Context, uid int64) (bool, error) {
func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}
// GetWebAuthnCredentialByName returns WebAuthn credential by id
func GetWebAuthnCredentialByName(uid int64, name string) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByName(db.DefaultContext, uid, name)
}
func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil {
return nil, err
@ -147,11 +131,7 @@ func getWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*
}
// GetWebAuthnCredentialByID returns WebAuthn credential by id
func GetWebAuthnCredentialByID(id int64) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByID(db.DefaultContext, id)
}
func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil {
return nil, err
@ -162,16 +142,12 @@ func getWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredenti
}
// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations
func HasWebAuthnRegistrationsByUID(uid int64) (bool, error) {
return db.GetEngine(db.DefaultContext).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}
// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID
func GetWebAuthnCredentialByCredID(userID int64, credID []byte) (*WebAuthnCredential, error) {
return getWebAuthnCredentialByCredID(db.DefaultContext, userID, credID)
}
func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil {
return nil, err
@ -182,11 +158,7 @@ func getWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []b
}
// CreateCredential will create a new WebAuthnCredential from the given Credential
func CreateCredential(userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
return createCredential(db.DefaultContext, userID, name, cred)
}
func createCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
c := &WebAuthnCredential{
UserID: userID,
Name: name,
@ -205,18 +177,14 @@ func createCredential(ctx context.Context, userID int64, name string, cred *weba
}
// DeleteCredential will delete WebAuthnCredential
func DeleteCredential(id, userID int64) (bool, error) {
return deleteCredential(db.DefaultContext, id, userID)
}
func deleteCredential(ctx context.Context, id, userID int64) (bool, error) {
func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) {
had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{})
return had > 0, err
}
// WebAuthnCredentials implementns the webauthn.User interface
func WebAuthnCredentials(userID int64) ([]webauthn.Credential, error) {
dbCreds, err := GetWebAuthnCredentialsByUID(userID)
func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) {
dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID)
if err != nil {
return nil, err
}

View File

@ -7,6 +7,7 @@ import (
"testing"
auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/go-webauthn/webauthn/webauthn"
@ -16,11 +17,11 @@ import (
func TestGetWebAuthnCredentialByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.GetWebAuthnCredentialByID(1)
res, err := auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, "WebAuthn credential", res.Name)
_, err = auth_model.GetWebAuthnCredentialByID(342432)
_, err = auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 342432)
assert.Error(t, err)
assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err))
}
@ -28,7 +29,7 @@ func TestGetWebAuthnCredentialByID(t *testing.T) {
func TestGetWebAuthnCredentialsByUID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.GetWebAuthnCredentialsByUID(32)
res, err := auth_model.GetWebAuthnCredentialsByUID(db.DefaultContext, 32)
assert.NoError(t, err)
assert.Len(t, res, 1)
assert.Equal(t, "WebAuthn credential", res[0].Name)
@ -42,7 +43,7 @@ func TestWebAuthnCredential_UpdateSignCount(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
cred.SignCount = 1
assert.NoError(t, cred.UpdateSignCount())
assert.NoError(t, cred.UpdateSignCount(db.DefaultContext))
unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1})
}
@ -50,14 +51,14 @@ func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
cred.SignCount = 0xffffffff
assert.NoError(t, cred.UpdateSignCount())
assert.NoError(t, cred.UpdateSignCount(db.DefaultContext))
unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff})
}
func TestCreateCredential(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.CreateCredential(1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")})
res, err := auth_model.CreateCredential(db.DefaultContext, 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")})
assert.NoError(t, err)
assert.Equal(t, "WebAuthn Created Credential", res.Name)
assert.Equal(t, []byte("Test"), res.CredentialID)

View File

@ -385,7 +385,7 @@ func TestMilestoneList_LoadTotalTrackedTimes(t *testing.T) {
unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}),
}
assert.NoError(t, miles.LoadTotalTrackedTimes())
assert.NoError(t, miles.LoadTotalTrackedTimes(db.DefaultContext))
assert.Equal(t, int64(3682), miles[0].TotalTrackedTime)
}
@ -394,7 +394,7 @@ func TestLoadTotalTrackedTime(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1})
assert.NoError(t, milestone.LoadTotalTrackedTime())
assert.NoError(t, milestone.LoadTotalTrackedTime(db.DefaultContext))
assert.Equal(t, int64(3682), milestone.TotalTrackedTime)
}

View File

@ -30,8 +30,8 @@ func init() {
type IssueWatchList []*IssueWatch
// CreateOrUpdateIssueWatch set watching for a user and issue
func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error {
iw, exists, err := GetIssueWatch(db.DefaultContext, userID, issueID)
func CreateOrUpdateIssueWatch(ctx context.Context, userID, issueID int64, isWatching bool) error {
iw, exists, err := GetIssueWatch(ctx, userID, issueID)
if err != nil {
return err
}
@ -43,13 +43,13 @@ func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error {
IsWatching: isWatching,
}
if _, err := db.GetEngine(db.DefaultContext).Insert(iw); err != nil {
if _, err := db.GetEngine(ctx).Insert(iw); err != nil {
return err
}
} else {
iw.IsWatching = isWatching
if _, err := db.GetEngine(db.DefaultContext).ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil {
if _, err := db.GetEngine(ctx).ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil {
return err
}
}
@ -69,15 +69,15 @@ func GetIssueWatch(ctx context.Context, userID, issueID int64) (iw *IssueWatch,
// CheckIssueWatch check if an user is watching an issue
// it takes participants and repo watch into account
func CheckIssueWatch(user *user_model.User, issue *Issue) (bool, error) {
iw, exist, err := GetIssueWatch(db.DefaultContext, user.ID, issue.ID)
func CheckIssueWatch(ctx context.Context, user *user_model.User, issue *Issue) (bool, error) {
iw, exist, err := GetIssueWatch(ctx, user.ID, issue.ID)
if err != nil {
return false, err
}
if exist {
return iw.IsWatching, nil
}
w, err := repo_model.GetWatch(db.DefaultContext, user.ID, issue.RepoID)
w, err := repo_model.GetWatch(ctx, user.ID, issue.RepoID)
if err != nil {
return false, err
}

View File

@ -16,11 +16,11 @@ import (
func TestCreateOrUpdateIssueWatch(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(3, 1, true))
assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(db.DefaultContext, 3, 1, true))
iw := unittest.AssertExistsAndLoadBean(t, &issues_model.IssueWatch{UserID: 3, IssueID: 1})
assert.True(t, iw.IsWatching)
assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(1, 1, false))
assert.NoError(t, issues_model.CreateOrUpdateIssueWatch(db.DefaultContext, 1, 1, false))
iw = unittest.AssertExistsAndLoadBean(t, &issues_model.IssueWatch{UserID: 1, IssueID: 1})
assert.False(t, iw.IsWatching)
}

View File

@ -199,8 +199,8 @@ func NewLabel(ctx context.Context, l *Label) error {
}
// NewLabels creates new labels
func NewLabels(labels ...*Label) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func NewLabels(ctx context.Context, labels ...*Label) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -221,19 +221,19 @@ func NewLabels(labels ...*Label) error {
}
// UpdateLabel updates label information.
func UpdateLabel(l *Label) error {
func UpdateLabel(ctx context.Context, l *Label) error {
color, err := label.NormalizeColor(l.Color)
if err != nil {
return err
}
l.Color = color
return updateLabelCols(db.DefaultContext, l, "name", "description", "color", "exclusive", "archived_unix")
return updateLabelCols(ctx, l, "name", "description", "color", "exclusive", "archived_unix")
}
// DeleteLabel delete a label
func DeleteLabel(id, labelID int64) error {
l, err := GetLabelByID(db.DefaultContext, labelID)
func DeleteLabel(ctx context.Context, id, labelID int64) error {
l, err := GetLabelByID(ctx, labelID)
if err != nil {
if IsErrLabelNotExist(err) {
return nil
@ -241,7 +241,7 @@ func DeleteLabel(id, labelID int64) error {
return err
}
ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -289,9 +289,9 @@ func GetLabelByID(ctx context.Context, labelID int64) (*Label, error) {
}
// GetLabelsByIDs returns a list of labels by IDs
func GetLabelsByIDs(labelIDs []int64, cols ...string) ([]*Label, error) {
func GetLabelsByIDs(ctx context.Context, labelIDs []int64, cols ...string) ([]*Label, error) {
labels := make([]*Label, 0, len(labelIDs))
return labels, db.GetEngine(db.DefaultContext).Table("label").
return labels, db.GetEngine(ctx).Table("label").
In("id", labelIDs).
Asc("name").
Cols(cols...).
@ -339,9 +339,9 @@ func GetLabelInRepoByID(ctx context.Context, repoID, labelID int64) (*Label, err
// GetLabelIDsInRepoByNames returns a list of labelIDs by names in a given
// repository.
// it silently ignores label names that do not belong to the repository.
func GetLabelIDsInRepoByNames(repoID int64, labelNames []string) ([]int64, error) {
func GetLabelIDsInRepoByNames(ctx context.Context, repoID int64, labelNames []string) ([]int64, error) {
labelIDs := make([]int64, 0, len(labelNames))
return labelIDs, db.GetEngine(db.DefaultContext).Table("label").
return labelIDs, db.GetEngine(ctx).Table("label").
Where("repo_id = ?", repoID).
In("name", labelNames).
Asc("name").
@ -398,8 +398,8 @@ func GetLabelsByRepoID(ctx context.Context, repoID int64, sortType string, listO
}
// CountLabelsByRepoID count number of all labels that belong to given repository by ID.
func CountLabelsByRepoID(repoID int64) (int64, error) {
return db.GetEngine(db.DefaultContext).Where("repo_id = ?", repoID).Count(&Label{})
func CountLabelsByRepoID(ctx context.Context, repoID int64) (int64, error) {
return db.GetEngine(ctx).Where("repo_id = ?", repoID).Count(&Label{})
}
// GetLabelInOrgByName returns a label by name in given organization.
@ -442,13 +442,13 @@ func GetLabelInOrgByID(ctx context.Context, orgID, labelID int64) (*Label, error
// GetLabelIDsInOrgByNames returns a list of labelIDs by names in a given
// organization.
func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error) {
func GetLabelIDsInOrgByNames(ctx context.Context, orgID int64, labelNames []string) ([]int64, error) {
if orgID <= 0 {
return nil, ErrOrgLabelNotExist{0, orgID}
}
labelIDs := make([]int64, 0, len(labelNames))
return labelIDs, db.GetEngine(db.DefaultContext).Table("label").
return labelIDs, db.GetEngine(ctx).Table("label").
Where("org_id = ?", orgID).
In("name", labelNames).
Asc("name").
@ -506,8 +506,8 @@ func GetLabelIDsByNames(ctx context.Context, labelNames []string) ([]int64, erro
}
// CountLabelsByOrgID count all labels that belong to given organization by ID.
func CountLabelsByOrgID(orgID int64) (int64, error) {
return db.GetEngine(db.DefaultContext).Where("org_id = ?", orgID).Count(&Label{})
func CountLabelsByOrgID(ctx context.Context, orgID int64) (int64, error) {
return db.GetEngine(ctx).Where("org_id = ?", orgID).Count(&Label{})
}
func updateLabelCols(ctx context.Context, l *Label, cols ...string) error {

View File

@ -48,7 +48,7 @@ func TestNewLabels(t *testing.T) {
for _, label := range labels {
unittest.AssertNotExistsBean(t, label)
}
assert.NoError(t, issues_model.NewLabels(labels...))
assert.NoError(t, issues_model.NewLabels(db.DefaultContext, labels...))
for _, label := range labels {
unittest.AssertExistsAndLoadBean(t, label, unittest.Cond("id = ?", label.ID))
}
@ -81,7 +81,7 @@ func TestGetLabelInRepoByName(t *testing.T) {
func TestGetLabelInRepoByNames(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
labelIDs, err := issues_model.GetLabelIDsInRepoByNames(1, []string{"label1", "label2"})
labelIDs, err := issues_model.GetLabelIDsInRepoByNames(db.DefaultContext, 1, []string{"label1", "label2"})
assert.NoError(t, err)
assert.Len(t, labelIDs, 2)
@ -93,7 +93,7 @@ func TestGetLabelInRepoByNames(t *testing.T) {
func TestGetLabelInRepoByNamesDiscardsNonExistentLabels(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
// label3 doesn't exists.. See labels.yml
labelIDs, err := issues_model.GetLabelIDsInRepoByNames(1, []string{"label1", "label2", "label3"})
labelIDs, err := issues_model.GetLabelIDsInRepoByNames(db.DefaultContext, 1, []string{"label1", "label2", "label3"})
assert.NoError(t, err)
assert.Len(t, labelIDs, 2)
@ -166,7 +166,7 @@ func TestGetLabelInOrgByName(t *testing.T) {
func TestGetLabelInOrgByNames(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
labelIDs, err := issues_model.GetLabelIDsInOrgByNames(3, []string{"orglabel3", "orglabel4"})
labelIDs, err := issues_model.GetLabelIDsInOrgByNames(db.DefaultContext, 3, []string{"orglabel3", "orglabel4"})
assert.NoError(t, err)
assert.Len(t, labelIDs, 2)
@ -178,7 +178,7 @@ func TestGetLabelInOrgByNames(t *testing.T) {
func TestGetLabelInOrgByNamesDiscardsNonExistentLabels(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
// orglabel99 doesn't exists.. See labels.yml
labelIDs, err := issues_model.GetLabelIDsInOrgByNames(3, []string{"orglabel3", "orglabel4", "orglabel99"})
labelIDs, err := issues_model.GetLabelIDsInOrgByNames(db.DefaultContext, 3, []string{"orglabel3", "orglabel4", "orglabel99"})
assert.NoError(t, err)
assert.Len(t, labelIDs, 2)
@ -269,7 +269,7 @@ func TestUpdateLabel(t *testing.T) {
}
label.Color = update.Color
label.Name = update.Name
assert.NoError(t, issues_model.UpdateLabel(update))
assert.NoError(t, issues_model.UpdateLabel(db.DefaultContext, update))
newLabel := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 1})
assert.EqualValues(t, label.ID, newLabel.ID)
assert.EqualValues(t, label.Color, newLabel.Color)
@ -282,13 +282,13 @@ func TestUpdateLabel(t *testing.T) {
func TestDeleteLabel(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
label := unittest.AssertExistsAndLoadBean(t, &issues_model.Label{ID: 1})
assert.NoError(t, issues_model.DeleteLabel(label.RepoID, label.ID))
assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, label.RepoID, label.ID))
unittest.AssertNotExistsBean(t, &issues_model.Label{ID: label.ID, RepoID: label.RepoID})
assert.NoError(t, issues_model.DeleteLabel(label.RepoID, label.ID))
assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, label.RepoID, label.ID))
unittest.AssertNotExistsBean(t, &issues_model.Label{ID: label.ID})
assert.NoError(t, issues_model.DeleteLabel(unittest.NonexistentID, unittest.NonexistentID))
assert.NoError(t, issues_model.DeleteLabel(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID))
unittest.CheckConsistencyFor(t, &issues_model.Label{}, &repo_model.Repository{})
}

View File

@ -103,8 +103,8 @@ func (m *Milestone) State() api.StateType {
}
// NewMilestone creates new milestone of repository.
func NewMilestone(m *Milestone) (err error) {
ctx, committer, err := db.TxContext(db.DefaultContext)
func NewMilestone(ctx context.Context, m *Milestone) (err error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -140,9 +140,9 @@ func GetMilestoneByRepoID(ctx context.Context, repoID, id int64) (*Milestone, er
}
// GetMilestoneByRepoIDANDName return a milestone if one exist by name and repo
func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error) {
func GetMilestoneByRepoIDANDName(ctx context.Context, repoID int64, name string) (*Milestone, error) {
var mile Milestone
has, err := db.GetEngine(db.DefaultContext).Where("repo_id=? AND name=?", repoID, name).Get(&mile)
has, err := db.GetEngine(ctx).Where("repo_id=? AND name=?", repoID, name).Get(&mile)
if err != nil {
return nil, err
}
@ -153,8 +153,8 @@ func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error)
}
// UpdateMilestone updates information of given milestone.
func UpdateMilestone(m *Milestone, oldIsClosed bool) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func UpdateMilestone(ctx context.Context, m *Milestone, oldIsClosed bool) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -211,8 +211,8 @@ func UpdateMilestoneCounters(ctx context.Context, id int64) error {
}
// ChangeMilestoneStatusByRepoIDAndID changes a milestone open/closed status if the milestone ID is in the repo.
func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func ChangeMilestoneStatusByRepoIDAndID(ctx context.Context, repoID, milestoneID int64, isClosed bool) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -238,8 +238,8 @@ func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool
}
// ChangeMilestoneStatus changes the milestone open/closed status.
func ChangeMilestoneStatus(m *Milestone, isClosed bool) (err error) {
ctx, committer, err := db.TxContext(db.DefaultContext)
func ChangeMilestoneStatus(ctx context.Context, m *Milestone, isClosed bool) (err error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -269,8 +269,8 @@ func changeMilestoneStatus(ctx context.Context, m *Milestone, isClosed bool) err
}
// DeleteMilestoneByRepoID deletes a milestone from a repository.
func DeleteMilestoneByRepoID(repoID, id int64) error {
m, err := GetMilestoneByRepoID(db.DefaultContext, repoID, id)
func DeleteMilestoneByRepoID(ctx context.Context, repoID, id int64) error {
m, err := GetMilestoneByRepoID(ctx, repoID, id)
if err != nil {
if IsErrMilestoneNotExist(err) {
return nil
@ -278,12 +278,12 @@ func DeleteMilestoneByRepoID(repoID, id int64) error {
return err
}
repo, err := repo_model.GetRepositoryByID(db.DefaultContext, m.RepoID)
repo, err := repo_model.GetRepositoryByID(ctx, m.RepoID)
if err != nil {
return err
}
ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
@ -332,7 +332,8 @@ func updateRepoMilestoneNum(ctx context.Context, repoID int64) error {
return err
}
func (m *Milestone) loadTotalTrackedTime(ctx context.Context) error {
// LoadTotalTrackedTime loads the tracked time for the milestone
func (m *Milestone) LoadTotalTrackedTime(ctx context.Context) error {
type totalTimesByMilestone struct {
MilestoneID int64
Time int64
@ -355,18 +356,13 @@ func (m *Milestone) loadTotalTrackedTime(ctx context.Context) error {
return nil
}
// LoadTotalTrackedTime loads the tracked time for the milestone
func (m *Milestone) LoadTotalTrackedTime() error {
return m.loadTotalTrackedTime(db.DefaultContext)
}
// InsertMilestones creates milestones of repository.
func InsertMilestones(ms ...*Milestone) (err error) {
func InsertMilestones(ctx context.Context, ms ...*Milestone) (err error) {
if len(ms) == 0 {
return nil
}
ctx, committer, err := db.TxContext(db.DefaultContext)
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}

View File

@ -100,9 +100,9 @@ func GetMilestoneIDsByNames(ctx context.Context, names []string) ([]int64, error
}
// SearchMilestones search milestones
func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType, keyword string) (MilestoneList, error) {
func SearchMilestones(ctx context.Context, repoCond builder.Cond, page int, isClosed bool, sortType, keyword string) (MilestoneList, error) {
miles := make([]*Milestone, 0, setting.UI.IssuePagingNum)
sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed)
sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed)
if len(keyword) > 0 {
sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)})
}
@ -131,8 +131,9 @@ func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType,
}
// GetMilestonesByRepoIDs returns a list of milestones of given repositories and status.
func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) {
func GetMilestonesByRepoIDs(ctx context.Context, repoIDs []int64, page int, isClosed bool, sortType string) (MilestoneList, error) {
return SearchMilestones(
ctx,
builder.In("repo_id", repoIDs),
page,
isClosed,
@ -141,7 +142,8 @@ func GetMilestonesByRepoIDs(repoIDs []int64, page int, isClosed bool, sortType s
)
}
func (milestones MilestoneList) loadTotalTrackedTimes(ctx context.Context) error {
// LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request
func (milestones MilestoneList) LoadTotalTrackedTimes(ctx context.Context) error {
type totalTimesByMilestone struct {
MilestoneID int64
Time int64
@ -181,11 +183,6 @@ func (milestones MilestoneList) loadTotalTrackedTimes(ctx context.Context) error
return nil
}
// LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request
func (milestones MilestoneList) LoadTotalTrackedTimes() error {
return milestones.loadTotalTrackedTimes(db.DefaultContext)
}
// CountMilestones returns number of milestones in given repository with other options
func CountMilestones(ctx context.Context, opts GetMilestonesOption) (int64, error) {
return db.GetEngine(ctx).
@ -194,8 +191,8 @@ func CountMilestones(ctx context.Context, opts GetMilestonesOption) (int64, erro
}
// CountMilestonesByRepoCond map from repo conditions to number of milestones matching the options`
func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64]int64, error) {
sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed)
func CountMilestonesByRepoCond(ctx context.Context, repoCond builder.Cond, isClosed bool) (map[int64]int64, error) {
sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed)
if repoCond.IsValid() {
sess.In("repo_id", builder.Select("id").From("repository").Where(repoCond))
}
@ -219,8 +216,8 @@ func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64]
}
// CountMilestonesByRepoCondAndKw map from repo conditions and the keyword of milestones' name to number of milestones matching the options`
func CountMilestonesByRepoCondAndKw(repoCond builder.Cond, keyword string, isClosed bool) (map[int64]int64, error) {
sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", isClosed)
func CountMilestonesByRepoCondAndKw(ctx context.Context, repoCond builder.Cond, keyword string, isClosed bool) (map[int64]int64, error) {
sess := db.GetEngine(ctx).Where("is_closed = ?", isClosed)
if len(keyword) > 0 {
sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)})
}
@ -257,11 +254,11 @@ func (m MilestonesStats) Total() int64 {
}
// GetMilestonesStatsByRepoCond returns milestone statistic information for dashboard by given conditions.
func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, error) {
func GetMilestonesStatsByRepoCond(ctx context.Context, repoCond builder.Cond) (*MilestonesStats, error) {
var err error
stats := &MilestonesStats{}
sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", false)
sess := db.GetEngine(ctx).Where("is_closed = ?", false)
if repoCond.IsValid() {
sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond)))
}
@ -270,7 +267,7 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro
return nil, err
}
sess = db.GetEngine(db.DefaultContext).Where("is_closed = ?", true)
sess = db.GetEngine(ctx).Where("is_closed = ?", true)
if repoCond.IsValid() {
sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond)))
}
@ -283,11 +280,11 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro
}
// GetMilestonesStatsByRepoCondAndKw returns milestone statistic information for dashboard by given repo conditions and name keyword.
func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (*MilestonesStats, error) {
func GetMilestonesStatsByRepoCondAndKw(ctx context.Context, repoCond builder.Cond, keyword string) (*MilestonesStats, error) {
var err error
stats := &MilestonesStats{}
sess := db.GetEngine(db.DefaultContext).Where("is_closed = ?", false)
sess := db.GetEngine(ctx).Where("is_closed = ?", false)
if len(keyword) > 0 {
sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)})
}
@ -299,7 +296,7 @@ func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (*
return nil, err
}
sess = db.GetEngine(db.DefaultContext).Where("is_closed = ?", true)
sess = db.GetEngine(ctx).Where("is_closed = ?", true)
if len(keyword) > 0 {
sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)})
}

View File

@ -201,12 +201,12 @@ func TestCountMilestonesByRepoIDs(t *testing.T) {
repo1OpenCount, repo1ClosedCount := milestonesCount(1)
repo2OpenCount, repo2ClosedCount := milestonesCount(2)
openCounts, err := issues_model.CountMilestonesByRepoCond(builder.In("repo_id", []int64{1, 2}), false)
openCounts, err := issues_model.CountMilestonesByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{1, 2}), false)
assert.NoError(t, err)
assert.EqualValues(t, repo1OpenCount, openCounts[1])
assert.EqualValues(t, repo2OpenCount, openCounts[2])
closedCounts, err := issues_model.CountMilestonesByRepoCond(builder.In("repo_id", []int64{1, 2}), true)
closedCounts, err := issues_model.CountMilestonesByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{1, 2}), true)
assert.NoError(t, err)
assert.EqualValues(t, repo1ClosedCount, closedCounts[1])
assert.EqualValues(t, repo2ClosedCount, closedCounts[2])
@ -218,7 +218,7 @@ func TestGetMilestonesByRepoIDs(t *testing.T) {
repo2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 2})
test := func(sortType string, sortCond func(*issues_model.Milestone) int) {
for _, page := range []int{0, 1} {
openMilestones, err := issues_model.GetMilestonesByRepoIDs([]int64{repo1.ID, repo2.ID}, page, false, sortType)
openMilestones, err := issues_model.GetMilestonesByRepoIDs(db.DefaultContext, []int64{repo1.ID, repo2.ID}, page, false, sortType)
assert.NoError(t, err)
assert.Len(t, openMilestones, repo1.NumOpenMilestones+repo2.NumOpenMilestones)
values := make([]int, len(openMilestones))
@ -227,7 +227,7 @@ func TestGetMilestonesByRepoIDs(t *testing.T) {
}
assert.True(t, sort.IntsAreSorted(values))
closedMilestones, err := issues_model.GetMilestonesByRepoIDs([]int64{repo1.ID, repo2.ID}, page, true, sortType)
closedMilestones, err := issues_model.GetMilestonesByRepoIDs(db.DefaultContext, []int64{repo1.ID, repo2.ID}, page, true, sortType)
assert.NoError(t, err)
assert.Len(t, closedMilestones, repo1.NumClosedMilestones+repo2.NumClosedMilestones)
values = make([]int, len(closedMilestones))
@ -262,7 +262,7 @@ func TestGetMilestonesStats(t *testing.T) {
test := func(repoID int64) {
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID})
stats, err := issues_model.GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": repoID}))
stats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.And(builder.Eq{"repo_id": repoID}))
assert.NoError(t, err)
assert.EqualValues(t, repo.NumMilestones-repo.NumClosedMilestones, stats.OpenCount)
assert.EqualValues(t, repo.NumClosedMilestones, stats.ClosedCount)
@ -271,7 +271,7 @@ func TestGetMilestonesStats(t *testing.T) {
test(2)
test(3)
stats, err := issues_model.GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": unittest.NonexistentID}))
stats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.And(builder.Eq{"repo_id": unittest.NonexistentID}))
assert.NoError(t, err)
assert.EqualValues(t, 0, stats.OpenCount)
assert.EqualValues(t, 0, stats.ClosedCount)
@ -279,7 +279,7 @@ func TestGetMilestonesStats(t *testing.T) {
repo1 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
repo2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 2})
milestoneStats, err := issues_model.GetMilestonesStatsByRepoCond(builder.In("repo_id", []int64{repo1.ID, repo2.ID}))
milestoneStats, err := issues_model.GetMilestonesStatsByRepoCond(db.DefaultContext, builder.In("repo_id", []int64{repo1.ID, repo2.ID}))
assert.NoError(t, err)
assert.EqualValues(t, repo1.NumOpenMilestones+repo2.NumOpenMilestones, milestoneStats.OpenCount)
assert.EqualValues(t, repo1.NumClosedMilestones+repo2.NumClosedMilestones, milestoneStats.ClosedCount)
@ -293,7 +293,7 @@ func TestNewMilestone(t *testing.T) {
Content: "milestoneContent",
}
assert.NoError(t, issues_model.NewMilestone(milestone))
assert.NoError(t, issues_model.NewMilestone(db.DefaultContext, milestone))
unittest.AssertExistsAndLoadBean(t, milestone)
unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{})
}
@ -302,22 +302,22 @@ func TestChangeMilestoneStatus(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1})
assert.NoError(t, issues_model.ChangeMilestoneStatus(milestone, true))
assert.NoError(t, issues_model.ChangeMilestoneStatus(db.DefaultContext, milestone, true))
unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}, "is_closed=1")
unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{})
assert.NoError(t, issues_model.ChangeMilestoneStatus(milestone, false))
assert.NoError(t, issues_model.ChangeMilestoneStatus(db.DefaultContext, milestone, false))
unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1}, "is_closed=0")
unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: milestone.RepoID}, &issues_model.Milestone{})
}
func TestDeleteMilestoneByRepoID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.NoError(t, issues_model.DeleteMilestoneByRepoID(1, 1))
assert.NoError(t, issues_model.DeleteMilestoneByRepoID(db.DefaultContext, 1, 1))
unittest.AssertNotExistsBean(t, &issues_model.Milestone{ID: 1})
unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: 1})
assert.NoError(t, issues_model.DeleteMilestoneByRepoID(unittest.NonexistentID, unittest.NonexistentID))
assert.NoError(t, issues_model.DeleteMilestoneByRepoID(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID))
}
func TestUpdateMilestone(t *testing.T) {
@ -326,7 +326,7 @@ func TestUpdateMilestone(t *testing.T) {
milestone := unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1})
milestone.Name = " newMilestoneName "
milestone.Content = "newMilestoneContent"
assert.NoError(t, issues_model.UpdateMilestone(milestone, milestone.IsClosed))
assert.NoError(t, issues_model.UpdateMilestone(db.DefaultContext, milestone, milestone.IsClosed))
milestone = unittest.AssertExistsAndLoadBean(t, &issues_model.Milestone{ID: 1})
assert.EqualValues(t, "newMilestoneName", milestone.Name)
unittest.CheckConsistencyFor(t, &issues_model.Milestone{})
@ -361,7 +361,7 @@ func TestMigrate_InsertMilestones(t *testing.T) {
RepoID: repo.ID,
Name: name,
}
err := issues_model.InsertMilestones(ms)
err := issues_model.InsertMilestones(db.DefaultContext, ms)
assert.NoError(t, err)
unittest.AssertExistsAndLoadBean(t, ms)
repoModified := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repo.ID})

View File

@ -81,9 +81,9 @@ type UserStopwatch struct {
}
// GetUIDsAndNotificationCounts between the two provided times
func GetUIDsAndStopwatch() ([]*UserStopwatch, error) {
func GetUIDsAndStopwatch(ctx context.Context) ([]*UserStopwatch, error) {
sws := []*Stopwatch{}
if err := db.GetEngine(db.DefaultContext).Where("issue_id != 0").Find(&sws); err != nil {
if err := db.GetEngine(ctx).Where("issue_id != 0").Find(&sws); err != nil {
return nil, err
}
if len(sws) == 0 {
@ -107,9 +107,9 @@ func GetUIDsAndStopwatch() ([]*UserStopwatch, error) {
}
// GetUserStopwatches return list of all stopwatches of a user
func GetUserStopwatches(userID int64, listOptions db.ListOptions) ([]*Stopwatch, error) {
func GetUserStopwatches(ctx context.Context, userID int64, listOptions db.ListOptions) ([]*Stopwatch, error) {
sws := make([]*Stopwatch, 0, 8)
sess := db.GetEngine(db.DefaultContext).Where("stopwatch.user_id = ?", userID)
sess := db.GetEngine(ctx).Where("stopwatch.user_id = ?", userID)
if listOptions.Page != 0 {
sess = db.SetSessionPagination(sess, &listOptions)
}
@ -122,13 +122,13 @@ func GetUserStopwatches(userID int64, listOptions db.ListOptions) ([]*Stopwatch,
}
// CountUserStopwatches return count of all stopwatches of a user
func CountUserStopwatches(userID int64) (int64, error) {
return db.GetEngine(db.DefaultContext).Where("user_id = ?", userID).Count(&Stopwatch{})
func CountUserStopwatches(ctx context.Context, userID int64) (int64, error) {
return db.GetEngine(ctx).Where("user_id = ?", userID).Count(&Stopwatch{})
}
// StopwatchExists returns true if the stopwatch exists
func StopwatchExists(userID, issueID int64) bool {
_, exists, _ := getStopwatch(db.DefaultContext, userID, issueID)
func StopwatchExists(ctx context.Context, userID, issueID int64) bool {
_, exists, _ := getStopwatch(ctx, userID, issueID)
return exists
}
@ -168,15 +168,15 @@ func FinishIssueStopwatchIfPossible(ctx context.Context, user *user_model.User,
}
// CreateOrStopIssueStopwatch create an issue stopwatch if it's not exist, otherwise finish it
func CreateOrStopIssueStopwatch(user *user_model.User, issue *Issue) error {
_, exists, err := getStopwatch(db.DefaultContext, user.ID, issue.ID)
func CreateOrStopIssueStopwatch(ctx context.Context, user *user_model.User, issue *Issue) error {
_, exists, err := getStopwatch(ctx, user.ID, issue.ID)
if err != nil {
return err
}
if exists {
return FinishIssueStopwatch(db.DefaultContext, user, issue)
return FinishIssueStopwatch(ctx, user, issue)
}
return CreateIssueStopwatch(db.DefaultContext, user, issue)
return CreateIssueStopwatch(ctx, user, issue)
}
// FinishIssueStopwatch if stopwatch exist then finish it otherwise return an error
@ -269,8 +269,8 @@ func CreateIssueStopwatch(ctx context.Context, user *user_model.User, issue *Iss
}
// CancelStopwatch removes the given stopwatch and logs it into issue's timeline.
func CancelStopwatch(user *user_model.User, issue *Issue) error {
ctx, committer, err := db.TxContext(db.DefaultContext)
func CancelStopwatch(ctx context.Context, user *user_model.User, issue *Issue) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}

View File

@ -26,20 +26,20 @@ func TestCancelStopwatch(t *testing.T) {
issue2, err := issues_model.GetIssueByID(db.DefaultContext, 2)