Finalize inbox artifacts and error protocol

This commit is contained in:
2026-03-19 03:25:06 +08:00
parent c3314cd9cf
commit f315d2330d
22 changed files with 659 additions and 86 deletions
+203 -33
View File
@@ -14,7 +14,10 @@ import (
var ErrLeaseConflict = errors.New("thread already claimed by another worker")
var ErrThreadNotFound = errors.New("thread not found")
var ErrMessageNotFound = errors.New("message not found")
var ErrNoActiveLease = errors.New("no active lease")
var ErrInvalidInput = errors.New("invalid input")
var ErrInvalidState = errors.New("invalid state")
type InboxStore struct {
db *sql.DB
@@ -44,6 +47,22 @@ type Message struct {
Body string `json:"body"`
PayloadJSON json.RawMessage `json:"payload_json"`
CreatedAt time.Time `json:"created_at"`
Artifacts []Artifact `json:"artifacts,omitempty"`
}
type Artifact struct {
ArtifactID string `json:"artifact_id"`
MessageID string `json:"message_id"`
Path string `json:"path"`
Kind string `json:"kind"`
MetadataJSON json.RawMessage `json:"metadata_json"`
CreatedAt time.Time `json:"created_at"`
}
type ArtifactInput struct {
Path string
Kind string
MetadataJSON string
}
type ThreadDetail struct {
@@ -76,6 +95,7 @@ type SendInput struct {
Body string
PayloadJSON string
Priority string
Artifacts []ArtifactInput
}
type FetchInput struct {
@@ -109,6 +129,7 @@ type UpdateInput struct {
Summary string
Body string
PayloadJSON string
Artifacts []ArtifactInput
}
type ReplyInput struct {
@@ -119,6 +140,7 @@ type ReplyInput struct {
Summary string
Body string
PayloadJSON string
Artifacts []ArtifactInput
}
type CompleteInput struct {
@@ -128,12 +150,14 @@ type CompleteInput struct {
Body string
PayloadJSON string
Failed bool
Artifacts []ArtifactInput
}
type CancelInput struct {
ThreadID string
Agent string
Reason string
ThreadID string
Agent string
Reason string
Artifacts []ArtifactInput
}
type ListInput struct {
@@ -257,25 +281,14 @@ func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread,
PayloadJSON: json.RawMessage(payload),
CreatedAt: now,
}
if _, err := tx.ExecContext(
ctx,
`INSERT INTO messages (
message_id, thread_id, from_agent, to_agent, kind, summary, body,
payload_json, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
message.MessageID,
message.ThreadID,
message.FromAgent,
message.ToAgent,
message.Kind,
message.Summary,
message.Body,
string(message.PayloadJSON),
formatTime(message.CreatedAt),
); err != nil {
return Thread{}, Message{}, fmt.Errorf("insert message: %w", err)
if err := insertMessage(ctx, tx, message); err != nil {
return Thread{}, Message{}, err
}
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
if err != nil {
return Thread{}, Message{}, err
}
message.Artifacts = artifacts
if err := insertEvent(ctx, tx, eventInput{
RunID: thread.RunID,
@@ -314,7 +327,7 @@ func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, i
return Thread{}, Message{}, err
}
if isTerminalStatus(thread.Status) {
return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", thread.ThreadID)
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, thread.ThreadID)
}
assignedTo := thread.AssignedTo
@@ -337,6 +350,11 @@ func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, i
if err := insertMessage(ctx, tx, message); err != nil {
return Thread{}, Message{}, err
}
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
if err != nil {
return Thread{}, Message{}, err
}
message.Artifacts = artifacts
if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, assignedTo, message.MessageID, now); err != nil {
return Thread{}, Message{}, err
@@ -413,7 +431,7 @@ func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread
}
if input.Unread {
if input.Agent == "" {
return nil, fmt.Errorf("agent is required when filtering unread threads")
return nil, fmt.Errorf("%w: agent is required when filtering unread threads", ErrInvalidInput)
}
joins = append(joins, "JOIN messages lm ON lm.message_id = t.latest_message_id")
conditions = append(conditions, "lm.to_agent = ?")
@@ -477,9 +495,8 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
if err != nil {
return ClaimResult{}, err
}
if thread.Status != "pending" {
return ClaimResult{}, fmt.Errorf("thread %s is not pending", input.ThreadID)
if isTerminalStatus(thread.Status) {
return ClaimResult{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
}
var activeLease string
@@ -498,6 +515,9 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
if activeLease != "" {
return ClaimResult{}, ErrLeaseConflict
}
if thread.Status != "pending" {
return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID)
}
if _, err := tx.ExecContext(
ctx,
@@ -614,7 +634,7 @@ func (s *InboxStore) RenewLease(ctx context.Context, input RenewInput) (ClaimRes
return ClaimResult{}, err
}
if isTerminalStatus(thread.Status) {
return ClaimResult{}, fmt.Errorf("thread %s is already terminal", input.ThreadID)
return ClaimResult{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
}
if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil {
@@ -684,7 +704,7 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
messageID := newID("msg")
if input.Status != "in_progress" && input.Status != "blocked" {
return Thread{}, Message{}, fmt.Errorf("unsupported update status %q", input.Status)
return Thread{}, Message{}, fmt.Errorf("%w: unsupported update status %q", ErrInvalidInput, input.Status)
}
tx, err := s.db.BeginTx(ctx, nil)
@@ -698,7 +718,7 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
return Thread{}, Message{}, err
}
if isTerminalStatus(thread.Status) {
return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID)
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
}
if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil {
return Thread{}, Message{}, err
@@ -724,6 +744,11 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
if err := insertMessage(ctx, tx, message); err != nil {
return Thread{}, Message{}, err
}
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
if err != nil {
return Thread{}, Message{}, err
}
message.Artifacts = artifacts
if err := updateThreadState(ctx, tx, thread.ThreadID, input.Status, thread.AssignedTo, message.MessageID, now); err != nil {
return Thread{}, Message{}, err
@@ -768,7 +793,7 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa
return Thread{}, Message{}, err
}
if isTerminalStatus(thread.Status) {
return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID)
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
}
message := Message{
@@ -786,6 +811,11 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa
if err := insertMessage(ctx, tx, message); err != nil {
return Thread{}, Message{}, err
}
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
if err != nil {
return Thread{}, Message{}, err
}
message.Artifacts = artifacts
if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil {
return Thread{}, Message{}, err
@@ -837,7 +867,7 @@ func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (T
return Thread{}, Message{}, err
}
if isTerminalStatus(thread.Status) {
return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID)
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
}
if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil {
return Thread{}, Message{}, err
@@ -858,6 +888,11 @@ func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (T
if err := insertMessage(ctx, tx, message); err != nil {
return Thread{}, Message{}, err
}
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
if err != nil {
return Thread{}, Message{}, err
}
message.Artifacts = artifacts
if err := updateThreadState(ctx, tx, thread.ThreadID, nextStatus, thread.AssignedTo, message.MessageID, now); err != nil {
return Thread{}, Message{}, err
@@ -914,7 +949,7 @@ func (s *InboxStore) CancelThread(ctx context.Context, input CancelInput) (Threa
return Thread{}, Message{}, err
}
if isTerminalStatus(thread.Status) {
return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID)
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
}
summary := defaultString(input.Reason, "thread cancelled")
@@ -933,6 +968,11 @@ func (s *InboxStore) CancelThread(ctx context.Context, input CancelInput) (Threa
if err := insertMessage(ctx, tx, message); err != nil {
return Thread{}, Message{}, err
}
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
if err != nil {
return Thread{}, Message{}, err
}
message.Artifacts = artifacts
if err := updateThreadState(ctx, tx, thread.ThreadID, "cancelled", thread.AssignedTo, message.MessageID, now); err != nil {
return Thread{}, Message{}, err
@@ -1008,6 +1048,12 @@ func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDeta
return ThreadDetail{}, fmt.Errorf("iterate thread messages: %w", err)
}
artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, messageIDs(messages))
if err != nil {
return ThreadDetail{}, err
}
attachArtifacts(messages, artifactsByMessageID)
return ThreadDetail{
Thread: thread,
Messages: messages,
@@ -1176,6 +1222,28 @@ func scanMessage(scanner threadScanner) (Message, error) {
return message, nil
}
func scanArtifact(scanner threadScanner) (Artifact, error) {
var (
artifact Artifact
metadata, created string
)
if err := scanner.Scan(
&artifact.ArtifactID,
&artifact.MessageID,
&artifact.Path,
&artifact.Kind,
&metadata,
&created,
); err != nil {
return Artifact{}, fmt.Errorf("scan artifact: %w", err)
}
artifact.MetadataJSON = json.RawMessage(metadata)
artifact.CreatedAt = parseTime(created)
return artifact, nil
}
func scanEvent(scanner threadScanner) (Event, error) {
var (
event Event
@@ -1290,6 +1358,44 @@ func insertMessage(ctx context.Context, tx *sql.Tx, message Message) error {
return nil
}
func insertArtifacts(ctx context.Context, tx *sql.Tx, messageID string, inputs []ArtifactInput, createdAt time.Time) ([]Artifact, error) {
if len(inputs) == 0 {
return nil, nil
}
artifacts := make([]Artifact, 0, len(inputs))
for _, input := range inputs {
artifact := Artifact{
ArtifactID: newID("art"),
MessageID: messageID,
Path: input.Path,
Kind: defaultString(input.Kind, "file"),
MetadataJSON: json.RawMessage(normalizeJSON(input.MetadataJSON)),
CreatedAt: createdAt,
}
_, err := tx.ExecContext(
ctx,
`INSERT INTO artifacts (
artifact_id, message_id, path, kind, metadata_json, created_at
) VALUES (?, ?, ?, ?, ?, ?)`,
artifact.ArtifactID,
artifact.MessageID,
artifact.Path,
artifact.Kind,
string(artifact.MetadataJSON),
formatTime(artifact.CreatedAt),
)
if err != nil {
return nil, fmt.Errorf("insert artifact: %w", err)
}
artifacts = append(artifacts, artifact)
}
return artifacts, nil
}
func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assignedTo, latestMessageID string, updatedAt time.Time) error {
_, err := tx.ExecContext(
ctx,
@@ -1308,6 +1414,60 @@ func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assign
return nil
}
func loadArtifactsForMessageIDs(ctx context.Context, db *sql.DB, messageIDs []string) (map[string][]Artifact, error) {
result := make(map[string][]Artifact)
if len(messageIDs) == 0 {
return result, nil
}
args := make([]any, 0, len(messageIDs))
for _, messageID := range messageIDs {
args = append(args, messageID)
}
rows, err := db.QueryContext(
ctx,
`SELECT
artifact_id, message_id, path, kind, metadata_json, created_at
FROM artifacts
WHERE message_id IN (`+placeholders(len(messageIDs))+`)
ORDER BY created_at ASC`,
args...,
)
if err != nil {
return nil, fmt.Errorf("query artifacts: %w", err)
}
defer rows.Close()
for rows.Next() {
artifact, err := scanArtifact(rows)
if err != nil {
return nil, err
}
result[artifact.MessageID] = append(result[artifact.MessageID], artifact)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate artifacts: %w", err)
}
return result, nil
}
func attachArtifacts(messages []Message, artifactsByMessageID map[string][]Artifact) {
for i := range messages {
messages[i].Artifacts = artifactsByMessageID[messages[i].MessageID]
}
}
func messageIDs(messages []Message) []string {
ids := make([]string, 0, len(messages))
for _, message := range messages {
ids = append(ids, message.MessageID)
}
return ids
}
func requireActiveLease(ctx context.Context, tx *sql.Tx, threadID, agent string, now time.Time) (string, error) {
var (
activeAgent string
@@ -1354,7 +1514,7 @@ func (s *InboxStore) lookupEventIDForMessage(ctx context.Context, threadID, mess
messageID,
).Scan(&eventID)
if errors.Is(err, sql.ErrNoRows) {
return 0, fmt.Errorf("message %s not found in thread %s", messageID, threadID)
return 0, fmt.Errorf("%w: message %s not found in thread %s", ErrMessageNotFound, messageID, threadID)
}
if err != nil {
return 0, fmt.Errorf("lookup message event: %w", err)
@@ -1416,6 +1576,11 @@ func (s *InboxStore) findReplyAfter(ctx context.Context, threadID string, afterE
message.PayloadJSON = json.RawMessage(payload)
message.CreatedAt = parseTime(created)
artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, []string{message.MessageID})
if err != nil {
return Message{}, 0, false, err
}
message.Artifacts = artifactsByMessageID[message.MessageID]
return message, eventID, true, nil
}
@@ -1510,6 +1675,11 @@ func (s *InboxStore) findWatchEventAfter(ctx context.Context, input WatchInput,
event.CreatedAt = parseTime(eventCreatedAt)
message.PayloadJSON = json.RawMessage(messagePayload)
message.CreatedAt = parseTime(messageCreatedAt)
artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, []string{message.MessageID})
if err != nil {
return Thread{}, Message{}, Event{}, false, err
}
message.Artifacts = artifactsByMessageID[message.MessageID]
return thread, message, event, true, nil
}