Add orch wait command

This commit is contained in:
2026-03-19 14:02:33 +08:00
parent 1b0cd723d7
commit f1785b314f
7 changed files with 535 additions and 22 deletions
+153
View File
@@ -4,6 +4,7 @@ import (
"os"
"path/filepath"
"testing"
"time"
)
func TestOrchRunDispatchReconcileLifecycle(t *testing.T) {
@@ -726,3 +727,155 @@ func TestOrchStrictWorktreeAllowsExplicitBaseRefOnDirtyRepo(t *testing.T) {
t.Fatalf("expected base_commit %q, got %q", baseCommit, got)
}
}
func TestOrchWaitWakesOnBlockedEvent(t *testing.T) {
t.Parallel()
dbPath := filepath.Join(t.TempDir(), "coord.db")
runOrchCommand(
t,
"--db", dbPath,
"--json",
"run", "init",
"--run", "run_blog_wait_001",
"--goal", "Validate wait wake behavior",
)
runOrchCommand(
t,
"--db", dbPath,
"--json",
"task", "add",
"--run", "run_blog_wait_001",
"--task", "T1",
"--title", "Implement backend",
"--default-to", "worker-a",
)
dispatchOut := runOrchCommand(
t,
"--db", dbPath,
"--json",
"dispatch",
"--run", "run_blog_wait_001",
"--task", "T1",
)
var dispatchResp map[string]any
mustDecodeJSON(t, dispatchOut, &dispatchResp)
threadID := nestedString(t, dispatchResp, "data", "attempt", "thread_id")
type waitResult struct {
stdout string
stderr string
exitCode int
}
resultCh := make(chan waitResult, 1)
go func() {
stdout, stderr, exitCode := executeOrchCommand(
"--db", dbPath,
"--json",
"wait",
"--run", "run_blog_wait_001",
"--for", "task_blocked",
"--after-event", "0",
"--timeout-seconds", "2",
)
resultCh <- waitResult{stdout: stdout, stderr: stderr, exitCode: exitCode}
}()
time.Sleep(200 * time.Millisecond)
runInboxCommand(
t,
"--db", dbPath,
"--json",
"claim",
"--agent", "worker-a",
"--thread", threadID,
)
runInboxCommand(
t,
"--db", dbPath,
"--json",
"update",
"--agent", "worker-a",
"--thread", threadID,
"--status", "blocked",
"--summary", "Need logging decision",
"--payload-json", `{"question":"stdout or stderr?"}`,
)
select {
case result := <-resultCh:
if result.exitCode != 0 {
t.Fatalf("wait exited with %d\nstderr:\n%s\nstdout:\n%s", result.exitCode, result.stderr, result.stdout)
}
var waitResp map[string]any
mustDecodeJSON(t, result.stdout, &waitResp)
if woke, _ := nestedValue(t, waitResp, "data", "woke").(bool); !woke {
t.Fatalf("expected wait to wake, got %#v", waitResp)
}
events := nestedArray(t, waitResp, "data", "events")
if len(events) != 1 {
t.Fatalf("expected one wait event, got %#v", events)
}
event, ok := events[0].(map[string]any)
if !ok {
t.Fatalf("expected wait event object, got %#v", events[0])
}
if got, _ := event["type"].(string); got != "task_blocked" {
t.Fatalf("expected task_blocked event, got %#v", event["type"])
}
if got, _ := event["summary"].(string); got != "Need logging decision" {
t.Fatalf("expected blocked summary to surface question summary, got %#v", event["summary"])
}
payload, ok := event["payload"].(map[string]any)
if !ok {
t.Fatalf("expected event payload object, got %#v", event["payload"])
}
if got, _ := payload["question"].(string); got != "stdout or stderr?" {
t.Fatalf("expected question payload, got %#v", payload["question"])
}
case <-time.After(3 * time.Second):
t.Fatal("timed out waiting for orch wait result")
}
}
func TestOrchWaitTimesOutWithoutMatchingEvent(t *testing.T) {
t.Parallel()
dbPath := filepath.Join(t.TempDir(), "coord.db")
runOrchCommand(
t,
"--db", dbPath,
"--json",
"run", "init",
"--run", "run_blog_wait_002",
"--goal", "Validate wait timeout behavior",
)
stdout, stderr, exitCode := executeOrchCommand(
"--db", dbPath,
"--json",
"wait",
"--run", "run_blog_wait_002",
"--for", "task_done",
"--after-event", "0",
"--timeout-seconds", "1",
)
if exitCode != 0 {
t.Fatalf("wait exited with %d\nstderr:\n%s\nstdout:\n%s", exitCode, stderr, stdout)
}
var waitResp map[string]any
mustDecodeJSON(t, stdout, &waitResp)
if woke, _ := nestedValue(t, waitResp, "data", "woke").(bool); woke {
t.Fatalf("expected wait timeout, got %#v", waitResp)
}
if nextEventID, _ := nestedValue(t, waitResp, "data", "next_event_id").(float64); nextEventID != 0 {
t.Fatalf("expected next_event_id 0 on timeout, got %#v", nextEventID)
}
}
+1
View File
@@ -28,6 +28,7 @@ func NewRootCmd() *cobra.Command {
cmd.AddCommand(newReadyCmd(opts))
cmd.AddCommand(newDispatchCmd(opts))
cmd.AddCommand(newReconcileCmd(opts))
cmd.AddCommand(newWaitCmd(opts))
cmd.AddCommand(newBlockedCmd(opts))
cmd.AddCommand(newAnswerCmd(opts))
cmd.AddCommand(newStatusCmd(opts))
+97
View File
@@ -0,0 +1,97 @@
package orch
import (
"fmt"
"strings"
"time"
"ai-workflow-skill/internal/protocol"
"ai-workflow-skill/internal/store"
"github.com/spf13/cobra"
)
type waitOptions struct {
runID string
eventTypesRaw string
afterEventID int64
timeoutSeconds int
}
func newWaitCmd(root *rootOptions) *cobra.Command {
opts := &waitOptions{}
cmd := &cobra.Command{
Use: "wait",
Short: "Block until matching run-scoped task events become available",
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
sqlDB, err := openOrchDB(ctx, root.dbPath)
if err != nil {
return err
}
defer sqlDB.Close()
result, err := store.NewOrchStore(sqlDB).WaitForEvents(ctx, store.WaitInput{
RunID: opts.runID,
EventTypes: splitCommaList(opts.eventTypesRaw),
AfterEventID: opts.afterEventID,
Timeout: time.Duration(opts.timeoutSeconds) * time.Second,
})
if err != nil {
return err
}
resp := protocol.Success{
OK: true,
Command: "wait",
Data: map[string]any{
"run_id": opts.runID,
"woke": result.Woke,
"next_event_id": result.NextEventID,
"events": result.Events,
},
}
if root.json {
return protocol.WriteJSON(cmd.OutOrStdout(), resp)
}
if !result.Woke {
_, err = fmt.Fprintf(cmd.OutOrStdout(), "wait timed out after event %d\n", result.NextEventID)
return err
}
for _, event := range result.Events {
if _, err := fmt.Fprintf(cmd.OutOrStdout(), "%d\t%s\t%s\t%s\n", event.EventID, event.Type, event.TaskID, event.Summary); err != nil {
return err
}
}
return nil
},
}
cmd.Flags().StringVar(&opts.runID, "run", "", "Run ID")
cmd.Flags().StringVar(&opts.eventTypesRaw, "for", "task_ready,task_blocked,task_done,task_failed", "Comma-separated event types to wait for")
cmd.Flags().Int64Var(&opts.afterEventID, "after-event", 0, "Only wait for events after this event ID")
cmd.Flags().IntVar(&opts.timeoutSeconds, "timeout-seconds", 0, "Maximum time to wait before timing out")
_ = cmd.MarkFlagRequired("run")
return cmd
}
func splitCommaList(value string) []string {
if strings.TrimSpace(value) == "" {
return nil
}
parts := strings.Split(value, ",")
result := make([]string, 0, len(parts))
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
result = append(result, part)
}
return result
}
+203 -13
View File
@@ -119,6 +119,30 @@ type ReconcileResult struct {
UpdatedTasks []Task `json:"updated_tasks"`
}
type RunEvent struct {
EventID int64 `json:"event_id"`
Type string `json:"type"`
RunID string `json:"run_id"`
TaskID string `json:"task_id"`
ThreadID string `json:"thread_id,omitempty"`
Summary string `json:"summary"`
Payload json.RawMessage `json:"payload"`
CreatedAt time.Time `json:"created_at"`
}
type WaitInput struct {
RunID string
EventTypes []string
AfterEventID int64
Timeout time.Duration
}
type WaitResult struct {
Woke bool `json:"woke"`
NextEventID int64 `json:"next_event_id"`
Events []RunEvent `json:"events,omitempty"`
}
type DispatchWorkspace struct {
BaseRef string `json:"base_ref,omitempty"`
BaseCommit string `json:"base_commit,omitempty"`
@@ -753,20 +777,31 @@ func (s *OrchStore) ReconcileRun(ctx context.Context, runID string) (ReconcileRe
return ReconcileResult{}, fmt.Errorf("update reconciled attempt status: %w", err)
}
summary := fmt.Sprintf("%s -> %s", taskID, nextStatus)
payloadJSON := marshalJSON(map[string]any{
"thread_id": threadID,
"thread_status": threadStatus,
"previous_status": taskStatus,
"previous_attempt": attemptStatus,
})
if nextStatus == "blocked" {
question, err := selectLatestQuestionMessage(ctx, tx, threadID)
if err != nil {
return ReconcileResult{}, err
}
summary = question.Summary
payloadJSON = string(question.PayloadJSON)
}
if err := insertEvent(ctx, tx, eventInput{
RunID: runID,
TaskID: taskID,
ThreadID: threadID,
Source: "orch",
EventType: "task_" + nextStatus,
Summary: fmt.Sprintf("%s -> %s", taskID, nextStatus),
PayloadJSON: marshalJSON(map[string]any{
"thread_id": threadID,
"thread_status": threadStatus,
"previous_status": taskStatus,
"previous_attempt": attemptStatus,
}),
CreatedAt: now,
RunID: runID,
TaskID: taskID,
ThreadID: threadID,
Source: "orch",
EventType: "task_" + nextStatus,
Summary: summary,
PayloadJSON: payloadJSON,
CreatedAt: now,
}); err != nil {
return ReconcileResult{}, err
}
@@ -1060,6 +1095,88 @@ func (s *OrchStore) GetRunOverview(ctx context.Context, runID string) (RunOvervi
}, nil
}
func (s *OrchStore) WaitForEvents(ctx context.Context, input WaitInput) (WaitResult, error) {
if strings.TrimSpace(input.RunID) == "" {
return WaitResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
eventTypes := normalizeWaitEventTypes(input.EventTypes)
if _, err := s.GetRun(ctx, input.RunID); err != nil {
return WaitResult{}, err
}
cursor := input.AfterEventID
waitCtx := ctx
cancel := func() {}
if input.Timeout > 0 {
waitCtx, cancel = context.WithTimeout(ctx, input.Timeout)
}
defer cancel()
for {
events, nextEventID, found, err := s.findRunEventsAfter(waitCtx, input.RunID, cursor, eventTypes)
if err != nil {
if isDeadlineExceeded(waitCtx) {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
return WaitResult{}, err
}
if found {
return WaitResult{
Woke: true,
NextEventID: nextEventID,
Events: events,
}, nil
}
if _, err := s.ReconcileRun(waitCtx, input.RunID); err != nil {
if isSQLiteBusyError(err) {
ok, waitErr := waitForNextPoll(waitCtx, 25*time.Millisecond)
if waitErr != nil {
if errors.Is(waitErr, context.DeadlineExceeded) {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
return WaitResult{}, waitErr
}
if !ok {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
continue
}
if isDeadlineExceeded(waitCtx) {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
return WaitResult{}, err
}
events, nextEventID, found, err = s.findRunEventsAfter(waitCtx, input.RunID, cursor, eventTypes)
if err != nil {
if isDeadlineExceeded(waitCtx) {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
return WaitResult{}, err
}
if found {
return WaitResult{
Woke: true,
NextEventID: nextEventID,
Events: events,
}, nil
}
ok, err := waitForNextPoll(waitCtx, 200*time.Millisecond)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
return WaitResult{}, err
}
if !ok {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
}
}
func listTasksForRun(ctx context.Context, db queryRowsContexter, runID string) ([]Task, error) {
rows, err := db.QueryContext(
ctx,
@@ -1090,6 +1207,55 @@ func listTasksForRun(ctx context.Context, db queryRowsContexter, runID string) (
return tasks, nil
}
func (s *OrchStore) findRunEventsAfter(ctx context.Context, runID string, afterEventID int64, eventTypes []string) ([]RunEvent, int64, bool, error) {
args := []any{runID, afterEventID}
query := `SELECT
event_id, event_type, run_id, task_id, thread_id, summary, payload_json, created_at
FROM events
WHERE run_id = ?
AND event_id > ?`
if len(eventTypes) > 0 {
query += " AND event_type IN (" + placeholders(len(eventTypes)) + ")"
for _, eventType := range eventTypes {
args = append(args, eventType)
}
}
query += " ORDER BY event_id ASC LIMIT 1"
row := s.db.QueryRowContext(ctx, query, args...)
var (
event RunEvent
threadID sql.NullString
payload string
createdAt string
)
err := row.Scan(
&event.EventID,
&event.Type,
&event.RunID,
&event.TaskID,
&threadID,
&event.Summary,
&payload,
&createdAt,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, 0, false, nil
}
if err != nil {
return nil, 0, false, fmt.Errorf("query run events after %d: %w", afterEventID, err)
}
if threadID.Valid {
event.ThreadID = threadID.String
}
event.Payload = json.RawMessage(payload)
event.CreatedAt = parseTime(createdAt)
return []RunEvent{event}, event.EventID, true, nil
}
type queryRowsContexter interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
@@ -1601,6 +1767,30 @@ func normalizePriority(priority string) (string, error) {
}
}
func normalizeWaitEventTypes(eventTypes []string) []string {
if len(eventTypes) == 0 {
return []string{"task_ready", "task_blocked", "task_done", "task_failed"}
}
normalized := make([]string, 0, len(eventTypes))
seen := make(map[string]struct{}, len(eventTypes))
for _, eventType := range eventTypes {
eventType = strings.TrimSpace(eventType)
if eventType == "" {
continue
}
if _, ok := seen[eventType]; ok {
continue
}
seen[eventType] = struct{}{}
normalized = append(normalized, eventType)
}
if len(normalized) == 0 {
return []string{"task_ready", "task_blocked", "task_done", "task_failed"}
}
return normalized
}
func validateAndNormalizeJSONDefault(fieldName, value, defaultValue string) (string, error) {
normalized := strings.TrimSpace(value)
if normalized == "" {