Add inbox send fetch claim show commands
This commit is contained in:
@@ -0,0 +1,596 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var ErrLeaseConflict = errors.New("thread already claimed by another worker")
|
||||
|
||||
type InboxStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
type Thread struct {
|
||||
ThreadID string `json:"thread_id"`
|
||||
RunID string `json:"run_id"`
|
||||
TaskID string `json:"task_id"`
|
||||
Subject string `json:"subject"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
AssignedTo string `json:"assigned_to"`
|
||||
Status string `json:"status"`
|
||||
Priority string `json:"priority"`
|
||||
LatestMessageID string `json:"latest_message_id,omitempty"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
MessageID string `json:"message_id"`
|
||||
ThreadID string `json:"thread_id"`
|
||||
FromAgent string `json:"from_agent"`
|
||||
ToAgent string `json:"to_agent"`
|
||||
Kind string `json:"kind"`
|
||||
Summary string `json:"summary"`
|
||||
Body string `json:"body"`
|
||||
PayloadJSON json.RawMessage `json:"payload_json"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type ThreadDetail struct {
|
||||
Thread Thread `json:"thread"`
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
type SendInput struct {
|
||||
ThreadID string
|
||||
RunID string
|
||||
TaskID string
|
||||
Subject string
|
||||
FromAgent string
|
||||
ToAgent string
|
||||
Kind string
|
||||
Summary string
|
||||
Body string
|
||||
PayloadJSON string
|
||||
Priority string
|
||||
}
|
||||
|
||||
type FetchInput struct {
|
||||
Agent string
|
||||
Statuses []string
|
||||
Limit int
|
||||
}
|
||||
|
||||
type ClaimInput struct {
|
||||
ThreadID string
|
||||
Agent string
|
||||
LeaseSeconds int
|
||||
}
|
||||
|
||||
type ClaimResult struct {
|
||||
Thread Thread `json:"thread"`
|
||||
Message Message `json:"message"`
|
||||
}
|
||||
|
||||
func NewInboxStore(db *sql.DB) *InboxStore {
|
||||
return &InboxStore{db: db}
|
||||
}
|
||||
|
||||
func (s *InboxStore) Send(ctx context.Context, input SendInput) (Thread, Message, error) {
|
||||
now := nowUTC()
|
||||
|
||||
threadID := defaultID(input.ThreadID, "thr")
|
||||
runID := defaultID(input.RunID, "run")
|
||||
taskID := defaultID(input.TaskID, "task")
|
||||
kind := defaultString(input.Kind, "task")
|
||||
priority := defaultString(input.Priority, "normal")
|
||||
summary := defaultString(input.Summary, input.Subject)
|
||||
payload := normalizeJSON(input.PayloadJSON)
|
||||
messageID := newID("msg")
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return Thread{}, Message{}, fmt.Errorf("begin send transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
thread := Thread{
|
||||
ThreadID: threadID,
|
||||
RunID: runID,
|
||||
TaskID: taskID,
|
||||
Subject: input.Subject,
|
||||
CreatedBy: input.FromAgent,
|
||||
AssignedTo: input.ToAgent,
|
||||
Status: "pending",
|
||||
Priority: priority,
|
||||
LatestMessageID: messageID,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO threads (
|
||||
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
|
||||
priority, latest_message_id, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
thread.ThreadID,
|
||||
thread.RunID,
|
||||
thread.TaskID,
|
||||
thread.Subject,
|
||||
thread.CreatedBy,
|
||||
thread.AssignedTo,
|
||||
thread.Status,
|
||||
thread.Priority,
|
||||
thread.LatestMessageID,
|
||||
formatTime(thread.CreatedAt),
|
||||
formatTime(thread.UpdatedAt),
|
||||
); err != nil {
|
||||
return Thread{}, Message{}, fmt.Errorf("insert thread: %w", err)
|
||||
}
|
||||
|
||||
message := Message{
|
||||
MessageID: messageID,
|
||||
ThreadID: threadID,
|
||||
FromAgent: input.FromAgent,
|
||||
ToAgent: input.ToAgent,
|
||||
Kind: kind,
|
||||
Summary: summary,
|
||||
Body: input.Body,
|
||||
PayloadJSON: json.RawMessage(payload),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO messages (
|
||||
message_id, thread_id, from_agent, to_agent, kind, summary, body,
|
||||
payload_json, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
message.MessageID,
|
||||
message.ThreadID,
|
||||
message.FromAgent,
|
||||
message.ToAgent,
|
||||
message.Kind,
|
||||
message.Summary,
|
||||
message.Body,
|
||||
string(message.PayloadJSON),
|
||||
formatTime(message.CreatedAt),
|
||||
); err != nil {
|
||||
return Thread{}, Message{}, fmt.Errorf("insert message: %w", err)
|
||||
}
|
||||
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: thread.RunID,
|
||||
TaskID: thread.TaskID,
|
||||
ThreadID: thread.ThreadID,
|
||||
Source: "inbox",
|
||||
EventType: "thread_created",
|
||||
MessageID: message.MessageID,
|
||||
Summary: summary,
|
||||
PayloadJSON: payload,
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return Thread{}, Message{}, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return Thread{}, Message{}, fmt.Errorf("commit send transaction: %w", err)
|
||||
}
|
||||
|
||||
return thread, message, nil
|
||||
}
|
||||
|
||||
func (s *InboxStore) FetchThreads(ctx context.Context, input FetchInput) ([]Thread, error) {
|
||||
statuses := input.Statuses
|
||||
if len(statuses) == 0 {
|
||||
statuses = []string{"pending"}
|
||||
}
|
||||
|
||||
limit := input.Limit
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
var args []any
|
||||
var conditions []string
|
||||
|
||||
if input.Agent != "" {
|
||||
conditions = append(conditions, "assigned_to = ?")
|
||||
args = append(args, input.Agent)
|
||||
}
|
||||
|
||||
conditions = append(conditions, "status IN ("+placeholders(len(statuses))+")")
|
||||
for _, status := range statuses {
|
||||
args = append(args, status)
|
||||
}
|
||||
args = append(args, limit)
|
||||
|
||||
query := `SELECT
|
||||
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
|
||||
priority, latest_message_id, created_at, updated_at
|
||||
FROM threads`
|
||||
if len(conditions) > 0 {
|
||||
query += " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
query += " ORDER BY updated_at DESC LIMIT ?"
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetch threads: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var threads []Thread
|
||||
for rows.Next() {
|
||||
thread, err := scanThread(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
threads = append(threads, thread)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate threads: %w", err)
|
||||
}
|
||||
|
||||
return threads, nil
|
||||
}
|
||||
|
||||
func (s *InboxStore) ClaimThread(ctx context.Context, input ClaimInput) (ClaimResult, error) {
|
||||
if input.LeaseSeconds <= 0 {
|
||||
input.LeaseSeconds = 900
|
||||
}
|
||||
|
||||
now := nowUTC()
|
||||
expiresAt := now.Add(time.Duration(input.LeaseSeconds) * time.Second)
|
||||
leaseToken := newID("lease")
|
||||
messageID := newID("msg")
|
||||
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return ClaimResult{}, fmt.Errorf("begin claim transaction: %w", err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
thread, err := selectThreadForUpdate(ctx, tx, input.ThreadID)
|
||||
if err != nil {
|
||||
return ClaimResult{}, err
|
||||
}
|
||||
|
||||
if thread.Status != "pending" {
|
||||
return ClaimResult{}, fmt.Errorf("thread %s is not pending", input.ThreadID)
|
||||
}
|
||||
|
||||
var activeLease string
|
||||
err = tx.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT agent_id FROM leases
|
||||
WHERE thread_id = ?
|
||||
AND released_at IS NULL
|
||||
AND expires_at > ?`,
|
||||
input.ThreadID,
|
||||
formatTime(now),
|
||||
).Scan(&activeLease)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return ClaimResult{}, fmt.Errorf("check active lease: %w", err)
|
||||
}
|
||||
if activeLease != "" {
|
||||
return ClaimResult{}, ErrLeaseConflict
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO leases (
|
||||
thread_id, agent_id, lease_token, claimed_at, expires_at, released_at
|
||||
) VALUES (?, ?, ?, ?, ?, NULL)
|
||||
ON CONFLICT(thread_id) DO UPDATE SET
|
||||
agent_id = excluded.agent_id,
|
||||
lease_token = excluded.lease_token,
|
||||
claimed_at = excluded.claimed_at,
|
||||
expires_at = excluded.expires_at,
|
||||
released_at = NULL`,
|
||||
input.ThreadID,
|
||||
input.Agent,
|
||||
leaseToken,
|
||||
formatTime(now),
|
||||
formatTime(expiresAt),
|
||||
); err != nil {
|
||||
return ClaimResult{}, fmt.Errorf("upsert lease: %w", err)
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
ctx,
|
||||
`UPDATE threads
|
||||
SET status = ?, assigned_to = ?, latest_message_id = ?, updated_at = ?
|
||||
WHERE thread_id = ?`,
|
||||
"claimed",
|
||||
input.Agent,
|
||||
messageID,
|
||||
formatTime(now),
|
||||
input.ThreadID,
|
||||
); err != nil {
|
||||
return ClaimResult{}, fmt.Errorf("update thread claim status: %w", err)
|
||||
}
|
||||
|
||||
message := Message{
|
||||
MessageID: messageID,
|
||||
ThreadID: input.ThreadID,
|
||||
FromAgent: input.Agent,
|
||||
ToAgent: input.Agent,
|
||||
Kind: "event",
|
||||
Summary: "thread claimed",
|
||||
Body: "",
|
||||
PayloadJSON: json.RawMessage(fmt.Sprintf(`{"lease_seconds":%d,"lease_token":"%s"}`, input.LeaseSeconds, leaseToken)),
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
if _, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO messages (
|
||||
message_id, thread_id, from_agent, to_agent, kind, summary, body,
|
||||
payload_json, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
message.MessageID,
|
||||
message.ThreadID,
|
||||
message.FromAgent,
|
||||
message.ToAgent,
|
||||
message.Kind,
|
||||
message.Summary,
|
||||
message.Body,
|
||||
string(message.PayloadJSON),
|
||||
formatTime(message.CreatedAt),
|
||||
); err != nil {
|
||||
return ClaimResult{}, fmt.Errorf("insert claim event message: %w", err)
|
||||
}
|
||||
|
||||
if err := insertEvent(ctx, tx, eventInput{
|
||||
RunID: thread.RunID,
|
||||
TaskID: thread.TaskID,
|
||||
ThreadID: thread.ThreadID,
|
||||
Source: "inbox",
|
||||
EventType: "thread_claimed",
|
||||
MessageID: message.MessageID,
|
||||
Summary: message.Summary,
|
||||
PayloadJSON: string(message.PayloadJSON),
|
||||
CreatedAt: now,
|
||||
}); err != nil {
|
||||
return ClaimResult{}, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return ClaimResult{}, fmt.Errorf("commit claim transaction: %w", err)
|
||||
}
|
||||
|
||||
thread.Status = "claimed"
|
||||
thread.AssignedTo = input.Agent
|
||||
thread.LatestMessageID = messageID
|
||||
thread.UpdatedAt = now
|
||||
|
||||
return ClaimResult{
|
||||
Thread: thread,
|
||||
Message: message,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDetail, error) {
|
||||
thread, err := selectThread(ctx, s.db, threadID)
|
||||
if err != nil {
|
||||
return ThreadDetail{}, err
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(
|
||||
ctx,
|
||||
`SELECT
|
||||
message_id, thread_id, from_agent, to_agent, kind, summary, body,
|
||||
payload_json, created_at
|
||||
FROM messages
|
||||
WHERE thread_id = ?
|
||||
ORDER BY created_at ASC`,
|
||||
threadID,
|
||||
)
|
||||
if err != nil {
|
||||
return ThreadDetail{}, fmt.Errorf("query thread messages: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []Message
|
||||
for rows.Next() {
|
||||
message, err := scanMessage(rows)
|
||||
if err != nil {
|
||||
return ThreadDetail{}, err
|
||||
}
|
||||
messages = append(messages, message)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return ThreadDetail{}, fmt.Errorf("iterate thread messages: %w", err)
|
||||
}
|
||||
|
||||
return ThreadDetail{
|
||||
Thread: thread,
|
||||
Messages: messages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type threadScanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func scanThread(scanner threadScanner) (Thread, error) {
|
||||
var (
|
||||
thread Thread
|
||||
createdAt, updatedAt string
|
||||
latestMessageID sql.NullString
|
||||
)
|
||||
|
||||
if err := scanner.Scan(
|
||||
&thread.ThreadID,
|
||||
&thread.RunID,
|
||||
&thread.TaskID,
|
||||
&thread.Subject,
|
||||
&thread.CreatedBy,
|
||||
&thread.AssignedTo,
|
||||
&thread.Status,
|
||||
&thread.Priority,
|
||||
&latestMessageID,
|
||||
&createdAt,
|
||||
&updatedAt,
|
||||
); err != nil {
|
||||
return Thread{}, fmt.Errorf("scan thread: %w", err)
|
||||
}
|
||||
|
||||
thread.CreatedAt = parseTime(createdAt)
|
||||
thread.UpdatedAt = parseTime(updatedAt)
|
||||
if latestMessageID.Valid {
|
||||
thread.LatestMessageID = latestMessageID.String
|
||||
}
|
||||
|
||||
return thread, nil
|
||||
}
|
||||
|
||||
func scanMessage(scanner threadScanner) (Message, error) {
|
||||
var (
|
||||
message Message
|
||||
payload, createdAt string
|
||||
)
|
||||
|
||||
if err := scanner.Scan(
|
||||
&message.MessageID,
|
||||
&message.ThreadID,
|
||||
&message.FromAgent,
|
||||
&message.ToAgent,
|
||||
&message.Kind,
|
||||
&message.Summary,
|
||||
&message.Body,
|
||||
&payload,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
return Message{}, fmt.Errorf("scan message: %w", err)
|
||||
}
|
||||
|
||||
message.PayloadJSON = json.RawMessage(payload)
|
||||
message.CreatedAt = parseTime(createdAt)
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func selectThread(ctx context.Context, db queryRower, threadID string) (Thread, error) {
|
||||
row := db.QueryRowContext(
|
||||
ctx,
|
||||
`SELECT
|
||||
thread_id, run_id, task_id, subject, created_by, assigned_to, status,
|
||||
priority, latest_message_id, created_at, updated_at
|
||||
FROM threads
|
||||
WHERE thread_id = ?`,
|
||||
threadID,
|
||||
)
|
||||
|
||||
thread, err := scanThread(row)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return Thread{}, fmt.Errorf("thread %s not found", threadID)
|
||||
}
|
||||
return thread, err
|
||||
}
|
||||
|
||||
func selectThreadForUpdate(ctx context.Context, tx *sql.Tx, threadID string) (Thread, error) {
|
||||
return selectThread(ctx, tx, threadID)
|
||||
}
|
||||
|
||||
type queryRower interface {
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
type eventInput struct {
|
||||
RunID string
|
||||
TaskID string
|
||||
ThreadID string
|
||||
Source string
|
||||
EventType string
|
||||
MessageID string
|
||||
Summary string
|
||||
PayloadJSON string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
func insertEvent(ctx context.Context, tx *sql.Tx, input eventInput) error {
|
||||
_, err := tx.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO events (
|
||||
run_id, task_id, thread_id, source, event_type, message_id, summary,
|
||||
payload_json, created_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
input.RunID,
|
||||
input.TaskID,
|
||||
input.ThreadID,
|
||||
input.Source,
|
||||
input.EventType,
|
||||
input.MessageID,
|
||||
input.Summary,
|
||||
normalizeJSON(input.PayloadJSON),
|
||||
formatTime(input.CreatedAt),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert event: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultID(value, prefix string) string {
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
return newID(prefix)
|
||||
}
|
||||
|
||||
func newID(prefix string) string {
|
||||
return prefix + "_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
}
|
||||
|
||||
func defaultString(value, fallback string) string {
|
||||
if value != "" {
|
||||
return value
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func normalizeJSON(value string) string {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return "{}"
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
if n <= 0 {
|
||||
return ""
|
||||
}
|
||||
parts := make([]string, n)
|
||||
for i := range parts {
|
||||
parts[i] = "?"
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
func nowUTC() time.Time {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
func formatTime(t time.Time) string {
|
||||
return t.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func parseTime(value string) time.Time {
|
||||
parsed, err := time.Parse(time.RFC3339Nano, value)
|
||||
if err != nil {
|
||||
return time.Time{}
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
Reference in New Issue
Block a user