Files

277 lines
8.0 KiB
Go

package sqlite
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"inbox/internal/domain/humantask"
"inbox/internal/domain/message"
"inbox/internal/domain/role"
)
func (s *Store) CreateMessage(ctx context.Context, value message.Record) (message.Record, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return message.Record{}, fmt.Errorf("begin create message: %w", err)
}
defer tx.Rollback()
item, err := s.createMessageTx(ctx, tx, value)
if err != nil {
return message.Record{}, err
}
if err := tx.Commit(); err != nil {
return message.Record{}, fmt.Errorf("commit create message: %w", err)
}
return item, nil
}
func (s *Store) ListMessagesByTopic(ctx context.Context, topicID string) ([]message.Record, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT id, workspace_id, topic_id, from_role_name, to_expr, type, stage, reply_to_message_id, body_markdown, created_at
FROM messages
WHERE topic_id = ?
ORDER BY created_at, id
`, topicID)
if err != nil {
return nil, fmt.Errorf("list messages by topic: %w", err)
}
defer rows.Close()
var out []message.Record
for rows.Next() {
item, err := scanMessage(rows)
if err != nil {
return nil, err
}
out = append(out, item)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate messages: %w", err)
}
return out, nil
}
func (s *Store) ListMessagesByWorkspace(ctx context.Context, workspaceID string) ([]message.Record, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT id, workspace_id, topic_id, from_role_name, to_expr, type, stage, reply_to_message_id, body_markdown, created_at
FROM messages
WHERE workspace_id = ?
ORDER BY created_at, id
`, workspaceID)
if err != nil {
return nil, fmt.Errorf("list messages by workspace: %w", err)
}
defer rows.Close()
var out []message.Record
for rows.Next() {
item, err := scanMessage(rows)
if err != nil {
return nil, err
}
out = append(out, item)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate messages by workspace: %w", err)
}
return out, nil
}
func (s *Store) ListPendingDeliveriesByWorkspace(ctx context.Context, workspaceID string) ([]message.PendingDelivery, error) {
rows, err := s.db.QueryContext(ctx, `
SELECT m.topic_id, d.recipient_role_name, COUNT(*), MAX(d.updated_at)
FROM message_deliveries d
JOIN messages m ON m.id = d.message_id
JOIN roles r ON r.name = d.recipient_role_name
WHERE m.workspace_id = ? AND d.state = ? AND r.executor_kind = ?
GROUP BY m.topic_id, d.recipient_role_name
ORDER BY MAX(d.updated_at) DESC, m.topic_id, d.recipient_role_name
`, workspaceID, string(message.DeliveryPending), string(role.ExecutorKindCodex))
if err != nil {
return nil, fmt.Errorf("list pending deliveries by workspace: %w", err)
}
defer rows.Close()
var out []message.PendingDelivery
for rows.Next() {
var item message.PendingDelivery
if err := rows.Scan(&item.TopicID, &item.RoleName, &item.Count, &item.LastUpdated); err != nil {
return nil, fmt.Errorf("scan pending delivery: %w", err)
}
out = append(out, item)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate pending deliveries: %w", err)
}
return out, nil
}
func (s *Store) expandRecipients(ctx context.Context, tx *sql.Tx, toExpr, fromRole string) ([]string, error) {
toExpr = strings.TrimSpace(toExpr)
if toExpr == "" {
return nil, fmt.Errorf("to expr is required")
}
if toExpr == "all" {
rows, err := tx.QueryContext(ctx, `SELECT name FROM roles WHERE is_enabled = 1 ORDER BY sort_order, name`)
if err != nil {
return nil, fmt.Errorf("list recipient roles: %w", err)
}
defer rows.Close()
var out []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, fmt.Errorf("scan recipient role: %w", err)
}
if name == fromRole {
continue
}
out = append(out, name)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate recipient roles: %w", err)
}
if len(out) == 0 {
return nil, fmt.Errorf("no recipients resolved for %q", toExpr)
}
return out, nil
}
parts := strings.Split(toExpr, ",")
out := make([]string, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, part := range parts {
name := strings.TrimSpace(part)
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
out = append(out, name)
}
sort.Strings(out)
if len(out) == 0 {
return nil, fmt.Errorf("no recipients resolved for %q", toExpr)
}
return out, nil
}
func (s *Store) createMessageTx(ctx context.Context, tx *sql.Tx, value message.Record) (message.Record, error) {
if err := value.Validate(); err != nil {
return message.Record{}, err
}
if value.ID == "" {
id, err := s.newID("message")
if err != nil {
return message.Record{}, err
}
value.ID = id
}
value.CreatedAt = coalesceString(value.CreatedAt, s.now())
if _, err := tx.ExecContext(ctx, `
INSERT INTO messages(id, workspace_id, topic_id, from_role_name, to_expr, type, stage, reply_to_message_id, body_markdown, created_at)
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`, value.ID, value.WorkspaceID, value.TopicID, value.FromRoleName, value.ToExpr, string(value.Type), value.Stage, nullableString(value.ReplyToMessageID), value.BodyMarkdown, value.CreatedAt); err != nil {
return message.Record{}, fmt.Errorf("insert message: %w", err)
}
recipients, err := s.expandRecipients(ctx, tx, value.ToExpr, value.FromRoleName)
if err != nil {
return message.Record{}, err
}
recipientRoles, err := resolveRecipientRolesTx(ctx, tx, recipients)
if err != nil {
return message.Record{}, err
}
for _, recipient := range recipients {
definition, ok := recipientRoles[recipient]
if !ok {
continue
}
switch definition.ExecutorKind {
case role.ExecutorKindHuman:
if value.Type != message.TypeQuestion {
continue
}
if _, err := insertHumanTaskTx(ctx, tx, s, humantask.Record{
WorkspaceID: value.WorkspaceID,
TopicID: value.TopicID,
RoleName: recipient,
PromptMessageID: value.ID,
Status: humantask.StatusPending,
CreatedAt: value.CreatedAt,
UpdatedAt: value.CreatedAt,
}); err != nil {
return message.Record{}, err
}
default:
if _, err := tx.ExecContext(ctx, `
INSERT INTO message_deliveries(message_id, recipient_role_name, state, delivered_at, updated_at)
VALUES(?, ?, ?, ?, ?)
`, value.ID, recipient, string(message.DeliveryPending), value.CreatedAt, value.CreatedAt); err != nil {
return message.Record{}, fmt.Errorf("insert message delivery: %w", err)
}
}
}
return value, nil
}
func resolveRecipientRolesTx(ctx context.Context, tx *sql.Tx, recipients []string) (map[string]role.Definition, error) {
if len(recipients) == 0 {
return map[string]role.Definition{}, nil
}
rows, err := tx.QueryContext(ctx, `
SELECT name, title, executor_kind, description, is_enabled, is_builtin, sort_order, created_at, updated_at
FROM roles
WHERE name IN (`+placeholders(len(recipients))+`)
`, stringSliceToAny(recipients)...)
if err != nil {
return nil, fmt.Errorf("list recipient role executors: %w", err)
}
defer rows.Close()
result := make(map[string]role.Definition, len(recipients))
for rows.Next() {
item, err := scanRole(rows)
if err != nil {
return nil, err
}
result[item.Name] = item
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate recipient role executors: %w", err)
}
return result, nil
}
func stringSliceToAny(values []string) []any {
out := make([]any, 0, len(values))
for _, value := range values {
out = append(out, value)
}
return out
}
func scanMessage(s scanner) (message.Record, error) {
var item message.Record
var msgType string
var replyTo sql.NullString
if err := s.Scan(&item.ID, &item.WorkspaceID, &item.TopicID, &item.FromRoleName, &item.ToExpr, &msgType, &item.Stage, &replyTo, &item.BodyMarkdown, &item.CreatedAt); err != nil {
return message.Record{}, err
}
item.Type = message.Type(msgType)
if replyTo.Valid {
item.ReplyToMessageID = replyTo.String
}
return item, nil
}