Finalize inbox artifacts and error protocol
This commit is contained in:
+203
-33
@@ -14,7 +14,10 @@ import (
|
||||
|
||||
var ErrLeaseConflict = errors.New("thread already claimed by another worker")
|
||||
var ErrThreadNotFound = errors.New("thread not found")
|
||||
var ErrMessageNotFound = errors.New("message not found")
|
||||
var ErrNoActiveLease = errors.New("no active lease")
|
||||
var ErrInvalidInput = errors.New("invalid input")
|
||||
var ErrInvalidState = errors.New("invalid state")
|
||||
|
||||
type InboxStore struct {
|
||||
db *sql.DB
|
||||
@@ -44,6 +47,22 @@ type Message struct {
|
||||
Body string `json:"body"`
|
||||
PayloadJSON json.RawMessage `json:"payload_json"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Artifacts []Artifact `json:"artifacts,omitempty"`
|
||||
}
|
||||
|
||||
type Artifact struct {
|
||||
ArtifactID string `json:"artifact_id"`
|
||||
MessageID string `json:"message_id"`
|
||||
Path string `json:"path"`
|
||||
Kind string `json:"kind"`
|
||||
MetadataJSON json.RawMessage `json:"metadata_json"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type ArtifactInput struct {
|
||||
Path string
|
||||
Kind string
|
||||
MetadataJSON string
|
||||
}
|
||||
|
||||
type ThreadDetail struct {
|
||||
@@ -76,6 +95,7 @@ type SendInput struct {
|
||||
Body string
|
||||
PayloadJSON string
|
||||
Priority string
|
||||
Artifacts []ArtifactInput
|
||||
}
|
||||
|
||||
type FetchInput struct {
|
||||
@@ -109,6 +129,7 @@ type UpdateInput struct {
|
||||
Summary string
|
||||
Body string
|
||||
PayloadJSON string
|
||||
Artifacts []ArtifactInput
|
||||
}
|
||||
|
||||
type ReplyInput struct {
|
||||
@@ -119,6 +140,7 @@ type ReplyInput struct {
|
||||
Summary string
|
||||
Body string
|
||||
PayloadJSON string
|
||||
Artifacts []ArtifactInput
|
||||
}
|
||||
|
||||
type CompleteInput struct {
|
||||
@@ -128,12 +150,14 @@ type CompleteInput struct {
|
||||
Body string
|
||||
PayloadJSON string
|
||||
Failed bool
|
||||
Artifacts []ArtifactInput
|
||||
}
|
||||
|
||||
type CancelInput struct {
|
||||
ThreadID string
|
||||
Agent string
|
||||
Reason string
|
||||
ThreadID string
|
||||
Agent string
|
||||
Reason string
|
||||
Artifacts []ArtifactInput
|
||||
}
|
||||
|
||||
type ListInput struct {
|
||||
@@ -257,25 +281,14 @@ func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread,
|
||||
PayloadJSON: json.RawMessage(payload),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO messages (
|
||||
message_id, thread_id, from_agent, to_agent, kind, summary, body,
|
||||
payload_json, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
message.MessageID,
|
||||
message.ThreadID,
|
||||
message.FromAgent,
|
||||
message.ToAgent,
|
||||
message.Kind,
|
||||
message.Summary,
|
||||
message.Body,
|
||||
string(message.PayloadJSON),
|
||||
formatTime(message.CreatedAt),
|
||||
); err != nil {
|
||||
return Thread{}, Message{}, fmt.Errorf("insert message: %w", err)
|
||||
if err := insertMessage(ctx, tx, message); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
message.Artifacts = artifacts
|
||||
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: thread.RunID,
|
||||
@@ -314,7 +327,7 @@ func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, i
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
if isTerminalStatus(thread.Status) {
|
||||
return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", thread.ThreadID)
|
||||
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, thread.ThreadID)
|
||||
}
|
||||
|
||||
assignedTo := thread.AssignedTo
|
||||
@@ -337,6 +350,11 @@ func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, i
|
||||
if err := insertMessage(ctx, tx, message); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
message.Artifacts = artifacts
|
||||
|
||||
if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, assignedTo, message.MessageID, now); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
@@ -413,7 +431,7 @@ func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread
|
||||
}
|
||||
if input.Unread {
|
||||
if input.Agent == "" {
|
||||
return nil, fmt.Errorf("agent is required when filtering unread threads")
|
||||
return nil, fmt.Errorf("%w: agent is required when filtering unread threads", ErrInvalidInput)
|
||||
}
|
||||
joins = append(joins, "JOIN messages lm ON lm.message_id = t.latest_message_id")
|
||||
conditions = append(conditions, "lm.to_agent = ?")
|
||||
@@ -477,9 +495,8 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
|
||||
if err != nil {
|
||||
return ClaimResult{}, err
|
||||
}
|
||||
|
||||
if thread.Status != "pending" {
|
||||
return ClaimResult{}, fmt.Errorf("thread %s is not pending", input.ThreadID)
|
||||
if isTerminalStatus(thread.Status) {
|
||||
return ClaimResult{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
|
||||
var activeLease string
|
||||
@@ -498,6 +515,9 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe
|
||||
if activeLease != "" {
|
||||
return ClaimResult{}, ErrLeaseConflict
|
||||
}
|
||||
if thread.Status != "pending" {
|
||||
return ClaimResult{}, fmt.Errorf("%w: thread %s is not pending", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
ctx,
|
||||
@@ -614,7 +634,7 @@ func (s *InboxStore) RenewLease(ctx context.Context, input RenewInput) (ClaimRes
|
||||
return ClaimResult{}, err
|
||||
}
|
||||
if isTerminalStatus(thread.Status) {
|
||||
return ClaimResult{}, fmt.Errorf("thread %s is already terminal", input.ThreadID)
|
||||
return ClaimResult{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
|
||||
if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil {
|
||||
@@ -684,7 +704,7 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
|
||||
messageID := newID("msg")
|
||||
|
||||
if input.Status != "in_progress" && input.Status != "blocked" {
|
||||
return Thread{}, Message{}, fmt.Errorf("unsupported update status %q", input.Status)
|
||||
return Thread{}, Message{}, fmt.Errorf("%w: unsupported update status %q", ErrInvalidInput, input.Status)
|
||||
}
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
@@ -698,7 +718,7 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
if isTerminalStatus(thread.Status) {
|
||||
return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID)
|
||||
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
@@ -724,6 +744,11 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput)
|
||||
if err := insertMessage(ctx, tx, message); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
message.Artifacts = artifacts
|
||||
|
||||
if err := updateThreadState(ctx, tx, thread.ThreadID, input.Status, thread.AssignedTo, message.MessageID, now); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
@@ -768,7 +793,7 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
if isTerminalStatus(thread.Status) {
|
||||
return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID)
|
||||
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
|
||||
message := Message{
|
||||
@@ -786,6 +811,11 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa
|
||||
if err := insertMessage(ctx, tx, message); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
message.Artifacts = artifacts
|
||||
|
||||
if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
@@ -837,7 +867,7 @@ func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (T
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
if isTerminalStatus(thread.Status) {
|
||||
return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID)
|
||||
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
@@ -858,6 +888,11 @@ func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (T
|
||||
if err := insertMessage(ctx, tx, message); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
message.Artifacts = artifacts
|
||||
|
||||
if err := updateThreadState(ctx, tx, thread.ThreadID, nextStatus, thread.AssignedTo, message.MessageID, now); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
@@ -914,7 +949,7 @@ func (s *InboxStore) CancelThread(ctx context.Context, input CancelInput) (Threa
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
if isTerminalStatus(thread.Status) {
|
||||
return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID)
|
||||
return Thread{}, Message{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, input.ThreadID)
|
||||
}
|
||||
|
||||
summary := defaultString(input.Reason, "thread cancelled")
|
||||
@@ -933,6 +968,11 @@ func (s *InboxStore) CancelThread(ctx context.Context, input CancelInput) (Threa
|
||||
if err := insertMessage(ctx, tx, message); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
artifacts, err := insertArtifacts(ctx, tx, message.MessageID, input.Artifacts, now)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
message.Artifacts = artifacts
|
||||
|
||||
if err := updateThreadState(ctx, tx, thread.ThreadID, "cancelled", thread.AssignedTo, message.MessageID, now); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
@@ -1008,6 +1048,12 @@ func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDeta
|
||||
return ThreadDetail{}, fmt.Errorf("iterate thread messages: %w", err)
|
||||
}
|
||||
|
||||
artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, messageIDs(messages))
|
||||
if err != nil {
|
||||
return ThreadDetail{}, err
|
||||
}
|
||||
attachArtifacts(messages, artifactsByMessageID)
|
||||
|
||||
return ThreadDetail{
|
||||
Thread: thread,
|
||||
Messages: messages,
|
||||
@@ -1176,6 +1222,28 @@ func scanMessage(scanner threadScanner) (Message, error) {
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func scanArtifact(scanner threadScanner) (Artifact, error) {
|
||||
var (
|
||||
artifact Artifact
|
||||
metadata, created string
|
||||
)
|
||||
|
||||
if err := scanner.Scan(
|
||||
&artifact.ArtifactID,
|
||||
&artifact.MessageID,
|
||||
&artifact.Path,
|
||||
&artifact.Kind,
|
||||
&metadata,
|
||||
&created,
|
||||
); err != nil {
|
||||
return Artifact{}, fmt.Errorf("scan artifact: %w", err)
|
||||
}
|
||||
|
||||
artifact.MetadataJSON = json.RawMessage(metadata)
|
||||
artifact.CreatedAt = parseTime(created)
|
||||
return artifact, nil
|
||||
}
|
||||
|
||||
func scanEvent(scanner threadScanner) (Event, error) {
|
||||
var (
|
||||
event Event
|
||||
@@ -1290,6 +1358,44 @@ func insertMessage(ctx context.Context, tx *sql.Tx, message Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertArtifacts(ctx context.Context, tx *sql.Tx, messageID string, inputs []ArtifactInput, createdAt time.Time) ([]Artifact, error) {
|
||||
if len(inputs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
artifacts := make([]Artifact, 0, len(inputs))
|
||||
for _, input := range inputs {
|
||||
artifact := Artifact{
|
||||
ArtifactID: newID("art"),
|
||||
MessageID: messageID,
|
||||
Path: input.Path,
|
||||
Kind: defaultString(input.Kind, "file"),
|
||||
MetadataJSON: json.RawMessage(normalizeJSON(input.MetadataJSON)),
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
|
||||
_, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO artifacts (
|
||||
artifact_id, message_id, path, kind, metadata_json, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?)`,
|
||||
artifact.ArtifactID,
|
||||
artifact.MessageID,
|
||||
artifact.Path,
|
||||
artifact.Kind,
|
||||
string(artifact.MetadataJSON),
|
||||
formatTime(artifact.CreatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert artifact: %w", err)
|
||||
}
|
||||
|
||||
artifacts = append(artifacts, artifact)
|
||||
}
|
||||
|
||||
return artifacts, nil
|
||||
}
|
||||
|
||||
func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assignedTo, latestMessageID string, updatedAt time.Time) error {
|
||||
_, err := tx.ExecContext(
|
||||
ctx,
|
||||
@@ -1308,6 +1414,60 @@ func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assign
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadArtifactsForMessageIDs(ctx context.Context, db *sql.DB, messageIDs []string) (map[string][]Artifact, error) {
|
||||
result := make(map[string][]Artifact)
|
||||
if len(messageIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
args := make([]any, 0, len(messageIDs))
|
||||
for _, messageID := range messageIDs {
|
||||
args = append(args, messageID)
|
||||
}
|
||||
|
||||
rows, err := db.QueryContext(
|
||||
ctx,
|
||||
`SELECT
|
||||
artifact_id, message_id, path, kind, metadata_json, created_at
|
||||
FROM artifacts
|
||||
WHERE message_id IN (`+placeholders(len(messageIDs))+`)
|
||||
ORDER BY created_at ASC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query artifacts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
artifact, err := scanArtifact(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[artifact.MessageID] = append(result[artifact.MessageID], artifact)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate artifacts: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func attachArtifacts(messages []Message, artifactsByMessageID map[string][]Artifact) {
|
||||
for i := range messages {
|
||||
messages[i].Artifacts = artifactsByMessageID[messages[i].MessageID]
|
||||
}
|
||||
}
|
||||
|
||||
func messageIDs(messages []Message) []string {
|
||||
ids := make([]string, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
ids = append(ids, message.MessageID)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func requireActiveLease(ctx context.Context, tx *sql.Tx, threadID, agent string, now time.Time) (string, error) {
|
||||
var (
|
||||
activeAgent string
|
||||
@@ -1354,7 +1514,7 @@ func (s *InboxStore) lookupEventIDForMessage(ctx context.Context, threadID, mess
|
||||
messageID,
|
||||
).Scan(&eventID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, fmt.Errorf("message %s not found in thread %s", messageID, threadID)
|
||||
return 0, fmt.Errorf("%w: message %s not found in thread %s", ErrMessageNotFound, messageID, threadID)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("lookup message event: %w", err)
|
||||
@@ -1416,6 +1576,11 @@ func (s *InboxStore) findReplyAfter(ctx context.Context, threadID string, afterE
|
||||
|
||||
message.PayloadJSON = json.RawMessage(payload)
|
||||
message.CreatedAt = parseTime(created)
|
||||
artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, []string{message.MessageID})
|
||||
if err != nil {
|
||||
return Message{}, 0, false, err
|
||||
}
|
||||
message.Artifacts = artifactsByMessageID[message.MessageID]
|
||||
return message, eventID, true, nil
|
||||
}
|
||||
|
||||
@@ -1510,6 +1675,11 @@ func (s *InboxStore) findWatchEventAfter(ctx context.Context, input WatchInput,
|
||||
event.CreatedAt = parseTime(eventCreatedAt)
|
||||
message.PayloadJSON = json.RawMessage(messagePayload)
|
||||
message.CreatedAt = parseTime(messageCreatedAt)
|
||||
artifactsByMessageID, err := loadArtifactsForMessageIDs(ctx, s.db, []string{message.MessageID})
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, Event{}, false, err
|
||||
}
|
||||
message.Artifacts = artifactsByMessageID[message.MessageID]
|
||||
return thread, message, event, true, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user