From 0603b89613b3070c72bf17ff23f6da5f7a8e2c6f Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Thu, 29 Jan 2026 15:59:05 -0700 Subject: [PATCH] feat(cmd/sql-migrate/v2): store migrations in db, nix batches --- cmd/sql-migrate/README.md | 72 +-- cmd/sql-migrate/main.go | 1184 ++++++++++++++++++++++++++----------- 2 files changed, 882 insertions(+), 374 deletions(-) diff --git a/cmd/sql-migrate/README.md b/cmd/sql-migrate/README.md index d25e143..7280b5a 100644 --- a/cmd/sql-migrate/README.md +++ b/cmd/sql-migrate/README.md @@ -25,9 +25,11 @@ Lexicographically-sortable files in the format `_. -sql-migrate status -sql-migrate up -sql-migrate down -sql-migrate list +sql-migrate -d ./sql/migrations/ init --sql-command --migrations-log ./sql/migrations.log +sql-migrate -d ./sql/migrations/ create +sql-migrate -d ./sql/migrations/ status +sql-migrate -d ./sql/migrations/ up 99 +sql-migrate -d ./sql/migrations/ down 1 +sql-migrate -d ./sql/migrations/ list ``` See `sql-migrate help` for details. @@ -94,31 +93,36 @@ See `sql-migrate help` for details. ```text COMMANDS - init - inits sql dir and migration file, adding or updating the - default command + init - creates migrations directory, initial migration, log file, + and query for migrations create - creates a new, canonically-named up/down file pair in the - migrations directory + migrations directory, with corresponding insert status - shows the same output as if processing a forward-migration - for the most recent batch - up - processes the first 'up' migration file missing from the - migration state - down - rolls back the latest entry of the latest migration batch - (the whole batch if just one) + up [n] - create a script to run pending migrations (ALL by default) + down [n] - create a script to roll back migrations (ONE by default) list - lists migrations OPTIONS -d default: ./sql/migrations/ - -f default: ./sql/migrations.log + --help show command-specific help + +NOTES + Migrations files are in the following format: + -_..sql + 2020-01-01-1000_init-app.up.sql + + The initial migration file contains configuration variables: + -- migrations_log: ./sql/migrations.log + -- sql_command: psql "$PG_URL" -v ON_ERROR_STOP=on --no-align --file %s + + The log is generated on each migration file contains a list of all migrations: + 0001-01-01-001000_migrations.up.sql + 2020-12-31-001000_init-app.up.sql + 2020-12-31-001100_add-customer-tables.up.sql + 2020-12-31-002000_add-ALL-THE-TABLES.up.sql + + The 'create' generates an up/down pair of files using the current date and + the number 1000. If either file exists, the number is incremented by 1000 and + tried again. - The migration state file contains the client command template (defaults to - 'psql "$PG_URL" < %s'), followed by a list of batches identified by a batch - number comment and a list of migration file basenames and optional user - comments, such as: - # command: psql "$PG_URL" < %s - # batch: 1 - 2020-01-01-1000_init.up.sql # does a lot - 2020-01-01-1100_add-customer-tables.up.sql - # batch: 2 - # We did id! Finally! - 2020-01-01-2000_add-ALL-THE-TABLES.up.sql ``` diff --git a/cmd/sql-migrate/main.go b/cmd/sql-migrate/main.go index 5b33b6c..4d2771f 100644 --- a/cmd/sql-migrate/main.go +++ b/cmd/sql-migrate/main.go @@ -12,6 +12,9 @@ package main import ( + "bufio" + "crypto/rand" + "encoding/hex" "flag" "fmt" "log" @@ -26,88 +29,641 @@ import ( ) const ( - defaultMigrationDir = "./sql/migrations/" - defaultMigrationLog = "./sql/migrations.log" - defaultCommand = `psql "$PG_URL" < %s` + version = "2.0.0" ) +const ( + defaultMigrationDir = "./sql/migrations/" + defaultLogPath = "../migrations.log" + sqlCommandPSQL = `psql "$PG_URL" -v ON_ERROR_STOP=on -A -t --file %s` + sqlCommandMariaDB = `mariadb --defaults-extra-file="$MY_CNF" -s -N --raw < %s` + sqlCommandMySQL = `mysql --defaults-extra-file="$MY_CNF" -s -N --raw < %s` + LOG_QUERY_NAME = "_migrations.sql" + M_MIGRATOR_NAME = "0001-01-01-01000_init-migrations" + M_MIGRATOR_UP_NAME = "0001-01-01-01000_init-migrations.up.sql" + M_MIGRATOR_DOWN_NAME = "0001-01-01-01000_init-migrations.down.sql" + defaultMigratorUpTmpl = `-- Config variables for sql-migrate (do not delete) +-- sql_command: %s +-- migrations_log: %s +-- + +CREATE TABLE IF NOT EXISTS _migrations ( + id CHAR(8) PRIMARY KEY DEFAULT encode(gen_random_bytes(4), 'hex'), + name VARCHAR(80) NULL UNIQUE, + applied_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP +); + +INSERT INTO _migrations (id, name) VALUES ('00000001', '0001-01-01-01000_init-migrations'); +` + defaultMigratorDown = `DELETE FROM _migrations WHERE id = '00000001'; + +DROP TABLE IF EXISTS _migrations; +` + LOG_MIGRATIONS_QUERY = `-- note: --no-align must be passed via CLI to avoid extraneous output +SELECT name FROM _migrations ORDER BY name; +` + shHeader = `#/bin/sh +set -e +set -u + +if test -s ./.env; then + . ./.env +fi +` +) + +const helpText = ` +sql-migrate v` + version + ` - a feature-branch-friendly SQL migrator + +USAGE + sql-migrate [-d sqldir] [args] + +EXAMPLE + sql-migrate -d ./sql/migrations/ init --sql-command + sql-migrate -d ./sql/migrations/ create + sql-migrate -d ./sql/migrations/ status + sql-migrate -d ./sql/migrations/ up 99 + sql-migrate -d ./sql/migrations/ down 1 + sql-migrate -d ./sql/migrations/ list + +COMMANDS + init - creates migrations directory, initial migration, log file, + and query for migrations + create - creates a new, canonically-named up/down file pair in the + migrations directory, with corresponding insert + status - shows the same output as if processing a forward-migration + up [n] - create a script to run pending migrations (ALL by default) + down [n] - create a script to roll back migrations (ONE by default) + list - lists migrations + +OPTIONS + -d default: ./sql/migrations/ + --help show command-specific help + +NOTES + Migrations files are in the following format: + -_..sql + 2020-01-01-1000_init-app.up.sql + + The initial migration file contains configuration variables: + -- migrations_log: ./sql/migrations.log + -- sql_command: psql "$PG_URL" -v ON_ERROR_STOP=on --no-align --file %s + + The log is generated on each migration file contains a list of all migrations: + 0001-01-01-001000_migrations.up.sql + 2020-12-31-001000_init-app.up.sql + 2020-12-31-001100_add-customer-tables.up.sql + 2020-12-31-002000_add-ALL-THE-TABLES.up.sql + + The 'create' generates an up/down pair of files using the current date and + the number 1000. If either file exists, the number is incremented by 1000 and + tried again. +` + var ( nonWordRe = regexp.MustCompile(`\W+`) - commandStartRe = regexp.MustCompile(`^#\s*command:\s*`) - batchStartRe = regexp.MustCompile(`^#\s*batch:\s*`) commentStartRe = regexp.MustCompile(`(^|\s+)#.*`) ) type State struct { - Date time.Time - Command string - Current int - Lines []string - Migrated []string - SqlDir string - LogFile string + Date time.Time + SQLCommand string + Lines []string + Migrated []string + MigrationsDir string + LogPath string } -func parseLog(text string, date time.Time) *State { - state := &State{Date: date, Command: "", Current: 0, Lines: []string{}, Migrated: []string{}} - text = strings.TrimSpace(text) - if text == "" { - state.Command = defaultCommand - return state +type MainConfig struct { + migrationsDir string + logPath string + sqlCommand string +} + +func main() { + var cfg MainConfig + var date = time.Now() + + if len(os.Args) < 2 { + //nolint + fmt.Printf("%s\n", helpText) + os.Exit(0) } - state.Lines = strings.Split(text, "\n") - batchCount := 0 - for i := range state.Lines { - line := strings.TrimSpace(state.Lines[i]) - if commandStartRe.MatchString(line) { - if state.Command != "" { - log.Printf(" ignoring duplicate '%s'", line) - } else { - state.Command = commandStartRe.ReplaceAllString(line, "") - } + + switch os.Args[1] { + case "help", "--help", + "version", "--version", "-V": + fmt.Printf("%s\n", helpText) + os.Exit(0) + default: + // do nothing + } + + mainArgs := os.Args[1:] + fsMain := flag.NewFlagSet("", flag.ExitOnError) + fsMain.StringVar(&cfg.migrationsDir, "d", defaultMigrationDir, "directory for migrations (where 0001-01-01_init-migrations.up.sql will be added)") + if err := fsMain.Parse(mainArgs); err != nil { + os.Exit(2) + } + + var subcmd string + // note: Args() includes any flags after the first non-flag arg + // sql-migrate -d ./migs/ init --migrations-log ./migrations.log + // => init -f ./migrations.log + subArgs := fsMain.Args() + if len(subArgs) > 0 { + subcmd = subArgs[0] + subArgs = subArgs[1:] + } + + var fsSub *flag.FlagSet + switch subcmd { + case "init": + fsSub = flag.NewFlagSet("init", flag.ExitOnError) + fsSub.StringVar(&cfg.logPath, "migrations-log", "", fmt.Sprintf("migration log file (default: %s) relative to and saved in %s", defaultLogPath, M_MIGRATOR_NAME)) + fsSub.StringVar(&cfg.sqlCommand, "sql-command", sqlCommandPSQL, "construct scripts with this to execute SQL files: 'psql', 'mysql', 'mariadb', or custom arguments") + case "create", "up", "down", "status", "list": + fsSub = flag.NewFlagSet(subcmd, flag.ExitOnError) + default: + log.Printf("unknown command %s", subcmd) + fmt.Printf("%s\n", helpText) + os.Exit(1) + } + if err := fsSub.Parse(subArgs); err != nil { + os.Exit(2) + } + leafArgs := fsSub.Args() + + switch cfg.sqlCommand { + case "", "posgres", "posgresql", "pg", "psql", "plpgsql": + cfg.sqlCommand = sqlCommandPSQL + case "mariadb": + cfg.sqlCommand = sqlCommandMariaDB + case "mysql", "my": + cfg.sqlCommand = sqlCommandMySQL + } + + if !strings.HasSuffix(cfg.migrationsDir, "/") { + cfg.migrationsDir += "/" + } + + if subcmd == "init" { + mustInit(&cfg) + } + + entries, err := os.ReadDir(cfg.migrationsDir) + if err != nil { + if os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "Error: missing migrations directory. Run 'sql-migrate -d %q init' to create it.\n", cfg.migrationsDir) + os.Exit(1) } - if batchStartRe.MatchString(line) { - parts := strings.SplitN(line, ":", 2) - if len(parts) < 2 { - continue - } - n, err := strconv.Atoi(strings.TrimSpace(parts[1])) - if err != nil || n <= 0 { - log.Printf(" invalid '%s'", line) - n = -1 - } - batchCount++ - if n > state.Current { - state.Current = n - } - if batchCount > state.Current { - state.Current = batchCount - } + fmt.Fprintf(os.Stderr, "Error: couldn't list migrations directory: %v\n", err) + os.Exit(1) + } + + ups, downs := migrationsList(cfg.migrationsDir, entries) + if !slices.Contains(ups, M_MIGRATOR_NAME) { + fmt.Fprintf(os.Stderr, "Error: missing initial migration. Run 'sql-migrate -d %q init' to create %q.\n", cfg.migrationsDir, M_MIGRATOR_UP_NAME) + os.Exit(1) + } + if !slices.Contains(downs, M_MIGRATOR_NAME) { + fmt.Fprintf(os.Stderr, "Error: missing initial migration. Run 'sql-migrate -d %q init' to create %q.\n", cfg.migrationsDir, M_MIGRATOR_DOWN_NAME) + os.Exit(1) + } + + logQueryPath := filepath.Join(cfg.migrationsDir, LOG_QUERY_NAME) + if !fileExists(logQueryPath) { + fmt.Fprintf(os.Stderr, "Error: missing %q. Run 'sql-migrate -d %q init' to create it.\n", logQueryPath, cfg.migrationsDir) + os.Exit(1) + } + + mMigratorUpPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_UP_NAME) + // mMigratorDownPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_DOWN_NAME) + + state := State{ + Date: date, + MigrationsDir: cfg.migrationsDir, + } + state.SQLCommand, state.LogPath, err = extractVars(mMigratorUpPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: couldn't read config from initial migration: %v\n", err) + os.Exit(1) + } + + logText, err := os.ReadFile(state.LogPath) + if err != nil { + if !os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "Error: couldn't read migrations log %q: %v\n", state.LogPath, err) + os.Exit(1) } - migration := commentStartRe.ReplaceAllString(line, "") - migration = strings.TrimSpace(migration) - if migration != "" { - state.Migrated = append(state.Migrated, migration) + + if err := migrationsLogInit(&state, subcmd); err != nil { + fmt.Fprintf(os.Stderr, "Error: couldn't create log file directory: %v\n", err) + os.Exit(1) } - state.Lines[i] = line } - if state.Command == "" { - state.Command = defaultCommand + + if err := state.parseAndFixupBatches(string(logText)); err != nil { + fmt.Fprintf(os.Stderr, "Error: couldn't read migrations log or fixup batches: %v\n", err) + os.Exit(1) } - if !strings.Contains(state.Command, "%s") { - state.Command += " %s" + + switch subcmd { + case "init": + break + case "create": + if len(leafArgs) == 0 { + log.Fatal("create requires a description") + } + desc := strings.Join(leafArgs, " ") + desc = nonWordRe.ReplaceAllString(desc, " ") + desc = strings.TrimSpace(desc) + desc = nonWordRe.ReplaceAllString(desc, "-") + desc = strings.ToLower(desc) + err = create(&state, desc) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: couldn't create migration: %v\n", err) + os.Exit(1) + } + case "status": + if len(leafArgs) > 0 && subcmd != "create" { + fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " ")) + os.Exit(1) + } + err = status(&state, ups) + if err != nil { + log.Fatal(err) + } + case "list": + if len(leafArgs) > 0 && subcmd != "create" { + fmt.Fprintf(os.Stderr, "Error: unexpected args: %s\n", strings.Join(leafArgs, " ")) + os.Exit(1) + } + fmt.Println("Ups:") + if len(ups) == 0 { + fmt.Println(" (none)") + } + for _, u := range ups { + fmt.Println(" ", u) + } + fmt.Println("") + fmt.Println("Downs:") + if len(downs) == 0 { + fmt.Println(" (none)") + } + for _, d := range downs { + fmt.Println(" ", d) + } + case "up": + var upN int + switch len(leafArgs) { + case 0: + // ignore + case 1: + upN, err = strconv.Atoi(leafArgs[0]) + if err != nil || upN < 0 { + fmt.Fprintf(os.Stderr, "Error: %s is not a positive number\n", leafArgs[0]) + os.Exit(1) + } + default: + fmt.Fprintf(os.Stderr, "Error: unrecognized arguments %q \n", strings.Join(leafArgs, "\" \"")) + os.Exit(1) + } + + err = up(&state, ups, upN) + if err != nil { + log.Fatal(err) + } + case "down": + var downN int + switch len(leafArgs) { + case 0: + // ignore + case 1: + downN, err = strconv.Atoi(leafArgs[0]) + if err != nil || downN < 0 { + fmt.Fprintf(os.Stderr, "Error: %s is not a positive number\n", leafArgs[0]) + os.Exit(1) + } + default: + fmt.Fprintf(os.Stderr, "Error: unrecognized arguments %q \n", strings.Join(leafArgs, "\" \"")) + os.Exit(1) + } + + err = down(&state, downN) + if err != nil { + log.Fatal(err) + } + default: + log.Printf("unknown command %s", subcmd) + fmt.Printf("%s\n", helpText) + os.Exit(1) } - return state +} + +func migrationsList(migrationsDir string, entries []os.DirEntry) (ups, downs []string) { + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") { + if name != LOG_QUERY_NAME { + fmt.Fprintf(os.Stderr, " ignoring '%s'\n", filepathUnclean(filepath.Join(migrationsDir, name))) + } + continue + } + + if base, ok := strings.CutSuffix(name, ".up.sql"); ok { + ups = append(ups, base) + // TODO on ups add INSERT to file and to up migration if it doesn't exist + continue + } + + if base, ok := strings.CutSuffix(name, ".down.sql"); ok { + downs = append(downs, base) + continue + } + + fmt.Fprintf(os.Stderr, " unknown '%s'\n", filepath.Join(migrationsDir, name)) + } + for _, down := range downs { + // TODO on downs add INSERT to file and to up migration if it doesn't exist + upName := strings.TrimSuffix(down, ".down.sql") + ".up.sql" + companion := filepath.Join(migrationsDir, upName) + if !fileExists(companion) { + fmt.Fprintf(os.Stderr, " missing '%s'\n", companion) + } + } + + sort.Strings(ups) + sort.Strings(downs) + return ups, downs } func fileExists(path string) bool { _, err := os.Stat(path) - return err == nil + if err != nil { + if os.IsNotExist(err) { + return false + } + fmt.Fprintf(os.Stderr, "Error: can't access %q\n", path) + os.Exit(1) + return false + } + + return true +} + +func filepathUnclean(path string) string { + if !strings.HasPrefix(path, "/") { + if !strings.HasPrefix(path, "./") && !strings.HasPrefix(path, "../") { + path = "./" + path + } + } + return path +} + +// initializes all necessary files and directories +// - ./sql/migrations.log +// +// - ./sql/migrations +// +// - ./sql/migrations/0001-01-01-01000_init-migrations.up.sql +// - migrations_log: ./sql/migrations.log +// - sql_command: psql "$PG_URL" -v ON_ERROR_STOP=on --no-align --file %s +// +// - ./sql/migrations/0001-01-01-01000_init-migrations.down.sql +func mustInit(cfg *MainConfig) { + fmt.Fprintf(os.Stderr, "Initializing %q ...\n", cfg.migrationsDir) + + var resolvedLogPath = cfg.logPath + if cfg.sqlCommand != "" && !strings.Contains(cfg.sqlCommand, "%s") { + fmt.Fprintf(os.Stderr, "Error: --sql-command must contain a literal '%%s' to accept the path to the SQL file\n") + os.Exit(1) + } + + entries, err := os.ReadDir(cfg.migrationsDir) + if err != nil { + if !os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "Error: init failed to create %q: %v\n", cfg.migrationsDir, err) + os.Exit(1) + } + if err = os.MkdirAll(cfg.migrationsDir, 0755); err != nil { + fmt.Fprintf(os.Stderr, "Error: init failed to read %q: %v\n", cfg.migrationsDir, err) + os.Exit(1) + } + } + + ups, downs := migrationsList(cfg.migrationsDir, entries) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: init couldn't list existing migrations: %v\n", err) + os.Exit(1) + } + + mMigratorUpPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_UP_NAME) + mMigratorDownPath := filepath.Join(cfg.migrationsDir, M_MIGRATOR_DOWN_NAME) + + // write config + if slices.Contains(ups, M_MIGRATOR_NAME) { + fmt.Fprintf(os.Stderr, " found '%s'\n", filepath.Join(cfg.migrationsDir, M_MIGRATOR_UP_NAME)) + } else { + if cfg.logPath == "" { + migrationsParent := filepath.Dir(cfg.migrationsDir) + resolvedLogPath = filepath.Join(migrationsParent, defaultLogPath) + // resolvedLogPath, err = filepath.Rel(cfg.migrationsDir, cfg.logPath) + // if err != nil { + // fmt.Fprintf(os.Stderr, "Error: init couldn't resolve the migrations log relative to the migrations dir: %v\n", err) + // os.Exit(1) + // } + } + + migratorUpQuery := fmt.Sprintf(defaultMigratorUpTmpl, cfg.sqlCommand, resolvedLogPath) + if created, err := initFile(mMigratorUpPath, migratorUpQuery); err != nil { + fmt.Fprintf(os.Stderr, "Error: init couldn't create initial up migration: %v\n", err) + os.Exit(1) + } else if created { + fmt.Fprintf(os.Stderr, " created '%s'\n", mMigratorUpPath) + } + } + + state := State{ + MigrationsDir: cfg.migrationsDir, + } + state.SQLCommand, state.LogPath, err = extractVars(mMigratorUpPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: init couldn't read config from initial migration: %v\n", err) + os.Exit(1) + } + if cfg.logPath != "" && filepath.Clean(cfg.logPath) != filepath.Clean(state.LogPath) { + fmt.Fprintf(os.Stderr, + "--migrations-log %q does not match %q from %q\n(drop the --migrations-log flag, or update the add migrations file)\n", + cfg.logPath, state.LogPath, mMigratorUpPath, + ) + os.Exit(1) + } + if cfg.sqlCommand != "" && cfg.sqlCommand != state.SQLCommand { + fmt.Fprintf(os.Stderr, + "--sql-command %q does not match %q from %q\n(drop the --sql-command flag, or update the add migrations file)\n", + cfg.sqlCommand, state.SQLCommand, mMigratorUpPath, + ) + os.Exit(1) + } + + if slices.Contains(downs, M_MIGRATOR_NAME) { + fmt.Fprintf(os.Stderr, " found '%s'\n", mMigratorDownPath) + } else { + migratorDownQuery := defaultMigratorDown + if created, err := initFile(mMigratorDownPath, migratorDownQuery); err != nil { + fmt.Fprintf(os.Stderr, "Error: init couldn't create initial up migration: %v\n", err) + os.Exit(1) + } else if created { + fmt.Fprintf(os.Stderr, " created '%s'\n", mMigratorDownPath) + } + } + + logQueryPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME) + if created, err := initFile(logQueryPath, LOG_MIGRATIONS_QUERY); err != nil { + fmt.Fprintf(os.Stderr, "Error: init couldn't create migrations query: %v\n", err) + os.Exit(1) + } else if created { + fmt.Fprintf(os.Stderr, " created '%s'\n", logQueryPath) + } else { + fmt.Fprintf(os.Stderr, " found '%s'\n", logQueryPath) + } + + if fileExists(state.LogPath) { + fmt.Fprintf(os.Stderr, " found '%s'\n", state.LogPath) + fmt.Fprintf(os.Stderr, "done\n") + return + } +} + +func initFile(path, contents string) (bool, error) { + if fileExists(path) { + return false, nil + } + + if err := os.WriteFile(path, []byte(contents), 0644); err != nil { + return false, err + } + + return true, nil +} + +func extractVars(curMigrationPath string) (sqlCommand string, logPath string, err error) { + f, err := os.Open(curMigrationPath) + if err != nil { + return "", "", err + } + defer f.Close() + + scanner := bufio.NewScanner(f) + + var logPathRel string + var logPathPrefix = "-- migrations_log:" + var commandPrefix = "-- sql_command:" + for scanner.Scan() { + txt := scanner.Text() + txt = strings.TrimSpace(txt) + if strings.HasPrefix(txt, logPathPrefix) { + logPathRel = strings.TrimSpace(txt[len(logPathPrefix):]) + continue + } else if strings.HasPrefix(txt, commandPrefix) { + sqlCommand = strings.TrimSpace(txt[len(commandPrefix):]) + continue + } + } + + if logPathRel == "" { + return "", "", fmt.Errorf("Could not find '-- migrations_log: ' in %q", curMigrationPath) + } + if sqlCommand == "" { + return "", "", fmt.Errorf("Could not find '-- sql_command: ' in %q", curMigrationPath) + } + + // migrationsDir := filepath.Dir(curMigrationPath) + // logPath = filepath.Join(migrationsDir, logPathRel) + // return sqlCommand, logPath, nil + return sqlCommand, logPathRel, nil +} + +func migrationsLogInit(state *State, subcmd string) error { + logDir := filepath.Dir(state.LogPath) + if err := os.MkdirAll(logDir, 0755); err != nil { + return err + } + + if subcmd != "up" { + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "Run the first migration to complete the initialization:\n") + fmt.Fprintf(os.Stderr, "(you'll need to provide DB credentials via .env or export)\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, " sql-migrate -d %q up > ./up.sh && sh ./up.sh\n", state.MigrationsDir) + fmt.Fprintf(os.Stderr, "\n") + } + return nil +} + +func (state *State) parseAndFixupBatches(text string) error { + text = strings.TrimSpace(text) + if text == "" { + return nil + } + + fixedUp := []string{} + fixedDown := []string{} + + state.Lines = strings.Split(text, "\n") + for i := range state.Lines { + line := strings.TrimSpace(state.Lines[i]) + migration := commentStartRe.ReplaceAllString(line, "") + migration = strings.TrimSpace(migration) + if migration != "" { + up, down, warn, err := fixupMigration(state.MigrationsDir, migration) + if warn != nil { + fmt.Fprintf(os.Stderr, "Warn: %s\n", warn) + } + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + } + state.Migrated = append(state.Migrated, migration) + if up { + fixedUp = append(fixedUp, migration) + } + if down { + fixedDown = append(fixedDown, migration) + } + + } + state.Lines[i] = line + } + showFixes(fixedUp, fixedDown) + + return nil +} + +func showFixes(fixedUp, fixedDown []string) { + if len(fixedUp) > 0 { + fmt.Fprintf(os.Stderr, "Fixup: prepended missing 'INSERT INTO _migrations ...' to:\n") + for _, up := range fixedUp { + fmt.Fprintf(os.Stderr, " %s\n", up) + } + fmt.Fprintf(os.Stderr, "\n") + } + + if len(fixedDown) > 0 { + fmt.Fprintf(os.Stderr, "Fixup: appended missing 'DELETE FROM _migrations ...' to:\n") + for _, down := range fixedDown { + fmt.Fprintf(os.Stderr, " %s\n", down) + } + fmt.Fprintf(os.Stderr, "\n") + } } func create(state *State, desc string) error { dateStr := state.Date.Format("2006-01-02") - entries, err := os.ReadDir(state.SqlDir) + entries, err := os.ReadDir(state.MigrationsDir) if err != nil { return err } @@ -127,7 +683,7 @@ func create(state *State, desc string) error { continue } if strings.HasSuffix(name, "_"+desc+".up.sql") { - return fmt.Errorf("migration for %q already exists:\n %s", desc, state.SqlDir+"/"+name) + return fmt.Errorf("migration for %q already exists:\n %s", desc, state.MigrationsDir+"/"+name) } if strings.HasSuffix(name, ".down.sql") { continue @@ -151,21 +707,31 @@ func create(state *State, desc string) error { } } - number := maxNumber / 1000 - number *= 1000 - number += 1000 - if number > 9000 { - return fmt.Errorf("it's over 9000! ") + number := maxNumber / 1_000 + number *= 1_000 + number += 1_000 + if number > 9_000 && number < 10_000 { + fmt.Fprintf(os.Stderr, "Achievement Unlocked: It's over 9000!\n") + } + if number >= 999_999 { + fmt.Fprintf(os.Stderr, "Error: cowardly refusing to generate such a suspiciously high number of migrations after running out of numbers\n") + os.Exit(1) } - baseFilename := fmt.Sprintf("%s-%06d_%s", dateStr, number, desc) - upPath := filepath.Join(state.SqlDir, baseFilename+".up.sql") - downPath := filepath.Join(state.SqlDir, baseFilename+".down.sql") + basename := fmt.Sprintf("%s-%06d_%s", dateStr, number, desc) + upPath := filepath.Join(state.MigrationsDir, basename+".up.sql") + downPath := filepath.Join(state.MigrationsDir, basename+".down.sql") - // Use fmt.Appendf to build byte slice, ignoring error as it can't fail with static format - upContent := fmt.Appendf(nil, "-- %s (up)\n", desc) + id := MustRandomHex(4) + + // Little Bobby Drop Tables says: + // We trust the person running the migrations to not use malicious names. + // (we don't want to embed db-specific logic here, and SQL doesn't define escaping) + migrationInsert := fmt.Sprintf("INSERT INTO _migrations (id, name) VALUES ('%s', '%s');", id, basename) + upContent := fmt.Appendf(nil, "-- leave this as the first line\n%s\n\n-- %s (up)\nSELECT 'place your UP migration here';\n", migrationInsert, desc) _ = os.WriteFile(upPath, upContent, 0644) - downContent := fmt.Appendf(nil, "-- %s (down)\n", desc) + migrationDelete := fmt.Sprintf("DELETE FROM _migrations WHERE id = '%s';", id) + downContent := fmt.Appendf(nil, "-- %s (down)\nSELECT 'place your DOWN migration here';\n\n-- leave this as the last line\n%s\n", desc, migrationDelete) _ = os.WriteFile(downPath, downContent, 0644) fmt.Fprintf(os.Stderr, " created pair %s\n", upPath) @@ -173,50 +739,104 @@ func create(state *State, desc string) error { return nil } -func listMigrations(state *State) (ups, downs []string, err error) { - entries, err := os.ReadDir(state.SqlDir) +func MustRandomHex(n int) string { + s, err := RandomHex(n) if err != nil { - return nil, nil, err + panic(err) } - for _, entry := range entries { - if entry.IsDir() { - continue - } - name := entry.Name() - if strings.HasPrefix(name, ".") || strings.HasPrefix(name, "_") { - log.Printf(" ignoring '%s'", name) - continue - } - if strings.HasSuffix(name, ".up.sql") { - base := strings.TrimSuffix(name, ".up.sql") - ups = append(ups, base) - companion := filepath.Join(state.SqlDir, base+".down.sql") - if !fileExists(companion) { - log.Printf(" missing '%s'", companion) - } - continue - } - if strings.HasSuffix(name, ".down.sql") { - base := strings.TrimSuffix(name, ".down.sql") - downs = append(downs, base) - companion := filepath.Join(state.SqlDir, base+".up.sql") - if !fileExists(companion) { - log.Printf(" missing '%s'", companion) - } - continue - } - log.Printf(" unknown '%s'", name) - } - sort.Strings(ups) - sort.Strings(downs) - return ups, downs, nil + return s } -func up(state *State) error { - ups, _, err := listMigrations(state) +func RandomHex(n int) (string, error) { + b := make([]byte, n) // 4 bytes = 8 hex chars + _, err := rand.Read(b) if err != nil { - return err + return "", err } + return hex.EncodeToString(b), nil +} + +// attempts to add missing INSERT and DELETE without breaking what already works +func fixupMigration(dir string, basename string) (up, down bool, warn error, err error) { + var id string + + var insertsOnUp bool + upPath := filepath.Join(dir, basename+".up.sql") + upScan, err := os.Open(upPath) + if err != nil { + return false, false, nil, fmt.Errorf("failed (up): %w", err) + } + defer upScan.Close() + scanner := bufio.NewScanner(upScan) + for scanner.Scan() { + txt := scanner.Text() + txt = strings.TrimSpace(txt) + txt = strings.ToLower(txt) + if strings.HasPrefix(txt, "insert into _migrations") { + insertsOnUp = true + break + } + } + if !insertsOnUp { + id = MustRandomHex(4) + upScan.Close() + upBytes, err := os.ReadFile(upPath) + if err != nil { + warn = fmt.Errorf("failed to add 'INSERT INTO _migrations ...' to %s: %w", upPath, err) + return false, false, warn, nil + } + + migrationInsertLn := fmt.Sprintf("INSERT INTO _migrations (id, name) VALUES ('%s', '%s');\n\n", id, basename) + upBytes = append([]byte(migrationInsertLn), upBytes...) + if err = os.WriteFile(upPath, upBytes, 0644); err != nil { + warn = fmt.Errorf("failed to prepend 'INSERT INTO _migrations ...' to %s: %w", upPath, err) + return false, false, warn, nil + } + up = true + } + + var deletesOnDown bool + downPath := filepath.Join(dir, basename+".down.sql") + downScan, err := os.Open(downPath) + if err != nil { + return false, false, fmt.Errorf("failed (down): %w", err), nil + } + defer downScan.Close() + scanner = bufio.NewScanner(downScan) + for scanner.Scan() { + txt := scanner.Text() + txt = strings.TrimSpace(txt) + txt = strings.ToLower(txt) + if strings.HasPrefix(txt, "delete from _migrations") { + deletesOnDown = true + break + } + } + if !deletesOnDown { + if id == "" { + return false, false, fmt.Errorf("must manually append \"DELETE FROM _migrations WHERE id = ''\" to %s with id from %s", downPath, basename+"up.sql"), nil + } + downScan.Close() + downFile, err := os.OpenFile(downPath, os.O_APPEND|os.O_WRONLY, 0o644) + if err != nil { + warn = fmt.Errorf("failed to append 'DELETE FROM _migrations ...' to %s: %v", downPath, err) + return false, false, warn, nil + } + defer downFile.Close() + + migrationInsertLn := fmt.Sprintf("\nDELETE FROM _migrations WHERE id = '%s';\n", id) + _, err = downFile.Write(([]byte(migrationInsertLn))) + if err != nil { + warn = fmt.Errorf("failed to add 'DELETE FROM _migrations ...' to %s: %w", downPath, err) + return false, false, warn, nil + } + down = true + } + + return up, down, nil, nil +} + +func up(state *State, ups []string, n int) error { var pending []string for _, mig := range ups { found := slices.Contains(state.Migrated, mig) @@ -224,87 +844,146 @@ func up(state *State) error { pending = append(pending, mig) } } + + getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME) + getMigsPath = filepathUnclean(getMigsPath) + getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1) + if len(pending) == 0 { - log.Println(" already up-to-date") + fmt.Fprintf(os.Stderr, "# Already up-to-date\n") + fmt.Fprintf(os.Stderr, "#\n") + fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n") + fmt.Fprintf(os.Stderr, "%s\n", "# "+getMigs+" > "+filepathUnclean(state.LogPath)) return nil } - n := state.Current + 1 - fmt.Printf("echo '# batch: %d' >> %s\n", n, state.LogFile) - for _, mig := range pending { - fmt.Println("") - fmt.Printf("# INSERT INTO \"migrations\" ('%d', '%s')\n", n, mig) - fmt.Printf("echo '%s' >> %s\n", mig, state.LogFile) - path := filepath.Join(state.SqlDir, mig+".up.sql") - if !strings.HasPrefix(path, "/") { - if !strings.HasPrefix(path, "./") && !strings.HasPrefix(path, "../") { - path = "./" + path + if n == 0 { + n = len(pending) + } + + fixedUp := []string{} + fixedDown := []string{} + + fmt.Printf(shHeader) + fmt.Println("") + fmt.Println("# FORWARD / UP Migrations") + fmt.Println("") + for i, migration := range pending { + if i >= n { + break + } + + path := filepath.Join(state.MigrationsDir, migration+".up.sql") + path = filepathUnclean(path) + { + up, down, warn, err := fixupMigration(state.MigrationsDir, migration) + if warn != nil { + fmt.Fprintf(os.Stderr, "Warn: %s\n", warn) + } + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + } + if up { + fixedUp = append(fixedUp, migration) + } + if down { + fixedDown = append(fixedDown, migration) } } - cmd := strings.Replace(state.Command, "%s", path, 1) + cmd := strings.Replace(state.SQLCommand, "%s", path, 1) + fmt.Printf("# +%d %s\n", i+1, migration) fmt.Println(cmd) + fmt.Println(getMigs + " > " + filepathUnclean(state.LogPath)) + fmt.Println("") } - fmt.Println("") + fmt.Println("cat", filepathUnclean(state.LogPath)) + + showFixes(fixedUp, fixedDown) return nil } -func down(state *State) error { +func down(state *State, n int) error { lines := make([]string, len(state.Lines)) copy(lines, state.Lines) - lineCount := len(lines) slices.Reverse(lines) - var batchLine string - var batch []string + + getMigsPath := filepath.Join(state.MigrationsDir, LOG_QUERY_NAME) + getMigsPath = filepathUnclean(getMigsPath) + getMigs := strings.Replace(state.SQLCommand, "%s", getMigsPath, 1) + + if len(lines) == 0 { + fmt.Fprintf(os.Stderr, "# No migration history\n") + fmt.Fprintf(os.Stderr, "#\n") + fmt.Fprintf(os.Stderr, "# To reload the migrations log:\n") + fmt.Fprintf(os.Stderr, "%s\n", "# "+getMigs+" > "+filepathUnclean(state.LogPath)) + return nil + } + if n == 0 { + n = 1 + } + + fixedUp := []string{} + fixedDown := []string{} + + var applied []string for _, line := range lines { - lineCount-- - if batchStartRe.MatchString(line) { - batchLine = line - break - } - mig := commentStartRe.ReplaceAllString(line, "") - mig = strings.TrimSpace(mig) - if mig == "" { - log.Printf(" ignoring '%s'", line) + migration := commentStartRe.ReplaceAllString(line, "") + migration = strings.TrimSpace(migration) + if migration == "" { continue } - batch = append(batch, mig) + applied = append(applied, migration) } - log.Printf("ROLLBACK %s", batchLine) - for _, mig := range batch { - fmt.Println("") - fmt.Printf("# DELETE FROM \"migrations\" WHERE \"name\" = '%s';\n", mig) - sqlfile := filepath.Join(state.SqlDir, mig+".down.sql") - if !fileExists(sqlfile) { - log.Printf(" missing '%s'", sqlfile) + + fmt.Printf(shHeader) + fmt.Println("") + fmt.Println("# ROLLBACK / DOWN Migration") + fmt.Println("") + for i, migration := range applied { + if i >= n { + break } - cmd := strings.Replace(state.Command, "%s", sqlfile, 1) + + downPath := filepath.Join(state.MigrationsDir, migration+".down.sql") + cmd := strings.Replace(state.SQLCommand, "%s", downPath, 1) + fmt.Printf("\n# -%d %s\n", i+1, migration) + if !fileExists(downPath) { + fmt.Fprintf(os.Stderr, "# Warn: missing '%s'\n", downPath) + fmt.Fprintf(os.Stderr, "# (the migration will fail to run)\n") + fmt.Printf("# ERROR: MISSING FILE\n") + } else { + up, down, warn, err := fixupMigration(state.MigrationsDir, migration) + if warn != nil { + fmt.Fprintf(os.Stderr, "Warn: %s\n", warn) + } + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + } + if up { + fixedUp = append(fixedUp, migration) + } + if down { + fixedDown = append(fixedDown, migration) + } + } + fmt.Println(cmd) + fmt.Println(getMigs + " > " + filepathUnclean(state.LogPath)) + fmt.Println("") } - fmt.Println("") - fmt.Println("# new file as to not overwrite the file while reading") - fmt.Printf("head -n '%d' %s > %s.new\n", lineCount, state.LogFile, state.LogFile) - fmt.Printf("mv %s.new %s\n", state.LogFile, state.LogFile) - fmt.Println("") + fmt.Println("cat", filepathUnclean(state.LogPath)) + + showFixes(fixedUp, fixedDown) return nil } -func status(state *State) error { - lines := make([]string, len(state.Lines)) - copy(lines, state.Lines) - hasCommand := commandStartRe.MatchString(lines[0]) - if hasCommand { - lines = lines[1:] - } - slices.Reverse(lines) - var previous []string - for _, line := range lines { - previous = append([]string{line}, previous...) - if batchStartRe.MatchString(line) { - break - } - } - fmt.Fprintf(os.Stderr, "sqldir: %s\n", state.SqlDir) - fmt.Fprintf(os.Stderr, "logfile: %s\n", state.LogFile) - fmt.Fprintf(os.Stderr, "command: %s\n", state.Command) +func status(state *State, ups []string) error { + previous := make([]string, len(state.Lines)) + copy(previous, state.Lines) + slices.Reverse(previous) + + fmt.Fprintf(os.Stderr, "migrations_dir: %s\n", state.MigrationsDir) + fmt.Fprintf(os.Stderr, "migrations_log: %s\n", state.LogPath) + fmt.Fprintf(os.Stderr, "sql_command: %s\n", state.SQLCommand) fmt.Fprintf(os.Stderr, "\n") fmt.Printf("# previous: %d\n", len(previous)) for _, mig := range previous { @@ -314,10 +993,6 @@ func status(state *State) error { fmt.Println(" # (no previous migrations)") } fmt.Println("") - ups, _, err := listMigrations(state) - if err != nil { - return err - } var pending []string for _, mig := range ups { found := slices.Contains(state.Migrated, mig) @@ -334,174 +1009,3 @@ func status(state *State) error { } return nil } - -const helpText = ` -sql-migrate v1.0.2 - a feature-branch-friendly SQL migrator - -USAGE - sql-migrate [-d sqldir] [-f logfile] [args] - -EXAMPLE - sql-migrate init -d ./sql/migrations/ -f ./sql/migrations.log - sql-migrate create - sql-migrate status - sql-migrate up - sql-migrate down - sql-migrate list - -COMMANDS - init - inits sql dir and migration file, adding or updating the - default command - create - creates a new, canonically-named up/down file pair in the - migrations directory - status - shows the same output as if processing a forward-migration - for the most recent batch - up - processes the first 'up' migration file missing from the - migration state - down - rolls back the latest entry of the latest migration batch - (the whole batch if just one) - list - lists migrations - -OPTIONS - -d default: ./sql/migrations/ - -f default: ./sql/migrations.log - -NOTES - Migrations files are in the following format: - -_..sql - 2020-01-01-1000_init.up.sql - - The migration state file contains the client command template (defaults to - 'psql "$PG_URL" < %s'), followed by a list of batches identified by a batch - number comment and a list of migration file basenames and optional user - comments, such as: - # command: psql "$PG_URL" < %s - # batch: 1 - 2020-01-01-1000_init.up.sql # does a lot - 2020-01-01-1100_add-customer-tables.up.sql - # batch: 2 - # We did id! Finally! - 2020-01-01-2000_add-ALL-THE-TABLES.up.sql - - The 'create' generates an up/down pair of files using the current date and - the number 1000. If either file exists, the number is incremented by 1000 and - tried again, up to 9000, or throws the error "it's over 9000!" on failure. -` - -func main() { - if len(os.Args) < 2 { - //nolint - fmt.Printf("%s\n", helpText) - os.Exit(0) - } - - command := os.Args[1] - switch command { - case "help", "--help", - "version", "--version", "-V": - fmt.Printf("%s\n", helpText) - os.Exit(0) - default: - // do nothing - } - - fs := flag.NewFlagSet(command, flag.ExitOnError) - sqlDir := fs.String("d", defaultMigrationDir, "migrations directory") - logFile := fs.String("f", defaultMigrationLog, "migration log file") - if err := fs.Parse(os.Args[2:]); err != nil { - os.Exit(2) - } - - date := time.Now() - var state *State - var err error - - logText, err := os.ReadFile(*logFile) - if os.IsNotExist(err) { - if command != "init" { - log.Printf(" run 'init' first: missing '%s'", *logFile) - os.Exit(1) - } - text := fmt.Sprintf("# command: %s\n", defaultCommand) - dir := filepath.Dir(*logFile) - err = os.MkdirAll(*sqlDir, 0755) - if err != nil { - log.Fatal(err) - } - err = os.MkdirAll(dir, 0755) - if err != nil { - log.Fatal(err) - } - err = os.WriteFile(*logFile, []byte(text), 0644) - if err != nil { - log.Fatal(err) - } - log.Printf(" created '%s'", *logFile) - logText = []byte{} - } else if err != nil { - log.Fatal(err) - } - - state = parseLog(string(logText), date) - state.SqlDir = *sqlDir - state.LogFile = *logFile - - switch command { - case "init": - if len(logText) > 0 { - log.Printf(" found '%s'", *logFile) - } - case "create": - args := fs.Args() - if len(args) == 0 { - log.Fatal("create requires a description") - } - desc := strings.Join(args, " ") - desc = nonWordRe.ReplaceAllString(desc, " ") - desc = strings.TrimSpace(desc) - desc = nonWordRe.ReplaceAllString(desc, "-") - desc = strings.ToLower(desc) - err = create(state, desc) - if err != nil { - log.Fatal(err) - } - case "status": - err = status(state) - if err != nil { - log.Fatal(err) - } - case "list": - ups, downs, err := listMigrations(state) - if err != nil { - log.Fatal(err) - } - fmt.Println("Ups:") - if len(ups) == 0 { - fmt.Println(" (none)") - } - for _, u := range ups { - fmt.Println(u) - } - fmt.Println("Downs:") - if len(downs) == 0 { - fmt.Println(" (none)") - } - for _, d := range downs { - fmt.Println(d) - } - case "up": - err = up(state) - if err != nil { - log.Fatal(err) - } - case "down": - err = down(state) - if err != nil { - log.Fatal(err) - } - default: - log.Printf("unknown command %s", command) - fmt.Printf("%s\n", helpText) - os.Exit(1) - } -}