introduced the concept of a dialect, to provide the exact SQL needed.

This commit is contained in:
Colin Henry 2021-07-10 13:09:26 -07:00
parent f7c677a3f5
commit a3d0263698
4 changed files with 70 additions and 40 deletions

33
dialect.go Normal file
View File

@ -0,0 +1,33 @@
package migrate
type Dialect interface {
CreateTable(table string) string
TableExists(table string) string
CheckVersion(table string) string
InsertVersion(table string) string
}
func Sqlite3() Dialect {
return sqlite3{}
}
type sqlite3 struct{}
func (s sqlite3) CreateTable(table string) string {
return "CREATE TABLE " + table + ` (
id INTEGER PRIMARY KEY AUTOINCREMENT,
description VARCHAR,
applied TIMESTAMP);`
}
func (s sqlite3) TableExists(table string) string {
return "SELECT * FROM " + table + ";"
}
func (s sqlite3) CheckVersion(table string) string {
return "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;"
}
func (s sqlite3) InsertVersion(table string) string {
return "INSERT INTO " + table + "(description, applied) VALUES (?,?);"
}

2
doc.go
View File

@ -1,6 +1,6 @@
package migrate package migrate
// migrate is a package for SQL datbase migrations in the spirit of dbstore(rsc.io/dbstore) // migrate is a package for SQL datbase migrations in the spirit of dbstore(rsc.io/dbstore)
// it is intended to keep its footprint small, requiring only an addiutional table in the database // it is intended to keep its footprint small, requiring only an additional table in the database
// there is no rollback support as you should only ever roll forward. // there is no rollback support as you should only ever roll forward.
// uses SQL99 compatible SQL only. // uses SQL99 compatible SQL only.

View File

@ -8,15 +8,6 @@ import (
const table = "dbversion" const table = "dbversion"
var tableCreateSql = "CREATE TABLE " + table + ` (
id INTEGER PRIMARY KEY AUTOINCREMENT,
description VARCHAR,
applied TIMESTAMP
);`
var tableCheckSql = "SELECT * FROM " + table + ";"
var versionCheckSql = "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;"
var versionInsertSql = "INSERT INTO " + table + "(description, applied) VALUES (?,?);"
type Error struct { type Error struct {
description string description string
wrapped error wrapped error
@ -40,13 +31,13 @@ type Context interface {
Query(query string, args ...interface{}) (*sql.Rows, error) Query(query string, args ...interface{}) (*sql.Rows, error)
} }
func Apply(ctx Context, migrations []Record) (err error) { func Apply(ctx Context, d Dialect, migrations []Record) (err error) {
if err = initialize(ctx); err == nil { if err = initialize(ctx, d); err == nil {
var currentVersion int64 var currentVersion int64
if currentVersion, err = dbVersion(ctx); err == nil { if currentVersion, err = dbVersion(ctx, d); err == nil {
migrations = migrations[currentVersion:] // only apply what hasnt been been applied already migrations = migrations[currentVersion:] // only apply what hasnt been been applied already
for i, m := range migrations { for i, m := range migrations {
if err = apply(ctx, m); err != nil { if err = apply(ctx, d, m); err != nil {
err = Error{ err = Error{
description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description), description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description),
wrapped: err, wrapped: err,
@ -59,40 +50,40 @@ func Apply(ctx Context, migrations []Record) (err error) {
return return
} }
func initialize(ctx Context) (err error) { func initialize(ctx Context, d Dialect) (err error) {
if noVersionTable(ctx) { if noVersionTable(ctx, d) {
return createVersionTable(ctx) return createVersionTable(ctx, d)
} }
return nil return
} }
func noVersionTable(ctx Context) bool { func noVersionTable(ctx Context, d Dialect) bool {
rows, table_check := ctx.Query(tableCheckSql) rows, table_check := ctx.Query(d.TableExists(table))
if rows != nil { if rows != nil {
defer rows.Close() defer rows.Close()
} }
return table_check != nil return table_check != nil
} }
func apply(ctx Context, r Record) (err error) { func apply(ctx Context, d Dialect, r Record) (err error) {
if err = r.F(ctx); err == nil { if err = r.F(ctx); err == nil {
err = incrementVersion(ctx, r.Description) err = incrementVersion(ctx, d, r.Description)
} }
return return
} }
func createVersionTable(ctx Context) (err error) { func createVersionTable(ctx Context, d Dialect) (err error) {
_, err = ctx.Exec(tableCreateSql) _, err = ctx.Exec(d.CreateTable(table))
return return
} }
func incrementVersion(ctx Context, description string) (err error) { func incrementVersion(ctx Context, d Dialect, description string) (err error) {
_, err = ctx.Exec(versionInsertSql, description, time.Now()) _, err = ctx.Exec(d.InsertVersion(table), description, time.Now())
return return
} }
func dbVersion(ctx Context) (id int64, err error) { func dbVersion(ctx Context, d Dialect) (id int64, err error) {
row, err := ctx.Query(versionCheckSql) row, err := ctx.Query(d.CheckVersion(table))
if row.Next() { if row.Next() {
err = row.Scan(&id) err = row.Scan(&id)
} }

View File

@ -27,7 +27,7 @@ func TestCreateVersionTable(t *testing.T) {
t.Fail() t.Fail()
} }
err = createVersionTable(db) err = createVersionTable(db, Sqlite3())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -43,7 +43,9 @@ func TestIncrementVersion(t *testing.T) {
t.Fail() t.Fail()
} }
err = createVersionTable(db) sl3 := Sqlite3()
err = createVersionTable(db, sl3)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -54,7 +56,7 @@ func TestIncrementVersion(t *testing.T) {
} }
for _, d := range descriptions { for _, d := range descriptions {
err = incrementVersion(db, d) err = incrementVersion(db, sl3, d)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -88,18 +90,20 @@ func TestDbVersion(t *testing.T) {
t.Fail() t.Fail()
} }
err = createVersionTable(db) sl3 := Sqlite3()
err = createVersionTable(db, sl3)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
ver, err := dbVersion(db) ver, err := dbVersion(db, sl3)
if ver != 0 || err != nil { if ver != 0 || err != nil {
t.Fatalf("version not 0 as expected (actual %d) or err: %#v", ver, err) t.Fatalf("version not 0 as expected (actual %d) or err: %#v", ver, err)
} }
err = incrementVersion(db, "Test 1") err = incrementVersion(db, sl3, "Test 1")
ver, err = dbVersion(db) ver, err = dbVersion(db, sl3)
if ver != 1 { if ver != 1 {
t.Fatalf("version not 1 as expected (actual %d)", ver) t.Fatalf("version not 1 as expected (actual %d)", ver)
} }
@ -120,6 +124,8 @@ func TestApply(t *testing.T) {
t.Fail() t.Fail()
} }
sl3 := Sqlite3()
records := records :=
[]Record{ []Record{
{ {
@ -144,7 +150,7 @@ func TestApply(t *testing.T) {
}, },
} }
err = Apply(db, records) err = Apply(db, sl3, records)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -160,8 +166,8 @@ func TestApply(t *testing.T) {
} }
// reapply and make sure we dont re-run anything // reapply and make sure we dont re-run anything
err = Apply(db, records) err = Apply(db, sl3, records)
ver, err := dbVersion(db) ver, err := dbVersion(db, sl3)
if ver != 2 { if ver != 2 {
t.Fatalf("version not 2 as expected (actual %d)", ver) t.Fatalf("version not 2 as expected (actual %d)", ver)
} }
@ -180,12 +186,12 @@ func TestApply(t *testing.T) {
}, },
}) })
err = Apply(db, records) err = Apply(db, sl3, records)
if errors.Unwrap(err) != ishouldntHideUserErrors { if errors.Unwrap(err) != ishouldntHideUserErrors {
t.Fatalf("unexpected error returned that should have been record function error: %#v", err) t.Fatalf("unexpected error returned that should have been record function error: %#v", err)
} }
ver, err = dbVersion(db) ver, err = dbVersion(db, sl3)
if ver != 2 { if ver != 2 {
t.Fatalf("version not 2 as expected (actual %d) after bad record apply", ver) t.Fatalf("version not 2 as expected (actual %d) after bad record apply", ver)
} }