package sqlite import ( "context" "database/sql" "embed" "fmt" "os" "path" "path/filepath" "sort" "strconv" "strings" _ "modernc.org/sqlite" "inbox/internal/base/timeutil" ) type migration struct { Version int Name string SQL string } //go:embed migrations/*.sql var migrationFiles embed.FS var migrations = mustLoadMigrations() type Store struct { db *sql.DB clock timeutil.Clock } func Open(dbPath string, clock timeutil.Clock) (*Store, error) { if clock == nil { clock = timeutil.SystemClock{} } if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil { return nil, fmt.Errorf("create db directory: %w", err) } db, err := sql.Open("sqlite", dbPath) if err != nil { return nil, fmt.Errorf("open sqlite db: %w", err) } db.SetMaxOpenConns(1) db.SetMaxIdleConns(1) if err := configure(db); err != nil { _ = db.Close() return nil, err } if err := Migrate(context.Background(), db, clock); err != nil { _ = db.Close() return nil, err } store := &Store{db: db, clock: clock} if err := store.ensureBuiltinRoles(context.Background()); err != nil { _ = db.Close() return nil, err } if err := store.ensureBuiltinSkills(context.Background()); err != nil { _ = db.Close() return nil, err } return store, nil } func OpenInMemory(clock timeutil.Clock) (*Store, error) { if clock == nil { clock = timeutil.SystemClock{} } db, err := sql.Open("sqlite", ":memory:") if err != nil { return nil, fmt.Errorf("open sqlite memory db: %w", err) } db.SetMaxOpenConns(1) db.SetMaxIdleConns(1) if err := configure(db); err != nil { _ = db.Close() return nil, err } if err := Migrate(context.Background(), db, clock); err != nil { _ = db.Close() return nil, err } store := &Store{db: db, clock: clock} if err := store.ensureBuiltinRoles(context.Background()); err != nil { _ = db.Close() return nil, err } if err := store.ensureBuiltinSkills(context.Background()); err != nil { _ = db.Close() return nil, err } return store, nil } func New(db *sql.DB, clock timeutil.Clock) *Store { if clock == nil { clock = timeutil.SystemClock{} } return &Store{db: db, clock: clock} } func (s *Store) DB() *sql.DB { return s.db } func (s *Store) Close() error { if s.db == nil { return nil } return s.db.Close() } func configure(db *sql.DB) error { if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { return fmt.Errorf("enable WAL: %w", err) } if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil { return fmt.Errorf("enable foreign keys: %w", err) } if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil { return fmt.Errorf("set busy timeout: %w", err) } return nil } func Migrate(ctx context.Context, db *sql.DB, clock timeutil.Clock) error { if clock == nil { clock = timeutil.SystemClock{} } if err := ensureSchemaMigrationsTable(db); err != nil { return err } applied, err := loadAppliedVersions(ctx, db) if err != nil { return err } if len(migrations) == 0 { return nil } current := migrations[len(migrations)-1].Version for version := range applied { if version > current { return fmt.Errorf("database schema version %d is newer than supported version %d", version, current) } } for _, migration := range migrations { if applied[migration.Version] { continue } if err := applyMigration(ctx, db, migration, clock); err != nil { return err } } return nil } func ensureSchemaMigrationsTable(db *sql.DB) error { _, err := db.Exec(` CREATE TABLE IF NOT EXISTS schema_migrations ( version INTEGER PRIMARY KEY, name TEXT NOT NULL, applied_at TEXT NOT NULL ) `) if err != nil { return fmt.Errorf("create schema_migrations table: %w", err) } return nil } func loadAppliedVersions(ctx context.Context, db *sql.DB) (map[int]bool, error) { rows, err := db.QueryContext(ctx, `SELECT version FROM schema_migrations`) if err != nil { return nil, fmt.Errorf("load schema migrations: %w", err) } defer rows.Close() applied := make(map[int]bool) for rows.Next() { var version int if err := rows.Scan(&version); err != nil { return nil, fmt.Errorf("scan schema migration version: %w", err) } applied[version] = true } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate schema migrations: %w", err) } return applied, nil } func applyMigration(ctx context.Context, db *sql.DB, m migration, clock timeutil.Clock) error { if m.Version == 12 { shouldApply, err := shouldApplyLaneRenameMigration(ctx, db) if err != nil { return fmt.Errorf("preflight migration %s: %w", m.Name, err) } if !shouldApply { if _, err := db.ExecContext(ctx, `INSERT INTO schema_migrations(version, name, applied_at) VALUES(?, ?, ?)`, m.Version, m.Name, timeutil.FormatRFC3339(clock.Now())); err != nil { return fmt.Errorf("record skipped migration %s: %w", m.Name, err) } return nil } } if m.Version == 32 { shouldApply, err := shouldApplyDropSkillMetadataColumnsMigration(ctx, db) if err != nil { return fmt.Errorf("preflight migration %s: %w", m.Name, err) } if !shouldApply { if _, err := db.ExecContext(ctx, `INSERT INTO schema_migrations(version, name, applied_at) VALUES(?, ?, ?)`, m.Version, m.Name, timeutil.FormatRFC3339(clock.Now())); err != nil { return fmt.Errorf("record skipped migration %s: %w", m.Name, err) } return nil } } if m.Version == 30 { return applyRoleConfigCollapseMigration(ctx, db, m, clock) } tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin migration %s: %w", m.Name, err) } defer tx.Rollback() if _, err := tx.ExecContext(ctx, m.SQL); err != nil { return fmt.Errorf("apply migration %s: %w", m.Name, err) } if _, err := tx.ExecContext( ctx, `INSERT INTO schema_migrations(version, name, applied_at) VALUES(?, ?, ?)`, m.Version, m.Name, timeutil.FormatRFC3339(clock.Now()), ); err != nil { return fmt.Errorf("record migration %s: %w", m.Name, err) } if err := tx.Commit(); err != nil { return fmt.Errorf("commit migration %s: %w", m.Name, err) } return nil } func shouldApplyDropSkillMetadataColumnsMigration(ctx context.Context, db *sql.DB) (bool, error) { rows, err := db.QueryContext(ctx, `PRAGMA table_info(skills)`) if err != nil { return false, fmt.Errorf("inspect skills table: %w", err) } defer rows.Close() hasSourceRef := false hasAssetRoot := false for rows.Next() { var cid int var name string var dataType string var notNull int var defaultValue sql.NullString var pk int if err := rows.Scan(&cid, &name, &dataType, ¬Null, &defaultValue, &pk); err != nil { return false, fmt.Errorf("scan skills table info: %w", err) } switch name { case "source_ref": hasSourceRef = true case "asset_root": hasAssetRoot = true } } if err := rows.Err(); err != nil { return false, fmt.Errorf("iterate skills table info: %w", err) } return hasSourceRef || hasAssetRoot, nil } func shouldApplyLaneRenameMigration(ctx context.Context, db *sql.DB) (bool, error) { hasChains, err := tableExists(ctx, db, "chains") if err != nil { return false, err } hasLanes, err := tableExists(ctx, db, "lanes") if err != nil { return false, err } if hasChains { return true, nil } if hasLanes { return false, nil } return false, nil } func tableExists(ctx context.Context, db *sql.DB, name string) (bool, error) { var count int if err := db.QueryRowContext(ctx, ` SELECT COUNT(*) FROM sqlite_master WHERE type = 'table' AND name = ? `, name).Scan(&count); err != nil { return false, err } return count > 0, nil } func mustLoadMigrations() []migration { entries, err := migrationFiles.ReadDir("migrations") if err != nil { panic(fmt.Sprintf("read migrations: %v", err)) } out := make([]migration, 0, len(entries)) seen := make(map[int]string, len(entries)) for _, entry := range entries { if entry.IsDir() || path.Ext(entry.Name()) != ".sql" { continue } parts := strings.SplitN(entry.Name(), "_", 2) if len(parts) != 2 { panic(fmt.Sprintf("migration file %q must start with a numeric prefix", entry.Name())) } version, err := strconv.Atoi(parts[0]) if err != nil || version <= 0 { panic(fmt.Sprintf("migration file %q has invalid version prefix", entry.Name())) } if prior, exists := seen[version]; exists { panic(fmt.Sprintf("duplicate migration version %d: %q and %q", version, prior, entry.Name())) } body, err := migrationFiles.ReadFile(path.Join("migrations", entry.Name())) if err != nil { panic(fmt.Sprintf("read migration %q: %v", entry.Name(), err)) } seen[version] = entry.Name() out = append(out, migration{ Version: version, Name: entry.Name(), SQL: string(body), }) } sort.Slice(out, func(i, j int) bool { return out[i].Version < out[j].Version }) for idx, item := range out { expected := idx + 1 if item.Version != expected { panic(fmt.Sprintf("migrations must be contiguous: expected version %d, found %d (%s)", expected, item.Version, item.Name)) } } return out }