introduced the concept of a dialect, to provide the exact SQL needed.
This commit is contained in:
parent
f7c677a3f5
commit
a3d0263698
33
dialect.go
Normal file
33
dialect.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
type Dialect interface {
|
||||||
|
CreateTable(table string) string
|
||||||
|
TableExists(table string) string
|
||||||
|
CheckVersion(table string) string
|
||||||
|
InsertVersion(table string) string
|
||||||
|
}
|
||||||
|
|
||||||
|
func Sqlite3() Dialect {
|
||||||
|
return sqlite3{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlite3 struct{}
|
||||||
|
|
||||||
|
func (s sqlite3) CreateTable(table string) string {
|
||||||
|
return "CREATE TABLE " + table + ` (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
description VARCHAR,
|
||||||
|
applied TIMESTAMP);`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) TableExists(table string) string {
|
||||||
|
return "SELECT * FROM " + table + ";"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) CheckVersion(table string) string {
|
||||||
|
return "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s sqlite3) InsertVersion(table string) string {
|
||||||
|
return "INSERT INTO " + table + "(description, applied) VALUES (?,?);"
|
||||||
|
}
|
2
doc.go
2
doc.go
@ -1,6 +1,6 @@
|
|||||||
package migrate
|
package migrate
|
||||||
|
|
||||||
// migrate is a package for SQL datbase migrations in the spirit of dbstore(rsc.io/dbstore)
|
// migrate is a package for SQL datbase migrations in the spirit of dbstore(rsc.io/dbstore)
|
||||||
// it is intended to keep its footprint small, requiring only an addiutional table in the database
|
// it is intended to keep its footprint small, requiring only an additional table in the database
|
||||||
// there is no rollback support as you should only ever roll forward.
|
// there is no rollback support as you should only ever roll forward.
|
||||||
// uses SQL99 compatible SQL only.
|
// uses SQL99 compatible SQL only.
|
||||||
|
45
migrate.go
45
migrate.go
@ -8,15 +8,6 @@ import (
|
|||||||
|
|
||||||
const table = "dbversion"
|
const table = "dbversion"
|
||||||
|
|
||||||
var tableCreateSql = "CREATE TABLE " + table + ` (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
description VARCHAR,
|
|
||||||
applied TIMESTAMP
|
|
||||||
);`
|
|
||||||
var tableCheckSql = "SELECT * FROM " + table + ";"
|
|
||||||
var versionCheckSql = "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;"
|
|
||||||
var versionInsertSql = "INSERT INTO " + table + "(description, applied) VALUES (?,?);"
|
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
description string
|
description string
|
||||||
wrapped error
|
wrapped error
|
||||||
@ -40,13 +31,13 @@ type Context interface {
|
|||||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Apply(ctx Context, migrations []Record) (err error) {
|
func Apply(ctx Context, d Dialect, migrations []Record) (err error) {
|
||||||
if err = initialize(ctx); err == nil {
|
if err = initialize(ctx, d); err == nil {
|
||||||
var currentVersion int64
|
var currentVersion int64
|
||||||
if currentVersion, err = dbVersion(ctx); err == nil {
|
if currentVersion, err = dbVersion(ctx, d); err == nil {
|
||||||
migrations = migrations[currentVersion:] // only apply what hasnt been been applied already
|
migrations = migrations[currentVersion:] // only apply what hasnt been been applied already
|
||||||
for i, m := range migrations {
|
for i, m := range migrations {
|
||||||
if err = apply(ctx, m); err != nil {
|
if err = apply(ctx, d, m); err != nil {
|
||||||
err = Error{
|
err = Error{
|
||||||
description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description),
|
description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description),
|
||||||
wrapped: err,
|
wrapped: err,
|
||||||
@ -59,40 +50,40 @@ func Apply(ctx Context, migrations []Record) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func initialize(ctx Context) (err error) {
|
func initialize(ctx Context, d Dialect) (err error) {
|
||||||
if noVersionTable(ctx) {
|
if noVersionTable(ctx, d) {
|
||||||
return createVersionTable(ctx)
|
return createVersionTable(ctx, d)
|
||||||
}
|
}
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func noVersionTable(ctx Context) bool {
|
func noVersionTable(ctx Context, d Dialect) bool {
|
||||||
rows, table_check := ctx.Query(tableCheckSql)
|
rows, table_check := ctx.Query(d.TableExists(table))
|
||||||
if rows != nil {
|
if rows != nil {
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
}
|
}
|
||||||
return table_check != nil
|
return table_check != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func apply(ctx Context, r Record) (err error) {
|
func apply(ctx Context, d Dialect, r Record) (err error) {
|
||||||
if err = r.F(ctx); err == nil {
|
if err = r.F(ctx); err == nil {
|
||||||
err = incrementVersion(ctx, r.Description)
|
err = incrementVersion(ctx, d, r.Description)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func createVersionTable(ctx Context) (err error) {
|
func createVersionTable(ctx Context, d Dialect) (err error) {
|
||||||
_, err = ctx.Exec(tableCreateSql)
|
_, err = ctx.Exec(d.CreateTable(table))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func incrementVersion(ctx Context, description string) (err error) {
|
func incrementVersion(ctx Context, d Dialect, description string) (err error) {
|
||||||
_, err = ctx.Exec(versionInsertSql, description, time.Now())
|
_, err = ctx.Exec(d.InsertVersion(table), description, time.Now())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func dbVersion(ctx Context) (id int64, err error) {
|
func dbVersion(ctx Context, d Dialect) (id int64, err error) {
|
||||||
row, err := ctx.Query(versionCheckSql)
|
row, err := ctx.Query(d.CheckVersion(table))
|
||||||
if row.Next() {
|
if row.Next() {
|
||||||
err = row.Scan(&id)
|
err = row.Scan(&id)
|
||||||
}
|
}
|
||||||
|
@ -27,7 +27,7 @@ func TestCreateVersionTable(t *testing.T) {
|
|||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = createVersionTable(db)
|
err = createVersionTable(db, Sqlite3())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -43,7 +43,9 @@ func TestIncrementVersion(t *testing.T) {
|
|||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = createVersionTable(db)
|
sl3 := Sqlite3()
|
||||||
|
|
||||||
|
err = createVersionTable(db, sl3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -54,7 +56,7 @@ func TestIncrementVersion(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, d := range descriptions {
|
for _, d := range descriptions {
|
||||||
err = incrementVersion(db, d)
|
err = incrementVersion(db, sl3, d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@ -88,18 +90,20 @@ func TestDbVersion(t *testing.T) {
|
|||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
|
|
||||||
err = createVersionTable(db)
|
sl3 := Sqlite3()
|
||||||
|
|
||||||
|
err = createVersionTable(db, sl3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ver, err := dbVersion(db)
|
ver, err := dbVersion(db, sl3)
|
||||||
if ver != 0 || err != nil {
|
if ver != 0 || err != nil {
|
||||||
t.Fatalf("version not 0 as expected (actual %d) or err: %#v", ver, err)
|
t.Fatalf("version not 0 as expected (actual %d) or err: %#v", ver, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = incrementVersion(db, "Test 1")
|
err = incrementVersion(db, sl3, "Test 1")
|
||||||
ver, err = dbVersion(db)
|
ver, err = dbVersion(db, sl3)
|
||||||
if ver != 1 {
|
if ver != 1 {
|
||||||
t.Fatalf("version not 1 as expected (actual %d)", ver)
|
t.Fatalf("version not 1 as expected (actual %d)", ver)
|
||||||
}
|
}
|
||||||
@ -120,6 +124,8 @@ func TestApply(t *testing.T) {
|
|||||||
t.Fail()
|
t.Fail()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sl3 := Sqlite3()
|
||||||
|
|
||||||
records :=
|
records :=
|
||||||
[]Record{
|
[]Record{
|
||||||
{
|
{
|
||||||
@ -144,7 +150,7 @@ func TestApply(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = Apply(db, records)
|
err = Apply(db, sl3, records)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -160,8 +166,8 @@ func TestApply(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// reapply and make sure we dont re-run anything
|
// reapply and make sure we dont re-run anything
|
||||||
err = Apply(db, records)
|
err = Apply(db, sl3, records)
|
||||||
ver, err := dbVersion(db)
|
ver, err := dbVersion(db, sl3)
|
||||||
if ver != 2 {
|
if ver != 2 {
|
||||||
t.Fatalf("version not 2 as expected (actual %d)", ver)
|
t.Fatalf("version not 2 as expected (actual %d)", ver)
|
||||||
}
|
}
|
||||||
@ -180,12 +186,12 @@ func TestApply(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
err = Apply(db, records)
|
err = Apply(db, sl3, records)
|
||||||
|
|
||||||
if errors.Unwrap(err) != ishouldntHideUserErrors {
|
if errors.Unwrap(err) != ishouldntHideUserErrors {
|
||||||
t.Fatalf("unexpected error returned that should have been record function error: %#v", err)
|
t.Fatalf("unexpected error returned that should have been record function error: %#v", err)
|
||||||
}
|
}
|
||||||
ver, err = dbVersion(db)
|
ver, err = dbVersion(db, sl3)
|
||||||
if ver != 2 {
|
if ver != 2 {
|
||||||
t.Fatalf("version not 2 as expected (actual %d) after bad record apply", ver)
|
t.Fatalf("version not 2 as expected (actual %d) after bad record apply", ver)
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user