Files

472 lines
14 KiB
Go

package sqlite
import (
"context"
"database/sql"
"fmt"
"inbox/internal/domain/lane"
"inbox/internal/domain/task"
"inbox/internal/domain/workflow"
)
func (s *Store) ClaimTaskExecution(ctx context.Context, run workflow.Run, taskID, startedAt string) (workflow.Run, task.Record, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return workflow.Run{}, task.Record{}, fmt.Errorf("begin claim task execution: %w", err)
}
defer tx.Rollback()
taskRecord, err := getTaskTx(ctx, tx, taskID)
if err != nil {
return workflow.Run{}, task.Record{}, err
}
run.StartedAt = startedAt
if err := createWorkflowRunTx(ctx, tx, s, &run); err != nil {
return workflow.Run{}, task.Record{}, err
}
taskRecord.Status = task.StatusRunning
taskRecord.AssignedRunID = run.ID
taskRecord.StartedAt = startedAt
taskRecord.UpdatedAt = startedAt
if err := updateTaskTx(ctx, tx, taskRecord); err != nil {
return workflow.Run{}, task.Record{}, err
}
if err := tx.Commit(); err != nil {
return workflow.Run{}, task.Record{}, fmt.Errorf("commit claim task execution: %w", err)
}
claimedRun, err := s.GetWorkflowRun(ctx, run.ID)
if err != nil {
return workflow.Run{}, task.Record{}, err
}
return claimedRun, taskRecord, nil
}
func (s *Store) CompleteTaskExecution(
ctx context.Context,
runID, taskID, laneID string,
status workflow.RunStatus,
exitCode int,
resultMarkdown, errorMessage, completedAt string,
) (workflow.Run, task.Record, lane.Record, []task.Record, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return workflow.Run{}, task.Record{}, lane.Record{}, nil, fmt.Errorf("begin complete task execution: %w", err)
}
defer tx.Rollback()
run, err := getWorkflowRunTx(ctx, tx, runID)
if err != nil {
return workflow.Run{}, task.Record{}, lane.Record{}, nil, err
}
taskRecord, err := getTaskTx(ctx, tx, taskID)
if err != nil {
return workflow.Run{}, task.Record{}, lane.Record{}, nil, err
}
laneRecord, err := getLaneTx(ctx, tx, laneID)
if err != nil {
return workflow.Run{}, task.Record{}, lane.Record{}, nil, err
}
if status == workflow.RunStatusSucceeded {
taskRecord.Status = task.StatusSucceeded
taskRecord.ResultSummaryMarkdown = resultMarkdown
taskRecord.BlockingReasonMarkdown = ""
} else {
taskRecord.Status = task.StatusFailed
taskRecord.ResultSummaryMarkdown = ""
taskRecord.BlockingReasonMarkdown = errorMessage
}
taskRecord.AssignedRunID = run.ID
taskRecord.CompletedAt = completedAt
taskRecord.UpdatedAt = completedAt
if err := updateTaskTx(ctx, tx, taskRecord); err != nil {
return workflow.Run{}, task.Record{}, lane.Record{}, nil, err
}
promotedTasks, touchedLaneIDs, err := promoteReadyTasksTx(ctx, tx, run.TopicID, completedAt)
if err != nil {
return workflow.Run{}, task.Record{}, lane.Record{}, nil, err
}
touchedLaneIDs[laneID] = struct{}{}
updatedLanes, err := refreshTopicLaneStatusesTx(ctx, tx, run.TopicID, touchedLaneIDs, completedAt)
if err != nil {
return workflow.Run{}, task.Record{}, lane.Record{}, nil, err
}
if updated, ok := updatedLanes[laneID]; ok {
laneRecord = updated
}
run.Status = status
run.ExitCode = exitCode
run.CompletedAt = completedAt
run.ErrorMessage = errorMessage
if err := updateWorkflowRunTx(ctx, tx, run); err != nil {
return workflow.Run{}, task.Record{}, lane.Record{}, nil, err
}
if err := tx.Commit(); err != nil {
return workflow.Run{}, task.Record{}, lane.Record{}, nil, fmt.Errorf("commit complete task execution: %w", err)
}
updatedRun, err := s.GetWorkflowRun(ctx, run.ID)
if err != nil {
return workflow.Run{}, task.Record{}, lane.Record{}, nil, err
}
return updatedRun, taskRecord, laneRecord, promotedTasks, nil
}
func getLaneTx(ctx context.Context, tx *sql.Tx, laneID string) (lane.Record, error) {
row := tx.QueryRowContext(ctx, `
SELECT id, workspace_id, topic_id, name, slug, purpose, status, base_branch, branch_name, head_commit, worktree_path,
container_name, runtime_endpoint, created_by_role_name, result_summary_markdown, error_message,
created_at, updated_at, started_at, completed_at
FROM lanes
WHERE id = ?
`, laneID)
return scanLane(row)
}
func updateLaneTx(ctx context.Context, tx *sql.Tx, value lane.Record) error {
if _, err := tx.ExecContext(ctx, `
UPDATE lanes
SET workspace_id = ?, topic_id = ?, name = ?, slug = ?, purpose = ?, status = ?, base_branch = ?, branch_name = ?, head_commit = ?,
worktree_path = ?, container_name = ?, runtime_endpoint = ?, created_by_role_name = ?,
result_summary_markdown = ?, error_message = ?, updated_at = ?, started_at = ?, completed_at = ?
WHERE id = ?
`,
value.WorkspaceID,
value.TopicID,
value.Name,
value.Slug,
value.Purpose,
string(value.Status),
value.BaseBranch,
value.BranchName,
value.HeadCommit,
value.WorktreePath,
value.ContainerName,
value.RuntimeEndpoint,
value.CreatedByRoleName,
value.ResultSummaryMarkdown,
value.ErrorMessage,
value.UpdatedAt,
nullableString(value.StartedAt),
nullableString(value.CompletedAt),
value.ID,
); err != nil {
return fmt.Errorf("update lane: %w", err)
}
return nil
}
func getWorkflowRunTx(ctx context.Context, tx *sql.Tx, runID string) (workflow.Run, error) {
row := tx.QueryRowContext(ctx, `
SELECT id, workspace_id, topic_id, role_name, stage, mode, status, request_message_id,
config_snapshot_json, command_json, reply_message_id, exit_code, started_at, completed_at, error_message
FROM workflow_runs
WHERE id = ?
`, runID)
return scanWorkflowRun(row)
}
func createWorkflowRunTx(ctx context.Context, tx *sql.Tx, s *Store, value *workflow.Run) error {
if err := value.Validate(); err != nil {
return err
}
if value.ID == "" {
id, err := s.newID("workflow-run")
if err != nil {
return err
}
value.ID = id
}
value.StartedAt = coalesceString(value.StartedAt, s.now())
if value.CommandJSON == "" {
value.CommandJSON = "[]"
}
if value.ConfigSnapshotJSON == "" {
value.ConfigSnapshotJSON = "{}"
}
if _, err := tx.ExecContext(ctx, `
INSERT INTO workflow_runs(id, workspace_id, topic_id, role_name, stage, mode, status, request_message_id, config_snapshot_json, command_json, reply_message_id, exit_code, started_at, completed_at, error_message)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`,
value.ID,
value.WorkspaceID,
value.TopicID,
value.RoleName,
string(value.Stage),
value.Mode,
string(value.Status),
nullableString(value.RequestMessageID),
value.ConfigSnapshotJSON,
value.CommandJSON,
nullableString(value.ReplyMessageID),
value.ExitCode,
value.StartedAt,
nullableString(value.CompletedAt),
value.ErrorMessage,
); err != nil {
return fmt.Errorf("create workflow run: %w", err)
}
return nil
}
func updateWorkflowRunTx(ctx context.Context, tx *sql.Tx, value workflow.Run) error {
if _, err := tx.ExecContext(ctx, `
UPDATE workflow_runs
SET workspace_id = ?, topic_id = ?, role_name = ?, stage = ?, mode = ?, status = ?, request_message_id = ?,
config_snapshot_json = ?, command_json = ?, reply_message_id = ?,
exit_code = ?, completed_at = ?, error_message = ?
WHERE id = ?
`,
value.WorkspaceID,
value.TopicID,
value.RoleName,
string(value.Stage),
value.Mode,
string(value.Status),
nullableString(value.RequestMessageID),
value.ConfigSnapshotJSON,
value.CommandJSON,
nullableString(value.ReplyMessageID),
value.ExitCode,
nullableString(value.CompletedAt),
value.ErrorMessage,
value.ID,
); err != nil {
return fmt.Errorf("update workflow run: %w", err)
}
return nil
}
func promoteReadyTasksTx(ctx context.Context, tx *sql.Tx, topicID, now string) ([]task.Record, map[string]struct{}, error) {
items, err := listTasksByTopicTx(ctx, tx, topicID)
if err != nil {
return nil, nil, err
}
promoted := make([]task.Record, 0)
touchedLanes := map[string]struct{}{}
byLane := make(map[string][]task.Record)
for _, item := range items {
byLane[item.LaneID] = append(byLane[item.LaneID], item)
}
for {
changed := false
for _, laneItems := range byLane {
if hasRunningTaskTx(laneItems) {
continue
}
for idx, item := range laneItems {
if item.Status != task.StatusDraft {
continue
}
ready, err := dependenciesSatisfiedTx(ctx, tx, item.ID)
if err != nil {
return nil, nil, err
}
if !ready {
continue
}
if item.Kind == task.KindMilestone {
item.Status = task.StatusSucceeded
item.ResultSummaryMarkdown = "Milestone completed automatically after dependencies succeeded."
item.BlockingReasonMarkdown = ""
item.UpdatedAt = now
item.CompletedAt = now
} else {
item.Status = task.StatusReady
item.UpdatedAt = now
promoted = append(promoted, item)
}
if err := updateTaskTx(ctx, tx, item); err != nil {
return nil, nil, err
}
laneItems[idx] = item
byLane[item.LaneID][idx] = item
touchedLanes[item.LaneID] = struct{}{}
changed = true
}
}
if !changed {
break
}
}
return promoted, touchedLanes, nil
}
func refreshLaneStatusTx(ctx context.Context, tx *sql.Tx, laneID, now string) (lane.Record, error) {
laneRecord, err := getLaneTx(ctx, tx, laneID)
if err != nil {
return lane.Record{}, err
}
items, err := listTasksByLaneTx(ctx, tx, laneID)
if err != nil {
return lane.Record{}, err
}
status := lane.StatusReady
if len(items) == 0 {
status = lane.StatusDraft
}
allSucceeded := len(items) > 0
for _, item := range items {
switch item.Status {
case task.StatusFailed:
status = lane.StatusBlocked
allSucceeded = false
case task.StatusRunning:
status = lane.StatusRunning
allSucceeded = false
case task.StatusBlocked:
status = lane.StatusBlocked
allSucceeded = false
case task.StatusDraft, task.StatusReady:
if status != lane.StatusBlocked && status != lane.StatusRunning {
status = lane.StatusReady
}
allSucceeded = false
case task.StatusCancelled:
allSucceeded = false
}
}
if allSucceeded {
status = lane.StatusSucceeded
laneRecord.CompletedAt = now
}
laneRecord.Status = status
if status == lane.StatusRunning && laneRecord.StartedAt == "" {
laneRecord.StartedAt = now
}
laneRecord.UpdatedAt = now
if err := updateLaneTx(ctx, tx, laneRecord); err != nil {
return lane.Record{}, err
}
return laneRecord, nil
}
func refreshTopicLaneStatusesTx(ctx context.Context, tx *sql.Tx, topicID string, laneIDs map[string]struct{}, now string) (map[string]lane.Record, error) {
out := make(map[string]lane.Record, len(laneIDs))
if len(laneIDs) == 0 {
return out, nil
}
for laneID := range laneIDs {
item, err := refreshLaneStatusTx(ctx, tx, laneID, now)
if err != nil {
return nil, err
}
out[laneID] = item
}
return out, nil
}
func listTasksByLaneTx(ctx context.Context, tx *sql.Tx, laneID string) ([]task.Record, error) {
rows, err := tx.QueryContext(ctx, `
SELECT id, workspace_id, topic_id, lane_id, title, body_markdown, acceptance_markdown, task_kind, deliverables_json, batch_key, status, priority,
task_order, created_by_role_name, blocking_reason_markdown, result_summary_markdown, assigned_run_id,
created_at, updated_at, started_at, completed_at
FROM tasks
WHERE lane_id = ?
ORDER BY task_order, priority DESC, created_at, id
`, laneID)
if err != nil {
return nil, fmt.Errorf("list tasks by lane: %w", err)
}
defer rows.Close()
var out []task.Record
for rows.Next() {
item, err := scanTask(rows)
if err != nil {
return nil, err
}
out = append(out, item)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate tasks by lane: %w", err)
}
return out, nil
}
func listTasksByTopicTx(ctx context.Context, tx *sql.Tx, topicID string) ([]task.Record, error) {
rows, err := tx.QueryContext(ctx, `
SELECT id, workspace_id, topic_id, lane_id, title, body_markdown, acceptance_markdown, task_kind, deliverables_json, batch_key, status, priority,
task_order, created_by_role_name, blocking_reason_markdown, result_summary_markdown, assigned_run_id,
created_at, updated_at, started_at, completed_at
FROM tasks
WHERE topic_id = ?
ORDER BY lane_id, task_order, priority DESC, created_at, id
`, topicID)
if err != nil {
return nil, fmt.Errorf("list tasks by topic: %w", err)
}
defer rows.Close()
var out []task.Record
for rows.Next() {
item, err := scanTask(rows)
if err != nil {
return nil, err
}
out = append(out, item)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate tasks by topic: %w", err)
}
return out, nil
}
func dependenciesSatisfiedTx(ctx context.Context, tx *sql.Tx, taskID string) (bool, error) {
deps, err := listTaskDependenciesTx(ctx, tx, taskID)
if err != nil {
return false, err
}
for _, dep := range deps {
item, err := getTaskTx(ctx, tx, dep.DependsOnTaskID)
if err != nil {
return false, err
}
if item.Status != task.StatusSucceeded {
return false, nil
}
}
return true, nil
}
func listTaskDependenciesTx(ctx context.Context, tx *sql.Tx, taskID string) ([]task.Dependency, error) {
rows, err := tx.QueryContext(ctx, `
SELECT task_id, depends_on_task_id
FROM task_dependencies
WHERE task_id = ?
ORDER BY depends_on_task_id
`, taskID)
if err != nil {
return nil, fmt.Errorf("list task dependencies: %w", err)
}
defer rows.Close()
var out []task.Dependency
for rows.Next() {
var item task.Dependency
if err := rows.Scan(&item.TaskID, &item.DependsOnTaskID); err != nil {
return nil, err
}
out = append(out, item)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate task dependencies: %w", err)
}
return out, nil
}
func hasRunningTaskTx(items []task.Record) bool {
for _, item := range items {
if item.Status == task.StatusRunning {
return true
}
}
return false
}