From f7c677a3f56724bd4297239ba9fe12a548c6261f Mon Sep 17 00:00:00 2001 From: Colin Henry Date: Fri, 4 Jun 2021 17:57:45 -0700 Subject: [PATCH] temporary holding ground for migrate --- doc.go | 6 ++ go.mod | 5 ++ go.sum | 2 + migrate.go | 100 ++++++++++++++++++++++ migrate_test.go | 216 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 329 insertions(+) create mode 100644 doc.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 migrate.go create mode 100644 migrate_test.go diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..84b7878 --- /dev/null +++ b/doc.go @@ -0,0 +1,6 @@ +package migrate + +// migrate is a package for SQL datbase migrations in the spirit of dbstore(rsc.io/dbstore) +// it is intended to keep its footprint small, requiring only an addiutional table in the database +// there is no rollback support as you should only ever roll forward. +// uses SQL99 compatible SQL only. diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..10c8bc6 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module github.com/jchenry/tmp/migrate + +go 1.16 + +require github.com/mattn/go-sqlite3 v1.14.7 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..96ff824 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= +github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= diff --git a/migrate.go b/migrate.go new file mode 100644 index 0000000..405aa63 --- /dev/null +++ b/migrate.go @@ -0,0 +1,100 @@ +package migrate + +import ( + "database/sql" + "fmt" + "time" +) + +const table = "dbversion" + +var tableCreateSql = "CREATE TABLE " + table + ` ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + description VARCHAR, + applied TIMESTAMP +);` +var tableCheckSql = "SELECT * FROM " + table + ";" +var versionCheckSql = "SELECT id FROM " + table + " ORDER BY id DESC LIMIT 0, 1;" +var versionInsertSql = "INSERT INTO " + table + "(description, applied) VALUES (?,?);" + +type Error struct { + description string + wrapped error +} + +func (e Error) Error() string { + return fmt.Sprintf("%s: %v", e.description, e.wrapped) +} + +func (e Error) Unwrap() error { + return e.wrapped +} + +type Record struct { + Description string + F func(ctx Context) error +} + +type Context interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Query(query string, args ...interface{}) (*sql.Rows, error) +} + +func Apply(ctx Context, migrations []Record) (err error) { + if err = initialize(ctx); err == nil { + var currentVersion int64 + if currentVersion, err = dbVersion(ctx); err == nil { + migrations = migrations[currentVersion:] // only apply what hasnt been been applied already + for i, m := range migrations { + if err = apply(ctx, m); err != nil { + err = Error{ + description: fmt.Sprintf("error performing migration \"%s\"", migrations[i].Description), + wrapped: err, + } + break + } + } + } + } + return +} + +func initialize(ctx Context) (err error) { + if noVersionTable(ctx) { + return createVersionTable(ctx) + } + return nil +} + +func noVersionTable(ctx Context) bool { + rows, table_check := ctx.Query(tableCheckSql) + if rows != nil { + defer rows.Close() + } + return table_check != nil +} + +func apply(ctx Context, r Record) (err error) { + if err = r.F(ctx); err == nil { + err = incrementVersion(ctx, r.Description) + } + return +} + +func createVersionTable(ctx Context) (err error) { + _, err = ctx.Exec(tableCreateSql) + return +} + +func incrementVersion(ctx Context, description string) (err error) { + _, err = ctx.Exec(versionInsertSql, description, time.Now()) + return +} + +func dbVersion(ctx Context) (id int64, err error) { + row, err := ctx.Query(versionCheckSql) + if row.Next() { + err = row.Scan(&id) + } + return +} diff --git a/migrate_test.go b/migrate_test.go new file mode 100644 index 0000000..e6278ce --- /dev/null +++ b/migrate_test.go @@ -0,0 +1,216 @@ +package migrate + +import ( + "database/sql" + "errors" + "io/ioutil" + "os" + "strings" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func TestHelperFuncs(t *testing.T) { + path, db, err := createTestDB() + if err != nil { + t.Fail() + } + if err = teardownTestDB(path, db); err != nil { + t.Fail() + } +} + +func TestCreateVersionTable(t *testing.T) { + path, db, err := createTestDB() + if err != nil { + t.Fail() + } + + err = createVersionTable(db) + if err != nil { + t.Fatal(err) + } + + if err = teardownTestDB(path, db); err != nil { + t.Fail() + } +} + +func TestIncrementVersion(t *testing.T) { + path, db, err := createTestDB() + if err != nil { + t.Fail() + } + + err = createVersionTable(db) + if err != nil { + t.Fatal(err) + } + + descriptions := []string{ + "this is a test", + "this is another test", + } + + for _, d := range descriptions { + err = incrementVersion(db, d) + if err != nil { + t.Fatal(err) + } + } + + rows, err := db.Query("SELECT id, description from dbversion") + if err != nil { + t.Fatal(err) + } + var id int + var description string + for r := 1; rows.Next(); r++ { + err = rows.Scan(&id, &description) + if err != nil { + t.Fatal(err) + } + if id != r || !strings.EqualFold(description, descriptions[r-1]) { + t.Fatalf("first row does not match %d %s: %d %s", id, descriptions[r-1], r, description) + } + + } + + if err = teardownTestDB(path, db); err != nil { + t.Fail() + } +} + +func TestDbVersion(t *testing.T) { + path, db, err := createTestDB() + if err != nil { + t.Fail() + } + + err = createVersionTable(db) + if err != nil { + t.Fatal(err) + } + + ver, err := dbVersion(db) + if ver != 0 || err != nil { + t.Fatalf("version not 0 as expected (actual %d) or err: %#v", ver, err) + } + + err = incrementVersion(db, "Test 1") + ver, err = dbVersion(db) + if ver != 1 { + t.Fatalf("version not 1 as expected (actual %d)", ver) + } + if err != nil { + t.Fatalf("err on dbversion of first increment: %#v", err) + } + + // err = incrementVersion(db, d) + + if err = teardownTestDB(path, db); err != nil { + t.Fail() + } +} + +func TestApply(t *testing.T) { + path, db, err := createTestDB() + if err != nil { + t.Fail() + } + + records := + []Record{ + { + Description: "create people table", + F: func(ctx Context) (err error) { + _, err = ctx.Exec(` + CREATE TABLE people ( + given_name VARCHAR(20), + surname VARCHAR(30), + sex CHAR(1), + age SMALLINT); + `) + return + }, + }, + { + Description: "Insert a person into people", + F: func(ctx Context) (err error) { + _, err = ctx.Exec(`INSERT INTO people VALUES('Henry','Colin','M', 42)`) + return + }, + }, + } + + err = Apply(db, records) + + if err != nil { + t.Fatal(err) + } + + r := db.QueryRow("SELECT given_name FROM people") + + var given_name string + r.Scan(&given_name) + + if given_name != "Henry" { + t.Fatalf("second migration did not complete: %s != %s", given_name, "Henry") + } + + // reapply and make sure we dont re-run anything + err = Apply(db, records) + ver, err := dbVersion(db) + if ver != 2 { + t.Fatalf("version not 2 as expected (actual %d)", ver) + } + if err != nil { + t.Fatalf("err on dbversion of re-apply: %#v", err) + } + + // add bad (causes migrate.Error) case here. + + ishouldntHideUserErrors := errors.New("I should fail") + + records = append(records, Record{ + Description: "Insert a person into people", + F: func(ctx Context) (err error) { + return ishouldntHideUserErrors + }, + }) + + err = Apply(db, records) + + if errors.Unwrap(err) != ishouldntHideUserErrors { + t.Fatalf("unexpected error returned that should have been record function error: %#v", err) + } + ver, err = dbVersion(db) + if ver != 2 { + t.Fatalf("version not 2 as expected (actual %d) after bad record apply", ver) + } + if err != nil { + t.Fatalf("err on dbversion of re-apply with bad record: %#v", err) + } + + if err = teardownTestDB(path, db); err != nil { + t.Fail() + } + +} + +func createTestDB() (path string, db *sql.DB, err error) { + if f, err := ioutil.TempFile(os.TempDir(), "migrate-test-db"); err == nil { + f.Close() + if db, err := sql.Open("sqlite3", f.Name()); err == nil { + return f.Name(), db, err + } + } + return +} +func teardownTestDB(path string, db *sql.DB) (err error) { + if err = db.Close(); err == nil { + err = os.Remove(path) + } + return +}