payment-poc/migration/migration.go

154 lines
3.7 KiB
Go

package migration
import (
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"github.com/jmoiron/sqlx"
"io/fs"
"log/slog"
"sort"
"strings"
"time"
)
const migrationsTable = `
CREATE TABLE IF NOT EXISTS migration (
filename varchar,
hash varchar,
created timestamp,
should_validate boolean
);`
const fetchMigrationsTable = `SELECT * FROM "migration";`
const createMigration = `INSERT INTO "migration" (filename, hash, created, should_validate) VALUES ($1, $2, current_timestamp, TRUE);`
type Migration struct {
Filename string `db:"filename"`
Hash string `db:"hash"`
Created time.Time `db:"created"`
ShouldValidate bool `db:"should_validate"`
}
func InitializeMigrations(db *sqlx.DB, migrationFiles fs.FS) error {
// create table if it doesn't exist already
if _, err := db.Exec(migrationsTable); err != nil {
return err
}
// fetch current migrations
var entries []Migration
if err := db.Select(&entries, fetchMigrationsTable); err != nil {
return err
}
migrations := map[string]Migration{}
for _, entry := range entries {
migrations[entry.Filename] = entry
}
return validateMigrations(db, migrations, migrationFiles)
}
type version struct {
major int
minor int
}
func validateMigrations(db *sqlx.DB, migrations map[string]Migration, migrationFiles fs.FS) error {
scripts := map[string]string{}
var scriptNames []string
// fetch scripts
err := fs.WalkDir(migrationFiles, ".", func(path string, d fs.DirEntry, outerErr error) error {
if !d.IsDir() {
if bytearr, err := fs.ReadFile(migrationFiles, path); err == nil {
scripts[d.Name()] = strings.TrimSpace(string(bytearr))
scriptNames = append(scriptNames, d.Name())
}
}
return nil
})
if err != nil {
return err
}
// sort scripts by version
sort.Slice(scriptNames, func(i, j int) bool {
return isLessThan(parseVersion(scriptNames[i]), parseVersion(scriptNames[j]))
})
for _, name := range scriptNames {
if _, exists := migrations[name]; exists {
if err := validateMigration(name, migrations[name], scripts[name]); err != nil {
return err
}
} else {
if err := executeMigration(db, name, scripts[name]); err != nil {
return err
}
}
}
return nil
}
func executeMigration(db *sqlx.DB, name string, script string) error {
logger := slog.Default().With(slog.String("script", name))
logger.Info("migrations - executing")
tx := db.MustBeginTx(context.Background(), nil)
var err error = nil
if _, e := tx.Exec(script); e != nil {
err = e
}
if _, e := tx.Exec(createMigration, name, hash(script)); e != nil {
err = e
}
if err != nil {
logger.Error("migrations - failed executing", slog.String("err", err.Error()))
tx.Rollback()
} else {
logger.Info("migrations - successfully executed")
tx.Commit()
}
return err
}
func validateMigration(name string, migration Migration, script string) error {
if !migration.ShouldValidate {
return nil
}
calculatedHash := hash(script)
if calculatedHash != migration.Hash {
err := errors.New(fmt.Sprintf("migrations - mismatch in hash for %s (expected '%s', calculated '%s')", name, migration.Hash, calculatedHash))
slog.Error("migrations - failed validation", slog.String("script", name), slog.String("err", err.Error()))
return err
}
return nil
}
func hash(script string) string {
hash := sha256.New()
hash.Write([]byte(script))
return base64.StdEncoding.EncodeToString(hash.Sum(nil))
}
func isLessThan(v1 version, v2 version) bool {
if v1.major > v2.major {
return false
}
if v1.major == v2.major && v1.minor > v2.minor {
return false
}
return true
}
func parseVersion(filename string) version {
ver := version{}
if _, err := fmt.Sscanf(filename, "v%d_%d", &ver.major, &ver.minor); err != nil {
panic(err)
}
return ver
}