Harden inbox JSON validation and claim conflicts

This commit is contained in:
2026-03-19 03:35:37 +08:00
parent f315d2330d
commit 02d98a78dd
3 changed files with 257 additions and 16 deletions
+17 -3
View File
@@ -608,14 +608,28 @@ func TestInboxJSONErrorsAndExitCodes(t *testing.T) {
"send",
"--from", "leader",
"--to", "worker-z",
"--subject", "Invalid body flags",
"--body", "inline",
"--body-file", filepath.Join(t.TempDir(), "missing.md"),
"--subject", "Invalid payload json",
"--payload-json", "not-json",
)
if exitCode != 30 {
t.Fatalf("expected invalid input exit code 30, got %d", exitCode)
}
assertErrorJSON(t, stdout, "invalid_input")
stdout, _, exitCode = executeInboxCommand(
"--db", dbPath,
"--json",
"send",
"--from", "leader",
"--to", "worker-z",
"--subject", "Invalid artifact json",
"--artifact", "/tmp/report.md",
"--artifact-metadata-json", "not-json",
)
if exitCode != 30 {
t.Fatalf("expected invalid artifact metadata exit code 30, got %d", exitCode)
}
assertErrorJSON(t, stdout, "invalid_input")
}
func runInboxCommand(t *testing.T, args ...string) string {
+133 -13
View File
@@ -1,6 +1,7 @@
package store
import (
"bytes"
"context"
"database/sql"
"encoding/json"
@@ -226,7 +227,10 @@ func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread,
kind := defaultString(input.Kind, "task")
priority := defaultString(input.Priority, "normal")
summary := defaultString(input.Summary, input.Subject)
payload := normalizeJSON(input.PayloadJSON)
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
if err != nil {
return Thread{}, Message{}, err
}
messageID := newID("msg")
tx, err := s.db.BeginTx(ctx, nil)
@@ -314,7 +318,10 @@ func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread,
func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, input SendInput) (Thread, Message, error) {
now := nowUTC()
messageID := newID("msg")
payload := normalizeJSON(input.PayloadJSON)
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
if err != nil {
return Thread{}, Message{}, err
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
@@ -480,6 +487,38 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
input.LeaseSeconds = 900
}
var lastBusyErr error
for attempt := 0; attempt < 20; attempt++ {
result, err := s.claimThreadOnce(ctx, input)
if err == nil {
return result, nil
}
if !isSQLiteBusyError(err) {
return ClaimResult{}, err
}
lastBusyErr = err
ok, waitErr := waitForNextPoll(ctx, 25*time.Millisecond)
if waitErr != nil {
return ClaimResult{}, waitErr
}
if !ok {
break
}
}
if resolvedErr := s.classifyClaimConflict(ctx, input.ThreadID); resolvedErr != nil {
return ClaimResult{}, resolvedErr
}
return ClaimResult{}, fmt.Errorf("claim thread: %w", lastBusyErr)
}
func (s *InboxStore) claimThreadOnce(ctx context.Context, input ClaimInput) (ClaimResult, error) {
if input.LeaseSeconds <= 0 {
input.LeaseSeconds = 900
}
now := nowUTC()
expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second)
leaseToken := newID("lease")
@@ -519,7 +558,7 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID)
}
if _, err := tx.ExecContext(
result, err := tx.ExecContext(
ctx,
`INSERT INTO leases (
thread_id, agent_id, lease_token, claimed_at, expires_at, released_at
@@ -529,29 +568,41 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
lease_token = excluded.lease_token,
claimed_at = excluded.claimed_at,
expires_at = excluded.expires_at,
released_at = NULL`,
released_at = NULL
WHERE leases.released_at IS NOT NULL
OR leases.expires_at <= excluded.claimed_at`,
input.ThreadID,
input.Agent,
leaseToken,
formatTime(now),
formatTime(expiresAt),
); err != nil {
)
if err != nil {
return ClaimResult{}, fmt.Errorf("upsert lease: %w", err)
}
if affected, err := result.RowsAffected(); err == nil && affected == 0 {
return ClaimResult{}, ErrLeaseConflict
}
if _, err := tx.ExecContext(
result, err = tx.ExecContext(
ctx,
`UPDATE threads
SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ?
WHERE thread_id = ?`,
WHERE thread_id = ?
AND status = ?`,
"claimed",
input.Agent,
messageID,
formatTime(now),
input.ThreadID,
); err != nil {
"pending",
)
if err != nil {
return ClaimResult{}, fmt.Errorf("update thread claim status: %w", err)
}
if affected, err := result.RowsAffected(); err == nil && affected == 0 {
return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID)
}
message := Message{
MessageID: messageID,
@@ -702,6 +753,10 @@ func (s *InboxStore) RenewLease(ctx context.Context, input RenewInput) (ClaimRes
func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) (Thread, Message, error) {
now := nowUTC()
messageID := newID("msg")
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
if err != nil {
return Thread{}, Message{}, err
}
if input.Status != "in_progress" && input.Status != "blocked" {
return Thread{}, Message{}, fmt.Errorf("%w: unsupported update status %q", ErrInvalidInput, input.Status)
@@ -737,7 +792,7 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
Kind: kind,
Summary: input.Summary,
Body: input.Body,
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
PayloadJSON: json.RawMessage(payload),
CreatedAt: now,
}
@@ -781,6 +836,10 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Thread, Message, error) {
now := nowUTC()
messageID := newID("msg")
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
if err != nil {
return Thread{}, Message{}, err
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
@@ -804,7 +863,7 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa
Kind: defaultString(input.Kind, "answer"),
Summary: input.Summary,
Body: input.Body,
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
PayloadJSON: json.RawMessage(payload),
CreatedAt: now,
}
@@ -847,6 +906,10 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa
func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (Thread, Message, error) {
now := nowUTC()
messageID := newID("msg")
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
if err != nil {
return Thread{}, Message{}, err
}
nextStatus := "done"
eventType := "thread_done"
@@ -881,7 +944,7 @@ func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (T
Kind: "result",
Summary: summary,
Body: input.Body,
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
PayloadJSON: json.RawMessage(payload),
CreatedAt: now,
}
@@ -1365,16 +1428,21 @@ func insertArtifacts(ctx context.Context, tx *sql.Tx, messageID string, inputs [
artifacts := make([]Artifact, 0, len(inputs))
for _, input := range inputs {
metadataJSON, err := validateAndNormalizeJSON("artifact-metadata-json", input.MetadataJSON)
if err != nil {
return nil, err
}
artifact := Artifact{
ArtifactID: newID("art"),
MessageID: messageID,
Path: input.Path,
Kind: defaultString(input.Kind, "file"),
MetadataJSON: json.RawMessage(normalizeJSON(input.MetadataJSON)),
MetadataJSON: json.RawMessage(metadataJSON),
CreatedAt: createdAt,
}
_, err := tx.ExecContext(
_, err = tx.ExecContext(
ctx,
`INSERT INTO artifacts (
artifact_id, message_id, path, kind, metadata_json, created_at
@@ -1468,6 +1536,37 @@ func messageIDs(messages []Message) []string {
return ids
}
func (s *InboxStore) classifyClaimConflict(ctx context.Context, threadID string) error {
thread, err := selectThread(ctx, s.db, threadID)
if err != nil {
return err
}
now := nowUTC()
var activeLease string
err = s.db.QueryRowContext(
ctx,
`SELECT agent_id
FROM leases
WHERE thread_id = ?
AND released_at IS NULL
AND expires_at > ?`,
threadID,
formatTime(now),
).Scan(&activeLease)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("check active lease after busy claim: %w", err)
}
if activeLease != "" {
return ErrLeaseConflict
}
if thread.Status != "pending" {
return fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, threadID)
}
return nil
}
func requireActiveLease(ctx context.Context, tx *sql.Tx, threadID, agent string, now time.Time) (string, error) {
var (
activeAgent string
@@ -1703,6 +1802,13 @@ func isDeadlineExceeded(ctx context.Context) bool {
return ctx.Err() != nil && errors.Is(ctx.Err(), context.DeadlineExceeded)
}
func isSQLiteBusyError(err error) bool {
message := strings.ToLower(err.Error())
return strings.Contains(message, "sqlite_busy") ||
strings.Contains(message, "database is locked") ||
strings.Contains(message, "database table is locked")
}
func defaultID(value, prefix string) string {
if value != "" {
return value
@@ -1728,6 +1834,20 @@ func normalizeJSON(value string) string {
return value
}
func validateAndNormalizeJSON(fieldName, value string) (string, error) {
normalized := normalizeJSON(value)
if !json.Valid([]byte(normalized)) {
return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName)
}
var compact bytes.Buffer
if err := json.Compact(&compact, []byte(normalized)); err != nil {
return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName)
}
return compact.String(), nil
}
func placeholders(n int) string {
if n <= 0 {
return ""
+107
View File
@@ -0,0 +1,107 @@
package store
import (
"context"
"errors"
"path/filepath"
"testing"
"time"
dbpkg "ai-workflow-skill/internal/db"
)
func TestClaimThreadReturnsLeaseConflictAfterBusyWrite(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
dbPath := filepath.Join(t.TempDir(), "coord.db")
sqlDB, err := dbpkg.Open(ctx, dbPath)
if err != nil {
t.Fatalf("open base db: %v", err)
}
defer sqlDB.Close()
if err := dbpkg.ApplyMigrations(ctx, sqlDB); err != nil {
t.Fatalf("apply migrations: %v", err)
}
baseStore := NewInboxStore(sqlDB)
thread, _, err := baseStore.Send(ctx, SendInput{
FromAgent: "leader",
ToAgent: "worker-a",
Subject: "race claim",
Summary: "race claim",
})
if err != nil {
t.Fatalf("seed thread: %v", err)
}
lockerDB, err := dbpkg.Open(ctx, dbPath)
if err != nil {
t.Fatalf("open locker db: %v", err)
}
defer lockerDB.Close()
lockTx, err := lockerDB.BeginTx(ctx, nil)
if err != nil {
t.Fatalf("begin locker tx: %v", err)
}
now := nowUTC()
if _, err := lockTx.ExecContext(
ctx,
`INSERT INTO leases (
thread_id, agent_id, lease_token, claimed_at, expires_at, released_at
) VALUES (?, ?, ?, ?, ?, NULL)`,
thread.ThreadID,
"worker-a",
"lease_locked",
formatTime(now),
formatTime(now.Add(5*time.Minute)),
); err != nil {
t.Fatalf("seed active lease in tx: %v", err)
}
if _, err := lockTx.ExecContext(
ctx,
`UPDATE threads
SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ?
WHERE thread_id = ?`,
"claimed",
"worker-a",
"msg_locked",
formatTime(now),
thread.ThreadID,
); err != nil {
t.Fatalf("seed claimed thread in tx: %v", err)
}
commitDone := make(chan error, 1)
go func() {
time.Sleep(100 * time.Millisecond)
commitDone <- lockTx.Commit()
}()
claimDB, err := dbpkg.Open(ctx, dbPath)
if err != nil {
t.Fatalf("open claim db: %v", err)
}
defer claimDB.Close()
claimStore := NewInboxStore(claimDB)
_, err = claimStore.ClaimThread(ctx, ClaimInput{
ThreadID: thread.ThreadID,
Agent: "worker-b",
LeaseSeconds: 300,
})
if !errors.Is(err, ErrLeaseConflict) {
t.Fatalf("expected lease conflict after busy retry, got %v", err)
}
if err := <-commitDone; err != nil {
t.Fatalf("commit locker tx: %v", err)
}
}