import { describe, it, expect, beforeEach, vi } from 'vitest'; // Mock state const mockState = { generateTextResult: { text: '任务完成', steps: [{ toolCalls: [] }], }, streamTextResult: { textStream: (async function* () { yield '流式'; yield '输出'; })(), response: Promise.resolve({}), }, }; // Mock @ai-sdk/anthropic vi.mock('@ai-sdk/anthropic', () => ({ createAnthropic: vi.fn(() => vi.fn(() => ({ modelId: 'claude-3' }))), })); // Mock @ai-sdk/deepseek vi.mock('@ai-sdk/deepseek', () => ({ createDeepSeek: vi.fn(() => vi.fn(() => ({ modelId: 'deepseek' }))), })); // Mock ai package vi.mock('ai', () => ({ generateText: vi.fn(async () => mockState.generateTextResult), streamText: vi.fn(() => mockState.streamTextResult), stepCountIs: vi.fn(() => () => false), })); // Mock permission-merger vi.mock('../../../src/agent/permission-merger.js', () => ({ checkBashPermission: vi.fn(() => 'allow'), })); // Mock types vi.mock('../../../src/types/index.js', () => ({ buildZodSchema: vi.fn(() => ({})), })); import { AgentExecutor } from '../../../src/agent/executor.js'; import { generateText, streamText } from 'ai'; import { checkBashPermission } from '../../../src/agent/permission-merger.js'; describe('AgentExecutor - Agent 执行器', () => { let executor: AgentExecutor; let mockToolRegistry: any; let mockAgentInfo: any; let mockBaseConfig: any; beforeEach(() => { vi.clearAllMocks(); mockToolRegistry = { getAllTools: vi.fn(() => [ { name: 'bash', description: '执行命令', parameters: { command: { type: 'string', required: true } }, execute: vi.fn().mockResolvedValue({ success: true, output: 'ok' }), }, { name: 'read_file', description: '读取文件', parameters: { path: { type: 'string', required: true } }, execute: vi.fn().mockResolvedValue({ success: true, output: 'content' }), }, { name: 'task', description: '子任务', parameters: { prompt: { type: 'string', required: true } }, execute: vi.fn().mockResolvedValue({ success: true, output: 'done' }), }, ]), }; mockAgentInfo = { name: 'test-agent', description: '测试 Agent', mode: 'subagent', prompt: '你是测试助手', }; mockBaseConfig = { provider: 'anthropic', model: 'claude-3-sonnet', apiKey: 'test-api-key', maxTokens: 4096, systemPrompt: '默认系统提示词', }; // 重置 mock 结果 mockState.generateTextResult = { text: '任务完成', steps: [{ toolCalls: [] }], }; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); }); describe('构造函数', () => { it('成功创建 Anthropic provider', () => { const exec = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); expect(exec).toBeDefined(); }); it('成功创建 DeepSeek provider', () => { const config = { ...mockBaseConfig, provider: 'deepseek' as const }; const exec = new AgentExecutor(mockAgentInfo, config, mockToolRegistry); expect(exec).toBeDefined(); }); it('不支持的 provider 抛出错误', () => { const config = { ...mockBaseConfig, provider: 'unknown' as any }; expect(() => new AgentExecutor(mockAgentInfo, config, mockToolRegistry)).toThrow('不支持的 provider'); }); it('使用 Agent 指定的 provider', () => { const agentInfo = { ...mockAgentInfo, model: { provider: 'deepseek' as const, model: 'deepseek-chat' }, }; const exec = new AgentExecutor(agentInfo, mockBaseConfig, mockToolRegistry); expect(exec).toBeDefined(); }); }); describe('execute - 执行', () => { it('非流式模式成功执行', async () => { const result = await executor.execute('测试任务', {}); expect(result.success).toBe(true); expect(result.text).toBe('任务完成'); expect(generateText).toHaveBeenCalled(); }); it('流式模式成功执行', async () => { const onStream = vi.fn(); // 重置流式结果 mockState.streamTextResult = { textStream: (async function* () { yield '流式'; yield '输出'; })(), response: Promise.resolve({}), }; const result = await executor.execute('测试任务', { onStream }); expect(result.success).toBe(true); expect(result.text).toBe('流式输出'); expect(streamText).toHaveBeenCalled(); expect(onStream).toHaveBeenCalledWith('流式'); expect(onStream).toHaveBeenCalledWith('输出'); }); it('执行失败返回错误', async () => { vi.mocked(generateText).mockRejectedValueOnce(new Error('API 错误')); const result = await executor.execute('测试任务', {}); expect(result.success).toBe(false); expect(result.error).toContain('API 错误'); }); it('传递父会话 ID', async () => { const result = await executor.execute('测试任务', { parentSessionId: 'parent-123', }); expect(result.sessionId).toBe('parent-123'); }); it('无父会话 ID 使用 standalone', async () => { const result = await executor.execute('测试任务', {}); expect(result.sessionId).toBe('standalone'); }); }); describe('工具过滤', () => { it('无配置返回所有工具', async () => { await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; expect(Object.keys(call.tools || {})).toHaveLength(3); }); it('enabled 配置只保留指定工具', async () => { mockAgentInfo.tools = { enabled: ['bash'] }; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; expect(Object.keys(call.tools || {})).toContain('bash'); expect(Object.keys(call.tools || {})).not.toContain('read_file'); }); it('disabled 配置移除指定工具', async () => { mockAgentInfo.tools = { disabled: ['task'] }; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; expect(Object.keys(call.tools || {})).not.toContain('task'); }); it('noTask 配置移除 task 工具', async () => { mockAgentInfo.tools = { noTask: true }; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; expect(Object.keys(call.tools || {})).not.toContain('task'); }); }); describe('权限检查', () => { it('bash 命令被拒绝', async () => { vi.mocked(checkBashPermission).mockReturnValue('deny'); mockAgentInfo.permission = { bash: { deny: ['rm *'] } }; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); // 获取工具并执行 await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; const bashTool = call.tools?.bash; if (bashTool && 'execute' in bashTool) { const result = await bashTool.execute({ command: 'rm -rf /' }); expect(result.success).toBe(false); expect(result.error).toContain('权限拒绝'); } }); it('文件写入被拒绝', async () => { mockAgentInfo.permission = { file: { write: 'deny' } }; // 添加 write_file 工具 mockToolRegistry.getAllTools.mockReturnValue([ { name: 'write_file', description: '写文件', parameters: { path: { type: 'string', required: true } }, execute: vi.fn().mockResolvedValue({ success: true, output: 'ok' }), }, ]); executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; const writeTool = call.tools?.write_file; if (writeTool && 'execute' in writeTool) { const result = await writeTool.execute({ path: '/test.txt', content: 'test' }); expect(result.success).toBe(false); expect(result.error).toContain('权限拒绝'); } }); it('Git 写操作被拒绝', async () => { mockAgentInfo.permission = { git: { write: 'deny' } }; mockToolRegistry.getAllTools.mockReturnValue([ { name: 'git_push', description: 'Git push', parameters: {}, execute: vi.fn().mockResolvedValue({ success: true, output: 'ok' }), }, ]); executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; const gitTool = call.tools?.git_push; if (gitTool && 'execute' in gitTool) { const result = await gitTool.execute({}); expect(result.success).toBe(false); expect(result.error).toContain('Git 写操作被禁止'); } }); it('无权限配置允许所有操作', async () => { // 无 permission 配置 delete mockAgentInfo.permission; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; const bashTool = call.tools?.bash; if (bashTool && 'execute' in bashTool) { const result = await bashTool.execute({ command: 'ls' }); expect(result.success).toBe(true); } }); }); describe('系统提示词', () => { it('使用 Agent 自定义提示词', async () => { mockAgentInfo.prompt = '自定义提示词'; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; expect(call.system).toBe('自定义提示词'); }); it('无自定义提示词使用基础配置', async () => { delete mockAgentInfo.prompt; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; expect(call.system).toBe('默认系统提示词'); }); }); describe('模型配置', () => { it('使用 Agent 指定的模型', async () => { mockAgentInfo.model = { model: 'claude-3-opus' }; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); expect(generateText).toHaveBeenCalled(); }); it('使用 Agent 指定的 maxSteps', async () => { mockAgentInfo.maxSteps = 5; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); expect(generateText).toHaveBeenCalled(); }); it('使用 Agent 指定的 maxTokens', async () => { mockAgentInfo.model = { maxTokens: 8192 }; executor = new AgentExecutor(mockAgentInfo, mockBaseConfig, mockToolRegistry); await executor.execute('测试', {}); const call = vi.mocked(generateText).mock.calls[0][0]; expect(call.maxOutputTokens).toBe(8192); }); }); });