diff --git a/.env.example b/.env.example index 4373662..a6c5187 100644 --- a/.env.example +++ b/.env.example @@ -4,3 +4,8 @@ ANTHROPIC_API_KEY=sk-ant-xxxxx # 可选配置 AI_MODEL=claude-sonnet-4-20250514 AI_MAX_TOKENS=4096 + +# Vision 配置(用于图片理解,当主模型不支持 vision 时使用) +# 如果不配置,默认使用 Anthropic Claude +# VISION_PROVIDER=anthropic +# VISION_MODEL=claude-sonnet-4-20250514 diff --git a/package-lock.json b/package-lock.json index 60c7568..5c79154 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,6 +11,7 @@ "dependencies": { "@ai-sdk/anthropic": "^2.0.54", "@ai-sdk/deepseek": "^1.0.31", + "@ai-sdk/openai": "^2.0.80", "@tavily/core": "^0.6.0", "ai": "^5.0.108", "chalk": "^5.3.0", @@ -88,6 +89,22 @@ "zod": "^3.25.76 || ^4.1.8" } }, + "node_modules/@ai-sdk/openai": { + "version": "2.0.80", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-2.0.80.tgz", + "integrity": "sha512-tNHuraF11db+8xJEDBoU9E3vMcpnHFKRhnLQ3DQX2LnEzfPB9DksZ8rE+yVuDN1WRW9cm2OWAhgHFgVKs7ICuw==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "2.0.0", + "@ai-sdk/provider-utils": "3.0.18" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.25.76 || ^4.1.8" + } + }, "node_modules/@ai-sdk/provider": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.0.tgz", diff --git a/package.json b/package.json index c5fdf65..97560cc 100644 --- a/package.json +++ b/package.json @@ -28,6 +28,7 @@ "dependencies": { "@ai-sdk/anthropic": "^2.0.54", "@ai-sdk/deepseek": "^1.0.31", + "@ai-sdk/openai": "^2.0.80", "@tavily/core": "^0.6.0", "ai": "^5.0.108", "chalk": "^5.3.0", diff --git a/src/agent/executor.ts b/src/agent/executor.ts index 77eaefa..9d5f7ba 100644 --- a/src/agent/executor.ts +++ b/src/agent/executor.ts @@ -1,5 +1,6 @@ import { createAnthropic } from '@ai-sdk/anthropic'; import { createDeepSeek } from '@ai-sdk/deepseek'; +import { createOpenAI } from '@ai-sdk/openai'; import { generateText, streamText, @@ -18,17 +19,27 @@ import type { } from './types.js'; import { checkBashPermission } from './permission-merger.js'; +// Provider 配置 +interface ProviderOptions { + apiKey: string; + baseUrl?: string; +} + // Provider 工厂函数类型 -type ProviderFactory = (apiKey: string) => (model: string) => LanguageModel; +type ProviderFactory = (options: ProviderOptions) => (model: string) => LanguageModel; // Provider 注册表 const providers: Record = { - anthropic: (apiKey) => { - const client = createAnthropic({ apiKey }); + anthropic: ({ apiKey, baseUrl }) => { + const client = createAnthropic({ apiKey, baseURL: baseUrl }); return (model) => client(model); }, - deepseek: (apiKey) => { - const client = createDeepSeek({ apiKey }); + deepseek: ({ apiKey, baseUrl }) => { + const client = createDeepSeek({ apiKey, baseURL: baseUrl }); + return (model) => client(model); + }, + openai: ({ apiKey, baseUrl }) => { + const client = createOpenAI({ apiKey, baseURL: baseUrl }); return (model) => client(model); }, }; @@ -58,7 +69,7 @@ export class AgentExecutor { if (!providerFactory) { throw new Error(`不支持的 provider: ${provider}`); } - this.getModel = providerFactory(baseConfig.apiKey); + this.getModel = providerFactory({ apiKey: baseConfig.apiKey, baseUrl: baseConfig.baseUrl }); } /** diff --git a/src/core/agent.ts b/src/core/agent.ts index 8ebbc96..2e26c4c 100644 --- a/src/core/agent.ts +++ b/src/core/agent.ts @@ -1,5 +1,6 @@ import { createAnthropic } from '@ai-sdk/anthropic'; import { createDeepSeek } from '@ai-sdk/deepseek'; +import { createOpenAI } from '@ai-sdk/openai'; import { generateText, streamText, @@ -8,7 +9,7 @@ import { type Tool as AITool, type LanguageModel, } from 'ai'; -import type { Tool, ToolResult, Message, AgentConfig, ProviderType } from '../types/index.js'; +import type { Tool, ToolResult, Message, AgentConfig, ProviderType, UserInput, ContentBlock } from '../types/index.js'; import { buildZodSchema } from '../types/index.js'; import { ToolRegistry } from '../tools/registry.js'; import { SessionManager } from '../session/index.js'; @@ -19,17 +20,27 @@ import { } from '../context/index.js'; import type { AgentInfo } from '../agent/types.js'; +// Provider 配置 +interface ProviderOptions { + apiKey: string; + baseUrl?: string; +} + // Provider 工厂函数类型 -type ProviderFactory = (apiKey: string) => (model: string) => LanguageModel; +type ProviderFactory = (options: ProviderOptions) => (model: string) => LanguageModel; // Provider 注册表 const providers: Record = { - anthropic: (apiKey) => { - const client = createAnthropic({ apiKey }); + anthropic: ({ apiKey, baseUrl }) => { + const client = createAnthropic({ apiKey, baseURL: baseUrl }); return (model) => client(model); }, - deepseek: (apiKey) => { - const client = createDeepSeek({ apiKey }); + deepseek: ({ apiKey, baseUrl }) => { + const client = createDeepSeek({ apiKey, baseURL: baseUrl }); + return (model) => client(model); + }, + openai: ({ apiKey, baseUrl }) => { + const client = createOpenAI({ apiKey, baseURL: baseUrl }); return (model) => client(model); }, }; @@ -68,7 +79,7 @@ export class Agent { if (!providerFactory) { throw new Error(`不支持的 provider: ${config.provider}`); } - this.getModel = providerFactory(config.apiKey); + this.getModel = providerFactory({ apiKey: config.apiKey, baseUrl: config.baseUrl }); // 初始化压缩管理器 this.compressionManager = new CompressionManager(compressionConfig); @@ -221,13 +232,49 @@ export class Agent { /** * 发送消息并处理响应(流式) + * @param userMessage 用户消息文本或包含图片的 UserInput + * @param onStream 流式输出回调 */ - async chat(userMessage: string, onStream?: (text: string) => void): Promise { + async chat(userMessage: string | UserInput, onStream?: (text: string) => void): Promise { + // 构建消息内容 + let messageContent: string | ContentBlock[]; + + if (typeof userMessage === 'string') { + // 纯文本消息 + messageContent = userMessage; + } else { + // 带图片的消息 + const blocks: ContentBlock[] = []; + + // 添加图片 + if (userMessage.images && userMessage.images.length > 0) { + for (const img of userMessage.images) { + blocks.push({ + type: 'image', + image: img.data, + mimeType: img.mimeType, + }); + } + } + + // 添加文本 + if (userMessage.text) { + blocks.push({ + type: 'text', + text: userMessage.text, + }); + } + + messageContent = blocks.length === 1 && blocks[0].type === 'text' + ? blocks[0].text + : blocks; + } + // 添加用户消息到历史 this.conversationHistory.push({ role: 'user', - content: userMessage, - }); + content: messageContent, + } as ModelMessage); const vercelTools = this.getVercelTools(); let fullResponse = ''; @@ -436,4 +483,37 @@ export class Agent { getAgentModeName(): string { return this.currentAgentMode?.name ?? 'default'; } + + /** + * 检查当前模型是否支持 vision(图片理解) + */ + supportsVision(): boolean { + const model = this.config.model.toLowerCase(); + + // Anthropic Claude 模型支持 vision + if (this.config.provider === 'anthropic') { + // Claude 3 及以上版本支持 vision + return model.includes('claude-3') || model.includes('claude-4'); + } + + // OpenAI GPT-4 系列支持 vision + if (this.config.provider === 'openai') { + // GPT-4o, GPT-4 Turbo, GPT-4 Vision 等支持 + return model.includes('gpt-4'); + } + + // DeepSeek 目前不支持 vision + if (this.config.provider === 'deepseek') { + return false; + } + + return false; + } + + /** + * 获取当前配置 + */ + getConfig(): AgentConfig { + return { ...this.config }; + } } diff --git a/src/types/index.ts b/src/types/index.ts index f023f56..ff3e97e 100644 --- a/src/types/index.ts +++ b/src/types/index.ts @@ -1,9 +1,38 @@ import { z } from 'zod'; -// 消息类型 +// 内容块类型(支持多模态) +export interface TextContentBlock { + type: 'text'; + text: string; +} + +export interface ImageContentBlock { + type: 'image'; + /** base64 编码的图片数据 */ + image: string; + /** MIME 类型 */ + mimeType?: string; +} + +export type ContentBlock = TextContentBlock | ImageContentBlock; + +// 消息类型(支持多模态) export interface Message { role: 'user' | 'assistant'; - content: string; + /** 内容可以是纯文本或内容块数组 */ + content: string | ContentBlock[]; +} + +// 用户输入(带图片) +export interface UserInput { + /** 文本内容 */ + text: string; + /** 图片列表(base64 编码) */ + images?: Array<{ + data: string; + mimeType: string; + filename?: string; + }>; } // 工具参数定义 @@ -37,7 +66,7 @@ export interface ToolCall { } // 支持的 Provider 类型 -export type ProviderType = 'anthropic' | 'deepseek'; +export type ProviderType = 'anthropic' | 'deepseek' | 'openai'; // Agent 配置 export interface AgentConfig { @@ -46,6 +75,8 @@ export interface AgentConfig { model: string; maxTokens: number; systemPrompt: string; + /** 自定义 API 基础 URL(用于兼容 OpenAI API 的第三方服务,如阿里云百炼) */ + baseUrl?: string; } // 会话上下文 diff --git a/src/ui/terminal.ts b/src/ui/terminal.ts index 0f4d3c8..8956805 100644 --- a/src/ui/terminal.ts +++ b/src/ui/terminal.ts @@ -7,6 +7,17 @@ import { createCommandExecutor, type CommandExecutionResult, } from '../commands/index.js'; +import { + extractImageReferences, + loadImages, + formatFileSize, +} from '../utils/image.js'; +import { + analyzeImages, + isVisionAvailable, + getVisionInfo, +} from '../utils/vision.js'; +import type { UserInput } from '../types/index.js'; export class TerminalUI { private agent: Agent; @@ -278,6 +289,118 @@ export class TerminalUI { } } + // 处理包含图片引用的输入 + private async processImageInput( + input: string + ): Promise<{ userInput: UserInput; hasImages: boolean } | null> { + const { imagePaths, textContent } = extractImageReferences(input); + + // 没有图片引用,返回纯文本 + if (imagePaths.length === 0) { + return { + userInput: { text: input }, + hasImages: false, + }; + } + + // 加载图片 + console.log(chalk.gray(`\n正在加载 ${imagePaths.length} 张图片...`)); + const { images, errors } = await loadImages(imagePaths, process.cwd()); + + // 显示加载错误 + for (const err of errors) { + console.log(chalk.red(` ✗ ${err.path}: ${err.error}`)); + } + + // 如果没有成功加载任何图片 + if (images.length === 0) { + console.log(chalk.red('没有成功加载任何图片\n')); + return null; + } + + // 显示成功加载的图片 + for (const img of images) { + console.log( + chalk.green(` ✓ ${img.filename}`) + + chalk.gray(` (${formatFileSize(img.size)})`) + ); + } + console.log(''); + + return { + userInput: { + text: textContent, + images: images.map((img) => ({ + data: img.base64, + mimeType: img.mimeType, + filename: img.filename, + })), + }, + hasImages: true, + }; + } + + // 处理不支持 Vision 的情况 + private async handleNoVisionSupport( + userInput: UserInput + ): Promise { + // 检查 Vision 服务是否可用 + if (!isVisionAvailable()) { + console.log(chalk.yellow('\n⚠ 当前模型不支持图片理解,且未配置 Vision 服务')); + console.log(chalk.gray('请在配置文件中设置 visionProvider、visionApiKey 等参数')); + console.log(chalk.gray('或切换到支持图片理解的模型(如 Claude、GPT-4o)\n')); + return null; + } + + const visionInfo = getVisionInfo(); + + // 提示用户选择 + console.log(chalk.yellow('\n⚠ 当前模型不支持图片理解')); + console.log(chalk.gray('请选择处理方式:')); + console.log(chalk.white(` 1. 使用 Vision 服务 (${visionInfo.model}) 分析图片后继续对话`)); + console.log(chalk.white(' 2. 取消本次输入')); + + const choice = await new Promise((resolve) => { + this.rl.question(chalk.green('选择 (1/2): '), resolve); + }); + + if (choice.trim() !== '1') { + console.log(chalk.gray('已取消\n')); + return null; + } + + // 使用 Vision 服务分析图片 + console.log(chalk.cyan(`\n正在使用 ${visionInfo.model} 分析图片...`)); + + const images = userInput.images || []; + if (images.length === 0) { + console.log(chalk.red('没有图片需要分析\n')); + return null; + } + + // 调用 Vision API 分析图片 + const result = await analyzeImages( + images.map(img => ({ + data: img.data, + mimeType: img.mimeType, + filename: img.filename, + })), + userInput.text || undefined + ); + + if (!result.success) { + console.log(chalk.red(`\n图片分析失败: ${result.error}\n`)); + return null; + } + + console.log(chalk.green('✓ 图片分析完成\n')); + + // 构建带图片描述的文本消息 + const combinedText = `[Vision 服务分析结果]\n${result.description}\n\n用户原始问题: ${userInput.text}`; + + return combinedText; + } + // 提问并获取用户输入 private prompt(): Promise { return new Promise((resolve, reject) => { @@ -378,13 +501,35 @@ export class TerminalUI { continue; } + // 处理图片引用 + const processed = await this.processImageInput(input); + if (!processed) { + continue; + } + + let { userInput, hasImages } = processed; + + // 如果有图片且当前模型不支持 vision + if (hasImages && !this.agent.supportsVision()) { + const fallbackText = await this.handleNoVisionSupport(userInput); + if (!fallbackText) { + continue; + } + // 使用 Vision 分析结果替代图片 + userInput = { text: fallbackText }; + hasImages = false; + } + // 发送给 AI process.stdout.write(chalk.gray('思考中...')); try { let isFirstChunk = true; - await this.agent.chat(input, (text) => { + // 根据是否有图片选择发送格式 + const messageToSend = hasImages ? userInput : userInput.text; + + await this.agent.chat(messageToSend, (text) => { if (isFirstChunk) { // 清除 "思考中..." 并显示 AI 前缀 process.stdout.write('\r' + ' '.repeat(20) + '\r'); diff --git a/src/utils/config.ts b/src/utils/config.ts index 73f83a5..4baed6a 100644 --- a/src/utils/config.ts +++ b/src/utils/config.ts @@ -10,15 +10,42 @@ interface StoredConfig { provider?: ProviderType; apiKey?: string; deepseekApiKey?: string; + openaiApiKey?: string; model?: string; maxTokens?: number; tavilyApiKey?: string; + /** 自定义 API 基础 URL(用于 OpenAI 兼容服务,如阿里云百炼) */ + baseUrl?: string; + // Vision 配置 + visionProvider?: ProviderType; + visionModel?: string; + /** Vision 专用的 API Key(可选,不设置则使用对应 provider 的 key) */ + visionApiKey?: string; + /** Vision 专用的 Base URL(用于 OpenAI 兼容的 Vision 服务) */ + visionBaseUrl?: string; +} + +// Vision 配置接口 +export interface VisionConfig { + provider: ProviderType; + apiKey: string; + model: string; + /** 自定义 Base URL(用于 OpenAI 兼容的 Vision 服务) */ + baseUrl?: string; } // 默认模型配置 const DEFAULT_MODELS: Record = { anthropic: 'claude-sonnet-4-20250514', deepseek: 'deepseek-chat', + openai: 'gpt-4o', +}; + +// 默认 Vision 模型(需要支持图片理解) +const DEFAULT_VISION_MODELS: Record = { + anthropic: 'claude-sonnet-4-20250514', + deepseek: 'deepseek-chat', // DeepSeek 暂不支持 vision,占位用 + openai: 'gpt-4o', }; // 默认系统提示词 @@ -60,8 +87,10 @@ export function loadConfig(): AgentConfig { const provider = (process.env.AI_PROVIDER as ProviderType) || 'anthropic'; const anthropicApiKey = process.env.ANTHROPIC_API_KEY; const deepseekApiKey = process.env.DEEPSEEK_API_KEY; + const openaiApiKey = process.env.OPENAI_API_KEY; const model = process.env.AI_MODEL; const maxTokens = parseInt(process.env.AI_MAX_TOKENS || '4096', 10); + const baseUrl = process.env.AI_BASE_URL; // 从配置文件读取 let storedConfig: StoredConfig = {}; @@ -83,10 +112,17 @@ export function loadConfig(): AgentConfig { finalApiKey = anthropicApiKey || storedConfig.apiKey; } else if (finalProvider === 'deepseek') { finalApiKey = deepseekApiKey || storedConfig.deepseekApiKey; + } else if (finalProvider === 'openai') { + finalApiKey = openaiApiKey || storedConfig.openaiApiKey; } if (!finalApiKey) { - const envVar = finalProvider === 'anthropic' ? 'ANTHROPIC_API_KEY' : 'DEEPSEEK_API_KEY'; + const envVarMap: Record = { + anthropic: 'ANTHROPIC_API_KEY', + deepseek: 'DEEPSEEK_API_KEY', + openai: 'OPENAI_API_KEY', + }; + const envVar = envVarMap[finalProvider]; console.error(`❌ 错误: 未设置 ${envVar}`); console.error(`请设置环境变量: export ${envVar}=your-api-key`); console.error('或运行: ai-assist init 进行初始化配置'); @@ -96,12 +132,71 @@ export function loadConfig(): AgentConfig { // 确定模型 const finalModel = model || storedConfig.model || DEFAULT_MODELS[finalProvider]; + // 确定 baseUrl(环境变量优先) + const finalBaseUrl = baseUrl || storedConfig.baseUrl; + return { provider: finalProvider, apiKey: finalApiKey, model: finalModel, maxTokens: storedConfig.maxTokens || maxTokens, systemPrompt: DEFAULT_SYSTEM_PROMPT, + baseUrl: finalBaseUrl, + }; +} + +/** + * 加载 Vision 配置 + * Vision 用于图片理解,当主模型不支持 vision 时使用 + * 优先级:环境变量 > 配置文件 > 默认使用 Anthropic Claude + */ +export function loadVisionConfig(): VisionConfig | null { + // 从环境变量获取 + const visionProvider = process.env.VISION_PROVIDER as ProviderType | undefined; + const visionModel = process.env.VISION_MODEL; + const visionApiKey = process.env.VISION_API_KEY; + const visionBaseUrl = process.env.VISION_BASE_URL; + const anthropicApiKey = process.env.ANTHROPIC_API_KEY; + const deepseekApiKey = process.env.DEEPSEEK_API_KEY; + const openaiApiKey = process.env.OPENAI_API_KEY; + + // 从配置文件读取 + const storedConfig = getConfig(); + + // 确定 vision provider(默认使用 anthropic,因为 Claude 支持 vision) + const finalProvider = visionProvider || storedConfig.visionProvider || 'anthropic'; + + // 获取 Vision 专用的 API Key(优先级:环境变量 > 配置文件专用 key > provider 对应的 key) + let finalApiKey: string | undefined; + finalApiKey = visionApiKey || storedConfig.visionApiKey; + + // 如果没有专用 key,回退到对应 provider 的 key + if (!finalApiKey) { + if (finalProvider === 'anthropic') { + finalApiKey = anthropicApiKey || storedConfig.apiKey; + } else if (finalProvider === 'deepseek') { + finalApiKey = deepseekApiKey || storedConfig.deepseekApiKey; + } else if (finalProvider === 'openai') { + finalApiKey = openaiApiKey || storedConfig.openaiApiKey; + } + } + + // 如果没有 API Key,返回 null + if (!finalApiKey) { + return null; + } + + // 确定模型 + const finalModel = visionModel || storedConfig.visionModel || DEFAULT_VISION_MODELS[finalProvider]; + + // 确定 baseUrl(Vision 专用) + const finalBaseUrl = visionBaseUrl || storedConfig.visionBaseUrl; + + return { + provider: finalProvider, + apiKey: finalApiKey, + model: finalModel, + baseUrl: finalBaseUrl, }; } @@ -142,45 +237,120 @@ export async function initConfig(): Promise { message: '选择 AI 服务商:', choices: [ { name: 'Anthropic (Claude)', value: 'anthropic' }, + { name: 'OpenAI (GPT)', value: 'openai' }, + { name: 'OpenAI 兼容服务 (阿里云百炼、Azure 等)', value: 'openai-compatible' }, { name: 'DeepSeek', value: 'deepseek' }, ], default: 'anthropic', }, ]); + // 是否是 OpenAI 兼容服务 + const isOpenAICompatible = provider === 'openai-compatible'; + const actualProvider = isOpenAICompatible ? 'openai' : provider; + + // 如果是 OpenAI 兼容服务,询问 base URL + let baseUrl: string | undefined; + if (isOpenAICompatible) { + const { customBaseUrl } = await inquirer.prompt([ + { + type: 'input', + name: 'customBaseUrl', + message: '请输入 API 基础 URL (如: https://dashscope.aliyuncs.com/compatible-mode/v1):', + validate: (input: string) => { + if (!input) return 'Base URL 不能为空'; + try { + new URL(input); + return true; + } catch { + return '请输入有效的 URL'; + } + }, + }, + ]); + baseUrl = customBaseUrl; + } + // 根据 provider 显示不同的模型选项 - const modelChoices = - provider === 'anthropic' - ? [ - { name: 'Claude Sonnet 4 (推荐,平衡性能和成本)', value: 'claude-sonnet-4-20250514' }, - { name: 'Claude Opus 4 (最强,成本较高)', value: 'claude-opus-4-20250514' }, - { name: 'Claude 3.5 Haiku (快速,成本低)', value: 'claude-3-5-haiku-20241022' }, - ] - : [ - { name: 'DeepSeek Chat (推荐)', value: 'deepseek-chat' }, - { name: 'DeepSeek Reasoner (推理增强)', value: 'deepseek-reasoner' }, - ]; + let modelChoices: Array<{ name: string; value: string }>; + let allowCustomModel = false; - const apiKeyField = provider === 'anthropic' ? 'apiKey' : 'deepseekApiKey'; - const apiKeyMessage = - provider === 'anthropic' - ? '请输入你的 Anthropic API Key:' - : '请输入你的 DeepSeek API Key:'; + if (actualProvider === 'anthropic') { + modelChoices = [ + { name: 'Claude Sonnet 4 (推荐,平衡性能和成本)', value: 'claude-sonnet-4-20250514' }, + { name: 'Claude Opus 4 (最强,成本较高)', value: 'claude-opus-4-20250514' }, + { name: 'Claude 3.5 Haiku (快速,成本低)', value: 'claude-3-5-haiku-20241022' }, + ]; + } else if (actualProvider === 'openai') { + if (isOpenAICompatible) { + // OpenAI 兼容服务允许自定义模型名称 + modelChoices = [ + { name: 'qwen-plus (通义千问)', value: 'qwen-plus' }, + { name: 'qwen-turbo (通义千问快速版)', value: 'qwen-turbo' }, + { name: 'qwen-max (通义千问最强版)', value: 'qwen-max' }, + { name: 'gpt-4o', value: 'gpt-4o' }, + { name: '自定义模型名称...', value: '__custom__' }, + ]; + allowCustomModel = true; + } else { + modelChoices = [ + { name: 'GPT-4o (推荐,支持 vision)', value: 'gpt-4o' }, + { name: 'GPT-4o mini (快速,成本低)', value: 'gpt-4o-mini' }, + { name: 'GPT-4 Turbo', value: 'gpt-4-turbo' }, + { name: 'o1 (推理增强)', value: 'o1' }, + { name: 'o1-mini (推理,成本低)', value: 'o1-mini' }, + ]; + } + } else { + modelChoices = [ + { name: 'DeepSeek Chat (推荐)', value: 'deepseek-chat' }, + { name: 'DeepSeek Reasoner (推理增强)', value: 'deepseek-reasoner' }, + ]; + } - const answers = await inquirer.prompt([ + const apiKeyMessageMap: Record = { + anthropic: '请输入你的 Anthropic API Key:', + openai: '请输入你的 OpenAI API Key:', + deepseek: '请输入你的 DeepSeek API Key:', + }; + + // 分开询问 API Key + const { apiKey } = await inquirer.prompt([ { type: 'password', - name: apiKeyField, - message: apiKeyMessage, + name: 'apiKey', + message: isOpenAICompatible ? '请输入你的 API Key:' : apiKeyMessageMap[actualProvider], validate: (input: string) => input.length > 0 || 'API Key 不能为空', }, + ]); + + // 询问模型配置 + const { model: selectedModel } = await inquirer.prompt([ { type: 'list', name: 'model', message: '选择默认模型:', choices: modelChoices, - default: DEFAULT_MODELS[provider as ProviderType], + default: DEFAULT_MODELS[actualProvider as ProviderType], }, + ]); + + // 如果选择自定义模型,询问模型名称 + let finalModel = selectedModel; + if (allowCustomModel && selectedModel === '__custom__') { + const { customModel } = await inquirer.prompt([ + { + type: 'input', + name: 'customModel', + message: '请输入模型名称:', + validate: (input: string) => input.length > 0 || '模型名称不能为空', + }, + ]); + finalModel = customModel; + } + + // 询问 token 配置 + const { maxTokens } = await inquirer.prompt([ { type: 'number', name: 'maxTokens', @@ -189,7 +359,28 @@ export async function initConfig(): Promise { }, ]); - saveConfig({ provider, ...answers }); + // 根据 provider 构建配置对象 + const configToSave: Partial = { + provider: actualProvider as ProviderType, + model: finalModel, + maxTokens, + }; + + // 存储 API Key 到对应字段 + if (actualProvider === 'anthropic') { + configToSave.apiKey = apiKey; + } else if (actualProvider === 'openai') { + configToSave.openaiApiKey = apiKey; + } else if (actualProvider === 'deepseek') { + configToSave.deepseekApiKey = apiKey; + } + + // 存储 base URL + if (baseUrl) { + configToSave.baseUrl = baseUrl; + } + + saveConfig(configToSave); console.log('\n✅ 配置已保存到', CONFIG_FILE); console.log('现在可以运行 ai-assist 开始使用了!\n'); } diff --git a/src/utils/image.ts b/src/utils/image.ts new file mode 100644 index 0000000..f22eea8 --- /dev/null +++ b/src/utils/image.ts @@ -0,0 +1,201 @@ +/** + * 图片处理工具 + * + * 提供图片文件读取、格式检测、base64 编码等功能 + */ + +import * as fs from 'fs/promises'; +import * as path from 'path'; + +/** 支持的图片扩展名 */ +export const IMAGE_EXTENSIONS = ['.png', '.jpg', '.jpeg', '.gif', '.webp']; + +/** 图片 MIME 类型映射 */ +const MIME_TYPES: Record = { + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.gif': 'image/gif', + '.webp': 'image/webp', +}; + +/** 图片信息 */ +export interface ImageInfo { + /** 原始文件路径 */ + path: string; + /** 文件名 */ + filename: string; + /** 扩展名 */ + extension: string; + /** MIME 类型 */ + mimeType: string; + /** 文件大小(字节) */ + size: number; + /** base64 编码的数据 */ + base64: string; + /** 完整的 data URL */ + dataUrl: string; +} + +/** 图片加载结果 */ +export interface ImageLoadResult { + success: boolean; + image?: ImageInfo; + error?: string; +} + +/** + * 判断文件路径是否为图片 + */ +export function isImagePath(filePath: string): boolean { + const ext = path.extname(filePath).toLowerCase(); + return IMAGE_EXTENSIONS.includes(ext); +} + +/** + * 从输入文本中提取图片引用 + * 支持多种格式: + * 1. @path/to/image.png(不带空格的路径) + * 2. @"path/to/image with spaces.png"(带空格的路径用引号包裹) + * 3. @/path/to/image.png(绝对路径,自动匹配到图片扩展名结束) + * + * @param input 用户输入 + * @returns 图片路径列表和去除图片引用后的文本 + */ +export function extractImageReferences(input: string): { + imagePaths: string[]; + textContent: string; +} { + const imagePaths: string[] = []; + let textContent = input; + + // 模式1: 带引号的路径 @"path/to/image.png" 或 @'path/to/image.png' + const quotedMatches = [...input.matchAll(/@["']([^"']+\.(?:png|jpg|jpeg|gif|webp))["']/gi)]; + for (const match of quotedMatches) { + imagePaths.push(match[1]); + textContent = textContent.replace(match[0], ' '); + } + + // 模式2: 绝对路径(以 / 或 ~ 开头,匹配到图片扩展名结束) + // 支持路径中包含空格 + const absoluteMatches = [...textContent.matchAll(/@([/~][^\n]*?\.(?:png|jpg|jpeg|gif|webp))(?=\s|$)/gi)]; + for (const match of absoluteMatches) { + if (!imagePaths.includes(match[1])) { + imagePaths.push(match[1]); + textContent = textContent.replace(match[0], ' '); + } + } + + // 模式3: 相对路径(不以 / 开头,不包含空格) + const relativeMatches = [...textContent.matchAll(/@((?:\.\/|\.\.\/)?[^\s@"'/][^\s@"']*\.(?:png|jpg|jpeg|gif|webp))/gi)]; + for (const match of relativeMatches) { + if (!imagePaths.includes(match[1])) { + imagePaths.push(match[1]); + textContent = textContent.replace(match[0], ' '); + } + } + + // 清理多余空格 + textContent = textContent.replace(/\s+/g, ' ').trim(); + + return { imagePaths, textContent }; +} + +/** + * 加载图片文件 + * @param filePath 图片路径(相对或绝对) + * @param workdir 工作目录(用于解析相对路径) + */ +export async function loadImage( + filePath: string, + workdir: string = process.cwd() +): Promise { + try { + // 解析路径 + const absolutePath = path.isAbsolute(filePath) + ? filePath + : path.resolve(workdir, filePath); + + // 检查扩展名 + const ext = path.extname(absolutePath).toLowerCase(); + if (!IMAGE_EXTENSIONS.includes(ext)) { + return { + success: false, + error: `不支持的图片格式: ${ext}。支持的格式: ${IMAGE_EXTENSIONS.join(', ')}`, + }; + } + + // 读取文件 + const buffer = await fs.readFile(absolutePath); + const stats = await fs.stat(absolutePath); + + // 转换为 base64 + const base64 = buffer.toString('base64'); + const mimeType = MIME_TYPES[ext] || 'application/octet-stream'; + const dataUrl = `data:${mimeType};base64,${base64}`; + + return { + success: true, + image: { + path: absolutePath, + filename: path.basename(absolutePath), + extension: ext, + mimeType, + size: stats.size, + base64, + dataUrl, + }, + }; + } catch (error) { + if ((error as NodeJS.ErrnoException).code === 'ENOENT') { + return { + success: false, + error: `图片文件不存在: ${filePath}`, + }; + } + return { + success: false, + error: `加载图片失败: ${error instanceof Error ? error.message : String(error)}`, + }; + } +} + +/** + * 批量加载图片 + * @param filePaths 图片路径列表 + * @param workdir 工作目录 + */ +export async function loadImages( + filePaths: string[], + workdir: string = process.cwd() +): Promise<{ + images: ImageInfo[]; + errors: Array<{ path: string; error: string }>; +}> { + const images: ImageInfo[] = []; + const errors: Array<{ path: string; error: string }> = []; + + for (const filePath of filePaths) { + const result = await loadImage(filePath, workdir); + if (result.success && result.image) { + images.push(result.image); + } else { + errors.push({ path: filePath, error: result.error || '未知错误' }); + } + } + + return { images, errors }; +} + +/** + * 格式化文件大小 + */ +export function formatFileSize(bytes: number): string { + if (bytes < 1024) { + return `${bytes}B`; + } else if (bytes < 1024 * 1024) { + return `${(bytes / 1024).toFixed(1)}KB`; + } else { + return `${(bytes / (1024 * 1024)).toFixed(1)}MB`; + } +} diff --git a/src/utils/vision.ts b/src/utils/vision.ts new file mode 100644 index 0000000..b7049cb --- /dev/null +++ b/src/utils/vision.ts @@ -0,0 +1,217 @@ +import { loadVisionConfig, type VisionConfig } from './config.js'; + +/** + * Vision 服务 - 用于图片理解 + * 当主模型不支持 vision 时,使用独立的 Vision 服务分析图片 + * 使用原生 fetch 调用 OpenAI 兼容接口,以确保与百炼等服务兼容 + */ + +export interface ImageData { + /** base64 编码的图片数据 */ + data: string; + /** MIME 类型 */ + mimeType: string; + /** 文件名(可选) */ + filename?: string; +} + +export interface VisionAnalysisResult { + success: boolean; + /** 图片描述 */ + description: string; + /** 错误信息(如果失败) */ + error?: string; +} + +/** + * 分析单张图片 + */ +export async function analyzeImage( + image: ImageData, + prompt?: string +): Promise { + const config = loadVisionConfig(); + + if (!config) { + return { + success: false, + description: '', + error: '未配置 Vision 服务。请在配置文件中设置 visionProvider、visionApiKey 等参数。', + }; + } + + try { + const description = await callVisionAPI(config, [image], prompt); + return { + success: true, + description, + }; + } catch (error) { + return { + success: false, + description: '', + error: error instanceof Error ? error.message : String(error), + }; + } +} + +/** + * 批量分析图片 + */ +export async function analyzeImages( + images: ImageData[], + prompt?: string +): Promise { + const config = loadVisionConfig(); + + if (!config) { + return { + success: false, + description: '', + error: '未配置 Vision 服务。请在配置文件中设置 visionProvider、visionApiKey 等参数。', + }; + } + + if (images.length === 0) { + return { + success: false, + description: '', + error: '没有提供图片', + }; + } + + try { + const description = await callVisionAPI(config, images, prompt); + return { + success: true, + description, + }; + } catch (error) { + return { + success: false, + description: '', + error: error instanceof Error ? error.message : String(error), + }; + } +} + +/** + * 调用 Vision API + * 使用原生 fetch 调用 OpenAI 兼容接口,确保与百炼等服务兼容 + */ +async function callVisionAPI( + config: VisionConfig, + images: ImageData[], + userPrompt?: string +): Promise { + // 目前只支持 OpenAI 兼容的 Vision API(如百炼的 qwen-vl-plus) + if (config.provider !== 'openai') { + throw new Error(`暂不支持 ${config.provider} 的 Vision 服务`); + } + + // 构建消息内容(OpenAI Vision API 格式) + const content: Array< + | { type: 'text'; text: string } + | { type: 'image_url'; image_url: { url: string } } + > = []; + + // 添加图片(使用 data URL 格式) + for (const img of images) { + content.push({ + type: 'image_url', + image_url: { + url: `data:${img.mimeType};base64,${img.data}`, + }, + }); + } + + // 添加提示文本 + const defaultPrompt = images.length === 1 + ? '请详细描述这张图片的内容,包括主要元素、文字、颜色、布局等信息。' + : `请详细描述这 ${images.length} 张图片的内容,包括主要元素、文字、颜色、布局等信息。`; + + content.push({ + type: 'text', + text: userPrompt || defaultPrompt, + }); + + // 构建请求体 + const requestBody = { + model: config.model, + messages: [ + { + role: 'user', + content, + }, + ], + max_tokens: 2000, + }; + + // 确定 API 端点 + const baseUrl = config.baseUrl || 'https://api.openai.com/v1'; + const endpoint = `${baseUrl.replace(/\/$/, '')}/chat/completions`; + + // 发送请求 + const response = await fetch(endpoint, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${config.apiKey}`, + }, + body: JSON.stringify(requestBody), + }); + + if (!response.ok) { + const errorText = await response.text(); + let errorMessage = `API 请求失败: ${response.status} ${response.statusText}`; + try { + const errorJson = JSON.parse(errorText); + if (errorJson.error?.message) { + errorMessage = errorJson.error.message; + } + } catch { + if (errorText) { + errorMessage += ` - ${errorText}`; + } + } + throw new Error(errorMessage); + } + + const result = await response.json() as { + choices?: Array<{ + message?: { + content?: string; + }; + }>; + }; + + const text = result.choices?.[0]?.message?.content; + if (!text) { + throw new Error('API 返回了空响应'); + } + + return text; +} + +/** + * 检查 Vision 服务是否可用 + */ +export function isVisionAvailable(): boolean { + const config = loadVisionConfig(); + return config !== null; +} + +/** + * 获取 Vision 配置信息(用于显示) + */ +export function getVisionInfo(): { available: boolean; provider?: string; model?: string } { + const config = loadVisionConfig(); + if (!config) { + return { available: false }; + } + return { + available: true, + provider: config.provider, + model: config.model, + }; +} diff --git a/tests/unit/utils/image.test.ts b/tests/unit/utils/image.test.ts new file mode 100644 index 0000000..8155927 --- /dev/null +++ b/tests/unit/utils/image.test.ts @@ -0,0 +1,155 @@ +import { describe, it, expect } from 'vitest'; +import { + isImagePath, + extractImageReferences, + formatFileSize, + IMAGE_EXTENSIONS, +} from '../../../src/utils/image.js'; + +describe('Image Utils - 图片处理工具', () => { + describe('IMAGE_EXTENSIONS', () => { + it('包含常见图片扩展名', () => { + expect(IMAGE_EXTENSIONS).toContain('.png'); + expect(IMAGE_EXTENSIONS).toContain('.jpg'); + expect(IMAGE_EXTENSIONS).toContain('.jpeg'); + expect(IMAGE_EXTENSIONS).toContain('.gif'); + expect(IMAGE_EXTENSIONS).toContain('.webp'); + }); + }); + + describe('isImagePath - 判断是否为图片路径', () => { + it('识别 PNG 文件', () => { + expect(isImagePath('screenshot.png')).toBe(true); + expect(isImagePath('path/to/image.PNG')).toBe(true); + }); + + it('识别 JPG/JPEG 文件', () => { + expect(isImagePath('photo.jpg')).toBe(true); + expect(isImagePath('photo.jpeg')).toBe(true); + expect(isImagePath('photo.JPG')).toBe(true); + }); + + it('识别 GIF 文件', () => { + expect(isImagePath('animation.gif')).toBe(true); + }); + + it('识别 WebP 文件', () => { + expect(isImagePath('modern.webp')).toBe(true); + }); + + it('不识别非图片文件', () => { + expect(isImagePath('document.txt')).toBe(false); + expect(isImagePath('script.ts')).toBe(false); + expect(isImagePath('data.json')).toBe(false); + expect(isImagePath('readme.md')).toBe(false); + }); + + it('处理没有扩展名的文件', () => { + expect(isImagePath('noextension')).toBe(false); + expect(isImagePath('Makefile')).toBe(false); + }); + }); + + describe('extractImageReferences - 提取图片引用', () => { + it('提取单个 @ 引用', () => { + const result = extractImageReferences('请分析这张图片 @screenshot.png'); + expect(result.imagePaths).toEqual(['screenshot.png']); + expect(result.textContent).toBe('请分析这张图片'); + }); + + it('提取多个 @ 引用', () => { + const result = extractImageReferences('对比 @before.png 和 @after.jpg'); + expect(result.imagePaths).toEqual(['before.png', 'after.jpg']); + expect(result.textContent).toBe('对比 和'); + }); + + it('提取带路径的图片引用', () => { + const result = extractImageReferences('分析 @./images/test.png'); + expect(result.imagePaths).toEqual(['./images/test.png']); + }); + + it('提取绝对路径图片引用', () => { + const result = extractImageReferences('查看 @/tmp/screenshot.png'); + expect(result.imagePaths).toEqual(['/tmp/screenshot.png']); + }); + + it('提取带空格的绝对路径(自动匹配到扩展名)', () => { + const result = extractImageReferences('这张图片内容是什么?@/Users/xd/Adobe Express - file.png'); + expect(result.imagePaths).toEqual(['/Users/xd/Adobe Express - file.png']); + expect(result.textContent).toBe('这张图片内容是什么?'); + }); + + it('提取带引号的路径(支持空格)', () => { + const result = extractImageReferences('分析 @"./my images/test photo.png"'); + expect(result.imagePaths).toEqual(['./my images/test photo.png']); + expect(result.textContent).toBe('分析'); + }); + + it('提取带单引号的路径', () => { + const result = extractImageReferences("查看 @'./path with spaces/image.jpg'"); + expect(result.imagePaths).toEqual(['./path with spaces/image.jpg']); + expect(result.textContent).toBe('查看'); + }); + + it('忽略非图片的 @ 引用', () => { + const result = extractImageReferences('请查看 @readme.md 文件'); + expect(result.imagePaths).toEqual([]); + expect(result.textContent).toBe('请查看 @readme.md 文件'); + }); + + it('忽略邮箱地址', () => { + const result = extractImageReferences('联系 user@example.com 了解详情'); + expect(result.imagePaths).toEqual([]); + expect(result.textContent).toBe('联系 user@example.com 了解详情'); + }); + + it('混合图片和非图片引用', () => { + const result = extractImageReferences( + '查看 @screenshot.png 和 @config.json' + ); + expect(result.imagePaths).toEqual(['screenshot.png']); + expect(result.textContent).toBe('查看 和 @config.json'); + }); + + it('没有引用时返回原文本', () => { + const result = extractImageReferences('这是一段普通文本'); + expect(result.imagePaths).toEqual([]); + expect(result.textContent).toBe('这是一段普通文本'); + }); + + it('只有图片引用时文本内容为空', () => { + const result = extractImageReferences('@image.png'); + expect(result.imagePaths).toEqual(['image.png']); + expect(result.textContent).toBe(''); + }); + + it('处理多个空格', () => { + const result = extractImageReferences(' @test.png 描述 '); + expect(result.imagePaths).toEqual(['test.png']); + expect(result.textContent.trim()).toBe('描述'); + }); + }); + + describe('formatFileSize - 格式化文件大小', () => { + it('格式化字节', () => { + expect(formatFileSize(0)).toBe('0B'); + expect(formatFileSize(100)).toBe('100B'); + expect(formatFileSize(1023)).toBe('1023B'); + }); + + it('格式化 KB', () => { + expect(formatFileSize(1024)).toBe('1.0KB'); + expect(formatFileSize(1536)).toBe('1.5KB'); + expect(formatFileSize(10240)).toBe('10.0KB'); + }); + + it('格式化 MB', () => { + expect(formatFileSize(1024 * 1024)).toBe('1.0MB'); + expect(formatFileSize(1024 * 1024 * 2.5)).toBe('2.5MB'); + }); + + it('大文件以 MB 为单位', () => { + expect(formatFileSize(1024 * 1024 * 1024)).toBe('1024.0MB'); + }); + }); +});