diff --git a/internal/cli/inbox/integration_test.go b/internal/cli/inbox/integration_test.go index b916fb2..2ce6827 100644 --- a/internal/cli/inbox/integration_test.go +++ b/internal/cli/inbox/integration_test.go @@ -608,14 +608,28 @@ func TestInboxJSONErrorsAndExitCodes(t *testing.T) { "send", "--from", "leader", "--to", "worker-z", - "--subject", "Invalid body flags", - "--body", "inline", - "--body-file", filepath.Join(t.TempDir(), "missing.md"), + "--subject", "Invalid payload json", + "--payload-json", "not-json", ) if exitCode != 30 { t.Fatalf("expected invalid input exit code 30, got %d", exitCode) } assertErrorJSON(t, stdout, "invalid_input") + + stdout, _, exitCode = executeInboxCommand( + "--db", dbPath, + "--json", + "send", + "--from", "leader", + "--to", "worker-z", + "--subject", "Invalid artifact json", + "--artifact", "/tmp/report.md", + "--artifact-metadata-json", "not-json", + ) + if exitCode != 30 { + t.Fatalf("expected invalid artifact metadata exit code 30, got %d", exitCode) + } + assertErrorJSON(t, stdout, "invalid_input") } func runInboxCommand(t *testing.T, args ...string) string { diff --git a/internal/store/inbox.go b/internal/store/inbox.go index 23f364c..5120cd6 100644 --- a/internal/store/inbox.go +++ b/internal/store/inbox.go @@ -1,6 +1,7 @@ package store import ( + "bytes" "context" "database/sql" "encoding/json" @@ -226,7 +227,10 @@ func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread, kind := defaultString(input.Kind, "task") priority := defaultString(input.Priority, "normal") summary := defaultString(input.Summary, input.Subject) - payload := normalizeJSON(input.PayloadJSON) + payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON) + if err != nil { + return Thread{}, Message{}, err + } messageID := newID("msg") tx, err := s.db.BeginTx(ctx, nil) @@ -314,7 +318,10 @@ func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread, func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, input SendInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") - payload := normalizeJSON(input.PayloadJSON) + payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON) + if err != nil { + return Thread{}, Message{}, err + } tx, err := s.db.BeginTx(ctx, nil) if err != nil { @@ -480,6 +487,38 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe input.LeaseSeconds = 900 } + var lastBusyErr error + for attempt := 0; attempt < 20; attempt++ { + result, err := s.claimThreadOnce(ctx, input) + if err == nil { + return result, nil + } + if !isSQLiteBusyError(err) { + return ClaimResult{}, err + } + lastBusyErr = err + + ok, waitErr := waitForNextPoll(ctx, 25*time.Millisecond) + if waitErr != nil { + return ClaimResult{}, waitErr + } + if !ok { + break + } + } + + if resolvedErr := s.classifyClaimConflict(ctx, input.ThreadID); resolvedErr != nil { + return ClaimResult{}, resolvedErr + } + + return ClaimResult{}, fmt.Errorf("claim thread: %w", lastBusyErr) +} + +func (s *InboxStore) claimThreadOnce(ctx context.Context, input ClaimInput) (ClaimResult, error) { + if input.LeaseSeconds <= 0 { + input.LeaseSeconds = 900 + } + now := nowUTC() expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second) leaseToken := newID("lease") @@ -519,7 +558,7 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID) } - if _, err := tx.ExecContext( + result, err := tx.ExecContext( ctx, `INSERT INTO leases ( thread_id, agent_id, lease_token, claimed_at, expires_at, released_at @@ -529,29 +568,41 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe lease_token = excluded.lease_token, claimed_at = excluded.claimed_at, expires_at = excluded.expires_at, - released_at = NULL`, + released_at = NULL + WHERE leases.released_at IS NOT NULL + OR leases.expires_at <= excluded.claimed_at`, input.ThreadID, input.Agent, leaseToken, formatTime(now), formatTime(expiresAt), - ); err != nil { + ) + if err != nil { return ClaimResult{}, fmt.Errorf("upsert lease: %w", err) } + if affected, err := result.RowsAffected(); err == nil && affected == 0 { + return ClaimResult{}, ErrLeaseConflict + } - if _, err := tx.ExecContext( + result, err = tx.ExecContext( ctx, `UPDATE threads SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ? - WHERE thread_id = ?`, + WHERE thread_id = ? + AND status = ?`, "claimed", input.Agent, messageID, formatTime(now), input.ThreadID, - ); err != nil { + "pending", + ) + if err != nil { return ClaimResult{}, fmt.Errorf("update thread claim status: %w", err) } + if affected, err := result.RowsAffected(); err == nil && affected == 0 { + return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID) + } message := Message{ MessageID: messageID, @@ -702,6 +753,10 @@ func (s *InboxStore) RenewLease(ctx context.Context, input RenewInput) (ClaimRes func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") + payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON) + if err != nil { + return Thread{}, Message{}, err + } if input.Status != "in_progress" && input.Status != "blocked" { return Thread{}, Message{}, fmt.Errorf("%w: unsupported update status %q", ErrInvalidInput, input.Status) @@ -737,7 +792,7 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) Kind: kind, Summary: input.Summary, Body: input.Body, - PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), + PayloadJSON: json.RawMessage(payload), CreatedAt: now, } @@ -781,6 +836,10 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") + payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON) + if err != nil { + return Thread{}, Message{}, err + } tx, err := s.db.BeginTx(ctx, nil) if err != nil { @@ -804,7 +863,7 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa Kind: defaultString(input.Kind, "answer"), Summary: input.Summary, Body: input.Body, - PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), + PayloadJSON: json.RawMessage(payload), CreatedAt: now, } @@ -847,6 +906,10 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") + payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON) + if err != nil { + return Thread{}, Message{}, err + } nextStatus := "done" eventType := "thread_done" @@ -881,7 +944,7 @@ func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (T Kind: "result", Summary: summary, Body: input.Body, - PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), + PayloadJSON: json.RawMessage(payload), CreatedAt: now, } @@ -1365,16 +1428,21 @@ func insertArtifacts(ctx context.Context, tx *sql.Tx, messageID string, inputs [ artifacts := make([]Artifact, 0, len(inputs)) for _, input := range inputs { + metadataJSON, err := validateAndNormalizeJSON("artifact-metadata-json", input.MetadataJSON) + if err != nil { + return nil, err + } + artifact := Artifact{ ArtifactID: newID("art"), MessageID: messageID, Path: input.Path, Kind: defaultString(input.Kind, "file"), - MetadataJSON: json.RawMessage(normalizeJSON(input.MetadataJSON)), + MetadataJSON: json.RawMessage(metadataJSON), CreatedAt: createdAt, } - _, err := tx.ExecContext( + _, err = tx.ExecContext( ctx, `INSERT INTO artifacts ( artifact_id, message_id, path, kind, metadata_json, created_at @@ -1468,6 +1536,37 @@ func messageIDs(messages []Message) []string { return ids } +func (s *InboxStore) classifyClaimConflict(ctx context.Context, threadID string) error { + thread, err := selectThread(ctx, s.db, threadID) + if err != nil { + return err + } + + now := nowUTC() + var activeLease string + err = s.db.QueryRowContext( + ctx, + `SELECT agent_id + FROM leases + WHERE thread_id = ? + AND released_at IS NULL + AND expires_at > ?`, + threadID, + formatTime(now), + ).Scan(&activeLease) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("check active lease after busy claim: %w", err) + } + if activeLease != "" { + return ErrLeaseConflict + } + if thread.Status != "pending" { + return fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, threadID) + } + + return nil +} + func requireActiveLease(ctx context.Context, tx *sql.Tx, threadID, agent string, now time.Time) (string, error) { var ( activeAgent string @@ -1703,6 +1802,13 @@ func isDeadlineExceeded(ctx context.Context) bool { return ctx.Err() != nil && errors.Is(ctx.Err(), context.DeadlineExceeded) } +func isSQLiteBusyError(err error) bool { + message := strings.ToLower(err.Error()) + return strings.Contains(message, "sqlite_busy") || + strings.Contains(message, "database is locked") || + strings.Contains(message, "database table is locked") +} + func defaultID(value, prefix string) string { if value != "" { return value @@ -1728,6 +1834,20 @@ func normalizeJSON(value string) string { return value } +func validateAndNormalizeJSON(fieldName, value string) (string, error) { + normalized := normalizeJSON(value) + if !json.Valid([]byte(normalized)) { + return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName) + } + + var compact bytes.Buffer + if err := json.Compact(&compact, []byte(normalized)); err != nil { + return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName) + } + + return compact.String(), nil +} + func placeholders(n int) string { if n <= 0 { return "" diff --git a/internal/store/inbox_test.go b/internal/store/inbox_test.go new file mode 100644 index 0000000..53ed9c9 --- /dev/null +++ b/internal/store/inbox_test.go @@ -0,0 +1,107 @@ +package store + +import ( + "context" + "errors" + "path/filepath" + "testing" + "time" + + dbpkg "ai-workflow-skill/internal/db" +) + +func TestClaimThreadReturnsLeaseConflictAfterBusyWrite(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + dbPath := filepath.Join(t.TempDir(), "coord.db") + + sqlDB, err := dbpkg.Open(ctx, dbPath) + if err != nil { + t.Fatalf("open base db: %v", err) + } + defer sqlDB.Close() + + if err := dbpkg.ApplyMigrations(ctx, sqlDB); err != nil { + t.Fatalf("apply migrations: %v", err) + } + + baseStore := NewInboxStore(sqlDB) + thread, _, err := baseStore.Send(ctx, SendInput{ + FromAgent: "leader", + ToAgent: "worker-a", + Subject: "race claim", + Summary: "race claim", + }) + if err != nil { + t.Fatalf("seed thread: %v", err) + } + + lockerDB, err := dbpkg.Open(ctx, dbPath) + if err != nil { + t.Fatalf("open locker db: %v", err) + } + defer lockerDB.Close() + + lockTx, err := lockerDB.BeginTx(ctx, nil) + if err != nil { + t.Fatalf("begin locker tx: %v", err) + } + + now := nowUTC() + if _, err := lockTx.ExecContext( + ctx, + `INSERT INTO leases ( + thread_id, agent_id, lease_token, claimed_at, expires_at, released_at + ) VALUES (?, ?, ?, ?, ?, NULL)`, + thread.ThreadID, + "worker-a", + "lease_locked", + formatTime(now), + formatTime(now.Add(5*time.Minute)), + ); err != nil { + t.Fatalf("seed active lease in tx: %v", err) + } + + if _, err := lockTx.ExecContext( + ctx, + `UPDATE threads + SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ? + WHERE thread_id = ?`, + "claimed", + "worker-a", + "msg_locked", + formatTime(now), + thread.ThreadID, + ); err != nil { + t.Fatalf("seed claimed thread in tx: %v", err) + } + + commitDone := make(chan error, 1) + go func() { + time.Sleep(100 * time.Millisecond) + commitDone <- lockTx.Commit() + }() + + claimDB, err := dbpkg.Open(ctx, dbPath) + if err != nil { + t.Fatalf("open claim db: %v", err) + } + defer claimDB.Close() + + claimStore := NewInboxStore(claimDB) + _, err = claimStore.ClaimThread(ctx, ClaimInput{ + ThreadID: thread.ThreadID, + Agent: "worker-b", + LeaseSeconds: 300, + }) + if !errors.Is(err, ErrLeaseConflict) { + t.Fatalf("expected lease conflict after busy retry, got %v", err) + } + + if err := <-commitDone; err != nil { + t.Fatalf("commit locker tx: %v", err) + } +}