Add orch wait command
This commit is contained in:
+203
-13
@@ -119,6 +119,30 @@ type ReconcileResult struct {
|
||||
UpdatedTasks []Task `json:"updated_tasks"`
|
||||
}
|
||||
|
||||
type RunEvent struct {
|
||||
EventID int64 `json:"event_id"`
|
||||
Type string `json:"type"`
|
||||
RunID string `json:"run_id"`
|
||||
TaskID string `json:"task_id"`
|
||||
ThreadID string `json:"thread_id,omitempty"`
|
||||
Summary string `json:"summary"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type WaitInput struct {
|
||||
RunID string
|
||||
EventTypes []string
|
||||
AfterEventID int64
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type WaitResult struct {
|
||||
Woke bool `json:"woke"`
|
||||
NextEventID int64 `json:"next_event_id"`
|
||||
Events []RunEvent `json:"events,omitempty"`
|
||||
}
|
||||
|
||||
type DispatchWorkspace struct {
|
||||
BaseRef string `json:"base_ref,omitempty"`
|
||||
BaseCommit string `json:"base_commit,omitempty"`
|
||||
@@ -753,20 +777,31 @@ func (s *OrchStore) ReconcileRun(ctx context.Context, runID string) (ReconcileRe
|
||||
return ReconcileResult{}, fmt.Errorf("update reconciled attempt status: %w", err)
|
||||
}
|
||||
|
||||
summary := fmt.Sprintf("%s -> %s", taskID, nextStatus)
|
||||
payloadJSON := marshalJSON(map[string]any{
|
||||
"thread_id": threadID,
|
||||
"thread_status": threadStatus,
|
||||
"previous_status": taskStatus,
|
||||
"previous_attempt": attemptStatus,
|
||||
})
|
||||
if nextStatus == "blocked" {
|
||||
question, err := selectLatestQuestionMessage(ctx, tx, threadID)
|
||||
if err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
summary = question.Summary
|
||||
payloadJSON = string(question.PayloadJSON)
|
||||
}
|
||||
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: runID,
|
||||
TaskID: taskID,
|
||||
ThreadID: threadID,
|
||||
Source: "orch",
|
||||
EventType: "task_" + nextStatus,
|
||||
Summary: fmt.Sprintf("%s -> %s", taskID, nextStatus),
|
||||
PayloadJSON: marshalJSON(map[string]any{
|
||||
"thread_id": threadID,
|
||||
"thread_status": threadStatus,
|
||||
"previous_status": taskStatus,
|
||||
"previous_attempt": attemptStatus,
|
||||
}),
|
||||
CreatedAt: now,
|
||||
RunID: runID,
|
||||
TaskID: taskID,
|
||||
ThreadID: threadID,
|
||||
Source: "orch",
|
||||
EventType: "task_" + nextStatus,
|
||||
Summary: summary,
|
||||
PayloadJSON: payloadJSON,
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return ReconcileResult{}, err
|
||||
}
|
||||
@@ -1060,6 +1095,88 @@ func (s *OrchStore) GetRunOverview(ctx context.Context, runID string) (RunOvervi
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *OrchStore) WaitForEvents(ctx context.Context, input WaitInput) (WaitResult, error) {
|
||||
if strings.TrimSpace(input.RunID) == "" {
|
||||
return WaitResult{}, fmt.Errorf("%w: run id is required", ErrInvalidInput)
|
||||
}
|
||||
|
||||
eventTypes := normalizeWaitEventTypes(input.EventTypes)
|
||||
if _, err := s.GetRun(ctx, input.RunID); err != nil {
|
||||
return WaitResult{}, err
|
||||
}
|
||||
|
||||
cursor := input.AfterEventID
|
||||
waitCtx := ctx
|
||||
cancel := func() {}
|
||||
if input.Timeout > 0 {
|
||||
waitCtx, cancel = context.WithTimeout(ctx, input.Timeout)
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
for {
|
||||
events, nextEventID, found, err := s.findRunEventsAfter(waitCtx, input.RunID, cursor, eventTypes)
|
||||
if err != nil {
|
||||
if isDeadlineExceeded(waitCtx) {
|
||||
return WaitResult{Woke: false, NextEventID: cursor}, nil
|
||||
}
|
||||
return WaitResult{}, err
|
||||
}
|
||||
if found {
|
||||
return WaitResult{
|
||||
Woke: true,
|
||||
NextEventID: nextEventID,
|
||||
Events: events,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if _, err := s.ReconcileRun(waitCtx, input.RunID); err != nil {
|
||||
if isSQLiteBusyError(err) {
|
||||
ok, waitErr := waitForNextPoll(waitCtx, 25*time.Millisecond)
|
||||
if waitErr != nil {
|
||||
if errors.Is(waitErr, context.DeadlineExceeded) {
|
||||
return WaitResult{Woke: false, NextEventID: cursor}, nil
|
||||
}
|
||||
return WaitResult{}, waitErr
|
||||
}
|
||||
if !ok {
|
||||
return WaitResult{Woke: false, NextEventID: cursor}, nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
if isDeadlineExceeded(waitCtx) {
|
||||
return WaitResult{Woke: false, NextEventID: cursor}, nil
|
||||
}
|
||||
return WaitResult{}, err
|
||||
}
|
||||
|
||||
events, nextEventID, found, err = s.findRunEventsAfter(waitCtx, input.RunID, cursor, eventTypes)
|
||||
if err != nil {
|
||||
if isDeadlineExceeded(waitCtx) {
|
||||
return WaitResult{Woke: false, NextEventID: cursor}, nil
|
||||
}
|
||||
return WaitResult{}, err
|
||||
}
|
||||
if found {
|
||||
return WaitResult{
|
||||
Woke: true,
|
||||
NextEventID: nextEventID,
|
||||
Events: events,
|
||||
}, nil
|
||||
}
|
||||
|
||||
ok, err := waitForNextPoll(waitCtx, 200*time.Millisecond)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return WaitResult{Woke: false, NextEventID: cursor}, nil
|
||||
}
|
||||
return WaitResult{}, err
|
||||
}
|
||||
if !ok {
|
||||
return WaitResult{Woke: false, NextEventID: cursor}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func listTasksForRun(ctx context.Context, db queryRowsContexter, runID string) ([]Task, error) {
|
||||
rows, err := db.QueryContext(
|
||||
ctx,
|
||||
@@ -1090,6 +1207,55 @@ func listTasksForRun(ctx context.Context, db queryRowsContexter, runID string) (
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
func (s *OrchStore) findRunEventsAfter(ctx context.Context, runID string, afterEventID int64, eventTypes []string) ([]RunEvent, int64, bool, error) {
|
||||
args := []any{runID, afterEventID}
|
||||
query := `SELECT
|
||||
event_id, event_type, run_id, task_id, thread_id, summary, payload_json, created_at
|
||||
FROM events
|
||||
WHERE run_id = ?
|
||||
AND event_id > ?`
|
||||
if len(eventTypes) > 0 {
|
||||
query += " AND event_type IN (" + placeholders(len(eventTypes)) + ")"
|
||||
for _, eventType := range eventTypes {
|
||||
args = append(args, eventType)
|
||||
}
|
||||
}
|
||||
query += " ORDER BY event_id ASC LIMIT 1"
|
||||
|
||||
row := s.db.QueryRowContext(ctx, query, args...)
|
||||
|
||||
var (
|
||||
event RunEvent
|
||||
threadID sql.NullString
|
||||
payload string
|
||||
createdAt string
|
||||
)
|
||||
err := row.Scan(
|
||||
&event.EventID,
|
||||
&event.Type,
|
||||
&event.RunID,
|
||||
&event.TaskID,
|
||||
&threadID,
|
||||
&event.Summary,
|
||||
&payload,
|
||||
&createdAt,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, 0, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, 0, false, fmt.Errorf("query run events after %d: %w", afterEventID, err)
|
||||
}
|
||||
|
||||
if threadID.Valid {
|
||||
event.ThreadID = threadID.String
|
||||
}
|
||||
event.Payload = json.RawMessage(payload)
|
||||
event.CreatedAt = parseTime(createdAt)
|
||||
|
||||
return []RunEvent{event}, event.EventID, true, nil
|
||||
}
|
||||
|
||||
type queryRowsContexter interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
@@ -1601,6 +1767,30 @@ func normalizePriority(priority string) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeWaitEventTypes(eventTypes []string) []string {
|
||||
if len(eventTypes) == 0 {
|
||||
return []string{"task_ready", "task_blocked", "task_done", "task_failed"}
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(eventTypes))
|
||||
seen := make(map[string]struct{}, len(eventTypes))
|
||||
for _, eventType := range eventTypes {
|
||||
eventType = strings.TrimSpace(eventType)
|
||||
if eventType == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[eventType]; ok {
|
||||
continue
|
||||
}
|
||||
seen[eventType] = struct{}{}
|
||||
normalized = append(normalized, eventType)
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return []string{"task_ready", "task_blocked", "task_done", "task_failed"}
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func validateAndNormalizeJSONDefault(fieldName, value, defaultValue string) (string, error) {
|
||||
normalized := strings.TrimSpace(value)
|
||||
if normalized == "" {
|
||||
|
||||
Reference in New Issue
Block a user