Implement inbox read cursors for unread threads

This commit is contained in:
2026-03-19 03:43:10 +08:00
parent 02d98a78dd
commit 1927930570
19 changed files with 240 additions and 38 deletions
+62 -6
View File
@@ -191,6 +191,7 @@ type WaitReplyInput struct {
AfterMessageID string
AfterEventID int64
Kinds []string
Agent string
Timeout time.Duration
}
@@ -412,7 +413,8 @@ func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread
}
var (
args []any
joinArgs []any
whereArgs []any
conditions []string
joins []string
)
@@ -424,16 +426,16 @@ func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread
if assignedTo != "" {
conditions = append(conditions, "t.assigned_to = ?")
args = append(args, assignedTo)
whereArgs = append(whereArgs, assignedTo)
}
if input.CreatedBy != "" {
conditions = append(conditions, "t.created_by = ?")
args = append(args, input.CreatedBy)
whereArgs = append(whereArgs, input.CreatedBy)
}
if len(input.Statuses) > 0 {
conditions = append(conditions, "t.status IN ("+placeholders(len(input.Statuses))+")")
for _, status := range input.Statuses {
args = append(args, status)
whereArgs = append(whereArgs, status)
}
}
if input.Unread {
@@ -441,10 +443,13 @@ func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread
return nil, fmt.Errorf("%w: agent is required when filtering unread threads", ErrInvalidInput)
}
joins = append(joins, "JOIN messages lm ON lm.message_id = t.latest_message_id")
joins = append(joins, "LEFT JOIN thread_reads tr ON tr.thread_id = t.thread_id AND tr.agent_id = ?")
joinArgs = append(joinArgs, input.Agent)
conditions = append(conditions, "lm.to_agent = ?")
args = append(args, input.Agent)
whereArgs = append(whereArgs, input.Agent)
conditions = append(conditions, "lm.from_agent <> ?")
args = append(args, input.Agent)
whereArgs = append(whereArgs, input.Agent)
conditions = append(conditions, "(tr.last_read_message_id IS NULL OR tr.last_read_message_id <> t.latest_message_id)")
}
query := `SELECT
@@ -458,6 +463,7 @@ func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread
query += " WHERE " + strings.Join(conditions, " AND ")
}
query += " ORDER BY t.updated_at DESC LIMIT ?"
args := append(joinArgs, whereArgs...)
args = append(args, limit)
rows, err := s.db.QueryContext(ctx, query, args...)
@@ -1078,6 +1084,10 @@ func (s *InboxStore) CancelThread(ctx context.Context, input CancelInput) (Threa
}
func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDetail, error) {
return s.GetThreadForAgent(ctx, threadID, "", false)
}
func (s *InboxStore) GetThreadForAgent(ctx context.Context, threadID, agent string, markRead bool) (ThreadDetail, error) {
thread, err := selectThread(ctx, s.db, threadID)
if err != nil {
return ThreadDetail{}, err
@@ -1117,6 +1127,12 @@ func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDeta
}
attachArtifacts(messages, artifactsByMessageID)
if markRead {
if err := markThreadRead(ctx, s.db, thread.ThreadID, agent, thread.LatestMessageID, nowUTC()); err != nil {
return ThreadDetail{}, err
}
}
return ThreadDetail{
Thread: thread,
Messages: messages,
@@ -1204,6 +1220,11 @@ func (s *InboxStore) WaitReply(ctx context.Context, input WaitReplyInput) (WaitR
return WaitReplyResult{}, err
}
if found {
if shouldMarkMessageRead(message, input.Agent) {
if err := markThreadRead(waitCtx, s.db, input.ThreadID, input.Agent, message.MessageID, nowUTC()); err != nil {
return WaitReplyResult{}, err
}
}
return WaitReplyResult{
Woke: true,
NextEventID: eventID,
@@ -1363,6 +1384,10 @@ type queryRower interface {
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type execContexter interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
type eventInput struct {
RunID string
TaskID string
@@ -1482,6 +1507,30 @@ func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assign
return nil
}
func markThreadRead(ctx context.Context, execer execContexter, threadID, agent, messageID string, readAt time.Time) error {
if agent == "" || messageID == "" {
return nil
}
_, err := execer.ExecContext(
ctx,
`INSERT INTO thread_reads (
thread_id, agent_id, last_read_message_id, last_read_at
) VALUES (?, ?, ?, ?)
ON CONFLICT(thread_id, agent_id) DO UPDATE SET
last_read_message_id = excluded.last_read_message_id,
last_read_at = excluded.last_read_at`,
threadID,
agent,
messageID,
formatTime(readAt),
)
if err != nil {
return fmt.Errorf("mark thread read: %w", err)
}
return nil
}
func loadArtifactsForMessageIDs(ctx context.Context, db *sql.DB, messageIDs []string) (map[string][]Artifact, error) {
result := make(map[string][]Artifact)
if len(messageIDs) == 0 {
@@ -1809,6 +1858,13 @@ func isSQLiteBusyError(err error) bool {
strings.Contains(message, "database table is locked")
}
func shouldMarkMessageRead(message Message, agent string) bool {
if agent == "" {
return false
}
return message.ToAgent == agent && message.FromAgent != agent
}
func defaultID(value, prefix string) string {
if value != "" {
return value