package migration import ( "context" "crypto/sha256" "encoding/base64" "errors" "fmt" "github.com/jmoiron/sqlx" "io/fs" "log/slog" "sort" "strings" "time" ) const migrationsTable = ` CREATE TABLE IF NOT EXISTS migration ( filename varchar, hash varchar, created timestamp, should_validate boolean );` const fetchMigrationsTable = `SELECT * FROM "migration";` const createMigration = `INSERT INTO "migration" (filename, hash, created, should_validate) VALUES ($1, $2, current_timestamp, TRUE);` type Migration struct { Filename string `db:"filename"` Hash string `db:"hash"` Created time.Time `db:"created"` ShouldValidate bool `db:"should_validate"` } func InitializeMigrations(db *sqlx.DB, migrationFiles fs.FS) error { // create table if it doesn't exist already if _, err := db.Exec(migrationsTable); err != nil { return err } // fetch current migrations var entries []Migration if err := db.Select(&entries, fetchMigrationsTable); err != nil { return err } migrations := map[string]Migration{} for _, entry := range entries { migrations[entry.Filename] = entry } return validateMigrations(db, migrations, migrationFiles) } type version struct { major int minor int } func validateMigrations(db *sqlx.DB, migrations map[string]Migration, migrationFiles fs.FS) error { scripts := map[string]string{} var scriptNames []string // fetch scripts err := fs.WalkDir(migrationFiles, ".", func(path string, d fs.DirEntry, outerErr error) error { if !d.IsDir() { if bytearr, err := fs.ReadFile(migrationFiles, path); err == nil { scripts[d.Name()] = strings.TrimSpace(string(bytearr)) scriptNames = append(scriptNames, d.Name()) } } return nil }) if err != nil { return err } // sort scripts by version sort.Slice(scriptNames, func(i, j int) bool { return isLessThan(parseVersion(scriptNames[i]), parseVersion(scriptNames[j])) }) for _, name := range scriptNames { if _, exists := migrations[name]; exists { if err := validateMigration(name, migrations[name], scripts[name]); err != nil { return err } } else { if err := executeMigration(db, name, scripts[name]); err != nil { return err } } } return nil } func executeMigration(db *sqlx.DB, name string, script string) error { logger := slog.Default().With(slog.String("script", name)) logger.Info("migrations - executing") tx := db.MustBeginTx(context.Background(), nil) var err error = nil if _, e := tx.Exec(script); e != nil { err = e } if _, e := tx.Exec(createMigration, name, hash(script)); e != nil { err = e } if err != nil { logger.Error("migrations - failed executing", slog.String("err", err.Error())) tx.Rollback() } else { logger.Info("migrations - successfully executed") tx.Commit() } return err } func validateMigration(name string, migration Migration, script string) error { if !migration.ShouldValidate { return nil } calculatedHash := hash(script) if calculatedHash != migration.Hash { err := errors.New(fmt.Sprintf("migrations - mismatch in hash for %s (expected '%s', calculated '%s')", name, migration.Hash, calculatedHash)) slog.Error("migrations - failed validation", slog.String("script", name), slog.String("err", err.Error())) return err } return nil } func hash(script string) string { hash := sha256.New() hash.Write([]byte(script)) return base64.StdEncoding.EncodeToString(hash.Sum(nil)) } func isLessThan(v1 version, v2 version) bool { if v1.major > v2.major { return false } if v1.major == v2.major && v1.minor > v2.minor { return false } return true } func parseVersion(filename string) version { ver := version{} if _, err := fmt.Sscanf(filename, "v%d_%d", &ver.major, &ver.minor); err != nil { panic(err) } return ver }