Harden inbox JSON validation and claim conflicts
This commit is contained in:
@@ -608,14 +608,28 @@ func TestInboxJSONErrorsAndExitCodes(t *testing.T) {
|
||||
"send",
|
||||
"--from", "leader",
|
||||
"--to", "worker-z",
|
||||
"--subject", "Invalid body flags",
|
||||
"--body", "inline",
|
||||
"--body-file", filepath.Join(t.TempDir(), "missing.md"),
|
||||
"--subject", "Invalid payload json",
|
||||
"--payload-json", "not-json",
|
||||
)
|
||||
if exitCode != 30 {
|
||||
t.Fatalf("expected invalid input exit code 30, got %d", exitCode)
|
||||
}
|
||||
assertErrorJSON(t, stdout, "invalid_input")
|
||||
|
||||
stdout, _, exitCode = executeInboxCommand(
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"send",
|
||||
"--from", "leader",
|
||||
"--to", "worker-z",
|
||||
"--subject", "Invalid artifact json",
|
||||
"--artifact", "/tmp/report.md",
|
||||
"--artifact-metadata-json", "not-json",
|
||||
)
|
||||
if exitCode != 30 {
|
||||
t.Fatalf("expected invalid artifact metadata exit code 30, got %d", exitCode)
|
||||
}
|
||||
assertErrorJSON(t, stdout, "invalid_input")
|
||||
}
|
||||
|
||||
func runInboxCommand(t *testing.T, args ...string) string {
|
||||
|
||||
+133
-13
@@ -1,6 +1,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
@@ -226,7 +227,10 @@ func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread,
|
||||
kind := defaultString(input.Kind, "task")
|
||||
priority := defaultString(input.Priority, "normal")
|
||||
summary := defaultString(input.Summary, input.Subject)
|
||||
payload := normalizeJSON(input.PayloadJSON)
|
||||
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
messageID := newID("msg")
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
@@ -314,7 +318,10 @@ func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread,
|
||||
func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, input SendInput) (Thread, Message, error) {
|
||||
now := nowUTC()
|
||||
messageID := newID("msg")
|
||||
payload := normalizeJSON(input.PayloadJSON)
|
||||
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -480,6 +487,38 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
|
||||
input.LeaseSeconds = 900
|
||||
}
|
||||
|
||||
var lastBusyErr error
|
||||
for attempt := 0; attempt < 20; attempt++ {
|
||||
result, err := s.claimThreadOnce(ctx, input)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
if !isSQLiteBusyError(err) {
|
||||
return ClaimResult{}, err
|
||||
}
|
||||
lastBusyErr = err
|
||||
|
||||
ok, waitErr := waitForNextPoll(ctx, 25*time.Millisecond)
|
||||
if waitErr != nil {
|
||||
return ClaimResult{}, waitErr
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if resolvedErr := s.classifyClaimConflict(ctx, input.ThreadID); resolvedErr != nil {
|
||||
return ClaimResult{}, resolvedErr
|
||||
}
|
||||
|
||||
return ClaimResult{}, fmt.Errorf("claim thread: %w", lastBusyErr)
|
||||
}
|
||||
|
||||
func (s *InboxStore) claimThreadOnce(ctx context.Context, input ClaimInput) (ClaimResult, error) {
|
||||
if input.LeaseSeconds <= 0 {
|
||||
input.LeaseSeconds = 900
|
||||
}
|
||||
|
||||
now := nowUTC()
|
||||
expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second)
|
||||
leaseToken := newID("lease")
|
||||
@@ -519,7 +558,7 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
|
||||
return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
result, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO leases (
|
||||
thread_id, agent_id, lease_token, claimed_at, expires_at, released_at
|
||||
@@ -529,29 +568,41 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
|
||||
lease_token = excluded.lease_token,
|
||||
claimed_at = excluded.claimed_at,
|
||||
expires_at = excluded.expires_at,
|
||||
released_at = NULL`,
|
||||
released_at = NULL
|
||||
WHERE leases.released_at IS NOT NULL
|
||||
OR leases.expires_at <= excluded.claimed_at`,
|
||||
input.ThreadID,
|
||||
input.Agent,
|
||||
leaseToken,
|
||||
formatTime(now),
|
||||
formatTime(expiresAt),
|
||||
); err != nil {
|
||||
)
|
||||
if err != nil {
|
||||
return ClaimResult{}, fmt.Errorf("upsert lease: %w", err)
|
||||
}
|
||||
if affected, err := result.RowsAffected(); err == nil && affected == 0 {
|
||||
return ClaimResult{}, ErrLeaseConflict
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
result, err = tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE threads
|
||||
SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ?
|
||||
WHERE thread_id = ?`,
|
||||
WHERE thread_id = ?
|
||||
AND status = ?`,
|
||||
"claimed",
|
||||
input.Agent,
|
||||
messageID,
|
||||
formatTime(now),
|
||||
input.ThreadID,
|
||||
); err != nil {
|
||||
"pending",
|
||||
)
|
||||
if err != nil {
|
||||
return ClaimResult{}, fmt.Errorf("update thread claim status: %w", err)
|
||||
}
|
||||
if affected, err := result.RowsAffected(); err == nil && affected == 0 {
|
||||
return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
|
||||
message := Message{
|
||||
MessageID: messageID,
|
||||
@@ -702,6 +753,10 @@ func (s *InboxStore) RenewLease(ctx context.Context, input RenewInput) (ClaimRes
|
||||
func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) (Thread, Message, error) {
|
||||
now := nowUTC()
|
||||
messageID := newID("msg")
|
||||
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
|
||||
if input.Status != "in_progress" && input.Status != "blocked" {
|
||||
return Thread{}, Message{}, fmt.Errorf("%w: unsupported update status %q", ErrInvalidInput, input.Status)
|
||||
@@ -737,7 +792,7 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
|
||||
Kind: kind,
|
||||
Summary: input.Summary,
|
||||
Body: input.Body,
|
||||
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
|
||||
PayloadJSON: json.RawMessage(payload),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
@@ -781,6 +836,10 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
|
||||
func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Thread, Message, error) {
|
||||
now := nowUTC()
|
||||
messageID := newID("msg")
|
||||
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -804,7 +863,7 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa
|
||||
Kind: defaultString(input.Kind, "answer"),
|
||||
Summary: input.Summary,
|
||||
Body: input.Body,
|
||||
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
|
||||
PayloadJSON: json.RawMessage(payload),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
@@ -847,6 +906,10 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa
|
||||
func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (Thread, Message, error) {
|
||||
now := nowUTC()
|
||||
messageID := newID("msg")
|
||||
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
|
||||
nextStatus := "done"
|
||||
eventType := "thread_done"
|
||||
@@ -881,7 +944,7 @@ func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (T
|
||||
Kind: "result",
|
||||
Summary: summary,
|
||||
Body: input.Body,
|
||||
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
|
||||
PayloadJSON: json.RawMessage(payload),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
@@ -1365,16 +1428,21 @@ func insertArtifacts(ctx context.Context, tx *sql.Tx, messageID string, inputs [
|
||||
|
||||
artifacts := make([]Artifact, 0, len(inputs))
|
||||
for _, input := range inputs {
|
||||
metadataJSON, err := validateAndNormalizeJSON("artifact-metadata-json", input.MetadataJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
artifact := Artifact{
|
||||
ArtifactID: newID("art"),
|
||||
MessageID: messageID,
|
||||
Path: input.Path,
|
||||
Kind: defaultString(input.Kind, "file"),
|
||||
MetadataJSON: json.RawMessage(normalizeJSON(input.MetadataJSON)),
|
||||
MetadataJSON: json.RawMessage(metadataJSON),
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
_, err := tx.ExecContext(
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO artifacts (
|
||||
artifact_id, message_id, path, kind, metadata_json, created_at
|
||||
@@ -1468,6 +1536,37 @@ func messageIDs(messages []Message) []string {
|
||||
return ids
|
||||
}
|
||||
|
||||
func (s *InboxStore) classifyClaimConflict(ctx context.Context, threadID string) error {
|
||||
thread, err := selectThread(ctx, s.db, threadID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := nowUTC()
|
||||
var activeLease string
|
||||
err = s.db.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT agent_id
|
||||
FROM leases
|
||||
WHERE thread_id = ?
|
||||
AND released_at IS NULL
|
||||
AND expires_at > ?`,
|
||||
threadID,
|
||||
formatTime(now),
|
||||
).Scan(&activeLease)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return fmt.Errorf("check active lease after busy claim: %w", err)
|
||||
}
|
||||
if activeLease != "" {
|
||||
return ErrLeaseConflict
|
||||
}
|
||||
if thread.Status != "pending" {
|
||||
return fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, threadID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func requireActiveLease(ctx context.Context, tx *sql.Tx, threadID, agent string, now time.Time) (string, error) {
|
||||
var (
|
||||
activeAgent string
|
||||
@@ -1703,6 +1802,13 @@ func isDeadlineExceeded(ctx context.Context) bool {
|
||||
return ctx.Err() != nil && errors.Is(ctx.Err(), context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func isSQLiteBusyError(err error) bool {
|
||||
message := strings.ToLower(err.Error())
|
||||
return strings.Contains(message, "sqlite_busy") ||
|
||||
strings.Contains(message, "database is locked") ||
|
||||
strings.Contains(message, "database table is locked")
|
||||
}
|
||||
|
||||
func defaultID(value, prefix string) string {
|
||||
if value != "" {
|
||||
return value
|
||||
@@ -1728,6 +1834,20 @@ func normalizeJSON(value string) string {
|
||||
return value
|
||||
}
|
||||
|
||||
func validateAndNormalizeJSON(fieldName, value string) (string, error) {
|
||||
normalized := normalizeJSON(value)
|
||||
if !json.Valid([]byte(normalized)) {
|
||||
return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName)
|
||||
}
|
||||
|
||||
var compact bytes.Buffer
|
||||
if err := json.Compact(&compact, []byte(normalized)); err != nil {
|
||||
return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName)
|
||||
}
|
||||
|
||||
return compact.String(), nil
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
if n <= 0 {
|
||||
return ""
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbpkg "ai-workflow-skill/internal/db"
|
||||
)
|
||||
|
||||
func TestClaimThreadReturnsLeaseConflictAfterBusyWrite(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "coord.db")
|
||||
|
||||
sqlDB, err := dbpkg.Open(ctx, dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open base db: %v", err)
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
|
||||
if err := dbpkg.ApplyMigrations(ctx, sqlDB); err != nil {
|
||||
t.Fatalf("apply migrations: %v", err)
|
||||
}
|
||||
|
||||
baseStore := NewInboxStore(sqlDB)
|
||||
thread, _, err := baseStore.Send(ctx, SendInput{
|
||||
FromAgent: "leader",
|
||||
ToAgent: "worker-a",
|
||||
Subject: "race claim",
|
||||
Summary: "race claim",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("seed thread: %v", err)
|
||||
}
|
||||
|
||||
lockerDB, err := dbpkg.Open(ctx, dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open locker db: %v", err)
|
||||
}
|
||||
defer lockerDB.Close()
|
||||
|
||||
lockTx, err := lockerDB.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("begin locker tx: %v", err)
|
||||
}
|
||||
|
||||
now := nowUTC()
|
||||
if _, err := lockTx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO leases (
|
||||
thread_id, agent_id, lease_token, claimed_at, expires_at, released_at
|
||||
) VALUES (?, ?, ?, ?, ?, NULL)`,
|
||||
thread.ThreadID,
|
||||
"worker-a",
|
||||
"lease_locked",
|
||||
formatTime(now),
|
||||
formatTime(now.Add(5*time.Minute)),
|
||||
); err != nil {
|
||||
t.Fatalf("seed active lease in tx: %v", err)
|
||||
}
|
||||
|
||||
if _, err := lockTx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE threads
|
||||
SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ?
|
||||
WHERE thread_id = ?`,
|
||||
"claimed",
|
||||
"worker-a",
|
||||
"msg_locked",
|
||||
formatTime(now),
|
||||
thread.ThreadID,
|
||||
); err != nil {
|
||||
t.Fatalf("seed claimed thread in tx: %v", err)
|
||||
}
|
||||
|
||||
commitDone := make(chan error, 1)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
commitDone <- lockTx.Commit()
|
||||
}()
|
||||
|
||||
claimDB, err := dbpkg.Open(ctx, dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open claim db: %v", err)
|
||||
}
|
||||
defer claimDB.Close()
|
||||
|
||||
claimStore := NewInboxStore(claimDB)
|
||||
_, err = claimStore.ClaimThread(ctx, ClaimInput{
|
||||
ThreadID: thread.ThreadID,
|
||||
Agent: "worker-b",
|
||||
LeaseSeconds: 300,
|
||||
})
|
||||
if !errors.Is(err, ErrLeaseConflict) {
|
||||
t.Fatalf("expected lease conflict after busy retry, got %v", err)
|
||||
}
|
||||
|
||||
if err := <-commitDone; err != nil {
|
||||
t.Fatalf("commit locker tx: %v", err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user