Implement inbox read cursors for unread threads
This commit is contained in:
+62
-6
@@ -191,6 +191,7 @@ type WaitReplyInput struct {
|
||||
AfterMessageID string
|
||||
AfterEventID int64
|
||||
Kinds []string
|
||||
Agent string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
@@ -412,7 +413,8 @@ func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread
|
||||
}
|
||||
|
||||
var (
|
||||
args []any
|
||||
joinArgs []any
|
||||
whereArgs []any
|
||||
conditions []string
|
||||
joins []string
|
||||
)
|
||||
@@ -424,16 +426,16 @@ func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread
|
||||
|
||||
if assignedTo != "" {
|
||||
conditions = append(conditions, "t.assigned_to = ?")
|
||||
args = append(args, assignedTo)
|
||||
whereArgs = append(whereArgs, assignedTo)
|
||||
}
|
||||
if input.CreatedBy != "" {
|
||||
conditions = append(conditions, "t.created_by = ?")
|
||||
args = append(args, input.CreatedBy)
|
||||
whereArgs = append(whereArgs, input.CreatedBy)
|
||||
}
|
||||
if len(input.Statuses) > 0 {
|
||||
conditions = append(conditions, "t.status IN ("+placeholders(len(input.Statuses))+")")
|
||||
for _, status := range input.Statuses {
|
||||
args = append(args, status)
|
||||
whereArgs = append(whereArgs, status)
|
||||
}
|
||||
}
|
||||
if input.Unread {
|
||||
@@ -441,10 +443,13 @@ func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread
|
||||
return nil, fmt.Errorf("%w: agent is required when filtering unread threads", ErrInvalidInput)
|
||||
}
|
||||
joins = append(joins, "JOIN messages lm ON lm.message_id = t.latest_message_id")
|
||||
joins = append(joins, "LEFT JOIN thread_reads tr ON tr.thread_id = t.thread_id AND tr.agent_id = ?")
|
||||
joinArgs = append(joinArgs, input.Agent)
|
||||
conditions = append(conditions, "lm.to_agent = ?")
|
||||
args = append(args, input.Agent)
|
||||
whereArgs = append(whereArgs, input.Agent)
|
||||
conditions = append(conditions, "lm.from_agent <> ?")
|
||||
args = append(args, input.Agent)
|
||||
whereArgs = append(whereArgs, input.Agent)
|
||||
conditions = append(conditions, "(tr.last_read_message_id IS NULL OR tr.last_read_message_id <> t.latest_message_id)")
|
||||
}
|
||||
|
||||
query := `SELECT
|
||||
@@ -458,6 +463,7 @@ func (s *InboxStore) ListThreads(ctx context.Context, input ListInput) ([]Thread
|
||||
query += " WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
query += " ORDER BY t.updated_at DESC LIMIT ?"
|
||||
args := append(joinArgs, whereArgs...)
|
||||
args = append(args, limit)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
@@ -1078,6 +1084,10 @@ func (s *InboxStore) CancelThread(ctx context.Context, input CancelInput) (Threa
|
||||
}
|
||||
|
||||
func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDetail, error) {
|
||||
return s.GetThreadForAgent(ctx, threadID, "", false)
|
||||
}
|
||||
|
||||
func (s *InboxStore) GetThreadForAgent(ctx context.Context, threadID, agent string, markRead bool) (ThreadDetail, error) {
|
||||
thread, err := selectThread(ctx, s.db, threadID)
|
||||
if err != nil {
|
||||
return ThreadDetail{}, err
|
||||
@@ -1117,6 +1127,12 @@ func (s *InboxStore) GetThread(ctx context.Context, threadID string) (ThreadDeta
|
||||
}
|
||||
attachArtifacts(messages, artifactsByMessageID)
|
||||
|
||||
if markRead {
|
||||
if err := markThreadRead(ctx, s.db, thread.ThreadID, agent, thread.LatestMessageID, nowUTC()); err != nil {
|
||||
return ThreadDetail{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return ThreadDetail{
|
||||
Thread: thread,
|
||||
Messages: messages,
|
||||
@@ -1204,6 +1220,11 @@ func (s *InboxStore) WaitReply(ctx context.Context, input WaitReplyInput) (WaitR
|
||||
return WaitReplyResult{}, err
|
||||
}
|
||||
if found {
|
||||
if shouldMarkMessageRead(message, input.Agent) {
|
||||
if err := markThreadRead(waitCtx, s.db, input.ThreadID, input.Agent, message.MessageID, nowUTC()); err != nil {
|
||||
return WaitReplyResult{}, err
|
||||
}
|
||||
}
|
||||
return WaitReplyResult{
|
||||
Woke: true,
|
||||
NextEventID: eventID,
|
||||
@@ -1363,6 +1384,10 @@ type queryRower interface {
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
type execContexter interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
type eventInput struct {
|
||||
RunID string
|
||||
TaskID string
|
||||
@@ -1482,6 +1507,30 @@ func updateThreadState(ctx context.Context, tx *sql.Tx, threadID, status, assign
|
||||
return nil
|
||||
}
|
||||
|
||||
func markThreadRead(ctx context.Context, execer execContexter, threadID, agent, messageID string, readAt time.Time) error {
|
||||
if agent == "" || messageID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := execer.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO thread_reads (
|
||||
thread_id, agent_id, last_read_message_id, last_read_at
|
||||
) VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(thread_id, agent_id) DO UPDATE SET
|
||||
last_read_message_id = excluded.last_read_message_id,
|
||||
last_read_at = excluded.last_read_at`,
|
||||
threadID,
|
||||
agent,
|
||||
messageID,
|
||||
formatTime(readAt),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("mark thread read: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func loadArtifactsForMessageIDs(ctx context.Context, db *sql.DB, messageIDs []string) (map[string][]Artifact, error) {
|
||||
result := make(map[string][]Artifact)
|
||||
if len(messageIDs) == 0 {
|
||||
@@ -1809,6 +1858,13 @@ func isSQLiteBusyError(err error) bool {
|
||||
strings.Contains(message, "database table is locked")
|
||||
}
|
||||
|
||||
func shouldMarkMessageRead(message Message, agent string) bool {
|
||||
if agent == "" {
|
||||
return false
|
||||
}
|
||||
return message.ToAgent == agent && message.FromAgent != agent
|
||||
}
|
||||
|
||||
func defaultID(value, prefix string) string {
|
||||
if value != "" {
|
||||
return value
|
||||
|
||||
Reference in New Issue
Block a user