migrate/migrate.go

92 lines
1.9 KiB
Go

package migrate
import (
"database/sql"
"fmt"
"time"
)
const table = "dbversion"
type Error struct {
description string
wrapped error
}
func (e Error) Error() string {
return fmt.Sprintf("%s: %v", e.description, e.wrapped)
}
func (e Error) Unwrap() error {
return e.wrapped
}
type Record struct {
Description string
F func(ctx Context) error
}
type Context interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
}
func Apply(ctx Context, d Dialect, migrations []Record) (err error) {
if err = initialize(ctx, d); err == nil {
var currentVersion int64
if currentVersion, err = dbVersion(ctx, d); err == nil {
migrations = migrations[currentVersion:] // only apply what hasnt been been applied already
for i, m := range migrations {
if err = apply(ctx, d, m); err != nil {
err = Error{
description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description),
wrapped: err,
}
break
}
}
}
}
return
}
func initialize(ctx Context, d Dialect) (err error) {
if noVersionTable(ctx, d) {
return createVersionTable(ctx, d)
}
return
}
func noVersionTable(ctx Context, d Dialect) bool {
rows, table_check := ctx.Query(d.TableExists(table))
if rows != nil {
defer rows.Close()
}
return table_check != nil
}
func apply(ctx Context, d Dialect, r Record) (err error) {
if err = r.F(ctx); err == nil {
err = incrementVersion(ctx, d, r.Description)
}
return
}
func createVersionTable(ctx Context, d Dialect) (err error) {
_, err = ctx.Exec(d.CreateTable(table))
return
}
func incrementVersion(ctx Context, d Dialect, description string) (err error) {
_, err = ctx.Exec(d.InsertVersion(table), description, time.Now())
return
}
func dbVersion(ctx Context, d Dialect) (id int64, err error) {
row, err := ctx.Query(d.CheckVersion(table))
if row.Next() {
err = row.Scan(&id)
}
return
}