Files
ai-terminal-assistant/src/core/agent.ts
T
kurihada 32fdb244f0 feat: 添加 OpenAI 兼容 API 支持和独立 Vision 服务
- 添加 OpenAI AI SDK provider 支持 (@ai-sdk/openai)
- 支持 OpenAI 兼容服务的 baseUrl 配置(如阿里云百炼)
- 添加独立的 Vision 配置(visionProvider/visionApiKey/visionBaseUrl/visionModel)
- 实现图片引用语法 @path/to/image.png,支持带空格的路径
- 当主模型不支持 vision 时,自动调用配置的 Vision 服务分析图片
- 添加图片处理工具函数和单元测试
2025-12-11 17:49:16 +08:00

520 lines
15 KiB
TypeScript

import { createAnthropic } from '@ai-sdk/anthropic';
import { createDeepSeek } from '@ai-sdk/deepseek';
import { createOpenAI } from '@ai-sdk/openai';
import {
generateText,
streamText,
stepCountIs,
type ModelMessage,
type Tool as AITool,
type LanguageModel,
} from 'ai';
import type { Tool, ToolResult, Message, AgentConfig, ProviderType, UserInput, ContentBlock } from '../types/index.js';
import { buildZodSchema } from '../types/index.js';
import { ToolRegistry } from '../tools/registry.js';
import { SessionManager } from '../session/index.js';
import {
CompressionManager,
type TokenUsage,
type CompressionConfig,
} from '../context/index.js';
import type { AgentInfo } from '../agent/types.js';
// Provider 配置
interface ProviderOptions {
apiKey: string;
baseUrl?: string;
}
// Provider 工厂函数类型
type ProviderFactory = (options: ProviderOptions) => (model: string) => LanguageModel;
// Provider 注册表
const providers: Record<ProviderType, ProviderFactory> = {
anthropic: ({ apiKey, baseUrl }) => {
const client = createAnthropic({ apiKey, baseURL: baseUrl });
return (model) => client(model);
},
deepseek: ({ apiKey, baseUrl }) => {
const client = createDeepSeek({ apiKey, baseURL: baseUrl });
return (model) => client(model);
},
openai: ({ apiKey, baseUrl }) => {
const client = createOpenAI({ apiKey, baseURL: baseUrl });
return (model) => client(model);
},
};
export class Agent {
private getModel: (model: string) => LanguageModel;
private config: AgentConfig;
private conversationHistory: ModelMessage[] = [];
// 工具注册表
private registry: ToolRegistry | null = null;
// 已发现的工具(通过 tool_search 发现的)
private discoveredTools: Set<string> = new Set();
// 兼容旧模式:直接注册的工具
private legacyTools: Map<string, Tool> = new Map();
// 会话管理器(可选)
private sessionManager: SessionManager | null = null;
// 压缩管理器
private compressionManager: CompressionManager;
// 当前 Agent 模式(null 表示默认模式)
private currentAgentMode: AgentInfo | null = null;
// 原始 system prompt(用于切换回 default 时恢复)
private originalSystemPrompt: string;
constructor(config: AgentConfig, compressionConfig?: Partial<CompressionConfig>) {
this.config = config;
this.originalSystemPrompt = config.systemPrompt;
const providerFactory = providers[config.provider];
if (!providerFactory) {
throw new Error(`不支持的 provider: ${config.provider}`);
}
this.getModel = providerFactory({ apiKey: config.apiKey, baseUrl: config.baseUrl });
// 初始化压缩管理器
this.compressionManager = new CompressionManager(compressionConfig);
// 设置模型用于生成摘要
this.compressionManager.setModel(this.getModel(config.model));
}
/**
* 设置工具注册表(新模式:支持动态工具发现)
*/
setRegistry(registry: ToolRegistry): void {
this.registry = registry;
}
/**
* 设置会话管理器(启用会话持久化)
*/
setSessionManager(manager: SessionManager): void {
this.sessionManager = manager;
// 从会话恢复状态
const session = manager.getSession();
if (session) {
this.conversationHistory = [...session.messages];
this.discoveredTools = new Set(session.discoveredTools);
}
}
/**
* 获取会话管理器
*/
getSessionManager(): SessionManager | null {
return this.sessionManager;
}
/**
* 注册单个工具(兼容旧代码)
*/
registerTool(customTool: Tool): void {
this.legacyTools.set(customTool.name, customTool);
}
/**
* 批量注册工具(兼容旧代码)
*/
registerTools(tools: Tool[]): void {
for (const tool of tools) {
this.legacyTools.set(tool.name, tool);
}
}
/**
* 获取当前可用的工具
* - 如果使用 registry 模式:返回核心工具 + 已发现的工具
* - 如果使用旧模式:返回所有注册的工具
* - 如果当前有 Agent 模式,应用工具过滤
*/
private getAvailableTools(): Tool[] {
let tools: Tool[];
if (this.registry) {
// 新模式:核心工具 + 已发现的工具
const coreTools = this.registry.getCoreTools();
const discoveredTools = this.registry.getTools([...this.discoveredTools]);
tools = [...coreTools, ...discoveredTools];
} else {
// 旧模式:返回所有注册的工具
tools = [...this.legacyTools.values()];
}
// 应用 Agent 模式的工具过滤
if (this.currentAgentMode?.tools) {
tools = this.filterToolsByAgentConfig(tools);
}
return tools;
}
/**
* 根据 Agent 配置过滤工具
*/
private filterToolsByAgentConfig(tools: Tool[]): Tool[] {
const toolConfig = this.currentAgentMode?.tools;
if (!toolConfig) return tools;
let filteredTools = tools;
// 如果设置了 enabled 列表,只保留这些工具
if (toolConfig.enabled && toolConfig.enabled.length > 0) {
const enabledSet = new Set(toolConfig.enabled);
filteredTools = filteredTools.filter((t) => enabledSet.has(t.name));
}
// 如果设置了 disabled 列表,排除这些工具
if (toolConfig.disabled && toolConfig.disabled.length > 0) {
const disabledSet = new Set(toolConfig.disabled);
filteredTools = filteredTools.filter((t) => !disabledSet.has(t.name));
}
// 如果禁止嵌套 Task,移除 task 工具
if (toolConfig.noTask) {
filteredTools = filteredTools.filter((t) => t.name !== 'task');
}
return filteredTools;
}
/**
* 将工具转换为 Vercel AI SDK 的工具格式
*/
private getVercelTools(): Record<string, AITool> {
const vercelTools: Record<string, AITool> = {};
const availableTools = this.getAvailableTools();
for (const tool of availableTools) {
const schema = buildZodSchema(tool.parameters);
vercelTools[tool.name] = {
description: tool.description,
inputSchema: schema,
execute: async (params) => {
const result = await tool.execute(params as Record<string, unknown>);
// 如果是 tool_search 调用,解析结果并注入发现的工具
if (tool.name === 'tool_search' && result.success) {
this.handleToolSearchResult(result.output);
}
return result;
},
} as AITool;
}
return vercelTools;
}
/**
* 处理 tool_search 的结果,将发现的工具添加到可用列表
*/
private handleToolSearchResult(output: string): void {
// 解析输出,提取工具名称
// 格式: "- tool_name: description [category]"
const matches = output.matchAll(/^- (\w+):/gm);
for (const match of matches) {
const toolName = match[1];
if (this.registry?.has(toolName)) {
this.discoveredTools.add(toolName);
}
}
}
/**
* 发送消息并处理响应(流式)
* @param userMessage 用户消息文本或包含图片的 UserInput
* @param onStream 流式输出回调
*/
async chat(userMessage: string | UserInput, onStream?: (text: string) => void): Promise<string> {
// 构建消息内容
let messageContent: string | ContentBlock[];
if (typeof userMessage === 'string') {
// 纯文本消息
messageContent = userMessage;
} else {
// 带图片的消息
const blocks: ContentBlock[] = [];
// 添加图片
if (userMessage.images && userMessage.images.length > 0) {
for (const img of userMessage.images) {
blocks.push({
type: 'image',
image: img.data,
mimeType: img.mimeType,
});
}
}
// 添加文本
if (userMessage.text) {
blocks.push({
type: 'text',
text: userMessage.text,
});
}
messageContent = blocks.length === 1 && blocks[0].type === 'text'
? blocks[0].text
: blocks;
}
// 添加用户消息到历史
this.conversationHistory.push({
role: 'user',
content: messageContent,
} as ModelMessage);
const vercelTools = this.getVercelTools();
let fullResponse = '';
let responseMessages: ModelMessage[] = [];
if (onStream) {
// 流式模式
const result = streamText({
model: this.getModel(this.config.model),
system: this.config.systemPrompt,
messages: this.conversationHistory,
tools: vercelTools,
maxOutputTokens: this.config.maxTokens,
stopWhen: stepCountIs(10), // 允许最多 10 轮工具调用
onChunk: ({ chunk }) => {
if (chunk.type === 'tool-call') {
onStream(`\n[调用工具: ${chunk.toolName}]\n`);
} else if (chunk.type === 'tool-result') {
const output = (chunk as { output?: ToolResult }).output;
if (output && typeof output === 'object') {
if (output.success) {
// 截断过长的输出
const displayOutput =
output.output.length > 500
? output.output.substring(0, 500) + '...(截断)'
: output.output;
onStream(`[结果: ${displayOutput}]\n`);
} else {
onStream(`[错误: ${output.error}]\n`);
}
}
}
},
});
// 流式输出文本
for await (const chunk of result.textStream) {
fullResponse += chunk;
onStream(chunk);
}
// 等待完成并获取完整的响应消息(包括工具调用和结果)
const response = await result.response;
responseMessages = response.messages as ModelMessage[];
} else {
// 非流式模式
const result = await generateText({
model: this.getModel(this.config.model),
system: this.config.systemPrompt,
messages: this.conversationHistory,
tools: vercelTools,
maxOutputTokens: this.config.maxTokens,
stopWhen: stepCountIs(10), // 允许最多 10 轮工具调用
});
fullResponse = result.text;
responseMessages = result.response.messages as ModelMessage[];
}
// 将完整的响应消息添加到历史(包括工具调用和结果)
this.conversationHistory.push(...responseMessages);
// 检查是否需要自动压缩
if (this.compressionManager.shouldCompress(this.conversationHistory)) {
const result = await this.compressionManager.compress(this.conversationHistory);
if (result.freedTokens > 0) {
this.conversationHistory = result.messages;
if (onStream) {
onStream(`\n[自动压缩: 释放了 ${(result.freedTokens / 1000).toFixed(1)}k tokens]\n`);
}
}
}
// 持久化会话
await this.persistSession();
return fullResponse;
}
/**
* 持久化当前会话状态
*/
private async persistSession(): Promise<void> {
if (!this.sessionManager) return;
await this.sessionManager.setMessages(this.conversationHistory);
await this.sessionManager.setDiscoveredTools([...this.discoveredTools]);
}
/**
* 清空对话历史和发现的工具
*/
async clearHistory(): Promise<void> {
this.conversationHistory = [];
this.discoveredTools.clear();
// 如果有会话管理器,创建新会话
if (this.sessionManager) {
await this.sessionManager.newSession();
}
}
/**
* 获取对话历史
*/
getHistory(): Message[] {
return this.conversationHistory
.filter(
(msg): msg is ModelMessage & { role: 'user' | 'assistant' } =>
msg.role === 'user' || msg.role === 'assistant'
)
.map((msg) => ({
role: msg.role,
content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content),
}));
}
/**
* 获取当前可用工具的数量
*/
getToolCount(): { core: number; discovered: number; total: number } {
if (this.registry) {
const coreCount = this.registry.getCoreTools().length;
const discoveredCount = this.discoveredTools.size;
return {
core: coreCount,
discovered: discoveredCount,
total: coreCount + discoveredCount,
};
} else {
return {
core: this.legacyTools.size,
discovered: 0,
total: this.legacyTools.size,
};
}
}
/**
* 获取当前上下文使用情况
*/
getContextUsage(): TokenUsage {
return this.compressionManager.calculateUsage(this.conversationHistory);
}
/**
* 获取格式化的上下文使用情况(用于 CLI 显示)
*/
getContextUsageFormatted(): string {
return this.compressionManager.formatUsage(this.conversationHistory);
}
/**
* 获取压缩管理器
*/
getCompressionManager(): CompressionManager {
return this.compressionManager;
}
/**
* 手动压缩对话历史(用于 /compact 命令)
*/
async compactHistory(): Promise<{ freedTokens: number; type: string }> {
const result = await this.compressionManager.forceCompress(this.conversationHistory);
if (result.freedTokens > 0) {
this.conversationHistory = result.messages;
await this.persistSession();
}
return {
freedTokens: result.freedTokens,
type: result.type,
};
}
/**
* 切换 Agent 模式
*/
setAgentMode(agent: AgentInfo | null): void {
this.currentAgentMode = agent;
if (agent?.prompt) {
// 切换到指定 Agent,使用其 prompt
this.config = {
...this.config,
systemPrompt: agent.prompt,
};
} else {
// 切换回 default,恢复原始 prompt
this.config = {
...this.config,
systemPrompt: this.originalSystemPrompt,
};
}
}
/**
* 获取当前 Agent 模式
*/
getAgentMode(): AgentInfo | null {
return this.currentAgentMode;
}
/**
* 获取当前 Agent 名称
*/
getAgentModeName(): string {
return this.currentAgentMode?.name ?? 'default';
}
/**
* 检查当前模型是否支持 vision(图片理解)
*/
supportsVision(): boolean {
const model = this.config.model.toLowerCase();
// Anthropic Claude 模型支持 vision
if (this.config.provider === 'anthropic') {
// Claude 3 及以上版本支持 vision
return model.includes('claude-3') || model.includes('claude-4');
}
// OpenAI GPT-4 系列支持 vision
if (this.config.provider === 'openai') {
// GPT-4o, GPT-4 Turbo, GPT-4 Vision 等支持
return model.includes('gpt-4');
}
// DeepSeek 目前不支持 vision
if (this.config.provider === 'deepseek') {
return false;
}
return false;
}
/**
* 获取当前配置
*/
getConfig(): AgentConfig {
return { ...this.config };
}
}