From a3d0263698f4ea6d84f67bcd251725e4c24987b6 Mon Sep 17 00:00:00 2001 From: Colin Henry Date: Sat, 10 Jul 2021 13:09:26 -0700 Subject: [PATCH] introduced the concept of a dialect, to provide the exact SQL needed. --- dialect.go | 33 +++++++++++++++++++++++++++++++++ doc.go | 2 +- migrate.go | 45 ++++++++++++++++++--------------------------- migrate_test.go | 30 ++++++++++++++++++------------ 4 files changed, 70 insertions(+), 40 deletions(-) create mode 100644 dialect.go diff --git a/dialect.go b/dialect.go new file mode 100644 index 0000000..469af54 --- /dev/null +++ b/dialect.go @@ -0,0 +1,33 @@ +package migrate + +type Dialect interface { + CreateTable(table string) string + TableExists(table string) string + CheckVersion(table string) string + InsertVersion(table string) string +} + +func Sqlite3() Dialect { + return sqlite3{} +} + +type sqlite3 struct{} + +func (s sqlite3) CreateTable(table string) string { + return "CREATE TABLE " + table + ` ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + description VARCHAR, + applied TIMESTAMP);` +} + +func (s sqlite3) TableExists(table string) string { + return "SELECT * FROM " + table + ";" +} + +func (s sqlite3) CheckVersion(table string) string { + return "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;" +} + +func (s sqlite3) InsertVersion(table string) string { + return "INSERT INTO " + table + "(description, applied) VALUES (?,?);" +} diff --git a/doc.go b/doc.go index 84b7878..c5406ae 100644 --- a/doc.go +++ b/doc.go @@ -1,6 +1,6 @@ package migrate // migrate is a package for SQL datbase migrations in the spirit of dbstore(rsc.io/dbstore) -// it is intended to keep its footprint small, requiring only an addiutional table in the database +// it is intended to keep its footprint small, requiring only an additional table in the database // there is no rollback support as you should only ever roll forward. // uses SQL99 compatible SQL only. diff --git a/migrate.go b/migrate.go index 405aa63..11a3fd1 100644 --- a/migrate.go +++ b/migrate.go @@ -8,15 +8,6 @@ import ( const table = "dbversion" -var tableCreateSql = "CREATE TABLE " + table + ` ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - description VARCHAR, - applied TIMESTAMP -);` -var tableCheckSql = "SELECT * FROM " + table + ";" -var versionCheckSql = "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;" -var versionInsertSql = "INSERT INTO " + table + "(description, applied) VALUES (?,?);" - type Error struct { description string wrapped error @@ -40,13 +31,13 @@ type Context interface { Query(query string, args ...interface{}) (*sql.Rows, error) } -func Apply(ctx Context, migrations []Record) (err error) { - if err = initialize(ctx); err == nil { +func Apply(ctx Context, d Dialect, migrations []Record) (err error) { + if err = initialize(ctx, d); err == nil { var currentVersion int64 - if currentVersion, err = dbVersion(ctx); err == nil { + if currentVersion, err = dbVersion(ctx, d); err == nil { migrations = migrations[currentVersion:] // only apply what hasnt been been applied already for i, m := range migrations { - if err = apply(ctx, m); err != nil { + if err = apply(ctx, d, m); err != nil { err = Error{ description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description), wrapped: err, @@ -59,40 +50,40 @@ func Apply(ctx Context, migrations []Record) (err error) { return } -func initialize(ctx Context) (err error) { - if noVersionTable(ctx) { - return createVersionTable(ctx) +func initialize(ctx Context, d Dialect) (err error) { + if noVersionTable(ctx, d) { + return createVersionTable(ctx, d) } - return nil + return } -func noVersionTable(ctx Context) bool { - rows, table_check := ctx.Query(tableCheckSql) +func noVersionTable(ctx Context, d Dialect) bool { + rows, table_check := ctx.Query(d.TableExists(table)) if rows != nil { defer rows.Close() } return table_check != nil } -func apply(ctx Context, r Record) (err error) { +func apply(ctx Context, d Dialect, r Record) (err error) { if err = r.F(ctx); err == nil { - err = incrementVersion(ctx, r.Description) + err = incrementVersion(ctx, d, r.Description) } return } -func createVersionTable(ctx Context) (err error) { - _, err = ctx.Exec(tableCreateSql) +func createVersionTable(ctx Context, d Dialect) (err error) { + _, err = ctx.Exec(d.CreateTable(table)) return } -func incrementVersion(ctx Context, description string) (err error) { - _, err = ctx.Exec(versionInsertSql, description, time.Now()) +func incrementVersion(ctx Context, d Dialect, description string) (err error) { + _, err = ctx.Exec(d.InsertVersion(table), description, time.Now()) return } -func dbVersion(ctx Context) (id int64, err error) { - row, err := ctx.Query(versionCheckSql) +func dbVersion(ctx Context, d Dialect) (id int64, err error) { + row, err := ctx.Query(d.CheckVersion(table)) if row.Next() { err = row.Scan(&id) } diff --git a/migrate_test.go b/migrate_test.go index e6278ce..ab1b215 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -27,7 +27,7 @@ func TestCreateVersionTable(t *testing.T) { t.Fail() } - err = createVersionTable(db) + err = createVersionTable(db, Sqlite3()) if err != nil { t.Fatal(err) } @@ -43,7 +43,9 @@ func TestIncrementVersion(t *testing.T) { t.Fail() } - err = createVersionTable(db) + sl3 := Sqlite3() + + err = createVersionTable(db, sl3) if err != nil { t.Fatal(err) } @@ -54,7 +56,7 @@ func TestIncrementVersion(t *testing.T) { } for _, d := range descriptions { - err = incrementVersion(db, d) + err = incrementVersion(db, sl3, d) if err != nil { t.Fatal(err) } @@ -88,18 +90,20 @@ func TestDbVersion(t *testing.T) { t.Fail() } - err = createVersionTable(db) + sl3 := Sqlite3() + + err = createVersionTable(db, sl3) if err != nil { t.Fatal(err) } - ver, err := dbVersion(db) + ver, err := dbVersion(db, sl3) if ver != 0 || err != nil { t.Fatalf("version not 0 as expected (actual %d) or err: %#v", ver, err) } - err = incrementVersion(db, "Test 1") - ver, err = dbVersion(db) + err = incrementVersion(db, sl3, "Test 1") + ver, err = dbVersion(db, sl3) if ver != 1 { t.Fatalf("version not 1 as expected (actual %d)", ver) } @@ -120,6 +124,8 @@ func TestApply(t *testing.T) { t.Fail() } + sl3 := Sqlite3() + records := []Record{ { @@ -144,7 +150,7 @@ func TestApply(t *testing.T) { }, } - err = Apply(db, records) + err = Apply(db, sl3, records) if err != nil { t.Fatal(err) @@ -160,8 +166,8 @@ func TestApply(t *testing.T) { } // reapply and make sure we dont re-run anything - err = Apply(db, records) - ver, err := dbVersion(db) + err = Apply(db, sl3, records) + ver, err := dbVersion(db, sl3) if ver != 2 { t.Fatalf("version not 2 as expected (actual %d)", ver) } @@ -180,12 +186,12 @@ func TestApply(t *testing.T) { }, }) - err = Apply(db, records) + err = Apply(db, sl3, records) if errors.Unwrap(err) != ishouldntHideUserErrors { t.Fatalf("unexpected error returned that should have been record function error: %#v", err) } - ver, err = dbVersion(db) + ver, err = dbVersion(db, sl3) if ver != 2 { t.Fatalf("version not 2 as expected (actual %d) after bad record apply", ver) }