Files
ai-workflow-skill/internal/store/inbox.go
T

597 lines
14 KiB
Go

package store
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
)
var ErrLeaseConflict = errors.New("thread already claimed by another worker")
type InboxStore struct {
db *sql.DB
}
type Thread struct {
ThreadID string `json:"thread_id"`
RunID string `json:"run_id"`
TaskID string `json:"task_id"`
Subject string `json:"subject"`
CreatedBy string `json:"created_by"`
AssignedTo string `json:"assigned_to"`
Status string `json:"status"`
Priority string `json:"priority"`
LatestMessageID string `json:"latest_message_id,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type Message struct {
MessageID string `json:"message_id"`
ThreadID string `json:"thread_id"`
FromAgent string `json:"from_agent"`
ToAgent string `json:"to_agent"`
Kind string `json:"kind"`
Summary string `json:"summary"`
Body string `json:"body"`
PayloadJSON json.RawMessage `json:"payload_json"`
CreatedAt time.Time `json:"created_at"`
}
type ThreadDetail struct {
Thread Thread `json:"thread"`
Messages []Message `json:"messages"`
}
type SendInput struct {
ThreadID string
RunID string
TaskID string
Subject string
FromAgent string
ToAgent string
Kind string
Summary string
Body string
PayloadJSON string
Priority string
}
type FetchInput struct {
Agent string
Statuses []string
Limit int
}
type ClaimInput struct {
ThreadID string
Agent string
LeaseSeconds int
}
type ClaimResult struct {
Thread Thread `json:"thread"`
Message Message `json:"message"`
}
func NewInboxStore(db *sql.DB) *InboxStore {
return &InboxStore{db: db}
}
func (s *InboxStore) Send(ctx context.Context, input SendInput) (Thread, Message, error) {
now := nowUTC()
threadID := defaultID(input.ThreadID, "thr")
runID := defaultID(input.RunID, "run")
taskID := defaultID(input.TaskID, "task")
kind := defaultString(input.Kind, "task")
priority := defaultString(input.Priority, "normal")
summary := defaultString(input.Summary, input.Subject)
payload := normalizeJSON(input.PayloadJSON)
messageID := newID("msg")
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return Thread{}, Message{}, fmt.Errorf("begin send transaction: %w", err)
}
defer tx.Rollback()
thread := Thread{
ThreadID: threadID,
RunID: runID,
TaskID: taskID,
Subject: input.Subject,
CreatedBy: input.FromAgent,
AssignedTo: input.ToAgent,
Status: "pending",
Priority: priority,
LatestMessageID: messageID,
CreatedAt: now,
UpdatedAt: now,
}
if _, err := tx.ExecContext(
ctx,
`INSERT INTO threads (
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
priority, latest_message_id, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
thread.ThreadID,
thread.RunID,
thread.TaskID,
thread.Subject,
thread.CreatedBy,
thread.AssignedTo,
thread.Status,
thread.Priority,
thread.LatestMessageID,
formatTime(thread.CreatedAt),
formatTime(thread.UpdatedAt),
); err != nil {
return Thread{}, Message{}, fmt.Errorf("insert thread: %w", err)
}
message := Message{
MessageID: messageID,
ThreadID: threadID,
FromAgent: input.FromAgent,
ToAgent: input.ToAgent,
Kind: kind,
Summary: summary,
Body: input.Body,
PayloadJSON: json.RawMessage(payload),
CreatedAt: now,
}
if _, err := tx.ExecContext(
ctx,
`INSERT INTO messages (
message_id, thread_id, from_agent, to_agent, kind, summary, body,
payload_json, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
message.MessageID,
message.ThreadID,
message.FromAgent,
message.ToAgent,
message.Kind,
message.Summary,
message.Body,
string(message.PayloadJSON),
formatTime(message.CreatedAt),
); err != nil {
return Thread{}, Message{}, fmt.Errorf("insert message: %w", err)
}
if err := insertEvent(ctx, tx, eventInput{
RunID: thread.RunID,
TaskID: thread.TaskID,
ThreadID: thread.ThreadID,
Source: "inbox",
EventType: "thread_created",
MessageID: message.MessageID,
Summary: summary,
PayloadJSON: payload,
CreatedAt: now,
}); err != nil {
return Thread{}, Message{}, err
}
if err := tx.Commit(); err != nil {
return Thread{}, Message{}, fmt.Errorf("commit send transaction: %w", err)
}
return thread, message, nil
}
func (s *InboxStore) FetchThreads(ctx context.Context, input FetchInput) ([]Thread, error) {
statuses := input.Statuses
if len(statuses) == 0 {
statuses = []string{"pending"}
}
limit := input.Limit
if limit <= 0 {
limit = 20
}
var args []any
var conditions []string
if input.Agent != "" {
conditions = append(conditions, "assigned_to = ?")
args = append(args, input.Agent)
}
conditions = append(conditions, "status IN ("+placeholders(len(statuses))+")")
for _, status := range statuses {
args = append(args, status)
}
args = append(args, limit)
query := `SELECT
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
priority, latest_message_id, created_at, updated_at
FROM threads`
if len(conditions) > 0 {
query += " WHERE " + strings.Join(conditions, " AND ")
}
query += " ORDER BY updated_at DESC LIMIT ?"
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("fetch threads: %w", err)
}
defer rows.Close()
var threads []Thread
for rows.Next() {
thread, err := scanThread(rows)
if err != nil {
return nil, err
}
threads = append(threads, thread)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate threads: %w", err)
}
return threads, nil
}
func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimResult, error) {
if input.LeaseSeconds <= 0 {
input.LeaseSeconds = 900
}
now := nowUTC()
expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second)
leaseToken := newID("lease")
messageID := newID("msg")
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return ClaimResult{}, fmt.Errorf("begin claim transaction: %w", err)
}
defer tx.Rollback()
thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID)
if err != nil {
return ClaimResult{}, err
}
if thread.Status != "pending" {
return ClaimResult{}, fmt.Errorf("thread %s is not pending", input.ThreadID)
}
var activeLease string
err = tx.QueryRowContext(
ctx,
`SELECT agent_id FROM leases
WHERE thread_id = ?
AND released_at IS NULL
AND expires_at > ?`,
input.ThreadID,
formatTime(now),
).Scan(&activeLease)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return ClaimResult{}, fmt.Errorf("check active lease: %w", err)
}
if activeLease != "" {
return ClaimResult{}, ErrLeaseConflict
}
if _, err := tx.ExecContext(
ctx,
`INSERT INTO leases (
thread_id, agent_id, lease_token, claimed_at, expires_at, released_at
) VALUES (?, ?, ?, ?, ?, NULL)
ON CONFLICT(thread_id) DO UPDATE SET
agent_id = excluded.agent_id,
lease_token = excluded.lease_token,
claimed_at = excluded.claimed_at,
expires_at = excluded.expires_at,
released_at = NULL`,
input.ThreadID,
input.Agent,
leaseToken,
formatTime(now),
formatTime(expiresAt),
); err != nil {
return ClaimResult{}, fmt.Errorf("upsert lease: %w", err)
}
if _, err := tx.ExecContext(
ctx,
`UPDATE threads
SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ?
WHERE thread_id = ?`,
"claimed",
input.Agent,
messageID,
formatTime(now),
input.ThreadID,
); err != nil {
return ClaimResult{}, fmt.Errorf("update thread claim status: %w", err)
}
message := Message{
MessageID: messageID,
ThreadID: input.ThreadID,
FromAgent: input.Agent,
ToAgent: input.Agent,
Kind: "event",
Summary: "thread claimed",
Body: "",
PayloadJSON: json.RawMessage(fmt.Sprintf(`{"lease_seconds":%d,"lease_token":"%s"}`, input.LeaseSeconds, leaseToken)),
CreatedAt: now,
}
if _, err := tx.ExecContext(
ctx,
`INSERT INTO messages (
message_id, thread_id, from_agent, to_agent, kind, summary, body,
payload_json, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
message.MessageID,
message.ThreadID,
message.FromAgent,
message.ToAgent,
message.Kind,
message.Summary,
message.Body,
string(message.PayloadJSON),
formatTime(message.CreatedAt),
); err != nil {
return ClaimResult{}, fmt.Errorf("insert claim event message: %w", err)
}
if err := insertEvent(ctx, tx, eventInput{
RunID: thread.RunID,
TaskID: thread.TaskID,
ThreadID: thread.ThreadID,
Source: "inbox",
EventType: "thread_claimed",
MessageID: message.MessageID,
Summary: message.Summary,
PayloadJSON: string(message.PayloadJSON),
CreatedAt: now,
}); err != nil {
return ClaimResult{}, err
}
if err := tx.Commit(); err != nil {
return ClaimResult{}, fmt.Errorf("commit claim transaction: %w", err)
}
thread.Status = "claimed"
thread.AssignedTo = input.Agent
thread.LatestMessageID = messageID
thread.UpdatedAt = now
return ClaimResult{
Thread: thread,
Message: message,
}, nil
}
func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDetail, error) {
thread, err := selectThread(ctx, s.db, threadID)
if err != nil {
return ThreadDetail{}, err
}
rows, err := s.db.QueryContext(
ctx,
`SELECT
message_id, thread_id, from_agent, to_agent, kind, summary, body,
payload_json, created_at
FROM messages
WHERE thread_id = ?
ORDER BY created_at ASC`,
threadID,
)
if err != nil {
return ThreadDetail{}, fmt.Errorf("query thread messages: %w", err)
}
defer rows.Close()
var messages []Message
for rows.Next() {
message, err := scanMessage(rows)
if err != nil {
return ThreadDetail{}, err
}
messages = append(messages, message)
}
if err := rows.Err(); err != nil {
return ThreadDetail{}, fmt.Errorf("iterate thread messages: %w", err)
}
return ThreadDetail{
Thread: thread,
Messages: messages,
}, nil
}
type threadScanner interface {
Scan(dest ...any) error
}
func scanThread(scanner threadScanner) (Thread, error) {
var (
thread Thread
createdAt, updatedAt string
latestMessageID sql.NullString
)
if err := scanner.Scan(
&thread.ThreadID,
&thread.RunID,
&thread.TaskID,
&thread.Subject,
&thread.CreatedBy,
&thread.AssignedTo,
&thread.Status,
&thread.Priority,
&latestMessageID,
&createdAt,
&updatedAt,
); err != nil {
return Thread{}, fmt.Errorf("scan thread: %w", err)
}
thread.CreatedAt = parseTime(createdAt)
thread.UpdatedAt = parseTime(updatedAt)
if latestMessageID.Valid {
thread.LatestMessageID = latestMessageID.String
}
return thread, nil
}
func scanMessage(scanner threadScanner) (Message, error) {
var (
message Message
payload, createdAt string
)
if err := scanner.Scan(
&message.MessageID,
&message.ThreadID,
&message.FromAgent,
&message.ToAgent,
&message.Kind,
&message.Summary,
&message.Body,
&payload,
&createdAt,
); err != nil {
return Message{}, fmt.Errorf("scan message: %w", err)
}
message.PayloadJSON = json.RawMessage(payload)
message.CreatedAt = parseTime(createdAt)
return message, nil
}
func selectThread(ctx context.Context, db queryRower, threadID string) (Thread, error) {
row := db.QueryRowContext(
ctx,
`SELECT
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
priority, latest_message_id, created_at, updated_at
FROM threads
WHERE thread_id = ?`,
threadID,
)
thread, err := scanThread(row)
if errors.Is(err, sql.ErrNoRows) {
return Thread{}, fmt.Errorf("thread %s not found", threadID)
}
return thread, err
}
func selectThreadForUpdate(ctx context.Context, tx *sql.Tx, threadID string) (Thread, error) {
return selectThread(ctx, tx, threadID)
}
type queryRower interface {
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type eventInput struct {
RunID string
TaskID string
ThreadID string
Source string
EventType string
MessageID string
Summary string
PayloadJSON string
CreatedAt time.Time
}
func insertEvent(ctx context.Context, tx *sql.Tx, input eventInput) error {
_, err := tx.ExecContext(
ctx,
`INSERT INTO events (
run_id, task_id, thread_id, source, event_type, message_id, summary,
payload_json, created_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
input.RunID,
input.TaskID,
input.ThreadID,
input.Source,
input.EventType,
input.MessageID,
input.Summary,
normalizeJSON(input.PayloadJSON),
formatTime(input.CreatedAt),
)
if err != nil {
return fmt.Errorf("insert event: %w", err)
}
return nil
}
func defaultID(value, prefix string) string {
if value != "" {
return value
}
return newID(prefix)
}
func newID(prefix string) string {
return prefix + "_" + strings.ReplaceAll(uuid.NewString(), "-", "")
}
func defaultString(value, fallback string) string {
if value != "" {
return value
}
return fallback
}
func normalizeJSON(value string) string {
if strings.TrimSpace(value) == "" {
return "{}"
}
return value
}
func placeholders(n int) string {
if n <= 0 {
return ""
}
parts := make([]string, n)
for i := range parts {
parts[i] = "?"
}
return strings.Join(parts, ",")
}
func nowUTC() time.Time {
return time.Now().UTC()
}
func formatTime(t time.Time) string {
return t.UTC().Format(time.RFC3339Nano)
}
func parseTime(value string) time.Time {
parsed, err := time.Parse(time.RFC3339Nano, value)
if err != nil {
return time.Time{}
}
return parsed
}