package store import ( "context" "database/sql" "encoding/json" "errors" "fmt" "strings" "time" "github.com/google/uuid" ) var ErrLeaseConflict = errors.New("thread already claimed by another worker") type InboxStore struct { db *sql.DB } type Thread struct { ThreadID string `json:"thread_id"` RunID string `json:"run_id"` TaskID string `json:"task_id"` Subject string `json:"subject"` CreatedBy string `json:"created_by"` AssignedTo string `json:"assigned_to"` Status string `json:"status"` Priority string `json:"priority"` LatestMessageID string `json:"latest_message_id,omitempty"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } type Message struct { MessageID string `json:"message_id"` ThreadID string `json:"thread_id"` FromAgent string `json:"from_agent"` ToAgent string `json:"to_agent"` Kind string `json:"kind"` Summary string `json:"summary"` Body string `json:"body"` PayloadJSON json.RawMessage `json:"payload_json"` CreatedAt time.Time `json:"created_at"` } type ThreadDetail struct { Thread Thread `json:"thread"` Messages []Message `json:"messages"` } type SendInput struct { ThreadID string RunID string TaskID string Subject string FromAgent string ToAgent string Kind string Summary string Body string PayloadJSON string Priority string } type FetchInput struct { Agent string Statuses []string Limit int } type ClaimInput struct { ThreadID string Agent string LeaseSeconds int } type ClaimResult struct { Thread Thread `json:"thread"` Message Message `json:"message"` } type UpdateInput struct { ThreadID string Agent string Status string Summary string Body string PayloadJSON string } type ReplyInput struct { ThreadID string FromAgent string ToAgent string Kind string Summary string Body string PayloadJSON string } type CompleteInput struct { ThreadID string Agent string Summary string Body string PayloadJSON string Failed bool } func NewInboxStore(db *sql.DB) *InboxStore { return &InboxStore{db: db} } func (s *InboxStore) Send(ctx context.Context, input SendInput) (Thread, Message, error) { now := nowUTC() threadID := defaultID(input.ThreadID, "thr") runID := defaultID(input.RunID, "run") taskID := defaultID(input.TaskID, "task") kind := defaultString(input.Kind, "task") priority := defaultString(input.Priority, "normal") summary := defaultString(input.Summary, input.Subject) payload := normalizeJSON(input.PayloadJSON) messageID := newID("msg") tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Thread{}, Message{}, fmt.Errorf("begin send transaction: %w", err) } defer tx.Rollback() thread := Thread{ ThreadID: threadID, RunID: runID, TaskID: taskID, Subject: input.Subject, CreatedBy: input.FromAgent, AssignedTo: input.ToAgent, Status: "pending", Priority: priority, LatestMessageID: messageID, CreatedAt: now, UpdatedAt: now, } if _, err := tx.ExecContext( ctx, `INSERT INTO threads ( thread_id, run_id, task_id, subject, created_by, assigned_to, status, priority, latest_message_id, created_at, updated_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, thread.ThreadID, thread.RunID, thread.TaskID, thread.Subject, thread.CreatedBy, thread.AssignedTo, thread.Status, thread.Priority, thread.LatestMessageID, formatTime(thread.CreatedAt), formatTime(thread.UpdatedAt), ); err != nil { return Thread{}, Message{}, fmt.Errorf("insert thread: %w", err) } message := Message{ MessageID: messageID, ThreadID: threadID, FromAgent: input.FromAgent, ToAgent: input.ToAgent, Kind: kind, Summary: summary, Body: input.Body, PayloadJSON: json.RawMessage(payload), CreatedAt: now, } if _, err := tx.ExecContext( ctx, `INSERT INTO messages ( message_id, thread_id, from_agent, to_agent, kind, summary, body, payload_json, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, message.MessageID, message.ThreadID, message.FromAgent, message.ToAgent, message.Kind, message.Summary, message.Body, string(message.PayloadJSON), formatTime(message.CreatedAt), ); err != nil { return Thread{}, Message{}, fmt.Errorf("insert message: %w", err) } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_created", MessageID: message.MessageID, Summary: summary, PayloadJSON: payload, CreatedAt: now, }); err != nil { return Thread{}, Message{}, err } if err := tx.Commit(); err != nil { return Thread{}, Message{}, fmt.Errorf("commit send transaction: %w", err) } return thread, message, nil } func (s *InboxStore) FetchThreads(ctx context.Context, input FetchInput) ([]Thread, error) { statuses := input.Statuses if len(statuses) == 0 { statuses = []string{"pending"} } limit := input.Limit if limit <= 0 { limit = 20 } var args []any var conditions []string if input.Agent != "" { conditions = append(conditions, "assigned_to = ?") args = append(args, input.Agent) } conditions = append(conditions, "status IN ("+placeholders(len(statuses))+")") for _, status := range statuses { args = append(args, status) } args = append(args, limit) query := `SELECT thread_id, run_id, task_id, subject, created_by, assigned_to, status, priority, latest_message_id, created_at, updated_at FROM threads` if len(conditions) > 0 { query += " WHERE " + strings.Join(conditions, " AND ") } query += " ORDER BY updated_at DESC LIMIT ?" rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, fmt.Errorf("fetch threads: %w", err) } defer rows.Close() var threads []Thread for rows.Next() { thread, err := scanThread(rows) if err != nil { return nil, err } threads = append(threads, thread) } if err := rows.Err(); err != nil { return nil, fmt.Errorf("iterate threads: %w", err) } return threads, nil } func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimResult, error) { if input.LeaseSeconds <= 0 { input.LeaseSeconds = 900 } now := nowUTC() expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second) leaseToken := newID("lease") messageID := newID("msg") tx, err := s.db.BeginTx(ctx, nil) if err != nil { return ClaimResult{}, fmt.Errorf("begin claim transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) if err != nil { return ClaimResult{}, err } if thread.Status != "pending" { return ClaimResult{}, fmt.Errorf("thread %s is not pending", input.ThreadID) } var activeLease string err = tx.QueryRowContext( ctx, `SELECT agent_id FROM leases WHERE thread_id = ? AND released_at IS NULL AND expires_at > ?`, input.ThreadID, formatTime(now), ).Scan(&activeLease) if err != nil && !errors.Is(err, sql.ErrNoRows) { return ClaimResult{}, fmt.Errorf("check active lease: %w", err) } if activeLease != "" { return ClaimResult{}, ErrLeaseConflict } if _, err := tx.ExecContext( ctx, `INSERT INTO leases ( thread_id, agent_id, lease_token, claimed_at, expires_at, released_at ) VALUES (?, ?, ?, ?, ?, NULL) ON CONFLICT(thread_id) DO UPDATE SET agent_id = excluded.agent_id, lease_token = excluded.lease_token, claimed_at = excluded.claimed_at, expires_at = excluded.expires_at, released_at = NULL`, input.ThreadID, input.Agent, leaseToken, formatTime(now), formatTime(expiresAt), ); err != nil { return ClaimResult{}, fmt.Errorf("upsert lease: %w", err) } if _, err := tx.ExecContext( ctx, `UPDATE threads SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ? WHERE thread_id = ?`, "claimed", input.Agent, messageID, formatTime(now), input.ThreadID, ); err != nil { return ClaimResult{}, fmt.Errorf("update thread claim status: %w", err) } message := Message{ MessageID: messageID, ThreadID: input.ThreadID, FromAgent: input.Agent, ToAgent: input.Agent, Kind: "event", Summary: "thread claimed", Body: "", PayloadJSON: json.RawMessage(fmt.Sprintf(`{"lease_seconds":%d,"lease_token":"%s"}`, input.LeaseSeconds, leaseToken)), CreatedAt: now, } if _, err := tx.ExecContext( ctx, `INSERT INTO messages ( message_id, thread_id, from_agent, to_agent, kind, summary, body, payload_json, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, message.MessageID, message.ThreadID, message.FromAgent, message.ToAgent, message.Kind, message.Summary, message.Body, string(message.PayloadJSON), formatTime(message.CreatedAt), ); err != nil { return ClaimResult{}, fmt.Errorf("insert claim event message: %w", err) } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_claimed", MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: string(message.PayloadJSON), CreatedAt: now, }); err != nil { return ClaimResult{}, err } if err := tx.Commit(); err != nil { return ClaimResult{}, fmt.Errorf("commit claim transaction: %w", err) } thread.Status = "claimed" thread.AssignedTo = input.Agent thread.LatestMessageID = messageID thread.UpdatedAt = now return ClaimResult{ Thread: thread, Message: message, }, nil } func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") if input.Status != "in_progress" && input.Status != "blocked" { return Thread{}, Message{}, fmt.Errorf("unsupported update status %q", input.Status) } tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Thread{}, Message{}, fmt.Errorf("begin update transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) if err != nil { return Thread{}, Message{}, err } if thread.Status == "done" || thread.Status == "failed" || thread.Status == "cancelled" { return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID) } kind := "progress" if input.Status == "blocked" { kind = "question" } message := Message{ MessageID: messageID, ThreadID: thread.ThreadID, FromAgent: input.Agent, ToAgent: thread.CreatedBy, Kind: kind, Summary: input.Summary, Body: input.Body, PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), CreatedAt: now, } if err := insertMessage(ctx, tx, message); err != nil { return Thread{}, Message{}, err } if err := updateThreadState(ctx, tx, thread.ThreadID, input.Status, thread.AssignedTo, message.MessageID, now); err != nil { return Thread{}, Message{}, err } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_" + input.Status, MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: string(message.PayloadJSON), CreatedAt: now, }); err != nil { return Thread{}, Message{}, err } if err := tx.Commit(); err != nil { return Thread{}, Message{}, fmt.Errorf("commit update transaction: %w", err) } thread.Status = input.Status thread.LatestMessageID = message.MessageID thread.UpdatedAt = now return thread, message, nil } func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Thread{}, Message{}, fmt.Errorf("begin reply transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) if err != nil { return Thread{}, Message{}, err } message := Message{ MessageID: messageID, ThreadID: thread.ThreadID, FromAgent: input.FromAgent, ToAgent: input.ToAgent, Kind: defaultString(input.Kind, "answer"), Summary: input.Summary, Body: input.Body, PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), CreatedAt: now, } if err := insertMessage(ctx, tx, message); err != nil { return Thread{}, Message{}, err } if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil { return Thread{}, Message{}, err } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: "thread_replied", MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: string(message.PayloadJSON), CreatedAt: now, }); err != nil { return Thread{}, Message{}, err } if err := tx.Commit(); err != nil { return Thread{}, Message{}, fmt.Errorf("commit reply transaction: %w", err) } thread.LatestMessageID = message.MessageID thread.UpdatedAt = now return thread, message, nil } func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") nextStatus := "done" eventType := "thread_done" summary := input.Summary if input.Failed { nextStatus = "failed" eventType = "thread_failed" } tx, err := s.db.BeginTx(ctx, nil) if err != nil { return Thread{}, Message{}, fmt.Errorf("begin complete transaction: %w", err) } defer tx.Rollback() thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) if err != nil { return Thread{}, Message{}, err } message := Message{ MessageID: messageID, ThreadID: thread.ThreadID, FromAgent: input.Agent, ToAgent: thread.CreatedBy, Kind: "result", Summary: summary, Body: input.Body, PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), CreatedAt: now, } if err := insertMessage(ctx, tx, message); err != nil { return Thread{}, Message{}, err } if err := updateThreadState(ctx, tx, thread.ThreadID, nextStatus, thread.AssignedTo, message.MessageID, now); err != nil { return Thread{}, Message{}, err } if _, err := tx.ExecContext( ctx, `UPDATE leases SET released_at = ? WHERE thread_id = ? AND released_at IS NULL`, formatTime(now), thread.ThreadID, ); err != nil { return Thread{}, Message{}, fmt.Errorf("release lease: %w", err) } if err := insertEvent(ctx, tx, eventInput{ RunID: thread.RunID, TaskID: thread.TaskID, ThreadID: thread.ThreadID, Source: "inbox", EventType: eventType, MessageID: message.MessageID, Summary: message.Summary, PayloadJSON: string(message.PayloadJSON), CreatedAt: now, }); err != nil { return Thread{}, Message{}, err } if err := tx.Commit(); err != nil { return Thread{}, Message{}, fmt.Errorf("commit complete transaction: %w", err) } thread.Status = nextStatus thread.LatestMessageID = message.MessageID thread.UpdatedAt = now return thread, message, nil } func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDetail, error) { thread, err := selectThread(ctx, s.db, threadID) if err != nil { return ThreadDetail{}, err } rows, err := s.db.QueryContext( ctx, `SELECT message_id, thread_id, from_agent, to_agent, kind, summary, body, payload_json, created_at FROM messages WHERE thread_id = ? ORDER BY created_at ASC`, threadID, ) if err != nil { return ThreadDetail{}, fmt.Errorf("query thread messages: %w", err) } defer rows.Close() var messages []Message for rows.Next() { message, err := scanMessage(rows) if err != nil { return ThreadDetail{}, err } messages = append(messages, message) } if err := rows.Err(); err != nil { return ThreadDetail{}, fmt.Errorf("iterate thread messages: %w", err) } return ThreadDetail{ Thread: thread, Messages: messages, }, nil } type threadScanner interface { Scan(dest ...any) error } func scanThread(scanner threadScanner) (Thread, error) { var ( thread Thread createdAt, updatedAt string latestMessageID sql.NullString ) if err := scanner.Scan( &thread.ThreadID, &thread.RunID, &thread.TaskID, &thread.Subject, &thread.CreatedBy, &thread.AssignedTo, &thread.Status, &thread.Priority, &latestMessageID, &createdAt, &updatedAt, ); err != nil { return Thread{}, fmt.Errorf("scan thread: %w", err) } thread.CreatedAt = parseTime(createdAt) thread.UpdatedAt = parseTime(updatedAt) if latestMessageID.Valid { thread.LatestMessageID = latestMessageID.String } return thread, nil } func scanMessage(scanner threadScanner) (Message, error) { var ( message Message payload, createdAt string ) if err := scanner.Scan( &message.MessageID, &message.ThreadID, &message.FromAgent, &message.ToAgent, &message.Kind, &message.Summary, &message.Body, &payload, &createdAt, ); err != nil { return Message{}, fmt.Errorf("scan message: %w", err) } message.PayloadJSON = json.RawMessage(payload) message.CreatedAt = parseTime(createdAt) return message, nil } func selectThread(ctx context.Context, db queryRower, threadID string) (Thread, error) { row := db.QueryRowContext( ctx, `SELECT thread_id, run_id, task_id, subject, created_by, assigned_to, status, priority, latest_message_id, created_at, updated_at FROM threads WHERE thread_id = ?`, threadID, ) thread, err := scanThread(row) if errors.Is(err, sql.ErrNoRows) { return Thread{}, fmt.Errorf("thread %s not found", threadID) } return thread, err } func selectThreadForUpdate(ctx context.Context, tx *sql.Tx, threadID string) (Thread, error) { return selectThread(ctx, tx, threadID) } type queryRower interface { QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row } type eventInput struct { RunID string TaskID string ThreadID string Source string EventType string MessageID string Summary string PayloadJSON string CreatedAt time.Time } func insertEvent(ctx context.Context, tx *sql.Tx, input eventInput) error { _, err := tx.ExecContext( ctx, `INSERT INTO events ( run_id, task_id, thread_id, source, event_type, message_id, summary, payload_json, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, input.RunID, input.TaskID, input.ThreadID, input.Source, input.EventType, input.MessageID, input.Summary, normalizeJSON(input.PayloadJSON), formatTime(input.CreatedAt), ) if err != nil { return fmt.Errorf("insert event: %w", err) } return nil } func insertMessage(ctx context.Context, tx *sql.Tx, message Message) error { _, err := tx.ExecContext( ctx, `INSERT INTO messages ( message_id, thread_id, from_agent, to_agent, kind, summary, body, payload_json, created_at ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, message.MessageID, message.ThreadID, message.FromAgent, message.ToAgent, message.Kind, message.Summary, message.Body, string(message.PayloadJSON), formatTime(message.CreatedAt), ) if err != nil { return fmt.Errorf("insert message: %w", err) } return nil } func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assignedTo, latestMessageID string, updatedAt time.Time) error { _, err := tx.ExecContext( ctx, `UPDATE threads SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ? WHERE thread_id = ?`, status, assignedTo, latestMessageID, formatTime(updatedAt), threadID, ) if err != nil { return fmt.Errorf("update thread state: %w", err) } return nil } func defaultID(value, prefix string) string { if value != "" { return value } return newID(prefix) } func newID(prefix string) string { return prefix + "_" + strings.ReplaceAll(uuid.NewString(), "-", "") } func defaultString(value, fallback string) string { if value != "" { return value } return fallback } func normalizeJSON(value string) string { if strings.TrimSpace(value) == "" { return "{}" } return value } func placeholders(n int) string { if n <= 0 { return "" } parts := make([]string, n) for i := range parts { parts[i] = "?" } return strings.Join(parts, ",") } func nowUTC() time.Time { return time.Now().UTC() } func formatTime(t time.Time) string { return t.UTC().Format(time.RFC3339Nano) } func parseTime(value string) time.Time { parsed, err := time.Parse(time.RFC3339Nano, value) if err != nil { return time.Time{} } return parsed }