151 lines
3.7 KiB
Go
151 lines
3.7 KiB
Go
package migration
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"github.com/jmoiron/sqlx"
|
|
"io/fs"
|
|
"log"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const migrationsTable = `
|
|
CREATE TABLE IF NOT EXISTS migration (
|
|
filename varchar,
|
|
hash varchar,
|
|
created timestamp,
|
|
should_validate boolean
|
|
);`
|
|
const fetchMigrationsTable = `SELECT * FROM "migration";`
|
|
const createMigration = `INSERT INTO "migration" (filename, hash, created, should_validate) VALUES ($1, $2, current_timestamp, TRUE);`
|
|
|
|
type Migration struct {
|
|
Filename string `db:"filename"`
|
|
Hash string `db:"hash"`
|
|
Created time.Time `db:"created"`
|
|
ShouldValidate bool `db:"should_validate"`
|
|
}
|
|
|
|
func InitializeMigrations(db *sqlx.DB, migrationFiles fs.FS) error {
|
|
// create table if it doesn't exist already
|
|
if _, err := db.Exec(migrationsTable); err != nil {
|
|
return err
|
|
}
|
|
|
|
// fetch current migrations
|
|
var entries []Migration
|
|
if err := db.Select(&entries, fetchMigrationsTable); err != nil {
|
|
return err
|
|
}
|
|
|
|
migrations := map[string]Migration{}
|
|
for _, entry := range entries {
|
|
migrations[entry.Filename] = entry
|
|
}
|
|
return validateMigrations(db, migrations, migrationFiles)
|
|
}
|
|
|
|
type version struct {
|
|
major int
|
|
minor int
|
|
}
|
|
|
|
func validateMigrations(db *sqlx.DB, migrations map[string]Migration, migrationFiles fs.FS) error {
|
|
scripts := map[string]string{}
|
|
var scriptNames []string
|
|
|
|
// fetch scripts
|
|
err := fs.WalkDir(migrationFiles, ".", func(path string, d fs.DirEntry, outerErr error) error {
|
|
if !d.IsDir() {
|
|
if bytearr, err := fs.ReadFile(migrationFiles, path); err == nil {
|
|
scripts[d.Name()] = strings.TrimSpace(string(bytearr))
|
|
scriptNames = append(scriptNames, d.Name())
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// sort scripts by version
|
|
sort.Slice(scriptNames, func(i, j int) bool {
|
|
return isLessThan(parseVersion(scriptNames[i]), parseVersion(scriptNames[j]))
|
|
})
|
|
|
|
for _, name := range scriptNames {
|
|
if _, exists := migrations[name]; exists {
|
|
if err := validateMigration(name, migrations[name], scripts[name]); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := executeMigration(db, name, scripts[name]); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func executeMigration(db *sqlx.DB, name string, script string) error {
|
|
log.Printf("[INFO] script='%s' | migrations - executing", name)
|
|
tx := db.MustBeginTx(context.Background(), nil)
|
|
var err error = nil
|
|
if _, e := tx.Exec(script); e != nil {
|
|
err = e
|
|
} else if _, e := tx.Exec(createMigration, name, hash(script)); e != nil {
|
|
err = e
|
|
}
|
|
if err != nil {
|
|
log.Printf("[ERROR] script='%s' | migrations - failed executing", name)
|
|
tx.Rollback()
|
|
} else {
|
|
log.Printf("[INFO] script='%s' | migrations - succesfully executed", name)
|
|
tx.Commit()
|
|
}
|
|
return err
|
|
}
|
|
|
|
func validateMigration(name string, migration Migration, script string) error {
|
|
if !migration.ShouldValidate {
|
|
return nil
|
|
}
|
|
|
|
calculatedHash := hash(script)
|
|
|
|
if calculatedHash != migration.Hash {
|
|
err := fmt.Sprintf("migrations - mismatch in hash for %s (expected '%s', calculated '%s')", name, migration.Hash, calculatedHash)
|
|
log.Printf("[ERROR] script='%s' err='%s' | migrations - failed executing", script, err)
|
|
return fmt.Errorf("migrations - mismatch in hashes for %s", name)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func hash(script string) string {
|
|
hash := sha256.New()
|
|
hash.Write([]byte(script))
|
|
return base64.StdEncoding.EncodeToString(hash.Sum(nil))
|
|
}
|
|
|
|
func isLessThan(v1 version, v2 version) bool {
|
|
if v1.major > v2.major {
|
|
return false
|
|
}
|
|
if v1.major == v2.major && v1.minor > v2.minor {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func parseVersion(filename string) version {
|
|
ver := version{}
|
|
if _, err := fmt.Sscanf(filename, "v%d_%d", &ver.major, &ver.minor); err != nil {
|
|
panic(err)
|
|
}
|
|
return ver
|
|
}
|