Files
ai-workflow-skill/internal/store/orch.go
T

1640 lines
41 KiB
Go

package store
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
)
var ErrRunNotFound = errors.New("run not found")
var ErrTaskNotFound = errors.New("task not found")
type OrchStore struct {
db *sql.DB
}
type Run struct {
RunID string `json:"run_id"`
Goal string `json:"goal"`
Summary string `json:"summary"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type Task struct {
RunID string `json:"run_id"`
TaskID string `json:"task_id"`
Title string `json:"title"`
Summary string `json:"summary"`
Status string `json:"status"`
DefaultTo string `json:"default_to,omitempty"`
Priority string `json:"priority"`
AcceptanceJSON json.RawMessage `json:"acceptance_json"`
LatestAttemptNo int `json:"latest_attempt_no,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type TaskDependency struct {
RunID string `json:"run_id"`
TaskID string `json:"task_id"`
DependsOnTaskID string `json:"depends_on_task_id"`
}
type TaskAttempt struct {
RunID string `json:"run_id"`
TaskID string `json:"task_id"`
AttemptNo int `json:"attempt_no"`
AssignedTo string `json:"assigned_to"`
ThreadID string `json:"thread_id"`
BaseRef string `json:"base_ref,omitempty"`
BaseCommit string `json:"base_commit,omitempty"`
BranchName string `json:"branch_name,omitempty"`
WorktreePath string `json:"worktree_path,omitempty"`
WorkspaceStatus string `json:"workspace_status,omitempty"`
ResultCommit string `json:"result_commit,omitempty"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type RunOverview struct {
Run Run `json:"run"`
TaskCounts map[string]int `json:"task_counts"`
Tasks []Task `json:"tasks,omitempty"`
}
type CreateRunInput struct {
RunID string
Goal string
Summary string
}
type AddTaskInput struct {
RunID string
TaskID string
Title string
Summary string
DefaultTo string
AcceptanceJSON string
Priority string
}
type AddDependencyInput struct {
RunID string
TaskID string
DependsOnTaskID string
}
type ListReadyInput struct {
RunID string
Limit int
}
type DispatchInput struct {
RunID string
TaskID string
ToAgent string
Body string
BaseRef string
}
type DispatchResult struct {
Task Task `json:"task"`
Attempt TaskAttempt `json:"attempt"`
Thread Thread `json:"thread"`
Message Message `json:"message"`
}
type ReconcileResult struct {
Run Run `json:"run"`
TaskCounts map[string]int `json:"task_counts"`
UpdatedTasks []Task `json:"updated_tasks"`
}
type BlockedTask struct {
Task Task `json:"task"`
Attempt TaskAttempt `json:"attempt"`
Question Message `json:"question"`
}
type AnswerInput struct {
RunID string
TaskID string
Body string
PayloadJSON string
}
type AnswerResult struct {
Task Task `json:"task"`
Attempt TaskAttempt `json:"attempt"`
Thread Thread `json:"thread"`
Message Message `json:"message"`
}
func NewOrchStore(db *sql.DB) *OrchStore {
return &OrchStore{db: db}
}
func (s *OrchStore) CreateRun(ctx context.Context, input CreateRunInput) (Run, error) {
runID := strings.TrimSpace(input.RunID)
goal := strings.TrimSpace(input.Goal)
if runID == "" {
return Run{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
if goal == "" {
return Run{}, fmt.Errorf("%w: goal is required", ErrInvalidInput)
}
now := nowUTC()
run := Run{
RunID: runID,
Goal: goal,
Summary: strings.TrimSpace(input.Summary),
Status: "active",
CreatedAt: now,
UpdatedAt: now,
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return Run{}, fmt.Errorf("begin create run transaction: %w", err)
}
defer tx.Rollback()
_, err = tx.ExecContext(
ctx,
`INSERT INTO runs (run_id, goal, summary, status, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?)`,
run.RunID,
run.Goal,
run.Summary,
run.Status,
formatTime(run.CreatedAt),
formatTime(run.UpdatedAt),
)
if err != nil {
if isUniqueConstraintError(err) {
return Run{}, fmt.Errorf("%w: run %s already exists", ErrInvalidState, run.RunID)
}
return Run{}, fmt.Errorf("insert run: %w", err)
}
if err := insertEvent(ctx, tx, eventInput{
RunID: run.RunID,
TaskID: "",
Source: "orch",
EventType: "run_initialized",
Summary: defaultString(run.Summary, run.Goal),
PayloadJSON: marshalJSON(map[string]any{"goal": run.Goal, "summary": run.Summary}),
CreatedAt: now,
}); err != nil {
return Run{}, err
}
if err := tx.Commit(); err != nil {
return Run{}, fmt.Errorf("commit create run transaction: %w", err)
}
return run, nil
}
func (s *OrchStore) GetRun(ctx context.Context, runID string) (Run, error) {
return selectRun(ctx, s.db, runID)
}
func (s *OrchStore) AddTask(ctx context.Context, input AddTaskInput) (Task, error) {
if strings.TrimSpace(input.RunID) == "" {
return Task{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
if strings.TrimSpace(input.TaskID) == "" {
return Task{}, fmt.Errorf("%w: task id is required", ErrInvalidInput)
}
if strings.TrimSpace(input.Title) == "" {
return Task{}, fmt.Errorf("%w: title is required", ErrInvalidInput)
}
priority, err := normalizePriority(input.Priority)
if err != nil {
return Task{}, err
}
acceptanceJSON, err := validateAndNormalizeJSONDefault("acceptance-json", input.AcceptanceJSON, "[]")
if err != nil {
return Task{}, err
}
now := nowUTC()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return Task{}, fmt.Errorf("begin add task transaction: %w", err)
}
defer tx.Rollback()
if _, err := selectRun(ctx, tx, input.RunID); err != nil {
return Task{}, err
}
_, err = tx.ExecContext(
ctx,
`INSERT INTO tasks (
run_id, task_id, title, summary, status, default_to, priority,
acceptance_json, latest_attempt_no, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, NULL, ?, ?)`,
input.RunID,
input.TaskID,
input.Title,
input.Summary,
"planned",
nullIfEmpty(input.DefaultTo),
priority,
acceptanceJSON,
formatTime(now),
formatTime(now),
)
if err != nil {
if isUniqueConstraintError(err) {
return Task{}, fmt.Errorf("%w: task %s already exists in run %s", ErrInvalidState, input.TaskID, input.RunID)
}
return Task{}, fmt.Errorf("insert task: %w", err)
}
if err := insertEvent(ctx, tx, eventInput{
RunID: input.RunID,
TaskID: input.TaskID,
Source: "orch",
EventType: "task_added",
Summary: input.Title,
PayloadJSON: marshalJSON(map[string]any{"title": input.Title, "priority": priority}),
CreatedAt: now,
}); err != nil {
return Task{}, err
}
if err := refreshReadyStates(ctx, tx, input.RunID, now); err != nil {
return Task{}, err
}
if err := updateRunAggregateStatus(ctx, tx, input.RunID, now); err != nil {
return Task{}, err
}
task, err := selectTask(ctx, tx, input.RunID, input.TaskID)
if err != nil {
return Task{}, err
}
if err := tx.Commit(); err != nil {
return Task{}, fmt.Errorf("commit add task transaction: %w", err)
}
return task, nil
}
func (s *OrchStore) AddDependency(ctx context.Context, input AddDependencyInput) (TaskDependency, error) {
if strings.TrimSpace(input.RunID) == "" {
return TaskDependency{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
if strings.TrimSpace(input.TaskID) == "" {
return TaskDependency{}, fmt.Errorf("%w: task id is required", ErrInvalidInput)
}
if strings.TrimSpace(input.DependsOnTaskID) == "" {
return TaskDependency{}, fmt.Errorf("%w: depends-on task id is required", ErrInvalidInput)
}
if input.TaskID == input.DependsOnTaskID {
return TaskDependency{}, fmt.Errorf("%w: task cannot depend on itself", ErrInvalidInput)
}
now := nowUTC()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return TaskDependency{}, fmt.Errorf("begin add dependency transaction: %w", err)
}
defer tx.Rollback()
if _, err := selectRun(ctx, tx, input.RunID); err != nil {
return TaskDependency{}, err
}
if _, err := selectTask(ctx, tx, input.RunID, input.TaskID); err != nil {
return TaskDependency{}, err
}
if _, err := selectTask(ctx, tx, input.RunID, input.DependsOnTaskID); err != nil {
return TaskDependency{}, err
}
_, err = tx.ExecContext(
ctx,
`INSERT INTO task_dependencies (run_id, task_id, depends_on_task_id)
VALUES (?, ?, ?)`,
input.RunID,
input.TaskID,
input.DependsOnTaskID,
)
if err != nil {
if isUniqueConstraintError(err) {
return TaskDependency{}, fmt.Errorf("%w: dependency %s -> %s already exists", ErrInvalidState, input.TaskID, input.DependsOnTaskID)
}
return TaskDependency{}, fmt.Errorf("insert dependency: %w", err)
}
if err := insertEvent(ctx, tx, eventInput{
RunID: input.RunID,
TaskID: input.TaskID,
Source: "orch",
EventType: "task_dependency_added",
Summary: fmt.Sprintf("%s depends on %s", input.TaskID, input.DependsOnTaskID),
PayloadJSON: marshalJSON(map[string]any{"depends_on_task_id": input.DependsOnTaskID}),
CreatedAt: now,
}); err != nil {
return TaskDependency{}, err
}
if err := refreshReadyStates(ctx, tx, input.RunID, now); err != nil {
return TaskDependency{}, err
}
if err := updateRunAggregateStatus(ctx, tx, input.RunID, now); err != nil {
return TaskDependency{}, err
}
if err := tx.Commit(); err != nil {
return TaskDependency{}, fmt.Errorf("commit add dependency transaction: %w", err)
}
return TaskDependency{
RunID: input.RunID,
TaskID: input.TaskID,
DependsOnTaskID: input.DependsOnTaskID,
}, nil
}
func (s *OrchStore) ListReadyTasks(ctx context.Context, input ListReadyInput) ([]Task, error) {
if strings.TrimSpace(input.RunID) == "" {
return nil, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
limit := input.Limit
if limit <= 0 {
limit = 20
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("begin list ready transaction: %w", err)
}
defer tx.Rollback()
if _, err := selectRun(ctx, tx, input.RunID); err != nil {
return nil, err
}
if err := refreshReadyStates(ctx, tx, input.RunID, nowUTC()); err != nil {
return nil, err
}
if err := updateRunAggregateStatus(ctx, tx, input.RunID, nowUTC()); err != nil {
return nil, err
}
rows, err := tx.QueryContext(
ctx,
`SELECT
run_id, task_id, title, summary, status, default_to, priority,
acceptance_json, latest_attempt_no, created_at, updated_at
FROM tasks
WHERE run_id = ? AND status = 'ready'
ORDER BY CASE priority
WHEN 'high' THEN 0
WHEN 'normal' THEN 1
ELSE 2
END, created_at ASC
LIMIT ?`,
input.RunID,
limit,
)
if err != nil {
return nil, fmt.Errorf("query ready tasks: %w", err)
}
defer rows.Close()
var tasks []Task
for rows.Next() {
task, err := scanTask(rows)
if err != nil {
return nil, err
}
tasks = append(tasks, task)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate ready tasks: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit list ready transaction: %w", err)
}
return tasks, nil
}
func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (DispatchResult, error) {
if strings.TrimSpace(input.RunID) == "" {
return DispatchResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
if strings.TrimSpace(input.TaskID) == "" {
return DispatchResult{}, fmt.Errorf("%w: task id is required", ErrInvalidInput)
}
now := nowUTC()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return DispatchResult{}, fmt.Errorf("begin dispatch transaction: %w", err)
}
defer tx.Rollback()
if _, err := selectRun(ctx, tx, input.RunID); err != nil {
return DispatchResult{}, err
}
if err := refreshReadyStates(ctx, tx, input.RunID, now); err != nil {
return DispatchResult{}, err
}
task, err := selectTask(ctx, tx, input.RunID, input.TaskID)
if err != nil {
return DispatchResult{}, err
}
if task.Status != "ready" {
return DispatchResult{}, fmt.Errorf("%w: task %s is not ready for dispatch", ErrInvalidState, task.TaskID)
}
assignedTo := defaultString(strings.TrimSpace(input.ToAgent), task.DefaultTo)
if assignedTo == "" {
return DispatchResult{}, fmt.Errorf("%w: dispatch target agent is required", ErrInvalidInput)
}
attemptNo := task.LatestAttemptNo + 1
threadID := newID("thr")
messageID := newID("msg")
payloadJSON := buildDispatchPayload(task, attemptNo, input.BaseRef)
thread := Thread{
ThreadID: threadID,
RunID: task.RunID,
TaskID: task.TaskID,
Subject: task.Title,
CreatedBy: "orch",
AssignedTo: assignedTo,
Status: "pending",
Priority: task.Priority,
LatestMessageID: messageID,
CreatedAt: now,
UpdatedAt: now,
}
_, err = tx.ExecContext(
ctx,
`INSERT INTO threads (
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
priority, latest_message_id, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
thread.ThreadID,
thread.RunID,
thread.TaskID,
thread.Subject,
thread.CreatedBy,
thread.AssignedTo,
thread.Status,
thread.Priority,
thread.LatestMessageID,
formatTime(thread.CreatedAt),
formatTime(thread.UpdatedAt),
)
if err != nil {
return DispatchResult{}, fmt.Errorf("insert dispatch thread: %w", err)
}
message := Message{
MessageID: messageID,
ThreadID: threadID,
FromAgent: "orch",
ToAgent: assignedTo,
Kind: "task",
Summary: defaultString(task.Summary, task.Title),
Body: input.Body,
PayloadJSON: json.RawMessage(payloadJSON),
CreatedAt: now,
}
if err := insertMessage(ctx, tx, message); err != nil {
return DispatchResult{}, err
}
if err := insertEvent(ctx, tx, eventInput{
RunID: thread.RunID,
TaskID: thread.TaskID,
ThreadID: thread.ThreadID,
Source: "inbox",
EventType: "thread_created",
MessageID: message.MessageID,
Summary: message.Summary,
PayloadJSON: payloadJSON,
CreatedAt: now,
}); err != nil {
return DispatchResult{}, err
}
attempt := TaskAttempt{
RunID: task.RunID,
TaskID: task.TaskID,
AttemptNo: attemptNo,
AssignedTo: assignedTo,
ThreadID: threadID,
BaseRef: strings.TrimSpace(input.BaseRef),
Status: "dispatched",
CreatedAt: now,
UpdatedAt: now,
}
_, err = tx.ExecContext(
ctx,
`INSERT INTO task_attempts (
run_id, task_id, attempt_no, assigned_to, thread_id, base_ref, base_commit,
branch_name, worktree_path, workspace_status, result_commit, status,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
attempt.RunID,
attempt.TaskID,
attempt.AttemptNo,
attempt.AssignedTo,
attempt.ThreadID,
nullIfEmpty(attempt.BaseRef),
nil,
nil,
nil,
nil,
nil,
attempt.Status,
formatTime(attempt.CreatedAt),
formatTime(attempt.UpdatedAt),
)
if err != nil {
return DispatchResult{}, fmt.Errorf("insert task attempt: %w", err)
}
_, err = tx.ExecContext(
ctx,
`UPDATE tasks
SET status = ?, latest_attempt_no = ?, updated_at = ?
WHERE run_id = ? AND task_id = ?`,
"dispatched",
attempt.AttemptNo,
formatTime(now),
task.RunID,
task.TaskID,
)
if err != nil {
return DispatchResult{}, fmt.Errorf("update task dispatch status: %w", err)
}
if err := insertEvent(ctx, tx, eventInput{
RunID: task.RunID,
TaskID: task.TaskID,
ThreadID: thread.ThreadID,
Source: "orch",
EventType: "task_dispatched",
MessageID: message.MessageID,
Summary: message.Summary,
PayloadJSON: payloadJSON,
CreatedAt: now,
}); err != nil {
return DispatchResult{}, err
}
if err := updateRunAggregateStatus(ctx, tx, task.RunID, now); err != nil {
return DispatchResult{}, err
}
if err := tx.Commit(); err != nil {
return DispatchResult{}, fmt.Errorf("commit dispatch transaction: %w", err)
}
task.Status = "dispatched"
task.LatestAttemptNo = attempt.AttemptNo
task.UpdatedAt = now
return DispatchResult{
Task: task,
Attempt: attempt,
Thread: thread,
Message: message,
}, nil
}
func (s *OrchStore) ReconcileRun(ctx context.Context, runID string) (ReconcileResult, error) {
if strings.TrimSpace(runID) == "" {
return ReconcileResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
now := nowUTC()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return ReconcileResult{}, fmt.Errorf("begin reconcile transaction: %w", err)
}
defer tx.Rollback()
if _, err := selectRun(ctx, tx, runID); err != nil {
return ReconcileResult{}, err
}
rows, err := tx.QueryContext(
ctx,
`SELECT
t.task_id,
t.status,
a.attempt_no,
a.status,
a.thread_id,
th.status
FROM tasks t
JOIN task_attempts a
ON a.run_id = t.run_id
AND a.task_id = t.task_id
AND a.attempt_no = t.latest_attempt_no
JOIN threads th ON th.thread_id = a.thread_id
WHERE t.run_id = ?
AND t.latest_attempt_no IS NOT NULL`,
runID,
)
if err != nil {
return ReconcileResult{}, fmt.Errorf("query reconcile candidates: %w", err)
}
defer rows.Close()
var updatedIDs []string
for rows.Next() {
var (
taskID string
taskStatus string
attemptNo int
attemptStatus string
threadID string
threadStatus string
)
if err := rows.Scan(&taskID, &taskStatus, &attemptNo, &attemptStatus, &threadID, &threadStatus); err != nil {
return ReconcileResult{}, fmt.Errorf("scan reconcile candidate: %w", err)
}
nextStatus := reconcileTaskStatus(threadStatus)
if nextStatus == "" {
continue
}
if nextStatus == taskStatus && nextStatus == attemptStatus {
continue
}
_, err = tx.ExecContext(
ctx,
`UPDATE tasks
SET status = ?, updated_at = ?
WHERE run_id = ? AND task_id = ?`,
nextStatus,
formatTime(now),
runID,
taskID,
)
if err != nil {
return ReconcileResult{}, fmt.Errorf("update reconciled task status: %w", err)
}
_, err = tx.ExecContext(
ctx,
`UPDATE task_attempts
SET status = ?, updated_at = ?
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
nextStatus,
formatTime(now),
runID,
taskID,
attemptNo,
)
if err != nil {
return ReconcileResult{}, fmt.Errorf("update reconciled attempt status: %w", err)
}
if err := insertEvent(ctx, tx, eventInput{
RunID: runID,
TaskID: taskID,
ThreadID: threadID,
Source: "orch",
EventType: "task_" + nextStatus,
Summary: fmt.Sprintf("%s -> %s", taskID, nextStatus),
PayloadJSON: marshalJSON(map[string]any{
"thread_id": threadID,
"thread_status": threadStatus,
"previous_status": taskStatus,
"previous_attempt": attemptStatus,
}),
CreatedAt: now,
}); err != nil {
return ReconcileResult{}, err
}
updatedIDs = append(updatedIDs, taskID)
}
if err := rows.Err(); err != nil {
return ReconcileResult{}, fmt.Errorf("iterate reconcile candidates: %w", err)
}
if err := refreshReadyStates(ctx, tx, runID, now); err != nil {
return ReconcileResult{}, err
}
if err := updateRunAggregateStatus(ctx, tx, runID, now); err != nil {
return ReconcileResult{}, err
}
run, err := selectRun(ctx, tx, runID)
if err != nil {
return ReconcileResult{}, err
}
taskCounts, err := collectTaskCounts(ctx, tx, runID)
if err != nil {
return ReconcileResult{}, err
}
updatedTasks := make([]Task, 0, len(updatedIDs))
for _, taskID := range updatedIDs {
task, err := selectTask(ctx, tx, runID, taskID)
if err != nil {
return ReconcileResult{}, err
}
updatedTasks = append(updatedTasks, task)
}
if err := tx.Commit(); err != nil {
return ReconcileResult{}, fmt.Errorf("commit reconcile transaction: %w", err)
}
return ReconcileResult{
Run: run,
TaskCounts: taskCounts,
UpdatedTasks: updatedTasks,
}, nil
}
func (s *OrchStore) ListBlockedTasks(ctx context.Context, runID string) ([]BlockedTask, error) {
if strings.TrimSpace(runID) == "" {
return nil, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("begin list blocked transaction: %w", err)
}
defer tx.Rollback()
if _, err := selectRun(ctx, tx, runID); err != nil {
return nil, err
}
rows, err := tx.QueryContext(
ctx,
`SELECT
t.run_id, t.task_id, t.title, t.summary, t.status, t.default_to, t.priority,
t.acceptance_json, t.latest_attempt_no, t.created_at, t.updated_at,
a.run_id, a.task_id, a.attempt_no, a.assigned_to, a.thread_id, a.base_ref,
a.base_commit, a.branch_name, a.worktree_path, a.workspace_status,
a.result_commit, a.status, a.created_at, a.updated_at
FROM tasks t
JOIN task_attempts a
ON a.run_id = t.run_id
AND a.task_id = t.task_id
AND a.attempt_no = t.latest_attempt_no
WHERE t.run_id = ?
AND t.status = 'blocked'
ORDER BY t.updated_at ASC`,
runID,
)
if err != nil {
return nil, fmt.Errorf("query blocked tasks: %w", err)
}
defer rows.Close()
var blocked []BlockedTask
for rows.Next() {
task, attempt, err := scanTaskAndAttempt(rows)
if err != nil {
return nil, err
}
question, err := selectLatestQuestionMessage(ctx, tx, attempt.ThreadID)
if err != nil {
return nil, err
}
blocked = append(blocked, BlockedTask{
Task: task,
Attempt: attempt,
Question: question,
})
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate blocked tasks: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit list blocked transaction: %w", err)
}
return blocked, nil
}
func (s *OrchStore) AnswerTask(ctx context.Context, input AnswerInput) (AnswerResult, error) {
if strings.TrimSpace(input.RunID) == "" {
return AnswerResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
if strings.TrimSpace(input.TaskID) == "" {
return AnswerResult{}, fmt.Errorf("%w: task id is required", ErrInvalidInput)
}
payloadJSON, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
if err != nil {
return AnswerResult{}, err
}
if strings.TrimSpace(input.Body) == "" && payloadJSON == "{}" {
return AnswerResult{}, fmt.Errorf("%w: body or payload-json is required", ErrInvalidInput)
}
now := nowUTC()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return AnswerResult{}, fmt.Errorf("begin answer transaction: %w", err)
}
defer tx.Rollback()
task, err := selectTask(ctx, tx, input.RunID, input.TaskID)
if err != nil {
return AnswerResult{}, err
}
if task.Status != "blocked" {
return AnswerResult{}, fmt.Errorf("%w: task %s is not blocked", ErrInvalidState, task.TaskID)
}
if task.LatestAttemptNo == 0 {
return AnswerResult{}, fmt.Errorf("%w: task %s has no active attempt", ErrInvalidState, task.TaskID)
}
attempt, err := selectAttempt(ctx, tx, input.RunID, input.TaskID, task.LatestAttemptNo)
if err != nil {
return AnswerResult{}, err
}
thread, err := selectThread(ctx, tx, attempt.ThreadID)
if err != nil {
return AnswerResult{}, err
}
if isTerminalStatus(thread.Status) {
return AnswerResult{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, thread.ThreadID)
}
message := Message{
MessageID: newID("msg"),
ThreadID: thread.ThreadID,
FromAgent: "orch",
ToAgent: attempt.AssignedTo,
Kind: "answer",
Summary: summarizeAnswer(input.Body),
Body: input.Body,
PayloadJSON: json.RawMessage(payloadJSON),
CreatedAt: now,
}
if err := insertMessage(ctx, tx, message); err != nil {
return AnswerResult{}, err
}
if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil {
return AnswerResult{}, err
}
if err := insertEvent(ctx, tx, eventInput{
RunID: thread.RunID,
TaskID: thread.TaskID,
ThreadID: thread.ThreadID,
Source: "inbox",
EventType: "thread_reply",
MessageID: message.MessageID,
Summary: message.Summary,
PayloadJSON: payloadJSON,
CreatedAt: now,
}); err != nil {
return AnswerResult{}, err
}
if err := insertEvent(ctx, tx, eventInput{
RunID: task.RunID,
TaskID: task.TaskID,
ThreadID: thread.ThreadID,
Source: "orch",
EventType: "task_answered",
MessageID: message.MessageID,
Summary: message.Summary,
PayloadJSON: payloadJSON,
CreatedAt: now,
}); err != nil {
return AnswerResult{}, err
}
_, err = tx.ExecContext(
ctx,
`UPDATE tasks
SET updated_at = ?
WHERE run_id = ? AND task_id = ?`,
formatTime(now),
task.RunID,
task.TaskID,
)
if err != nil {
return AnswerResult{}, fmt.Errorf("touch answered task: %w", err)
}
_, err = tx.ExecContext(
ctx,
`UPDATE task_attempts
SET updated_at = ?
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
formatTime(now),
attempt.RunID,
attempt.TaskID,
attempt.AttemptNo,
)
if err != nil {
return AnswerResult{}, fmt.Errorf("touch answered attempt: %w", err)
}
if err := updateRunAggregateStatus(ctx, tx, task.RunID, now); err != nil {
return AnswerResult{}, err
}
task.UpdatedAt = now
attempt.UpdatedAt = now
thread.LatestMessageID = message.MessageID
thread.UpdatedAt = now
if err := tx.Commit(); err != nil {
return AnswerResult{}, fmt.Errorf("commit answer transaction: %w", err)
}
return AnswerResult{
Task: task,
Attempt: attempt,
Thread: thread,
Message: message,
}, nil
}
func (s *OrchStore) GetRunOverview(ctx context.Context, runID string) (RunOverview, error) {
if strings.TrimSpace(runID) == "" {
return RunOverview{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
now := nowUTC()
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return RunOverview{}, fmt.Errorf("begin run overview transaction: %w", err)
}
defer tx.Rollback()
if _, err := selectRun(ctx, tx, runID); err != nil {
return RunOverview{}, err
}
if err := refreshReadyStates(ctx, tx, runID, now); err != nil {
return RunOverview{}, err
}
if err := updateRunAggregateStatus(ctx, tx, runID, now); err != nil {
return RunOverview{}, err
}
run, err := selectRun(ctx, tx, runID)
if err != nil {
return RunOverview{}, err
}
taskCounts, err := collectTaskCounts(ctx, tx, runID)
if err != nil {
return RunOverview{}, err
}
tasks, err := listTasksForRun(ctx, tx, runID)
if err != nil {
return RunOverview{}, err
}
if err := tx.Commit(); err != nil {
return RunOverview{}, fmt.Errorf("commit run overview transaction: %w", err)
}
return RunOverview{
Run: run,
TaskCounts: taskCounts,
Tasks: tasks,
}, nil
}
func listTasksForRun(ctx context.Context, db queryRowsContexter, runID string) ([]Task, error) {
rows, err := db.QueryContext(
ctx,
`SELECT
run_id, task_id, title, summary, status, default_to, priority,
acceptance_json, latest_attempt_no, created_at, updated_at
FROM tasks
WHERE run_id = ?
ORDER BY created_at ASC`,
runID,
)
if err != nil {
return nil, fmt.Errorf("query tasks for run: %w", err)
}
defer rows.Close()
var tasks []Task
for rows.Next() {
task, err := scanTask(rows)
if err != nil {
return nil, err
}
tasks = append(tasks, task)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate tasks for run: %w", err)
}
return tasks, nil
}
type queryRowsContexter interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
func scanRun(scanner threadScanner) (Run, error) {
var (
run Run
createdAt, updated string
)
if err := scanner.Scan(
&run.RunID,
&run.Goal,
&run.Summary,
&run.Status,
&createdAt,
&updated,
); err != nil {
return Run{}, fmt.Errorf("scan run: %w", err)
}
run.CreatedAt = parseTime(createdAt)
run.UpdatedAt = parseTime(updated)
return run, nil
}
func scanTask(scanner threadScanner) (Task, error) {
var (
task Task
defaultTo sql.NullString
latestAttempt sql.NullInt64
acceptanceJSON string
createdAt, updatedAt string
)
if err := scanner.Scan(
&task.RunID,
&task.TaskID,
&task.Title,
&task.Summary,
&task.Status,
&defaultTo,
&task.Priority,
&acceptanceJSON,
&latestAttempt,
&createdAt,
&updatedAt,
); err != nil {
return Task{}, fmt.Errorf("scan task: %w", err)
}
task.DefaultTo = defaultTo.String
task.AcceptanceJSON = json.RawMessage(acceptanceJSON)
if latestAttempt.Valid {
task.LatestAttemptNo = int(latestAttempt.Int64)
}
task.CreatedAt = parseTime(createdAt)
task.UpdatedAt = parseTime(updatedAt)
return task, nil
}
func scanAttempt(scanner threadScanner) (TaskAttempt, error) {
var (
attempt TaskAttempt
baseRef sql.NullString
baseCommit sql.NullString
branchName sql.NullString
worktreePath sql.NullString
workspaceStatus sql.NullString
resultCommit sql.NullString
createdAt, updated string
)
if err := scanner.Scan(
&attempt.RunID,
&attempt.TaskID,
&attempt.AttemptNo,
&attempt.AssignedTo,
&attempt.ThreadID,
&baseRef,
&baseCommit,
&branchName,
&worktreePath,
&workspaceStatus,
&resultCommit,
&attempt.Status,
&createdAt,
&updated,
); err != nil {
return TaskAttempt{}, fmt.Errorf("scan attempt: %w", err)
}
attempt.BaseRef = baseRef.String
attempt.BaseCommit = baseCommit.String
attempt.BranchName = branchName.String
attempt.WorktreePath = worktreePath.String
attempt.WorkspaceStatus = workspaceStatus.String
attempt.ResultCommit = resultCommit.String
attempt.CreatedAt = parseTime(createdAt)
attempt.UpdatedAt = parseTime(updated)
return attempt, nil
}
func scanTaskAndAttempt(scanner threadScanner) (Task, TaskAttempt, error) {
var (
task Task
taskDefaultTo sql.NullString
taskLatestAttempt sql.NullInt64
taskAcceptanceJSON string
taskCreatedAt string
taskUpdatedAt string
attempt TaskAttempt
attemptBaseRef sql.NullString
attemptBaseCommit sql.NullString
attemptBranchName sql.NullString
attemptWorktreePath sql.NullString
attemptWorkspaceState sql.NullString
attemptResultCommit sql.NullString
attemptCreatedAt string
attemptUpdatedAt string
)
if err := scanner.Scan(
&task.RunID,
&task.TaskID,
&task.Title,
&task.Summary,
&task.Status,
&taskDefaultTo,
&task.Priority,
&taskAcceptanceJSON,
&taskLatestAttempt,
&taskCreatedAt,
&taskUpdatedAt,
&attempt.RunID,
&attempt.TaskID,
&attempt.AttemptNo,
&attempt.AssignedTo,
&attempt.ThreadID,
&attemptBaseRef,
&attemptBaseCommit,
&attemptBranchName,
&attemptWorktreePath,
&attemptWorkspaceState,
&attemptResultCommit,
&attempt.Status,
&attemptCreatedAt,
&attemptUpdatedAt,
); err != nil {
return Task{}, TaskAttempt{}, fmt.Errorf("scan task and attempt: %w", err)
}
task.DefaultTo = taskDefaultTo.String
task.AcceptanceJSON = json.RawMessage(taskAcceptanceJSON)
if taskLatestAttempt.Valid {
task.LatestAttemptNo = int(taskLatestAttempt.Int64)
}
task.CreatedAt = parseTime(taskCreatedAt)
task.UpdatedAt = parseTime(taskUpdatedAt)
attempt.BaseRef = attemptBaseRef.String
attempt.BaseCommit = attemptBaseCommit.String
attempt.BranchName = attemptBranchName.String
attempt.WorktreePath = attemptWorktreePath.String
attempt.WorkspaceStatus = attemptWorkspaceState.String
attempt.ResultCommit = attemptResultCommit.String
attempt.CreatedAt = parseTime(attemptCreatedAt)
attempt.UpdatedAt = parseTime(attemptUpdatedAt)
return task, attempt, nil
}
func selectRun(ctx context.Context, db queryRower, runID string) (Run, error) {
row := db.QueryRowContext(
ctx,
`SELECT run_id, goal, summary, status, created_at, updated_at
FROM runs
WHERE run_id = ?`,
runID,
)
run, err := scanRun(row)
if errors.Is(err, sql.ErrNoRows) {
return Run{}, fmt.Errorf("%w: %s", ErrRunNotFound, runID)
}
return run, err
}
func selectTask(ctx context.Context, db queryRower, runID, taskID string) (Task, error) {
row := db.QueryRowContext(
ctx,
`SELECT
run_id, task_id, title, summary, status, default_to, priority,
acceptance_json, latest_attempt_no, created_at, updated_at
FROM tasks
WHERE run_id = ? AND task_id = ?`,
runID,
taskID,
)
task, err := scanTask(row)
if errors.Is(err, sql.ErrNoRows) {
return Task{}, fmt.Errorf("%w: %s/%s", ErrTaskNotFound, runID, taskID)
}
return task, err
}
func selectAttempt(ctx context.Context, db queryRower, runID, taskID string, attemptNo int) (TaskAttempt, error) {
row := db.QueryRowContext(
ctx,
`SELECT
run_id, task_id, attempt_no, assigned_to, thread_id, base_ref, base_commit,
branch_name, worktree_path, workspace_status, result_commit, status,
created_at, updated_at
FROM task_attempts
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
runID,
taskID,
attemptNo,
)
attempt, err := scanAttempt(row)
if errors.Is(err, sql.ErrNoRows) {
return TaskAttempt{}, fmt.Errorf("%w: attempt %s/%s/%d not found", ErrInvalidState, runID, taskID, attemptNo)
}
return attempt, err
}
func selectLatestQuestionMessage(ctx context.Context, db queryRowsAndRower, threadID string) (Message, error) {
row := db.QueryRowContext(
ctx,
`SELECT
message_id, thread_id, from_agent, to_agent, kind, summary, body,
payload_json, created_at
FROM messages
WHERE thread_id = ? AND kind = 'question'
ORDER BY created_at DESC
LIMIT 1`,
threadID,
)
message, err := scanMessage(row)
if errors.Is(err, sql.ErrNoRows) {
return Message{}, fmt.Errorf("%w: blocked thread %s has no question message", ErrInvalidState, threadID)
}
if err != nil {
return Message{}, err
}
artifactsByMessageID, err := loadArtifactsForMessageIDsFromQueryer(ctx, db, []string{message.MessageID})
if err != nil {
return Message{}, err
}
message.Artifacts = artifactsByMessageID[message.MessageID]
return message, nil
}
type queryRowsAndRower interface {
queryRower
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
func loadArtifactsForMessageIDsFromQueryer(ctx context.Context, db queryRowsContexter, messageIDs []string) (map[string][]Artifact, error) {
result := make(map[string][]Artifact)
if len(messageIDs) == 0 {
return result, nil
}
args := make([]any, 0, len(messageIDs))
for _, messageID := range messageIDs {
args = append(args, messageID)
}
rows, err := db.QueryContext(
ctx,
`SELECT
artifact_id, message_id, path, kind, metadata_json, created_at
FROM artifacts
WHERE message_id IN (`+placeholders(len(messageIDs))+`)
ORDER BY created_at ASC`,
args...,
)
if err != nil {
return nil, fmt.Errorf("query artifacts: %w", err)
}
defer rows.Close()
for rows.Next() {
artifact, err := scanArtifact(rows)
if err != nil {
return nil, err
}
result[artifact.MessageID] = append(result[artifact.MessageID], artifact)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate artifacts: %w", err)
}
return result, nil
}
func refreshReadyStates(ctx context.Context, tx *sql.Tx, runID string, now time.Time) error {
rows, err := tx.QueryContext(
ctx,
`SELECT task_id, status, title
FROM tasks
WHERE run_id = ?
AND status IN ('planned', 'ready')`,
runID,
)
if err != nil {
return fmt.Errorf("query tasks for readiness refresh: %w", err)
}
defer rows.Close()
type readinessRow struct {
taskID string
status string
title string
}
var tasks []readinessRow
for rows.Next() {
var row readinessRow
if err := rows.Scan(&row.taskID, &row.status, &row.title); err != nil {
return fmt.Errorf("scan readiness refresh row: %w", err)
}
tasks = append(tasks, row)
}
if err := rows.Err(); err != nil {
return fmt.Errorf("iterate readiness refresh rows: %w", err)
}
for _, task := range tasks {
ready, err := dependenciesSatisfied(ctx, tx, runID, task.taskID)
if err != nil {
return err
}
desired := "planned"
if ready {
desired = "ready"
}
if desired == task.status {
continue
}
_, err = tx.ExecContext(
ctx,
`UPDATE tasks
SET status = ?, updated_at = ?
WHERE run_id = ? AND task_id = ?`,
desired,
formatTime(now),
runID,
task.taskID,
)
if err != nil {
return fmt.Errorf("update task readiness: %w", err)
}
if desired == "ready" {
if err := insertEvent(ctx, tx, eventInput{
RunID: runID,
TaskID: task.taskID,
Source: "orch",
EventType: "task_ready",
Summary: defaultString(task.title, task.taskID),
PayloadJSON: marshalJSON(map[string]any{"task_id": task.taskID}),
CreatedAt: now,
}); err != nil {
return err
}
}
}
return nil
}
func dependenciesSatisfied(ctx context.Context, tx *sql.Tx, runID, taskID string) (bool, error) {
var pendingCount int
err := tx.QueryRowContext(
ctx,
`SELECT COUNT(*)
FROM task_dependencies d
JOIN tasks dep
ON dep.run_id = d.run_id
AND dep.task_id = d.depends_on_task_id
WHERE d.run_id = ?
AND d.task_id = ?
AND dep.status <> 'done'`,
runID,
taskID,
).Scan(&pendingCount)
if err != nil {
return false, fmt.Errorf("query dependency readiness: %w", err)
}
return pendingCount == 0, nil
}
func updateRunAggregateStatus(ctx context.Context, tx *sql.Tx, runID string, now time.Time) error {
counts, err := collectTaskCounts(ctx, tx, runID)
if err != nil {
return err
}
nextStatus := deriveRunStatus(counts)
_, err = tx.ExecContext(
ctx,
`UPDATE runs
SET status = ?, updated_at = ?
WHERE run_id = ?`,
nextStatus,
formatTime(now),
runID,
)
if err != nil {
return fmt.Errorf("update run aggregate status: %w", err)
}
return nil
}
func collectTaskCounts(ctx context.Context, db queryRowsContexter, runID string) (map[string]int, error) {
rows, err := db.QueryContext(
ctx,
`SELECT status, COUNT(*)
FROM tasks
WHERE run_id = ?
GROUP BY status`,
runID,
)
if err != nil {
return nil, fmt.Errorf("query task counts: %w", err)
}
defer rows.Close()
counts := make(map[string]int)
for rows.Next() {
var (
status string
count int
)
if err := rows.Scan(&status, &count); err != nil {
return nil, fmt.Errorf("scan task count: %w", err)
}
counts[status] = count
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate task counts: %w", err)
}
return counts, nil
}
func deriveRunStatus(counts map[string]int) string {
total := 0
for _, count := range counts {
total += count
}
if total == 0 {
return "active"
}
if counts["blocked"] > 0 {
return "blocked"
}
if counts["failed"] > 0 {
return "failed"
}
if counts["running"] > 0 || counts["dispatched"] > 0 {
return "running"
}
if counts["ready"] > 0 {
return "ready"
}
if counts["planned"] > 0 {
return "planned"
}
if counts["done"] > 0 {
return "done"
}
if counts["cancelled"] == total {
return "cancelled"
}
return "active"
}
func reconcileTaskStatus(threadStatus string) string {
switch threadStatus {
case "pending":
return "dispatched"
case "claimed", "in_progress":
return "running"
case "blocked":
return "blocked"
case "done":
return "done"
case "failed":
return "failed"
case "cancelled":
return "cancelled"
default:
return ""
}
}
func normalizePriority(priority string) (string, error) {
priority = defaultString(strings.TrimSpace(priority), "normal")
switch priority {
case "low", "normal", "high":
return priority, nil
default:
return "", fmt.Errorf("%w: priority must be one of low, normal, high", ErrInvalidInput)
}
}
func validateAndNormalizeJSONDefault(fieldName, value, defaultValue string) (string, error) {
normalized := strings.TrimSpace(value)
if normalized == "" {
normalized = defaultValue
}
if !json.Valid([]byte(normalized)) {
return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName)
}
var compact bytes.Buffer
if err := json.Compact(&compact, []byte(normalized)); err != nil {
return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName)
}
return compact.String(), nil
}
func buildDispatchPayload(task Task, attemptNo int, baseRef string) string {
payload := map[string]any{
"run_id": task.RunID,
"task_id": task.TaskID,
"attempt_no": attemptNo,
"title": task.Title,
"summary": task.Summary,
"priority": task.Priority,
}
if len(task.AcceptanceJSON) > 0 {
var acceptance any
if err := json.Unmarshal(task.AcceptanceJSON, &acceptance); err == nil {
payload["acceptance"] = acceptance
}
}
if strings.TrimSpace(baseRef) != "" {
payload["base_ref"] = strings.TrimSpace(baseRef)
}
return marshalJSON(payload)
}
func marshalJSON(v any) string {
data, err := json.Marshal(v)
if err != nil {
return "{}"
}
return string(data)
}
func nullIfEmpty(value string) any {
if strings.TrimSpace(value) == "" {
return nil
}
return value
}
func summarizeAnswer(body string) string {
body = strings.TrimSpace(body)
if body == "" {
return "task answer"
}
line := body
if idx := strings.IndexByte(line, '\n'); idx >= 0 {
line = line[:idx]
}
line = strings.TrimSpace(line)
if line == "" {
return "task answer"
}
return line
}
func isUniqueConstraintError(err error) bool {
return strings.Contains(strings.ToLower(err.Error()), "unique constraint failed")
}