Architecture Changes: - Split dialect implementations into separate files for better organization - Move SQLite dialect from dialect.go to sqlite.go - Keep only Dialect interface in dialect.go with comprehensive documentation - Each dialect now in its own file following single responsibility principle New Features: - Add PostgreSQL dialect support (Postgres() function) - PostgreSQL uses SERIAL PRIMARY KEY (auto-incrementing integer) - PostgreSQL uses $1, $2 placeholders instead of ? for parameters - PostgreSQL uses SELECT 1 for table existence check (more efficient) - Both dialects implement proper SQL identifier quoting for security Testing: - Add comprehensive dialect-specific tests in sqlite_test.go - Add comprehensive dialect-specific tests in postgres_test.go - Test SQL generation for all dialect methods - Test SQL injection protection via identifier escaping - All tests pass (8 test functions, 10 subtests) Documentation: - Update README with PostgreSQL usage example - Add "Supported Databases" section listing SQLite and PostgreSQL - Improve code examples with proper imports and error handling - Document how to implement Dialect interface for other databases File Structure: - dialect.go: Interface definition only (18 lines) - sqlite.go: SQLite dialect implementation (39 lines) - postgres.go: PostgreSQL dialect implementation (42 lines) - sqlite_test.go: SQLite dialect tests (67 lines) - postgres_test.go: PostgreSQL dialect tests (67 lines) Security: - Both dialects use quoteIdentifier() for SQL injection protection - Identifiers are quoted and internal quotes are escaped - Follows SQL standard quoting mechanism (double quotes for escaping) This change maintains backward compatibility while adding PostgreSQL support and improving code organization for future dialect additions.
68 lines
1.8 KiB
Go
68 lines
1.8 KiB
Go
package migrate
|
|
|
|
import (
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestPostgresDialect(t *testing.T) {
|
|
pg := Postgres()
|
|
|
|
t.Run("CreateTable", func(t *testing.T) {
|
|
sql := pg.CreateTable("dbversion")
|
|
if !strings.Contains(sql, "SERIAL PRIMARY KEY") {
|
|
t.Errorf("Expected SERIAL PRIMARY KEY, got: %s", sql)
|
|
}
|
|
if !strings.Contains(sql, `"dbversion"`) {
|
|
t.Errorf("Expected quoted table name, got: %s", sql)
|
|
}
|
|
})
|
|
|
|
t.Run("TableExists", func(t *testing.T) {
|
|
sql := pg.TableExists("dbversion")
|
|
if !strings.Contains(sql, `"dbversion"`) {
|
|
t.Errorf("Expected quoted table name, got: %s", sql)
|
|
}
|
|
if !strings.Contains(sql, "SELECT 1") {
|
|
t.Errorf("Expected SELECT 1 for existence check, got: %s", sql)
|
|
}
|
|
})
|
|
|
|
t.Run("CheckVersion", func(t *testing.T) {
|
|
sql := pg.CheckVersion("dbversion")
|
|
if !strings.Contains(sql, "ORDER BY id DESC") {
|
|
t.Errorf("Expected ORDER BY id DESC, got: %s", sql)
|
|
}
|
|
if !strings.Contains(sql, "LIMIT 1") {
|
|
t.Errorf("Expected LIMIT 1, got: %s", sql)
|
|
}
|
|
})
|
|
|
|
t.Run("InsertVersion", func(t *testing.T) {
|
|
sql := pg.InsertVersion("dbversion")
|
|
// PostgreSQL uses $1, $2 placeholders
|
|
if !strings.Contains(sql, "$1") || !strings.Contains(sql, "$2") {
|
|
t.Errorf("Expected PostgreSQL placeholders ($1, $2), got: %s", sql)
|
|
}
|
|
if !strings.Contains(sql, `"dbversion"`) {
|
|
t.Errorf("Expected quoted table name, got: %s", sql)
|
|
}
|
|
})
|
|
|
|
t.Run("QuoteIdentifier", func(t *testing.T) {
|
|
pg := postgres{}
|
|
|
|
// Test normal identifier
|
|
quoted := pg.quoteIdentifier("tablename")
|
|
if quoted != `"tablename"` {
|
|
t.Errorf("Expected quoted identifier, got: %s", quoted)
|
|
}
|
|
|
|
// Test identifier with quotes (SQL injection attempt)
|
|
quoted = pg.quoteIdentifier(`table"; DROP TABLE users; --`)
|
|
if !strings.Contains(quoted, `""`) {
|
|
t.Errorf("Expected escaped quotes, got: %s", quoted)
|
|
}
|
|
})
|
|
}
|