227 lines
6.4 KiB
Go
227 lines
6.4 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"inbox/internal/domain/humantask"
|
|
"inbox/internal/domain/message"
|
|
)
|
|
|
|
func (s *Store) CreateHumanTask(ctx context.Context, value humantask.Record) (humantask.Record, error) {
|
|
if err := value.Validate(); err != nil {
|
|
return humantask.Record{}, err
|
|
}
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return humantask.Record{}, fmt.Errorf("begin create human task: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
item, err := insertHumanTaskTx(ctx, tx, s, value)
|
|
if err != nil {
|
|
return humantask.Record{}, err
|
|
}
|
|
if err := tx.Commit(); err != nil {
|
|
return humantask.Record{}, fmt.Errorf("commit create human task: %w", err)
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func insertHumanTaskTx(ctx context.Context, tx *sql.Tx, s *Store, value humantask.Record) (humantask.Record, error) {
|
|
if value.ID == "" {
|
|
id, err := s.newID("human-task")
|
|
if err != nil {
|
|
return humantask.Record{}, err
|
|
}
|
|
value.ID = id
|
|
}
|
|
value.Status = humantask.Status(strings.TrimSpace(string(value.Status)))
|
|
if value.CreatedAt == "" {
|
|
value.CreatedAt = s.now()
|
|
}
|
|
if value.UpdatedAt == "" {
|
|
value.UpdatedAt = value.CreatedAt
|
|
}
|
|
if err := value.Validate(); err != nil {
|
|
return humantask.Record{}, err
|
|
}
|
|
if _, err := tx.ExecContext(ctx, `
|
|
INSERT INTO human_tasks(id, workspace_id, topic_id, role_name, prompt_message_id, status, answered_message_id, created_at, updated_at)
|
|
VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
`,
|
|
value.ID,
|
|
value.WorkspaceID,
|
|
value.TopicID,
|
|
value.RoleName,
|
|
value.PromptMessageID,
|
|
string(value.Status),
|
|
nullableString(value.AnsweredMessageID),
|
|
value.CreatedAt,
|
|
value.UpdatedAt,
|
|
); err != nil {
|
|
return humantask.Record{}, fmt.Errorf("insert human task: %w", err)
|
|
}
|
|
return value, nil
|
|
}
|
|
|
|
func (s *Store) GetHumanTask(ctx context.Context, taskID string) (humantask.Record, error) {
|
|
row := s.db.QueryRowContext(ctx, `
|
|
SELECT id, workspace_id, topic_id, role_name, prompt_message_id, status, answered_message_id, created_at, updated_at
|
|
FROM human_tasks
|
|
WHERE id = ?
|
|
`, taskID)
|
|
return scanHumanTask(row)
|
|
}
|
|
|
|
func (s *Store) UpdateHumanTask(ctx context.Context, value humantask.Record) (humantask.Record, error) {
|
|
if err := value.Validate(); err != nil {
|
|
return humantask.Record{}, err
|
|
}
|
|
if value.UpdatedAt == "" {
|
|
value.UpdatedAt = s.now()
|
|
}
|
|
if _, err := s.db.ExecContext(ctx, `
|
|
UPDATE human_tasks
|
|
SET status = ?, answered_message_id = ?, updated_at = ?
|
|
WHERE id = ?
|
|
`,
|
|
string(value.Status),
|
|
nullableString(value.AnsweredMessageID),
|
|
value.UpdatedAt,
|
|
value.ID,
|
|
); err != nil {
|
|
return humantask.Record{}, fmt.Errorf("update human task: %w", err)
|
|
}
|
|
return s.GetHumanTask(ctx, value.ID)
|
|
}
|
|
|
|
func (s *Store) AnswerHumanTask(ctx context.Context, taskID string, reply message.Record) (humantask.Record, message.Record, error) {
|
|
tx, err := s.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return humantask.Record{}, message.Record{}, fmt.Errorf("begin answer human task: %w", err)
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
task, err := getHumanTaskTx(ctx, tx, taskID)
|
|
if err != nil {
|
|
return humantask.Record{}, message.Record{}, err
|
|
}
|
|
if task.Status != humantask.StatusPending {
|
|
return humantask.Record{}, message.Record{}, fmt.Errorf("human task %s is not pending", task.ID)
|
|
}
|
|
|
|
savedReply, err := s.createMessageTx(ctx, tx, reply)
|
|
if err != nil {
|
|
return humantask.Record{}, message.Record{}, err
|
|
}
|
|
|
|
task.Status = humantask.StatusAnswered
|
|
task.AnsweredMessageID = savedReply.ID
|
|
task.UpdatedAt = coalesceString(reply.CreatedAt, s.now())
|
|
if _, err := tx.ExecContext(ctx, `
|
|
UPDATE human_tasks
|
|
SET status = ?, answered_message_id = ?, updated_at = ?
|
|
WHERE id = ?
|
|
`, string(task.Status), nullableString(task.AnsweredMessageID), task.UpdatedAt, task.ID); err != nil {
|
|
return humantask.Record{}, message.Record{}, fmt.Errorf("update answered human task: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return humantask.Record{}, message.Record{}, fmt.Errorf("commit answer human task: %w", err)
|
|
}
|
|
|
|
updated, err := s.GetHumanTask(ctx, taskID)
|
|
if err != nil {
|
|
return humantask.Record{}, message.Record{}, err
|
|
}
|
|
return updated, savedReply, nil
|
|
}
|
|
|
|
func (s *Store) ListHumanTasksByTopic(ctx context.Context, topicID string) ([]humantask.Record, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT id, workspace_id, topic_id, role_name, prompt_message_id, status, answered_message_id, created_at, updated_at
|
|
FROM human_tasks
|
|
WHERE topic_id = ?
|
|
ORDER BY created_at, id
|
|
`, topicID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list human tasks by topic: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []humantask.Record
|
|
for rows.Next() {
|
|
item, err := scanHumanTask(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate human tasks by topic: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func (s *Store) ListPendingHumanTasksByWorkspace(ctx context.Context, workspaceID string) ([]humantask.Record, error) {
|
|
rows, err := s.db.QueryContext(ctx, `
|
|
SELECT id, workspace_id, topic_id, role_name, prompt_message_id, status, answered_message_id, created_at, updated_at
|
|
FROM human_tasks
|
|
WHERE workspace_id = ? AND status = ?
|
|
ORDER BY updated_at DESC, created_at DESC, id DESC
|
|
`, workspaceID, string(humantask.StatusPending))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list pending human tasks by workspace: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var out []humantask.Record
|
|
for rows.Next() {
|
|
item, err := scanHumanTask(rows)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate pending human tasks by workspace: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
func scanHumanTask(s scanner) (humantask.Record, error) {
|
|
var item humantask.Record
|
|
var status string
|
|
var answeredMessageID sql.NullString
|
|
if err := s.Scan(
|
|
&item.ID,
|
|
&item.WorkspaceID,
|
|
&item.TopicID,
|
|
&item.RoleName,
|
|
&item.PromptMessageID,
|
|
&status,
|
|
&answeredMessageID,
|
|
&item.CreatedAt,
|
|
&item.UpdatedAt,
|
|
); err != nil {
|
|
return humantask.Record{}, err
|
|
}
|
|
item.Status = humantask.Status(status)
|
|
if answeredMessageID.Valid {
|
|
item.AnsweredMessageID = answeredMessageID.String
|
|
}
|
|
return item, nil
|
|
}
|
|
|
|
func getHumanTaskTx(ctx context.Context, tx *sql.Tx, taskID string) (humantask.Record, error) {
|
|
row := tx.QueryRowContext(ctx, `
|
|
SELECT id, workspace_id, topic_id, role_name, prompt_message_id, status, answered_message_id, created_at, updated_at
|
|
FROM human_tasks
|
|
WHERE id = ?
|
|
`, taskID)
|
|
return scanHumanTask(row)
|
|
}
|