Add orch control commands
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
package orch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"ai-workflow-skill/internal/protocol"
|
||||
"ai-workflow-skill/internal/store"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type cancelOptions struct {
|
||||
runID string
|
||||
taskID string
|
||||
reason string
|
||||
}
|
||||
|
||||
func newCancelCmd(root *rootOptions) *cobra.Command {
|
||||
opts := &cancelOptions{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "cancel",
|
||||
Short: "Cancel a task or an entire run",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
|
||||
sqlDB, err := openOrchDB(ctx, root.dbPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
|
||||
result, err := store.NewOrchStore(sqlDB).Cancel(ctx, store.CancelControlInput{
|
||||
RunID: opts.runID,
|
||||
TaskID: opts.taskID,
|
||||
Reason: opts.reason,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := protocol.Success{
|
||||
OK: true,
|
||||
Command: "cancel",
|
||||
Data: map[string]any{
|
||||
"run": result.Run,
|
||||
"cancelled_tasks": result.CancelledTasks,
|
||||
},
|
||||
}
|
||||
if root.json {
|
||||
return protocol.WriteJSON(cmd.OutOrStdout(), resp)
|
||||
}
|
||||
|
||||
if opts.taskID != "" {
|
||||
_, err = fmt.Fprintf(cmd.OutOrStdout(), "cancelled task %s in run %s\n", opts.taskID, opts.runID)
|
||||
return err
|
||||
}
|
||||
_, err = fmt.Fprintf(cmd.OutOrStdout(), "cancelled run %s (%d tasks)\n", opts.runID, len(result.CancelledTasks))
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&opts.runID, "run", "", "Run ID")
|
||||
cmd.Flags().StringVar(&opts.taskID, "task", "", "Optional task ID")
|
||||
cmd.Flags().StringVar(&opts.reason, "reason", "", "Cancellation reason")
|
||||
_ = cmd.MarkFlagRequired("run")
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,84 @@
|
||||
package orch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"ai-workflow-skill/internal/protocol"
|
||||
"ai-workflow-skill/internal/store"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type cleanupOptions struct {
|
||||
runID string
|
||||
taskID string
|
||||
attemptNo int
|
||||
allCompleted bool
|
||||
force bool
|
||||
}
|
||||
|
||||
func newCleanupCmd(root *rootOptions) *cobra.Command {
|
||||
opts := &cleanupOptions{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "cleanup",
|
||||
Short: "Remove completed or abandoned attempt worktrees",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
|
||||
sqlDB, err := openOrchDB(ctx, root.dbPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
|
||||
s := store.NewOrchStore(sqlDB)
|
||||
candidates, err := s.ListCleanupCandidates(ctx, store.CleanupInput{
|
||||
RunID: opts.runID,
|
||||
TaskID: opts.taskID,
|
||||
AttemptNo: opts.attemptNo,
|
||||
AllCompleted: opts.allCompleted,
|
||||
Force: opts.force,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
records := make([]store.CleanupRecord, 0, len(candidates))
|
||||
for _, candidate := range candidates {
|
||||
if err := cleanupAttemptWorktree(ctx, candidate.Attempt, opts.force); err != nil {
|
||||
return err
|
||||
}
|
||||
records = append(records, store.CleanupRecord{Attempt: candidate.Attempt})
|
||||
}
|
||||
|
||||
cleaned, err := s.MarkAttemptsCleaned(ctx, records)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := protocol.Success{
|
||||
OK: true,
|
||||
Command: "cleanup",
|
||||
Data: map[string]any{
|
||||
"cleaned": cleaned,
|
||||
},
|
||||
}
|
||||
if root.json {
|
||||
return protocol.WriteJSON(cmd.OutOrStdout(), resp)
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(cmd.OutOrStdout(), "cleaned %d worktrees\n", len(cleaned))
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&opts.runID, "run", "", "Run ID")
|
||||
cmd.Flags().StringVar(&opts.taskID, "task", "", "Optional task ID")
|
||||
cmd.Flags().IntVar(&opts.attemptNo, "attempt", 0, "Specific attempt number")
|
||||
cmd.Flags().BoolVar(&opts.allCompleted, "all-completed", false, "Clean all completed or abandoned worktrees in the run")
|
||||
cmd.Flags().BoolVar(&opts.force, "force", false, "Force cleanup even for non-terminal worktrees")
|
||||
_ = cmd.MarkFlagRequired("run")
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -879,3 +879,424 @@ func TestOrchWaitTimesOutWithoutMatchingEvent(t *testing.T) {
|
||||
t.Fatalf("expected next_event_id 0 on timeout, got %#v", nextEventID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrchRetryCreatesNewAttempt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "coord.db")
|
||||
repoPath := initGitRepo(t)
|
||||
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"run", "init",
|
||||
"--run", "run_blog_retry_001",
|
||||
"--goal", "Validate retry behavior",
|
||||
)
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"task", "add",
|
||||
"--run", "run_blog_retry_001",
|
||||
"--task", "T1",
|
||||
"--title", "Implement backend",
|
||||
"--default-to", "worker-a",
|
||||
)
|
||||
|
||||
dispatchOut := runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"dispatch",
|
||||
"--run", "run_blog_retry_001",
|
||||
"--task", "T1",
|
||||
"--repo-path", repoPath,
|
||||
"--workspace-root", ".orch/worktrees",
|
||||
"--strict-worktree",
|
||||
)
|
||||
|
||||
var dispatchResp map[string]any
|
||||
mustDecodeJSON(t, dispatchOut, &dispatchResp)
|
||||
threadID := nestedString(t, dispatchResp, "data", "attempt", "thread_id")
|
||||
firstWorktreePath := nestedString(t, dispatchResp, "data", "attempt", "worktree_path")
|
||||
|
||||
runInboxCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"claim",
|
||||
"--agent", "worker-a",
|
||||
"--thread", threadID,
|
||||
)
|
||||
runInboxCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"fail",
|
||||
"--agent", "worker-a",
|
||||
"--thread", threadID,
|
||||
"--summary", "Build failed",
|
||||
)
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"reconcile",
|
||||
"--run", "run_blog_retry_001",
|
||||
)
|
||||
|
||||
retryOut := runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"retry",
|
||||
"--run", "run_blog_retry_001",
|
||||
"--task", "T1",
|
||||
"--body", "Retry after fixing the failure.",
|
||||
)
|
||||
|
||||
var retryResp map[string]any
|
||||
mustDecodeJSON(t, retryOut, &retryResp)
|
||||
if got := nestedString(t, retryResp, "data", "task", "status"); got != "dispatched" {
|
||||
t.Fatalf("expected retried task to be dispatched, got %q", got)
|
||||
}
|
||||
if got := nestedValue(t, retryResp, "data", "attempt", "attempt_no").(float64); got != 2 {
|
||||
t.Fatalf("expected retry attempt 2, got %#v", got)
|
||||
}
|
||||
secondThreadID := nestedString(t, retryResp, "data", "attempt", "thread_id")
|
||||
if secondThreadID == threadID {
|
||||
t.Fatalf("expected retry to create a new thread, got same thread %q", secondThreadID)
|
||||
}
|
||||
secondWorktreePath := nestedString(t, retryResp, "data", "attempt", "worktree_path")
|
||||
if secondWorktreePath == firstWorktreePath {
|
||||
t.Fatalf("expected retry to create a new worktree, got reused path %q", secondWorktreePath)
|
||||
}
|
||||
if _, err := os.Stat(secondWorktreePath); err != nil {
|
||||
t.Fatalf("stat retry worktree %s: %v", secondWorktreePath, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrchReassignCancelsOldThreadAndDispatchesNewAttempt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "coord.db")
|
||||
repoPath := initGitRepo(t)
|
||||
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"run", "init",
|
||||
"--run", "run_blog_reassign_001",
|
||||
"--goal", "Validate reassign behavior",
|
||||
)
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"task", "add",
|
||||
"--run", "run_blog_reassign_001",
|
||||
"--task", "T1",
|
||||
"--title", "Implement backend",
|
||||
"--default-to", "worker-a",
|
||||
)
|
||||
|
||||
dispatchOut := runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"dispatch",
|
||||
"--run", "run_blog_reassign_001",
|
||||
"--task", "T1",
|
||||
"--repo-path", repoPath,
|
||||
"--workspace-root", ".orch/worktrees",
|
||||
"--strict-worktree",
|
||||
)
|
||||
|
||||
var dispatchResp map[string]any
|
||||
mustDecodeJSON(t, dispatchOut, &dispatchResp)
|
||||
originalThreadID := nestedString(t, dispatchResp, "data", "attempt", "thread_id")
|
||||
|
||||
runInboxCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"claim",
|
||||
"--agent", "worker-a",
|
||||
"--thread", originalThreadID,
|
||||
)
|
||||
runInboxCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"update",
|
||||
"--agent", "worker-a",
|
||||
"--thread", originalThreadID,
|
||||
"--status", "blocked",
|
||||
"--summary", "Need product decision",
|
||||
"--payload-json", `{"question":"Proceed with v1 scope?"}`,
|
||||
)
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"reconcile",
|
||||
"--run", "run_blog_reassign_001",
|
||||
)
|
||||
|
||||
reassignOut := runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"reassign",
|
||||
"--run", "run_blog_reassign_001",
|
||||
"--task", "T1",
|
||||
"--to", "worker-b",
|
||||
"--reason", "Try another worker with clearer ownership.",
|
||||
)
|
||||
|
||||
var reassignResp map[string]any
|
||||
mustDecodeJSON(t, reassignOut, &reassignResp)
|
||||
if got := nestedString(t, reassignResp, "data", "attempt", "assigned_to"); got != "worker-b" {
|
||||
t.Fatalf("expected reassigned attempt to target worker-b, got %q", got)
|
||||
}
|
||||
if got := nestedValue(t, reassignResp, "data", "attempt", "attempt_no").(float64); got != 2 {
|
||||
t.Fatalf("expected reassign attempt 2, got %#v", got)
|
||||
}
|
||||
newThreadID := nestedString(t, reassignResp, "data", "attempt", "thread_id")
|
||||
if newThreadID == originalThreadID {
|
||||
t.Fatalf("expected reassignment to create a new thread, got %q", newThreadID)
|
||||
}
|
||||
|
||||
showOut := runInboxCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"show",
|
||||
"--thread", originalThreadID,
|
||||
)
|
||||
|
||||
var showResp map[string]any
|
||||
mustDecodeJSON(t, showOut, &showResp)
|
||||
if got := nestedString(t, showResp, "data", "thread", "status"); got != "cancelled" {
|
||||
t.Fatalf("expected old reassigned thread to be cancelled, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrchCancelTaskAndRun(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "coord.db")
|
||||
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"run", "init",
|
||||
"--run", "run_blog_cancel_001",
|
||||
"--goal", "Validate cancel behavior",
|
||||
)
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"task", "add",
|
||||
"--run", "run_blog_cancel_001",
|
||||
"--task", "T1",
|
||||
"--title", "Implement backend",
|
||||
"--default-to", "worker-a",
|
||||
)
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"task", "add",
|
||||
"--run", "run_blog_cancel_001",
|
||||
"--task", "T2",
|
||||
"--title", "Implement frontend",
|
||||
"--default-to", "worker-b",
|
||||
)
|
||||
|
||||
dispatchOut := runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"dispatch",
|
||||
"--run", "run_blog_cancel_001",
|
||||
"--task", "T1",
|
||||
)
|
||||
|
||||
var dispatchResp map[string]any
|
||||
mustDecodeJSON(t, dispatchOut, &dispatchResp)
|
||||
threadID := nestedString(t, dispatchResp, "data", "attempt", "thread_id")
|
||||
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"cancel",
|
||||
"--run", "run_blog_cancel_001",
|
||||
"--task", "T1",
|
||||
"--reason", "Task is no longer needed.",
|
||||
)
|
||||
|
||||
statusOut := runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"status",
|
||||
"--run", "run_blog_cancel_001",
|
||||
)
|
||||
|
||||
var statusResp map[string]any
|
||||
mustDecodeJSON(t, statusOut, &statusResp)
|
||||
tasks := nestedArray(t, statusResp, "data", "tasks")
|
||||
taskStatuses := map[string]string{}
|
||||
for _, item := range tasks {
|
||||
task, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected task object, got %#v", item)
|
||||
}
|
||||
taskStatuses[task["task_id"].(string)] = task["status"].(string)
|
||||
}
|
||||
if taskStatuses["T1"] != "cancelled" {
|
||||
t.Fatalf("expected T1 cancelled, got %q", taskStatuses["T1"])
|
||||
}
|
||||
if taskStatuses["T2"] == "cancelled" {
|
||||
t.Fatalf("expected T2 to remain active before run cancel, got %q", taskStatuses["T2"])
|
||||
}
|
||||
|
||||
showOut := runInboxCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"show",
|
||||
"--thread", threadID,
|
||||
)
|
||||
|
||||
var showResp map[string]any
|
||||
mustDecodeJSON(t, showOut, &showResp)
|
||||
if got := nestedString(t, showResp, "data", "thread", "status"); got != "cancelled" {
|
||||
t.Fatalf("expected cancelled task thread to be cancelled, got %q", got)
|
||||
}
|
||||
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"cancel",
|
||||
"--run", "run_blog_cancel_001",
|
||||
"--reason", "Stop the run.",
|
||||
)
|
||||
|
||||
statusOut = runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"status",
|
||||
"--run", "run_blog_cancel_001",
|
||||
)
|
||||
mustDecodeJSON(t, statusOut, &statusResp)
|
||||
if got := nestedString(t, statusResp, "data", "run", "status"); got != "cancelled" {
|
||||
t.Fatalf("expected cancelled run, got %q", got)
|
||||
}
|
||||
tasks = nestedArray(t, statusResp, "data", "tasks")
|
||||
for _, item := range tasks {
|
||||
task, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("expected task object, got %#v", item)
|
||||
}
|
||||
if got, _ := task["status"].(string); got != "cancelled" {
|
||||
t.Fatalf("expected all tasks cancelled after run cancel, got %#v", task["status"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrchCleanupRemovesCompletedWorktree(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbPath := filepath.Join(t.TempDir(), "coord.db")
|
||||
repoPath := initGitRepo(t)
|
||||
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"run", "init",
|
||||
"--run", "run_blog_cleanup_001",
|
||||
"--goal", "Validate cleanup behavior",
|
||||
)
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"task", "add",
|
||||
"--run", "run_blog_cleanup_001",
|
||||
"--task", "T1",
|
||||
"--title", "Implement backend",
|
||||
"--default-to", "worker-a",
|
||||
)
|
||||
|
||||
dispatchOut := runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"dispatch",
|
||||
"--run", "run_blog_cleanup_001",
|
||||
"--task", "T1",
|
||||
"--repo-path", repoPath,
|
||||
"--workspace-root", ".orch/worktrees",
|
||||
"--strict-worktree",
|
||||
)
|
||||
|
||||
var dispatchResp map[string]any
|
||||
mustDecodeJSON(t, dispatchOut, &dispatchResp)
|
||||
threadID := nestedString(t, dispatchResp, "data", "attempt", "thread_id")
|
||||
worktreePath := nestedString(t, dispatchResp, "data", "attempt", "worktree_path")
|
||||
|
||||
runInboxCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"claim",
|
||||
"--agent", "worker-a",
|
||||
"--thread", threadID,
|
||||
)
|
||||
runInboxCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"done",
|
||||
"--agent", "worker-a",
|
||||
"--thread", threadID,
|
||||
"--summary", "Backend complete",
|
||||
)
|
||||
runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"reconcile",
|
||||
"--run", "run_blog_cleanup_001",
|
||||
)
|
||||
|
||||
cleanupOut := runOrchCommand(
|
||||
t,
|
||||
"--db", dbPath,
|
||||
"--json",
|
||||
"cleanup",
|
||||
"--run", "run_blog_cleanup_001",
|
||||
"--task", "T1",
|
||||
)
|
||||
|
||||
var cleanupResp map[string]any
|
||||
mustDecodeJSON(t, cleanupOut, &cleanupResp)
|
||||
cleaned := nestedArray(t, cleanupResp, "data", "cleaned")
|
||||
if len(cleaned) != 1 {
|
||||
t.Fatalf("expected one cleaned attempt, got %#v", cleaned)
|
||||
}
|
||||
if _, err := os.Stat(worktreePath); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected cleaned worktree path to be removed, err=%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
package orch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"ai-workflow-skill/internal/protocol"
|
||||
"ai-workflow-skill/internal/store"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type reassignOptions struct {
|
||||
runID string
|
||||
taskID string
|
||||
toAgent string
|
||||
reason string
|
||||
}
|
||||
|
||||
func newReassignCmd(root *rootOptions) *cobra.Command {
|
||||
opts := &reassignOptions{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "reassign",
|
||||
Short: "Reassign a blocked or failed task to another worker",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ctx := cmd.Context()
|
||||
|
||||
sqlDB, err := openOrchDB(ctx, root.dbPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
|
||||
s := store.NewOrchStore(sqlDB)
|
||||
task, attempt, err := s.GetTaskWithLatestAttempt(ctx, opts.runID, opts.taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := s.ReassignTask(ctx, store.ReassignInput{
|
||||
RunID: opts.runID,
|
||||
TaskID: opts.taskID,
|
||||
ToAgent: opts.toAgent,
|
||||
Reason: opts.reason,
|
||||
PrepareWorkspace: newAttemptReuseWorkspacePreparer(cmd, task, attempt),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := protocol.Success{
|
||||
OK: true,
|
||||
Command: "reassign",
|
||||
Data: map[string]any{
|
||||
"task": result.Task,
|
||||
"attempt": result.Attempt,
|
||||
"thread": result.Thread,
|
||||
"message": result.Message,
|
||||
"previous_attempt": result.PreviousAttempt,
|
||||
},
|
||||
}
|
||||
if root.json {
|
||||
return protocol.WriteJSON(cmd.OutOrStdout(), resp)
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(cmd.OutOrStdout(), "reassigned task %s to %s as attempt %d\n", result.Task.TaskID, result.Attempt.AssignedTo, result.Attempt.AttemptNo)
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&opts.runID, "run", "", "Run ID")
|
||||
cmd.Flags().StringVar(&opts.taskID, "task", "", "Task ID")
|
||||
cmd.Flags().StringVar(&opts.toAgent, "to", "", "Destination worker agent")
|
||||
cmd.Flags().StringVar(&opts.reason, "reason", "", "Reason for reassignment")
|
||||
_ = cmd.MarkFlagRequired("run")
|
||||
_ = cmd.MarkFlagRequired("task")
|
||||
_ = cmd.MarkFlagRequired("to")
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package orch
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"ai-workflow-skill/internal/protocol"
|
||||
"ai-workflow-skill/internal/store"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
type retryOptions struct {
|
||||
runID string
|
||||
taskID string
|
||||
toAgent string
|
||||
body string
|
||||
bodyFile string
|
||||
}
|
||||
|
||||
func newRetryCmd(root *rootOptions) *cobra.Command {
|
||||
opts := &retryOptions{}
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "retry",
|
||||
Short: "Retry a failed task by creating a new attempt",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
body, err := resolveBodyValue(opts.body, opts.bodyFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := cmd.Context()
|
||||
sqlDB, err := openOrchDB(ctx, root.dbPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer sqlDB.Close()
|
||||
|
||||
s := store.NewOrchStore(sqlDB)
|
||||
task, attempt, err := s.GetTaskWithLatestAttempt(ctx, opts.runID, opts.taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
result, err := s.RetryTask(ctx, store.RetryInput{
|
||||
RunID: opts.runID,
|
||||
TaskID: opts.taskID,
|
||||
ToAgent: opts.toAgent,
|
||||
Body: body,
|
||||
PrepareWorkspace: newAttemptReuseWorkspacePreparer(cmd, task, attempt),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp := protocol.Success{
|
||||
OK: true,
|
||||
Command: "retry",
|
||||
Data: map[string]any{
|
||||
"task": result.Task,
|
||||
"attempt": result.Attempt,
|
||||
"thread": result.Thread,
|
||||
"message": result.Message,
|
||||
"previous_attempt": result.PreviousAttempt,
|
||||
},
|
||||
}
|
||||
if root.json {
|
||||
return protocol.WriteJSON(cmd.OutOrStdout(), resp)
|
||||
}
|
||||
|
||||
_, err = fmt.Fprintf(cmd.OutOrStdout(), "retried task %s as attempt %d\n", result.Task.TaskID, result.Attempt.AttemptNo)
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&opts.runID, "run", "", "Run ID")
|
||||
cmd.Flags().StringVar(&opts.taskID, "task", "", "Task ID")
|
||||
cmd.Flags().StringVar(&opts.toAgent, "to", "", "Optional worker agent override")
|
||||
cmd.Flags().StringVar(&opts.body, "body", "", "Retry instruction body")
|
||||
cmd.Flags().StringVar(&opts.bodyFile, "body-file", "", "Read retry instruction body from file")
|
||||
_ = cmd.MarkFlagRequired("run")
|
||||
_ = cmd.MarkFlagRequired("task")
|
||||
|
||||
return cmd
|
||||
}
|
||||
@@ -29,6 +29,10 @@ func NewRootCmd() *cobra.Command {
|
||||
cmd.AddCommand(newDispatchCmd(opts))
|
||||
cmd.AddCommand(newReconcileCmd(opts))
|
||||
cmd.AddCommand(newWaitCmd(opts))
|
||||
cmd.AddCommand(newRetryCmd(opts))
|
||||
cmd.AddCommand(newReassignCmd(opts))
|
||||
cmd.AddCommand(newCancelCmd(opts))
|
||||
cmd.AddCommand(newCleanupCmd(opts))
|
||||
cmd.AddCommand(newBlockedCmd(opts))
|
||||
cmd.AddCommand(newAnswerCmd(opts))
|
||||
cmd.AddCommand(newStatusCmd(opts))
|
||||
|
||||
@@ -26,6 +26,31 @@ func newDispatchWorkspacePreparer(cmd *cobra.Command, opts dispatchOptions) stor
|
||||
}
|
||||
}
|
||||
|
||||
func newAttemptReuseWorkspacePreparer(cmd *cobra.Command, task store.Task, attempt *store.TaskAttempt) store.DispatchWorkspacePreparer {
|
||||
if attempt == nil || attempt.WorktreePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
workspaceRoot, ok := deriveWorkspaceRootFromAttempt(task.RunID, task.TaskID, attempt.WorktreePath)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
baseRef := attempt.BaseRef
|
||||
if strings.TrimSpace(baseRef) == "" {
|
||||
baseRef = attempt.BaseCommit
|
||||
}
|
||||
|
||||
opts := dispatchOptions{
|
||||
repoPath: attempt.WorktreePath,
|
||||
workspaceRoot: workspaceRoot,
|
||||
strictWorktree: true,
|
||||
baseRef: baseRef,
|
||||
}
|
||||
|
||||
return newDispatchWorkspacePreparer(cmd, opts)
|
||||
}
|
||||
|
||||
func dispatchUsesWorktree(opts dispatchOptions) bool {
|
||||
return strings.TrimSpace(opts.repoPath) != "" ||
|
||||
strings.TrimSpace(opts.workspaceRoot) != "" ||
|
||||
@@ -94,11 +119,15 @@ func resolveRepoRoot(ctx context.Context, repoPath string) (string, error) {
|
||||
return "", fmt.Errorf("resolve repo path: %w", err)
|
||||
}
|
||||
|
||||
stdout, _, err := runGit(ctx, absPath, "rev-parse", "--show-toplevel")
|
||||
if _, _, err := runGit(ctx, absPath, "rev-parse", "--show-toplevel"); err != nil {
|
||||
return "", protocol.InvalidInput("repo-path must point to a Git worktree", err)
|
||||
}
|
||||
|
||||
commonDir, err := resolveCommonGitDir(ctx, absPath)
|
||||
if err != nil {
|
||||
return "", protocol.InvalidInput("repo-path must point to a Git worktree", err)
|
||||
}
|
||||
return strings.TrimSpace(stdout), nil
|
||||
return filepath.Dir(commonDir), nil
|
||||
}
|
||||
|
||||
func resolveDispatchBase(ctx context.Context, repoRoot, workspaceRoot, requestedBaseRef string, strict bool) (string, string, error) {
|
||||
@@ -242,6 +271,27 @@ func buildAttemptWorktreePath(workspaceRoot, runID, taskID string, attemptNo int
|
||||
)
|
||||
}
|
||||
|
||||
func deriveWorkspaceRootFromAttempt(runID, taskID, worktreePath string) (string, bool) {
|
||||
suffix := filepath.Join(
|
||||
sanitizePathSegment(runID),
|
||||
sanitizePathSegment(taskID),
|
||||
filepath.Base(worktreePath),
|
||||
)
|
||||
parent := filepath.Dir(worktreePath)
|
||||
if filepath.Base(parent) != sanitizePathSegment(taskID) {
|
||||
return "", false
|
||||
}
|
||||
runDir := filepath.Dir(parent)
|
||||
if filepath.Base(runDir) != sanitizePathSegment(runID) {
|
||||
return "", false
|
||||
}
|
||||
root := filepath.Dir(runDir)
|
||||
if filepath.Clean(filepath.Join(root, suffix)) != filepath.Clean(worktreePath) {
|
||||
return "", false
|
||||
}
|
||||
return root, true
|
||||
}
|
||||
|
||||
func sanitizeGitSegment(value string) string {
|
||||
return sanitizeSegment(value)
|
||||
}
|
||||
@@ -301,3 +351,59 @@ func runGit(ctx context.Context, repoRoot string, args ...string) (string, strin
|
||||
}
|
||||
return "", message, fmt.Errorf("git %s: %s", strings.Join(args, " "), message)
|
||||
}
|
||||
|
||||
func cleanupAttemptWorktree(ctx context.Context, attempt store.TaskAttempt, force bool) error {
|
||||
if strings.TrimSpace(attempt.WorktreePath) == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := os.Stat(attempt.WorktreePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("stat worktree path: %w", err)
|
||||
}
|
||||
|
||||
repoRoot, err := resolveRepoRootFromExistingWorktree(ctx, attempt.WorktreePath)
|
||||
if err != nil {
|
||||
if force {
|
||||
return os.RemoveAll(attempt.WorktreePath)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
args := []string{"worktree", "remove"}
|
||||
if force {
|
||||
args = append(args, "--force")
|
||||
}
|
||||
args = append(args, attempt.WorktreePath)
|
||||
if _, _, err := runGit(ctx, repoRoot, args...); err != nil {
|
||||
if force {
|
||||
return os.RemoveAll(attempt.WorktreePath)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveRepoRootFromExistingWorktree(ctx context.Context, worktreePath string) (string, error) {
|
||||
commonDir, err := resolveCommonGitDir(ctx, worktreePath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Dir(commonDir), nil
|
||||
}
|
||||
|
||||
func resolveCommonGitDir(ctx context.Context, repoPath string) (string, error) {
|
||||
stdout, _, err := runGit(ctx, repoPath, "rev-parse", "--path-format=absolute", "--git-common-dir")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
commonDir := strings.TrimSpace(stdout)
|
||||
if !filepath.IsAbs(commonDir) {
|
||||
commonDir = filepath.Join(repoPath, commonDir)
|
||||
}
|
||||
return filepath.Clean(commonDir), nil
|
||||
}
|
||||
|
||||
+715
-29
@@ -9,6 +9,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"ai-workflow-skill/internal/protocol"
|
||||
)
|
||||
|
||||
var ErrRunNotFound = errors.New("run not found")
|
||||
@@ -173,6 +175,65 @@ type AnswerResult struct {
|
||||
Message Message `json:"message"`
|
||||
}
|
||||
|
||||
type RetryInput struct {
|
||||
RunID string
|
||||
TaskID string
|
||||
ToAgent string
|
||||
Body string
|
||||
PrepareWorkspace DispatchWorkspacePreparer
|
||||
}
|
||||
|
||||
type RetryResult struct {
|
||||
Task Task `json:"task"`
|
||||
Attempt TaskAttempt `json:"attempt"`
|
||||
Thread Thread `json:"thread"`
|
||||
Message Message `json:"message"`
|
||||
PreviousAttempt TaskAttempt `json:"previous_attempt"`
|
||||
}
|
||||
|
||||
type ReassignInput struct {
|
||||
RunID string
|
||||
TaskID string
|
||||
ToAgent string
|
||||
Reason string
|
||||
PrepareWorkspace DispatchWorkspacePreparer
|
||||
}
|
||||
|
||||
type ReassignResult struct {
|
||||
Task Task `json:"task"`
|
||||
Attempt TaskAttempt `json:"attempt"`
|
||||
Thread Thread `json:"thread"`
|
||||
Message Message `json:"message"`
|
||||
PreviousAttempt TaskAttempt `json:"previous_attempt"`
|
||||
}
|
||||
|
||||
type CancelControlInput struct {
|
||||
RunID string
|
||||
TaskID string
|
||||
Reason string
|
||||
}
|
||||
|
||||
type CancelResult struct {
|
||||
Run Run `json:"run"`
|
||||
CancelledTasks []Task `json:"cancelled_tasks"`
|
||||
}
|
||||
|
||||
type CleanupInput struct {
|
||||
RunID string
|
||||
TaskID string
|
||||
AttemptNo int
|
||||
AllCompleted bool
|
||||
Force bool
|
||||
}
|
||||
|
||||
type CleanupCandidate struct {
|
||||
Attempt TaskAttempt `json:"attempt"`
|
||||
}
|
||||
|
||||
type CleanupRecord struct {
|
||||
Attempt TaskAttempt `json:"attempt"`
|
||||
}
|
||||
|
||||
func NewOrchStore(db *sql.DB) *OrchStore {
|
||||
return &OrchStore{db: db}
|
||||
}
|
||||
@@ -502,30 +563,542 @@ func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (Disp
|
||||
return DispatchResult{}, fmt.Errorf("%w: task %s is not ready for dispatch", ErrInvalidState, task.TaskID)
|
||||
}
|
||||
|
||||
assignedTo := defaultString(strings.TrimSpace(input.ToAgent), task.DefaultTo)
|
||||
result, finalizeWorkspace, err := s.dispatchTaskTx(ctx, tx, task, strings.TrimSpace(input.ToAgent), input.Body, strings.TrimSpace(input.BaseRef), input.PrepareWorkspace, now)
|
||||
if err != nil {
|
||||
return DispatchResult{}, err
|
||||
}
|
||||
workspaceCommitted := false
|
||||
defer func() {
|
||||
finalizeWorkspace(workspaceCommitted)
|
||||
}()
|
||||
|
||||
if err := updateRunAggregateStatus(ctx, tx, task.RunID, now); err != nil {
|
||||
return DispatchResult{}, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return DispatchResult{}, fmt.Errorf("commit dispatch transaction: %w", err)
|
||||
}
|
||||
workspaceCommitted = true
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *OrchStore) GetTaskWithLatestAttempt(ctx context.Context, runID, taskID string) (Task, *TaskAttempt, error) {
|
||||
task, err := selectTask(ctx, s.db, runID, taskID)
|
||||
if err != nil {
|
||||
return Task{}, nil, err
|
||||
}
|
||||
if task.LatestAttemptNo == 0 {
|
||||
return task, nil, nil
|
||||
}
|
||||
|
||||
attempt, err := selectAttempt(ctx, s.db, runID, taskID, task.LatestAttemptNo)
|
||||
if err != nil {
|
||||
return Task{}, nil, err
|
||||
}
|
||||
return task, &attempt, nil
|
||||
}
|
||||
|
||||
func (s *OrchStore) RetryTask(ctx context.Context, input RetryInput) (RetryResult, error) {
|
||||
if strings.TrimSpace(input.RunID) == "" {
|
||||
return RetryResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
||||
}
|
||||
if strings.TrimSpace(input.TaskID) == "" {
|
||||
return RetryResult{}, fmt.Errorf("%w: task id is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
now := nowUTC()
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return RetryResult{}, fmt.Errorf("begin retry transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := selectRun(ctx, tx, input.RunID); err != nil {
|
||||
return RetryResult{}, err
|
||||
}
|
||||
|
||||
task, err := selectTask(ctx, tx, input.RunID, input.TaskID)
|
||||
if err != nil {
|
||||
return RetryResult{}, err
|
||||
}
|
||||
if task.Status != "failed" {
|
||||
return RetryResult{}, fmt.Errorf("%w: task %s is not failed", ErrInvalidState, task.TaskID)
|
||||
}
|
||||
if task.LatestAttemptNo == 0 {
|
||||
return RetryResult{}, fmt.Errorf("%w: task %s has no attempt to retry", ErrInvalidState, task.TaskID)
|
||||
}
|
||||
|
||||
previousAttempt, err := selectAttempt(ctx, tx, task.RunID, task.TaskID, task.LatestAttemptNo)
|
||||
if err != nil {
|
||||
return RetryResult{}, err
|
||||
}
|
||||
|
||||
result, finalizeWorkspace, err := s.dispatchTaskTx(
|
||||
ctx,
|
||||
tx,
|
||||
task,
|
||||
strings.TrimSpace(input.ToAgent),
|
||||
input.Body,
|
||||
defaultString(previousAttempt.BaseRef, previousAttempt.BaseCommit),
|
||||
input.PrepareWorkspace,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return RetryResult{}, err
|
||||
}
|
||||
workspaceCommitted := false
|
||||
defer func() {
|
||||
finalizeWorkspace(workspaceCommitted)
|
||||
}()
|
||||
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE task_attempts
|
||||
SET workspace_status = CASE
|
||||
WHEN workspace_status = 'cleaned' THEN workspace_status
|
||||
ELSE ?
|
||||
END,
|
||||
updated_at = ?
|
||||
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
|
||||
"abandoned",
|
||||
formatTime(now),
|
||||
previousAttempt.RunID,
|
||||
previousAttempt.TaskID,
|
||||
previousAttempt.AttemptNo,
|
||||
)
|
||||
if err != nil {
|
||||
return RetryResult{}, fmt.Errorf("mark previous retry attempt abandoned: %w", err)
|
||||
}
|
||||
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: task.RunID,
|
||||
TaskID: task.TaskID,
|
||||
ThreadID: result.Thread.ThreadID,
|
||||
Source: "orch",
|
||||
EventType: "task_retried",
|
||||
MessageID: result.Message.MessageID,
|
||||
Summary: result.Message.Summary,
|
||||
PayloadJSON: marshalJSON(map[string]any{
|
||||
"previous_attempt_no": previousAttempt.AttemptNo,
|
||||
"previous_thread_id": previousAttempt.ThreadID,
|
||||
"attempt_no": result.Attempt.AttemptNo,
|
||||
"thread_id": result.Attempt.ThreadID,
|
||||
}),
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return RetryResult{}, err
|
||||
}
|
||||
|
||||
if err := updateRunAggregateStatus(ctx, tx, task.RunID, now); err != nil {
|
||||
return RetryResult{}, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return RetryResult{}, fmt.Errorf("commit retry transaction: %w", err)
|
||||
}
|
||||
workspaceCommitted = true
|
||||
|
||||
return RetryResult{
|
||||
Task: result.Task,
|
||||
Attempt: result.Attempt,
|
||||
Thread: result.Thread,
|
||||
Message: result.Message,
|
||||
PreviousAttempt: previousAttempt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OrchStore) ReassignTask(ctx context.Context, input ReassignInput) (ReassignResult, error) {
|
||||
if strings.TrimSpace(input.RunID) == "" {
|
||||
return ReassignResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
||||
}
|
||||
if strings.TrimSpace(input.TaskID) == "" {
|
||||
return ReassignResult{}, fmt.Errorf("%w: task id is required", ErrInvalidInput)
|
||||
}
|
||||
if strings.TrimSpace(input.ToAgent) == "" {
|
||||
return ReassignResult{}, fmt.Errorf("%w: destination agent is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
now := nowUTC()
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return ReassignResult{}, fmt.Errorf("begin reassign transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := selectRun(ctx, tx, input.RunID); err != nil {
|
||||
return ReassignResult{}, err
|
||||
}
|
||||
|
||||
task, err := selectTask(ctx, tx, input.RunID, input.TaskID)
|
||||
if err != nil {
|
||||
return ReassignResult{}, err
|
||||
}
|
||||
if task.Status != "blocked" && task.Status != "failed" {
|
||||
return ReassignResult{}, fmt.Errorf("%w: task %s is not blocked or failed", ErrInvalidState, task.TaskID)
|
||||
}
|
||||
if task.LatestAttemptNo == 0 {
|
||||
return ReassignResult{}, fmt.Errorf("%w: task %s has no attempt to reassign", ErrInvalidState, task.TaskID)
|
||||
}
|
||||
|
||||
previousAttempt, err := selectAttempt(ctx, tx, task.RunID, task.TaskID, task.LatestAttemptNo)
|
||||
if err != nil {
|
||||
return ReassignResult{}, err
|
||||
}
|
||||
|
||||
if task.Status == "blocked" && previousAttempt.ThreadID != "" {
|
||||
thread, err := selectThread(ctx, tx, previousAttempt.ThreadID)
|
||||
if err != nil && !errors.Is(err, ErrThreadNotFound) {
|
||||
return ReassignResult{}, err
|
||||
}
|
||||
if err == nil && !isTerminalStatus(thread.Status) {
|
||||
if err := cancelThreadTx(ctx, tx, thread, defaultString(input.Reason, "task reassigned"), now); err != nil {
|
||||
return ReassignResult{}, err
|
||||
}
|
||||
}
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE task_attempts
|
||||
SET status = ?, workspace_status = CASE
|
||||
WHEN workspace_status = 'cleaned' THEN workspace_status
|
||||
ELSE ?
|
||||
END,
|
||||
updated_at = ?
|
||||
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
|
||||
"cancelled",
|
||||
"abandoned",
|
||||
formatTime(now),
|
||||
previousAttempt.RunID,
|
||||
previousAttempt.TaskID,
|
||||
previousAttempt.AttemptNo,
|
||||
)
|
||||
if err != nil {
|
||||
return ReassignResult{}, fmt.Errorf("mark previous blocked attempt abandoned: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE task_attempts
|
||||
SET workspace_status = CASE
|
||||
WHEN workspace_status = 'cleaned' THEN workspace_status
|
||||
ELSE ?
|
||||
END,
|
||||
updated_at = ?
|
||||
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
|
||||
"abandoned",
|
||||
formatTime(now),
|
||||
previousAttempt.RunID,
|
||||
previousAttempt.TaskID,
|
||||
previousAttempt.AttemptNo,
|
||||
)
|
||||
if err != nil {
|
||||
return ReassignResult{}, fmt.Errorf("mark previous attempt abandoned: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
result, finalizeWorkspace, err := s.dispatchTaskTx(
|
||||
ctx,
|
||||
tx,
|
||||
task,
|
||||
strings.TrimSpace(input.ToAgent),
|
||||
input.Reason,
|
||||
defaultString(previousAttempt.BaseRef, previousAttempt.BaseCommit),
|
||||
input.PrepareWorkspace,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return ReassignResult{}, err
|
||||
}
|
||||
workspaceCommitted := false
|
||||
defer func() {
|
||||
finalizeWorkspace(workspaceCommitted)
|
||||
}()
|
||||
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: task.RunID,
|
||||
TaskID: task.TaskID,
|
||||
ThreadID: result.Thread.ThreadID,
|
||||
Source: "orch",
|
||||
EventType: "task_reassigned",
|
||||
MessageID: result.Message.MessageID,
|
||||
Summary: defaultString(input.Reason, result.Message.Summary),
|
||||
PayloadJSON: marshalJSON(map[string]any{
|
||||
"previous_attempt_no": previousAttempt.AttemptNo,
|
||||
"previous_thread_id": previousAttempt.ThreadID,
|
||||
"from_agent": previousAttempt.AssignedTo,
|
||||
"to_agent": result.Attempt.AssignedTo,
|
||||
"attempt_no": result.Attempt.AttemptNo,
|
||||
"thread_id": result.Attempt.ThreadID,
|
||||
}),
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return ReassignResult{}, err
|
||||
}
|
||||
|
||||
if err := updateRunAggregateStatus(ctx, tx, task.RunID, now); err != nil {
|
||||
return ReassignResult{}, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return ReassignResult{}, fmt.Errorf("commit reassign transaction: %w", err)
|
||||
}
|
||||
workspaceCommitted = true
|
||||
|
||||
return ReassignResult{
|
||||
Task: result.Task,
|
||||
Attempt: result.Attempt,
|
||||
Thread: result.Thread,
|
||||
Message: result.Message,
|
||||
PreviousAttempt: previousAttempt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OrchStore) Cancel(ctx context.Context, input CancelControlInput) (CancelResult, error) {
|
||||
if strings.TrimSpace(input.RunID) == "" {
|
||||
return CancelResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
now := nowUTC()
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return CancelResult{}, fmt.Errorf("begin cancel transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
run, err := selectRun(ctx, tx, input.RunID)
|
||||
if err != nil {
|
||||
return CancelResult{}, err
|
||||
}
|
||||
|
||||
var tasks []Task
|
||||
if strings.TrimSpace(input.TaskID) != "" {
|
||||
task, err := selectTask(ctx, tx, input.RunID, input.TaskID)
|
||||
if err != nil {
|
||||
return CancelResult{}, err
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
} else {
|
||||
tasks, err = listTasksForRun(ctx, tx, input.RunID)
|
||||
if err != nil {
|
||||
return CancelResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
cancelledTasks := make([]Task, 0, len(tasks))
|
||||
for _, task := range tasks {
|
||||
if task.Status == "cancelled" {
|
||||
if strings.TrimSpace(input.TaskID) != "" {
|
||||
return CancelResult{}, fmt.Errorf("%w: task %s is already cancelled", ErrInvalidState, task.TaskID)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
cancelledTask, err := cancelTaskTx(ctx, tx, task, defaultString(input.Reason, "task cancelled"), now)
|
||||
if err != nil {
|
||||
return CancelResult{}, err
|
||||
}
|
||||
cancelledTasks = append(cancelledTasks, cancelledTask)
|
||||
}
|
||||
|
||||
if len(cancelledTasks) == 0 && len(tasks) == 0 {
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE runs SET status = ?, updated_at = ? WHERE run_id = ?`,
|
||||
"cancelled",
|
||||
formatTime(now),
|
||||
run.RunID,
|
||||
)
|
||||
if err != nil {
|
||||
return CancelResult{}, fmt.Errorf("cancel empty run: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: run.RunID,
|
||||
Source: "orch",
|
||||
EventType: "run_cancelled",
|
||||
Summary: defaultString(input.Reason, "run cancelled"),
|
||||
PayloadJSON: marshalJSON(map[string]any{
|
||||
"task_id": input.TaskID,
|
||||
"reason": input.Reason,
|
||||
}),
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return CancelResult{}, err
|
||||
}
|
||||
|
||||
if err := updateRunAggregateStatus(ctx, tx, run.RunID, now); err != nil {
|
||||
return CancelResult{}, err
|
||||
}
|
||||
|
||||
run, err = selectRun(ctx, tx, run.RunID)
|
||||
if err != nil {
|
||||
return CancelResult{}, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return CancelResult{}, fmt.Errorf("commit cancel transaction: %w", err)
|
||||
}
|
||||
|
||||
return CancelResult{
|
||||
Run: run,
|
||||
CancelledTasks: cancelledTasks,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OrchStore) ListCleanupCandidates(ctx context.Context, input CleanupInput) ([]CleanupCandidate, error) {
|
||||
if strings.TrimSpace(input.RunID) == "" {
|
||||
return nil, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
||||
}
|
||||
if input.AttemptNo > 0 && strings.TrimSpace(input.TaskID) == "" {
|
||||
return nil, fmt.Errorf("%w: task id is required when attempt is specified", ErrInvalidInput)
|
||||
}
|
||||
if !input.AllCompleted && strings.TrimSpace(input.TaskID) == "" && input.AttemptNo == 0 {
|
||||
return nil, fmt.Errorf("%w: specify --task, --attempt, or --all-completed", ErrInvalidInput)
|
||||
}
|
||||
|
||||
if _, err := s.GetRun(ctx, input.RunID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conditions := []string{"run_id = ?", "worktree_path <> ''", "workspace_status <> 'cleaned'"}
|
||||
args := []any{input.RunID}
|
||||
if strings.TrimSpace(input.TaskID) != "" {
|
||||
conditions = append(conditions, "task_id = ?")
|
||||
args = append(args, strings.TrimSpace(input.TaskID))
|
||||
}
|
||||
if input.AttemptNo > 0 {
|
||||
conditions = append(conditions, "attempt_no = ?")
|
||||
args = append(args, input.AttemptNo)
|
||||
}
|
||||
if !input.Force {
|
||||
conditions = append(conditions, "workspace_status IN (?, ?)")
|
||||
args = append(args, "completed", "abandoned")
|
||||
}
|
||||
|
||||
query := `SELECT
|
||||
run_id, task_id, attempt_no, assigned_to, thread_id, base_ref, base_commit,
|
||||
branch_name, worktree_path, workspace_status, result_commit, status,
|
||||
created_at, updated_at
|
||||
FROM task_attempts
|
||||
WHERE ` + strings.Join(conditions, " AND ") + `
|
||||
ORDER BY run_id, task_id, attempt_no ASC`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query cleanup candidates: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var candidates []CleanupCandidate
|
||||
for rows.Next() {
|
||||
attempt, err := scanAttempt(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
candidates = append(candidates, CleanupCandidate{Attempt: attempt})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate cleanup candidates: %w", err)
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, protocol.NoMatchingWork("no cleanup candidates matched the requested filters")
|
||||
}
|
||||
return candidates, nil
|
||||
}
|
||||
|
||||
func (s *OrchStore) MarkAttemptsCleaned(ctx context.Context, records []CleanupRecord) ([]TaskAttempt, error) {
|
||||
if len(records) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
now := nowUTC()
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin cleanup commit transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
cleaned := make([]TaskAttempt, 0, len(records))
|
||||
for _, record := range records {
|
||||
attempt := record.Attempt
|
||||
_, err := tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE task_attempts
|
||||
SET workspace_status = ?, updated_at = ?
|
||||
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
|
||||
"cleaned",
|
||||
formatTime(now),
|
||||
attempt.RunID,
|
||||
attempt.TaskID,
|
||||
attempt.AttemptNo,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("mark attempt cleaned: %w", err)
|
||||
}
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: attempt.RunID,
|
||||
TaskID: attempt.TaskID,
|
||||
ThreadID: attempt.ThreadID,
|
||||
Source: "orch",
|
||||
EventType: "workspace_cleaned",
|
||||
Summary: fmt.Sprintf("cleaned workspace for %s/%s attempt %d", attempt.RunID, attempt.TaskID, attempt.AttemptNo),
|
||||
PayloadJSON: marshalJSON(map[string]any{
|
||||
"attempt_no": attempt.AttemptNo,
|
||||
"worktree_path": attempt.WorktreePath,
|
||||
}),
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
attempt.WorkspaceStatus = "cleaned"
|
||||
attempt.UpdatedAt = now
|
||||
cleaned = append(cleaned, attempt)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit cleanup transaction: %w", err)
|
||||
}
|
||||
|
||||
return cleaned, nil
|
||||
}
|
||||
|
||||
func (s *OrchStore) dispatchTaskTx(
|
||||
ctx context.Context,
|
||||
tx *sql.Tx,
|
||||
task Task,
|
||||
toAgent string,
|
||||
body string,
|
||||
baseRef string,
|
||||
prepareWorkspace DispatchWorkspacePreparer,
|
||||
now time.Time,
|
||||
) (DispatchResult, func(bool), error) {
|
||||
assignedTo := defaultString(strings.TrimSpace(toAgent), task.DefaultTo)
|
||||
if assignedTo == "" {
|
||||
return DispatchResult{}, fmt.Errorf("%w: dispatch target agent is required", ErrInvalidInput)
|
||||
return DispatchResult{}, nil, fmt.Errorf("%w: dispatch target agent is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
attemptNo := task.LatestAttemptNo + 1
|
||||
workspace := DispatchWorkspace{
|
||||
BaseRef: strings.TrimSpace(input.BaseRef),
|
||||
BaseRef: strings.TrimSpace(baseRef),
|
||||
}
|
||||
cleanupWorkspace := func() {}
|
||||
workspaceCommitted := false
|
||||
if input.PrepareWorkspace != nil {
|
||||
workspace, cleanupWorkspace, err = input.PrepareWorkspace(task, attemptNo)
|
||||
finalizeWorkspace := func(success bool) {}
|
||||
if prepareWorkspace != nil {
|
||||
cleanupWorkspace := func() {}
|
||||
var err error
|
||||
workspace, cleanupWorkspace, err = prepareWorkspace(task, attemptNo)
|
||||
if err != nil {
|
||||
return DispatchResult{}, err
|
||||
return DispatchResult{}, nil, err
|
||||
}
|
||||
if cleanupWorkspace == nil {
|
||||
cleanupWorkspace = func() {}
|
||||
}
|
||||
defer func() {
|
||||
if !workspaceCommitted {
|
||||
finalizeWorkspace = func(success bool) {
|
||||
if !success {
|
||||
cleanupWorkspace()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
threadID := newID("thr")
|
||||
@@ -545,7 +1118,7 @@ func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (Disp
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(
|
||||
_, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO threads (
|
||||
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
|
||||
@@ -564,7 +1137,7 @@ func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (Disp
|
||||
formatTime(thread.UpdatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return DispatchResult{}, fmt.Errorf("insert dispatch thread: %w", err)
|
||||
return DispatchResult{}, finalizeWorkspace, fmt.Errorf("insert dispatch thread: %w", err)
|
||||
}
|
||||
|
||||
message := Message{
|
||||
@@ -574,12 +1147,12 @@ func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (Disp
|
||||
ToAgent: assignedTo,
|
||||
Kind: "task",
|
||||
Summary: defaultString(task.Summary, task.Title),
|
||||
Body: input.Body,
|
||||
Body: body,
|
||||
PayloadJSON: json.RawMessage(payloadJSON),
|
||||
CreatedAt: now,
|
||||
}
|
||||
if err := insertMessage(ctx, tx, message); err != nil {
|
||||
return DispatchResult{}, err
|
||||
return DispatchResult{}, finalizeWorkspace, err
|
||||
}
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: thread.RunID,
|
||||
@@ -592,7 +1165,7 @@ func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (Disp
|
||||
PayloadJSON: payloadJSON,
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return DispatchResult{}, err
|
||||
return DispatchResult{}, finalizeWorkspace, err
|
||||
}
|
||||
|
||||
attempt := TaskAttempt{
|
||||
@@ -633,7 +1206,7 @@ func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (Disp
|
||||
formatTime(attempt.UpdatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return DispatchResult{}, fmt.Errorf("insert task attempt: %w", err)
|
||||
return DispatchResult{}, finalizeWorkspace, fmt.Errorf("insert task attempt: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(
|
||||
@@ -648,7 +1221,7 @@ func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (Disp
|
||||
task.TaskID,
|
||||
)
|
||||
if err != nil {
|
||||
return DispatchResult{}, fmt.Errorf("update task dispatch status: %w", err)
|
||||
return DispatchResult{}, finalizeWorkspace, fmt.Errorf("update task dispatch status: %w", err)
|
||||
}
|
||||
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
@@ -662,18 +1235,9 @@ func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (Disp
|
||||
PayloadJSON: payloadJSON,
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return DispatchResult{}, err
|
||||
return DispatchResult{}, finalizeWorkspace, err
|
||||
}
|
||||
|
||||
if err := updateRunAggregateStatus(ctx, tx, task.RunID, now); err != nil {
|
||||
return DispatchResult{}, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return DispatchResult{}, fmt.Errorf("commit dispatch transaction: %w", err)
|
||||
}
|
||||
workspaceCommitted = true
|
||||
|
||||
task.Status = "dispatched"
|
||||
task.LatestAttemptNo = attempt.AttemptNo
|
||||
task.UpdatedAt = now
|
||||
@@ -683,7 +1247,129 @@ func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (Disp
|
||||
Attempt: attempt,
|
||||
Thread: thread,
|
||||
Message: message,
|
||||
}, nil
|
||||
}, finalizeWorkspace, nil
|
||||
}
|
||||
|
||||
func cancelTaskTx(ctx context.Context, tx *sql.Tx, task Task, reason string, now time.Time) (Task, error) {
|
||||
if task.LatestAttemptNo > 0 {
|
||||
attempt, err := selectAttempt(ctx, tx, task.RunID, task.TaskID, task.LatestAttemptNo)
|
||||
if err != nil {
|
||||
return Task{}, err
|
||||
}
|
||||
if attempt.ThreadID != "" {
|
||||
thread, err := selectThread(ctx, tx, attempt.ThreadID)
|
||||
if err != nil && !errors.Is(err, ErrThreadNotFound) {
|
||||
return Task{}, err
|
||||
}
|
||||
if err == nil && !isTerminalStatus(thread.Status) {
|
||||
if err := cancelThreadTx(ctx, tx, thread, reason, now); err != nil {
|
||||
return Task{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
attemptStatus := attempt.Status
|
||||
if attemptStatus != "done" && attemptStatus != "failed" && attemptStatus != "cancelled" {
|
||||
attemptStatus = "cancelled"
|
||||
}
|
||||
workspaceStatus := attempt.WorkspaceStatus
|
||||
if workspaceStatus != "cleaned" {
|
||||
workspaceStatus = "abandoned"
|
||||
}
|
||||
_, err = tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE task_attempts
|
||||
SET status = ?, workspace_status = ?, updated_at = ?
|
||||
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
|
||||
attemptStatus,
|
||||
nullIfEmpty(workspaceStatus),
|
||||
formatTime(now),
|
||||
attempt.RunID,
|
||||
attempt.TaskID,
|
||||
attempt.AttemptNo,
|
||||
)
|
||||
if err != nil {
|
||||
return Task{}, fmt.Errorf("update cancelled attempt: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE tasks
|
||||
SET status = ?, updated_at = ?
|
||||
WHERE run_id = ? AND task_id = ?`,
|
||||
"cancelled",
|
||||
formatTime(now),
|
||||
task.RunID,
|
||||
task.TaskID,
|
||||
)
|
||||
if err != nil {
|
||||
return Task{}, fmt.Errorf("update cancelled task: %w", err)
|
||||
}
|
||||
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: task.RunID,
|
||||
TaskID: task.TaskID,
|
||||
Source: "orch",
|
||||
EventType: "task_cancelled",
|
||||
Summary: defaultString(reason, "task cancelled"),
|
||||
PayloadJSON: marshalJSON(map[string]any{"reason": reason}),
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return Task{}, err
|
||||
}
|
||||
|
||||
task.Status = "cancelled"
|
||||
task.UpdatedAt = now
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func cancelThreadTx(ctx context.Context, tx *sql.Tx, thread Thread, reason string, now time.Time) error {
|
||||
messageID := newID("msg")
|
||||
summary := defaultString(reason, "thread cancelled")
|
||||
message := Message{
|
||||
MessageID: messageID,
|
||||
ThreadID: thread.ThreadID,
|
||||
FromAgent: "orch",
|
||||
ToAgent: thread.AssignedTo,
|
||||
Kind: "control",
|
||||
Summary: summary,
|
||||
Body: reason,
|
||||
PayloadJSON: json.RawMessage(`{}`),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
if err := insertMessage(ctx, tx, message); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := updateThreadState(ctx, tx, thread.ThreadID, "cancelled", thread.AssignedTo, message.MessageID, now); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE leases
|
||||
SET released_at = ?
|
||||
WHERE thread_id = ?
|
||||
AND released_at IS NULL`,
|
||||
formatTime(now),
|
||||
thread.ThreadID,
|
||||
); err != nil {
|
||||
return fmt.Errorf("release lease on orch cancel: %w", err)
|
||||
}
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: thread.RunID,
|
||||
TaskID: thread.TaskID,
|
||||
ThreadID: thread.ThreadID,
|
||||
Source: "inbox",
|
||||
EventType: "thread_cancelled",
|
||||
MessageID: message.MessageID,
|
||||
Summary: message.Summary,
|
||||
PayloadJSON: string(message.PayloadJSON),
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *OrchStore) ReconcileRun(ctx context.Context, runID string) (ReconcileResult, error) {
|
||||
|
||||
Reference in New Issue
Block a user