Files

193 lines
5.5 KiB
Go

package sqlite
import (
"context"
"database/sql"
"fmt"
"strings"
"inbox/internal/domain/message"
"inbox/internal/domain/role"
)
var errClaimConflict = fmt.Errorf("delivery claim conflict")
func (s *Store) GetMessage(ctx context.Context, messageID string) (message.Record, error) {
row := s.db.QueryRowContext(ctx, `
SELECT id, workspace_id, topic_id, from_role_name, to_expr, type, stage, reply_to_message_id, body_markdown, created_at
FROM messages
WHERE id = ?
`, messageID)
return scanMessage(row)
}
func (s *Store) ClaimNextDelivery(ctx context.Context, workspaceID string, roleNames []string, staleBefore string) (message.DeliveryClaim, error) {
roleNames = normalizeRoleNames(roleNames)
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return message.DeliveryClaim{}, fmt.Errorf("begin claim delivery: %w", err)
}
defer tx.Rollback()
for attempt := 0; attempt < 5; attempt++ {
item, err := s.claimNextDeliveryTx(ctx, tx, workspaceID, roleNames, staleBefore)
if err == nil {
if err := tx.Commit(); err != nil {
return message.DeliveryClaim{}, fmt.Errorf("commit claim delivery: %w", err)
}
return item, nil
}
if err == sql.ErrNoRows {
return message.DeliveryClaim{}, err
}
if err == errClaimConflict {
continue
}
return message.DeliveryClaim{}, err
}
return message.DeliveryClaim{}, sql.ErrNoRows
}
func (s *Store) claimNextDeliveryTx(ctx context.Context, tx *sql.Tx, workspaceID string, roleNames []string, staleBefore string) (message.DeliveryClaim, error) {
query := `
SELECT m.id, m.workspace_id, m.topic_id, m.from_role_name, m.to_expr, m.type, m.stage, m.reply_to_message_id, m.body_markdown, m.created_at,
d.recipient_role_name, d.state, d.updated_at
FROM message_deliveries d
JOIN messages m ON m.id = d.message_id
JOIN roles r ON r.name = d.recipient_role_name
WHERE m.workspace_id = ?
AND r.executor_kind = ?
AND (d.state = ? OR (d.state = ? AND d.updated_at < ?))
`
args := []any{workspaceID, string(role.ExecutorKindCodex), string(message.DeliveryPending), string(message.DeliveryReceived), staleBefore}
if len(roleNames) > 0 {
query += " AND d.recipient_role_name IN (" + placeholders(len(roleNames)) + ")"
for _, roleName := range roleNames {
args = append(args, roleName)
}
}
query += `
ORDER BY CASE d.state WHEN ? THEN 0 ELSE 1 END, d.updated_at, m.created_at, m.id
LIMIT 1
`
args = append(args, string(message.DeliveryPending))
row := tx.QueryRowContext(ctx, query, args...)
var item message.DeliveryClaim
var msgType string
var replyTo sql.NullString
var state string
if err := row.Scan(
&item.Message.ID,
&item.Message.WorkspaceID,
&item.Message.TopicID,
&item.Message.FromRoleName,
&item.Message.ToExpr,
&msgType,
&item.Message.Stage,
&replyTo,
&item.Message.BodyMarkdown,
&item.Message.CreatedAt,
&item.RecipientRoleName,
&state,
&item.UpdatedAt,
); err != nil {
if err == sql.ErrNoRows {
return message.DeliveryClaim{}, sql.ErrNoRows
}
return message.DeliveryClaim{}, fmt.Errorf("select claim delivery: %w", err)
}
item.Message.Type = message.Type(msgType)
item.State = message.DeliveryState(state)
if replyTo.Valid {
item.Message.ReplyToMessageID = replyTo.String
}
now := s.now()
result, err := tx.ExecContext(ctx, `
UPDATE message_deliveries
SET state = ?, updated_at = ?
WHERE message_id = ? AND recipient_role_name = ? AND state = ? AND updated_at = ?
`,
string(message.DeliveryReceived),
now,
item.Message.ID,
item.RecipientRoleName,
string(item.State),
item.UpdatedAt,
)
if err != nil {
return message.DeliveryClaim{}, fmt.Errorf("claim delivery: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err == nil && rowsAffected == 0 {
return message.DeliveryClaim{}, errClaimConflict
}
item.State = message.DeliveryReceived
item.UpdatedAt = now
return item, nil
}
func (s *Store) TouchDelivery(ctx context.Context, messageID, roleName string) error {
if _, err := s.db.ExecContext(ctx, `
UPDATE message_deliveries
SET updated_at = ?
WHERE message_id = ? AND recipient_role_name = ? AND state = ?
`, s.now(), messageID, roleName, string(message.DeliveryReceived)); err != nil {
return fmt.Errorf("touch delivery: %w", err)
}
return nil
}
func (s *Store) ArchiveDelivery(ctx context.Context, messageID, roleName string) error {
now := s.now()
if _, err := s.db.ExecContext(ctx, `
UPDATE message_deliveries
SET state = ?, updated_at = ?
WHERE message_id = ? AND recipient_role_name = ?
`, string(message.DeliveryArchived), now, messageID, roleName); err != nil {
return fmt.Errorf("archive delivery: %w", err)
}
return nil
}
func (s *Store) ArchiveMessageDeliveries(ctx context.Context, messageID string) error {
now := s.now()
if _, err := s.db.ExecContext(ctx, `
UPDATE message_deliveries
SET state = ?, updated_at = ?
WHERE message_id = ? AND state != ?
`, string(message.DeliveryArchived), now, messageID, string(message.DeliveryArchived)); err != nil {
return fmt.Errorf("archive message deliveries: %w", err)
}
return nil
}
func placeholders(count int) string {
items := make([]string, 0, count)
for i := 0; i < count; i++ {
items = append(items, "?")
}
return strings.Join(items, ", ")
}
func normalizeRoleNames(values []string) []string {
seen := make(map[string]struct{}, len(values))
out := make([]string, 0, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
out = append(out, value)
}
return out
}