265 lines
7.3 KiB
Go
265 lines
7.3 KiB
Go
package topics
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
|
|
"inbox/internal/base/slug"
|
|
"inbox/internal/base/timeutil"
|
|
"inbox/internal/domain/lane"
|
|
"inbox/internal/domain/message"
|
|
"inbox/internal/domain/task"
|
|
"inbox/internal/domain/taskgraph"
|
|
"inbox/internal/domain/topic"
|
|
"inbox/internal/domain/workflow"
|
|
)
|
|
|
|
type Repository interface {
|
|
ListTopics(ctx context.Context, workspaceID string) ([]topic.Record, error)
|
|
CreateTopic(ctx context.Context, value topic.Record) (topic.Record, error)
|
|
GetTopic(ctx context.Context, topicID string) (topic.Record, error)
|
|
UpdateTopic(ctx context.Context, value topic.Record) (topic.Record, error)
|
|
DeleteTopic(ctx context.Context, topicID string) error
|
|
ListMessagesByTopic(ctx context.Context, topicID string) ([]message.Record, error)
|
|
CreateMessage(ctx context.Context, value message.Record) (message.Record, error)
|
|
ListLanesByTopic(ctx context.Context, topicID string) ([]lane.Record, error)
|
|
UpdateLane(ctx context.Context, value lane.Record) (lane.Record, error)
|
|
ListTasksByTopic(ctx context.Context, topicID string) ([]task.Record, error)
|
|
UpdateTask(ctx context.Context, value task.Record) (task.Record, error)
|
|
ListWorkflowRunsByTopic(ctx context.Context, topicID string) ([]workflow.Run, error)
|
|
UpdateWorkflowRun(ctx context.Context, value workflow.Run) (workflow.Run, error)
|
|
ArchiveMessageDeliveries(ctx context.Context, messageID string) error
|
|
CreateTaskGraphVersion(ctx context.Context, value taskgraph.Record) (taskgraph.Record, error)
|
|
GetLatestTaskGraphVersionByTopic(ctx context.Context, topicID string) (taskgraph.Record, error)
|
|
UpdateTaskGraphVersion(ctx context.Context, value taskgraph.Record) (taskgraph.Record, error)
|
|
}
|
|
|
|
type RuntimeManager interface {
|
|
EnsureLane(ctx context.Context, laneID string) (lane.Record, error)
|
|
StopLane(ctx context.Context, laneID string) (lane.Record, error)
|
|
}
|
|
|
|
type Service struct {
|
|
repo Repository
|
|
runtime RuntimeManager
|
|
clock timeutil.Clock
|
|
}
|
|
|
|
func NewService(repo Repository, runtime RuntimeManager, clock timeutil.Clock) *Service {
|
|
if clock == nil {
|
|
clock = timeutil.SystemClock{}
|
|
}
|
|
return &Service{repo: repo, runtime: runtime, clock: clock}
|
|
}
|
|
|
|
func (s *Service) List(ctx context.Context, workspaceID string) ([]topic.Record, error) {
|
|
return s.repo.ListTopics(ctx, workspaceID)
|
|
}
|
|
|
|
func (s *Service) Create(ctx context.Context, value topic.Record) (topic.Record, error) {
|
|
if value.Slug == "" {
|
|
value.Slug = normalizeSlug(value.Title)
|
|
}
|
|
return s.repo.CreateTopic(ctx, value)
|
|
}
|
|
|
|
func (s *Service) Get(ctx context.Context, topicID string) (topic.Record, error) {
|
|
return s.repo.GetTopic(ctx, topicID)
|
|
}
|
|
|
|
func (s *Service) Delete(ctx context.Context, topicID string) error {
|
|
return s.repo.DeleteTopic(ctx, topicID)
|
|
}
|
|
|
|
func (s *Service) Stop(ctx context.Context, topicID string) (topic.Record, error) {
|
|
current, err := s.repo.GetTopic(ctx, topicID)
|
|
if err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
|
|
now := timeutil.FormatRFC3339(s.clock.Now())
|
|
stopReason := "Stopped manually on user request."
|
|
|
|
lanes, err := s.repo.ListLanesByTopic(ctx, topicID)
|
|
if err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
for _, item := range lanes {
|
|
if isTerminalLaneStatus(item.Status) {
|
|
continue
|
|
}
|
|
if s.runtime != nil {
|
|
if _, err := s.runtime.StopLane(ctx, item.ID); err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
continue
|
|
}
|
|
item.Status = lane.StatusCancelled
|
|
item.RuntimeEndpoint = ""
|
|
item.ErrorMessage = stopReason
|
|
if item.CompletedAt == "" {
|
|
item.CompletedAt = now
|
|
}
|
|
if _, err := s.repo.UpdateLane(ctx, item); err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
}
|
|
|
|
tasks, err := s.repo.ListTasksByTopic(ctx, topicID)
|
|
if err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
for _, item := range tasks {
|
|
if isTerminalTaskStatus(item.Status) {
|
|
continue
|
|
}
|
|
item.Status = task.StatusCancelled
|
|
item.BlockingReasonMarkdown = stopReason
|
|
if item.CompletedAt == "" {
|
|
item.CompletedAt = now
|
|
}
|
|
if _, err := s.repo.UpdateTask(ctx, item); err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
}
|
|
|
|
runs, err := s.repo.ListWorkflowRunsByTopic(ctx, topicID)
|
|
if err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
for _, item := range runs {
|
|
if item.Status != workflow.RunStatusRunning {
|
|
continue
|
|
}
|
|
item.Status = workflow.RunStatusCancelled
|
|
item.ExitCode = 130
|
|
item.ErrorMessage = stopReason
|
|
if item.CompletedAt == "" {
|
|
item.CompletedAt = now
|
|
}
|
|
if _, err := s.repo.UpdateWorkflowRun(ctx, item); err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
}
|
|
|
|
messages, err := s.repo.ListMessagesByTopic(ctx, topicID)
|
|
if err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
for _, item := range messages {
|
|
if err := s.repo.ArchiveMessageDeliveries(ctx, item.ID); err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
}
|
|
|
|
current.Status = "cancelled"
|
|
if current.ClosedAt == "" {
|
|
current.ClosedAt = now
|
|
}
|
|
return s.repo.UpdateTopic(ctx, current)
|
|
}
|
|
|
|
func (s *Service) ConfirmPlan(ctx context.Context, topicID string) (topic.Record, error) {
|
|
current, err := s.repo.GetTopic(ctx, topicID)
|
|
if err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
if current.Status != "awaiting_confirmation" {
|
|
return current, nil
|
|
}
|
|
|
|
latestGraph, err := s.repo.GetLatestTaskGraphVersionByTopic(ctx, topicID)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return topic.Record{}, err
|
|
}
|
|
if err == nil && latestGraph.Status == taskgraph.StatusDraft {
|
|
latestGraph.Status = taskgraph.StatusActive
|
|
latestGraph.ConfirmedAt = timeutil.FormatRFC3339(s.clock.Now())
|
|
if _, err := s.repo.UpdateTaskGraphVersion(ctx, latestGraph); err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
}
|
|
|
|
current.Status = "execution"
|
|
current.ClosedAt = ""
|
|
current, err = s.repo.UpdateTopic(ctx, current)
|
|
if err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
if s.runtime == nil {
|
|
return current, nil
|
|
}
|
|
|
|
tasks, err := s.repo.ListTasksByTopic(ctx, topicID)
|
|
if err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
lanes, err := s.repo.ListLanesByTopic(ctx, topicID)
|
|
if err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
laneByID := make(map[string]lane.Record, len(lanes))
|
|
for _, item := range lanes {
|
|
laneByID[item.ID] = item
|
|
}
|
|
|
|
hasOpenGate := false
|
|
for _, item := range tasks {
|
|
if item.Kind == task.KindGate && item.Status != task.StatusSucceeded && item.Status != task.StatusCancelled {
|
|
hasOpenGate = true
|
|
break
|
|
}
|
|
}
|
|
seen := map[string]struct{}{}
|
|
for _, item := range tasks {
|
|
if item.Status != task.StatusReady {
|
|
continue
|
|
}
|
|
if hasOpenGate && item.Kind != task.KindGate {
|
|
continue
|
|
}
|
|
laneItem, ok := laneByID[item.LaneID]
|
|
if !ok {
|
|
continue
|
|
}
|
|
if _, exists := seen[laneItem.ID]; exists {
|
|
continue
|
|
}
|
|
seen[laneItem.ID] = struct{}{}
|
|
if _, err := s.runtime.EnsureLane(ctx, laneItem.ID); err != nil {
|
|
return topic.Record{}, err
|
|
}
|
|
}
|
|
return current, nil
|
|
}
|
|
|
|
func (s *Service) ListMessages(ctx context.Context, topicID string) ([]message.Record, error) {
|
|
return s.repo.ListMessagesByTopic(ctx, topicID)
|
|
}
|
|
|
|
func (s *Service) CreateMessage(ctx context.Context, value message.Record) (message.Record, error) {
|
|
return s.repo.CreateMessage(ctx, value)
|
|
}
|
|
|
|
func normalizeSlug(value string) string {
|
|
return slug.Normalize(value)
|
|
}
|
|
|
|
func isTerminalLaneStatus(status lane.Status) bool {
|
|
switch status {
|
|
case lane.StatusSucceeded, lane.StatusFailed, lane.StatusCancelled:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func isTerminalTaskStatus(status task.Status) bool {
|
|
switch status {
|
|
case task.StatusSucceeded, task.StatusFailed, task.StatusCancelled:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|