package migration import ( "context" "crypto/sha256" "encoding/base64" "fmt" "github.com/jmoiron/sqlx" "io/fs" "log" "sort" "strings" "time" ) const migrationsTable = ` CREATE TABLE IF NOT EXISTS migration ( filename varchar, hash varchar, created timestamp, should_validate boolean );` const fetchMigrationsTable = `SELECT * FROM "migration";` const createMigration = `INSERT INTO "migration" (filename, hash, created, should_validate) VALUES ($1, $2, current_timestamp, TRUE);` type Migration struct { Filename string `db:"filename"` Hash string `db:"hash"` Created time.Time `db:"created"` ShouldValidate bool `db:"should_validate"` } func InitializeMigrations(db *sqlx.DB, migrationFiles fs.FS) error { // create table if it doesn't exist already if _, err := db.Exec(migrationsTable); err != nil { return err } // fetch current migrations var entries []Migration if err := db.Select(&entries, fetchMigrationsTable); err != nil { return err } migrations := map[string]Migration{} for _, entry := range entries { migrations[entry.Filename] = entry } return validateMigrations(db, migrations, migrationFiles) } type version struct { major int minor int } func validateMigrations(db *sqlx.DB, migrations map[string]Migration, migrationFiles fs.FS) error { scripts := map[string]string{} var scriptNames []string // fetch scripts err := fs.WalkDir(migrationFiles, ".", func(path string, d fs.DirEntry, outerErr error) error { if !d.IsDir() { if bytearr, err := fs.ReadFile(migrationFiles, path); err == nil { scripts[d.Name()] = strings.TrimSpace(string(bytearr)) scriptNames = append(scriptNames, d.Name()) } } return nil }) if err != nil { return err } // sort scripts by version sort.Slice(scriptNames, func(i, j int) bool { return isLessThan(parseVersion(scriptNames[i]), parseVersion(scriptNames[j])) }) for _, name := range scriptNames { if _, exists := migrations[name]; exists { if err := validateMigration(name, migrations[name], scripts[name]); err != nil { return err } } else { if err := executeMigration(db, name, scripts[name]); err != nil { return err } } } return nil } func executeMigration(db *sqlx.DB, name string, script string) error { log.Printf("[INFO] script='%s' | migrations - executing", name) tx := db.MustBeginTx(context.Background(), nil) var err error = nil if _, e := tx.Exec(script); e != nil { err = e } if _, e := tx.Exec(createMigration, name, hash(script)); e != nil { err = e } if err != nil { log.Printf("[ERROR] script='%s' | migrations - failed executing", name) tx.Rollback() } else { log.Printf("[INFO] script='%s' | migrations - succesfully executed", name) tx.Commit() } return err } func validateMigration(name string, migration Migration, script string) error { if !migration.ShouldValidate { return nil } calculatedHash := hash(script) if calculatedHash != migration.Hash { err := fmt.Sprintf("migrations - mismatch in hash for %s (expected '%s', calculated '%s')", name, migration.Hash, calculatedHash) log.Printf("[ERROR] script='%s' err='%s' | migrations - failed executing", script, err) return fmt.Errorf("migrations - mismatch in hashes for %s", name) } return nil } func hash(script string) string { hash := sha256.New() hash.Write([]byte(script)) return base64.StdEncoding.EncodeToString(hash.Sum(nil)) } func isLessThan(v1 version, v2 version) bool { if v1.major > v2.major { return false } if v1.major == v2.major && v1.minor > v2.minor { return false } return true } func parseVersion(filename string) version { ver := version{} if _, err := fmt.Sscanf(filename, "v%d_%d", &ver.major, &ver.minor); err != nil { panic(err) } return ver }