92 lines
1.9 KiB
Go
92 lines
1.9 KiB
Go
package migrate
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"time"
|
|
)
|
|
|
|
const table = "dbversion"
|
|
|
|
type Error struct {
|
|
description string
|
|
wrapped error
|
|
}
|
|
|
|
func (e Error) Error() string {
|
|
return fmt.Sprintf("%s: %v", e.description, e.wrapped)
|
|
}
|
|
|
|
func (e Error) Unwrap() error {
|
|
return e.wrapped
|
|
}
|
|
|
|
type Record struct {
|
|
Description string
|
|
F func(ctx Context) error
|
|
}
|
|
|
|
type Context interface {
|
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
|
}
|
|
|
|
func Apply(ctx Context, d Dialect, migrations []Record) (err error) {
|
|
if err = initialize(ctx, d); err == nil {
|
|
var currentVersion int64
|
|
if currentVersion, err = dbVersion(ctx, d); err == nil {
|
|
migrations = migrations[currentVersion:] // only apply what hasnt been been applied already
|
|
for i, m := range migrations {
|
|
if err = apply(ctx, d, m); err != nil {
|
|
err = Error{
|
|
description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description),
|
|
wrapped: err,
|
|
}
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func initialize(ctx Context, d Dialect) (err error) {
|
|
if noVersionTable(ctx, d) {
|
|
return createVersionTable(ctx, d)
|
|
}
|
|
return
|
|
}
|
|
|
|
func noVersionTable(ctx Context, d Dialect) bool {
|
|
rows, table_check := ctx.Query(d.TableExists(table))
|
|
if rows != nil {
|
|
defer rows.Close()
|
|
}
|
|
return table_check != nil
|
|
}
|
|
|
|
func apply(ctx Context, d Dialect, r Record) (err error) {
|
|
if err = r.F(ctx); err == nil {
|
|
err = incrementVersion(ctx, d, r.Description)
|
|
}
|
|
return
|
|
}
|
|
|
|
func createVersionTable(ctx Context, d Dialect) (err error) {
|
|
_, err = ctx.Exec(d.CreateTable(table))
|
|
return
|
|
}
|
|
|
|
func incrementVersion(ctx Context, d Dialect, description string) (err error) {
|
|
_, err = ctx.Exec(d.InsertVersion(table), description, time.Now())
|
|
return
|
|
}
|
|
|
|
func dbVersion(ctx Context, d Dialect) (id int64, err error) {
|
|
row, err := ctx.Query(d.CheckVersion(table))
|
|
if row.Next() {
|
|
err = row.Scan(&id)
|
|
}
|
|
return
|
|
}
|