32fdb244f0
- 添加 OpenAI AI SDK provider 支持 (@ai-sdk/openai) - 支持 OpenAI 兼容服务的 baseUrl 配置(如阿里云百炼) - 添加独立的 Vision 配置(visionProvider/visionApiKey/visionBaseUrl/visionModel) - 实现图片引用语法 @path/to/image.png,支持带空格的路径 - 当主模型不支持 vision 时,自动调用配置的 Vision 服务分析图片 - 添加图片处理工具函数和单元测试
520 lines
15 KiB
TypeScript
520 lines
15 KiB
TypeScript
import { createAnthropic } from '@ai-sdk/anthropic';
|
|
import { createDeepSeek } from '@ai-sdk/deepseek';
|
|
import { createOpenAI } from '@ai-sdk/openai';
|
|
import {
|
|
generateText,
|
|
streamText,
|
|
stepCountIs,
|
|
type ModelMessage,
|
|
type Tool as AITool,
|
|
type LanguageModel,
|
|
} from 'ai';
|
|
import type { Tool, ToolResult, Message, AgentConfig, ProviderType, UserInput, ContentBlock } from '../types/index.js';
|
|
import { buildZodSchema } from '../types/index.js';
|
|
import { ToolRegistry } from '../tools/registry.js';
|
|
import { SessionManager } from '../session/index.js';
|
|
import {
|
|
CompressionManager,
|
|
type TokenUsage,
|
|
type CompressionConfig,
|
|
} from '../context/index.js';
|
|
import type { AgentInfo } from '../agent/types.js';
|
|
|
|
// Provider 配置
|
|
interface ProviderOptions {
|
|
apiKey: string;
|
|
baseUrl?: string;
|
|
}
|
|
|
|
// Provider 工厂函数类型
|
|
type ProviderFactory = (options: ProviderOptions) => (model: string) => LanguageModel;
|
|
|
|
// Provider 注册表
|
|
const providers: Record<ProviderType, ProviderFactory> = {
|
|
anthropic: ({ apiKey, baseUrl }) => {
|
|
const client = createAnthropic({ apiKey, baseURL: baseUrl });
|
|
return (model) => client(model);
|
|
},
|
|
deepseek: ({ apiKey, baseUrl }) => {
|
|
const client = createDeepSeek({ apiKey, baseURL: baseUrl });
|
|
return (model) => client(model);
|
|
},
|
|
openai: ({ apiKey, baseUrl }) => {
|
|
const client = createOpenAI({ apiKey, baseURL: baseUrl });
|
|
return (model) => client(model);
|
|
},
|
|
};
|
|
|
|
export class Agent {
|
|
private getModel: (model: string) => LanguageModel;
|
|
private config: AgentConfig;
|
|
private conversationHistory: ModelMessage[] = [];
|
|
|
|
// 工具注册表
|
|
private registry: ToolRegistry | null = null;
|
|
|
|
// 已发现的工具(通过 tool_search 发现的)
|
|
private discoveredTools: Set<string> = new Set();
|
|
|
|
// 兼容旧模式:直接注册的工具
|
|
private legacyTools: Map<string, Tool> = new Map();
|
|
|
|
// 会话管理器(可选)
|
|
private sessionManager: SessionManager | null = null;
|
|
|
|
// 压缩管理器
|
|
private compressionManager: CompressionManager;
|
|
|
|
// 当前 Agent 模式(null 表示默认模式)
|
|
private currentAgentMode: AgentInfo | null = null;
|
|
|
|
// 原始 system prompt(用于切换回 default 时恢复)
|
|
private originalSystemPrompt: string;
|
|
|
|
constructor(config: AgentConfig, compressionConfig?: Partial<CompressionConfig>) {
|
|
this.config = config;
|
|
this.originalSystemPrompt = config.systemPrompt;
|
|
|
|
const providerFactory = providers[config.provider];
|
|
if (!providerFactory) {
|
|
throw new Error(`不支持的 provider: ${config.provider}`);
|
|
}
|
|
this.getModel = providerFactory({ apiKey: config.apiKey, baseUrl: config.baseUrl });
|
|
|
|
// 初始化压缩管理器
|
|
this.compressionManager = new CompressionManager(compressionConfig);
|
|
// 设置模型用于生成摘要
|
|
this.compressionManager.setModel(this.getModel(config.model));
|
|
}
|
|
|
|
/**
|
|
* 设置工具注册表(新模式:支持动态工具发现)
|
|
*/
|
|
setRegistry(registry: ToolRegistry): void {
|
|
this.registry = registry;
|
|
}
|
|
|
|
/**
|
|
* 设置会话管理器(启用会话持久化)
|
|
*/
|
|
setSessionManager(manager: SessionManager): void {
|
|
this.sessionManager = manager;
|
|
// 从会话恢复状态
|
|
const session = manager.getSession();
|
|
if (session) {
|
|
this.conversationHistory = [...session.messages];
|
|
this.discoveredTools = new Set(session.discoveredTools);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 获取会话管理器
|
|
*/
|
|
getSessionManager(): SessionManager | null {
|
|
return this.sessionManager;
|
|
}
|
|
|
|
/**
|
|
* 注册单个工具(兼容旧代码)
|
|
*/
|
|
registerTool(customTool: Tool): void {
|
|
this.legacyTools.set(customTool.name, customTool);
|
|
}
|
|
|
|
/**
|
|
* 批量注册工具(兼容旧代码)
|
|
*/
|
|
registerTools(tools: Tool[]): void {
|
|
for (const tool of tools) {
|
|
this.legacyTools.set(tool.name, tool);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 获取当前可用的工具
|
|
* - 如果使用 registry 模式:返回核心工具 + 已发现的工具
|
|
* - 如果使用旧模式:返回所有注册的工具
|
|
* - 如果当前有 Agent 模式,应用工具过滤
|
|
*/
|
|
private getAvailableTools(): Tool[] {
|
|
let tools: Tool[];
|
|
|
|
if (this.registry) {
|
|
// 新模式:核心工具 + 已发现的工具
|
|
const coreTools = this.registry.getCoreTools();
|
|
const discoveredTools = this.registry.getTools([...this.discoveredTools]);
|
|
tools = [...coreTools, ...discoveredTools];
|
|
} else {
|
|
// 旧模式:返回所有注册的工具
|
|
tools = [...this.legacyTools.values()];
|
|
}
|
|
|
|
// 应用 Agent 模式的工具过滤
|
|
if (this.currentAgentMode?.tools) {
|
|
tools = this.filterToolsByAgentConfig(tools);
|
|
}
|
|
|
|
return tools;
|
|
}
|
|
|
|
/**
|
|
* 根据 Agent 配置过滤工具
|
|
*/
|
|
private filterToolsByAgentConfig(tools: Tool[]): Tool[] {
|
|
const toolConfig = this.currentAgentMode?.tools;
|
|
if (!toolConfig) return tools;
|
|
|
|
let filteredTools = tools;
|
|
|
|
// 如果设置了 enabled 列表,只保留这些工具
|
|
if (toolConfig.enabled && toolConfig.enabled.length > 0) {
|
|
const enabledSet = new Set(toolConfig.enabled);
|
|
filteredTools = filteredTools.filter((t) => enabledSet.has(t.name));
|
|
}
|
|
|
|
// 如果设置了 disabled 列表,排除这些工具
|
|
if (toolConfig.disabled && toolConfig.disabled.length > 0) {
|
|
const disabledSet = new Set(toolConfig.disabled);
|
|
filteredTools = filteredTools.filter((t) => !disabledSet.has(t.name));
|
|
}
|
|
|
|
// 如果禁止嵌套 Task,移除 task 工具
|
|
if (toolConfig.noTask) {
|
|
filteredTools = filteredTools.filter((t) => t.name !== 'task');
|
|
}
|
|
|
|
return filteredTools;
|
|
}
|
|
|
|
/**
|
|
* 将工具转换为 Vercel AI SDK 的工具格式
|
|
*/
|
|
private getVercelTools(): Record<string, AITool> {
|
|
const vercelTools: Record<string, AITool> = {};
|
|
const availableTools = this.getAvailableTools();
|
|
|
|
for (const tool of availableTools) {
|
|
const schema = buildZodSchema(tool.parameters);
|
|
|
|
vercelTools[tool.name] = {
|
|
description: tool.description,
|
|
inputSchema: schema,
|
|
execute: async (params) => {
|
|
const result = await tool.execute(params as Record<string, unknown>);
|
|
|
|
// 如果是 tool_search 调用,解析结果并注入发现的工具
|
|
if (tool.name === 'tool_search' && result.success) {
|
|
this.handleToolSearchResult(result.output);
|
|
}
|
|
|
|
return result;
|
|
},
|
|
} as AITool;
|
|
}
|
|
|
|
return vercelTools;
|
|
}
|
|
|
|
/**
|
|
* 处理 tool_search 的结果,将发现的工具添加到可用列表
|
|
*/
|
|
private handleToolSearchResult(output: string): void {
|
|
// 解析输出,提取工具名称
|
|
// 格式: "- tool_name: description [category]"
|
|
const matches = output.matchAll(/^- (\w+):/gm);
|
|
for (const match of matches) {
|
|
const toolName = match[1];
|
|
if (this.registry?.has(toolName)) {
|
|
this.discoveredTools.add(toolName);
|
|
}
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 发送消息并处理响应(流式)
|
|
* @param userMessage 用户消息文本或包含图片的 UserInput
|
|
* @param onStream 流式输出回调
|
|
*/
|
|
async chat(userMessage: string | UserInput, onStream?: (text: string) => void): Promise<string> {
|
|
// 构建消息内容
|
|
let messageContent: string | ContentBlock[];
|
|
|
|
if (typeof userMessage === 'string') {
|
|
// 纯文本消息
|
|
messageContent = userMessage;
|
|
} else {
|
|
// 带图片的消息
|
|
const blocks: ContentBlock[] = [];
|
|
|
|
// 添加图片
|
|
if (userMessage.images && userMessage.images.length > 0) {
|
|
for (const img of userMessage.images) {
|
|
blocks.push({
|
|
type: 'image',
|
|
image: img.data,
|
|
mimeType: img.mimeType,
|
|
});
|
|
}
|
|
}
|
|
|
|
// 添加文本
|
|
if (userMessage.text) {
|
|
blocks.push({
|
|
type: 'text',
|
|
text: userMessage.text,
|
|
});
|
|
}
|
|
|
|
messageContent = blocks.length === 1 && blocks[0].type === 'text'
|
|
? blocks[0].text
|
|
: blocks;
|
|
}
|
|
|
|
// 添加用户消息到历史
|
|
this.conversationHistory.push({
|
|
role: 'user',
|
|
content: messageContent,
|
|
} as ModelMessage);
|
|
|
|
const vercelTools = this.getVercelTools();
|
|
let fullResponse = '';
|
|
let responseMessages: ModelMessage[] = [];
|
|
|
|
if (onStream) {
|
|
// 流式模式
|
|
const result = streamText({
|
|
model: this.getModel(this.config.model),
|
|
system: this.config.systemPrompt,
|
|
messages: this.conversationHistory,
|
|
tools: vercelTools,
|
|
maxOutputTokens: this.config.maxTokens,
|
|
stopWhen: stepCountIs(10), // 允许最多 10 轮工具调用
|
|
onChunk: ({ chunk }) => {
|
|
if (chunk.type === 'tool-call') {
|
|
onStream(`\n[调用工具: ${chunk.toolName}]\n`);
|
|
} else if (chunk.type === 'tool-result') {
|
|
const output = (chunk as { output?: ToolResult }).output;
|
|
if (output && typeof output === 'object') {
|
|
if (output.success) {
|
|
// 截断过长的输出
|
|
const displayOutput =
|
|
output.output.length > 500
|
|
? output.output.substring(0, 500) + '...(截断)'
|
|
: output.output;
|
|
onStream(`[结果: ${displayOutput}]\n`);
|
|
} else {
|
|
onStream(`[错误: ${output.error}]\n`);
|
|
}
|
|
}
|
|
}
|
|
},
|
|
});
|
|
|
|
// 流式输出文本
|
|
for await (const chunk of result.textStream) {
|
|
fullResponse += chunk;
|
|
onStream(chunk);
|
|
}
|
|
|
|
// 等待完成并获取完整的响应消息(包括工具调用和结果)
|
|
const response = await result.response;
|
|
responseMessages = response.messages as ModelMessage[];
|
|
} else {
|
|
// 非流式模式
|
|
const result = await generateText({
|
|
model: this.getModel(this.config.model),
|
|
system: this.config.systemPrompt,
|
|
messages: this.conversationHistory,
|
|
tools: vercelTools,
|
|
maxOutputTokens: this.config.maxTokens,
|
|
stopWhen: stepCountIs(10), // 允许最多 10 轮工具调用
|
|
});
|
|
|
|
fullResponse = result.text;
|
|
responseMessages = result.response.messages as ModelMessage[];
|
|
}
|
|
|
|
// 将完整的响应消息添加到历史(包括工具调用和结果)
|
|
this.conversationHistory.push(...responseMessages);
|
|
|
|
// 检查是否需要自动压缩
|
|
if (this.compressionManager.shouldCompress(this.conversationHistory)) {
|
|
const result = await this.compressionManager.compress(this.conversationHistory);
|
|
if (result.freedTokens > 0) {
|
|
this.conversationHistory = result.messages;
|
|
if (onStream) {
|
|
onStream(`\n[自动压缩: 释放了 ${(result.freedTokens / 1000).toFixed(1)}k tokens]\n`);
|
|
}
|
|
}
|
|
}
|
|
|
|
// 持久化会话
|
|
await this.persistSession();
|
|
|
|
return fullResponse;
|
|
}
|
|
|
|
/**
|
|
* 持久化当前会话状态
|
|
*/
|
|
private async persistSession(): Promise<void> {
|
|
if (!this.sessionManager) return;
|
|
|
|
await this.sessionManager.setMessages(this.conversationHistory);
|
|
await this.sessionManager.setDiscoveredTools([...this.discoveredTools]);
|
|
}
|
|
|
|
/**
|
|
* 清空对话历史和发现的工具
|
|
*/
|
|
async clearHistory(): Promise<void> {
|
|
this.conversationHistory = [];
|
|
this.discoveredTools.clear();
|
|
|
|
// 如果有会话管理器,创建新会话
|
|
if (this.sessionManager) {
|
|
await this.sessionManager.newSession();
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 获取对话历史
|
|
*/
|
|
getHistory(): Message[] {
|
|
return this.conversationHistory
|
|
.filter(
|
|
(msg): msg is ModelMessage & { role: 'user' | 'assistant' } =>
|
|
msg.role === 'user' || msg.role === 'assistant'
|
|
)
|
|
.map((msg) => ({
|
|
role: msg.role,
|
|
content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content),
|
|
}));
|
|
}
|
|
|
|
/**
|
|
* 获取当前可用工具的数量
|
|
*/
|
|
getToolCount(): { core: number; discovered: number; total: number } {
|
|
if (this.registry) {
|
|
const coreCount = this.registry.getCoreTools().length;
|
|
const discoveredCount = this.discoveredTools.size;
|
|
return {
|
|
core: coreCount,
|
|
discovered: discoveredCount,
|
|
total: coreCount + discoveredCount,
|
|
};
|
|
} else {
|
|
return {
|
|
core: this.legacyTools.size,
|
|
discovered: 0,
|
|
total: this.legacyTools.size,
|
|
};
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 获取当前上下文使用情况
|
|
*/
|
|
getContextUsage(): TokenUsage {
|
|
return this.compressionManager.calculateUsage(this.conversationHistory);
|
|
}
|
|
|
|
/**
|
|
* 获取格式化的上下文使用情况(用于 CLI 显示)
|
|
*/
|
|
getContextUsageFormatted(): string {
|
|
return this.compressionManager.formatUsage(this.conversationHistory);
|
|
}
|
|
|
|
/**
|
|
* 获取压缩管理器
|
|
*/
|
|
getCompressionManager(): CompressionManager {
|
|
return this.compressionManager;
|
|
}
|
|
|
|
/**
|
|
* 手动压缩对话历史(用于 /compact 命令)
|
|
*/
|
|
async compactHistory(): Promise<{ freedTokens: number; type: string }> {
|
|
const result = await this.compressionManager.forceCompress(this.conversationHistory);
|
|
if (result.freedTokens > 0) {
|
|
this.conversationHistory = result.messages;
|
|
await this.persistSession();
|
|
}
|
|
return {
|
|
freedTokens: result.freedTokens,
|
|
type: result.type,
|
|
};
|
|
}
|
|
|
|
/**
|
|
* 切换 Agent 模式
|
|
*/
|
|
setAgentMode(agent: AgentInfo | null): void {
|
|
this.currentAgentMode = agent;
|
|
|
|
if (agent?.prompt) {
|
|
// 切换到指定 Agent,使用其 prompt
|
|
this.config = {
|
|
...this.config,
|
|
systemPrompt: agent.prompt,
|
|
};
|
|
} else {
|
|
// 切换回 default,恢复原始 prompt
|
|
this.config = {
|
|
...this.config,
|
|
systemPrompt: this.originalSystemPrompt,
|
|
};
|
|
}
|
|
}
|
|
|
|
/**
|
|
* 获取当前 Agent 模式
|
|
*/
|
|
getAgentMode(): AgentInfo | null {
|
|
return this.currentAgentMode;
|
|
}
|
|
|
|
/**
|
|
* 获取当前 Agent 名称
|
|
*/
|
|
getAgentModeName(): string {
|
|
return this.currentAgentMode?.name ?? 'default';
|
|
}
|
|
|
|
/**
|
|
* 检查当前模型是否支持 vision(图片理解)
|
|
*/
|
|
supportsVision(): boolean {
|
|
const model = this.config.model.toLowerCase();
|
|
|
|
// Anthropic Claude 模型支持 vision
|
|
if (this.config.provider === 'anthropic') {
|
|
// Claude 3 及以上版本支持 vision
|
|
return model.includes('claude-3') || model.includes('claude-4');
|
|
}
|
|
|
|
// OpenAI GPT-4 系列支持 vision
|
|
if (this.config.provider === 'openai') {
|
|
// GPT-4o, GPT-4 Turbo, GPT-4 Vision 等支持
|
|
return model.includes('gpt-4');
|
|
}
|
|
|
|
// DeepSeek 目前不支持 vision
|
|
if (this.config.provider === 'deepseek') {
|
|
return false;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
/**
|
|
* 获取当前配置
|
|
*/
|
|
getConfig(): AgentConfig {
|
|
return { ...this.config };
|
|
}
|
|
}
|