Harden inbox JSON validation and claim conflicts
This commit is contained in:
+133
-13
@@ -1,6 +1,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
@@ -226,7 +227,10 @@ func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread,
|
||||
kind := defaultString(input.Kind, "task")
|
||||
priority := defaultString(input.Priority, "normal")
|
||||
summary := defaultString(input.Summary, input.Subject)
|
||||
payload := normalizeJSON(input.PayloadJSON)
|
||||
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
messageID := newID("msg")
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
@@ -314,7 +318,10 @@ func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread,
|
||||
func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, input SendInput) (Thread, Message, error) {
|
||||
now := nowUTC()
|
||||
messageID := newID("msg")
|
||||
payload := normalizeJSON(input.PayloadJSON)
|
||||
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -480,6 +487,38 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
|
||||
input.LeaseSeconds = 900
|
||||
}
|
||||
|
||||
var lastBusyErr error
|
||||
for attempt := 0; attempt < 20; attempt++ {
|
||||
result, err := s.claimThreadOnce(ctx, input)
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
if !isSQLiteBusyError(err) {
|
||||
return ClaimResult{}, err
|
||||
}
|
||||
lastBusyErr = err
|
||||
|
||||
ok, waitErr := waitForNextPoll(ctx, 25*time.Millisecond)
|
||||
if waitErr != nil {
|
||||
return ClaimResult{}, waitErr
|
||||
}
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if resolvedErr := s.classifyClaimConflict(ctx, input.ThreadID); resolvedErr != nil {
|
||||
return ClaimResult{}, resolvedErr
|
||||
}
|
||||
|
||||
return ClaimResult{}, fmt.Errorf("claim thread: %w", lastBusyErr)
|
||||
}
|
||||
|
||||
func (s *InboxStore) claimThreadOnce(ctx context.Context, input ClaimInput) (ClaimResult, error) {
|
||||
if input.LeaseSeconds <= 0 {
|
||||
input.LeaseSeconds = 900
|
||||
}
|
||||
|
||||
now := nowUTC()
|
||||
expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second)
|
||||
leaseToken := newID("lease")
|
||||
@@ -519,7 +558,7 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
|
||||
return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
result, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO leases (
|
||||
thread_id, agent_id, lease_token, claimed_at, expires_at, released_at
|
||||
@@ -529,29 +568,41 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
|
||||
lease_token = excluded.lease_token,
|
||||
claimed_at = excluded.claimed_at,
|
||||
expires_at = excluded.expires_at,
|
||||
released_at = NULL`,
|
||||
released_at = NULL
|
||||
WHERE leases.released_at IS NOT NULL
|
||||
OR leases.expires_at <= excluded.claimed_at`,
|
||||
input.ThreadID,
|
||||
input.Agent,
|
||||
leaseToken,
|
||||
formatTime(now),
|
||||
formatTime(expiresAt),
|
||||
); err != nil {
|
||||
)
|
||||
if err != nil {
|
||||
return ClaimResult{}, fmt.Errorf("upsert lease: %w", err)
|
||||
}
|
||||
if affected, err := result.RowsAffected(); err == nil && affected == 0 {
|
||||
return ClaimResult{}, ErrLeaseConflict
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
result, err = tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE threads
|
||||
SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ?
|
||||
WHERE thread_id = ?`,
|
||||
WHERE thread_id = ?
|
||||
AND status = ?`,
|
||||
"claimed",
|
||||
input.Agent,
|
||||
messageID,
|
||||
formatTime(now),
|
||||
input.ThreadID,
|
||||
); err != nil {
|
||||
"pending",
|
||||
)
|
||||
if err != nil {
|
||||
return ClaimResult{}, fmt.Errorf("update thread claim status: %w", err)
|
||||
}
|
||||
if affected, err := result.RowsAffected(); err == nil && affected == 0 {
|
||||
return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
|
||||
message := Message{
|
||||
MessageID: messageID,
|
||||
@@ -702,6 +753,10 @@ func (s *InboxStore) RenewLease(ctx context.Context, input RenewInput) (ClaimRes
|
||||
func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) (Thread, Message, error) {
|
||||
now := nowUTC()
|
||||
messageID := newID("msg")
|
||||
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
|
||||
if input.Status != "in_progress" && input.Status != "blocked" {
|
||||
return Thread{}, Message{}, fmt.Errorf("%w: unsupported update status %q", ErrInvalidInput, input.Status)
|
||||
@@ -737,7 +792,7 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
|
||||
Kind: kind,
|
||||
Summary: input.Summary,
|
||||
Body: input.Body,
|
||||
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
|
||||
PayloadJSON: json.RawMessage(payload),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
@@ -781,6 +836,10 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
|
||||
func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Thread, Message, error) {
|
||||
now := nowUTC()
|
||||
messageID := newID("msg")
|
||||
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
@@ -804,7 +863,7 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa
|
||||
Kind: defaultString(input.Kind, "answer"),
|
||||
Summary: input.Summary,
|
||||
Body: input.Body,
|
||||
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
|
||||
PayloadJSON: json.RawMessage(payload),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
@@ -847,6 +906,10 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa
|
||||
func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (Thread, Message, error) {
|
||||
now := nowUTC()
|
||||
messageID := newID("msg")
|
||||
payload, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
|
||||
nextStatus := "done"
|
||||
eventType := "thread_done"
|
||||
@@ -881,7 +944,7 @@ func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (T
|
||||
Kind: "result",
|
||||
Summary: summary,
|
||||
Body: input.Body,
|
||||
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
|
||||
PayloadJSON: json.RawMessage(payload),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
@@ -1365,16 +1428,21 @@ func insertArtifacts(ctx context.Context, tx *sql.Tx, messageID string, inputs [
|
||||
|
||||
artifacts := make([]Artifact, 0, len(inputs))
|
||||
for _, input := range inputs {
|
||||
metadataJSON, err := validateAndNormalizeJSON("artifact-metadata-json", input.MetadataJSON)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
artifact := Artifact{
|
||||
ArtifactID: newID("art"),
|
||||
MessageID: messageID,
|
||||
Path: input.Path,
|
||||
Kind: defaultString(input.Kind, "file"),
|
||||
MetadataJSON: json.RawMessage(normalizeJSON(input.MetadataJSON)),
|
||||
MetadataJSON: json.RawMessage(metadataJSON),
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
_, err := tx.ExecContext(
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO artifacts (
|
||||
artifact_id, message_id, path, kind, metadata_json, created_at
|
||||
@@ -1468,6 +1536,37 @@ func messageIDs(messages []Message) []string {
|
||||
return ids
|
||||
}
|
||||
|
||||
func (s *InboxStore) classifyClaimConflict(ctx context.Context, threadID string) error {
|
||||
thread, err := selectThread(ctx, s.db, threadID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := nowUTC()
|
||||
var activeLease string
|
||||
err = s.db.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT agent_id
|
||||
FROM leases
|
||||
WHERE thread_id = ?
|
||||
AND released_at IS NULL
|
||||
AND expires_at > ?`,
|
||||
threadID,
|
||||
formatTime(now),
|
||||
).Scan(&activeLease)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return fmt.Errorf("check active lease after busy claim: %w", err)
|
||||
}
|
||||
if activeLease != "" {
|
||||
return ErrLeaseConflict
|
||||
}
|
||||
if thread.Status != "pending" {
|
||||
return fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, threadID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func requireActiveLease(ctx context.Context, tx *sql.Tx, threadID, agent string, now time.Time) (string, error) {
|
||||
var (
|
||||
activeAgent string
|
||||
@@ -1703,6 +1802,13 @@ func isDeadlineExceeded(ctx context.Context) bool {
|
||||
return ctx.Err() != nil && errors.Is(ctx.Err(), context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func isSQLiteBusyError(err error) bool {
|
||||
message := strings.ToLower(err.Error())
|
||||
return strings.Contains(message, "sqlite_busy") ||
|
||||
strings.Contains(message, "database is locked") ||
|
||||
strings.Contains(message, "database table is locked")
|
||||
}
|
||||
|
||||
func defaultID(value, prefix string) string {
|
||||
if value != "" {
|
||||
return value
|
||||
@@ -1728,6 +1834,20 @@ func normalizeJSON(value string) string {
|
||||
return value
|
||||
}
|
||||
|
||||
func validateAndNormalizeJSON(fieldName, value string) (string, error) {
|
||||
normalized := normalizeJSON(value)
|
||||
if !json.Valid([]byte(normalized)) {
|
||||
return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName)
|
||||
}
|
||||
|
||||
var compact bytes.Buffer
|
||||
if err := json.Compact(&compact, []byte(normalized)); err != nil {
|
||||
return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName)
|
||||
}
|
||||
|
||||
return compact.String(), nil
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
if n <= 0 {
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user