358 lines
8.9 KiB
Go
358 lines
8.9 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
"os"
|
|
"path"
|
|
"path/filepath"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
_ "modernc.org/sqlite"
|
|
|
|
"inbox/internal/base/timeutil"
|
|
)
|
|
|
|
type migration struct {
|
|
Version int
|
|
Name string
|
|
SQL string
|
|
}
|
|
|
|
//go:embed migrations/*.sql
|
|
var migrationFiles embed.FS
|
|
|
|
var migrations = mustLoadMigrations()
|
|
|
|
type Store struct {
|
|
db *sql.DB
|
|
clock timeutil.Clock
|
|
}
|
|
|
|
func Open(dbPath string, clock timeutil.Clock) (*Store, error) {
|
|
if clock == nil {
|
|
clock = timeutil.SystemClock{}
|
|
}
|
|
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
|
|
return nil, fmt.Errorf("create db directory: %w", err)
|
|
}
|
|
db, err := sql.Open("sqlite", dbPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open sqlite db: %w", err)
|
|
}
|
|
db.SetMaxOpenConns(1)
|
|
db.SetMaxIdleConns(1)
|
|
if err := configure(db); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
if err := Migrate(context.Background(), db, clock); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
store := &Store{db: db, clock: clock}
|
|
if err := store.ensureBuiltinRoles(context.Background()); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
if err := store.ensureBuiltinSkills(context.Background()); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
return store, nil
|
|
}
|
|
|
|
func OpenInMemory(clock timeutil.Clock) (*Store, error) {
|
|
if clock == nil {
|
|
clock = timeutil.SystemClock{}
|
|
}
|
|
db, err := sql.Open("sqlite", ":memory:")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open sqlite memory db: %w", err)
|
|
}
|
|
db.SetMaxOpenConns(1)
|
|
db.SetMaxIdleConns(1)
|
|
if err := configure(db); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
if err := Migrate(context.Background(), db, clock); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
store := &Store{db: db, clock: clock}
|
|
if err := store.ensureBuiltinRoles(context.Background()); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
if err := store.ensureBuiltinSkills(context.Background()); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
return store, nil
|
|
}
|
|
|
|
func New(db *sql.DB, clock timeutil.Clock) *Store {
|
|
if clock == nil {
|
|
clock = timeutil.SystemClock{}
|
|
}
|
|
return &Store{db: db, clock: clock}
|
|
}
|
|
|
|
func (s *Store) DB() *sql.DB {
|
|
return s.db
|
|
}
|
|
|
|
func (s *Store) Close() error {
|
|
if s.db == nil {
|
|
return nil
|
|
}
|
|
return s.db.Close()
|
|
}
|
|
|
|
func configure(db *sql.DB) error {
|
|
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
|
|
return fmt.Errorf("enable WAL: %w", err)
|
|
}
|
|
if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil {
|
|
return fmt.Errorf("enable foreign keys: %w", err)
|
|
}
|
|
if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil {
|
|
return fmt.Errorf("set busy timeout: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func Migrate(ctx context.Context, db *sql.DB, clock timeutil.Clock) error {
|
|
if clock == nil {
|
|
clock = timeutil.SystemClock{}
|
|
}
|
|
if err := ensureSchemaMigrationsTable(db); err != nil {
|
|
return err
|
|
}
|
|
applied, err := loadAppliedVersions(ctx, db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if len(migrations) == 0 {
|
|
return nil
|
|
}
|
|
current := migrations[len(migrations)-1].Version
|
|
for version := range applied {
|
|
if version > current {
|
|
return fmt.Errorf("database schema version %d is newer than supported version %d", version, current)
|
|
}
|
|
}
|
|
|
|
for _, migration := range migrations {
|
|
if applied[migration.Version] {
|
|
continue
|
|
}
|
|
if err := applyMigration(ctx, db, migration, clock); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ensureSchemaMigrationsTable(db *sql.DB) error {
|
|
_, err := db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version INTEGER PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
applied_at TEXT NOT NULL
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return fmt.Errorf("create schema_migrations table: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func loadAppliedVersions(ctx context.Context, db *sql.DB) (map[int]bool, error) {
|
|
rows, err := db.QueryContext(ctx, `SELECT version FROM schema_migrations`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("load schema migrations: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
applied := make(map[int]bool)
|
|
for rows.Next() {
|
|
var version int
|
|
if err := rows.Scan(&version); err != nil {
|
|
return nil, fmt.Errorf("scan schema migration version: %w", err)
|
|
}
|
|
applied[version] = true
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate schema migrations: %w", err)
|
|
}
|
|
return applied, nil
|
|
}
|
|
|
|
func applyMigration(ctx context.Context, db *sql.DB, m migration, clock timeutil.Clock) error {
|
|
if m.Version == 12 {
|
|
shouldApply, err := shouldApplyLaneRenameMigration(ctx, db)
|
|
if err != nil {
|
|
return fmt.Errorf("preflight migration %s: %w", m.Name, err)
|
|
}
|
|
if !shouldApply {
|
|
if _, err := db.ExecContext(ctx, `INSERT INTO schema_migrations(version, name, applied_at) VALUES(?, ?, ?)`,
|
|
m.Version, m.Name, timeutil.FormatRFC3339(clock.Now())); err != nil {
|
|
return fmt.Errorf("record skipped migration %s: %w", m.Name, err)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
if m.Version == 32 {
|
|
shouldApply, err := shouldApplyDropSkillMetadataColumnsMigration(ctx, db)
|
|
if err != nil {
|
|
return fmt.Errorf("preflight migration %s: %w", m.Name, err)
|
|
}
|
|
if !shouldApply {
|
|
if _, err := db.ExecContext(ctx, `INSERT INTO schema_migrations(version, name, applied_at) VALUES(?, ?, ?)`,
|
|
m.Version, m.Name, timeutil.FormatRFC3339(clock.Now())); err != nil {
|
|
return fmt.Errorf("record skipped migration %s: %w", m.Name, err)
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
if m.Version == 30 {
|
|
return applyRoleConfigCollapseMigration(ctx, db, m, clock)
|
|
}
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("begin migration %s: %w", m.Name, err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := tx.ExecContext(ctx, m.SQL); err != nil {
|
|
return fmt.Errorf("apply migration %s: %w", m.Name, err)
|
|
}
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO schema_migrations(version, name, applied_at) VALUES(?, ?, ?)`,
|
|
m.Version,
|
|
m.Name,
|
|
timeutil.FormatRFC3339(clock.Now()),
|
|
); err != nil {
|
|
return fmt.Errorf("record migration %s: %w", m.Name, err)
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return fmt.Errorf("commit migration %s: %w", m.Name, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func shouldApplyDropSkillMetadataColumnsMigration(ctx context.Context, db *sql.DB) (bool, error) {
|
|
rows, err := db.QueryContext(ctx, `PRAGMA table_info(skills)`)
|
|
if err != nil {
|
|
return false, fmt.Errorf("inspect skills table: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
hasSourceRef := false
|
|
hasAssetRoot := false
|
|
for rows.Next() {
|
|
var cid int
|
|
var name string
|
|
var dataType string
|
|
var notNull int
|
|
var defaultValue sql.NullString
|
|
var pk int
|
|
if err := rows.Scan(&cid, &name, &dataType, ¬Null, &defaultValue, &pk); err != nil {
|
|
return false, fmt.Errorf("scan skills table info: %w", err)
|
|
}
|
|
switch name {
|
|
case "source_ref":
|
|
hasSourceRef = true
|
|
case "asset_root":
|
|
hasAssetRoot = true
|
|
}
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return false, fmt.Errorf("iterate skills table info: %w", err)
|
|
}
|
|
return hasSourceRef || hasAssetRoot, nil
|
|
}
|
|
|
|
func shouldApplyLaneRenameMigration(ctx context.Context, db *sql.DB) (bool, error) {
|
|
hasChains, err := tableExists(ctx, db, "chains")
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
hasLanes, err := tableExists(ctx, db, "lanes")
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if hasChains {
|
|
return true, nil
|
|
}
|
|
if hasLanes {
|
|
return false, nil
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func tableExists(ctx context.Context, db *sql.DB, name string) (bool, error) {
|
|
var count int
|
|
if err := db.QueryRowContext(ctx, `
|
|
SELECT COUNT(*)
|
|
FROM sqlite_master
|
|
WHERE type = 'table' AND name = ?
|
|
`, name).Scan(&count); err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
func mustLoadMigrations() []migration {
|
|
entries, err := migrationFiles.ReadDir("migrations")
|
|
if err != nil {
|
|
panic(fmt.Sprintf("read migrations: %v", err))
|
|
}
|
|
|
|
out := make([]migration, 0, len(entries))
|
|
seen := make(map[int]string, len(entries))
|
|
for _, entry := range entries {
|
|
if entry.IsDir() || path.Ext(entry.Name()) != ".sql" {
|
|
continue
|
|
}
|
|
parts := strings.SplitN(entry.Name(), "_", 2)
|
|
if len(parts) != 2 {
|
|
panic(fmt.Sprintf("migration file %q must start with a numeric prefix", entry.Name()))
|
|
}
|
|
version, err := strconv.Atoi(parts[0])
|
|
if err != nil || version <= 0 {
|
|
panic(fmt.Sprintf("migration file %q has invalid version prefix", entry.Name()))
|
|
}
|
|
if prior, exists := seen[version]; exists {
|
|
panic(fmt.Sprintf("duplicate migration version %d: %q and %q", version, prior, entry.Name()))
|
|
}
|
|
body, err := migrationFiles.ReadFile(path.Join("migrations", entry.Name()))
|
|
if err != nil {
|
|
panic(fmt.Sprintf("read migration %q: %v", entry.Name(), err))
|
|
}
|
|
seen[version] = entry.Name()
|
|
out = append(out, migration{
|
|
Version: version,
|
|
Name: entry.Name(),
|
|
SQL: string(body),
|
|
})
|
|
}
|
|
sort.Slice(out, func(i, j int) bool {
|
|
return out[i].Version < out[j].Version
|
|
})
|
|
for idx, item := range out {
|
|
expected := idx + 1
|
|
if item.Version != expected {
|
|
panic(fmt.Sprintf("migrations must be contiguous: expected version %d, found %d (%s)", expected, item.Version, item.Name))
|
|
}
|
|
}
|
|
return out
|
|
}
|