Files
ai-workflow-skill/internal/query/read_service.go
T

212 lines
4.6 KiB
Go

package query
import (
"context"
"database/sql"
"fmt"
"time"
"ai-workflow-skill/internal/store"
)
type ReadService struct {
db *sql.DB
orch *store.OrchStore
inbox *store.InboxStore
}
type RunListItem struct {
Run store.Run `json:"run"`
TaskCounts map[string]int `json:"task_counts"`
TotalTasks int `json:"total_tasks"`
}
type RunDetail struct {
Run store.Run `json:"run"`
TaskCounts map[string]int `json:"task_counts"`
TotalTasks int `json:"total_tasks"`
Tasks []store.Task `json:"tasks"`
BlockedTasks []store.BlockedTask `json:"blocked_tasks"`
}
func NewReadService(db *sql.DB) *ReadService {
return &ReadService{
db: db,
orch: store.NewOrchStore(db),
inbox: store.NewInboxStore(db),
}
}
func (s *ReadService) ListRuns(ctx context.Context) ([]RunListItem, error) {
rows, err := s.db.QueryContext(
ctx,
`SELECT run_id, goal, summary, status, created_at, updated_at
FROM runs
ORDER BY updated_at DESC, created_at DESC`,
)
if err != nil {
return nil, fmt.Errorf("query runs: %w", err)
}
defer rows.Close()
var runs []store.Run
runIDs := make([]string, 0)
for rows.Next() {
var (
run store.Run
createdAt, updated string
)
if err := rows.Scan(
&run.RunID,
&run.Goal,
&run.Summary,
&run.Status,
&createdAt,
&updated,
); err != nil {
return nil, fmt.Errorf("scan run list row: %w", err)
}
run.CreatedAt = parseRFC3339(createdAt)
run.UpdatedAt = parseRFC3339(updated)
runs = append(runs, run)
runIDs = append(runIDs, run.RunID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate runs: %w", err)
}
countsByRunID, err := s.collectTaskCounts(ctx, runIDs)
if err != nil {
return nil, err
}
items := make([]RunListItem, 0, len(runs))
for _, run := range runs {
taskCounts := countsByRunID[run.RunID]
if taskCounts == nil {
taskCounts = map[string]int{}
}
items = append(items, RunListItem{
Run: run,
TaskCounts: taskCounts,
TotalTasks: totalTasks(taskCounts),
})
}
return items, nil
}
func (s *ReadService) GetRunDetail(ctx context.Context, runID string) (RunDetail, error) {
overview, err := s.orch.GetRunOverview(ctx, runID)
if err != nil {
return RunDetail{}, err
}
blocked, err := s.orch.ListBlockedTasks(ctx, runID)
if err != nil {
return RunDetail{}, err
}
return RunDetail{
Run: overview.Run,
TaskCounts: overview.TaskCounts,
TotalTasks: totalTasks(overview.TaskCounts),
Tasks: overview.Tasks,
BlockedTasks: blocked,
}, nil
}
func (s *ReadService) ListRunTasks(ctx context.Context, runID string) ([]store.Task, error) {
detail, err := s.GetRunDetail(ctx, runID)
if err != nil {
return nil, err
}
return detail.Tasks, nil
}
func (s *ReadService) ListBlockedTasks(ctx context.Context, runID string) ([]store.BlockedTask, error) {
return s.orch.ListBlockedTasks(ctx, runID)
}
func (s *ReadService) GetThreadDetail(ctx context.Context, threadID string) (store.ThreadDetail, error) {
return s.inbox.GetThread(ctx, threadID)
}
func (s *ReadService) collectTaskCounts(ctx context.Context, runIDs []string) (map[string]map[string]int, error) {
result := make(map[string]map[string]int, len(runIDs))
if len(runIDs) == 0 {
return result, nil
}
args := make([]any, 0, len(runIDs))
for _, runID := range runIDs {
args = append(args, runID)
}
rows, err := s.db.QueryContext(
ctx,
`SELECT run_id, status, COUNT(*)
FROM tasks
WHERE run_id IN (`+placeholders(len(runIDs))+`)
GROUP BY run_id, status`,
args...,
)
if err != nil {
return nil, fmt.Errorf("query task counts for runs: %w", err)
}
defer rows.Close()
for rows.Next() {
var (
runID string
status string
count int
)
if err := rows.Scan(&runID, &status, &count); err != nil {
return nil, fmt.Errorf("scan run task count: %w", err)
}
if result[runID] == nil {
result[runID] = make(map[string]int)
}
result[runID][status] = count
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate run task counts: %w", err)
}
return result, nil
}
func totalTasks(counts map[string]int) int {
total := 0
for _, count := range counts {
total += count
}
return total
}
func placeholders(count int) string {
if count <= 0 {
return ""
}
buf := make([]byte, 0, count*2-1)
for i := 0; i < count; i++ {
if i > 0 {
buf = append(buf, ',')
}
buf = append(buf, '?')
}
return string(buf)
}
func parseRFC3339(value string) time.Time {
parsed, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return time.Time{}
}
return parsed
}