diff --git a/packages/core/src/checkpoint/checkpoint-events.ts b/packages/core/src/checkpoint/checkpoint-events.ts new file mode 100644 index 0000000..4880a5c --- /dev/null +++ b/packages/core/src/checkpoint/checkpoint-events.ts @@ -0,0 +1,110 @@ +/** + * 检查点事件系统 + * 负责事件的发布和订阅 + */ + +import type { + CheckpointEvent, + CheckpointEventListener, + CheckpointMetadata, +} from './types.js'; + +/** + * 检查点事件管理器 + */ +export class CheckpointEvents { + private listeners: Set = new Set(); + + /** + * 添加事件监听器 + */ + addEventListener(listener: CheckpointEventListener): void { + this.listeners.add(listener); + } + + /** + * 移除事件监听器 + */ + removeEventListener(listener: CheckpointEventListener): void { + this.listeners.delete(listener); + } + + /** + * 触发事件 + */ + emit(event: CheckpointEvent): void { + for (const listener of this.listeners) { + try { + listener(event); + } catch (error) { + console.warn('Checkpoint event listener error:', error); + } + } + } + + /** + * 触发创建事件 + */ + emitCreated(checkpoint: CheckpointMetadata): void { + this.emit({ + type: 'created', + checkpoint, + timestamp: Date.now(), + }); + } + + /** + * 触发恢复事件 + */ + emitRestored( + checkpoint: CheckpointMetadata, + details: { + files: string[]; + previousCommit: string; + mode: string; + } + ): void { + this.emit({ + type: 'restored', + checkpoint, + timestamp: Date.now(), + details, + }); + } + + /** + * 触发删除事件 + */ + emitDeleted(checkpoint: CheckpointMetadata): void { + this.emit({ + type: 'deleted', + checkpoint, + timestamp: Date.now(), + }); + } + + /** + * 触发清理事件 + */ + emitCleanup(deletedCount: number): void { + this.emit({ + type: 'cleanup', + timestamp: Date.now(), + details: { deletedCount }, + }); + } + + /** + * 获取监听器数量 + */ + get listenerCount(): number { + return this.listeners.size; + } + + /** + * 清除所有监听器 + */ + clearListeners(): void { + this.listeners.clear(); + } +} diff --git a/packages/core/src/checkpoint/checkpoint-rollback.ts b/packages/core/src/checkpoint/checkpoint-rollback.ts new file mode 100644 index 0000000..7b76d1a --- /dev/null +++ b/packages/core/src/checkpoint/checkpoint-rollback.ts @@ -0,0 +1,289 @@ +/** + * 检查点回滚 + * 负责回滚相关操作 + */ + +import { nanoid } from 'nanoid'; +import type { ShadowGit } from './shadow-git.js'; +import type { CheckpointLock } from './lock.js'; +import type { CheckpointSafetyChecker } from './safety.js'; +import type { CheckpointStore } from './checkpoint-store.js'; +import { + RestoreMode, + type CheckpointMetadata, + type RollbackOptions, + type RollbackResult, + type RollbackRecord, + type UnrevertResult, + type DiffInfo, +} from './types.js'; + +/** + * 检查点回滚 + */ +export class CheckpointRollback { + private shadowGit: ShadowGit; + private lock: CheckpointLock; + private safetyChecker: CheckpointSafetyChecker; + private store: CheckpointStore; + private lastRollback: RollbackRecord | null = null; + + constructor( + shadowGit: ShadowGit, + lock: CheckpointLock, + safetyChecker: CheckpointSafetyChecker, + store: CheckpointStore + ) { + this.shadowGit = shadowGit; + this.lock = lock; + this.safetyChecker = safetyChecker; + this.store = store; + } + + /** + * 回滚到指定检查点 + */ + async rollback( + options: RollbackOptions, + onEvent?: (event: { type: string; checkpoint: CheckpointMetadata; details?: unknown }) => void + ): Promise { + const checkpoint = this.store.get(options.target); + if (!checkpoint) { + throw new Error(`Checkpoint not found: ${options.target}`); + } + + // 安全检查(除非明确跳过) + if (!options.skipSafetyCheck) { + const safetyResult = await this.safetyChecker.checkBeforeRollback(checkpoint, { + getCheckpoint: (id: string) => Promise.resolve(this.store.get(id)), + listCheckpoints: () => Promise.resolve(this.store.list()), + getDiff: (checkpointId: string) => this.getDiffById(checkpointId), + }); + if (!safetyResult.safe) { + const errorMsg = safetyResult.errors.join('; '); + throw new Error(`Safety check failed: ${errorMsg}`); + } + if (safetyResult.warnings.length > 0) { + console.warn('Rollback warnings:', safetyResult.warnings.join('; ')); + } + } + + // 使用锁保护回滚操作 + return this.lock.withLock(async () => { + const previousCommit = await this.shadowGit.getHead(); + + // 预览模式 + if (options.dryRun) { + const diff = await this.getDiff(checkpoint); + return { + success: true, + restoredFiles: diff.files.map((f) => f.path), + errors: [], + previousCommit, + }; + } + + // 创建回滚前检查点(用于 unrevert) + let preRollbackCheckpoint: CheckpointMetadata | null = null; + try { + preRollbackCheckpoint = await this.store.createInternal({ + trigger: 'pre_rollback', + description: `Before rollback to ${options.target}`, + }); + } catch { + // 忽略创建失败 + } + + const result: RollbackResult = { + success: true, + restoredFiles: [], + errors: [], + previousCommit, + }; + + try { + const mode = options.mode || RestoreMode.FULL; + + if (options.files && options.files.length > 0) { + // 选择性回滚(指定文件) + await this.shadowGit.checkoutFiles(checkpoint.commitHash, options.files); + result.restoredFiles = options.files; + } else if (mode === RestoreMode.AI_CHANGES_ONLY) { + // 仅恢复 AI 修改的文件 + const aiFiles = await this.getAiModifiedFiles(checkpoint); + if (aiFiles.length > 0) { + await this.shadowGit.checkoutFiles(checkpoint.commitHash, aiFiles); + result.restoredFiles = aiFiles; + } + } else if (mode === RestoreMode.WORKSPACE_ONLY) { + // 仅恢复工作区变更(不包括 AI 修改) + const workspaceFiles = await this.getWorkspaceOnlyFiles(checkpoint); + if (workspaceFiles.length > 0) { + await this.shadowGit.checkoutFiles(checkpoint.commitHash, workspaceFiles); + result.restoredFiles = workspaceFiles; + } + } else { + // 完整回滚 + await this.shadowGit.resetHard(checkpoint.commitHash); + + // 获取恢复的文件列表 + const diff = await this.shadowGit.getDiffSummary( + previousCommit, + checkpoint.commitHash + ); + result.restoredFiles = diff.files.map((f) => f.path); + } + + // 记录回滚信息(用于 unrevert) + this.lastRollback = { + id: nanoid(10), + timestamp: Date.now(), + targetCheckpoint: checkpoint.id, + previousCommit: preRollbackCheckpoint?.commitHash || previousCommit, + restoredFiles: result.restoredFiles, + canUnrevert: true, + }; + + // 触发事件 + onEvent?.({ + type: 'restored', + checkpoint, + details: { + files: result.restoredFiles, + previousCommit, + mode, + }, + }); + } catch (error) { + result.success = false; + result.errors.push({ + file: '*', + error: error instanceof Error ? error.message : String(error), + }); + } + + return result; + }); + } + + /** + * 撤销操作(回滚到上一个检查点) + */ + async undo( + onEvent?: (event: { type: string; checkpoint: CheckpointMetadata; details?: unknown }) => void + ): Promise { + const latest = this.store.getLatest(); + if (!latest) { + throw new Error('No checkpoints available'); + } + + const checkpoints = this.store.list(); + if (checkpoints.length < 2) { + return this.rollback({ target: latest.id }, onEvent); + } + + return this.rollback({ target: checkpoints[1].id }, onEvent); + } + + /** + * 撤销最近一次回滚 + */ + async unrevert(): Promise { + if (!this.lastRollback || !this.lastRollback.canUnrevert) { + return { + success: false, + restoredCommit: '', + filesRestored: 0, + error: 'No rollback to unrevert', + }; + } + + return this.lock.withLock(async () => { + try { + await this.shadowGit.resetHard(this.lastRollback!.previousCommit); + + const result: UnrevertResult = { + success: true, + restoredCommit: this.lastRollback!.previousCommit, + filesRestored: this.lastRollback!.restoredFiles.length, + }; + + this.lastRollback = null; + + return result; + } catch (error) { + return { + success: false, + restoredCommit: '', + filesRestored: 0, + error: error instanceof Error ? error.message : String(error), + }; + } + }); + } + + /** + * 检查是否可以 unrevert + */ + canUnrevert(): boolean { + return this.lastRollback !== null && this.lastRollback.canUnrevert; + } + + /** + * 获取最后一次回滚记录 + */ + getLastRollback(): RollbackRecord | null { + return this.lastRollback; + } + + /** + * 获取检查点与当前工作区的差异 + */ + private async getDiff(checkpoint: CheckpointMetadata): Promise { + return this.shadowGit.getDiffSummary(checkpoint.commitHash, 'HEAD'); + } + + /** + * 根据 ID 获取检查点与当前工作区的差异 + */ + private async getDiffById(checkpointId: string): Promise { + const checkpoint = this.store.get(checkpointId); + if (!checkpoint) { + return { from: '', to: 'HEAD', files: [], totalInsertions: 0, totalDeletions: 0 }; + } + return this.getDiff(checkpoint); + } + + /** + * 获取 AI 修改的文件列表 + */ + private async getAiModifiedFiles(checkpoint: CheckpointMetadata): Promise { + const files: string[] = []; + const checkpoints = this.store.list(); + + for (const cp of checkpoints) { + if (cp.timestamp > checkpoint.timestamp && cp.toolCall) { + const filePath = + (cp.toolCall.params.file_path as string) || + (cp.toolCall.params.path as string); + if (filePath && !files.includes(filePath)) { + files.push(filePath); + } + } + } + + return files; + } + + /** + * 获取仅工作区变更的文件(不包括 AI 修改) + */ + private async getWorkspaceOnlyFiles(checkpoint: CheckpointMetadata): Promise { + const diff = await this.getDiff(checkpoint); + const aiFiles = await this.getAiModifiedFiles(checkpoint); + + return diff.files + .map((f) => f.path) + .filter((path) => !aiFiles.includes(path)); + } +} diff --git a/packages/core/src/checkpoint/checkpoint-session.ts b/packages/core/src/checkpoint/checkpoint-session.ts new file mode 100644 index 0000000..2095194 --- /dev/null +++ b/packages/core/src/checkpoint/checkpoint-session.ts @@ -0,0 +1,178 @@ +/** + * 检查点会话追踪 + * 负责会话与检查点的关联 + */ + +import { nanoid } from 'nanoid'; +import type { CheckpointStore } from './checkpoint-store.js'; +import type { CheckpointMetadata, CheckpointTrigger } from './types.js'; + +/** + * 检查点会话追踪 + */ +export class CheckpointSession { + private store: CheckpointStore; + private currentSessionId: string | null = null; + private sessionCheckpoints: Map = new Map(); + + constructor(store: CheckpointStore) { + this.store = store; + } + + /** + * 开始新会话 + */ + async startSession( + sessionId?: string, + onEvent?: (event: { type: string; checkpoint?: CheckpointMetadata }) => void + ): Promise { + const id = sessionId || nanoid(10); + this.currentSessionId = id; + this.sessionCheckpoints.set(id, []); + + // 创建会话开始检查点 + try { + const checkpoint = await this.store.create( + { + trigger: 'session_start', + description: `Session started: ${id}`, + sessionId: id, + }, + id, + (checkpointId) => { + this.recordCheckpoint(checkpointId); + } + ); + onEvent?.({ type: 'created', checkpoint }); + } catch { + // 忽略创建失败 + } + + return id; + } + + /** + * 结束当前会话 + */ + async endSession( + onEvent?: (event: { type: string; checkpoint?: CheckpointMetadata }) => void + ): Promise { + if (!this.currentSessionId) return; + + // 创建会话结束检查点 + try { + const checkpoint = await this.store.create( + { + trigger: 'session_end', + description: `Session ended: ${this.currentSessionId}`, + sessionId: this.currentSessionId, + }, + this.currentSessionId, + (checkpointId) => { + this.recordCheckpoint(checkpointId); + } + ); + onEvent?.({ type: 'created', checkpoint }); + } catch { + // 忽略创建失败 + } + + this.currentSessionId = null; + } + + /** + * 获取当前会话 ID + */ + getCurrentSessionId(): string | null { + return this.currentSessionId; + } + + /** + * 设置当前会话 ID(用于外部恢复) + */ + setCurrentSessionId(sessionId: string | null): void { + this.currentSessionId = sessionId; + } + + /** + * 获取会话的所有检查点 + */ + getSessionCheckpoints(sessionId: string): CheckpointMetadata[] { + const checkpointIds = this.sessionCheckpoints.get(sessionId); + if (!checkpointIds) return []; + + const checkpoints: CheckpointMetadata[] = []; + for (const id of checkpointIds) { + const checkpoint = this.store.get(id); + if (checkpoint) { + checkpoints.push(checkpoint); + } + } + + return checkpoints.sort((a, b) => a.timestamp - b.timestamp); + } + + /** + * 创建与消息关联的检查点 + */ + async createMessageCheckpoint( + messageId: string, + turnIndex?: number, + options?: { + trigger?: CheckpointTrigger; + description?: string; + }, + onEvent?: (event: { type: string; checkpoint: CheckpointMetadata }) => void + ): Promise { + const checkpoint = await this.store.create( + { + trigger: options?.trigger || 'auto', + description: options?.description || `Message checkpoint: ${messageId}`, + messageId, + sessionId: this.currentSessionId || undefined, + turnIndex, + }, + this.currentSessionId, + (checkpointId) => { + this.recordCheckpoint(checkpointId); + } + ); + + onEvent?.({ type: 'created', checkpoint }); + + return checkpoint; + } + + /** + * 获取与消息关联的检查点 + */ + getMessageCheckpoints(messageId: string): CheckpointMetadata[] { + const checkpoints: CheckpointMetadata[] = []; + for (const checkpoint of this.store.list()) { + if (checkpoint.messageId === messageId) { + checkpoints.push(checkpoint); + } + } + + return checkpoints.sort((a, b) => a.timestamp - b.timestamp); + } + + /** + * 记录检查点到当前会话 + */ + recordCheckpoint(checkpointId: string): void { + if (!this.currentSessionId) return; + + const sessionCps = this.sessionCheckpoints.get(this.currentSessionId) || []; + sessionCps.push(checkpointId); + this.sessionCheckpoints.set(this.currentSessionId, sessionCps); + } + + /** + * 获取会话开始的检查点 + */ + getSessionStartCheckpoint(sessionId: string): CheckpointMetadata | null { + const checkpoints = this.getSessionCheckpoints(sessionId); + return checkpoints.find((cp) => cp.trigger === 'session_start') || null; + } +} diff --git a/packages/core/src/checkpoint/checkpoint-store.ts b/packages/core/src/checkpoint/checkpoint-store.ts new file mode 100644 index 0000000..4c6ac68 --- /dev/null +++ b/packages/core/src/checkpoint/checkpoint-store.ts @@ -0,0 +1,263 @@ +/** + * 检查点存储 + * 负责检查点的 CRUD 操作 + */ + +import { nanoid } from 'nanoid'; +import type { ShadowGit } from './shadow-git.js'; +import type { CheckpointLock } from './lock.js'; +import type { CommitMessageGenerator } from './commit-message.js'; +import type { + CheckpointMetadata, + CheckpointConfig, + CheckpointTrigger, +} from './types.js'; + +/** + * 检查点提交消息前缀 + */ +const CHECKPOINT_PREFIX = 'checkpoint:'; + +/** + * 创建检查点选项 + */ +export interface CreateCheckpointOptions { + name?: string; + description?: string; + trigger?: CheckpointTrigger; + toolCall?: { tool: string; params: Record }; + messageId?: string; + sessionId?: string; + turnIndex?: number; +} + +/** + * 检查点存储 + */ +export class CheckpointStore { + private shadowGit: ShadowGit; + private lock: CheckpointLock; + private commitMessageGenerator: CommitMessageGenerator; + private config: CheckpointConfig; + private index: Map = new Map(); + + constructor( + shadowGit: ShadowGit, + lock: CheckpointLock, + commitMessageGenerator: CommitMessageGenerator, + config: CheckpointConfig + ) { + this.shadowGit = shadowGit; + this.lock = lock; + this.commitMessageGenerator = commitMessageGenerator; + this.config = config; + } + + /** + * 从 Git 历史加载检查点索引 + */ + async loadIndex(): Promise { + try { + const commits = await this.shadowGit.getCommits(this.config.maxCheckpoints); + + for (const commit of commits) { + if (commit.message.startsWith(CHECKPOINT_PREFIX)) { + try { + const jsonStr = commit.message.slice(CHECKPOINT_PREFIX.length); + const metadata = JSON.parse(jsonStr) as CheckpointMetadata; + metadata.commitHash = commit.hash; + this.index.set(metadata.id, metadata); + } catch { + // 解析失败,跳过 + } + } + } + } catch { + // 仓库可能是空的 + } + } + + /** + * 创建检查点 + */ + async create( + options: CreateCheckpointOptions, + currentSessionId?: string | null, + onCheckpointCreated?: (id: string) => void + ): Promise { + return this.lock.withLock(async () => { + return this.createInternal(options, currentSessionId, onCheckpointCreated); + }); + } + + /** + * 创建内部检查点(不使用锁,供内部调用) + */ + async createInternal( + options: CreateCheckpointOptions, + currentSessionId?: string | null, + onCheckpointCreated?: (id: string) => void + ): Promise { + const id = nanoid(10); + const timestamp = Date.now(); + const trigger = options.trigger || 'manual'; + + // 创建元数据 + const metadata: CheckpointMetadata = { + id, + name: options.name, + description: options.description, + timestamp, + trigger, + toolCall: options.toolCall, + commitHash: '', + filesChanged: 0, + messageId: options.messageId, + sessionId: options.sessionId || currentSessionId || undefined, + turnIndex: options.turnIndex, + }; + + // 获取变更文件数 + let filesChanged: Array<{ path: string; type: string }> = []; + try { + const diff = await this.shadowGit.getWorkingDirDiff(); + metadata.filesChanged = diff.files.length; + filesChanged = diff.files; + } catch { + // 忽略 + } + + // 生成智能提交消息 + const humanReadableMessage = this.commitMessageGenerator.generateMessage( + trigger, + options.toolCall, + filesChanged as Array<{ path: string; type: 'added' | 'modified' | 'deleted' | 'renamed' }> + ); + + // 创建 commit + const commitMessage = CHECKPOINT_PREFIX + JSON.stringify({ + ...metadata, + _readableMessage: humanReadableMessage, + }); + const commitHash = await this.shadowGit.createCommit(commitMessage); + metadata.commitHash = commitHash; + + // 更新索引 + this.index.set(id, metadata); + + // 通知回调 + onCheckpointCreated?.(id); + + return metadata; + } + + /** + * 列出所有检查点 + */ + list(): CheckpointMetadata[] { + return Array.from(this.index.values()).sort( + (a, b) => b.timestamp - a.timestamp + ); + } + + /** + * 获取检查点 + */ + get(idOrHash: string): CheckpointMetadata | null { + // 先按 ID 查找 + if (this.index.has(idOrHash)) { + return this.index.get(idOrHash)!; + } + + // 再按 commit hash 查找 + for (const checkpoint of this.index.values()) { + if (checkpoint.commitHash.startsWith(idOrHash)) { + return checkpoint; + } + } + + return null; + } + + /** + * 获取最新检查点 + */ + getLatest(): CheckpointMetadata | null { + const checkpoints = this.list(); + return checkpoints[0] || null; + } + + /** + * 删除检查点 + */ + delete(checkpointId: string): CheckpointMetadata | null { + const checkpoint = this.index.get(checkpointId); + if (!checkpoint) { + return null; + } + + this.index.delete(checkpointId); + return checkpoint; + } + + /** + * 判断是否应该为指定工具创建检查点 + */ + shouldCreateForTool(tool: string): boolean { + if (!this.config.enabled) return false; + + const { autoCheckpoint } = this.config; + + switch (tool) { + case 'write_file': + return autoCheckpoint.beforeWrite; + case 'edit_file': + return autoCheckpoint.beforeEdit; + case 'delete_file': + return autoCheckpoint.beforeDelete; + case 'move_file': + case 'copy_file': + return autoCheckpoint.beforeMove; + case 'bash': + return autoCheckpoint.beforeBash; + default: + return false; + } + } + + /** + * 生成检查点描述 + */ + generateDescription(tool: string, params: Record): string { + switch (tool) { + case 'write_file': + return `Write file: ${params.file_path || params.path}`; + case 'edit_file': + return `Edit file: ${params.file_path || params.path}`; + case 'delete_file': + return `Delete file: ${params.file_path || params.path}`; + case 'move_file': + return `Move: ${params.source} -> ${params.destination}`; + case 'copy_file': + return `Copy: ${params.source} -> ${params.destination}`; + case 'bash': + return `Bash: ${String(params.command).slice(0, 50)}`; + default: + return `Tool: ${tool}`; + } + } + + /** + * 获取检查点数量 + */ + get size(): number { + return this.index.size; + } + + /** + * 获取配置 + */ + getConfig(): CheckpointConfig { + return this.config; + } +} diff --git a/packages/core/src/checkpoint/manager.ts b/packages/core/src/checkpoint/manager.ts index 7efad2c..822113f 100644 --- a/packages/core/src/checkpoint/manager.ts +++ b/packages/core/src/checkpoint/manager.ts @@ -1,10 +1,9 @@ /** * 检查点管理器 - * 管理检查点的创建、回滚、清理等操作 + * 作为编排器,委托具体工作给各个子模块 */ import * as path from 'path'; -import { nanoid } from 'nanoid'; import { getCheckpointsDir } from '../constants/paths.js'; import { ShadowGit, createShadowGit } from './shadow-git.js'; import { CheckpointLock } from './lock.js'; @@ -13,7 +12,6 @@ import { WorkspacePathValidator } from './path-validator.js'; import { CommitMessageGenerator } from './commit-message.js'; import { LFSPatternLoader } from './lfs.js'; import { - RestoreMode, type CheckpointMetadata, type CheckpointConfig, type CheckpointTrigger, @@ -21,17 +19,17 @@ import { type RollbackResult, type DiffInfo, type FileDiff, - type CheckpointEvent, type CheckpointEventListener, type RollbackRecord, type UnrevertResult, type SafetyCheckResult, } from './types.js'; -/** - * 检查点提交消息前缀 - */ -const CHECKPOINT_PREFIX = 'checkpoint:'; +// 子模块 +import { CheckpointStore, type CreateCheckpointOptions } from './checkpoint-store.js'; +import { CheckpointRollback } from './checkpoint-rollback.js'; +import { CheckpointSession } from './checkpoint-session.js'; +import { CheckpointEvents } from './checkpoint-events.js'; /** * 检查点管理器 @@ -40,25 +38,21 @@ export class CheckpointManager { private shadowGit: ShadowGit; private config: CheckpointConfig; private workDir: string; - private checkpointsIndex: Map = new Map(); private initialized = false; private lastCheckpointTime = 0; - private eventListeners: Set = new Set(); - // 新增:增强功能组件 + // 子模块 + private store: CheckpointStore; + private rollbackHandler: CheckpointRollback; + private session: CheckpointSession; + private events: CheckpointEvents; + + // 辅助组件 private lock: CheckpointLock; private safetyChecker: CheckpointSafetyChecker; private pathValidator: WorkspacePathValidator; - private commitMessageGenerator: CommitMessageGenerator; private lfsLoader: LFSPatternLoader; - // 新增:Unrevert 支持 - private lastRollback: RollbackRecord | null = null; - - // 新增:会话跟踪 - private currentSessionId: string | null = null; - private sessionCheckpoints: Map = new Map(); - // 防止重复创建检查点的最小间隔 (毫秒) private static readonly MIN_CHECKPOINT_INTERVAL = 1000; @@ -81,12 +75,18 @@ export class CheckpointManager { this.shadowGit = createShadowGit(this.workDir, this.config.storageDir); - // 初始化增强组件 + // 初始化辅助组件 this.lock = new CheckpointLock(this.shadowGit.getShadowGitDir()); this.safetyChecker = new CheckpointSafetyChecker(this.workDir); this.pathValidator = new WorkspacePathValidator(); - this.commitMessageGenerator = new CommitMessageGenerator(); this.lfsLoader = new LFSPatternLoader(); + + // 初始化子模块 + const commitMessageGenerator = new CommitMessageGenerator(); + this.store = new CheckpointStore(this.shadowGit, this.lock, commitMessageGenerator, this.config); + this.rollbackHandler = new CheckpointRollback(this.shadowGit, this.lock, this.safetyChecker, this.store); + this.session = new CheckpointSession(this.store); + this.events = new CheckpointEvents(); } /** @@ -113,58 +113,20 @@ export class CheckpointManager { await this.shadowGit.initialize(); // 加载检查点索引 - await this.loadCheckpointsIndex(); + await this.store.loadIndex(); this.initialized = true; } - /** - * 加载检查点索引 - */ - private async loadCheckpointsIndex(): Promise { - try { - const commits = await this.shadowGit.getCommits(this.config.maxCheckpoints); - - for (const commit of commits) { - if (commit.message.startsWith(CHECKPOINT_PREFIX)) { - try { - const jsonStr = commit.message.slice(CHECKPOINT_PREFIX.length); - const metadata = JSON.parse(jsonStr) as CheckpointMetadata; - metadata.commitHash = commit.hash; - this.checkpointsIndex.set(metadata.id, metadata); - } catch { - // 解析失败,跳过 - } - } - } - } catch { - // 仓库可能是空的 - } - } + // ============================================================================ + // 检查点操作(委托给 store) + // ============================================================================ /** * 判断是否应该为指定工具创建检查点 */ shouldCreateCheckpoint(tool: string): boolean { - if (!this.config.enabled) return false; - - const { autoCheckpoint } = this.config; - - switch (tool) { - case 'write_file': - return autoCheckpoint.beforeWrite; - case 'edit_file': - return autoCheckpoint.beforeEdit; - case 'delete_file': - return autoCheckpoint.beforeDelete; - case 'move_file': - case 'copy_file': - return autoCheckpoint.beforeMove; - case 'bash': - return autoCheckpoint.beforeBash; - default: - return false; - } + return this.store.shouldCreateForTool(tool); } /** @@ -188,7 +150,7 @@ export class CheckpointManager { const checkpoint = await this.createCheckpoint({ trigger: `tool:${tool}` as CheckpointTrigger, toolCall: { tool, params }, - description: this.generateDescription(tool, params), + description: this.store.generateDescription(tool, params), }); this.lastCheckpointTime = now; @@ -199,118 +161,29 @@ export class CheckpointManager { } } - /** - * 生成检查点描述 - */ - private generateDescription( - tool: string, - params: Record - ): string { - switch (tool) { - case 'write_file': - return `Write file: ${params.file_path || params.path}`; - case 'edit_file': - return `Edit file: ${params.file_path || params.path}`; - case 'delete_file': - return `Delete file: ${params.file_path || params.path}`; - case 'move_file': - return `Move: ${params.source} -> ${params.destination}`; - case 'copy_file': - return `Copy: ${params.source} -> ${params.destination}`; - case 'bash': - return `Bash: ${String(params.command).slice(0, 50)}`; - default: - return `Tool: ${tool}`; - } - } - /** * 创建检查点 */ - async createCheckpoint(options: { - name?: string; - description?: string; - trigger?: CheckpointTrigger; - toolCall?: { tool: string; params: Record }; - messageId?: string; - sessionId?: string; - turnIndex?: number; - }): Promise { + async createCheckpoint(options: CreateCheckpointOptions): Promise { await this.initialize(); if (!this.config.enabled) { throw new Error('Checkpoint system is disabled'); } - // 使用锁保护检查点创建 - return this.lock.withLock(async () => { - const id = nanoid(10); - const timestamp = Date.now(); - const trigger = options.trigger || 'manual'; + const checkpoint = await this.store.create( + options, + this.session.getCurrentSessionId(), + (id) => this.session.recordCheckpoint(id) + ); - // 创建元数据 - const metadata: CheckpointMetadata = { - id, - name: options.name, - description: options.description, - timestamp, - trigger, - toolCall: options.toolCall, - commitHash: '', // 待填充 - filesChanged: 0, // 待填充 - // 新增:消息和会话关联 - messageId: options.messageId, - sessionId: options.sessionId || this.currentSessionId || undefined, - turnIndex: options.turnIndex, - }; + // 触发事件 + this.events.emitCreated(checkpoint); - // 获取变更文件数 - let filesChanged: Array<{ path: string; type: string }> = []; - try { - const diff = await this.shadowGit.getWorkingDirDiff(); - metadata.filesChanged = diff.files.length; - filesChanged = diff.files; - } catch { - // 忽略 - } + // 异步清理 + this.cleanupAsync(); - // 生成智能提交消息 - const humanReadableMessage = this.commitMessageGenerator.generateMessage( - trigger, - options.toolCall, - filesChanged as any - ); - - // 创建 commit(使用 JSON 元数据作为 commit message,但包含可读描述) - const commitMessage = CHECKPOINT_PREFIX + JSON.stringify({ - ...metadata, - _readableMessage: humanReadableMessage, - }); - const commitHash = await this.shadowGit.createCommit(commitMessage); - metadata.commitHash = commitHash; - - // 更新索引 - this.checkpointsIndex.set(id, metadata); - - // 记录到当前会话 - if (this.currentSessionId) { - const sessionCps = this.sessionCheckpoints.get(this.currentSessionId) || []; - sessionCps.push(id); - this.sessionCheckpoints.set(this.currentSessionId, sessionCps); - } - - // 触发事件 - this.emitEvent({ - type: 'created', - checkpoint: metadata, - timestamp, - }); - - // 异步清理 - this.cleanupAsync(); - - return metadata; - }); + return checkpoint; } /** @@ -329,10 +202,7 @@ export class CheckpointManager { */ async listCheckpoints(): Promise { await this.initialize(); - - return Array.from(this.checkpointsIndex.values()).sort( - (a, b) => b.timestamp - a.timestamp - ); + return this.store.list(); } /** @@ -340,30 +210,35 @@ export class CheckpointManager { */ async getCheckpoint(idOrHash: string): Promise { await this.initialize(); - - // 先按 ID 查找 - if (this.checkpointsIndex.has(idOrHash)) { - return this.checkpointsIndex.get(idOrHash)!; - } - - // 再按 commit hash 查找 - for (const checkpoint of this.checkpointsIndex.values()) { - if (checkpoint.commitHash.startsWith(idOrHash)) { - return checkpoint; - } - } - - return null; + return this.store.get(idOrHash); } /** * 获取最近的检查点 */ async getLatestCheckpoint(): Promise { - const checkpoints = await this.listCheckpoints(); - return checkpoints[0] || null; + await this.initialize(); + return this.store.getLatest(); } + /** + * 删除检查点 + */ + async deleteCheckpoint(checkpointId: string): Promise { + await this.initialize(); + + const checkpoint = this.store.delete(checkpointId); + if (checkpoint) { + this.events.emitDeleted(checkpoint); + return true; + } + return false; + } + + // ============================================================================ + // 差异操作 + // ============================================================================ + /** * 获取检查点与当前工作区的差异 */ @@ -415,183 +290,64 @@ export class CheckpointManager { return this.shadowGit.getFileDiff(checkpoint.commitHash, head, filePath); } + // ============================================================================ + // 回滚操作(委托给 rollbackHandler) + // ============================================================================ + /** * 回滚到检查点 */ async rollback(options: RollbackOptions): Promise { await this.initialize(); - const checkpoint = await this.getCheckpoint(options.target); - if (!checkpoint) { - throw new Error(`Checkpoint not found: ${options.target}`); - } - - // 安全检查(除非明确跳过) - if (!options.skipSafetyCheck) { - const safetyResult = await this.safetyChecker.checkBeforeRollback(checkpoint, this); - if (!safetyResult.safe) { - const errorMsg = safetyResult.errors.join('; '); - throw new Error(`Safety check failed: ${errorMsg}`); - } - // 警告仍然记录,但不阻止操作 - if (safetyResult.warnings.length > 0) { - console.warn('Rollback warnings:', safetyResult.warnings.join('; ')); - } - } - - // 使用锁保护回滚操作 - return this.lock.withLock(async () => { - // 获取当前 HEAD 用于可能的撤销 - const previousCommit = await this.shadowGit.getHead(); - - // 预览模式 - if (options.dryRun) { - const diff = await this.getDiff(checkpoint.id); - return { - success: true, - restoredFiles: diff.files.map((f) => f.path), - errors: [], - previousCommit, - }; - } - - // 创建回滚前检查点(用于 unrevert) - let preRollbackCheckpoint: CheckpointMetadata | null = null; - try { - preRollbackCheckpoint = await this.createCheckpointInternal({ - trigger: 'pre_rollback', - description: `Before rollback to ${options.target}`, - }); - } catch { - // 忽略创建失败 - } - - const result: RollbackResult = { - success: true, - restoredFiles: [], - errors: [], - previousCommit, - }; - - try { - const mode = options.mode || RestoreMode.FULL; - - if (options.files && options.files.length > 0) { - // 选择性回滚(指定文件) - await this.shadowGit.checkoutFiles(checkpoint.commitHash, options.files); - result.restoredFiles = options.files; - } else if (mode === RestoreMode.AI_CHANGES_ONLY) { - // 仅恢复 AI 修改的文件 - const aiFiles = await this.getAiModifiedFiles(checkpoint); - if (aiFiles.length > 0) { - await this.shadowGit.checkoutFiles(checkpoint.commitHash, aiFiles); - result.restoredFiles = aiFiles; - } - } else if (mode === RestoreMode.WORKSPACE_ONLY) { - // 仅恢复工作区变更(不包括 AI 修改) - const workspaceFiles = await this.getWorkspaceOnlyFiles(checkpoint); - if (workspaceFiles.length > 0) { - await this.shadowGit.checkoutFiles(checkpoint.commitHash, workspaceFiles); - result.restoredFiles = workspaceFiles; - } - } else { - // 完整回滚 - await this.shadowGit.resetHard(checkpoint.commitHash); - - // 获取恢复的文件列表 - const diff = await this.shadowGit.getDiffSummary( - previousCommit, - checkpoint.commitHash - ); - result.restoredFiles = diff.files.map((f) => f.path); - } - - // 记录回滚信息(用于 unrevert) - this.lastRollback = { - id: nanoid(10), - timestamp: Date.now(), - targetCheckpoint: checkpoint.id, - previousCommit: preRollbackCheckpoint?.commitHash || previousCommit, - restoredFiles: result.restoredFiles, - canUnrevert: true, - }; - - // 触发事件 - this.emitEvent({ - type: 'restored', - checkpoint, - timestamp: Date.now(), - details: { - files: result.restoredFiles, - previousCommit, - mode, - }, - }); - } catch (error) { - result.success = false; - result.errors.push({ - file: '*', - error: error instanceof Error ? error.message : String(error), + return this.rollbackHandler.rollback(options, (event) => { + if (event.type === 'restored' && event.checkpoint) { + this.events.emitRestored(event.checkpoint, event.details as { + files: string[]; + previousCommit: string; + mode: string; }); } - - return result; }); } /** - * 撤销最近一次回滚(Unrevert) + * 撤销操作(回滚到上一个检查点) + */ + async undo(): Promise { + await this.initialize(); + + return this.rollbackHandler.undo((event) => { + if (event.type === 'restored' && event.checkpoint) { + this.events.emitRestored(event.checkpoint, event.details as { + files: string[]; + previousCommit: string; + mode: string; + }); + } + }); + } + + /** + * 撤销最近一次回滚 */ async unrevert(): Promise { await this.initialize(); - - if (!this.lastRollback || !this.lastRollback.canUnrevert) { - return { - success: false, - restoredCommit: '', - filesRestored: 0, - error: 'No rollback to unrevert', - }; - } - - return this.lock.withLock(async () => { - try { - // 恢复到回滚前的状态 - await this.shadowGit.resetHard(this.lastRollback!.previousCommit); - - const result: UnrevertResult = { - success: true, - restoredCommit: this.lastRollback!.previousCommit, - filesRestored: this.lastRollback!.restoredFiles.length, - }; - - // 清除 unrevert 记录 - this.lastRollback = null; - - return result; - } catch (error) { - return { - success: false, - restoredCommit: '', - filesRestored: 0, - error: error instanceof Error ? error.message : String(error), - }; - } - }); + return this.rollbackHandler.unrevert(); } /** - * 检查是否可以执行 unrevert + * 检查是否可以 unrevert */ canUnrevert(): boolean { - return this.lastRollback !== null && this.lastRollback.canUnrevert; + return this.rollbackHandler.canUnrevert(); } /** * 获取最后一次回滚记录 */ getLastRollback(): RollbackRecord | null { - return this.lastRollback; + return this.rollbackHandler.getLastRollback(); } /** @@ -609,115 +365,122 @@ export class CheckpointManager { }; } - return this.safetyChecker.checkBeforeRollback(checkpoint, this); + return this.safetyChecker.checkBeforeRollback(checkpoint, { + getCheckpoint: (id: string) => Promise.resolve(this.store.get(id)), + listCheckpoints: () => Promise.resolve(this.store.list()), + getDiff: (id: string) => this.getDiff(id), + }); } - /** - * 内部创建检查点(不使用锁,供 rollback 内部调用) - */ - private async createCheckpointInternal(options: { - trigger: CheckpointTrigger; - description?: string; - }): Promise { - const id = nanoid(10); - const timestamp = Date.now(); - - const metadata: CheckpointMetadata = { - id, - description: options.description, - timestamp, - trigger: options.trigger, - commitHash: '', - filesChanged: 0, - }; - - const commitMessage = CHECKPOINT_PREFIX + JSON.stringify(metadata); - const commitHash = await this.shadowGit.createCommit(commitMessage); - metadata.commitHash = commitHash; - - this.checkpointsIndex.set(id, metadata); - - return metadata; - } + // ============================================================================ + // 会话操作(委托给 session) + // ============================================================================ /** - * 获取 AI 修改的文件列表 + * 开始新会话 */ - private async getAiModifiedFiles(checkpoint: CheckpointMetadata): Promise { - const files: string[] = []; - const checkpoints = await this.listCheckpoints(); - - // 找到该检查点之后的所有检查点 - for (const cp of checkpoints) { - if (cp.timestamp > checkpoint.timestamp && cp.toolCall) { - const filePath = - (cp.toolCall.params.file_path as string) || - (cp.toolCall.params.path as string); - if (filePath && !files.includes(filePath)) { - files.push(filePath); - } - } - } - - return files; - } - - /** - * 获取仅工作区变更的文件(不包括 AI 修改) - */ - private async getWorkspaceOnlyFiles(checkpoint: CheckpointMetadata): Promise { - const diff = await this.getDiff(checkpoint.id); - const aiFiles = await this.getAiModifiedFiles(checkpoint); - - // 返回不在 AI 修改列表中的文件 - return diff.files - .map((f) => f.path) - .filter((path) => !aiFiles.includes(path)); - } - - /** - * 撤销操作 (回滚到上一个检查点) - */ - async undo(): Promise { - const latest = await this.getLatestCheckpoint(); - if (!latest) { - throw new Error('No checkpoints available'); - } - - // 找到倒数第二个检查点 - const checkpoints = await this.listCheckpoints(); - if (checkpoints.length < 2) { - // 只有一个检查点,回滚到它 - return this.rollback({ target: latest.id }); - } - - // 回滚到倒数第二个检查点 - return this.rollback({ target: checkpoints[1].id }); - } - - /** - * 删除检查点 - */ - async deleteCheckpoint(checkpointId: string): Promise { + async startSession(sessionId?: string): Promise { await this.initialize(); - if (!this.checkpointsIndex.has(checkpointId)) { - return false; + return this.session.startSession(sessionId, (event) => { + if (event.checkpoint) { + this.events.emitCreated(event.checkpoint); + } + }); + } + + /** + * 结束当前会话 + */ + async endSession(): Promise { + return this.session.endSession((event) => { + if (event.checkpoint) { + this.events.emitCreated(event.checkpoint); + } + }); + } + + /** + * 获取当前会话 ID + */ + getCurrentSessionId(): string | null { + return this.session.getCurrentSessionId(); + } + + /** + * 获取会话的所有检查点 + */ + async getSessionCheckpoints(sessionId: string): Promise { + await this.initialize(); + return this.session.getSessionCheckpoints(sessionId); + } + + /** + * 创建与消息关联的检查点 + */ + async createMessageCheckpoint( + messageId: string, + turnIndex?: number, + options?: { + trigger?: CheckpointTrigger; + description?: string; + } + ): Promise { + await this.initialize(); + + return this.session.createMessageCheckpoint(messageId, turnIndex, options, (event) => { + this.events.emitCreated(event.checkpoint); + }); + } + + /** + * 获取与消息关联的检查点 + */ + async getMessageCheckpoints(messageId: string): Promise { + await this.initialize(); + return this.session.getMessageCheckpoints(messageId); + } + + /** + * 撤销整个会话的修改 + */ + async undoSession(sessionId: string): Promise { + const sessionCheckpoints = await this.getSessionCheckpoints(sessionId); + if (sessionCheckpoints.length === 0) { + throw new Error(`No checkpoints found for session: ${sessionId}`); } - const checkpoint = this.checkpointsIndex.get(checkpointId)!; - this.checkpointsIndex.delete(checkpointId); + const startCheckpoint = this.session.getSessionStartCheckpoint(sessionId); - // 触发事件 - this.emitEvent({ - type: 'deleted', - checkpoint, - timestamp: Date.now(), - }); + if (!startCheckpoint) { + return this.rollback({ target: sessionCheckpoints[0].id }); + } - return true; + return this.rollback({ target: startCheckpoint.id }); } + // ============================================================================ + // 事件操作(委托给 events) + // ============================================================================ + + /** + * 添加事件监听器 + */ + addEventListener(listener: CheckpointEventListener): void { + this.events.addEventListener(listener); + } + + /** + * 移除事件监听器 + */ + removeEventListener(listener: CheckpointEventListener): void { + this.events.removeEventListener(listener); + } + + // ============================================================================ + // 清理操作 + // ============================================================================ + /** * 异步清理过期检查点 */ @@ -737,15 +500,17 @@ export class CheckpointManager { async cleanup(): Promise { await this.initialize(); - const checkpoints = await this.listCheckpoints(); + const checkpoints = this.store.list(); const now = Date.now(); let deletedCount = 0; // 按时间过期清理 for (const checkpoint of checkpoints) { if (now - checkpoint.timestamp > this.config.maxAge) { - await this.deleteCheckpoint(checkpoint.id); - deletedCount++; + if (this.store.delete(checkpoint.id)) { + this.events.emitDeleted(checkpoint); + deletedCount++; + } } } @@ -754,28 +519,25 @@ export class CheckpointManager { if (remaining > this.config.maxCheckpoints) { const toDelete = checkpoints.slice(this.config.maxCheckpoints); for (const checkpoint of toDelete) { - if (this.checkpointsIndex.has(checkpoint.id)) { - await this.deleteCheckpoint(checkpoint.id); + if (this.store.delete(checkpoint.id)) { + this.events.emitDeleted(checkpoint); deletedCount++; } } } if (deletedCount > 0) { - // 触发清理事件 - this.emitEvent({ - type: 'cleanup', - timestamp: now, - details: { deletedCount }, - }); - - // 运行 git gc + this.events.emitCleanup(deletedCount); await this.shadowGit.cleanup(this.config.maxCheckpoints); } return deletedCount; } + // ============================================================================ + // 其他方法 + // ============================================================================ + /** * 获取检查点存储统计 */ @@ -795,33 +557,6 @@ export class CheckpointManager { }; } - /** - * 添加事件监听器 - */ - addEventListener(listener: CheckpointEventListener): void { - this.eventListeners.add(listener); - } - - /** - * 移除事件监听器 - */ - removeEventListener(listener: CheckpointEventListener): void { - this.eventListeners.delete(listener); - } - - /** - * 触发事件 - */ - private emitEvent(event: CheckpointEvent): void { - for (const listener of this.eventListeners) { - try { - listener(event); - } catch (error) { - console.warn('Checkpoint event listener error:', error); - } - } - } - /** * 检查是否启用 */ @@ -836,137 +571,6 @@ export class CheckpointManager { return { ...this.config }; } - // ==================== 会话管理方法 ==================== - - /** - * 开始新会话 - */ - async startSession(sessionId?: string): Promise { - await this.initialize(); - - const id = sessionId || nanoid(10); - this.currentSessionId = id; - this.sessionCheckpoints.set(id, []); - - // 创建会话开始检查点 - try { - await this.createCheckpoint({ - trigger: 'session_start', - description: `Session started: ${id}`, - sessionId: id, - }); - } catch { - // 忽略创建失败 - } - - return id; - } - - /** - * 结束当前会话 - */ - async endSession(): Promise { - if (!this.currentSessionId) return; - - // 创建会话结束检查点 - try { - await this.createCheckpoint({ - trigger: 'session_end', - description: `Session ended: ${this.currentSessionId}`, - sessionId: this.currentSessionId, - }); - } catch { - // 忽略创建失败 - } - - this.currentSessionId = null; - } - - /** - * 获取当前会话 ID - */ - getCurrentSessionId(): string | null { - return this.currentSessionId; - } - - /** - * 获取会话的所有检查点 - */ - async getSessionCheckpoints(sessionId: string): Promise { - await this.initialize(); - - const checkpointIds = this.sessionCheckpoints.get(sessionId); - if (!checkpointIds) return []; - - const checkpoints: CheckpointMetadata[] = []; - for (const id of checkpointIds) { - const checkpoint = this.checkpointsIndex.get(id); - if (checkpoint) { - checkpoints.push(checkpoint); - } - } - - return checkpoints.sort((a, b) => a.timestamp - b.timestamp); - } - - /** - * 创建与消息关联的检查点 - */ - async createMessageCheckpoint( - messageId: string, - turnIndex?: number, - options?: { - trigger?: CheckpointTrigger; - description?: string; - } - ): Promise { - return this.createCheckpoint({ - trigger: options?.trigger || 'auto', - description: options?.description || `Message checkpoint: ${messageId}`, - messageId, - sessionId: this.currentSessionId || undefined, - turnIndex, - }); - } - - /** - * 获取与消息关联的检查点 - */ - async getMessageCheckpoints(messageId: string): Promise { - await this.initialize(); - - const checkpoints: CheckpointMetadata[] = []; - for (const checkpoint of this.checkpointsIndex.values()) { - if (checkpoint.messageId === messageId) { - checkpoints.push(checkpoint); - } - } - - return checkpoints.sort((a, b) => a.timestamp - b.timestamp); - } - - /** - * 撤销整个会话的修改 - */ - async undoSession(sessionId: string): Promise { - const sessionCheckpoints = await this.getSessionCheckpoints(sessionId); - if (sessionCheckpoints.length === 0) { - throw new Error(`No checkpoints found for session: ${sessionId}`); - } - - // 找到会话开始的检查点 - const startCheckpoint = sessionCheckpoints.find( - (cp) => cp.trigger === 'session_start' - ); - - if (!startCheckpoint) { - // 如果没有明确的开始检查点,使用第一个检查点 - return this.rollback({ target: sessionCheckpoints[0].id }); - } - - return this.rollback({ target: startCheckpoint.id }); - } - /** * 获取 LFS 模式加载器 */ diff --git a/packages/core/src/checkpoint/safety.ts b/packages/core/src/checkpoint/safety.ts index e443d9d..59fd133 100644 --- a/packages/core/src/checkpoint/safety.ts +++ b/packages/core/src/checkpoint/safety.ts @@ -5,11 +5,20 @@ import { exec } from 'child_process'; import { promisify } from 'util'; -import type { CheckpointManager } from './manager.js'; -import type { SafetyCheckResult, CheckpointMetadata } from './types.js'; +import type { SafetyCheckResult, CheckpointMetadata, DiffInfo } from './types.js'; const execAsync = promisify(exec); +/** + * 安全检查器所需的 CheckpointManager 接口 + * 使用接口而非完整类以避免循环依赖 + */ +export interface SafetyCheckManagerInterface { + listCheckpoints(): Promise; + getDiff(checkpointId: string): Promise; + getCheckpoint(idOrHash: string): Promise; +} + /** * 检查点安全检查器 */ @@ -25,7 +34,7 @@ export class CheckpointSafetyChecker { */ async checkBeforeRollback( checkpoint: CheckpointMetadata, - manager: CheckpointManager + manager: SafetyCheckManagerInterface ): Promise { const result: SafetyCheckResult = { safe: true, @@ -87,7 +96,7 @@ export class CheckpointSafetyChecker { /** * 检查工作区是否有未保存的变更 */ - private async hasUnsavedChanges(manager: CheckpointManager): Promise { + private async hasUnsavedChanges(manager: SafetyCheckManagerInterface): Promise { try { // 通过 manager 检查 const checkpoints = await manager.listCheckpoints(); diff --git a/packages/core/src/core/agent-message-handler.ts b/packages/core/src/core/agent-message-handler.ts new file mode 100644 index 0000000..4647b71 --- /dev/null +++ b/packages/core/src/core/agent-message-handler.ts @@ -0,0 +1,406 @@ +/** + * Agent 消息处理器 + * 负责消息的构建、流式处理、压缩 + */ + +import { + generateText, + streamText, + stepCountIs, + type ModelMessage, + type Tool as AITool, + type LanguageModel, +} from 'ai'; +import type { ToolResult, UserInput, ContentBlock, ChatResult } from '../types/index.js'; +import { + CompressionManager, + CompressionStatus, + type TokenUsage, +} from '../context/index.js'; +import { + createDoomLoopDetector, + type DoomLoopDetector, + DOOM_LOOP_WARNING, +} from './doom-loop.js'; +import type { ToolStartInfo, ToolEndInfo } from './agent-tool-executor.js'; + +/** + * Doom Loop 检测事件信息 + */ +export interface DoomLoopInfo { + toolName: string; + count: number; +} + +/** + * 消息处理配置 + */ +export interface MessageHandlerConfig { + model: LanguageModel; + systemPrompt: string; + maxTokens?: number; + maxSteps?: number; +} + +/** + * 流式回调 + */ +export interface StreamCallbacks { + onStream?: (text: string) => void; + onToolStart?: (info: ToolStartInfo) => void; + onToolEnd?: (info: ToolEndInfo) => void; + onDoomLoop?: (info: DoomLoopInfo) => void; +} + +/** + * Agent 消息处理器 + */ +export class AgentMessageHandler { + private compressionManager: CompressionManager; + private conversationHistory: ModelMessage[] = []; + private doomLoopDetector: DoomLoopDetector = createDoomLoopDetector(); + + constructor(compressionManager: CompressionManager) { + this.compressionManager = compressionManager; + } + + /** + * 重置 Doom Loop 检测器 + */ + resetDoomLoop(): void { + this.doomLoopDetector.reset(); + } + + /** + * 构建用户消息内容(处理文本和图片) + */ + buildUserMessageContent(input: string | UserInput): string | ContentBlock[] { + if (typeof input === 'string') { + return input; + } + + const blocks: ContentBlock[] = []; + + // 添加图片 + if (input.images && input.images.length > 0) { + for (const img of input.images) { + blocks.push({ + type: 'image', + image: img.data, + mimeType: img.mimeType, + }); + } + } + + // 添加文本 + if (input.text) { + blocks.push({ + type: 'text', + text: input.text, + }); + } + + // 如果只有一个文本块,直接返回文本 + if (blocks.length === 1 && blocks[0].type === 'text') { + return blocks[0].text; + } + + return blocks; + } + + /** + * 添加用户消息到历史 + */ + addUserMessage(content: string | ContentBlock[]): void { + this.conversationHistory.push({ + role: 'user', + content, + } as ModelMessage); + } + + /** + * 流式聊天 + */ + async streamChat( + config: MessageHandlerConfig, + tools: Record, + callbacks: StreamCallbacks, + abortSignal?: AbortSignal + ): Promise { + const { onStream, onToolStart, onToolEnd, onDoomLoop } = callbacks; + const maxSteps = config.maxSteps ?? 50; + + // 工具调用时间跟踪 + const toolStartTimes = new Map(); + // Doom loop 检测状态 + let doomLoopTriggered = false; + let fullResponse = ''; + let responseMessages: ModelMessage[] = []; + + const result = streamText({ + model: config.model, + system: config.systemPrompt, + messages: this.conversationHistory, + tools: doomLoopTriggered ? {} : tools, // doom loop 时禁用工具 + maxOutputTokens: config.maxTokens, + stopWhen: stepCountIs(maxSteps), + abortSignal, + onChunk: ({ chunk }) => { + if (chunk.type === 'tool-call') { + this.handleToolCallChunk( + chunk, + toolStartTimes, + doomLoopTriggered, + onToolStart, + onDoomLoop, + onStream, + (triggered) => { doomLoopTriggered = triggered; } + ); + } else if (chunk.type === 'tool-result') { + const toolResultChunk = chunk as unknown as { toolCallId: string; output?: unknown }; + this.handleToolResultChunk(toolResultChunk, toolStartTimes, onToolEnd, onStream); + } + }, + }); + + // 流式输出文本 + let aborted = false; + try { + for await (const chunk of result.textStream) { + if (abortSignal?.aborted) { + aborted = true; + break; + } + fullResponse += chunk; + onStream?.(chunk); + } + + if (aborted) { + onStream?.('\n[已取消]\n'); + if (fullResponse) { + this.conversationHistory.push({ + role: 'assistant', + content: fullResponse + '\n[已取消]', + } as ModelMessage); + } + return { text: fullResponse, messages: [] }; + } + + const response = await result.response; + responseMessages = response.messages as ModelMessage[]; + } catch (error) { + if (error instanceof Error && (error.name === 'AbortError' || abortSignal?.aborted)) { + onStream?.('\n[已取消]\n'); + if (fullResponse) { + this.conversationHistory.push({ + role: 'assistant', + content: fullResponse + '\n[已取消]', + } as ModelMessage); + } + return { text: fullResponse, messages: [] }; + } + throw error; + } + + // 将完整的响应消息添加到历史 + this.conversationHistory.push(...responseMessages); + + return { + text: fullResponse, + messages: responseMessages, + }; + } + + /** + * 处理工具调用 chunk + */ + private handleToolCallChunk( + chunk: { toolCallId: string; toolName: string; input: unknown }, + toolStartTimes: Map, + doomLoopTriggered: boolean, + onToolStart?: (info: ToolStartInfo) => void, + onDoomLoop?: (info: DoomLoopInfo) => void, + onStream?: (text: string) => void, + setDoomLoopTriggered?: (triggered: boolean) => void + ): void { + const toolCallId = chunk.toolCallId || `tool-${Date.now()}`; + + // Doom Loop 检测 + this.doomLoopDetector.record(chunk.toolName, chunk.input); + + if (this.doomLoopDetector.isTriggered() && !doomLoopTriggered) { + setDoomLoopTriggered?.(true); + const toolName = this.doomLoopDetector.getLastToolName() || chunk.toolName; + onDoomLoop?.({ toolName, count: 3 }); + onStream?.(`\n[警告: 检测到 Doom Loop - ${toolName} 被重复调用]\n`); + onStream?.(DOOM_LOOP_WARNING); + } + + // 记录开始时间 + toolStartTimes.set(toolCallId, Date.now()); + + // 调用回调 + if (onToolStart) { + onToolStart({ + id: toolCallId, + toolName: chunk.toolName, + args: (chunk.input as Record) || {}, + }); + } else { + onStream?.(`\n[调用工具: ${chunk.toolName}]\n`); + } + } + + /** + * 处理工具结果 chunk + */ + private handleToolResultChunk( + chunk: { toolCallId: string; output?: unknown }, + toolStartTimes: Map, + onToolEnd?: (info: ToolEndInfo) => void, + onStream?: (text: string) => void + ): void { + const toolCallId = chunk.toolCallId || ''; + const output = chunk.output as ToolResult | undefined; + + // 计算执行时长 + const startTime = toolStartTimes.get(toolCallId); + const duration = startTime ? Date.now() - startTime : undefined; + toolStartTimes.delete(toolCallId); + + if (output && typeof output === 'object') { + if (onToolEnd) { + onToolEnd({ + id: toolCallId, + status: output.success ? 'completed' : 'error', + result: output.success ? output.output : undefined, + error: output.success ? undefined : output.error, + duration, + }); + } else { + if (output.success) { + const displayOutput = + output.output.length > 500 + ? output.output.substring(0, 500) + '...(截断)' + : output.output; + onStream?.(`[结果: ${displayOutput}]\n`); + } else { + onStream?.(`[错误: ${output.error}]\n`); + } + } + } + } + + /** + * 非流式聊天 + */ + async generateChat( + config: MessageHandlerConfig, + tools: Record, + abortSignal?: AbortSignal + ): Promise { + const maxSteps = config.maxSteps ?? 50; + + const result = await generateText({ + model: config.model, + system: config.systemPrompt, + messages: this.conversationHistory, + tools, + maxOutputTokens: config.maxTokens, + stopWhen: stepCountIs(maxSteps), + abortSignal, + }); + + const fullResponse = result.text; + const responseMessages = result.response.messages as ModelMessage[]; + + // 将完整的响应消息添加到历史 + this.conversationHistory.push(...responseMessages); + + return { + text: fullResponse, + messages: responseMessages, + }; + } + + /** + * 检查并执行自动压缩 + */ + async autoCompress(onStream?: (text: string) => void): Promise { + if (!this.compressionManager.shouldCompress(this.conversationHistory)) { + return; + } + + const result = await this.compressionManager.compress(this.conversationHistory); + + if (result.status === CompressionStatus.SUCCESS && result.freedTokens > 0) { + this.conversationHistory = result.messages; + if (onStream) { + const typeLabel = result.type === 'both' ? 'prune+摘要' : result.type === 'compaction' ? '摘要' : 'prune'; + onStream(`\n[自动压缩(${typeLabel}): 释放了 ${(result.freedTokens / 1000).toFixed(1)}k tokens]\n`); + } + } else if (result.status === CompressionStatus.FAILED_EMPTY_SUMMARY) { + onStream?.('\n[压缩失败: 摘要生成为空,已跳过]\n'); + } else if (result.status === CompressionStatus.FAILED_TOKEN_INFLATED) { + onStream?.('\n[压缩失败: 摘要反而增加了 token,已跳过]\n'); + } else if (result.status === CompressionStatus.FAILED_ERROR) { + onStream?.(`\n[压缩失败: ${result.error || '未知错误'}]\n`); + } + } + + /** + * 手动压缩 + */ + async forceCompress(): Promise<{ freedTokens: number; type: string }> { + const result = await this.compressionManager.forceCompress(this.conversationHistory); + if (result.freedTokens > 0) { + this.conversationHistory = result.messages; + } + return { + freedTokens: result.freedTokens, + type: result.type, + }; + } + + /** + * 获取对话历史 + */ + getHistory(): ModelMessage[] { + return this.conversationHistory; + } + + /** + * 设置对话历史 + */ + setHistory(messages: ModelMessage[]): void { + this.conversationHistory = [...messages]; + } + + /** + * 清空对话历史 + */ + clearHistory(): void { + this.conversationHistory = []; + } + + /** + * 获取上下文使用情况 + */ + getContextUsage(): TokenUsage { + return this.compressionManager.calculateUsage(this.conversationHistory); + } + + /** + * 获取格式化的上下文使用情况 + */ + getContextUsageFormatted(): string { + return this.compressionManager.formatUsage(this.conversationHistory); + } + + /** + * 获取压缩管理器 + */ + getCompressionManager(): CompressionManager { + return this.compressionManager; + } +} diff --git a/packages/core/src/core/agent-mode-manager.ts b/packages/core/src/core/agent-mode-manager.ts new file mode 100644 index 0000000..63388b3 --- /dev/null +++ b/packages/core/src/core/agent-mode-manager.ts @@ -0,0 +1,131 @@ +/** + * Agent 模式管理器 + * 负责 Agent 模式的切换和权限管理 + */ + +import type { AgentInfo } from '../agent/types.js'; +import { + agentRegistry, + renderPromptTemplate, + createPlanContext, +} from '../agent/index.js'; + +/** + * Agent 模式管理器 + */ +export class AgentModeManager { + private currentMode: AgentInfo | null = null; + private originalSystemPrompt: string; + private currentSystemPrompt: string; + + constructor(originalSystemPrompt: string) { + this.originalSystemPrompt = originalSystemPrompt; + this.currentSystemPrompt = originalSystemPrompt; + } + + /** + * 设置 Agent 模式 + */ + setMode(mode: AgentInfo | 'build' | 'plan' | null): void { + // 如果是字符串模式,从 registry 获取预设 + if (typeof mode === 'string') { + const presetAgent = agentRegistry.get(mode); + if (presetAgent) { + this.currentMode = presetAgent; + this.currentSystemPrompt = this.resolveSystemPrompt(presetAgent); + } else { + // 如果找不到预设,回退到默认模式 + this.currentMode = null; + this.currentSystemPrompt = this.originalSystemPrompt; + } + return; + } + + this.currentMode = mode; + + if (mode) { + this.currentSystemPrompt = this.resolveSystemPrompt(mode); + } else { + this.currentSystemPrompt = this.originalSystemPrompt; + } + } + + /** + * 解析系统提示词 + */ + private resolveSystemPrompt(agent: AgentInfo): string { + if (!agent.prompt) { + return ''; + } + + // 如果启用了模板渲染,动态解析变量 + if (agent.promptTemplate) { + const context = createPlanContext({ + workdir: process.cwd(), + isSubagent: agent.mode === 'subagent', + }); + return renderPromptTemplate(agent.prompt, context); + } + + return agent.prompt; + } + + /** + * 获取当前系统提示词 + */ + getSystemPrompt(): string { + return this.currentSystemPrompt; + } + + /** + * 获取原始系统提示词 + */ + getOriginalSystemPrompt(): string { + return this.originalSystemPrompt; + } + + /** + * 检查是否只读模式 + */ + isReadOnlyMode(): boolean { + const permission = this.currentMode?.permission; + if (!permission) return false; + + return permission.file?.write === 'deny' && permission.file?.edit === 'deny'; + } + + /** + * 获取当前模式信息 + */ + getCurrentMode(): AgentInfo | null { + return this.currentMode; + } + + /** + * 获取当前模式名称 + */ + getModeName(): string { + return this.currentMode?.name ?? 'default'; + } + + /** + * 获取当前模式的 maxSteps 配置 + */ + getMaxSteps(): number { + return this.currentMode?.maxSteps ?? 50; + } + + /** + * 获取当前模式的权限配置 + */ + getPermission(): AgentInfo['permission'] | undefined { + return this.currentMode?.permission; + } + + /** + * 获取当前模式的工具配置 + */ + getToolConfig(): AgentInfo['tools'] | undefined { + return this.currentMode?.tools; + } +} diff --git a/packages/core/src/core/agent-tool-executor.ts b/packages/core/src/core/agent-tool-executor.ts new file mode 100644 index 0000000..e7ff991 --- /dev/null +++ b/packages/core/src/core/agent-tool-executor.ts @@ -0,0 +1,355 @@ +/** + * Agent 工具执行器 + * 负责工具的获取、过滤、转换和执行 + */ + +import type { Tool as AITool } from 'ai'; +import type { Tool, ToolResult } from '../types/index.js'; +import { buildZodSchema } from '../types/index.js'; +import type { ToolRegistry } from '../tools/registry.js'; +import type { AgentInfo } from '../agent/types.js'; +import { + checkBashPermission, + renderPromptTemplate, + createToolDescriptionContext, + type PromptContext, +} from '../agent/index.js'; +import { getHookManager } from '../hooks/index.js'; +import { getGitManager } from '../git/index.js'; + +/** + * 工具调用开始事件信息 + */ +export interface ToolStartInfo { + id: string; + toolName: string; + args: Record; +} + +/** + * 工具调用结束事件信息 + */ +export interface ToolEndInfo { + id: string; + status: 'completed' | 'error'; + result?: unknown; + error?: string; + duration?: number; +} + +/** + * 工具执行上下文 + */ +export interface ToolExecutionContext { + sessionId: string; + agentMode: AgentInfo | null; + onToolStart?: (info: ToolStartInfo) => void; + onToolEnd?: (info: ToolEndInfo) => void; +} + +/** + * Agent 工具执行器 + */ +export class AgentToolExecutor { + private registry: ToolRegistry; + private discoveredTools: Set = new Set(); + private toolDescriptionContext: PromptContext | null = null; + private currentAgentMode: AgentInfo | null = null; + + constructor(registry: ToolRegistry) { + this.registry = registry; + } + + /** + * 设置当前 Agent 模式 + */ + setAgentMode(mode: AgentInfo | null): void { + this.currentAgentMode = mode; + // 清除工具描述上下文缓存 + this.toolDescriptionContext = null; + } + + /** + * 获取可用工具(核心 + 已发现) + */ + getAvailableTools(): Tool[] { + // 核心工具 + 已发现的工具 + const coreTools = this.registry.getCoreTools(); + const discoveredTools = this.registry.getTools([...this.discoveredTools]); + let tools = [...coreTools, ...discoveredTools]; + + // 应用 Agent 模式的工具过滤 + if (this.currentAgentMode?.tools) { + tools = this.filterToolsByAgentConfig(tools); + } + + return tools; + } + + /** + * 根据 Agent 配置过滤工具 + */ + private filterToolsByAgentConfig(tools: Tool[]): Tool[] { + const toolConfig = this.currentAgentMode?.tools; + if (!toolConfig) return tools; + + let filteredTools = tools; + + // 如果设置了 enabled 列表,只保留这些工具 + if (toolConfig.enabled && toolConfig.enabled.length > 0) { + const enabledSet = new Set(toolConfig.enabled); + filteredTools = filteredTools.filter((t) => enabledSet.has(t.name)); + } + + // 如果设置了 disabled 列表,排除这些工具 + if (toolConfig.disabled && toolConfig.disabled.length > 0) { + const disabledSet = new Set(toolConfig.disabled); + filteredTools = filteredTools.filter((t) => !disabledSet.has(t.name)); + } + + // 如果禁止嵌套 Task,移除 task 工具 + if (toolConfig.noTask) { + filteredTools = filteredTools.filter((t) => t.name !== 'task'); + } + + return filteredTools; + } + + /** + * 获取或创建工具描述渲染上下文 + */ + private getToolDescriptionContext(): PromptContext { + if (!this.toolDescriptionContext) { + this.toolDescriptionContext = createToolDescriptionContext({ + agent: { + name: this.currentAgentMode?.name ?? 'default', + mode: this.currentAgentMode?.mode ?? 'primary', + isSubagent: this.currentAgentMode?.mode === 'subagent', + }, + }); + } + return this.toolDescriptionContext; + } + + /** + * 渲染工具描述中的模板变量 + */ + private renderToolDescription(description: string): string { + const context = this.getToolDescriptionContext(); + return renderPromptTemplate(description, context, { + throwOnUndefined: false, + undefinedValue: '', + }); + } + + /** + * 转换为 Vercel AI SDK 工具格式 + */ + toVercelTools(context: ToolExecutionContext): Record { + const vercelTools: Record = {}; + const availableTools = this.getAvailableTools(); + const hookManager = getHookManager(); + + for (const tool of availableTools) { + const schema = buildZodSchema(tool.parameters); + const renderedDescription = this.renderToolDescription(tool.description); + + vercelTools[tool.name] = { + description: renderedDescription, + inputSchema: schema, + execute: async (params) => { + return this.executeTool(tool, params as Record, context, hookManager); + }, + } as AITool; + } + + return vercelTools; + } + + /** + * 执行单个工具 + */ + private async executeTool( + tool: Tool, + args: Record, + context: ToolExecutionContext, + hookManager: ReturnType + ): Promise { + const callId = `${tool.name}-${Date.now()}`; + const { sessionId, onToolStart, onToolEnd } = context; + + // 触发工具执行前 hook + let finalArgs = args; + if (hookManager) { + const beforeOutput = await hookManager.triggerToolExecuteBefore({ + tool: tool.name, + sessionId, + callId, + args, + }); + + // 如果 hook 指定跳过,直接返回 + if (beforeOutput.skip && beforeOutput.skipResult) { + return beforeOutput.skipResult; + } + + finalArgs = beforeOutput.args; + } + + // Agent 级别的 Bash 权限检查 + if (tool.name === 'bash' && this.currentAgentMode?.permission?.bash) { + const command = finalArgs.command as string; + if (command) { + const action = checkBashPermission(command, this.currentAgentMode.permission.bash); + if (action === 'deny') { + return { + success: false, + output: '', + error: `[Agent 权限拒绝] 当前模式 (${this.currentAgentMode.name}) 禁止执行此命令: ${command}`, + }; + } + } + } + + // 通知工具开始 + onToolStart?.({ + id: callId, + toolName: tool.name, + args: finalArgs, + }); + + // 执行工具 + const startTime = Date.now(); + let result = await tool.execute(finalArgs); + const duration = Date.now() - startTime; + + // 触发工具执行后 hook + if (hookManager) { + const afterOutput = await hookManager.triggerToolExecuteAfter( + { + tool: tool.name, + sessionId, + callId, + args: finalArgs, + duration, + }, + result + ); + result = afterOutput.result; + + // 对于文件操作工具,触发相应的文件 hook 和 Git 自动提交 + if (result.success) { + await this.triggerFileHooksIfNeeded(tool.name, finalArgs, sessionId, hookManager); + } + } + + // 通知工具结束 + onToolEnd?.({ + id: callId, + status: result.success ? 'completed' : 'error', + result: result.success ? result.output : undefined, + error: result.success ? undefined : result.error, + duration, + }); + + // 如果是 tool_search 调用,解析结果并注入发现的工具 + if (tool.name === 'tool_search' && result.success) { + this.handleToolSearchResult(result.output); + } + + return result; + } + + /** + * 触发文件相关的 Hooks + */ + private async triggerFileHooksIfNeeded( + toolName: string, + args: Record, + sessionId: string, + hookManager: NonNullable> + ): Promise { + const filePath = args.path as string | undefined; + if (!filePath) return; + + const gitManager = getGitManager(); + + if (toolName === 'write_file') { + await hookManager.triggerFileCreated({ + path: filePath, + tool: toolName, + sessionId, + }); + if (gitManager) { + await gitManager.onFileChanged(filePath, 'create'); + } + } else if (toolName === 'edit_file') { + await hookManager.triggerFileEdited({ + path: filePath, + tool: toolName, + sessionId, + }); + if (gitManager) { + await gitManager.onFileChanged(filePath, 'modify'); + } + } else if (toolName === 'delete_file') { + await hookManager.triggerFileDeleted({ + path: filePath, + tool: toolName, + sessionId, + }); + if (gitManager) { + await gitManager.onFileChanged(filePath, 'delete'); + } + } + } + + /** + * 处理 tool_search 的结果,将发现的工具添加到可用列表 + */ + handleToolSearchResult(output: string): void { + // 解析输出,提取工具名称 + // 格式: "- tool_name: description [category]" + const matches = output.matchAll(/^- (\w+):/gm); + for (const match of matches) { + const toolName = match[1]; + if (this.registry.has(toolName)) { + this.discoveredTools.add(toolName); + } + } + } + + /** + * 获取已发现的工具 + */ + getDiscoveredTools(): string[] { + return [...this.discoveredTools]; + } + + /** + * 设置已发现的工具(用于会话恢复) + */ + setDiscoveredTools(tools: string[]): void { + this.discoveredTools = new Set(tools); + } + + /** + * 清除已发现的工具 + */ + clearDiscoveredTools(): void { + this.discoveredTools.clear(); + } + + /** + * 获取工具数量统计 + */ + getToolCount(): { core: number; discovered: number; total: number } { + const coreCount = this.registry.getCoreTools().length; + const discoveredCount = this.discoveredTools.size; + return { + core: coreCount, + discovered: discoveredCount, + total: coreCount + discoveredCount, + }; + } +} diff --git a/packages/core/src/core/agent-vision-handler.ts b/packages/core/src/core/agent-vision-handler.ts new file mode 100644 index 0000000..0d27201 --- /dev/null +++ b/packages/core/src/core/agent-vision-handler.ts @@ -0,0 +1,122 @@ +/** + * Agent Vision 处理器 + * 负责 Vision 处理的委托逻辑 + */ + +import type { AgentConfig } from '../types/index.js'; +import type { ToolRegistry } from '../tools/registry.js'; +import type { ImageData } from '../agent/types.js'; +import { agentRegistry, AgentExecutor } from '../agent/index.js'; +import { loadVisionConfig } from '../utils/config.js'; + +/** + * Agent Vision 处理器 + */ +export class AgentVisionHandler { + private config: AgentConfig; + private registry: ToolRegistry | null = null; + + constructor(config: AgentConfig) { + this.config = config; + } + + /** + * 设置工具注册表 + */ + setRegistry(registry: ToolRegistry): void { + this.registry = registry; + } + + /** + * 检查当前模型是否支持 Vision(图片理解) + */ + supportsVision(): boolean { + const model = this.config.model.toLowerCase(); + + // Anthropic Claude 模型支持 vision + if (this.config.provider === 'anthropic') { + // Claude 3 及以上版本支持 vision + return model.includes('claude-3') || model.includes('claude-4'); + } + + // OpenAI GPT-4 系列支持 vision + if (this.config.provider === 'openai') { + // GPT-4o, GPT-4 Turbo, GPT-4 Vision 等支持 + return model.includes('gpt-4'); + } + + // DeepSeek 目前不支持 vision + if (this.config.provider === 'deepseek') { + return false; + } + + return false; + } + + /** + * 使用 Vision Agent 处理图片 + * 当主模型不支持 vision 时,委托给 Vision Agent 分析图片 + * @returns 包含图片分析结果的文本消息,或 null 表示失败 + */ + async processWithVisionAgent( + images: ImageData[], + userText?: string, + onStream?: (text: string) => void + ): Promise { + // 检查 Vision 配置是否可用 + const visionConfig = loadVisionConfig(); + if (!visionConfig) { + onStream?.('\n⚠ Vision 服务未配置,无法处理图片\n'); + return null; + } + + // 获取 Vision Agent + const visionAgent = agentRegistry.get('vision'); + if (!visionAgent) { + onStream?.('\n⚠ Vision Agent 未注册\n'); + return null; + } + + // 确保有工具注册表 + if (!this.registry) { + onStream?.('\n⚠ 工具注册表未初始化\n'); + return null; + } + + onStream?.(`\n[委托 Vision Agent (${visionConfig.model}) 分析图片...]\n`); + + // 构建 Vision 配置 + const visionAgentConfig: AgentConfig = { + ...this.config, + provider: visionConfig.provider, + apiKey: visionConfig.apiKey, + model: visionConfig.model, + baseUrl: visionConfig.baseUrl, + }; + + // 创建 Vision Agent 执行器 + const executor = new AgentExecutor(visionAgent, visionAgentConfig, this.registry); + + // 构建提示词 + const prompt = userText || '请详细描述这张图片的内容'; + + // 执行 Vision 分析 + const result = await executor.execute(prompt, { + workdir: process.cwd(), + images, + onStream: undefined, // Vision Agent 不使用流式输出 + }); + + if (!result.success) { + onStream?.(`\n⚠ Vision 分析失败: ${result.error}\n`); + return null; + } + + onStream?.('\n[Vision 分析完成]\n'); + + // 构建带分析结果的文本消息 + const combinedText = `[图片分析结果 - 由 ${visionConfig.model} 提供]\n${result.text}\n\n用户问题: ${userText || '(无附加问题)'}`; + + return combinedText; + } +} diff --git a/packages/core/src/core/agent.ts b/packages/core/src/core/agent.ts index 0508f01..b899516 100644 --- a/packages/core/src/core/agent.ts +++ b/packages/core/src/core/agent.ts @@ -1,117 +1,64 @@ -import { - generateText, - streamText, - stepCountIs, - type ModelMessage, - type Tool as AITool, - type LanguageModel, -} from 'ai'; -import type { Tool, ToolResult, Message, AgentConfig, UserInput, ContentBlock, ChatResult } from '../types/index.js'; -import { buildZodSchema } from '../types/index.js'; +/** + * Agent 主类 + * 作为编排器,委托具体工作给各个子模块 + */ + +import type { LanguageModel } from 'ai'; +import type { Tool, AgentConfig, UserInput, Message, ChatResult } from '../types/index.js'; import { ToolRegistry } from '../tools/registry.js'; import { SessionManager } from '../session/index.js'; import { CompressionManager, - CompressionStatus, type TokenUsage, type CompressionConfig, } from '../context/index.js'; -import type { AgentInfo, ImageData } from '../agent/types.js'; -import { - agentRegistry, - AgentExecutor, - checkBashPermission, - renderPromptTemplate, - createPlanContext, - createToolDescriptionContext, - type PromptContext, -} from '../agent/index.js'; -import { loadVisionConfig } from '../utils/config.js'; +import type { AgentInfo } from '../agent/types.js'; +import { agentRegistry } from '../agent/index.js'; import { getProviderRegistry, resolveApiKey } from '../provider/index.js'; -import { getHookManager } from '../hooks/index.js'; -import { getGitManager } from '../git/index.js'; -import { - createDoomLoopDetector, - type DoomLoopDetector, - DOOM_LOOP_WARNING, -} from './doom-loop.js'; import { todoManager } from '../tools/todo/todo-manager.js'; import { initTaskContext } from '../tools/task/index.js'; -/** - * 工具调用开始事件信息 - */ -export interface ToolStartInfo { - id: string; - toolName: string; - args: Record; -} +// 子模块 +import { AgentToolExecutor, type ToolStartInfo, type ToolEndInfo } from './agent-tool-executor.js'; +import { AgentMessageHandler, type DoomLoopInfo } from './agent-message-handler.js'; +import { AgentModeManager } from './agent-mode-manager.js'; +import { AgentVisionHandler } from './agent-vision-handler.js'; -/** - * 工具调用结束事件信息 - */ -export interface ToolEndInfo { - id: string; - status: 'completed' | 'error'; - result?: unknown; - error?: string; - duration?: number; -} - -/** - * Doom Loop 检测事件信息 - */ -export interface DoomLoopInfo { - toolName: string; - count: number; -} +// 重新导出类型 +export type { ToolStartInfo, ToolEndInfo, DoomLoopInfo }; /** * Agent.chat() 选项 */ export interface AgentChatOptions { onStream?: (text: string) => void; - /** 工具开始执行回调 */ onToolStart?: (info: ToolStartInfo) => void; - /** 工具执行完成回调 */ onToolEnd?: (info: ToolEndInfo) => void; - /** Doom Loop 检测回调 */ onDoomLoop?: (info: DoomLoopInfo) => void; abortSignal?: AbortSignal; } +/** + * Agent 主类 + */ export class Agent { private getModel: (model: string) => LanguageModel; private config: AgentConfig; - private conversationHistory: ModelMessage[] = []; - // 工具注册表 - private registry: ToolRegistry | null = null; + // 子模块 + private toolExecutor: AgentToolExecutor | null = null; + private messageHandler: AgentMessageHandler; + private modeManager: AgentModeManager; + private visionHandler: AgentVisionHandler; - // 已发现的工具(通过 tool_search 发现的) - private discoveredTools: Set = new Set(); - - // 会话管理器(可选) + // 会话管理 private sessionManager: SessionManager | null = null; - // 压缩管理器 - private compressionManager: CompressionManager; - - // Doom Loop 检测器 - private doomLoopDetector: DoomLoopDetector = createDoomLoopDetector(); - - // 当前 Agent 模式(null 表示默认模式) - private currentAgentMode: AgentInfo | null = null; - - // 原始 system prompt(用于切换回 default 时恢复) - private originalSystemPrompt: string; - - // 工具描述渲染上下文缓存 - private toolDescriptionContext: PromptContext | null = null; + // Auto-approve 配置 + private autoApproveConfig: { file?: { write?: 'allow'; edit?: 'allow' } } | null = null; constructor(config: AgentConfig, compressionConfig?: Partial) { this.config = config; - this.originalSystemPrompt = config.systemPrompt; // 使用 ProviderRegistry 获取模型工厂 const providerRegistry = getProviderRegistry(); @@ -120,7 +67,7 @@ export class Agent { baseUrl: config.baseUrl, }); - // 构建压缩配置,使用模型的 contextWindow(如果有) + // 构建压缩配置 const finalCompressionConfig: Partial = { ...compressionConfig, }; @@ -129,12 +76,16 @@ export class Agent { } // 初始化压缩管理器 - this.compressionManager = new CompressionManager(finalCompressionConfig); - // 设置主模型(作为摘要模型的后备) - this.compressionManager.setModel(this.getModel(config.model)); + const compressionManager = new CompressionManager(finalCompressionConfig); + compressionManager.setModel(this.getModel(config.model)); // 从 Agent Registry 加载 Summary Agent 配置 - this.initSummaryModel(config, providerRegistry); + this.initSummaryModel(config, providerRegistry, compressionManager); + + // 初始化子模块 + this.messageHandler = new AgentMessageHandler(compressionManager); + this.modeManager = new AgentModeManager(config.systemPrompt); + this.visionHandler = new AgentVisionHandler(config); } /** @@ -142,19 +93,16 @@ export class Agent { */ private initSummaryModel( config: AgentConfig, - providerRegistry: ReturnType + providerRegistry: ReturnType, + compressionManager: CompressionManager ): void { - // 获取 Summary Agent(internal 模式) const summaryAgentInfo = agentRegistry.getInternal('summary'); if (!summaryAgentInfo?.model) { return; } const modelConfig = summaryAgentInfo.model; - // 确定 provider(默认使用主配置的 provider) const provider = modelConfig.provider || config.provider; - - // 从 ProviderRegistry 获取 API Key const providerConfig = providerRegistry.getConfig(provider); const apiKey = resolveApiKey(providerConfig) || config.apiKey; @@ -162,32 +110,35 @@ export class Agent { return; } - // 设置 Summary 模型 const baseUrl = providerConfig?.baseUrl; - this.compressionManager.setSummaryModelFromAgentConfig(modelConfig, apiKey, baseUrl); + compressionManager.setSummaryModelFromAgentConfig(modelConfig, apiKey, baseUrl); } /** - * 设置工具注册表(新模式:支持动态工具发现) + * 设置工具注册表 */ setRegistry(registry: ToolRegistry): void { - this.registry = registry; + this.toolExecutor = new AgentToolExecutor(registry); + this.visionHandler.setRegistry(registry); } /** - * 设置会话管理器(启用会话持久化) + * 设置会话管理器 */ setSessionManager(manager: SessionManager): void { this.sessionManager = manager; - // 初始化 todoManager,使其能够访问会话数据 + + // 初始化 todoManager todoManager.setSessionManager(manager); - // 初始化 Task 工具上下文(使子 Agent 能够正常工作) + + // 初始化 Task 工具上下文 initTaskContext(this.config, manager); + // 从会话恢复状态 const session = manager.getSession(); if (session) { - this.conversationHistory = [...session.messages]; - this.discoveredTools = new Set(session.discoveredTools); + this.messageHandler.setHistory([...session.messages]); + this.toolExecutor?.setDiscoveredTools(session.discoveredTools); } } @@ -199,260 +150,33 @@ export class Agent { } /** - * 获取当前可用的工具 - * 返回核心工具 + 已发现的工具,如果当前有 Agent 模式,应用工具过滤 - */ - private getAvailableTools(): Tool[] { - if (!this.registry) { - throw new Error('工具注册表未初始化,请先调用 setRegistry()'); - } - - // 核心工具 + 已发现的工具 - const coreTools = this.registry.getCoreTools(); - const discoveredTools = this.registry.getTools([...this.discoveredTools]); - let tools = [...coreTools, ...discoveredTools]; - - // 应用 Agent 模式的工具过滤 - if (this.currentAgentMode?.tools) { - tools = this.filterToolsByAgentConfig(tools); - } - - return tools; - } - - /** - * 根据 Agent 配置过滤工具 - */ - private filterToolsByAgentConfig(tools: Tool[]): Tool[] { - const toolConfig = this.currentAgentMode?.tools; - if (!toolConfig) return tools; - - let filteredTools = tools; - - // 如果设置了 enabled 列表,只保留这些工具 - if (toolConfig.enabled && toolConfig.enabled.length > 0) { - const enabledSet = new Set(toolConfig.enabled); - filteredTools = filteredTools.filter((t) => enabledSet.has(t.name)); - } - - // 如果设置了 disabled 列表,排除这些工具 - if (toolConfig.disabled && toolConfig.disabled.length > 0) { - const disabledSet = new Set(toolConfig.disabled); - filteredTools = filteredTools.filter((t) => !disabledSet.has(t.name)); - } - - // 如果禁止嵌套 Task,移除 task 工具 - if (toolConfig.noTask) { - filteredTools = filteredTools.filter((t) => t.name !== 'task'); - } - - return filteredTools; - } - - /** - * 获取或创建工具描述渲染上下文 - */ - private getToolDescriptionContext(): PromptContext { - if (!this.toolDescriptionContext) { - this.toolDescriptionContext = createToolDescriptionContext({ - agent: { - name: this.currentAgentMode?.name ?? 'default', - mode: this.currentAgentMode?.mode ?? 'primary', - isSubagent: this.currentAgentMode?.mode === 'subagent', - }, - }); - } - return this.toolDescriptionContext; - } - - /** - * 渲染工具描述中的模板变量 - */ - private renderToolDescription(description: string): string { - const context = this.getToolDescriptionContext(); - return renderPromptTemplate(description, context, { - throwOnUndefined: false, - undefinedValue: '', - }); - } - - /** - * 将工具转换为 Vercel AI SDK 的工具格式 - */ - private getVercelTools(): Record { - const vercelTools: Record = {}; - const availableTools = this.getAvailableTools(); - const hookManager = getHookManager(); - - for (const tool of availableTools) { - const schema = buildZodSchema(tool.parameters); - const renderedDescription = this.renderToolDescription(tool.description); - - vercelTools[tool.name] = { - description: renderedDescription, - inputSchema: schema, - execute: async (params) => { - const args = params as Record; - const callId = `${tool.name}-${Date.now()}`; - const sessionId = this.sessionManager?.getSession()?.id || 'default'; - - // 触发工具执行前 hook - let finalArgs = args; - if (hookManager) { - const beforeOutput = await hookManager.triggerToolExecuteBefore({ - tool: tool.name, - sessionId, - callId, - args, - }); - - // 如果 hook 指定跳过,直接返回 - if (beforeOutput.skip && beforeOutput.skipResult) { - return beforeOutput.skipResult; - } - - finalArgs = beforeOutput.args; - } - - // Agent 级别的权限检查(在全局权限检查之前) - if (tool.name === 'bash' && this.currentAgentMode?.permission?.bash) { - const command = finalArgs.command as string; - if (command) { - const action = checkBashPermission(command, this.currentAgentMode.permission.bash); - if (action === 'deny') { - return { - success: false, - output: '', - error: `[Agent 权限拒绝] 当前模式 (${this.currentAgentMode.name}) 禁止执行此命令: ${command}`, - }; - } - } - } - - // 执行工具 - const startTime = Date.now(); - let result = await tool.execute(finalArgs); - const duration = Date.now() - startTime; - - // 触发工具执行后 hook - if (hookManager) { - const afterOutput = await hookManager.triggerToolExecuteAfter( - { - tool: tool.name, - sessionId, - callId, - args: finalArgs, - duration, - }, - result - ); - result = afterOutput.result; - - // 对于文件操作工具,触发相应的文件 hook 和 Git 自动提交 - if (result.success) { - const filePath = finalArgs.path as string | undefined; - if (filePath) { - const gitManager = getGitManager(); - - if (tool.name === 'write_file') { - await hookManager.triggerFileCreated({ - path: filePath, - tool: tool.name, - sessionId, - }); - // Git 自动提交 - if (gitManager) { - await gitManager.onFileChanged(filePath, 'create'); - } - } else if (tool.name === 'edit_file') { - await hookManager.triggerFileEdited({ - path: filePath, - tool: tool.name, - sessionId, - }); - // Git 自动提交 - if (gitManager) { - await gitManager.onFileChanged(filePath, 'modify'); - } - } else if (tool.name === 'delete_file') { - await hookManager.triggerFileDeleted({ - path: filePath, - tool: tool.name, - sessionId, - }); - // Git 自动提交 - if (gitManager) { - await gitManager.onFileChanged(filePath, 'delete'); - } - } - } - } - } - - // 如果是 tool_search 调用,解析结果并注入发现的工具 - if (tool.name === 'tool_search' && result.success) { - this.handleToolSearchResult(result.output); - } - - return result; - }, - } as AITool; - } - - return vercelTools; - } - - /** - * 处理 tool_search 的结果,将发现的工具添加到可用列表 - */ - private handleToolSearchResult(output: string): void { - // 解析输出,提取工具名称 - // 格式: "- tool_name: description [category]" - const matches = output.matchAll(/^- (\w+):/gm); - for (const match of matches) { - const toolName = match[1]; - if (this.registry?.has(toolName)) { - this.discoveredTools.add(toolName); - } - } - } - - /** - * 发送消息并处理响应(流式) - * @param userMessage 用户消息文本或包含图片的 UserInput - * @param options 选项,包含 onStream 回调和 abortSignal - * @returns ChatResult 包含最终文本和完整的响应消息链 + * 主聊天入口 */ async chat(userMessage: string | UserInput, options?: AgentChatOptions | ((text: string) => void)): Promise { // 兼容旧的 onStream 参数 const opts: AgentChatOptions = typeof options === 'function' ? { onStream: options } : (options || {}); const { onStream, onToolStart, onToolEnd, onDoomLoop, abortSignal } = opts; - // 重置 doom loop 检测器(每次对话开始时) - this.doomLoopDetector.reset(); + if (!this.toolExecutor) { + throw new Error('工具注册表未初始化,请先调用 setRegistry()'); + } + + // 重置 doom loop 检测器 + this.messageHandler.resetDoomLoop(); - // 工具调用时间跟踪 - const toolStartTimes = new Map(); - // Doom loop 检测状态 - let doomLoopTriggered = false; // 处理带图片的消息 let processedMessage = userMessage; - if (typeof userMessage !== 'string' && userMessage.images && userMessage.images.length > 0) { - // 检查当前模型是否支持 vision - if (!this.supportsVision()) { - // 不支持 vision,尝试使用 Vision Agent 处理图片 - const visionResult = await this.processImagesWithVisionAgent( + if (!this.visionHandler.supportsVision()) { + const visionResult = await this.visionHandler.processWithVisionAgent( userMessage.images, userMessage.text, onStream ); if (visionResult) { - // 成功,将图片分析结果转换为文本消息 processedMessage = visionResult; } else { - // 失败,返回错误信息 const errorText = '无法处理图片:当前模型不支持图片理解,且 Vision 服务未配置或调用失败。'; return { text: errorText, messages: [] }; } @@ -460,257 +184,73 @@ export class Agent { } // 构建消息内容 - let messageContent: string | ContentBlock[]; - - if (typeof processedMessage === 'string') { - // 纯文本消息 - messageContent = processedMessage; - } else { - // 带图片的消息 - const blocks: ContentBlock[] = []; - - // 添加图片 - if (processedMessage.images && processedMessage.images.length > 0) { - for (const img of processedMessage.images) { - blocks.push({ - type: 'image', - image: img.data, - mimeType: img.mimeType, - }); - } - } - - // 添加文本 - if (processedMessage.text) { - blocks.push({ - type: 'text', - text: processedMessage.text, - }); - } - - messageContent = blocks.length === 1 && blocks[0].type === 'text' - ? blocks[0].text - : blocks; - } + const messageContent = this.messageHandler.buildUserMessageContent(processedMessage); // 添加用户消息到历史 - this.conversationHistory.push({ - role: 'user', - content: messageContent, - } as ModelMessage); + this.messageHandler.addUserMessage(messageContent); - const vercelTools = this.getVercelTools(); - let fullResponse = ''; - let responseMessages: ModelMessage[] = []; + // 同步 Agent 模式到工具执行器 + this.toolExecutor.setAgentMode(this.modeManager.getCurrentMode()); + // 获取工具 + const vercelTools = this.toolExecutor.toVercelTools({ + sessionId: this.sessionManager?.getSession()?.id || 'default', + agentMode: this.modeManager.getCurrentMode(), + onToolStart, + onToolEnd, + }); + + // 配置消息处理 + const handlerConfig = { + model: this.getModel(this.config.model), + systemPrompt: this.modeManager.getSystemPrompt(), + maxTokens: this.config.maxTokens, + maxSteps: this.modeManager.getMaxSteps(), + }; + + // 执行聊天 + let result: ChatResult; if (onStream) { - // 流式模式 - // 获取当前 Agent 的 maxSteps 配置(默认 50) - const maxSteps = this.currentAgentMode?.maxSteps ?? 50; - - const result = streamText({ - model: this.getModel(this.config.model), - system: this.config.systemPrompt, - messages: this.conversationHistory, - tools: doomLoopTriggered ? {} : vercelTools, // doom loop 时禁用工具 - maxOutputTokens: this.config.maxTokens, - stopWhen: stepCountIs(maxSteps), - abortSignal, // 支持取消 - onChunk: ({ chunk }) => { - if (chunk.type === 'tool-call') { - // AI SDK 中工具参数字段名为 input - const toolCallChunk = chunk as { toolCallId: string; toolName: string; input: unknown }; - const toolCallId = toolCallChunk.toolCallId || `tool-${Date.now()}`; - - // Doom Loop 检测:记录工具调用 - this.doomLoopDetector.record( - toolCallChunk.toolName, - toolCallChunk.input - ); - - // 检查是否触发 doom loop - if (this.doomLoopDetector.isTriggered() && !doomLoopTriggered) { - doomLoopTriggered = true; - const toolName = this.doomLoopDetector.getLastToolName() || toolCallChunk.toolName; - - // 通知回调 - onDoomLoop?.({ toolName, count: 3 }); - - // 输出警告 - onStream?.(`\n[警告: 检测到 Doom Loop - ${toolName} 被重复调用]\n`); - onStream?.(DOOM_LOOP_WARNING); - } - - // 记录开始时间 - toolStartTimes.set(toolCallId, Date.now()); - - // 调用 onToolStart 回调 - if (onToolStart) { - onToolStart({ - id: toolCallId, - toolName: toolCallChunk.toolName, - args: (toolCallChunk.input as Record) || {}, - }); - } else { - // 仅在没有 onToolStart 回调时输出文本(向后兼容 CLI) - onStream?.(`\n[调用工具: ${toolCallChunk.toolName}]\n`); - } - } else if (chunk.type === 'tool-result') { - const toolResultChunk = chunk as { toolCallId: string; output?: ToolResult }; - const toolCallId = toolResultChunk.toolCallId || ''; - const output = toolResultChunk.output; - - // 计算执行时长 - const startTime = toolStartTimes.get(toolCallId); - const duration = startTime ? Date.now() - startTime : undefined; - toolStartTimes.delete(toolCallId); - - if (output && typeof output === 'object') { - // 调用 onToolEnd 回调 - if (onToolEnd) { - onToolEnd({ - id: toolCallId, - status: output.success ? 'completed' : 'error', - result: output.success ? output.output : undefined, - error: output.success ? undefined : output.error, - duration, - }); - } else { - // 仅在没有 onToolEnd 回调时输出文本(向后兼容 CLI) - if (output.success) { - const displayOutput = - output.output.length > 500 - ? output.output.substring(0, 500) + '...(截断)' - : output.output; - onStream?.(`[结果: ${displayOutput}]\n`); - } else { - onStream?.(`[错误: ${output.error}]\n`); - } - } - } - } - }, - }); - - // 流式输出文本 - let aborted = false; - try { - for await (const chunk of result.textStream) { - // 检查是否已中止 - if (abortSignal?.aborted) { - aborted = true; - break; - } - fullResponse += chunk; - onStream(chunk); - } - - // 如果是手动中止(通过 break 退出),保存已收到的内容 - if (aborted) { - onStream?.('\n[已取消]\n'); - if (fullResponse) { - this.conversationHistory.push({ - role: 'assistant', - content: fullResponse + '\n[已取消]', - } as ModelMessage); - await this.persistSession(); - } - return { text: fullResponse, messages: [] }; - } - - // 等待完成并获取完整的响应消息(包括工具调用和结果) - const response = await result.response; - responseMessages = response.messages as ModelMessage[]; - } catch (error) { - // 如果是中止错误(AbortController.abort() 抛出),优雅处理 - if (error instanceof Error && (error.name === 'AbortError' || abortSignal?.aborted)) { - onStream?.('\n[已取消]\n'); - // 取消时也要保存已收到的内容 - if (fullResponse) { - this.conversationHistory.push({ - role: 'assistant', - content: fullResponse + '\n[已取消]', - } as ModelMessage); - await this.persistSession(); - } - return { text: fullResponse, messages: [] }; - } - throw error; - } + result = await this.messageHandler.streamChat( + handlerConfig, + vercelTools, + { onStream, onToolStart, onToolEnd, onDoomLoop }, + abortSignal + ); } else { - // 非流式模式 - // 获取当前 Agent 的 maxSteps 配置(默认 50) - const maxSteps = this.currentAgentMode?.maxSteps ?? 50; - - const result = await generateText({ - model: this.getModel(this.config.model), - system: this.config.systemPrompt, - messages: this.conversationHistory, - tools: vercelTools, - maxOutputTokens: this.config.maxTokens, - stopWhen: stepCountIs(maxSteps), - abortSignal, // 支持取消 - }); - - fullResponse = result.text; - responseMessages = result.response.messages as ModelMessage[]; + result = await this.messageHandler.generateChat( + handlerConfig, + vercelTools, + abortSignal + ); } - // 将完整的响应消息添加到历史(包括工具调用和结果) - this.conversationHistory.push(...responseMessages); - - // 检查是否需要自动压缩 - if (this.compressionManager.shouldCompress(this.conversationHistory)) { - const result = await this.compressionManager.compress(this.conversationHistory); - - if (result.status === CompressionStatus.SUCCESS && result.freedTokens > 0) { - this.conversationHistory = result.messages; - if (onStream) { - const typeLabel = result.type === 'both' ? 'prune+摘要' : result.type === 'compaction' ? '摘要' : 'prune'; - onStream(`\n[自动压缩(${typeLabel}): 释放了 ${(result.freedTokens / 1000).toFixed(1)}k tokens]\n`); - } - } else if (result.status === CompressionStatus.FAILED_EMPTY_SUMMARY) { - if (onStream) { - onStream('\n[压缩失败: 摘要生成为空,已跳过]\n'); - } - } else if (result.status === CompressionStatus.FAILED_TOKEN_INFLATED) { - if (onStream) { - onStream('\n[压缩失败: 摘要反而增加了 token,已跳过]\n'); - } - } else if (result.status === CompressionStatus.FAILED_ERROR) { - if (onStream) { - onStream(`\n[压缩失败: ${result.error || '未知错误'}]\n`); - } - } - } + // 自动压缩 + await this.messageHandler.autoCompress(onStream); // 持久化会话 await this.persistSession(); - return { - text: fullResponse, - messages: responseMessages, - }; + return result; } /** * 持久化当前会话状态 */ private async persistSession(): Promise { - if (!this.sessionManager) return; + if (!this.sessionManager || !this.toolExecutor) return; - await this.sessionManager.setMessages(this.conversationHistory); - await this.sessionManager.setDiscoveredTools([...this.discoveredTools]); + await this.sessionManager.setMessages(this.messageHandler.getHistory()); + await this.sessionManager.setDiscoveredTools(this.toolExecutor.getDiscoveredTools()); } /** * 清空对话历史和发现的工具 */ async clearHistory(): Promise { - this.conversationHistory = []; - this.discoveredTools.clear(); + this.messageHandler.clearHistory(); + this.toolExecutor?.clearDiscoveredTools(); - // 如果有会话管理器,创建新会话 if (this.sessionManager) { await this.sessionManager.newSession(); } @@ -720,13 +260,10 @@ export class Agent { * 获取对话历史 */ getHistory(): Message[] { - return this.conversationHistory - .filter( - (msg): msg is ModelMessage & { role: 'user' | 'assistant' } => - msg.role === 'user' || msg.role === 'assistant' - ) + return this.messageHandler.getHistory() + .filter((msg) => msg.role === 'user' || msg.role === 'assistant') .map((msg) => ({ - role: msg.role, + role: msg.role as 'user' | 'assistant', content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content), })); } @@ -735,169 +272,90 @@ export class Agent { * 获取当前可用工具的数量 */ getToolCount(): { core: number; discovered: number; total: number } { - if (!this.registry) { - return { core: 0, discovered: 0, total: 0 }; - } - const coreCount = this.registry.getCoreTools().length; - const discoveredCount = this.discoveredTools.size; - return { - core: coreCount, - discovered: discoveredCount, - total: coreCount + discoveredCount, - }; + return this.toolExecutor?.getToolCount() ?? { core: 0, discovered: 0, total: 0 }; } /** * 获取当前上下文使用情况 */ getContextUsage(): TokenUsage { - return this.compressionManager.calculateUsage(this.conversationHistory); + return this.messageHandler.getContextUsage(); } /** - * 获取格式化的上下文使用情况(用于 CLI 显示) + * 获取格式化的上下文使用情况 */ getContextUsageFormatted(): string { - return this.compressionManager.formatUsage(this.conversationHistory); + return this.messageHandler.getContextUsageFormatted(); } /** * 获取压缩管理器 */ getCompressionManager(): CompressionManager { - return this.compressionManager; + return this.messageHandler.getCompressionManager(); } /** - * 手动压缩对话历史(用于 /compact 命令) + * 手动压缩对话历史 */ async compactHistory(): Promise<{ freedTokens: number; type: string }> { - const result = await this.compressionManager.forceCompress(this.conversationHistory); - if (result.freedTokens > 0) { - this.conversationHistory = result.messages; - await this.persistSession(); - } - return { - freedTokens: result.freedTokens, - type: result.type, - }; + const result = await this.messageHandler.forceCompress(); + await this.persistSession(); + return result; } + // ============================================================================ + // Agent 模式管理 + // ============================================================================ + /** * 切换 Agent 模式 - * @param agent AgentInfo 对象或模式字符串 ('build'/'plan') */ setAgentMode(agent: AgentInfo | 'build' | 'plan' | null): void { - // 清除工具描述上下文缓存,下次使用时重新创建 - this.toolDescriptionContext = null; - - // 如果是字符串模式,从 registry 获取预设 - if (typeof agent === 'string') { - const presetAgent = agentRegistry.get(agent); - if (presetAgent) { - this.currentAgentMode = presetAgent; - // 使用 resolveSystemPrompt 获取提示词 - this.config = { - ...this.config, - systemPrompt: this.resolveSystemPrompt(presetAgent), - }; - } else { - // 如果找不到预设,回退到默认模式 - this.currentAgentMode = null; - this.config = { - ...this.config, - systemPrompt: this.originalSystemPrompt, - }; - } - return; - } - - this.currentAgentMode = agent; - - if (agent) { - // 切换到指定 Agent,使用 resolveSystemPrompt 获取提示词 - this.config = { - ...this.config, - systemPrompt: this.resolveSystemPrompt(agent), - }; - } else { - // 切换回 default,恢复原始 prompt - this.config = { - ...this.config, - systemPrompt: this.originalSystemPrompt, - }; - } - } - - /** - * 解析系统提示词 - * - * 如果 Agent 启用了 promptTemplate,则动态渲染模板变量 - */ - private resolveSystemPrompt(agent: AgentInfo): string { - if (!agent.prompt) { - return ''; - } - - // 如果启用了模板渲染,动态解析变量 - if (agent.promptTemplate) { - const context = createPlanContext({ - workdir: process.cwd(), - isSubagent: agent.mode === 'subagent', - }); - return renderPromptTemplate(agent.prompt, context); - } - - return agent.prompt; + this.modeManager.setMode(agent); + this.toolExecutor?.setAgentMode(this.modeManager.getCurrentMode()); } /** * 切换 Agent 模式(保留对话历史) - * - * 与 setAgentMode 不同,switchMode 会保留对话历史, - * 适用于会话中动态切换 Build ↔ Plan 模式。 - * - * @param mode 目标模式 - * @param preserveHistory 是否保留对话历史(默认 true) */ switchMode(mode: AgentInfo | 'build' | 'plan' | null, preserveHistory = true): void { - // 保存当前对话历史 - const currentHistory = preserveHistory ? [...this.conversationHistory] : []; - - // 执行模式切换 + const currentHistory = preserveHistory ? [...this.messageHandler.getHistory()] : []; this.setAgentMode(mode); - - // 恢复对话历史 if (preserveHistory) { - this.conversationHistory = currentHistory; + this.messageHandler.setHistory(currentHistory); } - - // 重置 doom loop 检测器(新模式重新计数) - this.doomLoopDetector.reset(); + this.messageHandler.resetDoomLoop(); } /** * 检查当前是否为只读模式 - * - * 只读模式:file.write === 'deny' && file.edit === 'deny' */ isReadOnlyMode(): boolean { - const permission = this.currentAgentMode?.permission; - if (!permission) return false; + return this.modeManager.isReadOnlyMode(); + } - return permission.file?.write === 'deny' && permission.file?.edit === 'deny'; + /** + * 获取当前 Agent 模式 + */ + getAgentMode(): AgentInfo | null { + return this.modeManager.getCurrentMode(); + } + + /** + * 获取当前 Agent 名称 + */ + getAgentModeName(): string { + return this.modeManager.getModeName(); } // ============================================================================ - // Auto-approve 功能(用于前端 Build 模式的自动授权) + // Auto-approve 功能 // ============================================================================ - /** 临时自动授权配置 */ - private autoApproveConfig: { file?: { write?: 'allow'; edit?: 'allow' } } | null = null; - /** * 设置自动授权配置 - * 仅影响 file write 和 file edit 操作(不包含 delete) */ setAutoApprove(config: { file?: { write?: 'allow'; edit?: 'allow' } }): void { this.autoApproveConfig = config; @@ -917,111 +375,15 @@ export class Agent { return this.autoApproveConfig; } - /** - * 获取当前 Agent 模式 - */ - getAgentMode(): AgentInfo | null { - return this.currentAgentMode; - } + // ============================================================================ + // Vision 功能 + // ============================================================================ /** - * 获取当前 Agent 名称 - */ - getAgentModeName(): string { - return this.currentAgentMode?.name ?? 'default'; - } - - /** - * 使用 Vision Agent 处理图片 - * 当主模型不支持 vision 时,委托给 Vision Agent 分析图片 - * @returns 包含图片分析结果的文本消息,或 null 表示失败 - */ - private async processImagesWithVisionAgent( - images: ImageData[], - userText?: string, - onStream?: (text: string) => void - ): Promise { - // 检查 Vision 配置是否可用 - const visionConfig = loadVisionConfig(); - if (!visionConfig) { - onStream?.('\n⚠ Vision 服务未配置,无法处理图片\n'); - return null; - } - - // 获取 Vision Agent - const visionAgent = agentRegistry.get('vision'); - if (!visionAgent) { - onStream?.('\n⚠ Vision Agent 未注册\n'); - return null; - } - - // 确保有工具注册表 - if (!this.registry) { - onStream?.('\n⚠ 工具注册表未初始化\n'); - return null; - } - - onStream?.(`\n[委托 Vision Agent (${visionConfig.model}) 分析图片...]\n`); - - // 构建 Vision 配置 - const visionAgentConfig: AgentConfig = { - ...this.config, - provider: visionConfig.provider, - apiKey: visionConfig.apiKey, - model: visionConfig.model, - baseUrl: visionConfig.baseUrl, - }; - - // 创建 Vision Agent 执行器 - const executor = new AgentExecutor(visionAgent, visionAgentConfig, this.registry); - - // 构建提示词 - const prompt = userText || '请详细描述这张图片的内容'; - - // 执行 Vision 分析 - const result = await executor.execute(prompt, { - workdir: process.cwd(), - images, - onStream: undefined, // Vision Agent 不使用流式输出 - }); - - if (!result.success) { - onStream?.(`\n⚠ Vision 分析失败: ${result.error}\n`); - return null; - } - - onStream?.('\n[Vision 分析完成]\n'); - - // 构建带分析结果的文本消息 - const combinedText = `[图片分析结果 - 由 ${visionConfig.model} 提供]\n${result.text}\n\n用户问题: ${userText || '(无附加问题)'}`; - - return combinedText; - } - - /** - * 检查当前模型是否支持 vision(图片理解) + * 检查当前模型是否支持 vision */ supportsVision(): boolean { - const model = this.config.model.toLowerCase(); - - // Anthropic Claude 模型支持 vision - if (this.config.provider === 'anthropic') { - // Claude 3 及以上版本支持 vision - return model.includes('claude-3') || model.includes('claude-4'); - } - - // OpenAI GPT-4 系列支持 vision - if (this.config.provider === 'openai') { - // GPT-4o, GPT-4 Turbo, GPT-4 Vision 等支持 - return model.includes('gpt-4'); - } - - // DeepSeek 目前不支持 vision - if (this.config.provider === 'deepseek') { - return false; - } - - return false; + return this.visionHandler.supportsVision(); } /** diff --git a/packages/core/src/core/index.ts b/packages/core/src/core/index.ts new file mode 100644 index 0000000..65dd668 --- /dev/null +++ b/packages/core/src/core/index.ts @@ -0,0 +1,32 @@ +/** + * Core 模块导出 + */ + +// Agent 主类 +export { Agent, type AgentChatOptions } from './agent.js'; + +// 子模块 +export { + AgentToolExecutor, + type ToolStartInfo, + type ToolEndInfo, + type ToolExecutionContext, +} from './agent-tool-executor.js'; + +export { + AgentMessageHandler, + type DoomLoopInfo, + type MessageHandlerConfig, + type StreamCallbacks, +} from './agent-message-handler.js'; + +export { AgentModeManager } from './agent-mode-manager.js'; + +export { AgentVisionHandler } from './agent-vision-handler.js'; + +// Doom Loop +export { + createDoomLoopDetector, + DOOM_LOOP_WARNING, + type DoomLoopDetector, +} from './doom-loop.js'; diff --git a/packages/core/src/session/manager.ts b/packages/core/src/session/manager.ts index 41850ca..2d8b8e0 100644 --- a/packages/core/src/session/manager.ts +++ b/packages/core/src/session/manager.ts @@ -1,48 +1,19 @@ +/** + * 会话管理器 + * 作为编排器,委托具体工作给各个子模块 + */ + import type { ModelMessage } from 'ai'; import * as storage from './storage/index.js'; -import { SessionStorage, MessageStorage, PartStorage, TodoStorage } from './storage/index.js'; -import type { SessionInfo, Part, TodoItem } from './storage/index.js'; -import { generateSessionId } from './id.js'; -import { getProjectId, isGitRepository } from './project.js'; +import type { TodoItem } from './storage/index.js'; -/** - * 会话摘要(用于列表展示) - */ -export interface SessionSummary { - id: string; - title: string; - workdir: string; - messageCount: number; - createdAt: string; - updatedAt: string; -} +// 子模块 +import { SessionStore, type SessionData, type SessionSummary } from './session-store.js'; +import { ProjectManager, type ProjectMetadata } from './project-manager.js'; +import { SessionAutoSave } from './session-auto-save.js'; -/** - * 运行时会话数据(兼容旧接口) - */ -export interface SessionData { - id: string; - projectId: string; - parentId?: string; - agentName?: string; - createdAt: string; - updatedAt: string; - workdir: string; - title?: string; - messages: ModelMessage[]; - discoveredTools: string[]; - todos: TodoItem[]; -} - -/** - * 项目元数据 - */ -export interface ProjectMetadata { - id: string; - workdir: string; - createdAt: string; - isGitRepo: boolean; -} +// 重新导出类型 +export type { SessionData, SessionSummary, ProjectMetadata }; /** * 会话管理器 @@ -50,14 +21,24 @@ export interface ProjectMetadata { */ export class SessionManager { private currentSession: SessionData | null = null; - private currentProject: ProjectMetadata | null = null; - private autoSaveInterval: ReturnType | null = null; private storageDir?: string; + // 子模块 + private store: SessionStore; + private projectManager: ProjectManager; + private autoSave: SessionAutoSave; + constructor(storageDir?: string) { this.storageDir = storageDir; + this.store = new SessionStore(); + this.projectManager = new ProjectManager(); + this.autoSave = new SessionAutoSave(); } + // ============================================================================ + // 初始化 + // ============================================================================ + /** * 初始化 - 尝试恢复或创建新会话 */ @@ -66,13 +47,14 @@ export class SessionManager { await storage.initStorage(this.storageDir); // 获取或创建项目 - this.currentProject = await this.getOrCreateProject(workdir); + await this.projectManager.getOrCreate(workdir); // 尝试加载当前会话 const currentSessionId = await this.getCurrentSessionId(); if (currentSessionId) { - const existing = await this.loadSession(this.currentProject.id, currentSessionId); + const projectId = this.projectManager.getProjectId()!; + const existing = await this.store.load(projectId, currentSessionId); if (existing && existing.workdir === workdir) { this.currentSession = existing; @@ -82,218 +64,18 @@ export class SessionManager { } // 创建新会话 - this.currentSession = await this.createNewSession(workdir); - await this.saveSessionInfo(); + const projectId = this.projectManager.getProjectId()!; + this.currentSession = await this.store.create(projectId, workdir); + await this.store.save(this.currentSession); await this.setCurrentSessionPointer(this.currentSession.id); this.startAutoSave(); return this.currentSession; } - /** - * 获取或创建项目 - */ - private async getOrCreateProject(workdir: string): Promise { - const projectId = await getProjectId(workdir); - - try { - const existing = await storage.read(['project', projectId]); - return existing; - } catch (e) { - if (e instanceof storage.StorageNotFoundError) { - const isGitRepo = await isGitRepository(workdir); - const project: ProjectMetadata = { - id: projectId, - workdir, - createdAt: new Date().toISOString(), - isGitRepo, - }; - await storage.write(['project', projectId], project); - return project; - } - throw e; - } - } - - /** - * 创建新会话 - */ - private async createNewSession(workdir: string): Promise { - if (!this.currentProject) { - throw new Error('Project not initialized. Call init() first.'); - } - - const sessionInfo = await SessionStorage.create(this.currentProject.id, workdir); - - return { - id: sessionInfo.id, - projectId: sessionInfo.projectId, - createdAt: new Date(sessionInfo.createdAt).toISOString(), - updatedAt: new Date(sessionInfo.updatedAt).toISOString(), - workdir: sessionInfo.workdir, - title: sessionInfo.title, - messages: [], - discoveredTools: sessionInfo.discoveredTools, - todos: [], - }; - } - - /** - * 加载会话(从存储重建) - */ - private async loadSession(projectId: string, sessionId: string): Promise { - const sessionInfo = await SessionStorage.get(projectId, sessionId); - if (!sessionInfo) return null; - - // 加载消息 - const messages = await this.loadMessagesFromStorage(sessionId); - - // 加载 todos - const todoList = await TodoStorage.get(sessionId); - - return { - id: sessionInfo.id, - projectId: sessionInfo.projectId, - parentId: sessionInfo.parentId, - agentName: sessionInfo.agentName, - createdAt: new Date(sessionInfo.createdAt).toISOString(), - updatedAt: new Date(sessionInfo.updatedAt).toISOString(), - workdir: sessionInfo.workdir, - title: sessionInfo.title, - messages, - discoveredTools: sessionInfo.discoveredTools, - todos: todoList?.items || [], - }; - } - - /** - * 从存储加载消息并转换为 AI SDK 格式 - */ - private async loadMessagesFromStorage(sessionId: string): Promise { - const messageInfos = await MessageStorage.listBySession(sessionId); - const messages: ModelMessage[] = []; - - for (const messageInfo of messageInfos) { - const parts = await PartStorage.getByIds(messageInfo.id, messageInfo.partIds); - const modelMessages = this.partsToModelMessages(messageInfo.role, parts); - messages.push(...modelMessages); - } - - return messages; - } - - /** - * 将 Parts 转换为 AI SDK ModelMessage(用于加载历史消息) - * - * 新逻辑: - * - user 消息:直接转换 - * - assistant 消息:转换文本和工具调用,然后为已完成的工具生成 tool 消息 - */ - private partsToModelMessages(role: string, parts: Part[]): ModelMessage[] { - if (parts.length === 0) return []; - - const result: ModelMessage[] = []; - - if (role === 'user') { - // User 消息:只有文本和文件 - const content: unknown[] = []; - for (const part of parts) { - if (part.type === 'text') { - content.push({ type: 'text', text: part.text }); - } else if (part.type === 'file') { - content.push({ - type: 'image', - image: part.data, - mimeType: part.mimeType, - }); - } - } - - if (content.length === 1 && (content[0] as { type: string }).type === 'text') { - result.push({ - role: 'user', - content: (content[0] as { text: string }).text, - }); - } else if (content.length > 0) { - result.push({ - role: 'user', - content, - } as ModelMessage); - } - - } else if (role === 'assistant') { - // Assistant 消息:文本 + 工具调用 - const content: unknown[] = []; - // input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等) - const completedTools: Array<{ toolCallId: string; toolName: string; input: unknown; output: unknown }> = []; - - for (const part of parts) { - if (part.type === 'text') { - content.push({ type: 'text', text: part.text }); - } else if (part.type === 'tool') { - // 只有非 pending 状态的工具调用才添加到 AI SDK 消息 - if (part.state.status !== 'pending') { - // AI SDK v5 使用 input 字段(不是 args) - content.push({ - type: 'tool-call', - toolCallId: part.toolCallId, - toolName: part.toolName, - input: part.state.input, - }); - - // 收集已完成的工具结果 - if (part.state.status === 'completed') { - completedTools.push({ - toolCallId: part.toolCallId, - toolName: part.toolName, - input: part.state.input, - output: part.state.output, - }); - } else if (part.state.status === 'error') { - completedTools.push({ - toolCallId: part.toolCallId, - toolName: part.toolName, - input: part.state.input, - output: part.state.error, - }); - } - } - } else if (part.type === 'reasoning') { - content.push({ type: 'text', text: `[Reasoning] ${part.text}` }); - } - } - - // 添加 assistant 消息 - if (content.length === 1 && (content[0] as { type: string }).type === 'text') { - result.push({ - role: 'assistant', - content: (content[0] as { text: string }).text, - }); - } else if (content.length > 0) { - result.push({ - role: 'assistant', - content, - } as ModelMessage); - } - - // 添加 tool 消息(如果有已完成的工具) - // AI SDK v5 要求 tool-result 必须包含 input 和 output 字段 - if (completedTools.length > 0) { - result.push({ - role: 'tool', - content: completedTools.map((t) => ({ - type: 'tool-result', - toolCallId: t.toolCallId, - toolName: t.toolName, - input: t.input, - output: t.output, - })), - } as unknown as ModelMessage); - } - } - - return result; - } + // ============================================================================ + // 会话获取 + // ============================================================================ /** * 获取当前会话 @@ -306,189 +88,108 @@ export class SessionManager { * 获取当前项目 */ getProject(): ProjectMetadata | null { - return this.currentProject; + return this.projectManager.getProject(); } /** - * 保存会话信息 + * 获取当前会话 ID */ - private async saveSessionInfo(): Promise { - if (!this.currentSession) return; - - const sessionInfo: SessionInfo = { - id: this.currentSession.id, - projectId: this.currentSession.projectId, - parentId: this.currentSession.parentId, - agentName: this.currentSession.agentName, - createdAt: new Date(this.currentSession.createdAt).getTime(), - updatedAt: Date.now(), - workdir: this.currentSession.workdir, - title: this.currentSession.title, - discoveredTools: this.currentSession.discoveredTools, - stats: { - messageCount: this.currentSession.messages.length, - inputTokens: 0, - outputTokens: 0, - }, - }; - - await SessionStorage.save(sessionInfo); + getSessionId(): string | undefined { + return this.currentSession?.id; } + // ============================================================================ + // 会话操作 + // ============================================================================ + /** * 保存当前会话 */ async save(): Promise { if (!this.currentSession) return; - await this.saveSessionInfo(); + await this.store.save(this.currentSession); } /** - * 同步消息到存储(将 AI SDK 消息转换为 Message + Parts) - * - * 新逻辑:只存储 user 和 assistant 消息 - * - user 消息:直接存储 - * - assistant 消息:合并后续的 tool 消息中的工具结果 - * - tool 消息:跳过(结果合并到 assistant) + * 清空当前会话并创建新会话 */ - async syncMessages(messages: ModelMessage[]): Promise { - if (!this.currentSession) return; - - const sessionId = this.currentSession.id; - - // 删除旧消息 - await MessageStorage.removeBySession(sessionId); - - // 用于跟踪当前 assistant 消息的工具调用 - let currentAssistantMsgId: string | null = null; - let currentUserMsgId: string | null = null; - const toolCallPartIds = new Map(); // toolCallId -> partId - - for (let i = 0; i < messages.length; i++) { - const message = messages[i]; - - if (message.role === 'user') { - // User 消息 - const messageInfo = await MessageStorage.create(sessionId, 'user'); - currentUserMsgId = messageInfo.id; - const partIds: string[] = []; - - if (typeof message.content === 'string') { - const part = await PartStorage.createText(messageInfo.id, message.content); - partIds.push(part.id); - } else if (Array.isArray(message.content)) { - for (const item of message.content) { - const itemType = (item as { type: string }).type; - if (itemType === 'text') { - const part = await PartStorage.createText(messageInfo.id, (item as { text: string }).text); - partIds.push(part.id); - } else if (itemType === 'image') { - const img = item as unknown as { image: string; mimeType: string }; - const part = await PartStorage.create(messageInfo.id, 'file', { - filename: 'image', - mimeType: img.mimeType, - data: typeof img.image === 'string' ? img.image : '', - }); - partIds.push(part.id); - } - } - } - - if (partIds.length > 0) { - await MessageStorage.update(sessionId, messageInfo.id, { partIds }); - } - - // 重置工具调用追踪 - currentAssistantMsgId = null; - toolCallPartIds.clear(); - - } else if (message.role === 'assistant') { - // Assistant 消息:如果当前轮次已有 assistant 消息,则追加 Parts - let messageId: string; - let existingPartIds: string[] = []; - - if (currentAssistantMsgId) { - // 同一轮对话的后续 assistant 消息,追加到现有消息 - messageId = currentAssistantMsgId; - const existingMsg = await MessageStorage.get(sessionId, messageId); - existingPartIds = existingMsg?.partIds ?? []; - } else { - // 新的 assistant 消息 - const messageInfo = await MessageStorage.create(sessionId, 'assistant', { - parentId: currentUserMsgId ?? undefined, - }); - messageId = messageInfo.id; - currentAssistantMsgId = messageId; - } - - const newPartIds: string[] = []; - - if (typeof message.content === 'string') { - const part = await PartStorage.createText(messageId, message.content); - newPartIds.push(part.id); - } else if (Array.isArray(message.content)) { - for (const item of message.content) { - const itemType = (item as { type: string }).type; - if (itemType === 'text') { - const part = await PartStorage.createText(messageId, (item as { text: string }).text); - newPartIds.push(part.id); - } else if (itemType === 'tool-call') { - // AI SDK 的 tool-call 使用 input 字段存储参数(不是 args) - const toolCall = item as unknown as { toolCallId: string; toolName: string; input: Record }; - // 创建 running 状态的工具 Part - const part = await PartStorage.createToolRunning( - messageId, - toolCall.toolCallId, - toolCall.toolName, - (toolCall.input as Record) ?? {} - ); - newPartIds.push(part.id); - toolCallPartIds.set(toolCall.toolCallId, part.id); - } - } - } - - if (newPartIds.length > 0) { - // 合并已有的和新的 partIds - const allPartIds = [...existingPartIds, ...newPartIds]; - await MessageStorage.update(sessionId, messageId, { partIds: allPartIds }); - } - - } else if (message.role === 'tool' && currentAssistantMsgId) { - // Tool 消息:更新对应 assistant 消息中的工具 Part 状态 - if (Array.isArray(message.content)) { - for (const item of message.content) { - const itemType = (item as { type: string }).type; - if (itemType === 'tool-result') { - // AI SDK v5 使用 output 字段存储结果(不是 result) - const toolResult = item as unknown as { toolCallId: string; toolName: string; output: unknown }; - const partId = toolCallPartIds.get(toolResult.toolCallId); - if (partId) { - // 更新工具状态为 completed - // 获取原始 start time - const part = await PartStorage.get(currentAssistantMsgId, partId); - const startTime = part?.type === 'tool' && part.state.status === 'running' - ? part.state.time.start - : Date.now(); - await PartStorage.setToolCompleted(currentAssistantMsgId, partId, toolResult.output, startTime); - } - } - } - } - // 不创建新消息,跳过 tool role - } - // 忽略 system 消息(system prompt 通过其他方式注入) + async newSession(workdir?: string): Promise { + if (!this.projectManager.isInitialized()) { + throw new Error('Project not initialized. Call init() first.'); } + + const newWorkdir = workdir || this.currentSession?.workdir || process.cwd(); + + // 如果工作目录变化,需要切换项目 + if (workdir && workdir !== this.projectManager.getProject()?.workdir) { + await this.projectManager.switchProject(workdir); + } + + const projectId = this.projectManager.getProjectId()!; + this.currentSession = await this.store.create(projectId, newWorkdir); + await this.store.save(this.currentSession); + await this.setCurrentSessionPointer(this.currentSession.id); + + return this.currentSession; } + /** + * 恢复指定会话 + */ + async restoreSession(sessionId: string): Promise { + if (!this.projectManager.isInitialized()) { + throw new Error('Project not initialized. Call init() first.'); + } + + const projectId = this.projectManager.getProjectId()!; + const session = await this.store.load(projectId, sessionId); + if (!session) return null; + + this.currentSession = session; + await this.setCurrentSessionPointer(sessionId); + + return session; + } + + /** + * 列出当前项目的历史会话 + */ + async listSessions(): Promise { + const projectId = this.projectManager.getProjectId(); + if (!projectId) { + return this.listAllSessions(); + } + return this.store.listByProject(projectId); + } + + /** + * 列出所有项目的会话 + */ + async listAllSessions(): Promise { + return this.store.listAll(); + } + + /** + * 删除历史会话 + */ + async deleteSession(sessionId: string): Promise { + const projectId = this.projectManager.getProjectId(); + if (!projectId) return false; + return this.store.delete(projectId, sessionId); + } + + // ============================================================================ + // 消息操作 + // ============================================================================ + /** * 批量设置消息(用于同步整个对话历史) */ async setMessages(messages: ModelMessage[]): Promise { if (!this.currentSession) return; this.currentSession.messages = messages; - await this.syncMessages(messages); - await this.saveSessionInfo(); + await this.store.syncMessages(this.currentSession.id, messages); + await this.store.save(this.currentSession); } /** @@ -507,13 +208,17 @@ export class SessionManager { return this.currentSession?.messages || []; } + // ============================================================================ + // 工具和待办操作 + // ============================================================================ + /** * 设置已发现的工具 */ async setDiscoveredTools(tools: string[]): Promise { if (!this.currentSession) return; this.currentSession.discoveredTools = tools; - await this.saveSessionInfo(); + await this.store.save(this.currentSession); } /** @@ -526,10 +231,12 @@ export class SessionManager { /** * 更新待办事项 */ - async setTodos(todos: Array<{ content: string; status: 'pending' | 'in_progress' | 'completed' }>): Promise { + async setTodos( + todos: Array<{ content: string; status: 'pending' | 'in_progress' | 'completed' }> + ): Promise { if (!this.currentSession) return; - const todoList = await TodoStorage.replace(this.currentSession.id, todos); - this.currentSession.todos = todoList.items; + const items = await this.store.setTodos(this.currentSession.id, todos); + this.currentSession.todos = items; } /** @@ -539,192 +246,46 @@ export class SessionManager { return this.currentSession?.todos || []; } - /** - * 清空当前会话并创建新会话 - */ - async newSession(workdir?: string): Promise { - if (!this.currentProject) { - throw new Error('Project not initialized. Call init() first.'); - } - - const newWorkdir = workdir || this.currentSession?.workdir || process.cwd(); - - // 如果工作目录变化,需要切换项目 - if (workdir && workdir !== this.currentProject.workdir) { - this.currentProject = await this.getOrCreateProject(workdir); - } - - this.currentSession = await this.createNewSession(newWorkdir); - await this.saveSessionInfo(); - await this.setCurrentSessionPointer(this.currentSession.id); - - return this.currentSession; - } + // ============================================================================ + // 子会话操作 + // ============================================================================ /** * 创建子会话(用于 Task 工具) */ createChildSession(parentId: string, agentName: string, title?: string): SessionData { - if (!this.currentProject) { + if (!this.projectManager.isInitialized()) { throw new Error('Project not initialized. Call init() first.'); } + const projectId = this.projectManager.getProjectId()!; const workdir = this.currentSession?.workdir || process.cwd(); - return { - id: generateSessionId(), - projectId: this.currentProject.id, - parentId, - agentName, - createdAt: new Date().toISOString(), - updatedAt: new Date().toISOString(), - workdir, - title: title || `子任务 (@${agentName})`, - messages: [], - discoveredTools: [], - todos: [], - }; + return this.store.createChildSession(projectId, parentId, agentName, workdir, title); } /** * 保存子会话 */ async saveChildSession(session: SessionData): Promise { - const sessionInfo: SessionInfo = { - id: session.id, - projectId: session.projectId, - parentId: session.parentId, - agentName: session.agentName, - createdAt: new Date(session.createdAt).getTime(), - updatedAt: Date.now(), - workdir: session.workdir, - title: session.title, - discoveredTools: session.discoveredTools, - }; - await SessionStorage.save(sessionInfo); + await this.store.saveChildSession(session); } - /** - * 获取当前会话 ID - */ - getSessionId(): string | undefined { - return this.currentSession?.id; - } + // ============================================================================ + // 自动保存 + // ============================================================================ /** - * 恢复指定会话 - */ - async restoreSession(sessionId: string): Promise { - if (!this.currentProject) { - throw new Error('Project not initialized. Call init() first.'); - } - - const session = await this.loadSession(this.currentProject.id, sessionId); - if (!session) return null; - - this.currentSession = session; - await this.setCurrentSessionPointer(sessionId); - - return session; - } - - /** - * 列出当前项目的历史会话 - */ - async listSessions(): Promise { - if (!this.currentProject) { - return this.listAllSessions(); - } - - const sessions = await SessionStorage.listByProject(this.currentProject.id); - return sessions.map((s) => ({ - id: s.id, - title: s.title || `会话 ${s.id}`, - workdir: s.workdir, - messageCount: s.stats?.messageCount || 0, - createdAt: new Date(s.createdAt).toISOString(), - updatedAt: new Date(s.updatedAt).toISOString(), - })); - } - - /** - * 列出所有项目的会话 - */ - async listAllSessions(): Promise { - const sessions = await SessionStorage.listAll(); - return sessions.map((s) => ({ - id: s.id, - title: s.title || `会话 ${s.id}`, - workdir: s.workdir, - messageCount: s.stats?.messageCount || 0, - createdAt: new Date(s.createdAt).toISOString(), - updatedAt: new Date(s.updatedAt).toISOString(), - })); - } - - /** - * 删除历史会话 - */ - async deleteSession(sessionId: string): Promise { - if (!this.currentProject) return false; - - try { - // 删除会话的消息和 Parts - const messageInfos = await MessageStorage.listBySession(sessionId); - for (const msg of messageInfos) { - await PartStorage.removeByMessage(msg.id); - } - await MessageStorage.removeBySession(sessionId); - - // 删除 todos - await TodoStorage.removeBySession(sessionId); - - // 删除会话信息 - await SessionStorage.remove(this.currentProject.id, sessionId); - - return true; - } catch { - return false; - } - } - - /** - * 获取当前会话 ID(从存储) - */ - private async getCurrentSessionId(): Promise { - try { - const pointer = await storage.read<{ sessionId: string }>(['current-session']); - return pointer.sessionId; - } catch { - return null; - } - } - - /** - * 设置当前会话指针 - */ - private async setCurrentSessionPointer(sessionId: string): Promise { - await storage.write(['current-session'], { sessionId }); - } - - /** - * 启动自动保存(每 30 秒) + * 启动自动保存 */ private startAutoSave(): void { - if (this.autoSaveInterval) return; - - this.autoSaveInterval = setInterval(async () => { - await this.save(); - }, 30000); + this.autoSave.start(() => this.save()); } /** * 停止自动保存 */ stopAutoSave(): void { - if (this.autoSaveInterval) { - clearInterval(this.autoSaveInterval); - this.autoSaveInterval = null; - } + this.autoSave.stop(); } /** @@ -735,6 +296,10 @@ export class SessionManager { await this.save(); } + // ============================================================================ + // 清理 + // ============================================================================ + /** * 清理旧会话 */ @@ -756,6 +321,29 @@ export class SessionManager { return deletedCount; } + // ============================================================================ + // 辅助方法 + // ============================================================================ + + /** + * 获取当前会话 ID(从存储) + */ + private async getCurrentSessionId(): Promise { + try { + const pointer = await storage.read<{ sessionId: string }>(['current-session']); + return pointer.sessionId; + } catch { + return null; + } + } + + /** + * 设置当前会话指针 + */ + private async setCurrentSessionPointer(sessionId: string): Promise { + await storage.write(['current-session'], { sessionId }); + } + /** * 获取存储目录 */ diff --git a/packages/core/src/session/message-converter.ts b/packages/core/src/session/message-converter.ts new file mode 100644 index 0000000..2dbc390 --- /dev/null +++ b/packages/core/src/session/message-converter.ts @@ -0,0 +1,331 @@ +/** + * 消息格式转换器 + * 负责 Part ↔ ModelMessage 的转换 + */ + +import type { ModelMessage } from 'ai'; +import { MessageStorage, PartStorage } from './storage/index.js'; +import type { Part } from './storage/index.js'; + +/** + * 消息格式转换器 + */ +export class MessageConverter { + /** + * 从存储加载消息并转换为 AI SDK 格式 + */ + async loadFromStorage(sessionId: string): Promise { + const messageInfos = await MessageStorage.listBySession(sessionId); + const messages: ModelMessage[] = []; + + for (const messageInfo of messageInfos) { + const parts = await PartStorage.getByIds(messageInfo.id, messageInfo.partIds); + const modelMessages = this.partsToModelMessages(messageInfo.role, parts); + messages.push(...modelMessages); + } + + return messages; + } + + /** + * 将 Parts 转换为 AI SDK ModelMessage(用于加载历史消息) + * + * 逻辑: + * - user 消息:直接转换 + * - assistant 消息:转换文本和工具调用,然后为已完成的工具生成 tool 消息 + */ + partsToModelMessages(role: string, parts: Part[]): ModelMessage[] { + if (parts.length === 0) return []; + + const result: ModelMessage[] = []; + + if (role === 'user') { + result.push(...this.convertUserParts(parts)); + } else if (role === 'assistant') { + result.push(...this.convertAssistantParts(parts)); + } + + return result; + } + + /** + * 转换用户消息 Parts + */ + private convertUserParts(parts: Part[]): ModelMessage[] { + const content: unknown[] = []; + + for (const part of parts) { + if (part.type === 'text') { + content.push({ type: 'text', text: part.text }); + } else if (part.type === 'file') { + content.push({ + type: 'image', + image: part.data, + mimeType: part.mimeType, + }); + } + } + + if (content.length === 0) return []; + + if (content.length === 1 && (content[0] as { type: string }).type === 'text') { + return [{ + role: 'user', + content: (content[0] as { text: string }).text, + }]; + } + + return [{ + role: 'user', + content, + } as ModelMessage]; + } + + /** + * 转换助手消息 Parts + */ + private convertAssistantParts(parts: Part[]): ModelMessage[] { + const result: ModelMessage[] = []; + const content: unknown[] = []; + const completedTools: Array<{ + toolCallId: string; + toolName: string; + input: unknown; + output: unknown; + }> = []; + + for (const part of parts) { + if (part.type === 'text') { + content.push({ type: 'text', text: part.text }); + } else if (part.type === 'tool') { + // 只有非 pending 状态的工具调用才添加到 AI SDK 消息 + if (part.state.status !== 'pending') { + content.push({ + type: 'tool-call', + toolCallId: part.toolCallId, + toolName: part.toolName, + input: part.state.input, + }); + + // 收集已完成的工具结果 + if (part.state.status === 'completed') { + completedTools.push({ + toolCallId: part.toolCallId, + toolName: part.toolName, + input: part.state.input, + output: part.state.output, + }); + } else if (part.state.status === 'error') { + completedTools.push({ + toolCallId: part.toolCallId, + toolName: part.toolName, + input: part.state.input, + output: part.state.error, + }); + } + } + } else if (part.type === 'reasoning') { + content.push({ type: 'text', text: `[Reasoning] ${part.text}` }); + } + } + + // 添加 assistant 消息 + if (content.length === 1 && (content[0] as { type: string }).type === 'text') { + result.push({ + role: 'assistant', + content: (content[0] as { text: string }).text, + }); + } else if (content.length > 0) { + result.push({ + role: 'assistant', + content, + } as ModelMessage); + } + + // 添加 tool 消息(如果有已完成的工具) + if (completedTools.length > 0) { + result.push({ + role: 'tool', + content: completedTools.map((t) => ({ + type: 'tool-result', + toolCallId: t.toolCallId, + toolName: t.toolName, + input: t.input, + output: t.output, + })), + } as unknown as ModelMessage); + } + + return result; + } + + /** + * 同步消息到存储(将 AI SDK 消息转换为 Message + Parts) + * + * 逻辑:只存储 user 和 assistant 消息 + * - user 消息:直接存储 + * - assistant 消息:合并后续的 tool 消息中的工具结果 + * - tool 消息:跳过(结果合并到 assistant) + */ + async syncToStorage(sessionId: string, messages: ModelMessage[]): Promise { + // 删除旧消息 + await MessageStorage.removeBySession(sessionId); + + // 用于跟踪当前 assistant 消息的工具调用 + let currentAssistantMsgId: string | null = null; + let currentUserMsgId: string | null = null; + const toolCallPartIds = new Map(); // toolCallId -> partId + + for (const message of messages) { + if (message.role === 'user') { + await this.syncUserMessage(sessionId, message); + currentUserMsgId = (await MessageStorage.listBySession(sessionId)).slice(-1)[0]?.id ?? null; + currentAssistantMsgId = null; + toolCallPartIds.clear(); + } else if (message.role === 'assistant') { + const result = await this.syncAssistantMessage( + sessionId, + message, + currentAssistantMsgId, + currentUserMsgId, + toolCallPartIds + ); + currentAssistantMsgId = result.messageId; + result.toolCallPartIds.forEach((v, k) => toolCallPartIds.set(k, v)); + } else if (message.role === 'tool' && currentAssistantMsgId) { + await this.syncToolMessage(currentAssistantMsgId, message, toolCallPartIds); + } + } + } + + /** + * 同步用户消息 + */ + private async syncUserMessage(sessionId: string, message: ModelMessage): Promise { + const messageInfo = await MessageStorage.create(sessionId, 'user'); + const partIds: string[] = []; + + if (typeof message.content === 'string') { + const part = await PartStorage.createText(messageInfo.id, message.content); + partIds.push(part.id); + } else if (Array.isArray(message.content)) { + for (const item of message.content) { + const itemType = (item as { type: string }).type; + if (itemType === 'text') { + const part = await PartStorage.createText(messageInfo.id, (item as { text: string }).text); + partIds.push(part.id); + } else if (itemType === 'image') { + const img = item as unknown as { image: string; mimeType: string }; + const part = await PartStorage.create(messageInfo.id, 'file', { + filename: 'image', + mimeType: img.mimeType, + data: typeof img.image === 'string' ? img.image : '', + }); + partIds.push(part.id); + } + } + } + + if (partIds.length > 0) { + await MessageStorage.update(sessionId, messageInfo.id, { partIds }); + } + } + + /** + * 同步助手消息 + */ + private async syncAssistantMessage( + sessionId: string, + message: ModelMessage, + currentAssistantMsgId: string | null, + currentUserMsgId: string | null, + existingToolCallPartIds: Map + ): Promise<{ messageId: string; toolCallPartIds: Map }> { + let messageId: string; + let existingPartIds: string[] = []; + const newToolCallPartIds = new Map(); + + if (currentAssistantMsgId) { + // 同一轮对话的后续 assistant 消息,追加到现有消息 + messageId = currentAssistantMsgId; + const existingMsg = await MessageStorage.get(sessionId, messageId); + existingPartIds = existingMsg?.partIds ?? []; + } else { + // 新的 assistant 消息 + const messageInfo = await MessageStorage.create(sessionId, 'assistant', { + parentId: currentUserMsgId ?? undefined, + }); + messageId = messageInfo.id; + } + + const newPartIds: string[] = []; + + if (typeof message.content === 'string') { + const part = await PartStorage.createText(messageId, message.content); + newPartIds.push(part.id); + } else if (Array.isArray(message.content)) { + for (const item of message.content) { + const itemType = (item as { type: string }).type; + if (itemType === 'text') { + const part = await PartStorage.createText(messageId, (item as { text: string }).text); + newPartIds.push(part.id); + } else if (itemType === 'tool-call') { + const toolCall = item as unknown as { + toolCallId: string; + toolName: string; + input: Record; + }; + const part = await PartStorage.createToolRunning( + messageId, + toolCall.toolCallId, + toolCall.toolName, + (toolCall.input as Record) ?? {} + ); + newPartIds.push(part.id); + newToolCallPartIds.set(toolCall.toolCallId, part.id); + } + } + } + + if (newPartIds.length > 0) { + const allPartIds = [...existingPartIds, ...newPartIds]; + await MessageStorage.update(sessionId, messageId, { partIds: allPartIds }); + } + + // 合并工具调用 ID + existingToolCallPartIds.forEach((v, k) => newToolCallPartIds.set(k, v)); + + return { messageId, toolCallPartIds: newToolCallPartIds }; + } + + /** + * 同步工具消息(更新工具状态) + */ + private async syncToolMessage( + assistantMsgId: string, + message: ModelMessage, + toolCallPartIds: Map + ): Promise { + if (!Array.isArray(message.content)) return; + + for (const item of message.content) { + const itemType = (item as { type: string }).type; + if (itemType === 'tool-result') { + const toolResult = item as unknown as { + toolCallId: string; + toolName: string; + output: unknown; + }; + const partId = toolCallPartIds.get(toolResult.toolCallId); + if (partId) { + const part = await PartStorage.get(assistantMsgId, partId); + const startTime = + part?.type === 'tool' && part.state.status === 'running' + ? part.state.time.start + : Date.now(); + await PartStorage.setToolCompleted(assistantMsgId, partId, toolResult.output, startTime); + } + } + } + } +} diff --git a/packages/core/src/session/project-manager.ts b/packages/core/src/session/project-manager.ts new file mode 100644 index 0000000..cf787d3 --- /dev/null +++ b/packages/core/src/session/project-manager.ts @@ -0,0 +1,86 @@ +/** + * 项目管理器 + * 负责项目的创建和管理 + */ + +import * as storage from './storage/index.js'; +import { getProjectId, isGitRepository } from './project.js'; + +/** + * 项目元数据 + */ +export interface ProjectMetadata { + id: string; + workdir: string; + createdAt: string; + isGitRepo: boolean; +} + +/** + * 项目管理器 + */ +export class ProjectManager { + private currentProject: ProjectMetadata | null = null; + + /** + * 获取当前项目 + */ + getProject(): ProjectMetadata | null { + return this.currentProject; + } + + /** + * 设置当前项目 + */ + setProject(project: ProjectMetadata | null): void { + this.currentProject = project; + } + + /** + * 获取或创建项目 + */ + async getOrCreate(workdir: string): Promise { + const projectId = await getProjectId(workdir); + + try { + const existing = await storage.read(['project', projectId]); + this.currentProject = existing; + return existing; + } catch (e) { + if (e instanceof storage.StorageNotFoundError) { + const isGitRepo = await isGitRepository(workdir); + const project: ProjectMetadata = { + id: projectId, + workdir, + createdAt: new Date().toISOString(), + isGitRepo, + }; + await storage.write(['project', projectId], project); + this.currentProject = project; + return project; + } + throw e; + } + } + + /** + * 切换项目 + */ + async switchProject(workdir: string): Promise { + return this.getOrCreate(workdir); + } + + /** + * 检查项目是否初始化 + */ + isInitialized(): boolean { + return this.currentProject !== null; + } + + /** + * 获取项目 ID + */ + getProjectId(): string | null { + return this.currentProject?.id ?? null; + } +} diff --git a/packages/core/src/session/session-auto-save.ts b/packages/core/src/session/session-auto-save.ts new file mode 100644 index 0000000..208c006 --- /dev/null +++ b/packages/core/src/session/session-auto-save.ts @@ -0,0 +1,93 @@ +/** + * 会话自动保存 + * 负责定期保存会话 + */ + +/** + * 保存回调类型 + */ +export type SaveCallback = () => Promise; + +/** + * 自动保存配置 + */ +export interface AutoSaveConfig { + /** 保存间隔(毫秒),默认 30000 */ + interval: number; +} + +const DEFAULT_CONFIG: AutoSaveConfig = { + interval: 30000, +}; + +/** + * 会话自动保存管理 + */ +export class SessionAutoSave { + private intervalId: ReturnType | null = null; + private saveCallback: SaveCallback | null = null; + private config: AutoSaveConfig; + + constructor(config?: Partial) { + this.config = { ...DEFAULT_CONFIG, ...config }; + } + + /** + * 启动自动保存 + */ + start(saveCallback: SaveCallback): void { + if (this.intervalId) return; + + this.saveCallback = saveCallback; + this.intervalId = setInterval(async () => { + try { + await this.saveCallback?.(); + } catch (error) { + console.warn('Auto-save failed:', error); + } + }, this.config.interval); + } + + /** + * 停止自动保存 + */ + stop(): void { + if (this.intervalId) { + clearInterval(this.intervalId); + this.intervalId = null; + } + this.saveCallback = null; + } + + /** + * 检查是否正在运行 + */ + isRunning(): boolean { + return this.intervalId !== null; + } + + /** + * 立即触发保存(不影响定时器) + */ + async saveNow(): Promise { + await this.saveCallback?.(); + } + + /** + * 更新配置 + */ + setConfig(config: Partial): void { + const wasRunning = this.isRunning(); + const callback = this.saveCallback; + + if (wasRunning) { + this.stop(); + } + + this.config = { ...this.config, ...config }; + + if (wasRunning && callback) { + this.start(callback); + } + } +} diff --git a/packages/core/src/session/session-store.ts b/packages/core/src/session/session-store.ts new file mode 100644 index 0000000..3429052 --- /dev/null +++ b/packages/core/src/session/session-store.ts @@ -0,0 +1,236 @@ +/** + * 会话存储管理 + * 负责会话的 CRUD 操作 + */ + +import type { ModelMessage } from 'ai'; +import { SessionStorage, MessageStorage, PartStorage, TodoStorage } from './storage/index.js'; +import type { SessionInfo, TodoItem } from './storage/index.js'; +import { MessageConverter } from './message-converter.js'; +import { generateSessionId } from './id.js'; + +/** + * 会话摘要(用于列表展示) + */ +export interface SessionSummary { + id: string; + title: string; + workdir: string; + messageCount: number; + createdAt: string; + updatedAt: string; +} + +/** + * 运行时会话数据 + */ +export interface SessionData { + id: string; + projectId: string; + parentId?: string; + agentName?: string; + createdAt: string; + updatedAt: string; + workdir: string; + title?: string; + messages: ModelMessage[]; + discoveredTools: string[]; + todos: TodoItem[]; +} + +/** + * 会话存储管理 + */ +export class SessionStore { + private messageConverter: MessageConverter; + + constructor() { + this.messageConverter = new MessageConverter(); + } + + /** + * 创建新会话 + */ + async create(projectId: string, workdir: string): Promise { + const sessionInfo = await SessionStorage.create(projectId, workdir); + + return { + id: sessionInfo.id, + projectId: sessionInfo.projectId, + createdAt: new Date(sessionInfo.createdAt).toISOString(), + updatedAt: new Date(sessionInfo.updatedAt).toISOString(), + workdir: sessionInfo.workdir, + title: sessionInfo.title, + messages: [], + discoveredTools: sessionInfo.discoveredTools, + todos: [], + }; + } + + /** + * 加载会话(从存储重建) + */ + async load(projectId: string, sessionId: string): Promise { + const sessionInfo = await SessionStorage.get(projectId, sessionId); + if (!sessionInfo) return null; + + // 加载消息 + const messages = await this.messageConverter.loadFromStorage(sessionId); + + // 加载 todos + const todoList = await TodoStorage.get(sessionId); + + return { + id: sessionInfo.id, + projectId: sessionInfo.projectId, + parentId: sessionInfo.parentId, + agentName: sessionInfo.agentName, + createdAt: new Date(sessionInfo.createdAt).toISOString(), + updatedAt: new Date(sessionInfo.updatedAt).toISOString(), + workdir: sessionInfo.workdir, + title: sessionInfo.title, + messages, + discoveredTools: sessionInfo.discoveredTools, + todos: todoList?.items || [], + }; + } + + /** + * 保存会话信息 + */ + async save(session: SessionData): Promise { + const sessionInfo: SessionInfo = { + id: session.id, + projectId: session.projectId, + parentId: session.parentId, + agentName: session.agentName, + createdAt: new Date(session.createdAt).getTime(), + updatedAt: Date.now(), + workdir: session.workdir, + title: session.title, + discoveredTools: session.discoveredTools, + stats: { + messageCount: session.messages.length, + inputTokens: 0, + outputTokens: 0, + }, + }; + + await SessionStorage.save(sessionInfo); + } + + /** + * 同步消息到存储 + */ + async syncMessages(sessionId: string, messages: ModelMessage[]): Promise { + await this.messageConverter.syncToStorage(sessionId, messages); + } + + /** + * 更新待办事项 + */ + async setTodos( + sessionId: string, + todos: Array<{ content: string; status: 'pending' | 'in_progress' | 'completed' }> + ): Promise { + const todoList = await TodoStorage.replace(sessionId, todos); + return todoList.items; + } + + /** + * 列出项目的所有会话 + */ + async listByProject(projectId: string): Promise { + const sessions = await SessionStorage.listByProject(projectId); + return this.toSummaries(sessions); + } + + /** + * 列出所有会话 + */ + async listAll(): Promise { + const sessions = await SessionStorage.listAll(); + return this.toSummaries(sessions); + } + + /** + * 删除会话 + */ + async delete(projectId: string, sessionId: string): Promise { + try { + // 删除会话的消息和 Parts + const messageInfos = await MessageStorage.listBySession(sessionId); + for (const msg of messageInfos) { + await PartStorage.removeByMessage(msg.id); + } + await MessageStorage.removeBySession(sessionId); + + // 删除 todos + await TodoStorage.removeBySession(sessionId); + + // 删除会话信息 + await SessionStorage.remove(projectId, sessionId); + + return true; + } catch { + return false; + } + } + + /** + * 创建子会话(用于 Task 工具) + */ + createChildSession( + projectId: string, + parentId: string, + agentName: string, + workdir: string, + title?: string + ): SessionData { + return { + id: generateSessionId(), + projectId, + parentId, + agentName, + createdAt: new Date().toISOString(), + updatedAt: new Date().toISOString(), + workdir, + title: title || `子任务 (@${agentName})`, + messages: [], + discoveredTools: [], + todos: [], + }; + } + + /** + * 保存子会话 + */ + async saveChildSession(session: SessionData): Promise { + const sessionInfo: SessionInfo = { + id: session.id, + projectId: session.projectId, + parentId: session.parentId, + agentName: session.agentName, + createdAt: new Date(session.createdAt).getTime(), + updatedAt: Date.now(), + workdir: session.workdir, + title: session.title, + discoveredTools: session.discoveredTools, + }; + await SessionStorage.save(sessionInfo); + } + + /** + * 转换为摘要列表 + */ + private toSummaries(sessions: SessionInfo[]): SessionSummary[] { + return sessions.map((s) => ({ + id: s.id, + title: s.title || `会话 ${s.id}`, + workdir: s.workdir, + messageCount: s.stats?.messageCount || 0, + createdAt: new Date(s.createdAt).toISOString(), + updatedAt: new Date(s.updatedAt).toISOString(), + })); + } +}