1640 lines
41 KiB
Go
1640 lines
41 KiB
Go
package store
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
var ErrRunNotFound = errors.New("run not found")
|
|
var ErrTaskNotFound = errors.New("task not found")
|
|
|
|
type OrchStore struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
type Run struct {
|
|
RunID string `json:"run_id"`
|
|
Goal string `json:"goal"`
|
|
Summary string `json:"summary"`
|
|
Status string `json:"status"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
type Task struct {
|
|
RunID string `json:"run_id"`
|
|
TaskID string `json:"task_id"`
|
|
Title string `json:"title"`
|
|
Summary string `json:"summary"`
|
|
Status string `json:"status"`
|
|
DefaultTo string `json:"default_to,omitempty"`
|
|
Priority string `json:"priority"`
|
|
AcceptanceJSON json.RawMessage `json:"acceptance_json"`
|
|
LatestAttemptNo int `json:"latest_attempt_no,omitempty"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
type TaskDependency struct {
|
|
RunID string `json:"run_id"`
|
|
TaskID string `json:"task_id"`
|
|
DependsOnTaskID string `json:"depends_on_task_id"`
|
|
}
|
|
|
|
type TaskAttempt struct {
|
|
RunID string `json:"run_id"`
|
|
TaskID string `json:"task_id"`
|
|
AttemptNo int `json:"attempt_no"`
|
|
AssignedTo string `json:"assigned_to"`
|
|
ThreadID string `json:"thread_id"`
|
|
BaseRef string `json:"base_ref,omitempty"`
|
|
BaseCommit string `json:"base_commit,omitempty"`
|
|
BranchName string `json:"branch_name,omitempty"`
|
|
WorktreePath string `json:"worktree_path,omitempty"`
|
|
WorkspaceStatus string `json:"workspace_status,omitempty"`
|
|
ResultCommit string `json:"result_commit,omitempty"`
|
|
Status string `json:"status"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
type RunOverview struct {
|
|
Run Run `json:"run"`
|
|
TaskCounts map[string]int `json:"task_counts"`
|
|
Tasks []Task `json:"tasks,omitempty"`
|
|
}
|
|
|
|
type CreateRunInput struct {
|
|
RunID string
|
|
Goal string
|
|
Summary string
|
|
}
|
|
|
|
type AddTaskInput struct {
|
|
RunID string
|
|
TaskID string
|
|
Title string
|
|
Summary string
|
|
DefaultTo string
|
|
AcceptanceJSON string
|
|
Priority string
|
|
}
|
|
|
|
type AddDependencyInput struct {
|
|
RunID string
|
|
TaskID string
|
|
DependsOnTaskID string
|
|
}
|
|
|
|
type ListReadyInput struct {
|
|
RunID string
|
|
Limit int
|
|
}
|
|
|
|
type DispatchInput struct {
|
|
RunID string
|
|
TaskID string
|
|
ToAgent string
|
|
Body string
|
|
BaseRef string
|
|
}
|
|
|
|
type DispatchResult struct {
|
|
Task Task `json:"task"`
|
|
Attempt TaskAttempt `json:"attempt"`
|
|
Thread Thread `json:"thread"`
|
|
Message Message `json:"message"`
|
|
}
|
|
|
|
type ReconcileResult struct {
|
|
Run Run `json:"run"`
|
|
TaskCounts map[string]int `json:"task_counts"`
|
|
UpdatedTasks []Task `json:"updated_tasks"`
|
|
}
|
|
|
|
type BlockedTask struct {
|
|
Task Task `json:"task"`
|
|
Attempt TaskAttempt `json:"attempt"`
|
|
Question Message `json:"question"`
|
|
}
|
|
|
|
type AnswerInput struct {
|
|
RunID string
|
|
TaskID string
|
|
Body string
|
|
PayloadJSON string
|
|
}
|
|
|
|
type AnswerResult struct {
|
|
Task Task `json:"task"`
|
|
Attempt TaskAttempt `json:"attempt"`
|
|
Thread Thread `json:"thread"`
|
|
Message Message `json:"message"`
|
|
}
|
|
|
|
func NewOrchStore(db *sql.DB) *OrchStore {
|
|
return &OrchStore{db: db}
|
|
}
|
|
|
|
func (s *OrchStore) CreateRun(ctx context.Context, input CreateRunInput) (Run, error) {
|
|
runID := strings.TrimSpace(input.RunID)
|
|
goal := strings.TrimSpace(input.Goal)
|
|
if runID == "" {
|
|
return Run{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
|
}
|
|
if goal == "" {
|
|
return Run{}, fmt.Errorf("%w: goal is required", ErrInvalidInput)
|
|
}
|
|
|
|
now := nowUTC()
|
|
run := Run{
|
|
RunID: runID,
|
|
Goal: goal,
|
|
Summary: strings.TrimSpace(input.Summary),
|
|
Status: "active",
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Run{}, fmt.Errorf("begin create run transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO runs (run_id, goal, summary, status, created_at, updated_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)`,
|
|
run.RunID,
|
|
run.Goal,
|
|
run.Summary,
|
|
run.Status,
|
|
formatTime(run.CreatedAt),
|
|
formatTime(run.UpdatedAt),
|
|
)
|
|
if err != nil {
|
|
if isUniqueConstraintError(err) {
|
|
return Run{}, fmt.Errorf("%w: run %s already exists", ErrInvalidState, run.RunID)
|
|
}
|
|
return Run{}, fmt.Errorf("insert run: %w", err)
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: run.RunID,
|
|
TaskID: "",
|
|
Source: "orch",
|
|
EventType: "run_initialized",
|
|
Summary: defaultString(run.Summary, run.Goal),
|
|
PayloadJSON: marshalJSON(map[string]any{"goal": run.Goal, "summary": run.Summary}),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return Run{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Run{}, fmt.Errorf("commit create run transaction: %w", err)
|
|
}
|
|
|
|
return run, nil
|
|
}
|
|
|
|
func (s *OrchStore) GetRun(ctx context.Context, runID string) (Run, error) {
|
|
return selectRun(ctx, s.db, runID)
|
|
}
|
|
|
|
func (s *OrchStore) AddTask(ctx context.Context, input AddTaskInput) (Task, error) {
|
|
if strings.TrimSpace(input.RunID) == "" {
|
|
return Task{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
|
}
|
|
if strings.TrimSpace(input.TaskID) == "" {
|
|
return Task{}, fmt.Errorf("%w: task id is required", ErrInvalidInput)
|
|
}
|
|
if strings.TrimSpace(input.Title) == "" {
|
|
return Task{}, fmt.Errorf("%w: title is required", ErrInvalidInput)
|
|
}
|
|
|
|
priority, err := normalizePriority(input.Priority)
|
|
if err != nil {
|
|
return Task{}, err
|
|
}
|
|
acceptanceJSON, err := validateAndNormalizeJSONDefault("acceptance-json", input.AcceptanceJSON, "[]")
|
|
if err != nil {
|
|
return Task{}, err
|
|
}
|
|
|
|
now := nowUTC()
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return Task{}, fmt.Errorf("begin add task transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := selectRun(ctx, tx, input.RunID); err != nil {
|
|
return Task{}, err
|
|
}
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO tasks (
|
|
run_id, task_id, title, summary, status, default_to, priority,
|
|
acceptance_json, latest_attempt_no, created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, NULL, ?, ?)`,
|
|
input.RunID,
|
|
input.TaskID,
|
|
input.Title,
|
|
input.Summary,
|
|
"planned",
|
|
nullIfEmpty(input.DefaultTo),
|
|
priority,
|
|
acceptanceJSON,
|
|
formatTime(now),
|
|
formatTime(now),
|
|
)
|
|
if err != nil {
|
|
if isUniqueConstraintError(err) {
|
|
return Task{}, fmt.Errorf("%w: task %s already exists in run %s", ErrInvalidState, input.TaskID, input.RunID)
|
|
}
|
|
return Task{}, fmt.Errorf("insert task: %w", err)
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: input.RunID,
|
|
TaskID: input.TaskID,
|
|
Source: "orch",
|
|
EventType: "task_added",
|
|
Summary: input.Title,
|
|
PayloadJSON: marshalJSON(map[string]any{"title": input.Title, "priority": priority}),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return Task{}, err
|
|
}
|
|
|
|
if err := refreshReadyStates(ctx, tx, input.RunID, now); err != nil {
|
|
return Task{}, err
|
|
}
|
|
if err := updateRunAggregateStatus(ctx, tx, input.RunID, now); err != nil {
|
|
return Task{}, err
|
|
}
|
|
|
|
task, err := selectTask(ctx, tx, input.RunID, input.TaskID)
|
|
if err != nil {
|
|
return Task{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return Task{}, fmt.Errorf("commit add task transaction: %w", err)
|
|
}
|
|
|
|
return task, nil
|
|
}
|
|
|
|
func (s *OrchStore) AddDependency(ctx context.Context, input AddDependencyInput) (TaskDependency, error) {
|
|
if strings.TrimSpace(input.RunID) == "" {
|
|
return TaskDependency{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
|
}
|
|
if strings.TrimSpace(input.TaskID) == "" {
|
|
return TaskDependency{}, fmt.Errorf("%w: task id is required", ErrInvalidInput)
|
|
}
|
|
if strings.TrimSpace(input.DependsOnTaskID) == "" {
|
|
return TaskDependency{}, fmt.Errorf("%w: depends-on task id is required", ErrInvalidInput)
|
|
}
|
|
if input.TaskID == input.DependsOnTaskID {
|
|
return TaskDependency{}, fmt.Errorf("%w: task cannot depend on itself", ErrInvalidInput)
|
|
}
|
|
|
|
now := nowUTC()
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return TaskDependency{}, fmt.Errorf("begin add dependency transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := selectRun(ctx, tx, input.RunID); err != nil {
|
|
return TaskDependency{}, err
|
|
}
|
|
if _, err := selectTask(ctx, tx, input.RunID, input.TaskID); err != nil {
|
|
return TaskDependency{}, err
|
|
}
|
|
if _, err := selectTask(ctx, tx, input.RunID, input.DependsOnTaskID); err != nil {
|
|
return TaskDependency{}, err
|
|
}
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO task_dependencies (run_id, task_id, depends_on_task_id)
|
|
VALUES (?, ?, ?)`,
|
|
input.RunID,
|
|
input.TaskID,
|
|
input.DependsOnTaskID,
|
|
)
|
|
if err != nil {
|
|
if isUniqueConstraintError(err) {
|
|
return TaskDependency{}, fmt.Errorf("%w: dependency %s -> %s already exists", ErrInvalidState, input.TaskID, input.DependsOnTaskID)
|
|
}
|
|
return TaskDependency{}, fmt.Errorf("insert dependency: %w", err)
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: input.RunID,
|
|
TaskID: input.TaskID,
|
|
Source: "orch",
|
|
EventType: "task_dependency_added",
|
|
Summary: fmt.Sprintf("%s depends on %s", input.TaskID, input.DependsOnTaskID),
|
|
PayloadJSON: marshalJSON(map[string]any{"depends_on_task_id": input.DependsOnTaskID}),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return TaskDependency{}, err
|
|
}
|
|
|
|
if err := refreshReadyStates(ctx, tx, input.RunID, now); err != nil {
|
|
return TaskDependency{}, err
|
|
}
|
|
if err := updateRunAggregateStatus(ctx, tx, input.RunID, now); err != nil {
|
|
return TaskDependency{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return TaskDependency{}, fmt.Errorf("commit add dependency transaction: %w", err)
|
|
}
|
|
|
|
return TaskDependency{
|
|
RunID: input.RunID,
|
|
TaskID: input.TaskID,
|
|
DependsOnTaskID: input.DependsOnTaskID,
|
|
}, nil
|
|
}
|
|
|
|
func (s *OrchStore) ListReadyTasks(ctx context.Context, input ListReadyInput) ([]Task, error) {
|
|
if strings.TrimSpace(input.RunID) == "" {
|
|
return nil, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
|
}
|
|
|
|
limit := input.Limit
|
|
if limit <= 0 {
|
|
limit = 20
|
|
}
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("begin list ready transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := selectRun(ctx, tx, input.RunID); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := refreshReadyStates(ctx, tx, input.RunID, nowUTC()); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := updateRunAggregateStatus(ctx, tx, input.RunID, nowUTC()); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := tx.QueryContext(
|
|
ctx,
|
|
`SELECT
|
|
run_id, task_id, title, summary, status, default_to, priority,
|
|
acceptance_json, latest_attempt_no, created_at, updated_at
|
|
FROM tasks
|
|
WHERE run_id = ? AND status = 'ready'
|
|
ORDER BY CASE priority
|
|
WHEN 'high' THEN 0
|
|
WHEN 'normal' THEN 1
|
|
ELSE 2
|
|
END, created_at ASC
|
|
LIMIT ?`,
|
|
input.RunID,
|
|
limit,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query ready tasks: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tasks []Task
|
|
for rows.Next() {
|
|
task, err := scanTask(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate ready tasks: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("commit list ready transaction: %w", err)
|
|
}
|
|
|
|
return tasks, nil
|
|
}
|
|
|
|
func (s *OrchStore) DispatchTask(ctx context.Context, input DispatchInput) (DispatchResult, error) {
|
|
if strings.TrimSpace(input.RunID) == "" {
|
|
return DispatchResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
|
}
|
|
if strings.TrimSpace(input.TaskID) == "" {
|
|
return DispatchResult{}, fmt.Errorf("%w: task id is required", ErrInvalidInput)
|
|
}
|
|
|
|
now := nowUTC()
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return DispatchResult{}, fmt.Errorf("begin dispatch transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := selectRun(ctx, tx, input.RunID); err != nil {
|
|
return DispatchResult{}, err
|
|
}
|
|
if err := refreshReadyStates(ctx, tx, input.RunID, now); err != nil {
|
|
return DispatchResult{}, err
|
|
}
|
|
|
|
task, err := selectTask(ctx, tx, input.RunID, input.TaskID)
|
|
if err != nil {
|
|
return DispatchResult{}, err
|
|
}
|
|
if task.Status != "ready" {
|
|
return DispatchResult{}, fmt.Errorf("%w: task %s is not ready for dispatch", ErrInvalidState, task.TaskID)
|
|
}
|
|
|
|
assignedTo := defaultString(strings.TrimSpace(input.ToAgent), task.DefaultTo)
|
|
if assignedTo == "" {
|
|
return DispatchResult{}, fmt.Errorf("%w: dispatch target agent is required", ErrInvalidInput)
|
|
}
|
|
|
|
attemptNo := task.LatestAttemptNo + 1
|
|
threadID := newID("thr")
|
|
messageID := newID("msg")
|
|
payloadJSON := buildDispatchPayload(task, attemptNo, input.BaseRef)
|
|
thread := Thread{
|
|
ThreadID: threadID,
|
|
RunID: task.RunID,
|
|
TaskID: task.TaskID,
|
|
Subject: task.Title,
|
|
CreatedBy: "orch",
|
|
AssignedTo: assignedTo,
|
|
Status: "pending",
|
|
Priority: task.Priority,
|
|
LatestMessageID: messageID,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO threads (
|
|
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
|
|
priority, latest_message_id, created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
thread.ThreadID,
|
|
thread.RunID,
|
|
thread.TaskID,
|
|
thread.Subject,
|
|
thread.CreatedBy,
|
|
thread.AssignedTo,
|
|
thread.Status,
|
|
thread.Priority,
|
|
thread.LatestMessageID,
|
|
formatTime(thread.CreatedAt),
|
|
formatTime(thread.UpdatedAt),
|
|
)
|
|
if err != nil {
|
|
return DispatchResult{}, fmt.Errorf("insert dispatch thread: %w", err)
|
|
}
|
|
|
|
message := Message{
|
|
MessageID: messageID,
|
|
ThreadID: threadID,
|
|
FromAgent: "orch",
|
|
ToAgent: assignedTo,
|
|
Kind: "task",
|
|
Summary: defaultString(task.Summary, task.Title),
|
|
Body: input.Body,
|
|
PayloadJSON: json.RawMessage(payloadJSON),
|
|
CreatedAt: now,
|
|
}
|
|
if err := insertMessage(ctx, tx, message); err != nil {
|
|
return DispatchResult{}, err
|
|
}
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: thread.RunID,
|
|
TaskID: thread.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "inbox",
|
|
EventType: "thread_created",
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: payloadJSON,
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return DispatchResult{}, err
|
|
}
|
|
|
|
attempt := TaskAttempt{
|
|
RunID: task.RunID,
|
|
TaskID: task.TaskID,
|
|
AttemptNo: attemptNo,
|
|
AssignedTo: assignedTo,
|
|
ThreadID: threadID,
|
|
BaseRef: strings.TrimSpace(input.BaseRef),
|
|
Status: "dispatched",
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO task_attempts (
|
|
run_id, task_id, attempt_no, assigned_to, thread_id, base_ref, base_commit,
|
|
branch_name, worktree_path, workspace_status, result_commit, status,
|
|
created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
attempt.RunID,
|
|
attempt.TaskID,
|
|
attempt.AttemptNo,
|
|
attempt.AssignedTo,
|
|
attempt.ThreadID,
|
|
nullIfEmpty(attempt.BaseRef),
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
nil,
|
|
attempt.Status,
|
|
formatTime(attempt.CreatedAt),
|
|
formatTime(attempt.UpdatedAt),
|
|
)
|
|
if err != nil {
|
|
return DispatchResult{}, fmt.Errorf("insert task attempt: %w", err)
|
|
}
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`UPDATE tasks
|
|
SET status = ?, latest_attempt_no = ?, updated_at = ?
|
|
WHERE run_id = ? AND task_id = ?`,
|
|
"dispatched",
|
|
attempt.AttemptNo,
|
|
formatTime(now),
|
|
task.RunID,
|
|
task.TaskID,
|
|
)
|
|
if err != nil {
|
|
return DispatchResult{}, fmt.Errorf("update task dispatch status: %w", err)
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: task.RunID,
|
|
TaskID: task.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "orch",
|
|
EventType: "task_dispatched",
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: payloadJSON,
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return DispatchResult{}, err
|
|
}
|
|
|
|
if err := updateRunAggregateStatus(ctx, tx, task.RunID, now); err != nil {
|
|
return DispatchResult{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return DispatchResult{}, fmt.Errorf("commit dispatch transaction: %w", err)
|
|
}
|
|
|
|
task.Status = "dispatched"
|
|
task.LatestAttemptNo = attempt.AttemptNo
|
|
task.UpdatedAt = now
|
|
|
|
return DispatchResult{
|
|
Task: task,
|
|
Attempt: attempt,
|
|
Thread: thread,
|
|
Message: message,
|
|
}, nil
|
|
}
|
|
|
|
func (s *OrchStore) ReconcileRun(ctx context.Context, runID string) (ReconcileResult, error) {
|
|
if strings.TrimSpace(runID) == "" {
|
|
return ReconcileResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
|
}
|
|
|
|
now := nowUTC()
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return ReconcileResult{}, fmt.Errorf("begin reconcile transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := selectRun(ctx, tx, runID); err != nil {
|
|
return ReconcileResult{}, err
|
|
}
|
|
|
|
rows, err := tx.QueryContext(
|
|
ctx,
|
|
`SELECT
|
|
t.task_id,
|
|
t.status,
|
|
a.attempt_no,
|
|
a.status,
|
|
a.thread_id,
|
|
th.status
|
|
FROM tasks t
|
|
JOIN task_attempts a
|
|
ON a.run_id = t.run_id
|
|
AND a.task_id = t.task_id
|
|
AND a.attempt_no = t.latest_attempt_no
|
|
JOIN threads th ON th.thread_id = a.thread_id
|
|
WHERE t.run_id = ?
|
|
AND t.latest_attempt_no IS NOT NULL`,
|
|
runID,
|
|
)
|
|
if err != nil {
|
|
return ReconcileResult{}, fmt.Errorf("query reconcile candidates: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var updatedIDs []string
|
|
for rows.Next() {
|
|
var (
|
|
taskID string
|
|
taskStatus string
|
|
attemptNo int
|
|
attemptStatus string
|
|
threadID string
|
|
threadStatus string
|
|
)
|
|
if err := rows.Scan(&taskID, &taskStatus, &attemptNo, &attemptStatus, &threadID, &threadStatus); err != nil {
|
|
return ReconcileResult{}, fmt.Errorf("scan reconcile candidate: %w", err)
|
|
}
|
|
|
|
nextStatus := reconcileTaskStatus(threadStatus)
|
|
if nextStatus == "" {
|
|
continue
|
|
}
|
|
if nextStatus == taskStatus && nextStatus == attemptStatus {
|
|
continue
|
|
}
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`UPDATE tasks
|
|
SET status = ?, updated_at = ?
|
|
WHERE run_id = ? AND task_id = ?`,
|
|
nextStatus,
|
|
formatTime(now),
|
|
runID,
|
|
taskID,
|
|
)
|
|
if err != nil {
|
|
return ReconcileResult{}, fmt.Errorf("update reconciled task status: %w", err)
|
|
}
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`UPDATE task_attempts
|
|
SET status = ?, updated_at = ?
|
|
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
|
|
nextStatus,
|
|
formatTime(now),
|
|
runID,
|
|
taskID,
|
|
attemptNo,
|
|
)
|
|
if err != nil {
|
|
return ReconcileResult{}, fmt.Errorf("update reconciled attempt status: %w", err)
|
|
}
|
|
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: runID,
|
|
TaskID: taskID,
|
|
ThreadID: threadID,
|
|
Source: "orch",
|
|
EventType: "task_" + nextStatus,
|
|
Summary: fmt.Sprintf("%s -> %s", taskID, nextStatus),
|
|
PayloadJSON: marshalJSON(map[string]any{
|
|
"thread_id": threadID,
|
|
"thread_status": threadStatus,
|
|
"previous_status": taskStatus,
|
|
"previous_attempt": attemptStatus,
|
|
}),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return ReconcileResult{}, err
|
|
}
|
|
|
|
updatedIDs = append(updatedIDs, taskID)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return ReconcileResult{}, fmt.Errorf("iterate reconcile candidates: %w", err)
|
|
}
|
|
|
|
if err := refreshReadyStates(ctx, tx, runID, now); err != nil {
|
|
return ReconcileResult{}, err
|
|
}
|
|
if err := updateRunAggregateStatus(ctx, tx, runID, now); err != nil {
|
|
return ReconcileResult{}, err
|
|
}
|
|
|
|
run, err := selectRun(ctx, tx, runID)
|
|
if err != nil {
|
|
return ReconcileResult{}, err
|
|
}
|
|
taskCounts, err := collectTaskCounts(ctx, tx, runID)
|
|
if err != nil {
|
|
return ReconcileResult{}, err
|
|
}
|
|
|
|
updatedTasks := make([]Task, 0, len(updatedIDs))
|
|
for _, taskID := range updatedIDs {
|
|
task, err := selectTask(ctx, tx, runID, taskID)
|
|
if err != nil {
|
|
return ReconcileResult{}, err
|
|
}
|
|
updatedTasks = append(updatedTasks, task)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return ReconcileResult{}, fmt.Errorf("commit reconcile transaction: %w", err)
|
|
}
|
|
|
|
return ReconcileResult{
|
|
Run: run,
|
|
TaskCounts: taskCounts,
|
|
UpdatedTasks: updatedTasks,
|
|
}, nil
|
|
}
|
|
|
|
func (s *OrchStore) ListBlockedTasks(ctx context.Context, runID string) ([]BlockedTask, error) {
|
|
if strings.TrimSpace(runID) == "" {
|
|
return nil, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
|
}
|
|
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("begin list blocked transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := selectRun(ctx, tx, runID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := tx.QueryContext(
|
|
ctx,
|
|
`SELECT
|
|
t.run_id, t.task_id, t.title, t.summary, t.status, t.default_to, t.priority,
|
|
t.acceptance_json, t.latest_attempt_no, t.created_at, t.updated_at,
|
|
a.run_id, a.task_id, a.attempt_no, a.assigned_to, a.thread_id, a.base_ref,
|
|
a.base_commit, a.branch_name, a.worktree_path, a.workspace_status,
|
|
a.result_commit, a.status, a.created_at, a.updated_at
|
|
FROM tasks t
|
|
JOIN task_attempts a
|
|
ON a.run_id = t.run_id
|
|
AND a.task_id = t.task_id
|
|
AND a.attempt_no = t.latest_attempt_no
|
|
WHERE t.run_id = ?
|
|
AND t.status = 'blocked'
|
|
ORDER BY t.updated_at ASC`,
|
|
runID,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query blocked tasks: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var blocked []BlockedTask
|
|
for rows.Next() {
|
|
task, attempt, err := scanTaskAndAttempt(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
question, err := selectLatestQuestionMessage(ctx, tx, attempt.ThreadID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
blocked = append(blocked, BlockedTask{
|
|
Task: task,
|
|
Attempt: attempt,
|
|
Question: question,
|
|
})
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate blocked tasks: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("commit list blocked transaction: %w", err)
|
|
}
|
|
|
|
return blocked, nil
|
|
}
|
|
|
|
func (s *OrchStore) AnswerTask(ctx context.Context, input AnswerInput) (AnswerResult, error) {
|
|
if strings.TrimSpace(input.RunID) == "" {
|
|
return AnswerResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
|
}
|
|
if strings.TrimSpace(input.TaskID) == "" {
|
|
return AnswerResult{}, fmt.Errorf("%w: task id is required", ErrInvalidInput)
|
|
}
|
|
|
|
payloadJSON, err := validateAndNormalizeJSON("payload-json", input.PayloadJSON)
|
|
if err != nil {
|
|
return AnswerResult{}, err
|
|
}
|
|
if strings.TrimSpace(input.Body) == "" && payloadJSON == "{}" {
|
|
return AnswerResult{}, fmt.Errorf("%w: body or payload-json is required", ErrInvalidInput)
|
|
}
|
|
|
|
now := nowUTC()
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return AnswerResult{}, fmt.Errorf("begin answer transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
task, err := selectTask(ctx, tx, input.RunID, input.TaskID)
|
|
if err != nil {
|
|
return AnswerResult{}, err
|
|
}
|
|
if task.Status != "blocked" {
|
|
return AnswerResult{}, fmt.Errorf("%w: task %s is not blocked", ErrInvalidState, task.TaskID)
|
|
}
|
|
if task.LatestAttemptNo == 0 {
|
|
return AnswerResult{}, fmt.Errorf("%w: task %s has no active attempt", ErrInvalidState, task.TaskID)
|
|
}
|
|
|
|
attempt, err := selectAttempt(ctx, tx, input.RunID, input.TaskID, task.LatestAttemptNo)
|
|
if err != nil {
|
|
return AnswerResult{}, err
|
|
}
|
|
thread, err := selectThread(ctx, tx, attempt.ThreadID)
|
|
if err != nil {
|
|
return AnswerResult{}, err
|
|
}
|
|
if isTerminalStatus(thread.Status) {
|
|
return AnswerResult{}, fmt.Errorf("%w: thread %s is already terminal", ErrInvalidState, thread.ThreadID)
|
|
}
|
|
|
|
message := Message{
|
|
MessageID: newID("msg"),
|
|
ThreadID: thread.ThreadID,
|
|
FromAgent: "orch",
|
|
ToAgent: attempt.AssignedTo,
|
|
Kind: "answer",
|
|
Summary: summarizeAnswer(input.Body),
|
|
Body: input.Body,
|
|
PayloadJSON: json.RawMessage(payloadJSON),
|
|
CreatedAt: now,
|
|
}
|
|
if err := insertMessage(ctx, tx, message); err != nil {
|
|
return AnswerResult{}, err
|
|
}
|
|
if err := updateThreadState(ctx, tx, thread.ThreadID, thread.Status, thread.AssignedTo, message.MessageID, now); err != nil {
|
|
return AnswerResult{}, err
|
|
}
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: thread.RunID,
|
|
TaskID: thread.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "inbox",
|
|
EventType: "thread_reply",
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: payloadJSON,
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return AnswerResult{}, err
|
|
}
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: task.RunID,
|
|
TaskID: task.TaskID,
|
|
ThreadID: thread.ThreadID,
|
|
Source: "orch",
|
|
EventType: "task_answered",
|
|
MessageID: message.MessageID,
|
|
Summary: message.Summary,
|
|
PayloadJSON: payloadJSON,
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return AnswerResult{}, err
|
|
}
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`UPDATE tasks
|
|
SET updated_at = ?
|
|
WHERE run_id = ? AND task_id = ?`,
|
|
formatTime(now),
|
|
task.RunID,
|
|
task.TaskID,
|
|
)
|
|
if err != nil {
|
|
return AnswerResult{}, fmt.Errorf("touch answered task: %w", err)
|
|
}
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`UPDATE task_attempts
|
|
SET updated_at = ?
|
|
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
|
|
formatTime(now),
|
|
attempt.RunID,
|
|
attempt.TaskID,
|
|
attempt.AttemptNo,
|
|
)
|
|
if err != nil {
|
|
return AnswerResult{}, fmt.Errorf("touch answered attempt: %w", err)
|
|
}
|
|
if err := updateRunAggregateStatus(ctx, tx, task.RunID, now); err != nil {
|
|
return AnswerResult{}, err
|
|
}
|
|
|
|
task.UpdatedAt = now
|
|
attempt.UpdatedAt = now
|
|
thread.LatestMessageID = message.MessageID
|
|
thread.UpdatedAt = now
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return AnswerResult{}, fmt.Errorf("commit answer transaction: %w", err)
|
|
}
|
|
|
|
return AnswerResult{
|
|
Task: task,
|
|
Attempt: attempt,
|
|
Thread: thread,
|
|
Message: message,
|
|
}, nil
|
|
}
|
|
|
|
func (s *OrchStore) GetRunOverview(ctx context.Context, runID string) (RunOverview, error) {
|
|
if strings.TrimSpace(runID) == "" {
|
|
return RunOverview{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
|
}
|
|
|
|
now := nowUTC()
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return RunOverview{}, fmt.Errorf("begin run overview transaction: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
if _, err := selectRun(ctx, tx, runID); err != nil {
|
|
return RunOverview{}, err
|
|
}
|
|
if err := refreshReadyStates(ctx, tx, runID, now); err != nil {
|
|
return RunOverview{}, err
|
|
}
|
|
if err := updateRunAggregateStatus(ctx, tx, runID, now); err != nil {
|
|
return RunOverview{}, err
|
|
}
|
|
|
|
run, err := selectRun(ctx, tx, runID)
|
|
if err != nil {
|
|
return RunOverview{}, err
|
|
}
|
|
taskCounts, err := collectTaskCounts(ctx, tx, runID)
|
|
if err != nil {
|
|
return RunOverview{}, err
|
|
}
|
|
tasks, err := listTasksForRun(ctx, tx, runID)
|
|
if err != nil {
|
|
return RunOverview{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return RunOverview{}, fmt.Errorf("commit run overview transaction: %w", err)
|
|
}
|
|
|
|
return RunOverview{
|
|
Run: run,
|
|
TaskCounts: taskCounts,
|
|
Tasks: tasks,
|
|
}, nil
|
|
}
|
|
|
|
func listTasksForRun(ctx context.Context, db queryRowsContexter, runID string) ([]Task, error) {
|
|
rows, err := db.QueryContext(
|
|
ctx,
|
|
`SELECT
|
|
run_id, task_id, title, summary, status, default_to, priority,
|
|
acceptance_json, latest_attempt_no, created_at, updated_at
|
|
FROM tasks
|
|
WHERE run_id = ?
|
|
ORDER BY created_at ASC`,
|
|
runID,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query tasks for run: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tasks []Task
|
|
for rows.Next() {
|
|
task, err := scanTask(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate tasks for run: %w", err)
|
|
}
|
|
return tasks, nil
|
|
}
|
|
|
|
type queryRowsContexter interface {
|
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
|
}
|
|
|
|
func scanRun(scanner threadScanner) (Run, error) {
|
|
var (
|
|
run Run
|
|
createdAt, updated string
|
|
)
|
|
|
|
if err := scanner.Scan(
|
|
&run.RunID,
|
|
&run.Goal,
|
|
&run.Summary,
|
|
&run.Status,
|
|
&createdAt,
|
|
&updated,
|
|
); err != nil {
|
|
return Run{}, fmt.Errorf("scan run: %w", err)
|
|
}
|
|
|
|
run.CreatedAt = parseTime(createdAt)
|
|
run.UpdatedAt = parseTime(updated)
|
|
return run, nil
|
|
}
|
|
|
|
func scanTask(scanner threadScanner) (Task, error) {
|
|
var (
|
|
task Task
|
|
defaultTo sql.NullString
|
|
latestAttempt sql.NullInt64
|
|
acceptanceJSON string
|
|
createdAt, updatedAt string
|
|
)
|
|
|
|
if err := scanner.Scan(
|
|
&task.RunID,
|
|
&task.TaskID,
|
|
&task.Title,
|
|
&task.Summary,
|
|
&task.Status,
|
|
&defaultTo,
|
|
&task.Priority,
|
|
&acceptanceJSON,
|
|
&latestAttempt,
|
|
&createdAt,
|
|
&updatedAt,
|
|
); err != nil {
|
|
return Task{}, fmt.Errorf("scan task: %w", err)
|
|
}
|
|
|
|
task.DefaultTo = defaultTo.String
|
|
task.AcceptanceJSON = json.RawMessage(acceptanceJSON)
|
|
if latestAttempt.Valid {
|
|
task.LatestAttemptNo = int(latestAttempt.Int64)
|
|
}
|
|
task.CreatedAt = parseTime(createdAt)
|
|
task.UpdatedAt = parseTime(updatedAt)
|
|
return task, nil
|
|
}
|
|
|
|
func scanAttempt(scanner threadScanner) (TaskAttempt, error) {
|
|
var (
|
|
attempt TaskAttempt
|
|
baseRef sql.NullString
|
|
baseCommit sql.NullString
|
|
branchName sql.NullString
|
|
worktreePath sql.NullString
|
|
workspaceStatus sql.NullString
|
|
resultCommit sql.NullString
|
|
createdAt, updated string
|
|
)
|
|
|
|
if err := scanner.Scan(
|
|
&attempt.RunID,
|
|
&attempt.TaskID,
|
|
&attempt.AttemptNo,
|
|
&attempt.AssignedTo,
|
|
&attempt.ThreadID,
|
|
&baseRef,
|
|
&baseCommit,
|
|
&branchName,
|
|
&worktreePath,
|
|
&workspaceStatus,
|
|
&resultCommit,
|
|
&attempt.Status,
|
|
&createdAt,
|
|
&updated,
|
|
); err != nil {
|
|
return TaskAttempt{}, fmt.Errorf("scan attempt: %w", err)
|
|
}
|
|
|
|
attempt.BaseRef = baseRef.String
|
|
attempt.BaseCommit = baseCommit.String
|
|
attempt.BranchName = branchName.String
|
|
attempt.WorktreePath = worktreePath.String
|
|
attempt.WorkspaceStatus = workspaceStatus.String
|
|
attempt.ResultCommit = resultCommit.String
|
|
attempt.CreatedAt = parseTime(createdAt)
|
|
attempt.UpdatedAt = parseTime(updated)
|
|
return attempt, nil
|
|
}
|
|
|
|
func scanTaskAndAttempt(scanner threadScanner) (Task, TaskAttempt, error) {
|
|
var (
|
|
task Task
|
|
taskDefaultTo sql.NullString
|
|
taskLatestAttempt sql.NullInt64
|
|
taskAcceptanceJSON string
|
|
taskCreatedAt string
|
|
taskUpdatedAt string
|
|
attempt TaskAttempt
|
|
attemptBaseRef sql.NullString
|
|
attemptBaseCommit sql.NullString
|
|
attemptBranchName sql.NullString
|
|
attemptWorktreePath sql.NullString
|
|
attemptWorkspaceState sql.NullString
|
|
attemptResultCommit sql.NullString
|
|
attemptCreatedAt string
|
|
attemptUpdatedAt string
|
|
)
|
|
|
|
if err := scanner.Scan(
|
|
&task.RunID,
|
|
&task.TaskID,
|
|
&task.Title,
|
|
&task.Summary,
|
|
&task.Status,
|
|
&taskDefaultTo,
|
|
&task.Priority,
|
|
&taskAcceptanceJSON,
|
|
&taskLatestAttempt,
|
|
&taskCreatedAt,
|
|
&taskUpdatedAt,
|
|
&attempt.RunID,
|
|
&attempt.TaskID,
|
|
&attempt.AttemptNo,
|
|
&attempt.AssignedTo,
|
|
&attempt.ThreadID,
|
|
&attemptBaseRef,
|
|
&attemptBaseCommit,
|
|
&attemptBranchName,
|
|
&attemptWorktreePath,
|
|
&attemptWorkspaceState,
|
|
&attemptResultCommit,
|
|
&attempt.Status,
|
|
&attemptCreatedAt,
|
|
&attemptUpdatedAt,
|
|
); err != nil {
|
|
return Task{}, TaskAttempt{}, fmt.Errorf("scan task and attempt: %w", err)
|
|
}
|
|
|
|
task.DefaultTo = taskDefaultTo.String
|
|
task.AcceptanceJSON = json.RawMessage(taskAcceptanceJSON)
|
|
if taskLatestAttempt.Valid {
|
|
task.LatestAttemptNo = int(taskLatestAttempt.Int64)
|
|
}
|
|
task.CreatedAt = parseTime(taskCreatedAt)
|
|
task.UpdatedAt = parseTime(taskUpdatedAt)
|
|
|
|
attempt.BaseRef = attemptBaseRef.String
|
|
attempt.BaseCommit = attemptBaseCommit.String
|
|
attempt.BranchName = attemptBranchName.String
|
|
attempt.WorktreePath = attemptWorktreePath.String
|
|
attempt.WorkspaceStatus = attemptWorkspaceState.String
|
|
attempt.ResultCommit = attemptResultCommit.String
|
|
attempt.CreatedAt = parseTime(attemptCreatedAt)
|
|
attempt.UpdatedAt = parseTime(attemptUpdatedAt)
|
|
|
|
return task, attempt, nil
|
|
}
|
|
|
|
func selectRun(ctx context.Context, db queryRower, runID string) (Run, error) {
|
|
row := db.QueryRowContext(
|
|
ctx,
|
|
`SELECT run_id, goal, summary, status, created_at, updated_at
|
|
FROM runs
|
|
WHERE run_id = ?`,
|
|
runID,
|
|
)
|
|
run, err := scanRun(row)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return Run{}, fmt.Errorf("%w: %s", ErrRunNotFound, runID)
|
|
}
|
|
return run, err
|
|
}
|
|
|
|
func selectTask(ctx context.Context, db queryRower, runID, taskID string) (Task, error) {
|
|
row := db.QueryRowContext(
|
|
ctx,
|
|
`SELECT
|
|
run_id, task_id, title, summary, status, default_to, priority,
|
|
acceptance_json, latest_attempt_no, created_at, updated_at
|
|
FROM tasks
|
|
WHERE run_id = ? AND task_id = ?`,
|
|
runID,
|
|
taskID,
|
|
)
|
|
task, err := scanTask(row)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return Task{}, fmt.Errorf("%w: %s/%s", ErrTaskNotFound, runID, taskID)
|
|
}
|
|
return task, err
|
|
}
|
|
|
|
func selectAttempt(ctx context.Context, db queryRower, runID, taskID string, attemptNo int) (TaskAttempt, error) {
|
|
row := db.QueryRowContext(
|
|
ctx,
|
|
`SELECT
|
|
run_id, task_id, attempt_no, assigned_to, thread_id, base_ref, base_commit,
|
|
branch_name, worktree_path, workspace_status, result_commit, status,
|
|
created_at, updated_at
|
|
FROM task_attempts
|
|
WHERE run_id = ? AND task_id = ? AND attempt_no = ?`,
|
|
runID,
|
|
taskID,
|
|
attemptNo,
|
|
)
|
|
attempt, err := scanAttempt(row)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return TaskAttempt{}, fmt.Errorf("%w: attempt %s/%s/%d not found", ErrInvalidState, runID, taskID, attemptNo)
|
|
}
|
|
return attempt, err
|
|
}
|
|
|
|
func selectLatestQuestionMessage(ctx context.Context, db queryRowsAndRower, threadID string) (Message, error) {
|
|
row := db.QueryRowContext(
|
|
ctx,
|
|
`SELECT
|
|
message_id, thread_id, from_agent, to_agent, kind, summary, body,
|
|
payload_json, created_at
|
|
FROM messages
|
|
WHERE thread_id = ? AND kind = 'question'
|
|
ORDER BY created_at DESC
|
|
LIMIT 1`,
|
|
threadID,
|
|
)
|
|
message, err := scanMessage(row)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return Message{}, fmt.Errorf("%w: blocked thread %s has no question message", ErrInvalidState, threadID)
|
|
}
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
artifactsByMessageID, err := loadArtifactsForMessageIDsFromQueryer(ctx, db, []string{message.MessageID})
|
|
if err != nil {
|
|
return Message{}, err
|
|
}
|
|
message.Artifacts = artifactsByMessageID[message.MessageID]
|
|
return message, nil
|
|
}
|
|
|
|
type queryRowsAndRower interface {
|
|
queryRower
|
|
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
|
}
|
|
|
|
func loadArtifactsForMessageIDsFromQueryer(ctx context.Context, db queryRowsContexter, messageIDs []string) (map[string][]Artifact, error) {
|
|
result := make(map[string][]Artifact)
|
|
if len(messageIDs) == 0 {
|
|
return result, nil
|
|
}
|
|
|
|
args := make([]any, 0, len(messageIDs))
|
|
for _, messageID := range messageIDs {
|
|
args = append(args, messageID)
|
|
}
|
|
|
|
rows, err := db.QueryContext(
|
|
ctx,
|
|
`SELECT
|
|
artifact_id, message_id, path, kind, metadata_json, created_at
|
|
FROM artifacts
|
|
WHERE message_id IN (`+placeholders(len(messageIDs))+`)
|
|
ORDER BY created_at ASC`,
|
|
args...,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query artifacts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
artifact, err := scanArtifact(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result[artifact.MessageID] = append(result[artifact.MessageID], artifact)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate artifacts: %w", err)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func refreshReadyStates(ctx context.Context, tx *sql.Tx, runID string, now time.Time) error {
|
|
rows, err := tx.QueryContext(
|
|
ctx,
|
|
`SELECT task_id, status, title
|
|
FROM tasks
|
|
WHERE run_id = ?
|
|
AND status IN ('planned', 'ready')`,
|
|
runID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("query tasks for readiness refresh: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
type readinessRow struct {
|
|
taskID string
|
|
status string
|
|
title string
|
|
}
|
|
|
|
var tasks []readinessRow
|
|
for rows.Next() {
|
|
var row readinessRow
|
|
if err := rows.Scan(&row.taskID, &row.status, &row.title); err != nil {
|
|
return fmt.Errorf("scan readiness refresh row: %w", err)
|
|
}
|
|
tasks = append(tasks, row)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return fmt.Errorf("iterate readiness refresh rows: %w", err)
|
|
}
|
|
|
|
for _, task := range tasks {
|
|
ready, err := dependenciesSatisfied(ctx, tx, runID, task.taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
desired := "planned"
|
|
if ready {
|
|
desired = "ready"
|
|
}
|
|
if desired == task.status {
|
|
continue
|
|
}
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`UPDATE tasks
|
|
SET status = ?, updated_at = ?
|
|
WHERE run_id = ? AND task_id = ?`,
|
|
desired,
|
|
formatTime(now),
|
|
runID,
|
|
task.taskID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("update task readiness: %w", err)
|
|
}
|
|
|
|
if desired == "ready" {
|
|
if err := insertEvent(ctx, tx, eventInput{
|
|
RunID: runID,
|
|
TaskID: task.taskID,
|
|
Source: "orch",
|
|
EventType: "task_ready",
|
|
Summary: defaultString(task.title, task.taskID),
|
|
PayloadJSON: marshalJSON(map[string]any{"task_id": task.taskID}),
|
|
CreatedAt: now,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func dependenciesSatisfied(ctx context.Context, tx *sql.Tx, runID, taskID string) (bool, error) {
|
|
var pendingCount int
|
|
err := tx.QueryRowContext(
|
|
ctx,
|
|
`SELECT COUNT(*)
|
|
FROM task_dependencies d
|
|
JOIN tasks dep
|
|
ON dep.run_id = d.run_id
|
|
AND dep.task_id = d.depends_on_task_id
|
|
WHERE d.run_id = ?
|
|
AND d.task_id = ?
|
|
AND dep.status <> 'done'`,
|
|
runID,
|
|
taskID,
|
|
).Scan(&pendingCount)
|
|
if err != nil {
|
|
return false, fmt.Errorf("query dependency readiness: %w", err)
|
|
}
|
|
return pendingCount == 0, nil
|
|
}
|
|
|
|
func updateRunAggregateStatus(ctx context.Context, tx *sql.Tx, runID string, now time.Time) error {
|
|
counts, err := collectTaskCounts(ctx, tx, runID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
nextStatus := deriveRunStatus(counts)
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
`UPDATE runs
|
|
SET status = ?, updated_at = ?
|
|
WHERE run_id = ?`,
|
|
nextStatus,
|
|
formatTime(now),
|
|
runID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("update run aggregate status: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func collectTaskCounts(ctx context.Context, db queryRowsContexter, runID string) (map[string]int, error) {
|
|
rows, err := db.QueryContext(
|
|
ctx,
|
|
`SELECT status, COUNT(*)
|
|
FROM tasks
|
|
WHERE run_id = ?
|
|
GROUP BY status`,
|
|
runID,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query task counts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
counts := make(map[string]int)
|
|
for rows.Next() {
|
|
var (
|
|
status string
|
|
count int
|
|
)
|
|
if err := rows.Scan(&status, &count); err != nil {
|
|
return nil, fmt.Errorf("scan task count: %w", err)
|
|
}
|
|
counts[status] = count
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate task counts: %w", err)
|
|
}
|
|
|
|
return counts, nil
|
|
}
|
|
|
|
func deriveRunStatus(counts map[string]int) string {
|
|
total := 0
|
|
for _, count := range counts {
|
|
total += count
|
|
}
|
|
if total == 0 {
|
|
return "active"
|
|
}
|
|
if counts["blocked"] > 0 {
|
|
return "blocked"
|
|
}
|
|
if counts["failed"] > 0 {
|
|
return "failed"
|
|
}
|
|
if counts["running"] > 0 || counts["dispatched"] > 0 {
|
|
return "running"
|
|
}
|
|
if counts["ready"] > 0 {
|
|
return "ready"
|
|
}
|
|
if counts["planned"] > 0 {
|
|
return "planned"
|
|
}
|
|
if counts["done"] > 0 {
|
|
return "done"
|
|
}
|
|
if counts["cancelled"] == total {
|
|
return "cancelled"
|
|
}
|
|
return "active"
|
|
}
|
|
|
|
func reconcileTaskStatus(threadStatus string) string {
|
|
switch threadStatus {
|
|
case "pending":
|
|
return "dispatched"
|
|
case "claimed", "in_progress":
|
|
return "running"
|
|
case "blocked":
|
|
return "blocked"
|
|
case "done":
|
|
return "done"
|
|
case "failed":
|
|
return "failed"
|
|
case "cancelled":
|
|
return "cancelled"
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func normalizePriority(priority string) (string, error) {
|
|
priority = defaultString(strings.TrimSpace(priority), "normal")
|
|
switch priority {
|
|
case "low", "normal", "high":
|
|
return priority, nil
|
|
default:
|
|
return "", fmt.Errorf("%w: priority must be one of low, normal, high", ErrInvalidInput)
|
|
}
|
|
}
|
|
|
|
func validateAndNormalizeJSONDefault(fieldName, value, defaultValue string) (string, error) {
|
|
normalized := strings.TrimSpace(value)
|
|
if normalized == "" {
|
|
normalized = defaultValue
|
|
}
|
|
if !json.Valid([]byte(normalized)) {
|
|
return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName)
|
|
}
|
|
|
|
var compact bytes.Buffer
|
|
if err := json.Compact(&compact, []byte(normalized)); err != nil {
|
|
return "", fmt.Errorf("%w: %s must be valid JSON", ErrInvalidInput, fieldName)
|
|
}
|
|
return compact.String(), nil
|
|
}
|
|
|
|
func buildDispatchPayload(task Task, attemptNo int, baseRef string) string {
|
|
payload := map[string]any{
|
|
"run_id": task.RunID,
|
|
"task_id": task.TaskID,
|
|
"attempt_no": attemptNo,
|
|
"title": task.Title,
|
|
"summary": task.Summary,
|
|
"priority": task.Priority,
|
|
}
|
|
|
|
if len(task.AcceptanceJSON) > 0 {
|
|
var acceptance any
|
|
if err := json.Unmarshal(task.AcceptanceJSON, &acceptance); err == nil {
|
|
payload["acceptance"] = acceptance
|
|
}
|
|
}
|
|
if strings.TrimSpace(baseRef) != "" {
|
|
payload["base_ref"] = strings.TrimSpace(baseRef)
|
|
}
|
|
|
|
return marshalJSON(payload)
|
|
}
|
|
|
|
func marshalJSON(v any) string {
|
|
data, err := json.Marshal(v)
|
|
if err != nil {
|
|
return "{}"
|
|
}
|
|
return string(data)
|
|
}
|
|
|
|
func nullIfEmpty(value string) any {
|
|
if strings.TrimSpace(value) == "" {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
func summarizeAnswer(body string) string {
|
|
body = strings.TrimSpace(body)
|
|
if body == "" {
|
|
return "task answer"
|
|
}
|
|
line := body
|
|
if idx := strings.IndexByte(line, '\n'); idx >= 0 {
|
|
line = line[:idx]
|
|
}
|
|
line = strings.TrimSpace(line)
|
|
if line == "" {
|
|
return "task answer"
|
|
}
|
|
return line
|
|
}
|
|
|
|
func isUniqueConstraintError(err error) bool {
|
|
return strings.Contains(strings.ToLower(err.Error()), "unique constraint failed")
|
|
}
|