Files
ai-workflow/inbox/internal/store/sqlite/sqlite.go
T

358 lines
8.9 KiB
Go

package sqlite
import (
"context"
"database/sql"
"embed"
"fmt"
"os"
"path"
"path/filepath"
"sort"
"strconv"
"strings"
_ "modernc.org/sqlite"
"inbox/internal/base/timeutil"
)
type migration struct {
Version int
Name string
SQL string
}
//go:embed migrations/*.sql
var migrationFiles embed.FS
var migrations = mustLoadMigrations()
type Store struct {
db *sql.DB
clock timeutil.Clock
}
func Open(dbPath string, clock timeutil.Clock) (*Store, error) {
if clock == nil {
clock = timeutil.SystemClock{}
}
if err := os.MkdirAll(filepath.Dir(dbPath), 0755); err != nil {
return nil, fmt.Errorf("create db directory: %w", err)
}
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, fmt.Errorf("open sqlite db: %w", err)
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
if err := configure(db); err != nil {
_ = db.Close()
return nil, err
}
if err := Migrate(context.Background(), db, clock); err != nil {
_ = db.Close()
return nil, err
}
store := &Store{db: db, clock: clock}
if err := store.ensureBuiltinRoles(context.Background()); err != nil {
_ = db.Close()
return nil, err
}
if err := store.ensureBuiltinSkills(context.Background()); err != nil {
_ = db.Close()
return nil, err
}
return store, nil
}
func OpenInMemory(clock timeutil.Clock) (*Store, error) {
if clock == nil {
clock = timeutil.SystemClock{}
}
db, err := sql.Open("sqlite", ":memory:")
if err != nil {
return nil, fmt.Errorf("open sqlite memory db: %w", err)
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
if err := configure(db); err != nil {
_ = db.Close()
return nil, err
}
if err := Migrate(context.Background(), db, clock); err != nil {
_ = db.Close()
return nil, err
}
store := &Store{db: db, clock: clock}
if err := store.ensureBuiltinRoles(context.Background()); err != nil {
_ = db.Close()
return nil, err
}
if err := store.ensureBuiltinSkills(context.Background()); err != nil {
_ = db.Close()
return nil, err
}
return store, nil
}
func New(db *sql.DB, clock timeutil.Clock) *Store {
if clock == nil {
clock = timeutil.SystemClock{}
}
return &Store{db: db, clock: clock}
}
func (s *Store) DB() *sql.DB {
return s.db
}
func (s *Store) Close() error {
if s.db == nil {
return nil
}
return s.db.Close()
}
func configure(db *sql.DB) error {
if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil {
return fmt.Errorf("enable WAL: %w", err)
}
if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil {
return fmt.Errorf("enable foreign keys: %w", err)
}
if _, err := db.Exec("PRAGMA busy_timeout=5000"); err != nil {
return fmt.Errorf("set busy timeout: %w", err)
}
return nil
}
func Migrate(ctx context.Context, db *sql.DB, clock timeutil.Clock) error {
if clock == nil {
clock = timeutil.SystemClock{}
}
if err := ensureSchemaMigrationsTable(db); err != nil {
return err
}
applied, err := loadAppliedVersions(ctx, db)
if err != nil {
return err
}
if len(migrations) == 0 {
return nil
}
current := migrations[len(migrations)-1].Version
for version := range applied {
if version > current {
return fmt.Errorf("database schema version %d is newer than supported version %d", version, current)
}
}
for _, migration := range migrations {
if applied[migration.Version] {
continue
}
if err := applyMigration(ctx, db, migration, clock); err != nil {
return err
}
}
return nil
}
func ensureSchemaMigrationsTable(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
name TEXT NOT NULL,
applied_at TEXT NOT NULL
)
`)
if err != nil {
return fmt.Errorf("create schema_migrations table: %w", err)
}
return nil
}
func loadAppliedVersions(ctx context.Context, db *sql.DB) (map[int]bool, error) {
rows, err := db.QueryContext(ctx, `SELECT version FROM schema_migrations`)
if err != nil {
return nil, fmt.Errorf("load schema migrations: %w", err)
}
defer rows.Close()
applied := make(map[int]bool)
for rows.Next() {
var version int
if err := rows.Scan(&version); err != nil {
return nil, fmt.Errorf("scan schema migration version: %w", err)
}
applied[version] = true
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate schema migrations: %w", err)
}
return applied, nil
}
func applyMigration(ctx context.Context, db *sql.DB, m migration, clock timeutil.Clock) error {
if m.Version == 12 {
shouldApply, err := shouldApplyLaneRenameMigration(ctx, db)
if err != nil {
return fmt.Errorf("preflight migration %s: %w", m.Name, err)
}
if !shouldApply {
if _, err := db.ExecContext(ctx, `INSERT INTO schema_migrations(version, name, applied_at) VALUES(?, ?, ?)`,
m.Version, m.Name, timeutil.FormatRFC3339(clock.Now())); err != nil {
return fmt.Errorf("record skipped migration %s: %w", m.Name, err)
}
return nil
}
}
if m.Version == 32 {
shouldApply, err := shouldApplyDropSkillMetadataColumnsMigration(ctx, db)
if err != nil {
return fmt.Errorf("preflight migration %s: %w", m.Name, err)
}
if !shouldApply {
if _, err := db.ExecContext(ctx, `INSERT INTO schema_migrations(version, name, applied_at) VALUES(?, ?, ?)`,
m.Version, m.Name, timeutil.FormatRFC3339(clock.Now())); err != nil {
return fmt.Errorf("record skipped migration %s: %w", m.Name, err)
}
return nil
}
}
if m.Version == 30 {
return applyRoleConfigCollapseMigration(ctx, db, m, clock)
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", m.Name, err)
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, m.SQL); err != nil {
return fmt.Errorf("apply migration %s: %w", m.Name, err)
}
if _, err := tx.ExecContext(
ctx,
`INSERT INTO schema_migrations(version, name, applied_at) VALUES(?, ?, ?)`,
m.Version,
m.Name,
timeutil.FormatRFC3339(clock.Now()),
); err != nil {
return fmt.Errorf("record migration %s: %w", m.Name, err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit migration %s: %w", m.Name, err)
}
return nil
}
func shouldApplyDropSkillMetadataColumnsMigration(ctx context.Context, db *sql.DB) (bool, error) {
rows, err := db.QueryContext(ctx, `PRAGMA table_info(skills)`)
if err != nil {
return false, fmt.Errorf("inspect skills table: %w", err)
}
defer rows.Close()
hasSourceRef := false
hasAssetRoot := false
for rows.Next() {
var cid int
var name string
var dataType string
var notNull int
var defaultValue sql.NullString
var pk int
if err := rows.Scan(&cid, &name, &dataType, &notNull, &defaultValue, &pk); err != nil {
return false, fmt.Errorf("scan skills table info: %w", err)
}
switch name {
case "source_ref":
hasSourceRef = true
case "asset_root":
hasAssetRoot = true
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("iterate skills table info: %w", err)
}
return hasSourceRef || hasAssetRoot, nil
}
func shouldApplyLaneRenameMigration(ctx context.Context, db *sql.DB) (bool, error) {
hasChains, err := tableExists(ctx, db, "chains")
if err != nil {
return false, err
}
hasLanes, err := tableExists(ctx, db, "lanes")
if err != nil {
return false, err
}
if hasChains {
return true, nil
}
if hasLanes {
return false, nil
}
return false, nil
}
func tableExists(ctx context.Context, db *sql.DB, name string) (bool, error) {
var count int
if err := db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM sqlite_master
WHERE type = 'table' AND name = ?
`, name).Scan(&count); err != nil {
return false, err
}
return count > 0, nil
}
func mustLoadMigrations() []migration {
entries, err := migrationFiles.ReadDir("migrations")
if err != nil {
panic(fmt.Sprintf("read migrations: %v", err))
}
out := make([]migration, 0, len(entries))
seen := make(map[int]string, len(entries))
for _, entry := range entries {
if entry.IsDir() || path.Ext(entry.Name()) != ".sql" {
continue
}
parts := strings.SplitN(entry.Name(), "_", 2)
if len(parts) != 2 {
panic(fmt.Sprintf("migration file %q must start with a numeric prefix", entry.Name()))
}
version, err := strconv.Atoi(parts[0])
if err != nil || version <= 0 {
panic(fmt.Sprintf("migration file %q has invalid version prefix", entry.Name()))
}
if prior, exists := seen[version]; exists {
panic(fmt.Sprintf("duplicate migration version %d: %q and %q", version, prior, entry.Name()))
}
body, err := migrationFiles.ReadFile(path.Join("migrations", entry.Name()))
if err != nil {
panic(fmt.Sprintf("read migration %q: %v", entry.Name(), err))
}
seen[version] = entry.Name()
out = append(out, migration{
Version: version,
Name: entry.Name(),
SQL: string(body),
})
}
sort.Slice(out, func(i, j int) bool {
return out[i].Version < out[j].Version
})
for idx, item := range out {
expected := idx + 1
if item.Version != expected {
panic(fmt.Sprintf("migrations must be contiguous: expected version %d, found %d (%s)", expected, item.Version, item.Name))
}
}
return out
}