Files

265 lines
7.3 KiB
Go

package topics
import (
"context"
"database/sql"
"inbox/internal/base/slug"
"inbox/internal/base/timeutil"
"inbox/internal/domain/lane"
"inbox/internal/domain/message"
"inbox/internal/domain/task"
"inbox/internal/domain/taskgraph"
"inbox/internal/domain/topic"
"inbox/internal/domain/workflow"
)
type Repository interface {
ListTopics(ctx context.Context, workspaceID string) ([]topic.Record, error)
CreateTopic(ctx context.Context, value topic.Record) (topic.Record, error)
GetTopic(ctx context.Context, topicID string) (topic.Record, error)
UpdateTopic(ctx context.Context, value topic.Record) (topic.Record, error)
DeleteTopic(ctx context.Context, topicID string) error
ListMessagesByTopic(ctx context.Context, topicID string) ([]message.Record, error)
CreateMessage(ctx context.Context, value message.Record) (message.Record, error)
ListLanesByTopic(ctx context.Context, topicID string) ([]lane.Record, error)
UpdateLane(ctx context.Context, value lane.Record) (lane.Record, error)
ListTasksByTopic(ctx context.Context, topicID string) ([]task.Record, error)
UpdateTask(ctx context.Context, value task.Record) (task.Record, error)
ListWorkflowRunsByTopic(ctx context.Context, topicID string) ([]workflow.Run, error)
UpdateWorkflowRun(ctx context.Context, value workflow.Run) (workflow.Run, error)
ArchiveMessageDeliveries(ctx context.Context, messageID string) error
CreateTaskGraphVersion(ctx context.Context, value taskgraph.Record) (taskgraph.Record, error)
GetLatestTaskGraphVersionByTopic(ctx context.Context, topicID string) (taskgraph.Record, error)
UpdateTaskGraphVersion(ctx context.Context, value taskgraph.Record) (taskgraph.Record, error)
}
type RuntimeManager interface {
EnsureLane(ctx context.Context, laneID string) (lane.Record, error)
StopLane(ctx context.Context, laneID string) (lane.Record, error)
}
type Service struct {
repo Repository
runtime RuntimeManager
clock timeutil.Clock
}
func NewService(repo Repository, runtime RuntimeManager, clock timeutil.Clock) *Service {
if clock == nil {
clock = timeutil.SystemClock{}
}
return &Service{repo: repo, runtime: runtime, clock: clock}
}
func (s *Service) List(ctx context.Context, workspaceID string) ([]topic.Record, error) {
return s.repo.ListTopics(ctx, workspaceID)
}
func (s *Service) Create(ctx context.Context, value topic.Record) (topic.Record, error) {
if value.Slug == "" {
value.Slug = normalizeSlug(value.Title)
}
return s.repo.CreateTopic(ctx, value)
}
func (s *Service) Get(ctx context.Context, topicID string) (topic.Record, error) {
return s.repo.GetTopic(ctx, topicID)
}
func (s *Service) Delete(ctx context.Context, topicID string) error {
return s.repo.DeleteTopic(ctx, topicID)
}
func (s *Service) Stop(ctx context.Context, topicID string) (topic.Record, error) {
current, err := s.repo.GetTopic(ctx, topicID)
if err != nil {
return topic.Record{}, err
}
now := timeutil.FormatRFC3339(s.clock.Now())
stopReason := "Stopped manually on user request."
lanes, err := s.repo.ListLanesByTopic(ctx, topicID)
if err != nil {
return topic.Record{}, err
}
for _, item := range lanes {
if isTerminalLaneStatus(item.Status) {
continue
}
if s.runtime != nil {
if _, err := s.runtime.StopLane(ctx, item.ID); err != nil {
return topic.Record{}, err
}
continue
}
item.Status = lane.StatusCancelled
item.RuntimeEndpoint = ""
item.ErrorMessage = stopReason
if item.CompletedAt == "" {
item.CompletedAt = now
}
if _, err := s.repo.UpdateLane(ctx, item); err != nil {
return topic.Record{}, err
}
}
tasks, err := s.repo.ListTasksByTopic(ctx, topicID)
if err != nil {
return topic.Record{}, err
}
for _, item := range tasks {
if isTerminalTaskStatus(item.Status) {
continue
}
item.Status = task.StatusCancelled
item.BlockingReasonMarkdown = stopReason
if item.CompletedAt == "" {
item.CompletedAt = now
}
if _, err := s.repo.UpdateTask(ctx, item); err != nil {
return topic.Record{}, err
}
}
runs, err := s.repo.ListWorkflowRunsByTopic(ctx, topicID)
if err != nil {
return topic.Record{}, err
}
for _, item := range runs {
if item.Status != workflow.RunStatusRunning {
continue
}
item.Status = workflow.RunStatusCancelled
item.ExitCode = 130
item.ErrorMessage = stopReason
if item.CompletedAt == "" {
item.CompletedAt = now
}
if _, err := s.repo.UpdateWorkflowRun(ctx, item); err != nil {
return topic.Record{}, err
}
}
messages, err := s.repo.ListMessagesByTopic(ctx, topicID)
if err != nil {
return topic.Record{}, err
}
for _, item := range messages {
if err := s.repo.ArchiveMessageDeliveries(ctx, item.ID); err != nil {
return topic.Record{}, err
}
}
current.Status = "cancelled"
if current.ClosedAt == "" {
current.ClosedAt = now
}
return s.repo.UpdateTopic(ctx, current)
}
func (s *Service) ConfirmPlan(ctx context.Context, topicID string) (topic.Record, error) {
current, err := s.repo.GetTopic(ctx, topicID)
if err != nil {
return topic.Record{}, err
}
if current.Status != "awaiting_confirmation" {
return current, nil
}
latestGraph, err := s.repo.GetLatestTaskGraphVersionByTopic(ctx, topicID)
if err != nil && err != sql.ErrNoRows {
return topic.Record{}, err
}
if err == nil && latestGraph.Status == taskgraph.StatusDraft {
latestGraph.Status = taskgraph.StatusActive
latestGraph.ConfirmedAt = timeutil.FormatRFC3339(s.clock.Now())
if _, err := s.repo.UpdateTaskGraphVersion(ctx, latestGraph); err != nil {
return topic.Record{}, err
}
}
current.Status = "execution"
current.ClosedAt = ""
current, err = s.repo.UpdateTopic(ctx, current)
if err != nil {
return topic.Record{}, err
}
if s.runtime == nil {
return current, nil
}
tasks, err := s.repo.ListTasksByTopic(ctx, topicID)
if err != nil {
return topic.Record{}, err
}
lanes, err := s.repo.ListLanesByTopic(ctx, topicID)
if err != nil {
return topic.Record{}, err
}
laneByID := make(map[string]lane.Record, len(lanes))
for _, item := range lanes {
laneByID[item.ID] = item
}
hasOpenGate := false
for _, item := range tasks {
if item.Kind == task.KindGate && item.Status != task.StatusSucceeded && item.Status != task.StatusCancelled {
hasOpenGate = true
break
}
}
seen := map[string]struct{}{}
for _, item := range tasks {
if item.Status != task.StatusReady {
continue
}
if hasOpenGate && item.Kind != task.KindGate {
continue
}
laneItem, ok := laneByID[item.LaneID]
if !ok {
continue
}
if _, exists := seen[laneItem.ID]; exists {
continue
}
seen[laneItem.ID] = struct{}{}
if _, err := s.runtime.EnsureLane(ctx, laneItem.ID); err != nil {
return topic.Record{}, err
}
}
return current, nil
}
func (s *Service) ListMessages(ctx context.Context, topicID string) ([]message.Record, error) {
return s.repo.ListMessagesByTopic(ctx, topicID)
}
func (s *Service) CreateMessage(ctx context.Context, value message.Record) (message.Record, error) {
return s.repo.CreateMessage(ctx, value)
}
func normalizeSlug(value string) string {
return slug.Normalize(value)
}
func isTerminalLaneStatus(status lane.Status) bool {
switch status {
case lane.StatusSucceeded, lane.StatusFailed, lane.StatusCancelled:
return true
default:
return false
}
}
func isTerminalTaskStatus(status task.Status) bool {
switch status {
case task.StatusSucceeded, task.StatusFailed, task.StatusCancelled:
return true
default:
return false
}
}