diff --git a/packages/core/src/agent/executor.ts b/packages/core/src/agent/executor.ts index 706eeec..42cfbe0 100644 --- a/packages/core/src/agent/executor.ts +++ b/packages/core/src/agent/executor.ts @@ -16,7 +16,7 @@ import type { ImageData, } from './types.js'; import { checkBashPermission, isPathInAllowedWritePaths } from './permission-merger.js'; -import { getProviderRegistry } from '../provider/index.js'; +import { getProviderRegistry, resolveApiKey } from '../provider/index.js'; import { renderPromptTemplate, createPlanContext } from '../template/index.js'; import { agentEventEmitter } from './events.js'; @@ -42,10 +42,20 @@ export class AgentExecutor { // 使用 ProviderRegistry 获取模型工厂 const provider = agentInfo.model?.provider ?? baseConfig.provider; const registry = getProviderRegistry(); - this.getModel = registry.getModelFactory(provider, { - apiKey: baseConfig.apiKey, - baseUrl: baseConfig.baseUrl, - }); + + // 当 Agent 指定了不同的 provider 时,需要从 ProviderRegistry 获取对应的配置 + // 而不是使用 baseConfig(主 Agent 配置)的 apiKey 和 baseUrl + let apiKey = baseConfig.apiKey; + let baseUrl = baseConfig.baseUrl; + + if (agentInfo.model?.provider && agentInfo.model.provider !== baseConfig.provider) { + // Agent 使用了不同的 provider,获取对应 provider 的配置 + const providerConfig = registry.getConfig(provider); + apiKey = resolveApiKey(providerConfig) || baseConfig.apiKey; + baseUrl = providerConfig?.baseUrl; + } + + this.getModel = registry.getModelFactory(provider, { apiKey, baseUrl }); } /** diff --git a/packages/core/tests/unit/agent/executor-extended.test.ts b/packages/core/tests/unit/agent/executor-extended.test.ts index 9b6cc4d..4d08cb9 100644 --- a/packages/core/tests/unit/agent/executor-extended.test.ts +++ b/packages/core/tests/unit/agent/executor-extended.test.ts @@ -37,10 +37,13 @@ vi.mock('../../../src/agent/permission-merger.js', () => ({ // Mock provider registry const mockGetModelFactory = vi.fn(); +const mockGetConfig = vi.fn(); vi.mock('../../../src/provider/index.js', () => ({ getProviderRegistry: () => ({ getModelFactory: (...args: unknown[]) => mockGetModelFactory(...args), + getConfig: (...args: unknown[]) => mockGetConfig(...args), }), + resolveApiKey: (config: { apiKey?: string } | undefined) => config?.apiKey, })); import { AgentExecutor } from '../../../src/agent/executor.js'; @@ -113,6 +116,50 @@ describe('AgentExecutor - Agent 执行器扩展测试', () => { expect(mockGetModelFactory).toHaveBeenCalledWith('openai', expect.any(Object)); }); + + it('当 Agent 使用不同 provider 时,获取对应 provider 的配置', () => { + // 设置 openai provider 的配置 + mockGetConfig.mockReturnValue({ + apiKey: 'openai-api-key', + baseUrl: 'https://api.openai.com/v1', + }); + + const agentWithDifferentProvider: AgentInfo = { + ...basicAgentInfo, + model: { provider: 'openai', model: 'gpt-4' }, + }; + + new AgentExecutor(agentWithDifferentProvider, baseConfig, mockToolRegistry as any); + + // 应该调用 getConfig 获取 openai 的配置 + expect(mockGetConfig).toHaveBeenCalledWith('openai'); + + // 应该使用 openai 的配置,而不是 baseConfig 的 + expect(mockGetModelFactory).toHaveBeenCalledWith('openai', { + apiKey: 'openai-api-key', + baseUrl: 'https://api.openai.com/v1', + }); + }); + + it('当 Agent 使用相同 provider 时,使用 baseConfig 的配置', () => { + mockGetConfig.mockClear(); + + const agentWithSameProvider: AgentInfo = { + ...basicAgentInfo, + model: { provider: 'anthropic', model: 'claude-opus' }, + }; + + new AgentExecutor(agentWithSameProvider, baseConfig, mockToolRegistry as any); + + // 不应该调用 getConfig,因为 provider 相同 + expect(mockGetConfig).not.toHaveBeenCalled(); + + // 应该使用 baseConfig 的配置 + expect(mockGetModelFactory).toHaveBeenCalledWith('anthropic', { + apiKey: 'test-key', + baseUrl: undefined, + }); + }); }); describe('execute - 执行任务', () => {