introduced the concept of a dialect, to provide the exact SQL needed.

This commit is contained in:
Colin Henry 2021-07-10 13:09:26 -07:00
parent f7c677a3f5
commit a3d0263698
4 changed files with 70 additions and 40 deletions

33
dialect.go Normal file
View File

@ -0,0 +1,33 @@
package migrate
type Dialect interface {
CreateTable(table string) string
TableExists(table string) string
CheckVersion(table string) string
InsertVersion(table string) string
}
func Sqlite3() Dialect {
return sqlite3{}
}
type sqlite3 struct{}
func (s sqlite3) CreateTable(table string) string {
return "CREATE TABLE " + table + ` (
id INTEGER PRIMARY KEY AUTOINCREMENT,
description VARCHAR,
applied TIMESTAMP);`
}
func (s sqlite3) TableExists(table string) string {
return "SELECT * FROM " + table + ";"
}
func (s sqlite3) CheckVersion(table string) string {
return "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;"
}
func (s sqlite3) InsertVersion(table string) string {
return "INSERT INTO " + table + "(description, applied) VALUES (?,?);"
}

2
doc.go
View File

@ -1,6 +1,6 @@
package migrate
// migrate is a package for SQL datbase migrations in the spirit of dbstore(rsc.io/dbstore)
// it is intended to keep its footprint small, requiring only an addiutional table in the database
// it is intended to keep its footprint small, requiring only an additional table in the database
// there is no rollback support as you should only ever roll forward.
// uses SQL99 compatible SQL only.

View File

@ -8,15 +8,6 @@ import (
const table = "dbversion"
var tableCreateSql = "CREATE TABLE " + table + ` (
id INTEGER PRIMARY KEY AUTOINCREMENT,
description VARCHAR,
applied TIMESTAMP
);`
var tableCheckSql = "SELECT * FROM " + table + ";"
var versionCheckSql = "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;"
var versionInsertSql = "INSERT INTO " + table + "(description, applied) VALUES (?,?);"
type Error struct {
description string
wrapped error
@ -40,13 +31,13 @@ type Context interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
}
func Apply(ctx Context, migrations []Record) (err error) {
if err = initialize(ctx); err == nil {
func Apply(ctx Context, d Dialect, migrations []Record) (err error) {
if err = initialize(ctx, d); err == nil {
var currentVersion int64
if currentVersion, err = dbVersion(ctx); err == nil {
if currentVersion, err = dbVersion(ctx, d); err == nil {
migrations = migrations[currentVersion:] // only apply what hasnt been been applied already
for i, m := range migrations {
if err = apply(ctx, m); err != nil {
if err = apply(ctx, d, m); err != nil {
err = Error{
description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description),
wrapped: err,
@ -59,40 +50,40 @@ func Apply(ctx Context, migrations []Record) (err error) {
return
}
func initialize(ctx Context) (err error) {
if noVersionTable(ctx) {
return createVersionTable(ctx)
func initialize(ctx Context, d Dialect) (err error) {
if noVersionTable(ctx, d) {
return createVersionTable(ctx, d)
}
return nil
return
}
func noVersionTable(ctx Context) bool {
rows, table_check := ctx.Query(tableCheckSql)
func noVersionTable(ctx Context, d Dialect) bool {
rows, table_check := ctx.Query(d.TableExists(table))
if rows != nil {
defer rows.Close()
}
return table_check != nil
}
func apply(ctx Context, r Record) (err error) {
func apply(ctx Context, d Dialect, r Record) (err error) {
if err = r.F(ctx); err == nil {
err = incrementVersion(ctx, r.Description)
err = incrementVersion(ctx, d, r.Description)
}
return
}
func createVersionTable(ctx Context) (err error) {
_, err = ctx.Exec(tableCreateSql)
func createVersionTable(ctx Context, d Dialect) (err error) {
_, err = ctx.Exec(d.CreateTable(table))
return
}
func incrementVersion(ctx Context, description string) (err error) {
_, err = ctx.Exec(versionInsertSql, description, time.Now())
func incrementVersion(ctx Context, d Dialect, description string) (err error) {
_, err = ctx.Exec(d.InsertVersion(table), description, time.Now())
return
}
func dbVersion(ctx Context) (id int64, err error) {
row, err := ctx.Query(versionCheckSql)
func dbVersion(ctx Context, d Dialect) (id int64, err error) {
row, err := ctx.Query(d.CheckVersion(table))
if row.Next() {
err = row.Scan(&id)
}

View File

@ -27,7 +27,7 @@ func TestCreateVersionTable(t *testing.T) {
t.Fail()
}
err = createVersionTable(db)
err = createVersionTable(db, Sqlite3())
if err != nil {
t.Fatal(err)
}
@ -43,7 +43,9 @@ func TestIncrementVersion(t *testing.T) {
t.Fail()
}
err = createVersionTable(db)
sl3 := Sqlite3()
err = createVersionTable(db, sl3)
if err != nil {
t.Fatal(err)
}
@ -54,7 +56,7 @@ func TestIncrementVersion(t *testing.T) {
}
for _, d := range descriptions {
err = incrementVersion(db, d)
err = incrementVersion(db, sl3, d)
if err != nil {
t.Fatal(err)
}
@ -88,18 +90,20 @@ func TestDbVersion(t *testing.T) {
t.Fail()
}
err = createVersionTable(db)
sl3 := Sqlite3()
err = createVersionTable(db, sl3)
if err != nil {
t.Fatal(err)
}
ver, err := dbVersion(db)
ver, err := dbVersion(db, sl3)
if ver != 0 || err != nil {
t.Fatalf("version not 0 as expected (actual %d) or err: %#v", ver, err)
}
err = incrementVersion(db, "Test 1")
ver, err = dbVersion(db)
err = incrementVersion(db, sl3, "Test 1")
ver, err = dbVersion(db, sl3)
if ver != 1 {
t.Fatalf("version not 1 as expected (actual %d)", ver)
}
@ -120,6 +124,8 @@ func TestApply(t *testing.T) {
t.Fail()
}
sl3 := Sqlite3()
records :=
[]Record{
{
@ -144,7 +150,7 @@ func TestApply(t *testing.T) {
},
}
err = Apply(db, records)
err = Apply(db, sl3, records)
if err != nil {
t.Fatal(err)
@ -160,8 +166,8 @@ func TestApply(t *testing.T) {
}
// reapply and make sure we dont re-run anything
err = Apply(db, records)
ver, err := dbVersion(db)
err = Apply(db, sl3, records)
ver, err := dbVersion(db, sl3)
if ver != 2 {
t.Fatalf("version not 2 as expected (actual %d)", ver)
}
@ -180,12 +186,12 @@ func TestApply(t *testing.T) {
},
})
err = Apply(db, records)
err = Apply(db, sl3, records)
if errors.Unwrap(err) != ishouldntHideUserErrors {
t.Fatalf("unexpected error returned that should have been record function error: %#v", err)
}
ver, err = dbVersion(db)
ver, err = dbVersion(db, sl3)
if ver != 2 {
t.Fatalf("version not 2 as expected (actual %d) after bad record apply", ver)
}