Add orch wait command

This commit is contained in:
2026-03-19 14:02:33 +08:00
parent 1b0cd723d7
commit f1785b314f
7 changed files with 535 additions and 22 deletions
+203 -13
View File
@@ -119,6 +119,30 @@ type ReconcileResult struct {
UpdatedTasks []Task `json:"updated_tasks"`
}
type RunEvent struct {
EventID int64 `json:"event_id"`
Type string `json:"type"`
RunID string `json:"run_id"`
TaskID string `json:"task_id"`
ThreadID string `json:"thread_id,omitempty"`
Summary string `json:"summary"`
Payload json.RawMessage `json:"payload"`
CreatedAt time.Time `json:"created_at"`
}
type WaitInput struct {
RunID string
EventTypes []string
AfterEventID int64
Timeout time.Duration
}
type WaitResult struct {
Woke bool `json:"woke"`
NextEventID int64 `json:"next_event_id"`
Events []RunEvent `json:"events,omitempty"`
}
type DispatchWorkspace struct {
BaseRef string `json:"base_ref,omitempty"`
BaseCommit string `json:"base_commit,omitempty"`
@@ -753,20 +777,31 @@ func (s *OrchStore) ReconcileRun(ctx context.Context, runID string) (ReconcileRe
return ReconcileResult{}, fmt.Errorf("update reconciled attempt status: %w", err)
}
summary := fmt.Sprintf("%s -> %s", taskID, nextStatus)
payloadJSON := marshalJSON(map[string]any{
"thread_id": threadID,
"thread_status": threadStatus,
"previous_status": taskStatus,
"previous_attempt": attemptStatus,
})
if nextStatus == "blocked" {
question, err := selectLatestQuestionMessage(ctx, tx, threadID)
if err != nil {
return ReconcileResult{}, err
}
summary = question.Summary
payloadJSON = string(question.PayloadJSON)
}
if err := insertEvent(ctx, tx, eventInput{
RunID: runID,
TaskID: taskID,
ThreadID: threadID,
Source: "orch",
EventType: "task_" + nextStatus,
Summary: fmt.Sprintf("%s -> %s", taskID, nextStatus),
PayloadJSON: marshalJSON(map[string]any{
"thread_id": threadID,
"thread_status": threadStatus,
"previous_status": taskStatus,
"previous_attempt": attemptStatus,
}),
CreatedAt: now,
RunID: runID,
TaskID: taskID,
ThreadID: threadID,
Source: "orch",
EventType: "task_" + nextStatus,
Summary: summary,
PayloadJSON: payloadJSON,
CreatedAt: now,
}); err != nil {
return ReconcileResult{}, err
}
@@ -1060,6 +1095,88 @@ func (s *OrchStore) GetRunOverview(ctx context.Context, runID string) (RunOvervi
}, nil
}
func (s *OrchStore) WaitForEvents(ctx context.Context, input WaitInput) (WaitResult, error) {
if strings.TrimSpace(input.RunID) == "" {
return WaitResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
}
eventTypes := normalizeWaitEventTypes(input.EventTypes)
if _, err := s.GetRun(ctx, input.RunID); err != nil {
return WaitResult{}, err
}
cursor := input.AfterEventID
waitCtx := ctx
cancel := func() {}
if input.Timeout > 0 {
waitCtx, cancel = context.WithTimeout(ctx, input.Timeout)
}
defer cancel()
for {
events, nextEventID, found, err := s.findRunEventsAfter(waitCtx, input.RunID, cursor, eventTypes)
if err != nil {
if isDeadlineExceeded(waitCtx) {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
return WaitResult{}, err
}
if found {
return WaitResult{
Woke: true,
NextEventID: nextEventID,
Events: events,
}, nil
}
if _, err := s.ReconcileRun(waitCtx, input.RunID); err != nil {
if isSQLiteBusyError(err) {
ok, waitErr := waitForNextPoll(waitCtx, 25*time.Millisecond)
if waitErr != nil {
if errors.Is(waitErr, context.DeadlineExceeded) {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
return WaitResult{}, waitErr
}
if !ok {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
continue
}
if isDeadlineExceeded(waitCtx) {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
return WaitResult{}, err
}
events, nextEventID, found, err = s.findRunEventsAfter(waitCtx, input.RunID, cursor, eventTypes)
if err != nil {
if isDeadlineExceeded(waitCtx) {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
return WaitResult{}, err
}
if found {
return WaitResult{
Woke: true,
NextEventID: nextEventID,
Events: events,
}, nil
}
ok, err := waitForNextPoll(waitCtx, 200*time.Millisecond)
if err != nil {
if errors.Is(err, context.DeadlineExceeded) {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
return WaitResult{}, err
}
if !ok {
return WaitResult{Woke: false, NextEventID: cursor}, nil
}
}
}
func listTasksForRun(ctx context.Context, db queryRowsContexter, runID string) ([]Task, error) {
rows, err := db.QueryContext(
ctx,
@@ -1090,6 +1207,55 @@ func listTasksForRun(ctx context.Context, db queryRowsContexter, runID string) (
return tasks, nil
}
func (s *OrchStore) findRunEventsAfter(ctx context.Context, runID string, afterEventID int64, eventTypes []string) ([]RunEvent, int64, bool, error) {
args := []any{runID, afterEventID}
query := `SELECT
event_id, event_type, run_id, task_id, thread_id, summary, payload_json, created_at
FROM events
WHERE run_id = ?
AND event_id > ?`
if len(eventTypes) > 0 {
query += " AND event_type IN (" + placeholders(len(eventTypes)) + ")"
for _, eventType := range eventTypes {
args = append(args, eventType)
}
}
query += " ORDER BY event_id ASC LIMIT 1"
row := s.db.QueryRowContext(ctx, query, args...)
var (
event RunEvent
threadID sql.NullString
payload string
createdAt string
)
err := row.Scan(
&event.EventID,
&event.Type,
&event.RunID,
&event.TaskID,
&threadID,
&event.Summary,
&payload,
&createdAt,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, 0, false, nil
}
if err != nil {
return nil, 0, false, fmt.Errorf("query run events after %d: %w", afterEventID, err)
}
if threadID.Valid {
event.ThreadID = threadID.String
}
event.Payload = json.RawMessage(payload)
event.CreatedAt = parseTime(createdAt)
return []RunEvent{event}, event.EventID, true, nil
}
type queryRowsContexter interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
@@ -1601,6 +1767,30 @@ func normalizePriority(priority string) (string, error) {
}
}
func normalizeWaitEventTypes(eventTypes []string) []string {
if len(eventTypes) == 0 {
return []string{"task_ready", "task_blocked", "task_done", "task_failed"}
}
normalized := make([]string, 0, len(eventTypes))
seen := make(map[string]struct{}, len(eventTypes))
for _, eventType := range eventTypes {
eventType = strings.TrimSpace(eventType)
if eventType == "" {
continue
}
if _, ok := seen[eventType]; ok {
continue
}
seen[eventType] = struct{}{}
normalized = append(normalized, eventType)
}
if len(normalized) == 0 {
return []string{"task_ready", "task_blocked", "task_done", "task_failed"}
}
return normalized
}
func validateAndNormalizeJSONDefault(fieldName, value, defaultValue string) (string, error) {
normalized := strings.TrimSpace(value)
if normalized == "" {