1757 lines
44 KiB
Go
1757 lines
44 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
var ErrLeaseConflict = errors.New("thread already claimed by another worker")
|
|
var ErrThreadNotFound = errors.New("thread not found")
|
|
var ErrMessageNotFound = errors.New("message not found")
|
|
var ErrNoActiveLease = errors.New("no active lease")
|
|
var ErrInvalidInput = errors.New("invalid input")
|
|
var ErrInvalidState = errors.New("invalid state")
|
|
|
|
type InboxStore struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
type Thread struct {
|
|
ThreadID string `json:"thread_id"`
|
|
RunID string `json:"run_id"`
|
|
TaskID string `json:"task_id"`
|
|
Subject string `json:"subject"`
|
|
CreatedBy string `json:"created_by"`
|
|
AssignedTo string `json:"assigned_to"`
|
|
Status string `json:"status"`
|
|
Priority string `json:"priority"`
|
|
LatestMessageID string `json:"latest_message_id,omitempty"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
type Message struct {
|
|
MessageID string `json:"message_id"`
|
|
ThreadID string `json:"thread_id"`
|
|
FromAgent string `json:"from_agent"`
|
|
ToAgent string `json:"to_agent"`
|
|
Kind string `json:"kind"`
|
|
Summary string `json:"summary"`
|
|
Body string `json:"body"`
|
|
PayloadJSON json.RawMessage `json:"payload_json"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
Artifacts []Artifact `json:"artifacts,omitempty"`
|
|
}
|
|
|
|
type Artifact struct {
|
|
ArtifactID string `json:"artifact_id"`
|
|
MessageID string `json:"message_id"`
|
|
Path string `json:"path"`
|
|
Kind string `json:"kind"`
|
|
MetadataJSON json.RawMessage `json:"metadata_json"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
type ArtifactInput struct {
|
|
Path string
|
|
Kind string
|
|
MetadataJSON string
|
|
}
|
|
|
|
type ThreadDetail struct {
|
|
Thread Thread `json:"thread"`
|
|
Messages []Message `json:"messages"`
|
|
}
|
|
|
|
type Event struct {
|
|
EventID int64 `json:"event_id"`
|
|
RunID string `json:"run_id"`
|
|
TaskID string `json:"task_id"`
|
|
ThreadID string `json:"thread_id,omitempty"`
|
|
Source string `json:"source"`
|
|
EventType string `json:"event_type"`
|
|
MessageID string `json:"message_id,omitempty"`
|
|
Summary string `json:"summary"`
|
|
PayloadJSON json.RawMessage `json:"payload_json"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
}
|
|
|
|
type SendInput struct {
|
|
ThreadID string
|
|
RunID string
|
|
TaskID string
|
|
Subject string
|
|
FromAgent string
|
|
ToAgent string
|
|
Kind string
|
|
Summary string
|
|
Body string
|
|
PayloadJSON string
|
|
Priority string
|
|
Artifacts []ArtifactInput
|
|
}
|
|
|
|
type FetchInput struct {
|
|
Agent string
|
|
Statuses []string
|
|
Limit int
|
|
Unread bool
|
|
}
|
|
|
|
type ClaimInput struct {
|
|
ThreadID string
|
|
Agent string
|
|
LeaseSeconds int
|
|
}
|
|
|
|
type RenewInput struct {
|
|
ThreadID string
|
|
Agent string
|
|
LeaseSeconds int
|
|
}
|
|
|
|
type ClaimResult struct {
|
|
Thread Thread `json:"thread"`
|
|
Message Message `json:"message"`
|
|
}
|
|
|
|
type UpdateInput struct {
|
|
ThreadID string
|
|
Agent string
|
|
Status string
|
|
Summary string
|
|
Body string
|
|
PayloadJSON string
|
|
Artifacts []ArtifactInput
|
|
}
|
|
|
|
type ReplyInput struct {
|
|
ThreadID string
|
|
FromAgent string
|
|
ToAgent string
|
|
Kind string
|
|
Summary string
|
|
Body string
|
|
PayloadJSON string
|
|
Artifacts []ArtifactInput
|
|
}
|
|
|
|
type CompleteInput struct {
|
|
ThreadID string
|
|
Agent string
|
|
Summary string
|
|
Body string
|
|
PayloadJSON string
|
|
Failed bool
|
|
Artifacts []ArtifactInput
|
|
}
|
|
|
|
type CancelInput struct {
|
|
ThreadID string
|
|
Agent string
|
|
Reason string
|
|
Artifacts []ArtifactInput
|
|
}
|
|
|
|
type ListInput struct {
|
|
Agent string
|
|
Statuses []string
|
|
CreatedBy string
|
|
AssignedTo string
|
|
Limit int
|
|
Unread bool
|
|
}
|
|
|
|
type WatchInput struct {
|
|
Agent string
|
|
Statuses []string
|
|
AfterEventID int64
|
|
StartFromNow bool
|
|
Timeout time.Duration
|
|
}
|
|
|
|
type WatchResult struct {
|
|
Woke bool `json:"woke"`
|
|
NextEventID int64 `json:"next_event_id"`
|
|
Thread *Thread `json:"thread,omitempty"`
|
|
Message *Message `json:"message,omitempty"`
|
|
Event *Event `json:"event,omitempty"`
|
|
}
|
|
|
|
type WaitReplyInput struct {
|
|
ThreadID string
|
|
AfterMessageID string
|
|
AfterEventID int64
|
|
Kinds []string
|
|
Timeout time.Duration
|
|
}
|
|
|
|
type WaitReplyResult struct {
|
|
Woke bool `json:"woke"`
|
|
NextEventID int64 `json:"next_event_id"`
|
|
Message *Message `json:"message,omitempty"`
|
|
}
|
|
|
|
func NewInboxStore(db *sql.DB) *InboxStore {
|
|
return &InboxStore{db: db}
|
|
}
|
|
|
|
func (s *InboxStore) Send(ctx context.Context, input SendInput) (Thread, Message, error) {
|
|
if input.ThreadID != "" {
|
|
thread, err := selectThread(ctx, s.db, input.ThreadID)
|
|
if err == nil {
|
|
return s.appendThreadMessage(ctx, thread, input)
|
|
}
|
|
if !errors.Is(err, ErrThreadNotFound) {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
}
|
|
|
|
return s.createThread(ctx, input)
|
|
}
|
|
|
|
func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread, Message, error) {
|
|
now := nowUTC()
|
|
|
|
threadID := defaultID(input.ThreadID, "thr")
|
|
runID := defaultID(input.RunID, "run")
|
|
taskID := defaultID(input.TaskID, "task")
|
|
kind := defaultString(input.Kind, "task")
|
|
priority := defaultString(input.Priority, "normal")
|
|
summary := defaultString(input.Summary, input.Subject)
|
|
payload := normalizeJSON(input.PayloadJSON)
|
|
messageID := newID("msg")
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("begin send transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
thread := Thread{
|
|
ThreadID: threadID,
|
|
RunID: runID,
|
|
TaskID: taskID,
|
|
Subject: input.Subject,
|
|
CreatedBy: input.FromAgent,
|
|
AssignedTo: input.ToAgent,
|
|
Status: "pending",
|
|
Priority: priority,
|
|
LatestMessageID: messageID,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO threads (
|
|
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
|
|
priority, latest_message_id, created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
thread.ThreadID,
|
|
thread.RunID,
|
|
thread.TaskID,
|
|
thread.Subject,
|
|
thread.CreatedBy,
|
|
thread.AssignedTo,
|
|
thread.Status,
|
|
thread.Priority,
|
|
thread.LatestMessageID,
|
|
formatTime(thread.CreatedAt),
|
|
formatTime(thread.UpdatedAt),
|
|
); err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("insert thread: %w", err)
|
|
}
|
|
|
|
message := Message{
|
|
MessageID: messageID,
|
|
ThreadID: threadID,
|
|
FromAgent: input.FromAgent,
|
|
ToAgent: input.ToAgent,
|
|
Kind: kind,
|
|
Summary: summary,
|
|
Body: input.Body,
|
|
PayloadJSON: json.RawMessage(payload),
|
|
CreatedAt: now,
|
|
}
|
|
if err := insertMessage(ctx, tx, message); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
message.Artifacts = artifacts
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: thread.RunID,
|
|
TaskID: thread.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "inbox",
|
|
EventType: "thread_created",
|
|
MessageID: message.MessageID,
|
|
Summary: summary,
|
|
PayloadJSON: payload,
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("commit send transaction: %w", err)
|
|
}
|
|
|
|
return thread, message, nil
|
|
}
|
|
|
|
func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, input SendInput) (Thread, Message, error) {
|
|
now := nowUTC()
|
|
messageID := newID("msg")
|
|
payload := normalizeJSON(input.PayloadJSON)
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("begin append transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
thread, err := selectThreadForUpdate(ctx, tx, existing.ThreadID)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
if isTerminalStatus(thread.Status) {
|
|
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, thread.ThreadID)
|
|
}
|
|
|
|
assignedTo := thread.AssignedTo
|
|
if input.ToAgent != "" {
|
|
assignedTo = input.ToAgent
|
|
}
|
|
|
|
message := Message{
|
|
MessageID: messageID,
|
|
ThreadID: thread.ThreadID,
|
|
FromAgent: input.FromAgent,
|
|
ToAgent: defaultString(input.ToAgent, thread.AssignedTo),
|
|
Kind: defaultString(input.Kind, "task"),
|
|
Summary: defaultString(input.Summary, thread.Subject),
|
|
Body: input.Body,
|
|
PayloadJSON: json.RawMessage(payload),
|
|
CreatedAt: now,
|
|
}
|
|
|
|
if err := insertMessage(ctx, tx, message); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
message.Artifacts = artifacts
|
|
|
|
if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, assignedTo, message.MessageID, now); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: thread.RunID,
|
|
TaskID: thread.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "inbox",
|
|
EventType: "thread_message_sent",
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: payload,
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("commit append transaction: %w", err)
|
|
}
|
|
|
|
thread.AssignedTo = assignedTo
|
|
thread.LatestMessageID = message.MessageID
|
|
thread.UpdatedAt = now
|
|
return thread, message, nil
|
|
}
|
|
|
|
func (s *InboxStore) FetchThreads(ctx context.Context, input FetchInput) ([]Thread, error) {
|
|
statuses := input.Statuses
|
|
if len(statuses) == 0 {
|
|
statuses = []string{"pending"}
|
|
}
|
|
|
|
return s.ListThreads(ctx, ListInput{
|
|
Agent: input.Agent,
|
|
Statuses: statuses,
|
|
Limit: input.Limit,
|
|
Unread: input.Unread,
|
|
})
|
|
}
|
|
|
|
func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread, error) {
|
|
limit := input.Limit
|
|
if limit <= 0 {
|
|
limit = 20
|
|
}
|
|
|
|
var (
|
|
args []any
|
|
conditions []string
|
|
joins []string
|
|
)
|
|
|
|
assignedTo := input.AssignedTo
|
|
if assignedTo == "" {
|
|
assignedTo = input.Agent
|
|
}
|
|
|
|
if assignedTo != "" {
|
|
conditions = append(conditions, "t.assigned_to = ?")
|
|
args = append(args, assignedTo)
|
|
}
|
|
if input.CreatedBy != "" {
|
|
conditions = append(conditions, "t.created_by = ?")
|
|
args = append(args, input.CreatedBy)
|
|
}
|
|
if len(input.Statuses) > 0 {
|
|
conditions = append(conditions, "t.status IN ("+placeholders(len(input.Statuses))+")")
|
|
for _, status := range input.Statuses {
|
|
args = append(args, status)
|
|
}
|
|
}
|
|
if input.Unread {
|
|
if input.Agent == "" {
|
|
return nil, fmt.Errorf("%w: agent is required when filtering unread threads", ErrInvalidInput)
|
|
}
|
|
joins = append(joins, "JOIN messages lm ON lm.message_id = t.latest_message_id")
|
|
conditions = append(conditions, "lm.to_agent = ?")
|
|
args = append(args, input.Agent)
|
|
conditions = append(conditions, "lm.from_agent <> ?")
|
|
args = append(args, input.Agent)
|
|
}
|
|
|
|
query := `SELECT
|
|
t.thread_id, t.run_id, t.task_id, t.subject, t.created_by, t.assigned_to, t.status,
|
|
t.priority, t.latest_message_id, t.created_at, t.updated_at
|
|
FROM threads t`
|
|
if len(joins) > 0 {
|
|
query += " " + strings.Join(joins, " ")
|
|
}
|
|
if len(conditions) > 0 {
|
|
query += " WHERE " + strings.Join(conditions, " AND ")
|
|
}
|
|
query += " ORDER BY t.updated_at DESC LIMIT ?"
|
|
args = append(args, limit)
|
|
|
|
rows, err := s.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list threads: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var threads []Thread
|
|
for rows.Next() {
|
|
thread, err := scanThread(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
threads = append(threads, thread)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate threads: %w", err)
|
|
}
|
|
|
|
return threads, nil
|
|
}
|
|
|
|
func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimResult, error) {
|
|
if input.LeaseSeconds <= 0 {
|
|
input.LeaseSeconds = 900
|
|
}
|
|
|
|
now := nowUTC()
|
|
expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second)
|
|
leaseToken := newID("lease")
|
|
messageID := newID("msg")
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return ClaimResult{}, fmt.Errorf("begin claim transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID)
|
|
if err != nil {
|
|
return ClaimResult{}, err
|
|
}
|
|
if isTerminalStatus(thread.Status) {
|
|
return ClaimResult{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
|
}
|
|
|
|
var activeLease string
|
|
err = tx.QueryRowContext(
|
|
ctx,
|
|
`SELECT agent_id FROM leases
|
|
WHERE thread_id = ?
|
|
AND released_at IS NULL
|
|
AND expires_at > ?`,
|
|
input.ThreadID,
|
|
formatTime(now),
|
|
).Scan(&activeLease)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return ClaimResult{}, fmt.Errorf("check active lease: %w", err)
|
|
}
|
|
if activeLease != "" {
|
|
return ClaimResult{}, ErrLeaseConflict
|
|
}
|
|
if thread.Status != "pending" {
|
|
return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID)
|
|
}
|
|
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO leases (
|
|
thread_id, agent_id, lease_token, claimed_at, expires_at, released_at
|
|
) VALUES (?, ?, ?, ?, ?, NULL)
|
|
ON CONFLICT(thread_id) DO UPDATE SET
|
|
agent_id = excluded.agent_id,
|
|
lease_token = excluded.lease_token,
|
|
claimed_at = excluded.claimed_at,
|
|
expires_at = excluded.expires_at,
|
|
released_at = NULL`,
|
|
input.ThreadID,
|
|
input.Agent,
|
|
leaseToken,
|
|
formatTime(now),
|
|
formatTime(expiresAt),
|
|
); err != nil {
|
|
return ClaimResult{}, fmt.Errorf("upsert lease: %w", err)
|
|
}
|
|
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`UPDATE threads
|
|
SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ?
|
|
WHERE thread_id = ?`,
|
|
"claimed",
|
|
input.Agent,
|
|
messageID,
|
|
formatTime(now),
|
|
input.ThreadID,
|
|
); err != nil {
|
|
return ClaimResult{}, fmt.Errorf("update thread claim status: %w", err)
|
|
}
|
|
|
|
message := Message{
|
|
MessageID: messageID,
|
|
ThreadID: input.ThreadID,
|
|
FromAgent: input.Agent,
|
|
ToAgent: input.Agent,
|
|
Kind: "event",
|
|
Summary: "thread claimed",
|
|
Body: "",
|
|
PayloadJSON: json.RawMessage(fmt.Sprintf(`{"lease_seconds":%d,"lease_token":"%s"}`, input.LeaseSeconds, leaseToken)),
|
|
CreatedAt: now,
|
|
}
|
|
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO messages (
|
|
message_id, thread_id, from_agent, to_agent, kind, summary, body,
|
|
payload_json, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
message.MessageID,
|
|
message.ThreadID,
|
|
message.FromAgent,
|
|
message.ToAgent,
|
|
message.Kind,
|
|
message.Summary,
|
|
message.Body,
|
|
string(message.PayloadJSON),
|
|
formatTime(message.CreatedAt),
|
|
); err != nil {
|
|
return ClaimResult{}, fmt.Errorf("insert claim event message: %w", err)
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: thread.RunID,
|
|
TaskID: thread.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "inbox",
|
|
EventType: "thread_claimed",
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: string(message.PayloadJSON),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return ClaimResult{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return ClaimResult{}, fmt.Errorf("commit claim transaction: %w", err)
|
|
}
|
|
|
|
thread.Status = "claimed"
|
|
thread.AssignedTo = input.Agent
|
|
thread.LatestMessageID = messageID
|
|
thread.UpdatedAt = now
|
|
|
|
return ClaimResult{
|
|
Thread: thread,
|
|
Message: message,
|
|
}, nil
|
|
}
|
|
|
|
func (s *InboxStore) RenewLease(ctx context.Context, input RenewInput) (ClaimResult, error) {
|
|
if input.LeaseSeconds <= 0 {
|
|
input.LeaseSeconds = 900
|
|
}
|
|
|
|
now := nowUTC()
|
|
expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second)
|
|
leaseToken := newID("lease")
|
|
messageID := newID("msg")
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return ClaimResult{}, fmt.Errorf("begin renew transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID)
|
|
if err != nil {
|
|
return ClaimResult{}, err
|
|
}
|
|
if isTerminalStatus(thread.Status) {
|
|
return ClaimResult{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
|
}
|
|
|
|
if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil {
|
|
return ClaimResult{}, err
|
|
}
|
|
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`UPDATE leases
|
|
SET lease_token = ?, expires_at = ?, released_at = NULL
|
|
WHERE thread_id = ?`,
|
|
leaseToken,
|
|
formatTime(expiresAt),
|
|
input.ThreadID,
|
|
); err != nil {
|
|
return ClaimResult{}, fmt.Errorf("renew lease: %w", err)
|
|
}
|
|
|
|
message := Message{
|
|
MessageID: messageID,
|
|
ThreadID: input.ThreadID,
|
|
FromAgent: input.Agent,
|
|
ToAgent: input.Agent,
|
|
Kind: "event",
|
|
Summary: "lease renewed",
|
|
Body: "",
|
|
PayloadJSON: json.RawMessage(fmt.Sprintf(`{"lease_seconds":%d,"lease_token":"%s"}`, input.LeaseSeconds, leaseToken)),
|
|
CreatedAt: now,
|
|
}
|
|
|
|
if err := insertMessage(ctx, tx, message); err != nil {
|
|
return ClaimResult{}, err
|
|
}
|
|
|
|
if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil {
|
|
return ClaimResult{}, err
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: thread.RunID,
|
|
TaskID: thread.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "inbox",
|
|
EventType: "thread_renewed",
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: string(message.PayloadJSON),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return ClaimResult{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return ClaimResult{}, fmt.Errorf("commit renew transaction: %w", err)
|
|
}
|
|
|
|
thread.LatestMessageID = message.MessageID
|
|
thread.UpdatedAt = now
|
|
return ClaimResult{
|
|
Thread: thread,
|
|
Message: message,
|
|
}, nil
|
|
}
|
|
|
|
func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) (Thread, Message, error) {
|
|
now := nowUTC()
|
|
messageID := newID("msg")
|
|
|
|
if input.Status != "in_progress" && input.Status != "blocked" {
|
|
return Thread{}, Message{}, fmt.Errorf("%w: unsupported update status %q", ErrInvalidInput, input.Status)
|
|
}
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("begin update transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
if isTerminalStatus(thread.Status) {
|
|
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
|
}
|
|
if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
kind := "progress"
|
|
if input.Status == "blocked" {
|
|
kind = "question"
|
|
}
|
|
|
|
message := Message{
|
|
MessageID: messageID,
|
|
ThreadID: thread.ThreadID,
|
|
FromAgent: input.Agent,
|
|
ToAgent: thread.CreatedBy,
|
|
Kind: kind,
|
|
Summary: input.Summary,
|
|
Body: input.Body,
|
|
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
|
|
CreatedAt: now,
|
|
}
|
|
|
|
if err := insertMessage(ctx, tx, message); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
message.Artifacts = artifacts
|
|
|
|
if err := updateThreadState(ctx, tx, thread.ThreadID, input.Status, thread.AssignedTo, message.MessageID, now); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: thread.RunID,
|
|
TaskID: thread.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "inbox",
|
|
EventType: "thread_" + input.Status,
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: string(message.PayloadJSON),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("commit update transaction: %w", err)
|
|
}
|
|
|
|
thread.Status = input.Status
|
|
thread.LatestMessageID = message.MessageID
|
|
thread.UpdatedAt = now
|
|
return thread, message, nil
|
|
}
|
|
|
|
func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Thread, Message, error) {
|
|
now := nowUTC()
|
|
messageID := newID("msg")
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("begin reply transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
if isTerminalStatus(thread.Status) {
|
|
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
|
}
|
|
|
|
message := Message{
|
|
MessageID: messageID,
|
|
ThreadID: thread.ThreadID,
|
|
FromAgent: input.FromAgent,
|
|
ToAgent: input.ToAgent,
|
|
Kind: defaultString(input.Kind, "answer"),
|
|
Summary: input.Summary,
|
|
Body: input.Body,
|
|
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
|
|
CreatedAt: now,
|
|
}
|
|
|
|
if err := insertMessage(ctx, tx, message); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
message.Artifacts = artifacts
|
|
|
|
if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: thread.RunID,
|
|
TaskID: thread.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "inbox",
|
|
EventType: "thread_replied",
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: string(message.PayloadJSON),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("commit reply transaction: %w", err)
|
|
}
|
|
|
|
thread.LatestMessageID = message.MessageID
|
|
thread.UpdatedAt = now
|
|
return thread, message, nil
|
|
}
|
|
|
|
func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (Thread, Message, error) {
|
|
now := nowUTC()
|
|
messageID := newID("msg")
|
|
|
|
nextStatus := "done"
|
|
eventType := "thread_done"
|
|
summary := input.Summary
|
|
if input.Failed {
|
|
nextStatus = "failed"
|
|
eventType = "thread_failed"
|
|
}
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("begin complete transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
if isTerminalStatus(thread.Status) {
|
|
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
|
}
|
|
if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
message := Message{
|
|
MessageID: messageID,
|
|
ThreadID: thread.ThreadID,
|
|
FromAgent: input.Agent,
|
|
ToAgent: thread.CreatedBy,
|
|
Kind: "result",
|
|
Summary: summary,
|
|
Body: input.Body,
|
|
PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)),
|
|
CreatedAt: now,
|
|
}
|
|
|
|
if err := insertMessage(ctx, tx, message); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
message.Artifacts = artifacts
|
|
|
|
if err := updateThreadState(ctx, tx, thread.ThreadID, nextStatus, thread.AssignedTo, message.MessageID, now); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`UPDATE leases
|
|
SET released_at = ?
|
|
WHERE thread_id = ?
|
|
AND released_at IS NULL`,
|
|
formatTime(now),
|
|
thread.ThreadID,
|
|
); err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("release lease: %w", err)
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: thread.RunID,
|
|
TaskID: thread.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "inbox",
|
|
EventType: eventType,
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: string(message.PayloadJSON),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("commit complete transaction: %w", err)
|
|
}
|
|
|
|
thread.Status = nextStatus
|
|
thread.LatestMessageID = message.MessageID
|
|
thread.UpdatedAt = now
|
|
return thread, message, nil
|
|
}
|
|
|
|
func (s *InboxStore) CancelThread(ctx context.Context, input CancelInput) (Thread, Message, error) {
|
|
now := nowUTC()
|
|
messageID := newID("msg")
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("begin cancel transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
if isTerminalStatus(thread.Status) {
|
|
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
|
}
|
|
|
|
summary := defaultString(input.Reason, "thread cancelled")
|
|
message := Message{
|
|
MessageID: messageID,
|
|
ThreadID: thread.ThreadID,
|
|
FromAgent: input.Agent,
|
|
ToAgent: thread.AssignedTo,
|
|
Kind: "control",
|
|
Summary: summary,
|
|
Body: input.Reason,
|
|
PayloadJSON: json.RawMessage(`{}`),
|
|
CreatedAt: now,
|
|
}
|
|
|
|
if err := insertMessage(ctx, tx, message); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
|
if err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
message.Artifacts = artifacts
|
|
|
|
if err := updateThreadState(ctx, tx, thread.ThreadID, "cancelled", thread.AssignedTo, message.MessageID, now); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if _, err := tx.ExecContext(
|
|
ctx,
|
|
`UPDATE leases
|
|
SET released_at = ?
|
|
WHERE thread_id = ?
|
|
AND released_at IS NULL`,
|
|
formatTime(now),
|
|
thread.ThreadID,
|
|
); err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("release lease on cancel: %w", err)
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: thread.RunID,
|
|
TaskID: thread.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "inbox",
|
|
EventType: "thread_cancelled",
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: string(message.PayloadJSON),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return Thread{}, Message{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Thread{}, Message{}, fmt.Errorf("commit cancel transaction: %w", err)
|
|
}
|
|
|
|
thread.Status = "cancelled"
|
|
thread.LatestMessageID = message.MessageID
|
|
thread.UpdatedAt = now
|
|
return thread, message, nil
|
|
}
|
|
|
|
func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDetail, error) {
|
|
thread, err := selectThread(ctx, s.db, threadID)
|
|
if err != nil {
|
|
return ThreadDetail{}, err
|
|
}
|
|
|
|
rows, err := s.db.QueryContext(
|
|
ctx,
|
|
`SELECT
|
|
message_id, thread_id, from_agent, to_agent, kind, summary, body,
|
|
payload_json, created_at
|
|
FROM messages
|
|
WHERE thread_id = ?
|
|
ORDER BY created_at ASC`,
|
|
threadID,
|
|
)
|
|
if err != nil {
|
|
return ThreadDetail{}, fmt.Errorf("query thread messages: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var messages []Message
|
|
for rows.Next() {
|
|
message, err := scanMessage(rows)
|
|
if err != nil {
|
|
return ThreadDetail{}, err
|
|
}
|
|
messages = append(messages, message)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return ThreadDetail{}, fmt.Errorf("iterate thread messages: %w", err)
|
|
}
|
|
|
|
artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, messageIDs(messages))
|
|
if err != nil {
|
|
return ThreadDetail{}, err
|
|
}
|
|
attachArtifacts(messages, artifactsByMessageID)
|
|
|
|
return ThreadDetail{
|
|
Thread: thread,
|
|
Messages: messages,
|
|
}, nil
|
|
}
|
|
|
|
func (s *InboxStore) WatchThreads(ctx context.Context, input WatchInput) (WatchResult, error) {
|
|
cursor := input.AfterEventID
|
|
if input.StartFromNow && cursor == 0 {
|
|
current, err := s.currentMaxEventID(ctx)
|
|
if err != nil {
|
|
return WatchResult{}, err
|
|
}
|
|
cursor = current
|
|
}
|
|
|
|
waitCtx := ctx
|
|
cancel := func() {}
|
|
if input.Timeout > 0 {
|
|
waitCtx, cancel = context.WithTimeout(ctx, input.Timeout)
|
|
}
|
|
defer cancel()
|
|
|
|
for {
|
|
thread, message, event, found, err := s.findWatchEventAfter(waitCtx, input, cursor)
|
|
if err != nil {
|
|
if isDeadlineExceeded(waitCtx) {
|
|
return WatchResult{Woke: false, NextEventID: cursor}, nil
|
|
}
|
|
return WatchResult{}, err
|
|
}
|
|
if found {
|
|
return WatchResult{
|
|
Woke: true,
|
|
NextEventID: event.EventID,
|
|
Thread: &thread,
|
|
Message: &message,
|
|
Event: &event,
|
|
}, nil
|
|
}
|
|
|
|
ok, err := waitForNextPoll(waitCtx, 200*time.Millisecond)
|
|
if err != nil {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return WatchResult{Woke: false, NextEventID: cursor}, nil
|
|
}
|
|
return WatchResult{}, err
|
|
}
|
|
if !ok {
|
|
return WatchResult{Woke: false, NextEventID: cursor}, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *InboxStore) WaitReply(ctx context.Context, input WaitReplyInput) (WaitReplyResult, error) {
|
|
cursor := input.AfterEventID
|
|
if input.AfterMessageID != "" {
|
|
eventID, err := s.lookupEventIDForMessage(ctx, input.ThreadID, input.AfterMessageID)
|
|
if err != nil {
|
|
return WaitReplyResult{}, err
|
|
}
|
|
if eventID > cursor {
|
|
cursor = eventID
|
|
}
|
|
}
|
|
|
|
kinds := input.Kinds
|
|
if len(kinds) == 0 {
|
|
kinds = []string{"answer", "control", "result"}
|
|
}
|
|
|
|
waitCtx := ctx
|
|
cancel := func() {}
|
|
if input.Timeout > 0 {
|
|
waitCtx, cancel = context.WithTimeout(ctx, input.Timeout)
|
|
}
|
|
defer cancel()
|
|
|
|
for {
|
|
message, eventID, found, err := s.findReplyAfter(waitCtx, input.ThreadID, cursor, kinds)
|
|
if err != nil {
|
|
if isDeadlineExceeded(waitCtx) {
|
|
return WaitReplyResult{Woke: false, NextEventID: cursor}, nil
|
|
}
|
|
return WaitReplyResult{}, err
|
|
}
|
|
if found {
|
|
return WaitReplyResult{
|
|
Woke: true,
|
|
NextEventID: eventID,
|
|
Message: &message,
|
|
}, nil
|
|
}
|
|
|
|
ok, err := waitForNextPoll(waitCtx, 200*time.Millisecond)
|
|
if err != nil {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return WaitReplyResult{Woke: false, NextEventID: cursor}, nil
|
|
}
|
|
return WaitReplyResult{}, err
|
|
}
|
|
if !ok {
|
|
return WaitReplyResult{Woke: false, NextEventID: cursor}, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
type threadScanner interface {
|
|
Scan(dest ...any) error
|
|
}
|
|
|
|
func scanThread(scanner threadScanner) (Thread, error) {
|
|
var (
|
|
thread Thread
|
|
createdAt, updatedAt string
|
|
latestMessageID sql.NullString
|
|
)
|
|
|
|
if err := scanner.Scan(
|
|
&thread.ThreadID,
|
|
&thread.RunID,
|
|
&thread.TaskID,
|
|
&thread.Subject,
|
|
&thread.CreatedBy,
|
|
&thread.AssignedTo,
|
|
&thread.Status,
|
|
&thread.Priority,
|
|
&latestMessageID,
|
|
&createdAt,
|
|
&updatedAt,
|
|
); err != nil {
|
|
return Thread{}, fmt.Errorf("scan thread: %w", err)
|
|
}
|
|
|
|
thread.CreatedAt = parseTime(createdAt)
|
|
thread.UpdatedAt = parseTime(updatedAt)
|
|
if latestMessageID.Valid {
|
|
thread.LatestMessageID = latestMessageID.String
|
|
}
|
|
|
|
return thread, nil
|
|
}
|
|
|
|
func scanMessage(scanner threadScanner) (Message, error) {
|
|
var (
|
|
message Message
|
|
payload, createdAt string
|
|
)
|
|
|
|
if err := scanner.Scan(
|
|
&message.MessageID,
|
|
&message.ThreadID,
|
|
&message.FromAgent,
|
|
&message.ToAgent,
|
|
&message.Kind,
|
|
&message.Summary,
|
|
&message.Body,
|
|
&payload,
|
|
&createdAt,
|
|
); err != nil {
|
|
return Message{}, fmt.Errorf("scan message: %w", err)
|
|
}
|
|
|
|
message.PayloadJSON = json.RawMessage(payload)
|
|
message.CreatedAt = parseTime(createdAt)
|
|
return message, nil
|
|
}
|
|
|
|
func scanArtifact(scanner threadScanner) (Artifact, error) {
|
|
var (
|
|
artifact Artifact
|
|
metadata, created string
|
|
)
|
|
|
|
if err := scanner.Scan(
|
|
&artifact.ArtifactID,
|
|
&artifact.MessageID,
|
|
&artifact.Path,
|
|
&artifact.Kind,
|
|
&metadata,
|
|
&created,
|
|
); err != nil {
|
|
return Artifact{}, fmt.Errorf("scan artifact: %w", err)
|
|
}
|
|
|
|
artifact.MetadataJSON = json.RawMessage(metadata)
|
|
artifact.CreatedAt = parseTime(created)
|
|
return artifact, nil
|
|
}
|
|
|
|
func scanEvent(scanner threadScanner) (Event, error) {
|
|
var (
|
|
event Event
|
|
messageID sql.NullString
|
|
payload, createdAt string
|
|
)
|
|
|
|
if err := scanner.Scan(
|
|
&event.EventID,
|
|
&event.RunID,
|
|
&event.TaskID,
|
|
&event.ThreadID,
|
|
&event.Source,
|
|
&event.EventType,
|
|
&messageID,
|
|
&event.Summary,
|
|
&payload,
|
|
&createdAt,
|
|
); err != nil {
|
|
return Event{}, fmt.Errorf("scan event: %w", err)
|
|
}
|
|
|
|
if messageID.Valid {
|
|
event.MessageID = messageID.String
|
|
}
|
|
event.PayloadJSON = json.RawMessage(payload)
|
|
event.CreatedAt = parseTime(createdAt)
|
|
return event, nil
|
|
}
|
|
|
|
func selectThread(ctx context.Context, db queryRower, threadID string) (Thread, error) {
|
|
row := db.QueryRowContext(
|
|
ctx,
|
|
`SELECT
|
|
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
|
|
priority, latest_message_id, created_at, updated_at
|
|
FROM threads
|
|
WHERE thread_id = ?`,
|
|
threadID,
|
|
)
|
|
|
|
thread, err := scanThread(row)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return Thread{}, fmt.Errorf("%w: %s", ErrThreadNotFound, threadID)
|
|
}
|
|
return thread, err
|
|
}
|
|
|
|
func selectThreadForUpdate(ctx context.Context, tx *sql.Tx, threadID string) (Thread, error) {
|
|
return selectThread(ctx, tx, threadID)
|
|
}
|
|
|
|
type queryRower interface {
|
|
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
|
}
|
|
|
|
type eventInput struct {
|
|
RunID string
|
|
TaskID string
|
|
ThreadID string
|
|
Source string
|
|
EventType string
|
|
MessageID string
|
|
Summary string
|
|
PayloadJSON string
|
|
CreatedAt time.Time
|
|
}
|
|
|
|
func insertEvent(ctx context.Context, tx *sql.Tx, input eventInput) error {
|
|
_, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO events (
|
|
run_id, task_id, thread_id, source, event_type, message_id, summary,
|
|
payload_json, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
input.RunID,
|
|
input.TaskID,
|
|
input.ThreadID,
|
|
input.Source,
|
|
input.EventType,
|
|
input.MessageID,
|
|
input.Summary,
|
|
normalizeJSON(input.PayloadJSON),
|
|
formatTime(input.CreatedAt),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("insert event: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func insertMessage(ctx context.Context, tx *sql.Tx, message Message) error {
|
|
_, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO messages (
|
|
message_id, thread_id, from_agent, to_agent, kind, summary, body,
|
|
payload_json, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
message.MessageID,
|
|
message.ThreadID,
|
|
message.FromAgent,
|
|
message.ToAgent,
|
|
message.Kind,
|
|
message.Summary,
|
|
message.Body,
|
|
string(message.PayloadJSON),
|
|
formatTime(message.CreatedAt),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("insert message: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func insertArtifacts(ctx context.Context, tx *sql.Tx, messageID string, inputs []ArtifactInput, createdAt time.Time) ([]Artifact, error) {
|
|
if len(inputs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
artifacts := make([]Artifact, 0, len(inputs))
|
|
for _, input := range inputs {
|
|
artifact := Artifact{
|
|
ArtifactID: newID("art"),
|
|
MessageID: messageID,
|
|
Path: input.Path,
|
|
Kind: defaultString(input.Kind, "file"),
|
|
MetadataJSON: json.RawMessage(normalizeJSON(input.MetadataJSON)),
|
|
CreatedAt: createdAt,
|
|
}
|
|
|
|
_, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO artifacts (
|
|
artifact_id, message_id, path, kind, metadata_json, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?)`,
|
|
artifact.ArtifactID,
|
|
artifact.MessageID,
|
|
artifact.Path,
|
|
artifact.Kind,
|
|
string(artifact.MetadataJSON),
|
|
formatTime(artifact.CreatedAt),
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("insert artifact: %w", err)
|
|
}
|
|
|
|
artifacts = append(artifacts, artifact)
|
|
}
|
|
|
|
return artifacts, nil
|
|
}
|
|
|
|
func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assignedTo, latestMessageID string, updatedAt time.Time) error {
|
|
_, err := tx.ExecContext(
|
|
ctx,
|
|
`UPDATE threads
|
|
SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ?
|
|
WHERE thread_id = ?`,
|
|
status,
|
|
assignedTo,
|
|
latestMessageID,
|
|
formatTime(updatedAt),
|
|
threadID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("update thread state: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func loadArtifactsForMessageIDs(ctx context.Context, db *sql.DB, messageIDs []string) (map[string][]Artifact, error) {
|
|
result := make(map[string][]Artifact)
|
|
if len(messageIDs) == 0 {
|
|
return result, nil
|
|
}
|
|
|
|
args := make([]any, 0, len(messageIDs))
|
|
for _, messageID := range messageIDs {
|
|
args = append(args, messageID)
|
|
}
|
|
|
|
rows, err := db.QueryContext(
|
|
ctx,
|
|
`SELECT
|
|
artifact_id, message_id, path, kind, metadata_json, created_at
|
|
FROM artifacts
|
|
WHERE message_id IN (`+placeholders(len(messageIDs))+`)
|
|
ORDER BY created_at ASC`,
|
|
args...,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query artifacts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
artifact, err := scanArtifact(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result[artifact.MessageID] = append(result[artifact.MessageID], artifact)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate artifacts: %w", err)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func attachArtifacts(messages []Message, artifactsByMessageID map[string][]Artifact) {
|
|
for i := range messages {
|
|
messages[i].Artifacts = artifactsByMessageID[messages[i].MessageID]
|
|
}
|
|
}
|
|
|
|
func messageIDs(messages []Message) []string {
|
|
ids := make([]string, 0, len(messages))
|
|
for _, message := range messages {
|
|
ids = append(ids, message.MessageID)
|
|
}
|
|
return ids
|
|
}
|
|
|
|
func requireActiveLease(ctx context.Context, tx *sql.Tx, threadID, agent string, now time.Time) (string, error) {
|
|
var (
|
|
activeAgent string
|
|
leaseToken string
|
|
expiresAt string
|
|
releasedAt sql.NullString
|
|
)
|
|
|
|
err := tx.QueryRowContext(
|
|
ctx,
|
|
`SELECT agent_id, lease_token, expires_at, released_at
|
|
FROM leases
|
|
WHERE thread_id = ?`,
|
|
threadID,
|
|
).Scan(&activeAgent, &leaseToken, &expiresAt, &releasedAt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return "", ErrNoActiveLease
|
|
}
|
|
if err != nil {
|
|
return "", fmt.Errorf("read lease: %w", err)
|
|
}
|
|
|
|
if releasedAt.Valid || !parseTime(expiresAt).After(now) {
|
|
return "", ErrNoActiveLease
|
|
}
|
|
if activeAgent != agent {
|
|
return "", ErrLeaseConflict
|
|
}
|
|
|
|
return leaseToken, nil
|
|
}
|
|
|
|
func (s *InboxStore) lookupEventIDForMessage(ctx context.Context, threadID, messageID string) (int64, error) {
|
|
var eventID int64
|
|
err := s.db.QueryRowContext(
|
|
ctx,
|
|
`SELECT event_id
|
|
FROM events
|
|
WHERE thread_id = ?
|
|
AND message_id = ?
|
|
ORDER BY event_id DESC
|
|
LIMIT 1`,
|
|
threadID,
|
|
messageID,
|
|
).Scan(&eventID)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return 0, fmt.Errorf("%w: message %s not found in thread %s", ErrMessageNotFound, messageID, threadID)
|
|
}
|
|
if err != nil {
|
|
return 0, fmt.Errorf("lookup message event: %w", err)
|
|
}
|
|
return eventID, nil
|
|
}
|
|
|
|
func (s *InboxStore) currentMaxEventID(ctx context.Context) (int64, error) {
|
|
var maxEventID int64
|
|
if err := s.db.QueryRowContext(ctx, `SELECT COALESCE(MAX(event_id), 0) FROM events`).Scan(&maxEventID); err != nil {
|
|
return 0, fmt.Errorf("query max event id: %w", err)
|
|
}
|
|
return maxEventID, nil
|
|
}
|
|
|
|
func (s *InboxStore) findReplyAfter(ctx context.Context, threadID string, afterEventID int64, kinds []string) (Message, int64, bool, error) {
|
|
args := []any{threadID, afterEventID}
|
|
query := `SELECT
|
|
e.event_id,
|
|
m.message_id, m.thread_id, m.from_agent, m.to_agent, m.kind, m.summary, m.body, m.payload_json, m.created_at
|
|
FROM events e
|
|
JOIN messages m ON m.message_id = e.message_id
|
|
WHERE e.thread_id = ?
|
|
AND e.event_id > ?`
|
|
if len(kinds) > 0 {
|
|
query += " AND m.kind IN (" + placeholders(len(kinds)) + ")"
|
|
for _, kind := range kinds {
|
|
args = append(args, kind)
|
|
}
|
|
}
|
|
query += " ORDER BY e.event_id ASC LIMIT 1"
|
|
|
|
row := s.db.QueryRowContext(ctx, query, args...)
|
|
|
|
var (
|
|
eventID int64
|
|
message Message
|
|
payload string
|
|
created string
|
|
)
|
|
err := row.Scan(
|
|
&eventID,
|
|
&message.MessageID,
|
|
&message.ThreadID,
|
|
&message.FromAgent,
|
|
&message.ToAgent,
|
|
&message.Kind,
|
|
&message.Summary,
|
|
&message.Body,
|
|
&payload,
|
|
&created,
|
|
)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return Message{}, 0, false, nil
|
|
}
|
|
if err != nil {
|
|
return Message{}, 0, false, fmt.Errorf("query reply after event %d: %w", afterEventID, err)
|
|
}
|
|
|
|
message.PayloadJSON = json.RawMessage(payload)
|
|
message.CreatedAt = parseTime(created)
|
|
artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, []string{message.MessageID})
|
|
if err != nil {
|
|
return Message{}, 0, false, err
|
|
}
|
|
message.Artifacts = artifactsByMessageID[message.MessageID]
|
|
return message, eventID, true, nil
|
|
}
|
|
|
|
func (s *InboxStore) findWatchEventAfter(ctx context.Context, input WatchInput, afterEventID int64) (Thread, Message, Event, bool, error) {
|
|
args := []any{afterEventID}
|
|
query := `SELECT
|
|
t.thread_id, t.run_id, t.task_id, t.subject, t.created_by, t.assigned_to, t.status,
|
|
t.priority, t.latest_message_id, t.created_at, t.updated_at,
|
|
e.event_id, e.run_id, e.task_id, e.thread_id, e.source, e.event_type, e.message_id, e.summary, e.payload_json, e.created_at,
|
|
m.message_id, m.thread_id, m.from_agent, m.to_agent, m.kind, m.summary, m.body, m.payload_json, m.created_at
|
|
FROM events e
|
|
JOIN threads t ON t.thread_id = e.thread_id
|
|
JOIN messages m ON m.message_id = e.message_id
|
|
WHERE e.event_id > ?`
|
|
|
|
if input.Agent != "" {
|
|
query += " AND t.assigned_to = ?"
|
|
args = append(args, input.Agent)
|
|
}
|
|
if len(input.Statuses) > 0 {
|
|
query += " AND t.status IN (" + placeholders(len(input.Statuses)) + ")"
|
|
for _, status := range input.Statuses {
|
|
args = append(args, status)
|
|
}
|
|
}
|
|
query += " ORDER BY e.event_id ASC LIMIT 1"
|
|
|
|
row := s.db.QueryRowContext(ctx, query, args...)
|
|
|
|
var (
|
|
thread Thread
|
|
threadCreatedAt string
|
|
threadUpdatedAt string
|
|
threadLatestMessage sql.NullString
|
|
event Event
|
|
eventMessageID sql.NullString
|
|
eventPayload string
|
|
eventCreatedAt string
|
|
message Message
|
|
messagePayload string
|
|
messageCreatedAt string
|
|
)
|
|
|
|
err := row.Scan(
|
|
&thread.ThreadID,
|
|
&thread.RunID,
|
|
&thread.TaskID,
|
|
&thread.Subject,
|
|
&thread.CreatedBy,
|
|
&thread.AssignedTo,
|
|
&thread.Status,
|
|
&thread.Priority,
|
|
&threadLatestMessage,
|
|
&threadCreatedAt,
|
|
&threadUpdatedAt,
|
|
&event.EventID,
|
|
&event.RunID,
|
|
&event.TaskID,
|
|
&event.ThreadID,
|
|
&event.Source,
|
|
&event.EventType,
|
|
&eventMessageID,
|
|
&event.Summary,
|
|
&eventPayload,
|
|
&eventCreatedAt,
|
|
&message.MessageID,
|
|
&message.ThreadID,
|
|
&message.FromAgent,
|
|
&message.ToAgent,
|
|
&message.Kind,
|
|
&message.Summary,
|
|
&message.Body,
|
|
&messagePayload,
|
|
&messageCreatedAt,
|
|
)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return Thread{}, Message{}, Event{}, false, nil
|
|
}
|
|
if err != nil {
|
|
return Thread{}, Message{}, Event{}, false, fmt.Errorf("query watch event after %d: %w", afterEventID, err)
|
|
}
|
|
|
|
if threadLatestMessage.Valid {
|
|
thread.LatestMessageID = threadLatestMessage.String
|
|
}
|
|
thread.CreatedAt = parseTime(threadCreatedAt)
|
|
thread.UpdatedAt = parseTime(threadUpdatedAt)
|
|
if eventMessageID.Valid {
|
|
event.MessageID = eventMessageID.String
|
|
}
|
|
event.PayloadJSON = json.RawMessage(eventPayload)
|
|
event.CreatedAt = parseTime(eventCreatedAt)
|
|
message.PayloadJSON = json.RawMessage(messagePayload)
|
|
message.CreatedAt = parseTime(messageCreatedAt)
|
|
artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, []string{message.MessageID})
|
|
if err != nil {
|
|
return Thread{}, Message{}, Event{}, false, err
|
|
}
|
|
message.Artifacts = artifactsByMessageID[message.MessageID]
|
|
return thread, message, event, true, nil
|
|
}
|
|
|
|
func waitForNextPoll(ctx context.Context, interval time.Duration) (bool, error) {
|
|
timer := time.NewTimer(interval)
|
|
defer timer.Stop()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return false, ctx.Err()
|
|
case <-timer.C:
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
func isTerminalStatus(status string) bool {
|
|
return status == "done" || status == "failed" || status == "cancelled"
|
|
}
|
|
|
|
func isDeadlineExceeded(ctx context.Context) bool {
|
|
return ctx.Err() != nil && errors.Is(ctx.Err(), context.DeadlineExceeded)
|
|
}
|
|
|
|
func defaultID(value, prefix string) string {
|
|
if value != "" {
|
|
return value
|
|
}
|
|
return newID(prefix)
|
|
}
|
|
|
|
func newID(prefix string) string {
|
|
return prefix + "_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
|
}
|
|
|
|
func defaultString(value, fallback string) string {
|
|
if value != "" {
|
|
return value
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func normalizeJSON(value string) string {
|
|
if strings.TrimSpace(value) == "" {
|
|
return "{}"
|
|
}
|
|
return value
|
|
}
|
|
|
|
func placeholders(n int) string {
|
|
if n <= 0 {
|
|
return ""
|
|
}
|
|
parts := make([]string, n)
|
|
for i := range parts {
|
|
parts[i] = "?"
|
|
}
|
|
return strings.Join(parts, ",")
|
|
}
|
|
|
|
func nowUTC() time.Time {
|
|
return time.Now().UTC()
|
|
}
|
|
|
|
func formatTime(t time.Time) string {
|
|
return t.UTC().Format(time.RFC3339Nano)
|
|
}
|
|
|
|
func parseTime(value string) time.Time {
|
|
parsed, err := time.Parse(time.RFC3339Nano, value)
|
|
if err != nil {
|
|
return time.Time{}
|
|
}
|
|
return parsed
|
|
}
|