package store import ( "context" "database/sql" "encoding/json" "errors" "fmt" "strings" "time" "github.com/google/uuid" ) var ErrLeaseConflict = errors.New("thread already claimed by another worker") var ErrThreadNotFound = errors.New("thread not found") var ErrMessageNotFound = errors.New("message not found") var ErrNoActiveLease = errors.New("no active lease") var ErrInvalidInput = errors.New("invalid input") var ErrInvalidState = errors.New("invalid state") type InboxStore struct { db *sql.DB } type Thread struct { ThreadID string `json:"thread_id"` RunID string `json:"run_id"` TaskID string `json:"task_id"` Subject string `json:"subject"` CreatedBy string `json:"created_by"` AssignedTo string `json:"assigned_to"` Status string `json:"status"` Priority string `json:"priority"` LatestMessageID string `json:"latest_message_id,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } type Message struct { MessageID string `json:"message_id"` ThreadID string `json:"thread_id"` FromAgent string `json:"from_agent"` ToAgent string `json:"to_agent"` Kind string `json:"kind"` Summary string `json:"summary"` Body string `json:"body"` PayloadJSON json.RawMessage `json:"payload_json"` CreatedAt time.Time `json:"created_at"` Artifacts []Artifact `json:"artifacts,omitempty"` } type Artifact struct { ArtifactID string `json:"artifact_id"` MessageID string `json:"message_id"` Path string `json:"path"` Kind string `json:"kind"` MetadataJSON json.RawMessage `json:"metadata_json"` CreatedAt time.Time `json:"created_at"` } type ArtifactInput struct { Path string Kind string MetadataJSON string } type ThreadDetail struct { Thread Thread `json:"thread"` Messages []Message `json:"messages"` } type Event struct { EventID int64 `json:"event_id"` RunID string `json:"run_id"` TaskID string `json:"task_id"` ThreadID string `json:"thread_id,omitempty"` Source string `json:"source"` EventType string `json:"event_type"` MessageID string `json:"message_id,omitempty"` Summary string `json:"summary"` PayloadJSON json.RawMessage `json:"payload_json"` CreatedAt time.Time `json:"created_at"` } type SendInput struct { ThreadID string RunID string TaskID string Subject string FromAgent string ToAgent string Kind string Summary string Body string PayloadJSON string Priority string Artifacts []ArtifactInput } type FetchInput struct { Agent string Statuses []string Limit int Unread bool } type ClaimInput struct { ThreadID string Agent string LeaseSeconds int } type RenewInput struct { ThreadID string Agent string LeaseSeconds int } type ClaimResult struct { Thread Thread `json:"thread"` Message Message `json:"message"` } type UpdateInput struct { ThreadID string Agent string Status string Summary string Body string PayloadJSON string Artifacts []ArtifactInput } type ReplyInput struct { ThreadID string FromAgent string ToAgent string Kind string Summary string Body string PayloadJSON string Artifacts []ArtifactInput } type CompleteInput struct { ThreadID string Agent string Summary string Body string PayloadJSON string Failed bool Artifacts []ArtifactInput } type CancelInput struct { ThreadID string Agent string Reason string Artifacts []ArtifactInput } type ListInput struct { Agent string Statuses []string CreatedBy string AssignedTo string Limit int Unread bool } type WatchInput struct { Agent string Statuses []string AfterEventID int64 StartFromNow bool Timeout time.Duration } type WatchResult struct { Woke bool `json:"woke"` NextEventID int64 `json:"next_event_id"` Thread *Thread `json:"thread,omitempty"` Message *Message `json:"message,omitempty"` Event *Event `json:"event,omitempty"` } type WaitReplyInput struct { ThreadID string AfterMessageID string AfterEventID int64 Kinds []string Timeout time.Duration } type WaitReplyResult struct { Woke bool `json:"woke"` NextEventID int64 `json:"next_event_id"` Message *Message `json:"message,omitempty"` } func NewInboxStore(db *sql.DB) *InboxStore { return &InboxStore{db: db} } func (s *InboxStore) Send(ctx context.Context, input SendInput) (Thread, Message, error) { if input.ThreadID != "" { thread, err := selectThread(ctx, s.db, input.ThreadID) if err == nil { return s.appendThreadMessage(ctx, thread, input) } if !errors.Is(err, ErrThreadNotFound) { return Thread{}, Message{}, err } } return s.createThread(ctx, input) } func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread, Message, error) { now := nowUTC() threadID := defaultID(input.ThreadID, "thr") runID := defaultID(input.RunID, "run") taskID := defaultID(input.TaskID, "task") kind := defaultString(input.Kind, "task") priority := defaultString(input.Priority, "normal") summary := defaultString(input.Summary, input.Subject) payload := normalizeJSON(input.PayloadJSON) messageID := newID("msg") tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Thread{}, Message{}, fmt.Errorf("begin send transaction: %w", err) } defer tx.Rollback() thread := Thread{ ThreadID: threadID, RunID: runID, TaskID: taskID, Subject: input.Subject, CreatedBy: input.FromAgent, AssignedTo: input.ToAgent, Status: "pending", Priority: priority, LatestMessageID: messageID, CreatedAt: now, UpdatedAt: now, } if _, err := tx.ExecContext( ctx, `INSERT INTO threads ( thread_id, run_id, task_id, subject, created_by, assigned_to, status, priority, latest_message_id, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, thread.ThreadID, thread.RunID, thread.TaskID, thread.Subject, thread.CreatedBy, thread.AssignedTo, thread.Status, thread.Priority, thread.LatestMessageID, formatTime(thread.CreatedAt), formatTime(thread.UpdatedAt), ); err != nil { return Thread{}, Message{}, fmt.Errorf("insert thread: %w", err) } message := Message{ MessageID: messageID, ThreadID: threadID, FromAgent: input.FromAgent, ToAgent: input.ToAgent, Kind: kind, Summary: summary, Body: input.Body, PayloadJSON: json.RawMessage(payload), CreatedAt: now, } if err := insertMessage(ctx, tx, message); err != nil { return Thread{}, Message{}, err } artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now) if err != nil { return Thread{}, Message{}, err } message.Artifacts = artifacts if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_created", MessageID: message.MessageID, Summary: summary, PayloadJSON: payload, CreatedAt: now, }); err != nil { return Thread{}, Message{}, err } if err := tx.Commit(); err != nil { return Thread{}, Message{}, fmt.Errorf("commit send transaction: %w", err) } return thread, message, nil } func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, input SendInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") payload := normalizeJSON(input.PayloadJSON) tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Thread{}, Message{}, fmt.Errorf("begin append transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, existing.ThreadID) if err != nil { return Thread{}, Message{}, err } if isTerminalStatus(thread.Status) { return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, thread.ThreadID) } assignedTo := thread.AssignedTo if input.ToAgent != "" { assignedTo = input.ToAgent } message := Message{ MessageID: messageID, ThreadID: thread.ThreadID, FromAgent: input.FromAgent, ToAgent: defaultString(input.ToAgent, thread.AssignedTo), Kind: defaultString(input.Kind, "task"), Summary: defaultString(input.Summary, thread.Subject), Body: input.Body, PayloadJSON: json.RawMessage(payload), CreatedAt: now, } if err := insertMessage(ctx, tx, message); err != nil { return Thread{}, Message{}, err } artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now) if err != nil { return Thread{}, Message{}, err } message.Artifacts = artifacts if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, assignedTo, message.MessageID, now); err != nil { return Thread{}, Message{}, err } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_message_sent", MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: payload, CreatedAt: now, }); err != nil { return Thread{}, Message{}, err } if err := tx.Commit(); err != nil { return Thread{}, Message{}, fmt.Errorf("commit append transaction: %w", err) } thread.AssignedTo = assignedTo thread.LatestMessageID = message.MessageID thread.UpdatedAt = now return thread, message, nil } func (s *InboxStore) FetchThreads(ctx context.Context, input FetchInput) ([]Thread, error) { statuses := input.Statuses if len(statuses) == 0 { statuses = []string{"pending"} } return s.ListThreads(ctx, ListInput{ Agent: input.Agent, Statuses: statuses, Limit: input.Limit, Unread: input.Unread, }) } func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread, error) { limit := input.Limit if limit <= 0 { limit = 20 } var ( args []any conditions []string joins []string ) assignedTo := input.AssignedTo if assignedTo == "" { assignedTo = input.Agent } if assignedTo != "" { conditions = append(conditions, "t.assigned_to = ?") args = append(args, assignedTo) } if input.CreatedBy != "" { conditions = append(conditions, "t.created_by = ?") args = append(args, input.CreatedBy) } if len(input.Statuses) > 0 { conditions = append(conditions, "t.status IN ("+placeholders(len(input.Statuses))+")") for _, status := range input.Statuses { args = append(args, status) } } if input.Unread { if input.Agent == "" { return nil, fmt.Errorf("%w: agent is required when filtering unread threads", ErrInvalidInput) } joins = append(joins, "JOIN messages lm ON lm.message_id = t.latest_message_id") conditions = append(conditions, "lm.to_agent = ?") args = append(args, input.Agent) conditions = append(conditions, "lm.from_agent <> ?") args = append(args, input.Agent) } query := `SELECT t.thread_id, t.run_id, t.task_id, t.subject, t.created_by, t.assigned_to, t.status, t.priority, t.latest_message_id, t.created_at, t.updated_at FROM threads t` if len(joins) > 0 { query += " " + strings.Join(joins, " ") } if len(conditions) > 0 { query += " WHERE " + strings.Join(conditions, " AND ") } query += " ORDER BY t.updated_at DESC LIMIT ?" args = append(args, limit) rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("list threads: %w", err) } defer rows.Close() var threads []Thread for rows.Next() { thread, err := scanThread(rows) if err != nil { return nil, err } threads = append(threads, thread) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate threads: %w", err) } return threads, nil } func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimResult, error) { if input.LeaseSeconds <= 0 { input.LeaseSeconds = 900 } now := nowUTC() expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second) leaseToken := newID("lease") messageID := newID("msg") tx, err := s.db.BeginTx(ctx, nil) if err != nil { return ClaimResult{}, fmt.Errorf("begin claim transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) if err != nil { return ClaimResult{}, err } if isTerminalStatus(thread.Status) { return ClaimResult{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID) } var activeLease string err = tx.QueryRowContext( ctx, `SELECT agent_id FROM leases WHERE thread_id = ? AND released_at IS NULL AND expires_at > ?`, input.ThreadID, formatTime(now), ).Scan(&activeLease) if err != nil && !errors.Is(err, sql.ErrNoRows) { return ClaimResult{}, fmt.Errorf("check active lease: %w", err) } if activeLease != "" { return ClaimResult{}, ErrLeaseConflict } if thread.Status != "pending" { return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID) } if _, err := tx.ExecContext( ctx, `INSERT INTO leases ( thread_id, agent_id, lease_token, claimed_at, expires_at, released_at ) VALUES (?, ?, ?, ?, ?, NULL) ON CONFLICT(thread_id) DO UPDATE SET agent_id = excluded.agent_id, lease_token = excluded.lease_token, claimed_at = excluded.claimed_at, expires_at = excluded.expires_at, released_at = NULL`, input.ThreadID, input.Agent, leaseToken, formatTime(now), formatTime(expiresAt), ); err != nil { return ClaimResult{}, fmt.Errorf("upsert lease: %w", err) } if _, err := tx.ExecContext( ctx, `UPDATE threads SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ? WHERE thread_id = ?`, "claimed", input.Agent, messageID, formatTime(now), input.ThreadID, ); err != nil { return ClaimResult{}, fmt.Errorf("update thread claim status: %w", err) } message := Message{ MessageID: messageID, ThreadID: input.ThreadID, FromAgent: input.Agent, ToAgent: input.Agent, Kind: "event", Summary: "thread claimed", Body: "", PayloadJSON: json.RawMessage(fmt.Sprintf(`{"lease_seconds":%d,"lease_token":"%s"}`, input.LeaseSeconds, leaseToken)), CreatedAt: now, } if _, err := tx.ExecContext( ctx, `INSERT INTO messages ( message_id, thread_id, from_agent, to_agent, kind, summary, body, payload_json, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, message.MessageID, message.ThreadID, message.FromAgent, message.ToAgent, message.Kind, message.Summary, message.Body, string(message.PayloadJSON), formatTime(message.CreatedAt), ); err != nil { return ClaimResult{}, fmt.Errorf("insert claim event message: %w", err) } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_claimed", MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: string(message.PayloadJSON), CreatedAt: now, }); err != nil { return ClaimResult{}, err } if err := tx.Commit(); err != nil { return ClaimResult{}, fmt.Errorf("commit claim transaction: %w", err) } thread.Status = "claimed" thread.AssignedTo = input.Agent thread.LatestMessageID = messageID thread.UpdatedAt = now return ClaimResult{ Thread: thread, Message: message, }, nil } func (s *InboxStore) RenewLease(ctx context.Context, input RenewInput) (ClaimResult, error) { if input.LeaseSeconds <= 0 { input.LeaseSeconds = 900 } now := nowUTC() expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second) leaseToken := newID("lease") messageID := newID("msg") tx, err := s.db.BeginTx(ctx, nil) if err != nil { return ClaimResult{}, fmt.Errorf("begin renew transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) if err != nil { return ClaimResult{}, err } if isTerminalStatus(thread.Status) { return ClaimResult{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID) } if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil { return ClaimResult{}, err } if _, err := tx.ExecContext( ctx, `UPDATE leases SET lease_token = ?, expires_at = ?, released_at = NULL WHERE thread_id = ?`, leaseToken, formatTime(expiresAt), input.ThreadID, ); err != nil { return ClaimResult{}, fmt.Errorf("renew lease: %w", err) } message := Message{ MessageID: messageID, ThreadID: input.ThreadID, FromAgent: input.Agent, ToAgent: input.Agent, Kind: "event", Summary: "lease renewed", Body: "", PayloadJSON: json.RawMessage(fmt.Sprintf(`{"lease_seconds":%d,"lease_token":"%s"}`, input.LeaseSeconds, leaseToken)), CreatedAt: now, } if err := insertMessage(ctx, tx, message); err != nil { return ClaimResult{}, err } if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil { return ClaimResult{}, err } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_renewed", MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: string(message.PayloadJSON), CreatedAt: now, }); err != nil { return ClaimResult{}, err } if err := tx.Commit(); err != nil { return ClaimResult{}, fmt.Errorf("commit renew transaction: %w", err) } thread.LatestMessageID = message.MessageID thread.UpdatedAt = now return ClaimResult{ Thread: thread, Message: message, }, nil } func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") if input.Status != "in_progress" && input.Status != "blocked" { return Thread{}, Message{}, fmt.Errorf("%w: unsupported update status %q", ErrInvalidInput, input.Status) } tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Thread{}, Message{}, fmt.Errorf("begin update transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) if err != nil { return Thread{}, Message{}, err } if isTerminalStatus(thread.Status) { return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID) } if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil { return Thread{}, Message{}, err } kind := "progress" if input.Status == "blocked" { kind = "question" } message := Message{ MessageID: messageID, ThreadID: thread.ThreadID, FromAgent: input.Agent, ToAgent: thread.CreatedBy, Kind: kind, Summary: input.Summary, Body: input.Body, PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), CreatedAt: now, } if err := insertMessage(ctx, tx, message); err != nil { return Thread{}, Message{}, err } artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now) if err != nil { return Thread{}, Message{}, err } message.Artifacts = artifacts if err := updateThreadState(ctx, tx, thread.ThreadID, input.Status, thread.AssignedTo, message.MessageID, now); err != nil { return Thread{}, Message{}, err } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_" + input.Status, MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: string(message.PayloadJSON), CreatedAt: now, }); err != nil { return Thread{}, Message{}, err } if err := tx.Commit(); err != nil { return Thread{}, Message{}, fmt.Errorf("commit update transaction: %w", err) } thread.Status = input.Status thread.LatestMessageID = message.MessageID thread.UpdatedAt = now return thread, message, nil } func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Thread{}, Message{}, fmt.Errorf("begin reply transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) if err != nil { return Thread{}, Message{}, err } if isTerminalStatus(thread.Status) { return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID) } message := Message{ MessageID: messageID, ThreadID: thread.ThreadID, FromAgent: input.FromAgent, ToAgent: input.ToAgent, Kind: defaultString(input.Kind, "answer"), Summary: input.Summary, Body: input.Body, PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), CreatedAt: now, } if err := insertMessage(ctx, tx, message); err != nil { return Thread{}, Message{}, err } artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now) if err != nil { return Thread{}, Message{}, err } message.Artifacts = artifacts if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil { return Thread{}, Message{}, err } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_replied", MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: string(message.PayloadJSON), CreatedAt: now, }); err != nil { return Thread{}, Message{}, err } if err := tx.Commit(); err != nil { return Thread{}, Message{}, fmt.Errorf("commit reply transaction: %w", err) } thread.LatestMessageID = message.MessageID thread.UpdatedAt = now return thread, message, nil } func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") nextStatus := "done" eventType := "thread_done" summary := input.Summary if input.Failed { nextStatus = "failed" eventType = "thread_failed" } tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Thread{}, Message{}, fmt.Errorf("begin complete transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) if err != nil { return Thread{}, Message{}, err } if isTerminalStatus(thread.Status) { return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID) } if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil { return Thread{}, Message{}, err } message := Message{ MessageID: messageID, ThreadID: thread.ThreadID, FromAgent: input.Agent, ToAgent: thread.CreatedBy, Kind: "result", Summary: summary, Body: input.Body, PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), CreatedAt: now, } if err := insertMessage(ctx, tx, message); err != nil { return Thread{}, Message{}, err } artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now) if err != nil { return Thread{}, Message{}, err } message.Artifacts = artifacts if err := updateThreadState(ctx, tx, thread.ThreadID, nextStatus, thread.AssignedTo, message.MessageID, now); err != nil { return Thread{}, Message{}, err } if _, err := tx.ExecContext( ctx, `UPDATE leases SET released_at = ? WHERE thread_id = ? AND released_at IS NULL`, formatTime(now), thread.ThreadID, ); err != nil { return Thread{}, Message{}, fmt.Errorf("release lease: %w", err) } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: eventType, MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: string(message.PayloadJSON), CreatedAt: now, }); err != nil { return Thread{}, Message{}, err } if err := tx.Commit(); err != nil { return Thread{}, Message{}, fmt.Errorf("commit complete transaction: %w", err) } thread.Status = nextStatus thread.LatestMessageID = message.MessageID thread.UpdatedAt = now return thread, message, nil } func (s *InboxStore) CancelThread(ctx context.Context, input CancelInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Thread{}, Message{}, fmt.Errorf("begin cancel transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) if err != nil { return Thread{}, Message{}, err } if isTerminalStatus(thread.Status) { return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID) } summary := defaultString(input.Reason, "thread cancelled") message := Message{ MessageID: messageID, ThreadID: thread.ThreadID, FromAgent: input.Agent, ToAgent: thread.AssignedTo, Kind: "control", Summary: summary, Body: input.Reason, PayloadJSON: json.RawMessage(`{}`), CreatedAt: now, } if err := insertMessage(ctx, tx, message); err != nil { return Thread{}, Message{}, err } artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now) if err != nil { return Thread{}, Message{}, err } message.Artifacts = artifacts if err := updateThreadState(ctx, tx, thread.ThreadID, "cancelled", thread.AssignedTo, message.MessageID, now); err != nil { return Thread{}, Message{}, err } if _, err := tx.ExecContext( ctx, `UPDATE leases SET released_at = ? WHERE thread_id = ? AND released_at IS NULL`, formatTime(now), thread.ThreadID, ); err != nil { return Thread{}, Message{}, fmt.Errorf("release lease on cancel: %w", err) } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_cancelled", MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: string(message.PayloadJSON), CreatedAt: now, }); err != nil { return Thread{}, Message{}, err } if err := tx.Commit(); err != nil { return Thread{}, Message{}, fmt.Errorf("commit cancel transaction: %w", err) } thread.Status = "cancelled" thread.LatestMessageID = message.MessageID thread.UpdatedAt = now return thread, message, nil } func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDetail, error) { thread, err := selectThread(ctx, s.db, threadID) if err != nil { return ThreadDetail{}, err } rows, err := s.db.QueryContext( ctx, `SELECT message_id, thread_id, from_agent, to_agent, kind, summary, body, payload_json, created_at FROM messages WHERE thread_id = ? ORDER BY created_at ASC`, threadID, ) if err != nil { return ThreadDetail{}, fmt.Errorf("query thread messages: %w", err) } defer rows.Close() var messages []Message for rows.Next() { message, err := scanMessage(rows) if err != nil { return ThreadDetail{}, err } messages = append(messages, message) } if err := rows.Err(); err != nil { return ThreadDetail{}, fmt.Errorf("iterate thread messages: %w", err) } artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, messageIDs(messages)) if err != nil { return ThreadDetail{}, err } attachArtifacts(messages, artifactsByMessageID) return ThreadDetail{ Thread: thread, Messages: messages, }, nil } func (s *InboxStore) WatchThreads(ctx context.Context, input WatchInput) (WatchResult, error) { cursor := input.AfterEventID if input.StartFromNow && cursor == 0 { current, err := s.currentMaxEventID(ctx) if err != nil { return WatchResult{}, err } cursor = current } waitCtx := ctx cancel := func() {} if input.Timeout > 0 { waitCtx, cancel = context.WithTimeout(ctx, input.Timeout) } defer cancel() for { thread, message, event, found, err := s.findWatchEventAfter(waitCtx, input, cursor) if err != nil { if isDeadlineExceeded(waitCtx) { return WatchResult{Woke: false, NextEventID: cursor}, nil } return WatchResult{}, err } if found { return WatchResult{ Woke: true, NextEventID: event.EventID, Thread: &thread, Message: &message, Event: &event, }, nil } ok, err := waitForNextPoll(waitCtx, 200*time.Millisecond) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return WatchResult{Woke: false, NextEventID: cursor}, nil } return WatchResult{}, err } if !ok { return WatchResult{Woke: false, NextEventID: cursor}, nil } } } func (s *InboxStore) WaitReply(ctx context.Context, input WaitReplyInput) (WaitReplyResult, error) { cursor := input.AfterEventID if input.AfterMessageID != "" { eventID, err := s.lookupEventIDForMessage(ctx, input.ThreadID, input.AfterMessageID) if err != nil { return WaitReplyResult{}, err } if eventID > cursor { cursor = eventID } } kinds := input.Kinds if len(kinds) == 0 { kinds = []string{"answer", "control", "result"} } waitCtx := ctx cancel := func() {} if input.Timeout > 0 { waitCtx, cancel = context.WithTimeout(ctx, input.Timeout) } defer cancel() for { message, eventID, found, err := s.findReplyAfter(waitCtx, input.ThreadID, cursor, kinds) if err != nil { if isDeadlineExceeded(waitCtx) { return WaitReplyResult{Woke: false, NextEventID: cursor}, nil } return WaitReplyResult{}, err } if found { return WaitReplyResult{ Woke: true, NextEventID: eventID, Message: &message, }, nil } ok, err := waitForNextPoll(waitCtx, 200*time.Millisecond) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return WaitReplyResult{Woke: false, NextEventID: cursor}, nil } return WaitReplyResult{}, err } if !ok { return WaitReplyResult{Woke: false, NextEventID: cursor}, nil } } } type threadScanner interface { Scan(dest ...any) error } func scanThread(scanner threadScanner) (Thread, error) { var ( thread Thread createdAt, updatedAt string latestMessageID sql.NullString ) if err := scanner.Scan( &thread.ThreadID, &thread.RunID, &thread.TaskID, &thread.Subject, &thread.CreatedBy, &thread.AssignedTo, &thread.Status, &thread.Priority, &latestMessageID, &createdAt, &updatedAt, ); err != nil { return Thread{}, fmt.Errorf("scan thread: %w", err) } thread.CreatedAt = parseTime(createdAt) thread.UpdatedAt = parseTime(updatedAt) if latestMessageID.Valid { thread.LatestMessageID = latestMessageID.String } return thread, nil } func scanMessage(scanner threadScanner) (Message, error) { var ( message Message payload, createdAt string ) if err := scanner.Scan( &message.MessageID, &message.ThreadID, &message.FromAgent, &message.ToAgent, &message.Kind, &message.Summary, &message.Body, &payload, &createdAt, ); err != nil { return Message{}, fmt.Errorf("scan message: %w", err) } message.PayloadJSON = json.RawMessage(payload) message.CreatedAt = parseTime(createdAt) return message, nil } func scanArtifact(scanner threadScanner) (Artifact, error) { var ( artifact Artifact metadata, created string ) if err := scanner.Scan( &artifact.ArtifactID, &artifact.MessageID, &artifact.Path, &artifact.Kind, &metadata, &created, ); err != nil { return Artifact{}, fmt.Errorf("scan artifact: %w", err) } artifact.MetadataJSON = json.RawMessage(metadata) artifact.CreatedAt = parseTime(created) return artifact, nil } func scanEvent(scanner threadScanner) (Event, error) { var ( event Event messageID sql.NullString payload, createdAt string ) if err := scanner.Scan( &event.EventID, &event.RunID, &event.TaskID, &event.ThreadID, &event.Source, &event.EventType, &messageID, &event.Summary, &payload, &createdAt, ); err != nil { return Event{}, fmt.Errorf("scan event: %w", err) } if messageID.Valid { event.MessageID = messageID.String } event.PayloadJSON = json.RawMessage(payload) event.CreatedAt = parseTime(createdAt) return event, nil } func selectThread(ctx context.Context, db queryRower, threadID string) (Thread, error) { row := db.QueryRowContext( ctx, `SELECT thread_id, run_id, task_id, subject, created_by, assigned_to, status, priority, latest_message_id, created_at, updated_at FROM threads WHERE thread_id = ?`, threadID, ) thread, err := scanThread(row) if errors.Is(err, sql.ErrNoRows) { return Thread{}, fmt.Errorf("%w: %s", ErrThreadNotFound, threadID) } return thread, err } func selectThreadForUpdate(ctx context.Context, tx *sql.Tx, threadID string) (Thread, error) { return selectThread(ctx, tx, threadID) } type queryRower interface { QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } type eventInput struct { RunID string TaskID string ThreadID string Source string EventType string MessageID string Summary string PayloadJSON string CreatedAt time.Time } func insertEvent(ctx context.Context, tx *sql.Tx, input eventInput) error { _, err := tx.ExecContext( ctx, `INSERT INTO events ( run_id, task_id, thread_id, source, event_type, message_id, summary, payload_json, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, input.RunID, input.TaskID, input.ThreadID, input.Source, input.EventType, input.MessageID, input.Summary, normalizeJSON(input.PayloadJSON), formatTime(input.CreatedAt), ) if err != nil { return fmt.Errorf("insert event: %w", err) } return nil } func insertMessage(ctx context.Context, tx *sql.Tx, message Message) error { _, err := tx.ExecContext( ctx, `INSERT INTO messages ( message_id, thread_id, from_agent, to_agent, kind, summary, body, payload_json, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, message.MessageID, message.ThreadID, message.FromAgent, message.ToAgent, message.Kind, message.Summary, message.Body, string(message.PayloadJSON), formatTime(message.CreatedAt), ) if err != nil { return fmt.Errorf("insert message: %w", err) } return nil } func insertArtifacts(ctx context.Context, tx *sql.Tx, messageID string, inputs []ArtifactInput, createdAt time.Time) ([]Artifact, error) { if len(inputs) == 0 { return nil, nil } artifacts := make([]Artifact, 0, len(inputs)) for _, input := range inputs { artifact := Artifact{ ArtifactID: newID("art"), MessageID: messageID, Path: input.Path, Kind: defaultString(input.Kind, "file"), MetadataJSON: json.RawMessage(normalizeJSON(input.MetadataJSON)), CreatedAt: createdAt, } _, err := tx.ExecContext( ctx, `INSERT INTO artifacts ( artifact_id, message_id, path, kind, metadata_json, created_at ) VALUES (?, ?, ?, ?, ?, ?)`, artifact.ArtifactID, artifact.MessageID, artifact.Path, artifact.Kind, string(artifact.MetadataJSON), formatTime(artifact.CreatedAt), ) if err != nil { return nil, fmt.Errorf("insert artifact: %w", err) } artifacts = append(artifacts, artifact) } return artifacts, nil } func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assignedTo, latestMessageID string, updatedAt time.Time) error { _, err := tx.ExecContext( ctx, `UPDATE threads SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ? WHERE thread_id = ?`, status, assignedTo, latestMessageID, formatTime(updatedAt), threadID, ) if err != nil { return fmt.Errorf("update thread state: %w", err) } return nil } func loadArtifactsForMessageIDs(ctx context.Context, db *sql.DB, messageIDs []string) (map[string][]Artifact, error) { result := make(map[string][]Artifact) if len(messageIDs) == 0 { return result, nil } args := make([]any, 0, len(messageIDs)) for _, messageID := range messageIDs { args = append(args, messageID) } rows, err := db.QueryContext( ctx, `SELECT artifact_id, message_id, path, kind, metadata_json, created_at FROM artifacts WHERE message_id IN (`+placeholders(len(messageIDs))+`) ORDER BY created_at ASC`, args..., ) if err != nil { return nil, fmt.Errorf("query artifacts: %w", err) } defer rows.Close() for rows.Next() { artifact, err := scanArtifact(rows) if err != nil { return nil, err } result[artifact.MessageID] = append(result[artifact.MessageID], artifact) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate artifacts: %w", err) } return result, nil } func attachArtifacts(messages []Message, artifactsByMessageID map[string][]Artifact) { for i := range messages { messages[i].Artifacts = artifactsByMessageID[messages[i].MessageID] } } func messageIDs(messages []Message) []string { ids := make([]string, 0, len(messages)) for _, message := range messages { ids = append(ids, message.MessageID) } return ids } func requireActiveLease(ctx context.Context, tx *sql.Tx, threadID, agent string, now time.Time) (string, error) { var ( activeAgent string leaseToken string expiresAt string releasedAt sql.NullString ) err := tx.QueryRowContext( ctx, `SELECT agent_id, lease_token, expires_at, released_at FROM leases WHERE thread_id = ?`, threadID, ).Scan(&activeAgent, &leaseToken, &expiresAt, &releasedAt) if errors.Is(err, sql.ErrNoRows) { return "", ErrNoActiveLease } if err != nil { return "", fmt.Errorf("read lease: %w", err) } if releasedAt.Valid || !parseTime(expiresAt).After(now) { return "", ErrNoActiveLease } if activeAgent != agent { return "", ErrLeaseConflict } return leaseToken, nil } func (s *InboxStore) lookupEventIDForMessage(ctx context.Context, threadID, messageID string) (int64, error) { var eventID int64 err := s.db.QueryRowContext( ctx, `SELECT event_id FROM events WHERE thread_id = ? AND message_id = ? ORDER BY event_id DESC LIMIT 1`, threadID, messageID, ).Scan(&eventID) if errors.Is(err, sql.ErrNoRows) { return 0, fmt.Errorf("%w: message %s not found in thread %s", ErrMessageNotFound, messageID, threadID) } if err != nil { return 0, fmt.Errorf("lookup message event: %w", err) } return eventID, nil } func (s *InboxStore) currentMaxEventID(ctx context.Context) (int64, error) { var maxEventID int64 if err := s.db.QueryRowContext(ctx, `SELECT COALESCE(MAX(event_id), 0) FROM events`).Scan(&maxEventID); err != nil { return 0, fmt.Errorf("query max event id: %w", err) } return maxEventID, nil } func (s *InboxStore) findReplyAfter(ctx context.Context, threadID string, afterEventID int64, kinds []string) (Message, int64, bool, error) { args := []any{threadID, afterEventID} query := `SELECT e.event_id, m.message_id, m.thread_id, m.from_agent, m.to_agent, m.kind, m.summary, m.body, m.payload_json, m.created_at FROM events e JOIN messages m ON m.message_id = e.message_id WHERE e.thread_id = ? AND e.event_id > ?` if len(kinds) > 0 { query += " AND m.kind IN (" + placeholders(len(kinds)) + ")" for _, kind := range kinds { args = append(args, kind) } } query += " ORDER BY e.event_id ASC LIMIT 1" row := s.db.QueryRowContext(ctx, query, args...) var ( eventID int64 message Message payload string created string ) err := row.Scan( &eventID, &message.MessageID, &message.ThreadID, &message.FromAgent, &message.ToAgent, &message.Kind, &message.Summary, &message.Body, &payload, &created, ) if errors.Is(err, sql.ErrNoRows) { return Message{}, 0, false, nil } if err != nil { return Message{}, 0, false, fmt.Errorf("query reply after event %d: %w", afterEventID, err) } message.PayloadJSON = json.RawMessage(payload) message.CreatedAt = parseTime(created) artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, []string{message.MessageID}) if err != nil { return Message{}, 0, false, err } message.Artifacts = artifactsByMessageID[message.MessageID] return message, eventID, true, nil } func (s *InboxStore) findWatchEventAfter(ctx context.Context, input WatchInput, afterEventID int64) (Thread, Message, Event, bool, error) { args := []any{afterEventID} query := `SELECT t.thread_id, t.run_id, t.task_id, t.subject, t.created_by, t.assigned_to, t.status, t.priority, t.latest_message_id, t.created_at, t.updated_at, e.event_id, e.run_id, e.task_id, e.thread_id, e.source, e.event_type, e.message_id, e.summary, e.payload_json, e.created_at, m.message_id, m.thread_id, m.from_agent, m.to_agent, m.kind, m.summary, m.body, m.payload_json, m.created_at FROM events e JOIN threads t ON t.thread_id = e.thread_id JOIN messages m ON m.message_id = e.message_id WHERE e.event_id > ?` if input.Agent != "" { query += " AND t.assigned_to = ?" args = append(args, input.Agent) } if len(input.Statuses) > 0 { query += " AND t.status IN (" + placeholders(len(input.Statuses)) + ")" for _, status := range input.Statuses { args = append(args, status) } } query += " ORDER BY e.event_id ASC LIMIT 1" row := s.db.QueryRowContext(ctx, query, args...) var ( thread Thread threadCreatedAt string threadUpdatedAt string threadLatestMessage sql.NullString event Event eventMessageID sql.NullString eventPayload string eventCreatedAt string message Message messagePayload string messageCreatedAt string ) err := row.Scan( &thread.ThreadID, &thread.RunID, &thread.TaskID, &thread.Subject, &thread.CreatedBy, &thread.AssignedTo, &thread.Status, &thread.Priority, &threadLatestMessage, &threadCreatedAt, &threadUpdatedAt, &event.EventID, &event.RunID, &event.TaskID, &event.ThreadID, &event.Source, &event.EventType, &eventMessageID, &event.Summary, &eventPayload, &eventCreatedAt, &message.MessageID, &message.ThreadID, &message.FromAgent, &message.ToAgent, &message.Kind, &message.Summary, &message.Body, &messagePayload, &messageCreatedAt, ) if errors.Is(err, sql.ErrNoRows) { return Thread{}, Message{}, Event{}, false, nil } if err != nil { return Thread{}, Message{}, Event{}, false, fmt.Errorf("query watch event after %d: %w", afterEventID, err) } if threadLatestMessage.Valid { thread.LatestMessageID = threadLatestMessage.String } thread.CreatedAt = parseTime(threadCreatedAt) thread.UpdatedAt = parseTime(threadUpdatedAt) if eventMessageID.Valid { event.MessageID = eventMessageID.String } event.PayloadJSON = json.RawMessage(eventPayload) event.CreatedAt = parseTime(eventCreatedAt) message.PayloadJSON = json.RawMessage(messagePayload) message.CreatedAt = parseTime(messageCreatedAt) artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, []string{message.MessageID}) if err != nil { return Thread{}, Message{}, Event{}, false, err } message.Artifacts = artifactsByMessageID[message.MessageID] return thread, message, event, true, nil } func waitForNextPoll(ctx context.Context, interval time.Duration) (bool, error) { timer := time.NewTimer(interval) defer timer.Stop() select { case <-ctx.Done(): return false, ctx.Err() case <-timer.C: return true, nil } } func isTerminalStatus(status string) bool { return status == "done" || status == "failed" || status == "cancelled" } func isDeadlineExceeded(ctx context.Context) bool { return ctx.Err() != nil && errors.Is(ctx.Err(), context.DeadlineExceeded) } func defaultID(value, prefix string) string { if value != "" { return value } return newID(prefix) } func newID(prefix string) string { return prefix + "_" + strings.ReplaceAll(uuid.NewString(), "-", "") } func defaultString(value, fallback string) string { if value != "" { return value } return fallback } func normalizeJSON(value string) string { if strings.TrimSpace(value) == "" { return "{}" } return value } func placeholders(n int) string { if n <= 0 { return "" } parts := make([]string, n) for i := range parts { parts[i] = "?" } return strings.Join(parts, ",") } func nowUTC() time.Time { return time.Now().UTC() } func formatTime(t time.Time) string { return t.UTC().Format(time.RFC3339Nano) } func parseTime(value string) time.Time { parsed, err := time.Parse(time.RFC3339Nano, value) if err != nil { return time.Time{} } return parsed }