277 lines
8.0 KiB
Go
277 lines
8.0 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"sort"
|
|
"strings"
|
|
|
|
"inbox/internal/domain/humantask"
|
|
"inbox/internal/domain/message"
|
|
"inbox/internal/domain/role"
|
|
)
|
|
|
|
func (s *Store) CreateMessage(ctx context.Context, value message.Record) (message.Record, error) {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return message.Record{}, fmt.Errorf("begin create message: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
item, err := s.createMessageTx(ctx, tx, value)
|
|
if err != nil {
|
|
return message.Record{}, err
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return message.Record{}, fmt.Errorf("commit create message: %w", err)
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func (s *Store) ListMessagesByTopic(ctx context.Context, topicID string) ([]message.Record, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT id, workspace_id, topic_id, from_role_name, to_expr, type, stage, reply_to_message_id, body_markdown, created_at
|
|
FROM messages
|
|
WHERE topic_id = ?
|
|
ORDER BY created_at, id
|
|
`, topicID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list messages by topic: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []message.Record
|
|
for rows.Next() {
|
|
item, err := scanMessage(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate messages: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) ListMessagesByWorkspace(ctx context.Context, workspaceID string) ([]message.Record, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT id, workspace_id, topic_id, from_role_name, to_expr, type, stage, reply_to_message_id, body_markdown, created_at
|
|
FROM messages
|
|
WHERE workspace_id = ?
|
|
ORDER BY created_at, id
|
|
`, workspaceID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list messages by workspace: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []message.Record
|
|
for rows.Next() {
|
|
item, err := scanMessage(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate messages by workspace: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) ListPendingDeliveriesByWorkspace(ctx context.Context, workspaceID string) ([]message.PendingDelivery, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT m.topic_id, d.recipient_role_name, COUNT(*), MAX(d.updated_at)
|
|
FROM message_deliveries d
|
|
JOIN messages m ON m.id = d.message_id
|
|
JOIN roles r ON r.name = d.recipient_role_name
|
|
WHERE m.workspace_id = ? AND d.state = ? AND r.executor_kind = ?
|
|
GROUP BY m.topic_id, d.recipient_role_name
|
|
ORDER BY MAX(d.updated_at) DESC, m.topic_id, d.recipient_role_name
|
|
`, workspaceID, string(message.DeliveryPending), string(role.ExecutorKindCodex))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list pending deliveries by workspace: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []message.PendingDelivery
|
|
for rows.Next() {
|
|
var item message.PendingDelivery
|
|
if err := rows.Scan(&item.TopicID, &item.RoleName, &item.Count, &item.LastUpdated); err != nil {
|
|
return nil, fmt.Errorf("scan pending delivery: %w", err)
|
|
}
|
|
out = append(out, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate pending deliveries: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) expandRecipients(ctx context.Context, tx *sql.Tx, toExpr, fromRole string) ([]string, error) {
|
|
toExpr = strings.TrimSpace(toExpr)
|
|
if toExpr == "" {
|
|
return nil, fmt.Errorf("to expr is required")
|
|
}
|
|
if toExpr == "all" {
|
|
rows, err := tx.QueryContext(ctx, `SELECT name FROM roles WHERE is_enabled = 1 ORDER BY sort_order, name`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list recipient roles: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []string
|
|
for rows.Next() {
|
|
var name string
|
|
if err := rows.Scan(&name); err != nil {
|
|
return nil, fmt.Errorf("scan recipient role: %w", err)
|
|
}
|
|
if name == fromRole {
|
|
continue
|
|
}
|
|
out = append(out, name)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate recipient roles: %w", err)
|
|
}
|
|
if len(out) == 0 {
|
|
return nil, fmt.Errorf("no recipients resolved for %q", toExpr)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
parts := strings.Split(toExpr, ",")
|
|
out := make([]string, 0, len(parts))
|
|
seen := make(map[string]struct{}, len(parts))
|
|
for _, part := range parts {
|
|
name := strings.TrimSpace(part)
|
|
if name == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[name]; ok {
|
|
continue
|
|
}
|
|
seen[name] = struct{}{}
|
|
out = append(out, name)
|
|
}
|
|
sort.Strings(out)
|
|
if len(out) == 0 {
|
|
return nil, fmt.Errorf("no recipients resolved for %q", toExpr)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) createMessageTx(ctx context.Context, tx *sql.Tx, value message.Record) (message.Record, error) {
|
|
if err := value.Validate(); err != nil {
|
|
return message.Record{}, err
|
|
}
|
|
if value.ID == "" {
|
|
id, err := s.newID("message")
|
|
if err != nil {
|
|
return message.Record{}, err
|
|
}
|
|
value.ID = id
|
|
}
|
|
value.CreatedAt = coalesceString(value.CreatedAt, s.now())
|
|
|
|
if _, err := tx.ExecContext(ctx, `
|
|
INSERT INTO messages(id, workspace_id, topic_id, from_role_name, to_expr, type, stage, reply_to_message_id, body_markdown, created_at)
|
|
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
`, value.ID, value.WorkspaceID, value.TopicID, value.FromRoleName, value.ToExpr, string(value.Type), value.Stage, nullableString(value.ReplyToMessageID), value.BodyMarkdown, value.CreatedAt); err != nil {
|
|
return message.Record{}, fmt.Errorf("insert message: %w", err)
|
|
}
|
|
|
|
recipients, err := s.expandRecipients(ctx, tx, value.ToExpr, value.FromRoleName)
|
|
if err != nil {
|
|
return message.Record{}, err
|
|
}
|
|
recipientRoles, err := resolveRecipientRolesTx(ctx, tx, recipients)
|
|
if err != nil {
|
|
return message.Record{}, err
|
|
}
|
|
for _, recipient := range recipients {
|
|
definition, ok := recipientRoles[recipient]
|
|
if !ok {
|
|
continue
|
|
}
|
|
switch definition.ExecutorKind {
|
|
case role.ExecutorKindHuman:
|
|
if value.Type != message.TypeQuestion {
|
|
continue
|
|
}
|
|
if _, err := insertHumanTaskTx(ctx, tx, s, humantask.Record{
|
|
WorkspaceID: value.WorkspaceID,
|
|
TopicID: value.TopicID,
|
|
RoleName: recipient,
|
|
PromptMessageID: value.ID,
|
|
Status: humantask.StatusPending,
|
|
CreatedAt: value.CreatedAt,
|
|
UpdatedAt: value.CreatedAt,
|
|
}); err != nil {
|
|
return message.Record{}, err
|
|
}
|
|
default:
|
|
if _, err := tx.ExecContext(ctx, `
|
|
INSERT INTO message_deliveries(message_id, recipient_role_name, state, delivered_at, updated_at)
|
|
VALUES(?, ?, ?, ?, ?)
|
|
`, value.ID, recipient, string(message.DeliveryPending), value.CreatedAt, value.CreatedAt); err != nil {
|
|
return message.Record{}, fmt.Errorf("insert message delivery: %w", err)
|
|
}
|
|
}
|
|
}
|
|
return value, nil
|
|
}
|
|
|
|
func resolveRecipientRolesTx(ctx context.Context, tx *sql.Tx, recipients []string) (map[string]role.Definition, error) {
|
|
if len(recipients) == 0 {
|
|
return map[string]role.Definition{}, nil
|
|
}
|
|
rows, err := tx.QueryContext(ctx, `
|
|
SELECT name, title, executor_kind, description, is_enabled, is_builtin, sort_order, created_at, updated_at
|
|
FROM roles
|
|
WHERE name IN (`+placeholders(len(recipients))+`)
|
|
`, stringSliceToAny(recipients)...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list recipient role executors: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
result := make(map[string]role.Definition, len(recipients))
|
|
for rows.Next() {
|
|
item, err := scanRole(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result[item.Name] = item
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate recipient role executors: %w", err)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
func stringSliceToAny(values []string) []any {
|
|
out := make([]any, 0, len(values))
|
|
for _, value := range values {
|
|
out = append(out, value)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func scanMessage(s scanner) (message.Record, error) {
|
|
var item message.Record
|
|
var msgType string
|
|
var replyTo sql.NullString
|
|
if err := s.Scan(&item.ID, &item.WorkspaceID, &item.TopicID, &item.FromRoleName, &item.ToExpr, &msgType, &item.Stage, &replyTo, &item.BodyMarkdown, &item.CreatedAt); err != nil {
|
|
return message.Record{}, err
|
|
}
|
|
item.Type = message.Type(msgType)
|
|
if replyTo.Valid {
|
|
item.ReplyToMessageID = replyTo.String
|
|
}
|
|
return item, nil
|
|
}
|