From c3314cd9cff952d2cf71a1937afb5a337236c056 Mon Sep 17 00:00:00 2001 From: kurihada Date: Thu, 19 Mar 2026 03:15:17 +0800 Subject: [PATCH] Complete inbox CLI implementation --- docs/implementation-roadmap.md | 33 +- internal/cli/inbox/body.go | 21 + internal/cli/inbox/cancel.go | 77 +++ internal/cli/inbox/done.go | 9 +- internal/cli/inbox/fetch.go | 3 + internal/cli/inbox/integration_test.go | 294 +++++++++- internal/cli/inbox/list.go | 81 +++ internal/cli/inbox/renew.go | 77 +++ internal/cli/inbox/reply.go | 9 +- internal/cli/inbox/root.go | 5 + internal/cli/inbox/send.go | 24 +- internal/cli/inbox/update.go | 9 +- internal/cli/inbox/wait_reply.go | 85 +++ internal/cli/inbox/watch.go | 92 +++ internal/store/inbox.go | 748 ++++++++++++++++++++++++- 15 files changed, 1524 insertions(+), 43 deletions(-) create mode 100644 internal/cli/inbox/body.go create mode 100644 internal/cli/inbox/cancel.go create mode 100644 internal/cli/inbox/list.go create mode 100644 internal/cli/inbox/renew.go create mode 100644 internal/cli/inbox/wait_reply.go create mode 100644 internal/cli/inbox/watch.go diff --git a/docs/implementation-roadmap.md b/docs/implementation-roadmap.md index 955f9bb..f9c317f 100644 --- a/docs/implementation-roadmap.md +++ b/docs/implementation-roadmap.md @@ -12,16 +12,18 @@ As of now: - architecture and workflow docs are written - CLI surfaces for `inbox`, `orch`, worktree execution, and `council-review` are defined -- SQLite schema drafts exist in the docs +- embedded SQLite schema and migrations exist in code - JSON output shapes are defined for the major flows - Go module and initial command skeletons exist - `inbox` and `orch` both compile - shared SQLite schema initialization exists -- `inbox init` works and creates the database schema +- `inbox` is implemented end-to-end, including send/fetch/claim/renew/update/reply/done/fail/cancel/list/show/watch/wait-reply +- `inbox` supports blocking waits, lease renewal, unread fetches, and `--body-file` +- integration tests cover the main inbox lifecycle plus wait/watch flows - `orch` currently exists as a command skeleton only -- no higher-level inbox or orch workflows have been implemented yet +- no scheduler workflows have been implemented yet -This means the project is past design discovery and ready for code implementation. +This means the project is past design discovery and ready for `orch` implementation. ## Source Of Truth @@ -59,10 +61,11 @@ Build a Go-based local agent orchestration stack with: Current implementation status: - `Milestone 1: Go Skeleton` is complete -- `Milestone 2: Shared DB Layer` is partially complete -- `Milestone 3: Inbox Happy Path` has started only through `inbox init` +- `Milestone 2: Shared DB Layer` is complete enough for both CLIs +- `Milestone 3: Inbox Happy Path` is complete +- `Milestone 6: Waiting Primitives` is partially complete through `inbox wait-reply` -The next practical coding target is the rest of the inbox happy path. +The next practical coding target is `Milestone 4: Orch Core Scheduling`. ### Milestone 1: Go Skeleton @@ -113,7 +116,7 @@ Definition of done: Status: -- partially completed +- completed for current inbox needs Completed so far: @@ -124,7 +127,7 @@ Completed so far: Remaining: -- decide whether `orch` should gain an explicit DB bootstrap check or reuse `inbox init` +- decide whether `orch` should gain an explicit DB bootstrap check or continue to rely on `inbox init` ### Milestone 3: Inbox Happy Path @@ -158,20 +161,24 @@ Definition of done: Status: -- not complete +- completed Completed so far: - `inbox init` - -Next commands to implement: - - `inbox send` - `inbox fetch` - `inbox claim` +- `inbox renew` - `inbox update` +- `inbox reply` - `inbox done` +- `inbox fail` +- `inbox cancel` +- `inbox list` - `inbox show` +- `inbox watch` +- `inbox wait-reply` ### Milestone 4: Orch Core Scheduling diff --git a/internal/cli/inbox/body.go b/internal/cli/inbox/body.go new file mode 100644 index 0000000..f646688 --- /dev/null +++ b/internal/cli/inbox/body.go @@ -0,0 +1,21 @@ +package inbox + +import ( + "fmt" + "os" +) + +func resolveBodyValue(body, bodyFile string) (string, error) { + if body != "" && bodyFile != "" { + return "", fmt.Errorf("body and body-file are mutually exclusive") + } + if bodyFile == "" { + return body, nil + } + + content, err := os.ReadFile(bodyFile) + if err != nil { + return "", fmt.Errorf("read body file %q: %w", bodyFile, err) + } + return string(content), nil +} diff --git a/internal/cli/inbox/cancel.go b/internal/cli/inbox/cancel.go new file mode 100644 index 0000000..b4bc7a9 --- /dev/null +++ b/internal/cli/inbox/cancel.go @@ -0,0 +1,77 @@ +package inbox + +import ( + "fmt" + + "ai-workflow-skill/internal/db" + "ai-workflow-skill/internal/protocol" + "ai-workflow-skill/internal/store" + + "github.com/spf13/cobra" +) + +type cancelOptions struct { + agent string + threadID string + reason string +} + +func newCancelCmd(root *rootOptions) *cobra.Command { + opts := &cancelOptions{} + + cmd := &cobra.Command{ + Use: "cancel", + Short: "Cancel a thread", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + agent := opts.agent + if agent == "" { + agent = root.agent + } + if agent == "" { + return fmt.Errorf("agent is required") + } + + sqlDB, err := db.Open(ctx, root.dbPath) + if err != nil { + return err + } + defer sqlDB.Close() + + s := store.NewInboxStore(sqlDB) + thread, message, err := s.CancelThread(ctx, store.CancelInput{ + ThreadID: opts.threadID, + Agent: agent, + Reason: opts.reason, + }) + if err != nil { + return err + } + + resp := protocol.Success{ + OK: true, + Command: "cancel", + Data: map[string]any{ + "thread": thread, + "message": message, + }, + } + + if root.json { + return protocol.WriteJSON(cmd.OutOrStdout(), resp) + } + + _, err = fmt.Fprintf(cmd.OutOrStdout(), "cancelled thread %s\n", thread.ThreadID) + return err + }, + } + + cmd.Flags().StringVar(&opts.agent, "agent", "", "Acting agent") + cmd.Flags().StringVar(&opts.threadID, "thread", "", "Thread ID") + cmd.Flags().StringVar(&opts.reason, "reason", "", "Cancellation reason") + + _ = cmd.MarkFlagRequired("thread") + + return cmd +} diff --git a/internal/cli/inbox/done.go b/internal/cli/inbox/done.go index 66717f7..6f25f9c 100644 --- a/internal/cli/inbox/done.go +++ b/internal/cli/inbox/done.go @@ -15,6 +15,7 @@ type completeOptions struct { threadID string summary string body string + bodyFile string payloadJSON string } @@ -43,6 +44,11 @@ func newCompleteCmd(root *rootOptions, mode string) *cobra.Command { return fmt.Errorf("agent is required") } + body, err := resolveBodyValue(opts.body, opts.bodyFile) + if err != nil { + return err + } + sqlDB, err := db.Open(ctx, root.dbPath) if err != nil { return err @@ -54,7 +60,7 @@ func newCompleteCmd(root *rootOptions, mode string) *cobra.Command { ThreadID: opts.threadID, Agent: agent, Summary: opts.summary, - Body: opts.body, + Body: body, PayloadJSON: opts.payloadJSON, Failed: mode == "fail", }) @@ -84,6 +90,7 @@ func newCompleteCmd(root *rootOptions, mode string) *cobra.Command { cmd.Flags().StringVar(&opts.threadID, "thread", "", "Thread ID") cmd.Flags().StringVar(&opts.summary, "summary", "", "Short completion summary") cmd.Flags().StringVar(&opts.body, "body", "", "Completion body") + cmd.Flags().StringVar(&opts.bodyFile, "body-file", "", "Read completion body from file") cmd.Flags().StringVar(&opts.payloadJSON, "payload-json", "", "Structured payload JSON string") _ = cmd.MarkFlagRequired("thread") diff --git a/internal/cli/inbox/fetch.go b/internal/cli/inbox/fetch.go index c5dc157..3a75430 100644 --- a/internal/cli/inbox/fetch.go +++ b/internal/cli/inbox/fetch.go @@ -15,6 +15,7 @@ type fetchOptions struct { agent string statuses string limit int + unread bool } func newFetchCmd(root *rootOptions) *cobra.Command { @@ -42,6 +43,7 @@ func newFetchCmd(root *rootOptions) *cobra.Command { Agent: agent, Statuses: parseCSV(opts.statuses), Limit: opts.limit, + Unread: opts.unread, }) if err != nil { return err @@ -71,6 +73,7 @@ func newFetchCmd(root *rootOptions) *cobra.Command { cmd.Flags().StringVar(&opts.agent, "agent", "", "Assigned agent filter") cmd.Flags().StringVar(&opts.statuses, "status", "pending", "Comma-separated status filter") cmd.Flags().IntVar(&opts.limit, "limit", 20, "Maximum number of threads") + cmd.Flags().BoolVar(&opts.unread, "unread", false, "Only return threads whose latest message is unread by the agent") return cmd } diff --git a/internal/cli/inbox/integration_test.go b/internal/cli/inbox/integration_test.go index 178d0f2..552ebc5 100644 --- a/internal/cli/inbox/integration_test.go +++ b/internal/cli/inbox/integration_test.go @@ -3,8 +3,10 @@ package inbox import ( "bytes" "encoding/json" + "os" "path/filepath" "testing" + "time" ) func TestInboxLifecycle(t *testing.T) { @@ -238,9 +240,294 @@ func TestInboxFailLifecycle(t *testing.T) { } } +func TestInboxRenewWaitReplyAndCancel(t *testing.T) { + t.Parallel() + + dbPath := filepath.Join(t.TempDir(), "coord.db") + + runInboxCommand(t, "--db", dbPath, "--json", "init") + + sendOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "send", + "--from", "leader", + "--to", "worker-c", + "--subject", "Investigate auth edge case", + "--summary", "Check auth redirect behavior", + "--run", "run_blog_003", + "--task", "T3", + ) + + var sendResp map[string]any + mustDecodeJSON(t, sendOut, &sendResp) + threadID := nestedString(t, sendResp, "data", "thread", "thread_id") + + runInboxCommand( + t, + "--db", dbPath, + "--json", + "claim", + "--agent", "worker-c", + "--thread", threadID, + "--lease-seconds", "300", + ) + + renewOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "renew", + "--agent", "worker-c", + "--thread", threadID, + "--lease-seconds", "600", + ) + + var renewResp map[string]any + mustDecodeJSON(t, renewOut, &renewResp) + if got := nestedString(t, renewResp, "data", "message", "summary"); got != "lease renewed" { + t.Fatalf("expected lease renewed summary, got %q", got) + } + + blockedOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "update", + "--agent", "worker-c", + "--thread", threadID, + "--status", "blocked", + "--summary", "Need policy decision", + "--body", "Should guest users be redirected to login or shown a 403 page?", + ) + + var blockedResp map[string]any + mustDecodeJSON(t, blockedOut, &blockedResp) + blockedMessageID := nestedString(t, blockedResp, "data", "message", "message_id") + + type commandResult struct { + stdout string + stderr string + err error + } + + waitCh := make(chan commandResult, 1) + go func() { + stdout, stderr, err := executeInboxCommand( + "--db", dbPath, + "--json", + "wait-reply", + "--thread", threadID, + "--after-message", blockedMessageID, + "--timeout-seconds", "2", + ) + waitCh <- commandResult{stdout: stdout, stderr: stderr, err: err} + }() + + time.Sleep(200 * time.Millisecond) + + runInboxCommand( + t, + "--db", dbPath, + "--json", + "reply", + "--from", "leader", + "--to", "worker-c", + "--thread", threadID, + "--summary", "Redirect to login", + "--body", "Redirect guests to login for the MVP.", + ) + + var waitResult commandResult + select { + case waitResult = <-waitCh: + case <-time.After(3 * time.Second): + t.Fatal("wait-reply command did not return") + } + + if waitResult.err != nil { + t.Fatalf("wait-reply failed: %v\nstderr:\n%s", waitResult.err, waitResult.stderr) + } + + var waitResp map[string]any + mustDecodeJSON(t, waitResult.stdout, &waitResp) + if woke, ok := nestedValue(t, waitResp, "data", "woke").(bool); !ok || !woke { + t.Fatalf("expected wait-reply to wake, got %#v", nestedValue(t, waitResp, "data", "woke")) + } + if kind := nestedString(t, waitResp, "data", "message", "kind"); kind != "answer" { + t.Fatalf("expected answer wake message, got %q", kind) + } + + cancelOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "cancel", + "--agent", "leader", + "--thread", threadID, + "--reason", "Task superseded by a larger refactor", + ) + + var cancelResp map[string]any + mustDecodeJSON(t, cancelOut, &cancelResp) + if status := nestedString(t, cancelResp, "data", "thread", "status"); status != "cancelled" { + t.Fatalf("expected cancelled thread, got %q", status) + } +} + +func TestInboxWatchListUnreadAndAppend(t *testing.T) { + t.Parallel() + + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "coord.db") + bodyPath := filepath.Join(tempDir, "task.md") + + if err := os.WriteFile(bodyPath, []byte("Implement the initial admin post editor."), 0o644); err != nil { + t.Fatalf("write body file: %v", err) + } + + runInboxCommand(t, "--db", dbPath, "--json", "init") + + type commandResult struct { + stdout string + stderr string + err error + } + + watchCh := make(chan commandResult, 1) + go func() { + stdout, stderr, err := executeInboxCommand( + "--db", dbPath, + "--json", + "watch", + "--agent", "worker-d", + "--status", "pending", + "--timeout-seconds", "2", + ) + watchCh <- commandResult{stdout: stdout, stderr: stderr, err: err} + }() + + time.Sleep(200 * time.Millisecond) + + sendOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "send", + "--from", "leader", + "--to", "worker-d", + "--subject", "Build admin editor", + "--summary", "Create the first editor screen", + "--body-file", bodyPath, + "--run", "run_blog_004", + "--task", "T4", + ) + + var sendResp map[string]any + mustDecodeJSON(t, sendOut, &sendResp) + threadID := nestedString(t, sendResp, "data", "thread", "thread_id") + + var watchResult commandResult + select { + case watchResult = <-watchCh: + case <-time.After(3 * time.Second): + t.Fatal("watch command did not return") + } + + if watchResult.err != nil { + t.Fatalf("watch failed: %v\nstderr:\n%s", watchResult.err, watchResult.stderr) + } + + var watchResp map[string]any + mustDecodeJSON(t, watchResult.stdout, &watchResp) + if woke, ok := nestedValue(t, watchResp, "data", "woke").(bool); !ok || !woke { + t.Fatalf("expected watch to wake, got %#v", nestedValue(t, watchResp, "data", "woke")) + } + if watchedThreadID := nestedString(t, watchResp, "data", "thread", "thread_id"); watchedThreadID != threadID { + t.Fatalf("expected watch on thread %s, got %s", threadID, watchedThreadID) + } + + fetchOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "fetch", + "--agent", "worker-d", + "--status", "pending", + "--unread", + ) + + var fetchResp map[string]any + mustDecodeJSON(t, fetchOut, &fetchResp) + fetchedThreads, ok := nestedValue(t, fetchResp, "data", "threads").([]any) + if !ok || len(fetchedThreads) != 1 { + t.Fatalf("expected one unread pending thread, got %#v", nestedValue(t, fetchResp, "data", "threads")) + } + + listOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "list", + "--assigned-to", "worker-d", + "--status", "pending", + ) + + var listResp map[string]any + mustDecodeJSON(t, listOut, &listResp) + listedThreads, ok := nestedValue(t, listResp, "data", "threads").([]any) + if !ok || len(listedThreads) != 1 { + t.Fatalf("expected one listed thread, got %#v", nestedValue(t, listResp, "data", "threads")) + } + + runInboxCommand( + t, + "--db", dbPath, + "--json", + "send", + "--from", "leader", + "--to", "worker-d", + "--thread", threadID, + "--summary", "Use a markdown editor", + "--body", "Prefer a textarea-based markdown editor for v1.", + ) + + showOut := runInboxCommand( + t, + "--db", dbPath, + "--json", + "show", + "--thread", threadID, + ) + + var showResp map[string]any + mustDecodeJSON(t, showOut, &showResp) + messages, ok := nestedValue(t, showResp, "data", "messages").([]any) + if !ok || len(messages) != 2 { + t.Fatalf("expected two messages after append, got %#v", nestedValue(t, showResp, "data", "messages")) + } + firstMessage, ok := messages[0].(map[string]any) + if !ok { + t.Fatalf("expected first message object, got %#v", messages[0]) + } + if firstMessage["body"] != "Implement the initial admin post editor." { + t.Fatalf("expected body-file content in first message, got %#v", firstMessage["body"]) + } +} + func runInboxCommand(t *testing.T, args ...string) string { t.Helper() + stdout, stderr, err := executeInboxCommand(args...) + if err != nil { + t.Fatalf("execute inbox command %v: %v\nstderr:\n%s", args, err, stderr) + } + + return stdout +} + +func executeInboxCommand(args ...string) (string, string, error) { cmd := NewRootCmd() var stdout bytes.Buffer var stderr bytes.Buffer @@ -248,11 +535,8 @@ func runInboxCommand(t *testing.T, args ...string) string { cmd.SetErr(&stderr) cmd.SetArgs(args) - if err := cmd.Execute(); err != nil { - t.Fatalf("execute inbox command %v: %v\nstderr:\n%s", args, err, stderr.String()) - } - - return stdout.String() + err := cmd.Execute() + return stdout.String(), stderr.String(), err } func mustDecodeJSON(t *testing.T, raw string, target any) { diff --git a/internal/cli/inbox/list.go b/internal/cli/inbox/list.go new file mode 100644 index 0000000..65c9bbe --- /dev/null +++ b/internal/cli/inbox/list.go @@ -0,0 +1,81 @@ +package inbox + +import ( + "fmt" + + "ai-workflow-skill/internal/db" + "ai-workflow-skill/internal/protocol" + "ai-workflow-skill/internal/store" + + "github.com/spf13/cobra" +) + +type listOptions struct { + agent string + statuses string + createdBy string + assignedTo string + limit int +} + +func newListCmd(root *rootOptions) *cobra.Command { + opts := &listOptions{} + + cmd := &cobra.Command{ + Use: "list", + Short: "List threads with filters", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + agent := opts.agent + if agent == "" { + agent = root.agent + } + + sqlDB, err := db.Open(ctx, root.dbPath) + if err != nil { + return err + } + defer sqlDB.Close() + + s := store.NewInboxStore(sqlDB) + threads, err := s.ListThreads(ctx, store.ListInput{ + Agent: agent, + Statuses: parseCSV(opts.statuses), + CreatedBy: opts.createdBy, + AssignedTo: opts.assignedTo, + Limit: opts.limit, + }) + if err != nil { + return err + } + + resp := protocol.Success{ + OK: true, + Command: "list", + Data: map[string]any{ + "threads": threads, + }, + } + + if root.json { + return protocol.WriteJSON(cmd.OutOrStdout(), resp) + } + + for _, thread := range threads { + if _, err := fmt.Fprintf(cmd.OutOrStdout(), "%s\t%s\t%s\t%s\n", thread.ThreadID, thread.Status, thread.AssignedTo, thread.Subject); err != nil { + return err + } + } + return nil + }, + } + + cmd.Flags().StringVar(&opts.agent, "agent", "", "Assigned agent filter shortcut") + cmd.Flags().StringVar(&opts.statuses, "status", "", "Comma-separated status filter") + cmd.Flags().StringVar(&opts.createdBy, "created-by", "", "Created-by filter") + cmd.Flags().StringVar(&opts.assignedTo, "assigned-to", "", "Assigned-to filter") + cmd.Flags().IntVar(&opts.limit, "limit", 20, "Maximum number of threads") + + return cmd +} diff --git a/internal/cli/inbox/renew.go b/internal/cli/inbox/renew.go new file mode 100644 index 0000000..5b26761 --- /dev/null +++ b/internal/cli/inbox/renew.go @@ -0,0 +1,77 @@ +package inbox + +import ( + "fmt" + + "ai-workflow-skill/internal/db" + "ai-workflow-skill/internal/protocol" + "ai-workflow-skill/internal/store" + + "github.com/spf13/cobra" +) + +type renewOptions struct { + agent string + threadID string + leaseSeconds int +} + +func newRenewCmd(root *rootOptions) *cobra.Command { + opts := &renewOptions{} + + cmd := &cobra.Command{ + Use: "renew", + Short: "Extend an existing lease", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + agent := opts.agent + if agent == "" { + agent = root.agent + } + if agent == "" { + return fmt.Errorf("agent is required") + } + + sqlDB, err := db.Open(ctx, root.dbPath) + if err != nil { + return err + } + defer sqlDB.Close() + + s := store.NewInboxStore(sqlDB) + result, err := s.RenewLease(ctx, store.RenewInput{ + ThreadID: opts.threadID, + Agent: agent, + LeaseSeconds: opts.leaseSeconds, + }) + if err != nil { + return err + } + + resp := protocol.Success{ + OK: true, + Command: "renew", + Data: map[string]any{ + "thread": result.Thread, + "message": result.Message, + }, + } + + if root.json { + return protocol.WriteJSON(cmd.OutOrStdout(), resp) + } + + _, err = fmt.Fprintf(cmd.OutOrStdout(), "renewed lease on thread %s\n", result.Thread.ThreadID) + return err + }, + } + + cmd.Flags().StringVar(&opts.agent, "agent", "", "Lease owner") + cmd.Flags().StringVar(&opts.threadID, "thread", "", "Thread ID") + cmd.Flags().IntVar(&opts.leaseSeconds, "lease-seconds", 900, "Lease duration in seconds") + + _ = cmd.MarkFlagRequired("thread") + + return cmd +} diff --git a/internal/cli/inbox/reply.go b/internal/cli/inbox/reply.go index 154a3e8..3c835c8 100644 --- a/internal/cli/inbox/reply.go +++ b/internal/cli/inbox/reply.go @@ -17,6 +17,7 @@ type replyOptions struct { kind string summary string body string + bodyFile string payloadJSON string } @@ -37,6 +38,11 @@ func newReplyCmd(root *rootOptions) *cobra.Command { return fmt.Errorf("from agent is required") } + body, err := resolveBodyValue(opts.body, opts.bodyFile) + if err != nil { + return err + } + sqlDB, err := db.Open(ctx, root.dbPath) if err != nil { return err @@ -50,7 +56,7 @@ func newReplyCmd(root *rootOptions) *cobra.Command { ToAgent: opts.to, Kind: opts.kind, Summary: opts.summary, - Body: opts.body, + Body: body, PayloadJSON: opts.payloadJSON, }) if err != nil { @@ -81,6 +87,7 @@ func newReplyCmd(root *rootOptions) *cobra.Command { cmd.Flags().StringVar(&opts.kind, "kind", "answer", "Reply kind") cmd.Flags().StringVar(&opts.summary, "summary", "", "Short reply summary") cmd.Flags().StringVar(&opts.body, "body", "", "Reply body") + cmd.Flags().StringVar(&opts.bodyFile, "body-file", "", "Read reply body from file") cmd.Flags().StringVar(&opts.payloadJSON, "payload-json", "", "Structured payload JSON string") _ = cmd.MarkFlagRequired("thread") diff --git a/internal/cli/inbox/root.go b/internal/cli/inbox/root.go index 259a15e..f5a88c9 100644 --- a/internal/cli/inbox/root.go +++ b/internal/cli/inbox/root.go @@ -26,10 +26,15 @@ func NewRootCmd() *cobra.Command { cmd.AddCommand(newSendCmd(opts)) cmd.AddCommand(newFetchCmd(opts)) cmd.AddCommand(newClaimCmd(opts)) + cmd.AddCommand(newRenewCmd(opts)) cmd.AddCommand(newUpdateCmd(opts)) cmd.AddCommand(newReplyCmd(opts)) cmd.AddCommand(newDoneCmd(opts)) cmd.AddCommand(newFailCmd(opts)) + cmd.AddCommand(newCancelCmd(opts)) + cmd.AddCommand(newListCmd(opts)) + cmd.AddCommand(newWatchCmd(opts)) + cmd.AddCommand(newWaitReplyCmd(opts)) cmd.AddCommand(newShowCmd(opts)) return cmd diff --git a/internal/cli/inbox/send.go b/internal/cli/inbox/send.go index 1371b5d..4168002 100644 --- a/internal/cli/inbox/send.go +++ b/internal/cli/inbox/send.go @@ -20,6 +20,7 @@ type sendOptions struct { kind string summary string body string + bodyFile string payloadJSON string priority string } @@ -33,6 +34,22 @@ func newSendCmd(root *rootOptions) *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() + from := opts.from + if from == "" { + from = root.agent + } + if from == "" { + return fmt.Errorf("from agent is required") + } + if opts.threadID == "" && opts.subject == "" { + return fmt.Errorf("subject is required when creating a new thread") + } + + body, err := resolveBodyValue(opts.body, opts.bodyFile) + if err != nil { + return err + } + sqlDB, err := db.Open(ctx, root.dbPath) if err != nil { return err @@ -45,11 +62,11 @@ func newSendCmd(root *rootOptions) *cobra.Command { RunID: opts.runID, TaskID: opts.taskID, Subject: opts.subject, - FromAgent: opts.from, + FromAgent: from, ToAgent: opts.to, Kind: opts.kind, Summary: opts.summary, - Body: opts.body, + Body: body, PayloadJSON: opts.payloadJSON, Priority: opts.priority, }) @@ -84,12 +101,11 @@ func newSendCmd(root *rootOptions) *cobra.Command { cmd.Flags().StringVar(&opts.kind, "kind", "task", "Initial message kind") cmd.Flags().StringVar(&opts.summary, "summary", "", "Short message summary") cmd.Flags().StringVar(&opts.body, "body", "", "Message body") + cmd.Flags().StringVar(&opts.bodyFile, "body-file", "", "Read message body from file") cmd.Flags().StringVar(&opts.payloadJSON, "payload-json", "", "Structured payload JSON string") cmd.Flags().StringVar(&opts.priority, "priority", "normal", "Thread priority") - _ = cmd.MarkFlagRequired("from") _ = cmd.MarkFlagRequired("to") - _ = cmd.MarkFlagRequired("subject") return cmd } diff --git a/internal/cli/inbox/update.go b/internal/cli/inbox/update.go index b22a960..f1f6229 100644 --- a/internal/cli/inbox/update.go +++ b/internal/cli/inbox/update.go @@ -16,6 +16,7 @@ type updateOptions struct { status string summary string body string + bodyFile string payloadJSON string } @@ -36,6 +37,11 @@ func newUpdateCmd(root *rootOptions) *cobra.Command { return fmt.Errorf("agent is required") } + body, err := resolveBodyValue(opts.body, opts.bodyFile) + if err != nil { + return err + } + sqlDB, err := db.Open(ctx, root.dbPath) if err != nil { return err @@ -48,7 +54,7 @@ func newUpdateCmd(root *rootOptions) *cobra.Command { Agent: agent, Status: opts.status, Summary: opts.summary, - Body: opts.body, + Body: body, PayloadJSON: opts.payloadJSON, }) if err != nil { @@ -78,6 +84,7 @@ func newUpdateCmd(root *rootOptions) *cobra.Command { cmd.Flags().StringVar(&opts.status, "status", "", "New status: in_progress or blocked") cmd.Flags().StringVar(&opts.summary, "summary", "", "Short update summary") cmd.Flags().StringVar(&opts.body, "body", "", "Update body") + cmd.Flags().StringVar(&opts.bodyFile, "body-file", "", "Read update body from file") cmd.Flags().StringVar(&opts.payloadJSON, "payload-json", "", "Structured payload JSON string") _ = cmd.MarkFlagRequired("thread") diff --git a/internal/cli/inbox/wait_reply.go b/internal/cli/inbox/wait_reply.go new file mode 100644 index 0000000..27f364d --- /dev/null +++ b/internal/cli/inbox/wait_reply.go @@ -0,0 +1,85 @@ +package inbox + +import ( + "fmt" + "time" + + "ai-workflow-skill/internal/db" + "ai-workflow-skill/internal/protocol" + "ai-workflow-skill/internal/store" + + "github.com/spf13/cobra" +) + +type waitReplyOptions struct { + threadID string + afterMessageID string + afterEventID int64 + kinds string + timeoutSeconds int +} + +func newWaitReplyCmd(root *rootOptions) *cobra.Command { + opts := &waitReplyOptions{} + + cmd := &cobra.Command{ + Use: "wait-reply", + Short: "Block until a reply-like message appears in a thread", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + sqlDB, err := db.Open(ctx, root.dbPath) + if err != nil { + return err + } + defer sqlDB.Close() + + s := store.NewInboxStore(sqlDB) + result, err := s.WaitReply(ctx, store.WaitReplyInput{ + ThreadID: opts.threadID, + AfterMessageID: opts.afterMessageID, + AfterEventID: opts.afterEventID, + Kinds: parseCSV(opts.kinds), + Timeout: time.Duration(opts.timeoutSeconds) * time.Second, + }) + if err != nil { + return err + } + + data := map[string]any{ + "woke": result.Woke, + "next_event_id": result.NextEventID, + } + if result.Message != nil { + data["message"] = result.Message + } + + resp := protocol.Success{ + OK: true, + Command: "wait-reply", + Data: data, + } + + if root.json { + return protocol.WriteJSON(cmd.OutOrStdout(), resp) + } + if !result.Woke { + _, err = fmt.Fprintln(cmd.OutOrStdout(), "wait-reply timed out") + return err + } + + _, err = fmt.Fprintf(cmd.OutOrStdout(), "reply received on thread %s at event %d\n", result.Message.ThreadID, result.NextEventID) + return err + }, + } + + cmd.Flags().StringVar(&opts.threadID, "thread", "", "Thread ID") + cmd.Flags().StringVar(&opts.afterMessageID, "after-message", "", "Resume after a known message ID") + cmd.Flags().Int64Var(&opts.afterEventID, "after-event", 0, "Resume after a known event ID") + cmd.Flags().StringVar(&opts.kinds, "kinds", "answer,control,result", "Comma-separated message kinds to wake on") + cmd.Flags().IntVar(&opts.timeoutSeconds, "timeout-seconds", 0, "Maximum time to wait; 0 waits forever") + + _ = cmd.MarkFlagRequired("thread") + + return cmd +} diff --git a/internal/cli/inbox/watch.go b/internal/cli/inbox/watch.go new file mode 100644 index 0000000..addd35a --- /dev/null +++ b/internal/cli/inbox/watch.go @@ -0,0 +1,92 @@ +package inbox + +import ( + "fmt" + "time" + + "ai-workflow-skill/internal/db" + "ai-workflow-skill/internal/protocol" + "ai-workflow-skill/internal/store" + + "github.com/spf13/cobra" +) + +type watchOptions struct { + agent string + statuses string + timeoutSeconds int + afterEventID int64 +} + +func newWatchCmd(root *rootOptions) *cobra.Command { + opts := &watchOptions{} + + cmd := &cobra.Command{ + Use: "watch", + Short: "Block until new matching activity appears", + RunE: func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + agent := opts.agent + if agent == "" { + agent = root.agent + } + + sqlDB, err := db.Open(ctx, root.dbPath) + if err != nil { + return err + } + defer sqlDB.Close() + + s := store.NewInboxStore(sqlDB) + result, err := s.WatchThreads(ctx, store.WatchInput{ + Agent: agent, + Statuses: parseCSV(opts.statuses), + AfterEventID: opts.afterEventID, + StartFromNow: !cmd.Flags().Changed("after-event"), + Timeout: time.Duration(opts.timeoutSeconds) * time.Second, + }) + if err != nil { + return err + } + + data := map[string]any{ + "woke": result.Woke, + "next_event_id": result.NextEventID, + } + if result.Thread != nil { + data["thread"] = result.Thread + } + if result.Message != nil { + data["message"] = result.Message + } + if result.Event != nil { + data["event"] = result.Event + } + + resp := protocol.Success{ + OK: true, + Command: "watch", + Data: data, + } + + if root.json { + return protocol.WriteJSON(cmd.OutOrStdout(), resp) + } + if !result.Woke { + _, err = fmt.Fprintln(cmd.OutOrStdout(), "watch timed out") + return err + } + + _, err = fmt.Fprintf(cmd.OutOrStdout(), "watch woke on thread %s at event %d\n", result.Thread.ThreadID, result.NextEventID) + return err + }, + } + + cmd.Flags().StringVar(&opts.agent, "agent", "", "Assigned agent filter") + cmd.Flags().StringVar(&opts.statuses, "status", "pending,blocked,done,failed", "Comma-separated status filter") + cmd.Flags().IntVar(&opts.timeoutSeconds, "timeout-seconds", 0, "Maximum time to wait; 0 waits forever") + cmd.Flags().Int64Var(&opts.afterEventID, "after-event", 0, "Resume after a known event ID") + + return cmd +} diff --git a/internal/store/inbox.go b/internal/store/inbox.go index 62bda53..ea93e28 100644 --- a/internal/store/inbox.go +++ b/internal/store/inbox.go @@ -13,6 +13,8 @@ import ( ) var ErrLeaseConflict = errors.New("thread already claimed by another worker") +var ErrThreadNotFound = errors.New("thread not found") +var ErrNoActiveLease = errors.New("no active lease") type InboxStore struct { db *sql.DB @@ -49,6 +51,19 @@ type ThreadDetail struct { Messages []Message `json:"messages"` } +type Event struct { + EventID int64 `json:"event_id"` + RunID string `json:"run_id"` + TaskID string `json:"task_id"` + ThreadID string `json:"thread_id,omitempty"` + Source string `json:"source"` + EventType string `json:"event_type"` + MessageID string `json:"message_id,omitempty"` + Summary string `json:"summary"` + PayloadJSON json.RawMessage `json:"payload_json"` + CreatedAt time.Time `json:"created_at"` +} + type SendInput struct { ThreadID string RunID string @@ -67,6 +82,7 @@ type FetchInput struct { Agent string Statuses []string Limit int + Unread bool } type ClaimInput struct { @@ -75,6 +91,12 @@ type ClaimInput struct { LeaseSeconds int } +type RenewInput struct { + ThreadID string + Agent string + LeaseSeconds int +} + type ClaimResult struct { Thread Thread `json:"thread"` Message Message `json:"message"` @@ -108,11 +130,70 @@ type CompleteInput struct { Failed bool } +type CancelInput struct { + ThreadID string + Agent string + Reason string +} + +type ListInput struct { + Agent string + Statuses []string + CreatedBy string + AssignedTo string + Limit int + Unread bool +} + +type WatchInput struct { + Agent string + Statuses []string + AfterEventID int64 + StartFromNow bool + Timeout time.Duration +} + +type WatchResult struct { + Woke bool `json:"woke"` + NextEventID int64 `json:"next_event_id"` + Thread *Thread `json:"thread,omitempty"` + Message *Message `json:"message,omitempty"` + Event *Event `json:"event,omitempty"` +} + +type WaitReplyInput struct { + ThreadID string + AfterMessageID string + AfterEventID int64 + Kinds []string + Timeout time.Duration +} + +type WaitReplyResult struct { + Woke bool `json:"woke"` + NextEventID int64 `json:"next_event_id"` + Message *Message `json:"message,omitempty"` +} + func NewInboxStore(db *sql.DB) *InboxStore { return &InboxStore{db: db} } func (s *InboxStore) Send(ctx context.Context, input SendInput) (Thread, Message, error) { + if input.ThreadID != "" { + thread, err := selectThread(ctx, s.db, input.ThreadID) + if err == nil { + return s.appendThreadMessage(ctx, thread, input) + } + if !errors.Is(err, ErrThreadNotFound) { + return Thread{}, Message{}, err + } + } + + return s.createThread(ctx, input) +} + +func (s *InboxStore) createThread(ctx context.Context, input SendInput) (Thread, Message, error) { now := nowUTC() threadID := defaultID(input.ThreadID, "thr") @@ -217,43 +298,146 @@ func (s *InboxStore) Send(ctx context.Context, input SendInput) (Thread, Message return thread, message, nil } +func (s *InboxStore) appendThreadMessage(ctx context.Context, existing Thread, input SendInput) (Thread, Message, error) { + now := nowUTC() + messageID := newID("msg") + payload := normalizeJSON(input.PayloadJSON) + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return Thread{}, Message{}, fmt.Errorf("begin append transaction: %w", err) + } + defer tx.Rollback() + + thread, err := selectThreadForUpdate(ctx, tx, existing.ThreadID) + if err != nil { + return Thread{}, Message{}, err + } + if isTerminalStatus(thread.Status) { + return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", thread.ThreadID) + } + + assignedTo := thread.AssignedTo + if input.ToAgent != "" { + assignedTo = input.ToAgent + } + + message := Message{ + MessageID: messageID, + ThreadID: thread.ThreadID, + FromAgent: input.FromAgent, + ToAgent: defaultString(input.ToAgent, thread.AssignedTo), + Kind: defaultString(input.Kind, "task"), + Summary: defaultString(input.Summary, thread.Subject), + Body: input.Body, + PayloadJSON: json.RawMessage(payload), + CreatedAt: now, + } + + if err := insertMessage(ctx, tx, message); err != nil { + return Thread{}, Message{}, err + } + + if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, assignedTo, message.MessageID, now); err != nil { + return Thread{}, Message{}, err + } + + if err := insertEvent(ctx, tx, eventInput{ + RunID: thread.RunID, + TaskID: thread.TaskID, + ThreadID: thread.ThreadID, + Source: "inbox", + EventType: "thread_message_sent", + MessageID: message.MessageID, + Summary: message.Summary, + PayloadJSON: payload, + CreatedAt: now, + }); err != nil { + return Thread{}, Message{}, err + } + + if err := tx.Commit(); err != nil { + return Thread{}, Message{}, fmt.Errorf("commit append transaction: %w", err) + } + + thread.AssignedTo = assignedTo + thread.LatestMessageID = message.MessageID + thread.UpdatedAt = now + return thread, message, nil +} + func (s *InboxStore) FetchThreads(ctx context.Context, input FetchInput) ([]Thread, error) { statuses := input.Statuses if len(statuses) == 0 { statuses = []string{"pending"} } + return s.ListThreads(ctx, ListInput{ + Agent: input.Agent, + Statuses: statuses, + Limit: input.Limit, + Unread: input.Unread, + }) +} + +func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread, error) { limit := input.Limit if limit <= 0 { limit = 20 } - var args []any - var conditions []string + var ( + args []any + conditions []string + joins []string + ) - if input.Agent != "" { - conditions = append(conditions, "assigned_to = ?") + assignedTo := input.AssignedTo + if assignedTo == "" { + assignedTo = input.Agent + } + + if assignedTo != "" { + conditions = append(conditions, "t.assigned_to = ?") + args = append(args, assignedTo) + } + if input.CreatedBy != "" { + conditions = append(conditions, "t.created_by = ?") + args = append(args, input.CreatedBy) + } + if len(input.Statuses) > 0 { + conditions = append(conditions, "t.status IN ("+placeholders(len(input.Statuses))+")") + for _, status := range input.Statuses { + args = append(args, status) + } + } + if input.Unread { + if input.Agent == "" { + return nil, fmt.Errorf("agent is required when filtering unread threads") + } + joins = append(joins, "JOIN messages lm ON lm.message_id = t.latest_message_id") + conditions = append(conditions, "lm.to_agent = ?") + args = append(args, input.Agent) + conditions = append(conditions, "lm.from_agent <> ?") args = append(args, input.Agent) } - conditions = append(conditions, "status IN ("+placeholders(len(statuses))+")") - for _, status := range statuses { - args = append(args, status) - } - args = append(args, limit) - query := `SELECT - thread_id, run_id, task_id, subject, created_by, assigned_to, status, - priority, latest_message_id, created_at, updated_at - FROM threads` + t.thread_id, t.run_id, t.task_id, t.subject, t.created_by, t.assigned_to, t.status, + t.priority, t.latest_message_id, t.created_at, t.updated_at + FROM threads t` + if len(joins) > 0 { + query += " " + strings.Join(joins, " ") + } if len(conditions) > 0 { query += " WHERE " + strings.Join(conditions, " AND ") } - query += " ORDER BY updated_at DESC LIMIT ?" + query += " ORDER BY t.updated_at DESC LIMIT ?" + args = append(args, limit) rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("fetch threads: %w", err) + return nil, fmt.Errorf("list threads: %w", err) } defer rows.Close() @@ -409,6 +593,92 @@ func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimRe }, nil } +func (s *InboxStore) RenewLease(ctx context.Context, input RenewInput) (ClaimResult, error) { + if input.LeaseSeconds <= 0 { + input.LeaseSeconds = 900 + } + + now := nowUTC() + expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second) + leaseToken := newID("lease") + messageID := newID("msg") + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return ClaimResult{}, fmt.Errorf("begin renew transaction: %w", err) + } + defer tx.Rollback() + + thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) + if err != nil { + return ClaimResult{}, err + } + if isTerminalStatus(thread.Status) { + return ClaimResult{}, fmt.Errorf("thread %s is already terminal", input.ThreadID) + } + + if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil { + return ClaimResult{}, err + } + + if _, err := tx.ExecContext( + ctx, + `UPDATE leases + SET lease_token = ?, expires_at = ?, released_at = NULL + WHERE thread_id = ?`, + leaseToken, + formatTime(expiresAt), + input.ThreadID, + ); err != nil { + return ClaimResult{}, fmt.Errorf("renew lease: %w", err) + } + + message := Message{ + MessageID: messageID, + ThreadID: input.ThreadID, + FromAgent: input.Agent, + ToAgent: input.Agent, + Kind: "event", + Summary: "lease renewed", + Body: "", + PayloadJSON: json.RawMessage(fmt.Sprintf(`{"lease_seconds":%d,"lease_token":"%s"}`, input.LeaseSeconds, leaseToken)), + CreatedAt: now, + } + + if err := insertMessage(ctx, tx, message); err != nil { + return ClaimResult{}, err + } + + if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil { + return ClaimResult{}, err + } + + if err := insertEvent(ctx, tx, eventInput{ + RunID: thread.RunID, + TaskID: thread.TaskID, + ThreadID: thread.ThreadID, + Source: "inbox", + EventType: "thread_renewed", + MessageID: message.MessageID, + Summary: message.Summary, + PayloadJSON: string(message.PayloadJSON), + CreatedAt: now, + }); err != nil { + return ClaimResult{}, err + } + + if err := tx.Commit(); err != nil { + return ClaimResult{}, fmt.Errorf("commit renew transaction: %w", err) + } + + thread.LatestMessageID = message.MessageID + thread.UpdatedAt = now + return ClaimResult{ + Thread: thread, + Message: message, + }, nil +} + func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) (Thread, Message, error) { now := nowUTC() messageID := newID("msg") @@ -427,10 +697,12 @@ func (s *InboxStore) UpdateThreadStatus(ctx context.Context, input UpdateInput) if err != nil { return Thread{}, Message{}, err } - - if thread.Status == "done" || thread.Status == "failed" || thread.Status == "cancelled" { + if isTerminalStatus(thread.Status) { return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID) } + if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil { + return Thread{}, Message{}, err + } kind := "progress" if input.Status == "blocked" { @@ -495,6 +767,9 @@ func (s *InboxStore) ReplyToThread(ctx context.Context, input ReplyInput) (Threa if err != nil { return Thread{}, Message{}, err } + if isTerminalStatus(thread.Status) { + return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID) + } message := Message{ MessageID: messageID, @@ -561,6 +836,12 @@ func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (T if err != nil { return Thread{}, Message{}, err } + if isTerminalStatus(thread.Status) { + return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID) + } + if _, err := requireActiveLease(ctx, tx, input.ThreadID, input.Agent, now); err != nil { + return Thread{}, Message{}, err + } message := Message{ MessageID: messageID, @@ -618,6 +899,81 @@ func (s *InboxStore) CompleteThread(ctx context.Context, input CompleteInput) (T return thread, message, nil } +func (s *InboxStore) CancelThread(ctx context.Context, input CancelInput) (Thread, Message, error) { + now := nowUTC() + messageID := newID("msg") + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return Thread{}, Message{}, fmt.Errorf("begin cancel transaction: %w", err) + } + defer tx.Rollback() + + thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID) + if err != nil { + return Thread{}, Message{}, err + } + if isTerminalStatus(thread.Status) { + return Thread{}, Message{}, fmt.Errorf("thread %s is already terminal", input.ThreadID) + } + + summary := defaultString(input.Reason, "thread cancelled") + message := Message{ + MessageID: messageID, + ThreadID: thread.ThreadID, + FromAgent: input.Agent, + ToAgent: thread.AssignedTo, + Kind: "control", + Summary: summary, + Body: input.Reason, + PayloadJSON: json.RawMessage(`{}`), + CreatedAt: now, + } + + if err := insertMessage(ctx, tx, message); err != nil { + return Thread{}, Message{}, err + } + + if err := updateThreadState(ctx, tx, thread.ThreadID, "cancelled", thread.AssignedTo, message.MessageID, now); err != nil { + return Thread{}, Message{}, err + } + + if _, err := tx.ExecContext( + ctx, + `UPDATE leases + SET released_at = ? + WHERE thread_id = ? + AND released_at IS NULL`, + formatTime(now), + thread.ThreadID, + ); err != nil { + return Thread{}, Message{}, fmt.Errorf("release lease on cancel: %w", err) + } + + if err := insertEvent(ctx, tx, eventInput{ + RunID: thread.RunID, + TaskID: thread.TaskID, + ThreadID: thread.ThreadID, + Source: "inbox", + EventType: "thread_cancelled", + MessageID: message.MessageID, + Summary: message.Summary, + PayloadJSON: string(message.PayloadJSON), + CreatedAt: now, + }); err != nil { + return Thread{}, Message{}, err + } + + if err := tx.Commit(); err != nil { + return Thread{}, Message{}, fmt.Errorf("commit cancel transaction: %w", err) + } + + thread.Status = "cancelled" + thread.LatestMessageID = message.MessageID + thread.UpdatedAt = now + return thread, message, nil +} + func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDetail, error) { thread, err := selectThread(ctx, s.db, threadID) if err != nil { @@ -658,6 +1014,107 @@ func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDeta }, nil } +func (s *InboxStore) WatchThreads(ctx context.Context, input WatchInput) (WatchResult, error) { + cursor := input.AfterEventID + if input.StartFromNow && cursor == 0 { + current, err := s.currentMaxEventID(ctx) + if err != nil { + return WatchResult{}, err + } + cursor = current + } + + waitCtx := ctx + cancel := func() {} + if input.Timeout > 0 { + waitCtx, cancel = context.WithTimeout(ctx, input.Timeout) + } + defer cancel() + + for { + thread, message, event, found, err := s.findWatchEventAfter(waitCtx, input, cursor) + if err != nil { + if isDeadlineExceeded(waitCtx) { + return WatchResult{Woke: false, NextEventID: cursor}, nil + } + return WatchResult{}, err + } + if found { + return WatchResult{ + Woke: true, + NextEventID: event.EventID, + Thread: &thread, + Message: &message, + Event: &event, + }, nil + } + + ok, err := waitForNextPoll(waitCtx, 200*time.Millisecond) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return WatchResult{Woke: false, NextEventID: cursor}, nil + } + return WatchResult{}, err + } + if !ok { + return WatchResult{Woke: false, NextEventID: cursor}, nil + } + } +} + +func (s *InboxStore) WaitReply(ctx context.Context, input WaitReplyInput) (WaitReplyResult, error) { + cursor := input.AfterEventID + if input.AfterMessageID != "" { + eventID, err := s.lookupEventIDForMessage(ctx, input.ThreadID, input.AfterMessageID) + if err != nil { + return WaitReplyResult{}, err + } + if eventID > cursor { + cursor = eventID + } + } + + kinds := input.Kinds + if len(kinds) == 0 { + kinds = []string{"answer", "control", "result"} + } + + waitCtx := ctx + cancel := func() {} + if input.Timeout > 0 { + waitCtx, cancel = context.WithTimeout(ctx, input.Timeout) + } + defer cancel() + + for { + message, eventID, found, err := s.findReplyAfter(waitCtx, input.ThreadID, cursor, kinds) + if err != nil { + if isDeadlineExceeded(waitCtx) { + return WaitReplyResult{Woke: false, NextEventID: cursor}, nil + } + return WaitReplyResult{}, err + } + if found { + return WaitReplyResult{ + Woke: true, + NextEventID: eventID, + Message: &message, + }, nil + } + + ok, err := waitForNextPoll(waitCtx, 200*time.Millisecond) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + return WaitReplyResult{Woke: false, NextEventID: cursor}, nil + } + return WaitReplyResult{}, err + } + if !ok { + return WaitReplyResult{Woke: false, NextEventID: cursor}, nil + } + } +} + type threadScanner interface { Scan(dest ...any) error } @@ -719,6 +1176,36 @@ func scanMessage(scanner threadScanner) (Message, error) { return message, nil } +func scanEvent(scanner threadScanner) (Event, error) { + var ( + event Event + messageID sql.NullString + payload, createdAt string + ) + + if err := scanner.Scan( + &event.EventID, + &event.RunID, + &event.TaskID, + &event.ThreadID, + &event.Source, + &event.EventType, + &messageID, + &event.Summary, + &payload, + &createdAt, + ); err != nil { + return Event{}, fmt.Errorf("scan event: %w", err) + } + + if messageID.Valid { + event.MessageID = messageID.String + } + event.PayloadJSON = json.RawMessage(payload) + event.CreatedAt = parseTime(createdAt) + return event, nil +} + func selectThread(ctx context.Context, db queryRower, threadID string) (Thread, error) { row := db.QueryRowContext( ctx, @@ -732,7 +1219,7 @@ func selectThread(ctx context.Context, db queryRower, threadID string) (Thread, thread, err := scanThread(row) if errors.Is(err, sql.ErrNoRows) { - return Thread{}, fmt.Errorf("thread %s not found", threadID) + return Thread{}, fmt.Errorf("%w: %s", ErrThreadNotFound, threadID) } return thread, err } @@ -821,6 +1308,231 @@ func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assign return nil } +func requireActiveLease(ctx context.Context, tx *sql.Tx, threadID, agent string, now time.Time) (string, error) { + var ( + activeAgent string + leaseToken string + expiresAt string + releasedAt sql.NullString + ) + + err := tx.QueryRowContext( + ctx, + `SELECT agent_id, lease_token, expires_at, released_at + FROM leases + WHERE thread_id = ?`, + threadID, + ).Scan(&activeAgent, &leaseToken, &expiresAt, &releasedAt) + if errors.Is(err, sql.ErrNoRows) { + return "", ErrNoActiveLease + } + if err != nil { + return "", fmt.Errorf("read lease: %w", err) + } + + if releasedAt.Valid || !parseTime(expiresAt).After(now) { + return "", ErrNoActiveLease + } + if activeAgent != agent { + return "", ErrLeaseConflict + } + + return leaseToken, nil +} + +func (s *InboxStore) lookupEventIDForMessage(ctx context.Context, threadID, messageID string) (int64, error) { + var eventID int64 + err := s.db.QueryRowContext( + ctx, + `SELECT event_id + FROM events + WHERE thread_id = ? + AND message_id = ? + ORDER BY event_id DESC + LIMIT 1`, + threadID, + messageID, + ).Scan(&eventID) + if errors.Is(err, sql.ErrNoRows) { + return 0, fmt.Errorf("message %s not found in thread %s", messageID, threadID) + } + if err != nil { + return 0, fmt.Errorf("lookup message event: %w", err) + } + return eventID, nil +} + +func (s *InboxStore) currentMaxEventID(ctx context.Context) (int64, error) { + var maxEventID int64 + if err := s.db.QueryRowContext(ctx, `SELECT COALESCE(MAX(event_id), 0) FROM events`).Scan(&maxEventID); err != nil { + return 0, fmt.Errorf("query max event id: %w", err) + } + return maxEventID, nil +} + +func (s *InboxStore) findReplyAfter(ctx context.Context, threadID string, afterEventID int64, kinds []string) (Message, int64, bool, error) { + args := []any{threadID, afterEventID} + query := `SELECT + e.event_id, + m.message_id, m.thread_id, m.from_agent, m.to_agent, m.kind, m.summary, m.body, m.payload_json, m.created_at + FROM events e + JOIN messages m ON m.message_id = e.message_id + WHERE e.thread_id = ? + AND e.event_id > ?` + if len(kinds) > 0 { + query += " AND m.kind IN (" + placeholders(len(kinds)) + ")" + for _, kind := range kinds { + args = append(args, kind) + } + } + query += " ORDER BY e.event_id ASC LIMIT 1" + + row := s.db.QueryRowContext(ctx, query, args...) + + var ( + eventID int64 + message Message + payload string + created string + ) + err := row.Scan( + &eventID, + &message.MessageID, + &message.ThreadID, + &message.FromAgent, + &message.ToAgent, + &message.Kind, + &message.Summary, + &message.Body, + &payload, + &created, + ) + if errors.Is(err, sql.ErrNoRows) { + return Message{}, 0, false, nil + } + if err != nil { + return Message{}, 0, false, fmt.Errorf("query reply after event %d: %w", afterEventID, err) + } + + message.PayloadJSON = json.RawMessage(payload) + message.CreatedAt = parseTime(created) + return message, eventID, true, nil +} + +func (s *InboxStore) findWatchEventAfter(ctx context.Context, input WatchInput, afterEventID int64) (Thread, Message, Event, bool, error) { + args := []any{afterEventID} + query := `SELECT + t.thread_id, t.run_id, t.task_id, t.subject, t.created_by, t.assigned_to, t.status, + t.priority, t.latest_message_id, t.created_at, t.updated_at, + e.event_id, e.run_id, e.task_id, e.thread_id, e.source, e.event_type, e.message_id, e.summary, e.payload_json, e.created_at, + m.message_id, m.thread_id, m.from_agent, m.to_agent, m.kind, m.summary, m.body, m.payload_json, m.created_at + FROM events e + JOIN threads t ON t.thread_id = e.thread_id + JOIN messages m ON m.message_id = e.message_id + WHERE e.event_id > ?` + + if input.Agent != "" { + query += " AND t.assigned_to = ?" + args = append(args, input.Agent) + } + if len(input.Statuses) > 0 { + query += " AND t.status IN (" + placeholders(len(input.Statuses)) + ")" + for _, status := range input.Statuses { + args = append(args, status) + } + } + query += " ORDER BY e.event_id ASC LIMIT 1" + + row := s.db.QueryRowContext(ctx, query, args...) + + var ( + thread Thread + threadCreatedAt string + threadUpdatedAt string + threadLatestMessage sql.NullString + event Event + eventMessageID sql.NullString + eventPayload string + eventCreatedAt string + message Message + messagePayload string + messageCreatedAt string + ) + + err := row.Scan( + &thread.ThreadID, + &thread.RunID, + &thread.TaskID, + &thread.Subject, + &thread.CreatedBy, + &thread.AssignedTo, + &thread.Status, + &thread.Priority, + &threadLatestMessage, + &threadCreatedAt, + &threadUpdatedAt, + &event.EventID, + &event.RunID, + &event.TaskID, + &event.ThreadID, + &event.Source, + &event.EventType, + &eventMessageID, + &event.Summary, + &eventPayload, + &eventCreatedAt, + &message.MessageID, + &message.ThreadID, + &message.FromAgent, + &message.ToAgent, + &message.Kind, + &message.Summary, + &message.Body, + &messagePayload, + &messageCreatedAt, + ) + if errors.Is(err, sql.ErrNoRows) { + return Thread{}, Message{}, Event{}, false, nil + } + if err != nil { + return Thread{}, Message{}, Event{}, false, fmt.Errorf("query watch event after %d: %w", afterEventID, err) + } + + if threadLatestMessage.Valid { + thread.LatestMessageID = threadLatestMessage.String + } + thread.CreatedAt = parseTime(threadCreatedAt) + thread.UpdatedAt = parseTime(threadUpdatedAt) + if eventMessageID.Valid { + event.MessageID = eventMessageID.String + } + event.PayloadJSON = json.RawMessage(eventPayload) + event.CreatedAt = parseTime(eventCreatedAt) + message.PayloadJSON = json.RawMessage(messagePayload) + message.CreatedAt = parseTime(messageCreatedAt) + return thread, message, event, true, nil +} + +func waitForNextPoll(ctx context.Context, interval time.Duration) (bool, error) { + timer := time.NewTimer(interval) + defer timer.Stop() + + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-timer.C: + return true, nil + } +} + +func isTerminalStatus(status string) bool { + return status == "done" || status == "failed" || status == "cancelled" +} + +func isDeadlineExceeded(ctx context.Context) bool { + return ctx.Err() != nil && errors.Is(ctx.Err(), context.DeadlineExceeded) +} + func defaultID(value, prefix string) string { if value != "" { return value