diff --git a/internal/cli/inbox/done.go b/internal/cli/inbox/done.go new file mode 100644 index 0000000..66717f7 --- /dev/null +++ b/internal/cli/inbox/done.go @@ -0,0 +1,93 @@ +package inbox + +import ( + "fmt" + + "ai-workflow-skill/internal/db" + "ai-workflow-skill/internal/protocol" + "ai-workflow-skill/internal/store" + + "github.com/spf13/cobra" +) + +type completeOptions struct { + agent string + threadID string + summary string + body string + payloadJSON string +} + +func newDoneCmd(root *rootOptions) *cobra.Command { + return newCompleteCmd(root, "done") +} + +func newFailCmd(root *rootOptions) *cobra.Command { + return newCompleteCmd(root, "fail") +} + +func newCompleteCmd(root *rootOptions, mode string) *cobra.Command { + opts := &completeOptions{} + + cmd := &cobra.Command{ + Use: mode, + Short: map[string]string{"done": "Mark a thread complete", "fail": "Mark a thread failed"}[mode], + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + agent := opts.agent + if agent == "" { + agent = root.agent + } + if agent == "" { + return fmt.Errorf("agent is required") + } + + sqlDB, err := db.Open(ctx, root.dbPath) + if err != nil { + return err + } + defer sqlDB.Close() + + s := store.NewInboxStore(sqlDB) + thread, message, err := s.CompleteThread(ctx, store.CompleteInput{ + ThreadID: opts.threadID, + Agent: agent, + Summary: opts.summary, + Body: opts.body, + PayloadJSON: opts.payloadJSON, + Failed: mode == "fail", + }) + if err != nil { + return err + } + + resp := protocol.Success{ + OK: true, + Command: mode, + Data: map[string]any{ + "thread": thread, + "message": message, + }, + } + + if root.json { + return protocol.WriteJSON(cmd.OutOrStdout(), resp) + } + + _, err = fmt.Fprintf(cmd.OutOrStdout(), "%s thread %s\n", mode, thread.ThreadID) + return err + }, + } + + cmd.Flags().StringVar(&opts.agent, "agent", "", "Acting agent") + cmd.Flags().StringVar(&opts.threadID, "thread", "", "Thread ID") + cmd.Flags().StringVar(&opts.summary, "summary", "", "Short completion summary") + cmd.Flags().StringVar(&opts.body, "body", "", "Completion body") + cmd.Flags().StringVar(&opts.payloadJSON, "payload-json", "", "Structured payload JSON string") + + _ = cmd.MarkFlagRequired("thread") + _ = cmd.MarkFlagRequired("summary") + + return cmd +} diff --git a/internal/cli/inbox/integration_test.go b/internal/cli/inbox/integration_test.go index 8f4eff4..178d0f2 100644 --- a/internal/cli/inbox/integration_test.go +++ b/internal/cli/inbox/integration_test.go @@ -75,6 +75,81 @@ func TestInboxLifecycle(t *testing.T) { t.Fatalf("expected claimed thread, got %q", claimedStatus) } + updateOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "update", + "--agent", "worker-a", + "--thread", threadID, + "--status", "in_progress", + "--summary", "Implementation started", + "--body", "Scanning current HTTP client usage.", + ) + + var updateResp map[string]any + mustDecodeJSON(t, updateOut, &updateResp) + updatedStatus := nestedString(t, updateResp, "data", "thread", "status") + if updatedStatus != "in_progress" { + t.Fatalf("expected in_progress thread, got %q", updatedStatus) + } + + blockedOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "update", + "--agent", "worker-a", + "--thread", threadID, + "--status", "blocked", + "--summary", "Need timeout decision", + "--payload-json", `{"question":"Should retries apply to read timeouts?"}`, + ) + + var blockedResp map[string]any + mustDecodeJSON(t, blockedOut, &blockedResp) + blockedStatus := nestedString(t, blockedResp, "data", "thread", "status") + if blockedStatus != "blocked" { + t.Fatalf("expected blocked thread, got %q", blockedStatus) + } + + replyOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "reply", + "--from", "leader", + "--to", "worker-a", + "--thread", threadID, + "--summary", "Retry read timeouts", + "--body", "Yes, include read timeouts in the retry policy.", + ) + + var replyResp map[string]any + mustDecodeJSON(t, replyOut, &replyResp) + replyKind := nestedString(t, replyResp, "data", "message", "kind") + if replyKind != "answer" { + t.Fatalf("expected answer reply, got %q", replyKind) + } + + doneOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "done", + "--agent", "worker-a", + "--thread", threadID, + "--summary", "Retry policy implemented", + "--body", "The HTTP client now retries the selected transient failures.", + ) + + var doneResp map[string]any + mustDecodeJSON(t, doneOut, &doneResp) + doneStatus := nestedString(t, doneResp, "data", "thread", "status") + if doneStatus != "done" { + t.Fatalf("expected done thread, got %q", doneStatus) + } + showOut := runInboxCommand( t, "--db", dbPath, @@ -86,13 +161,80 @@ func TestInboxLifecycle(t *testing.T) { var showResp map[string]any mustDecodeJSON(t, showOut, &showResp) showStatus := nestedString(t, showResp, "data", "thread", "status") - if showStatus != "claimed" { - t.Fatalf("expected show status claimed, got %q", showStatus) + if showStatus != "done" { + t.Fatalf("expected show status done, got %q", showStatus) } messagesValue := nestedValue(t, showResp, "data", "messages") messages, ok := messagesValue.([]any) - if !ok || len(messages) != 2 { - t.Fatalf("expected two messages in thread history, got %#v", messagesValue) + if !ok || len(messages) != 6 { + t.Fatalf("expected six messages in thread history, got %#v", messagesValue) + } +} + +func TestInboxFailLifecycle(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "coord.db") + + runInboxCommand(t, "--db", dbPath, "--json", "init") + + sendOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "send", + "--from", "leader", + "--to", "worker-b", + "--subject", "Investigate failing migration", + "--summary", "Check migration failure", + "--run", "run_blog_002", + "--task", "T2", + ) + + var sendResp map[string]any + mustDecodeJSON(t, sendOut, &sendResp) + threadID := nestedString(t, sendResp, "data", "thread", "thread_id") + + runInboxCommand( + t, + "--db", dbPath, + "--json", + "claim", + "--agent", "worker-b", + "--thread", threadID, + ) + + failOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "fail", + "--agent", "worker-b", + "--thread", threadID, + "--summary", "Migration failed", + "--body", "The migration cannot proceed because the prior schema is inconsistent.", + ) + + var failResp map[string]any + mustDecodeJSON(t, failOut, &failResp) + failStatus := nestedString(t, failResp, "data", "thread", "status") + if failStatus != "failed" { + t.Fatalf("expected failed thread, got %q", failStatus) + } + + showOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "show", + "--thread", threadID, + ) + + var showResp map[string]any + mustDecodeJSON(t, showOut, &showResp) + showStatus := nestedString(t, showResp, "data", "thread", "status") + if showStatus != "failed" { + t.Fatalf("expected show status failed, got %q", showStatus) } } diff --git a/internal/cli/inbox/reply.go b/internal/cli/inbox/reply.go new file mode 100644 index 0000000..154a3e8 --- /dev/null +++ b/internal/cli/inbox/reply.go @@ -0,0 +1,91 @@ +package inbox + +import ( + "fmt" + + "ai-workflow-skill/internal/db" + "ai-workflow-skill/internal/protocol" + "ai-workflow-skill/internal/store" + + "github.com/spf13/cobra" +) + +type replyOptions struct { + from string + to string + threadID string + kind string + summary string + body string + payloadJSON string +} + +func newReplyCmd(root *rootOptions) *cobra.Command { + opts := &replyOptions{} + + cmd := &cobra.Command{ + Use: "reply", + Short: "Reply inside an existing thread", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + from := opts.from + if from == "" { + from = root.agent + } + if from == "" { + return fmt.Errorf("from agent is required") + } + + sqlDB, err := db.Open(ctx, root.dbPath) + if err != nil { + return err + } + defer sqlDB.Close() + + s := store.NewInboxStore(sqlDB) + thread, message, err := s.ReplyToThread(ctx, store.ReplyInput{ + ThreadID: opts.threadID, + FromAgent: from, + ToAgent: opts.to, + Kind: opts.kind, + Summary: opts.summary, + Body: opts.body, + PayloadJSON: opts.payloadJSON, + }) + if err != nil { + return err + } + + resp := protocol.Success{ + OK: true, + Command: "reply", + Data: map[string]any{ + "thread": thread, + "message": message, + }, + } + + if root.json { + return protocol.WriteJSON(cmd.OutOrStdout(), resp) + } + + _, err = fmt.Fprintf(cmd.OutOrStdout(), "replied on thread %s\n", thread.ThreadID) + return err + }, + } + + cmd.Flags().StringVar(&opts.from, "from", "", "Replying agent") + cmd.Flags().StringVar(&opts.to, "to", "", "Receiving agent") + cmd.Flags().StringVar(&opts.threadID, "thread", "", "Thread ID") + cmd.Flags().StringVar(&opts.kind, "kind", "answer", "Reply kind") + cmd.Flags().StringVar(&opts.summary, "summary", "", "Short reply summary") + cmd.Flags().StringVar(&opts.body, "body", "", "Reply body") + cmd.Flags().StringVar(&opts.payloadJSON, "payload-json", "", "Structured payload JSON string") + + _ = cmd.MarkFlagRequired("thread") + _ = cmd.MarkFlagRequired("to") + _ = cmd.MarkFlagRequired("summary") + + return cmd +} diff --git a/internal/cli/inbox/root.go b/internal/cli/inbox/root.go index 1cb2918..259a15e 100644 --- a/internal/cli/inbox/root.go +++ b/internal/cli/inbox/root.go @@ -26,6 +26,10 @@ func NewRootCmd() *cobra.Command { cmd.AddCommand(newSendCmd(opts)) cmd.AddCommand(newFetchCmd(opts)) cmd.AddCommand(newClaimCmd(opts)) + cmd.AddCommand(newUpdateCmd(opts)) + cmd.AddCommand(newReplyCmd(opts)) + cmd.AddCommand(newDoneCmd(opts)) + cmd.AddCommand(newFailCmd(opts)) cmd.AddCommand(newShowCmd(opts)) return cmd diff --git a/internal/cli/inbox/update.go b/internal/cli/inbox/update.go new file mode 100644 index 0000000..b22a960 --- /dev/null +++ b/internal/cli/inbox/update.go @@ -0,0 +1,88 @@ +package inbox + +import ( + "fmt" + + "ai-workflow-skill/internal/db" + "ai-workflow-skill/internal/protocol" + "ai-workflow-skill/internal/store" + + "github.com/spf13/cobra" +) + +type updateOptions struct { + agent string + threadID string + status string + summary string + body string + payloadJSON string +} + +func newUpdateCmd(root *rootOptions) *cobra.Command { + opts := &updateOptions{} + + cmd := &cobra.Command{ + Use: "update", + Short: "Append a progress or blocked update to a thread", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + agent := opts.agent + if agent == "" { + agent = root.agent + } + if agent == "" { + return fmt.Errorf("agent is required") + } + + sqlDB, err := db.Open(ctx, root.dbPath) + if err != nil { + return err + } + defer sqlDB.Close() + + s := store.NewInboxStore(sqlDB) + thread, message, err := s.UpdateThreadStatus(ctx, store.UpdateInput{ + ThreadID: opts.threadID, + Agent: agent, + Status: opts.status, + Summary: opts.summary, + Body: opts.body, + PayloadJSON: opts.payloadJSON, + }) + if err != nil { + return err + } + + resp := protocol.Success{ + OK: true, + Command: "update", + Data: map[string]any{ + "thread": thread, + "message": message, + }, + } + + if root.json { + return protocol.WriteJSON(cmd.OutOrStdout(), resp) + } + + _, err = fmt.Fprintf(cmd.OutOrStdout(), "updated thread %s to %s\n", thread.ThreadID, thread.Status) + return err + }, + } + + cmd.Flags().StringVar(&opts.agent, "agent", "", "Updating agent") + cmd.Flags().StringVar(&opts.threadID, "thread", "", "Thread ID") + cmd.Flags().StringVar(&opts.status, "status", "", "New status: in_progress or blocked") + cmd.Flags().StringVar(&opts.summary, "summary", "", "Short update summary") + cmd.Flags().StringVar(&opts.body, "body", "", "Update body") + cmd.Flags().StringVar(&opts.payloadJSON, "payload-json", "", "Structured payload JSON string") + + _ = cmd.MarkFlagRequired("thread") + _ = cmd.MarkFlagRequired("status") + _ = cmd.MarkFlagRequired("summary") + + return cmd +} diff --git a/internal/store/inbox.go b/internal/store/inbox.go index ffcae45..62bda53 100644 --- a/internal/store/inbox.go +++ b/internal/store/inbox.go @@ -80,6 +80,34 @@ type ClaimResult struct { Message Message `json:"message"` } +type UpdateInput struct { + ThreadID string + Agent string + Status string + Summary string + Body string + PayloadJSON string +} + +type ReplyInput struct { + ThreadID string + FromAgent string + ToAgent string + Kind string + Summary string + Body string + PayloadJSON string +} + +type CompleteInput struct { + ThreadID string + Agent string + Summary string + Body string + PayloadJSON string + Failed bool +} + func NewInboxStore(db *sql.DB) *InboxStore { return &InboxStore{db: db} } @@ -381,6 +409,215 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe }, nil } +func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) (Thread, Message, error) { + now := nowUTC() + messageID := newID("msg") + + if input.Status != "in_progress" && input.Status != "blocked" { + return Thread{}, Message{}, fmt.Errorf("unsupported update status %q", input.Status) + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return Thread{}, Message{}, fmt.Errorf("begin update transaction: %w", err) + } + defer tx.Rollback() + + thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) + if err != nil { + return Thread{}, Message{}, err + } + + if thread.Status == "done" || thread.Status == "failed" || thread.Status == "cancelled" { + return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID) + } + + kind := "progress" + if input.Status == "blocked" { + kind = "question" + } + + message := Message{ + MessageID: messageID, + ThreadID: thread.ThreadID, + FromAgent: input.Agent, + ToAgent: thread.CreatedBy, + Kind: kind, + Summary: input.Summary, + Body: input.Body, + PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), + CreatedAt: now, + } + + if err := insertMessage(ctx, tx, message); err != nil { + return Thread{}, Message{}, err + } + + if err := updateThreadState(ctx, tx, thread.ThreadID, input.Status, thread.AssignedTo, message.MessageID, now); err != nil { + return Thread{}, Message{}, err + } + + if err := insertEvent(ctx, tx, eventInput{ + RunID: thread.RunID, + TaskID: thread.TaskID, + ThreadID: thread.ThreadID, + Source: "inbox", + EventType: "thread_" + input.Status, + MessageID: message.MessageID, + Summary: message.Summary, + PayloadJSON: string(message.PayloadJSON), + CreatedAt: now, + }); err != nil { + return Thread{}, Message{}, err + } + + if err := tx.Commit(); err != nil { + return Thread{}, Message{}, fmt.Errorf("commit update transaction: %w", err) + } + + thread.Status = input.Status + thread.LatestMessageID = message.MessageID + thread.UpdatedAt = now + return thread, message, nil +} + +func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Thread, Message, error) { + now := nowUTC() + messageID := newID("msg") + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return Thread{}, Message{}, fmt.Errorf("begin reply transaction: %w", err) + } + defer tx.Rollback() + + thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) + if err != nil { + return Thread{}, Message{}, err + } + + message := Message{ + MessageID: messageID, + ThreadID: thread.ThreadID, + FromAgent: input.FromAgent, + ToAgent: input.ToAgent, + Kind: defaultString(input.Kind, "answer"), + Summary: input.Summary, + Body: input.Body, + PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), + CreatedAt: now, + } + + if err := insertMessage(ctx, tx, message); err != nil { + return Thread{}, Message{}, err + } + + if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil { + return Thread{}, Message{}, err + } + + if err := insertEvent(ctx, tx, eventInput{ + RunID: thread.RunID, + TaskID: thread.TaskID, + ThreadID: thread.ThreadID, + Source: "inbox", + EventType: "thread_replied", + MessageID: message.MessageID, + Summary: message.Summary, + PayloadJSON: string(message.PayloadJSON), + CreatedAt: now, + }); err != nil { + return Thread{}, Message{}, err + } + + if err := tx.Commit(); err != nil { + return Thread{}, Message{}, fmt.Errorf("commit reply transaction: %w", err) + } + + thread.LatestMessageID = message.MessageID + thread.UpdatedAt = now + return thread, message, nil +} + +func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (Thread, Message, error) { + now := nowUTC() + messageID := newID("msg") + + nextStatus := "done" + eventType := "thread_done" + summary := input.Summary + if input.Failed { + nextStatus = "failed" + eventType = "thread_failed" + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return Thread{}, Message{}, fmt.Errorf("begin complete transaction: %w", err) + } + defer tx.Rollback() + + thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) + if err != nil { + return Thread{}, Message{}, err + } + + message := Message{ + MessageID: messageID, + ThreadID: thread.ThreadID, + FromAgent: input.Agent, + ToAgent: thread.CreatedBy, + Kind: "result", + Summary: summary, + Body: input.Body, + PayloadJSON: json.RawMessage(normalizeJSON(input.PayloadJSON)), + CreatedAt: now, + } + + if err := insertMessage(ctx, tx, message); err != nil { + return Thread{}, Message{}, err + } + + if err := updateThreadState(ctx, tx, thread.ThreadID, nextStatus, thread.AssignedTo, message.MessageID, now); err != nil { + return Thread{}, Message{}, err + } + + if _, err := tx.ExecContext( + ctx, + `UPDATE leases + SET released_at = ? + WHERE thread_id = ? + AND released_at IS NULL`, + formatTime(now), + thread.ThreadID, + ); err != nil { + return Thread{}, Message{}, fmt.Errorf("release lease: %w", err) + } + + if err := insertEvent(ctx, tx, eventInput{ + RunID: thread.RunID, + TaskID: thread.TaskID, + ThreadID: thread.ThreadID, + Source: "inbox", + EventType: eventType, + MessageID: message.MessageID, + Summary: message.Summary, + PayloadJSON: string(message.PayloadJSON), + CreatedAt: now, + }); err != nil { + return Thread{}, Message{}, err + } + + if err := tx.Commit(); err != nil { + return Thread{}, Message{}, fmt.Errorf("commit complete transaction: %w", err) + } + + thread.Status = nextStatus + thread.LatestMessageID = message.MessageID + thread.UpdatedAt = now + return thread, message, nil +} + func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDetail, error) { thread, err := selectThread(ctx, s.db, threadID) if err != nil { @@ -427,9 +664,9 @@ type threadScanner interface { func scanThread(scanner threadScanner) (Thread, error) { var ( - thread Thread - createdAt, updatedAt string - latestMessageID sql.NullString + thread Thread + createdAt, updatedAt string + latestMessageID sql.NullString ) if err := scanner.Scan( @@ -459,8 +696,8 @@ func scanThread(scanner threadScanner) (Thread, error) { func scanMessage(scanner threadScanner) (Message, error) { var ( - message Message - payload, createdAt string + message Message + payload, createdAt string ) if err := scanner.Scan( @@ -543,6 +780,47 @@ func insertEvent(ctx context.Context, tx *sql.Tx, input eventInput) error { return nil } +func insertMessage(ctx context.Context, tx *sql.Tx, message Message) error { + _, err := tx.ExecContext( + ctx, + `INSERT INTO messages ( + message_id, thread_id, from_agent, to_agent, kind, summary, body, + payload_json, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`, + message.MessageID, + message.ThreadID, + message.FromAgent, + message.ToAgent, + message.Kind, + message.Summary, + message.Body, + string(message.PayloadJSON), + formatTime(message.CreatedAt), + ) + if err != nil { + return fmt.Errorf("insert message: %w", err) + } + return nil +} + +func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assignedTo, latestMessageID string, updatedAt time.Time) error { + _, err := tx.ExecContext( + ctx, + `UPDATE threads + SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ? + WHERE thread_id = ?`, + status, + assignedTo, + latestMessageID, + formatTime(updatedAt), + threadID, + ) + if err != nil { + return fmt.Errorf("update thread state: %w", err) + } + return nil +} + func defaultID(value, prefix string) string { if value != "" { return value