introduced the concept of a dialect, to provide the exact SQL needed.
This commit is contained in:
parent
f7c677a3f5
commit
a3d0263698
33
dialect.go
Normal file
33
dialect.go
Normal file
@ -0,0 +1,33 @@
|
||||
package migrate
|
||||
|
||||
type Dialect interface {
|
||||
CreateTable(table string) string
|
||||
TableExists(table string) string
|
||||
CheckVersion(table string) string
|
||||
InsertVersion(table string) string
|
||||
}
|
||||
|
||||
func Sqlite3() Dialect {
|
||||
return sqlite3{}
|
||||
}
|
||||
|
||||
type sqlite3 struct{}
|
||||
|
||||
func (s sqlite3) CreateTable(table string) string {
|
||||
return "CREATE TABLE " + table + ` (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
description VARCHAR,
|
||||
applied TIMESTAMP);`
|
||||
}
|
||||
|
||||
func (s sqlite3) TableExists(table string) string {
|
||||
return "SELECT * FROM " + table + ";"
|
||||
}
|
||||
|
||||
func (s sqlite3) CheckVersion(table string) string {
|
||||
return "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;"
|
||||
}
|
||||
|
||||
func (s sqlite3) InsertVersion(table string) string {
|
||||
return "INSERT INTO " + table + "(description, applied) VALUES (?,?);"
|
||||
}
|
2
doc.go
2
doc.go
@ -1,6 +1,6 @@
|
||||
package migrate
|
||||
|
||||
// migrate is a package for SQL datbase migrations in the spirit of dbstore(rsc.io/dbstore)
|
||||
// it is intended to keep its footprint small, requiring only an addiutional table in the database
|
||||
// it is intended to keep its footprint small, requiring only an additional table in the database
|
||||
// there is no rollback support as you should only ever roll forward.
|
||||
// uses SQL99 compatible SQL only.
|
||||
|
45
migrate.go
45
migrate.go
@ -8,15 +8,6 @@ import (
|
||||
|
||||
const table = "dbversion"
|
||||
|
||||
var tableCreateSql = "CREATE TABLE " + table + ` (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
description VARCHAR,
|
||||
applied TIMESTAMP
|
||||
);`
|
||||
var tableCheckSql = "SELECT * FROM " + table + ";"
|
||||
var versionCheckSql = "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;"
|
||||
var versionInsertSql = "INSERT INTO " + table + "(description, applied) VALUES (?,?);"
|
||||
|
||||
type Error struct {
|
||||
description string
|
||||
wrapped error
|
||||
@ -40,13 +31,13 @@ type Context interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func Apply(ctx Context, migrations []Record) (err error) {
|
||||
if err = initialize(ctx); err == nil {
|
||||
func Apply(ctx Context, d Dialect, migrations []Record) (err error) {
|
||||
if err = initialize(ctx, d); err == nil {
|
||||
var currentVersion int64
|
||||
if currentVersion, err = dbVersion(ctx); err == nil {
|
||||
if currentVersion, err = dbVersion(ctx, d); err == nil {
|
||||
migrations = migrations[currentVersion:] // only apply what hasnt been been applied already
|
||||
for i, m := range migrations {
|
||||
if err = apply(ctx, m); err != nil {
|
||||
if err = apply(ctx, d, m); err != nil {
|
||||
err = Error{
|
||||
description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description),
|
||||
wrapped: err,
|
||||
@ -59,40 +50,40 @@ func Apply(ctx Context, migrations []Record) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func initialize(ctx Context) (err error) {
|
||||
if noVersionTable(ctx) {
|
||||
return createVersionTable(ctx)
|
||||
func initialize(ctx Context, d Dialect) (err error) {
|
||||
if noVersionTable(ctx, d) {
|
||||
return createVersionTable(ctx, d)
|
||||
}
|
||||
return nil
|
||||
return
|
||||
}
|
||||
|
||||
func noVersionTable(ctx Context) bool {
|
||||
rows, table_check := ctx.Query(tableCheckSql)
|
||||
func noVersionTable(ctx Context, d Dialect) bool {
|
||||
rows, table_check := ctx.Query(d.TableExists(table))
|
||||
if rows != nil {
|
||||
defer rows.Close()
|
||||
}
|
||||
return table_check != nil
|
||||
}
|
||||
|
||||
func apply(ctx Context, r Record) (err error) {
|
||||
func apply(ctx Context, d Dialect, r Record) (err error) {
|
||||
if err = r.F(ctx); err == nil {
|
||||
err = incrementVersion(ctx, r.Description)
|
||||
err = incrementVersion(ctx, d, r.Description)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createVersionTable(ctx Context) (err error) {
|
||||
_, err = ctx.Exec(tableCreateSql)
|
||||
func createVersionTable(ctx Context, d Dialect) (err error) {
|
||||
_, err = ctx.Exec(d.CreateTable(table))
|
||||
return
|
||||
}
|
||||
|
||||
func incrementVersion(ctx Context, description string) (err error) {
|
||||
_, err = ctx.Exec(versionInsertSql, description, time.Now())
|
||||
func incrementVersion(ctx Context, d Dialect, description string) (err error) {
|
||||
_, err = ctx.Exec(d.InsertVersion(table), description, time.Now())
|
||||
return
|
||||
}
|
||||
|
||||
func dbVersion(ctx Context) (id int64, err error) {
|
||||
row, err := ctx.Query(versionCheckSql)
|
||||
func dbVersion(ctx Context, d Dialect) (id int64, err error) {
|
||||
row, err := ctx.Query(d.CheckVersion(table))
|
||||
if row.Next() {
|
||||
err = row.Scan(&id)
|
||||
}
|
||||
|
@ -27,7 +27,7 @@ func TestCreateVersionTable(t *testing.T) {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
err = createVersionTable(db)
|
||||
err = createVersionTable(db, Sqlite3())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -43,7 +43,9 @@ func TestIncrementVersion(t *testing.T) {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
err = createVersionTable(db)
|
||||
sl3 := Sqlite3()
|
||||
|
||||
err = createVersionTable(db, sl3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -54,7 +56,7 @@ func TestIncrementVersion(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, d := range descriptions {
|
||||
err = incrementVersion(db, d)
|
||||
err = incrementVersion(db, sl3, d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -88,18 +90,20 @@ func TestDbVersion(t *testing.T) {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
err = createVersionTable(db)
|
||||
sl3 := Sqlite3()
|
||||
|
||||
err = createVersionTable(db, sl3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ver, err := dbVersion(db)
|
||||
ver, err := dbVersion(db, sl3)
|
||||
if ver != 0 || err != nil {
|
||||
t.Fatalf("version not 0 as expected (actual %d) or err: %#v", ver, err)
|
||||
}
|
||||
|
||||
err = incrementVersion(db, "Test 1")
|
||||
ver, err = dbVersion(db)
|
||||
err = incrementVersion(db, sl3, "Test 1")
|
||||
ver, err = dbVersion(db, sl3)
|
||||
if ver != 1 {
|
||||
t.Fatalf("version not 1 as expected (actual %d)", ver)
|
||||
}
|
||||
@ -120,6 +124,8 @@ func TestApply(t *testing.T) {
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
sl3 := Sqlite3()
|
||||
|
||||
records :=
|
||||
[]Record{
|
||||
{
|
||||
@ -144,7 +150,7 @@ func TestApply(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
err = Apply(db, records)
|
||||
err = Apply(db, sl3, records)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -160,8 +166,8 @@ func TestApply(t *testing.T) {
|
||||
}
|
||||
|
||||
// reapply and make sure we dont re-run anything
|
||||
err = Apply(db, records)
|
||||
ver, err := dbVersion(db)
|
||||
err = Apply(db, sl3, records)
|
||||
ver, err := dbVersion(db, sl3)
|
||||
if ver != 2 {
|
||||
t.Fatalf("version not 2 as expected (actual %d)", ver)
|
||||
}
|
||||
@ -180,12 +186,12 @@ func TestApply(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
err = Apply(db, records)
|
||||
err = Apply(db, sl3, records)
|
||||
|
||||
if errors.Unwrap(err) != ishouldntHideUserErrors {
|
||||
t.Fatalf("unexpected error returned that should have been record function error: %#v", err)
|
||||
}
|
||||
ver, err = dbVersion(db)
|
||||
ver, err = dbVersion(db, sl3)
|
||||
if ver != 2 {
|
||||
t.Fatalf("version not 2 as expected (actual %d) after bad record apply", ver)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user