migrate/migrate.go
2021-06-04 17:57:45 -07:00

101 lines
2.2 KiB
Go

package migrate
import (
"database/sql"
"fmt"
"time"
)
const table = "dbversion"
var tableCreateSql = "CREATE TABLE " + table + ` (
id INTEGER PRIMARY KEY AUTOINCREMENT,
description VARCHAR,
applied TIMESTAMP
);`
var tableCheckSql = "SELECT * FROM " + table + ";"
var versionCheckSql = "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;"
var versionInsertSql = "INSERT INTO " + table + "(description, applied) VALUES (?,?);"
type Error struct {
description string
wrapped error
}
func (e Error) Error() string {
return fmt.Sprintf("%s: %v", e.description, e.wrapped)
}
func (e Error) Unwrap() error {
return e.wrapped
}
type Record struct {
Description string
F func(ctx Context) error
}
type Context interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
}
func Apply(ctx Context, migrations []Record) (err error) {
if err = initialize(ctx); err == nil {
var currentVersion int64
if currentVersion, err = dbVersion(ctx); err == nil {
migrations = migrations[currentVersion:] // only apply what hasnt been been applied already
for i, m := range migrations {
if err = apply(ctx, m); err != nil {
err = Error{
description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description),
wrapped: err,
}
break
}
}
}
}
return
}
func initialize(ctx Context) (err error) {
if noVersionTable(ctx) {
return createVersionTable(ctx)
}
return nil
}
func noVersionTable(ctx Context) bool {
rows, table_check := ctx.Query(tableCheckSql)
if rows != nil {
defer rows.Close()
}
return table_check != nil
}
func apply(ctx Context, r Record) (err error) {
if err = r.F(ctx); err == nil {
err = incrementVersion(ctx, r.Description)
}
return
}
func createVersionTable(ctx Context) (err error) {
_, err = ctx.Exec(tableCreateSql)
return
}
func incrementVersion(ctx Context, description string) (err error) {
_, err = ctx.Exec(versionInsertSql, description, time.Now())
return
}
func dbVersion(ctx Context) (id int64, err error) {
row, err := ctx.Query(versionCheckSql)
if row.Next() {
err = row.Scan(&id)
}
return
}