feat: 添加 OpenAI 兼容 API 支持和独立 Vision 服务

- 添加 OpenAI AI SDK provider 支持 (@ai-sdk/openai)
- 支持 OpenAI 兼容服务的 baseUrl 配置(如阿里云百炼)
- 添加独立的 Vision 配置(visionProvider/visionApiKey/visionBaseUrl/visionModel)
- 实现图片引用语法 @path/to/image.png,支持带空格的路径
- 当主模型不支持 vision 时,自动调用配置的 Vision 服务分析图片
- 添加图片处理工具函数和单元测试
This commit is contained in:
2025-12-11 17:49:16 +08:00
parent a476a4240c
commit 32fdb244f0
11 changed files with 1096 additions and 42 deletions
+5
View File
@@ -4,3 +4,8 @@ ANTHROPIC_API_KEY=sk-ant-xxxxx
# 可选配置
AI_MODEL=claude-sonnet-4-20250514
AI_MAX_TOKENS=4096
# Vision 配置(用于图片理解,当主模型不支持 vision 时使用)
# 如果不配置,默认使用 Anthropic Claude
# VISION_PROVIDER=anthropic
# VISION_MODEL=claude-sonnet-4-20250514
+17
View File
@@ -11,6 +11,7 @@
"dependencies": {
"@ai-sdk/anthropic": "^2.0.54",
"@ai-sdk/deepseek": "^1.0.31",
"@ai-sdk/openai": "^2.0.80",
"@tavily/core": "^0.6.0",
"ai": "^5.0.108",
"chalk": "^5.3.0",
@@ -88,6 +89,22 @@
"zod": "^3.25.76 || ^4.1.8"
}
},
"node_modules/@ai-sdk/openai": {
"version": "2.0.80",
"resolved": "https://registry.npmjs.org/@ai-sdk/openai/-/openai-2.0.80.tgz",
"integrity": "sha512-tNHuraF11db+8xJEDBoU9E3vMcpnHFKRhnLQ3DQX2LnEzfPB9DksZ8rE+yVuDN1WRW9cm2OWAhgHFgVKs7ICuw==",
"license": "Apache-2.0",
"dependencies": {
"@ai-sdk/provider": "2.0.0",
"@ai-sdk/provider-utils": "3.0.18"
},
"engines": {
"node": ">=18"
},
"peerDependencies": {
"zod": "^3.25.76 || ^4.1.8"
}
},
"node_modules/@ai-sdk/provider": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-2.0.0.tgz",
+1
View File
@@ -28,6 +28,7 @@
"dependencies": {
"@ai-sdk/anthropic": "^2.0.54",
"@ai-sdk/deepseek": "^1.0.31",
"@ai-sdk/openai": "^2.0.80",
"@tavily/core": "^0.6.0",
"ai": "^5.0.108",
"chalk": "^5.3.0",
+17 -6
View File
@@ -1,5 +1,6 @@
import { createAnthropic } from '@ai-sdk/anthropic';
import { createDeepSeek } from '@ai-sdk/deepseek';
import { createOpenAI } from '@ai-sdk/openai';
import {
generateText,
streamText,
@@ -18,17 +19,27 @@ import type {
} from './types.js';
import { checkBashPermission } from './permission-merger.js';
// Provider 配置
interface ProviderOptions {
apiKey: string;
baseUrl?: string;
}
// Provider 工厂函数类型
type ProviderFactory = (apiKey: string) => (model: string) => LanguageModel;
type ProviderFactory = (options: ProviderOptions) => (model: string) => LanguageModel;
// Provider 注册表
const providers: Record<ProviderType, ProviderFactory> = {
anthropic: (apiKey) => {
const client = createAnthropic({ apiKey });
anthropic: ({ apiKey, baseUrl }) => {
const client = createAnthropic({ apiKey, baseURL: baseUrl });
return (model) => client(model);
},
deepseek: (apiKey) => {
const client = createDeepSeek({ apiKey });
deepseek: ({ apiKey, baseUrl }) => {
const client = createDeepSeek({ apiKey, baseURL: baseUrl });
return (model) => client(model);
},
openai: ({ apiKey, baseUrl }) => {
const client = createOpenAI({ apiKey, baseURL: baseUrl });
return (model) => client(model);
},
};
@@ -58,7 +69,7 @@ export class AgentExecutor {
if (!providerFactory) {
throw new Error(`不支持的 provider: ${provider}`);
}
this.getModel = providerFactory(baseConfig.apiKey);
this.getModel = providerFactory({ apiKey: baseConfig.apiKey, baseUrl: baseConfig.baseUrl });
}
/**
+90 -10
View File
@@ -1,5 +1,6 @@
import { createAnthropic } from '@ai-sdk/anthropic';
import { createDeepSeek } from '@ai-sdk/deepseek';
import { createOpenAI } from '@ai-sdk/openai';
import {
generateText,
streamText,
@@ -8,7 +9,7 @@ import {
type Tool as AITool,
type LanguageModel,
} from 'ai';
import type { Tool, ToolResult, Message, AgentConfig, ProviderType } from '../types/index.js';
import type { Tool, ToolResult, Message, AgentConfig, ProviderType, UserInput, ContentBlock } from '../types/index.js';
import { buildZodSchema } from '../types/index.js';
import { ToolRegistry } from '../tools/registry.js';
import { SessionManager } from '../session/index.js';
@@ -19,17 +20,27 @@ import {
} from '../context/index.js';
import type { AgentInfo } from '../agent/types.js';
// Provider 配置
interface ProviderOptions {
apiKey: string;
baseUrl?: string;
}
// Provider 工厂函数类型
type ProviderFactory = (apiKey: string) => (model: string) => LanguageModel;
type ProviderFactory = (options: ProviderOptions) => (model: string) => LanguageModel;
// Provider 注册表
const providers: Record<ProviderType, ProviderFactory> = {
anthropic: (apiKey) => {
const client = createAnthropic({ apiKey });
anthropic: ({ apiKey, baseUrl }) => {
const client = createAnthropic({ apiKey, baseURL: baseUrl });
return (model) => client(model);
},
deepseek: (apiKey) => {
const client = createDeepSeek({ apiKey });
deepseek: ({ apiKey, baseUrl }) => {
const client = createDeepSeek({ apiKey, baseURL: baseUrl });
return (model) => client(model);
},
openai: ({ apiKey, baseUrl }) => {
const client = createOpenAI({ apiKey, baseURL: baseUrl });
return (model) => client(model);
},
};
@@ -68,7 +79,7 @@ export class Agent {
if (!providerFactory) {
throw new Error(`不支持的 provider: ${config.provider}`);
}
this.getModel = providerFactory(config.apiKey);
this.getModel = providerFactory({ apiKey: config.apiKey, baseUrl: config.baseUrl });
// 初始化压缩管理器
this.compressionManager = new CompressionManager(compressionConfig);
@@ -221,13 +232,49 @@ export class Agent {
/**
* 发送消息并处理响应(流式)
* @param userMessage 用户消息文本或包含图片的 UserInput
* @param onStream 流式输出回调
*/
async chat(userMessage: string, onStream?: (text: string) => void): Promise<string> {
async chat(userMessage: string | UserInput, onStream?: (text: string) => void): Promise<string> {
// 构建消息内容
let messageContent: string | ContentBlock[];
if (typeof userMessage === 'string') {
// 纯文本消息
messageContent = userMessage;
} else {
// 带图片的消息
const blocks: ContentBlock[] = [];
// 添加图片
if (userMessage.images && userMessage.images.length > 0) {
for (const img of userMessage.images) {
blocks.push({
type: 'image',
image: img.data,
mimeType: img.mimeType,
});
}
}
// 添加文本
if (userMessage.text) {
blocks.push({
type: 'text',
text: userMessage.text,
});
}
messageContent = blocks.length === 1 && blocks[0].type === 'text'
? blocks[0].text
: blocks;
}
// 添加用户消息到历史
this.conversationHistory.push({
role: 'user',
content: userMessage,
});
content: messageContent,
} as ModelMessage);
const vercelTools = this.getVercelTools();
let fullResponse = '';
@@ -436,4 +483,37 @@ export class Agent {
getAgentModeName(): string {
return this.currentAgentMode?.name ?? 'default';
}
/**
* 检查当前模型是否支持 vision(图片理解)
*/
supportsVision(): boolean {
const model = this.config.model.toLowerCase();
// Anthropic Claude 模型支持 vision
if (this.config.provider === 'anthropic') {
// Claude 3 及以上版本支持 vision
return model.includes('claude-3') || model.includes('claude-4');
}
// OpenAI GPT-4 系列支持 vision
if (this.config.provider === 'openai') {
// GPT-4o, GPT-4 Turbo, GPT-4 Vision 等支持
return model.includes('gpt-4');
}
// DeepSeek 目前不支持 vision
if (this.config.provider === 'deepseek') {
return false;
}
return false;
}
/**
* 获取当前配置
*/
getConfig(): AgentConfig {
return { ...this.config };
}
}
+34 -3
View File
@@ -1,9 +1,38 @@
import { z } from 'zod';
// 消息类型
// 内容块类型(支持多模态)
export interface TextContentBlock {
type: 'text';
text: string;
}
export interface ImageContentBlock {
type: 'image';
/** base64 编码的图片数据 */
image: string;
/** MIME 类型 */
mimeType?: string;
}
export type ContentBlock = TextContentBlock | ImageContentBlock;
// 消息类型(支持多模态)
export interface Message {
role: 'user' | 'assistant';
content: string;
/** 内容可以是纯文本或内容块数组 */
content: string | ContentBlock[];
}
// 用户输入(带图片)
export interface UserInput {
/** 文本内容 */
text: string;
/** 图片列表(base64 编码) */
images?: Array<{
data: string;
mimeType: string;
filename?: string;
}>;
}
// 工具参数定义
@@ -37,7 +66,7 @@ export interface ToolCall {
}
// 支持的 Provider 类型
export type ProviderType = 'anthropic' | 'deepseek';
export type ProviderType = 'anthropic' | 'deepseek' | 'openai';
// Agent 配置
export interface AgentConfig {
@@ -46,6 +75,8 @@ export interface AgentConfig {
model: string;
maxTokens: number;
systemPrompt: string;
/** 自定义 API 基础 URL(用于兼容 OpenAI API 的第三方服务,如阿里云百炼) */
baseUrl?: string;
}
// 会话上下文
+146 -1
View File
@@ -7,6 +7,17 @@ import {
createCommandExecutor,
type CommandExecutionResult,
} from '../commands/index.js';
import {
extractImageReferences,
loadImages,
formatFileSize,
} from '../utils/image.js';
import {
analyzeImages,
isVisionAvailable,
getVisionInfo,
} from '../utils/vision.js';
import type { UserInput } from '../types/index.js';
export class TerminalUI {
private agent: Agent;
@@ -278,6 +289,118 @@ export class TerminalUI {
}
}
// 处理包含图片引用的输入
private async processImageInput(
input: string
): Promise<{ userInput: UserInput; hasImages: boolean } | null> {
const { imagePaths, textContent } = extractImageReferences(input);
// 没有图片引用,返回纯文本
if (imagePaths.length === 0) {
return {
userInput: { text: input },
hasImages: false,
};
}
// 加载图片
console.log(chalk.gray(`\n正在加载 ${imagePaths.length} 张图片...`));
const { images, errors } = await loadImages(imagePaths, process.cwd());
// 显示加载错误
for (const err of errors) {
console.log(chalk.red(`${err.path}: ${err.error}`));
}
// 如果没有成功加载任何图片
if (images.length === 0) {
console.log(chalk.red('没有成功加载任何图片\n'));
return null;
}
// 显示成功加载的图片
for (const img of images) {
console.log(
chalk.green(`${img.filename}`) +
chalk.gray(` (${formatFileSize(img.size)})`)
);
}
console.log('');
return {
userInput: {
text: textContent,
images: images.map((img) => ({
data: img.base64,
mimeType: img.mimeType,
filename: img.filename,
})),
},
hasImages: true,
};
}
// 处理不支持 Vision 的情况
private async handleNoVisionSupport(
userInput: UserInput
): Promise<string | null> {
// 检查 Vision 服务是否可用
if (!isVisionAvailable()) {
console.log(chalk.yellow('\n⚠ 当前模型不支持图片理解,且未配置 Vision 服务'));
console.log(chalk.gray('请在配置文件中设置 visionProvider、visionApiKey 等参数'));
console.log(chalk.gray('或切换到支持图片理解的模型(如 Claude、GPT-4o\n'));
return null;
}
const visionInfo = getVisionInfo();
// 提示用户选择
console.log(chalk.yellow('\n⚠ 当前模型不支持图片理解'));
console.log(chalk.gray('请选择处理方式:'));
console.log(chalk.white(` 1. 使用 Vision 服务 (${visionInfo.model}) 分析图片后继续对话`));
console.log(chalk.white(' 2. 取消本次输入'));
const choice = await new Promise<string>((resolve) => {
this.rl.question(chalk.green('选择 (1/2): '), resolve);
});
if (choice.trim() !== '1') {
console.log(chalk.gray('已取消\n'));
return null;
}
// 使用 Vision 服务分析图片
console.log(chalk.cyan(`\n正在使用 ${visionInfo.model} 分析图片...`));
const images = userInput.images || [];
if (images.length === 0) {
console.log(chalk.red('没有图片需要分析\n'));
return null;
}
// 调用 Vision API 分析图片
const result = await analyzeImages(
images.map(img => ({
data: img.data,
mimeType: img.mimeType,
filename: img.filename,
})),
userInput.text || undefined
);
if (!result.success) {
console.log(chalk.red(`\n图片分析失败: ${result.error}\n`));
return null;
}
console.log(chalk.green('✓ 图片分析完成\n'));
// 构建带图片描述的文本消息
const combinedText = `[Vision 服务分析结果]\n${result.description}\n\n用户原始问题: ${userInput.text}`;
return combinedText;
}
// 提问并获取用户输入
private prompt(): Promise<string> {
return new Promise((resolve, reject) => {
@@ -378,13 +501,35 @@ export class TerminalUI {
continue;
}
// 处理图片引用
const processed = await this.processImageInput(input);
if (!processed) {
continue;
}
let { userInput, hasImages } = processed;
// 如果有图片且当前模型不支持 vision
if (hasImages && !this.agent.supportsVision()) {
const fallbackText = await this.handleNoVisionSupport(userInput);
if (!fallbackText) {
continue;
}
// 使用 Vision 分析结果替代图片
userInput = { text: fallbackText };
hasImages = false;
}
// 发送给 AI
process.stdout.write(chalk.gray('思考中...'));
try {
let isFirstChunk = true;
await this.agent.chat(input, (text) => {
// 根据是否有图片选择发送格式
const messageToSend = hasImages ? userInput : userInput.text;
await this.agent.chat(messageToSend, (text) => {
if (isFirstChunk) {
// 清除 "思考中..." 并显示 AI 前缀
process.stdout.write('\r' + ' '.repeat(20) + '\r');
+207 -16
View File
@@ -10,15 +10,42 @@ interface StoredConfig {
provider?: ProviderType;
apiKey?: string;
deepseekApiKey?: string;
openaiApiKey?: string;
model?: string;
maxTokens?: number;
tavilyApiKey?: string;
/** 自定义 API 基础 URL(用于 OpenAI 兼容服务,如阿里云百炼) */
baseUrl?: string;
// Vision 配置
visionProvider?: ProviderType;
visionModel?: string;
/** Vision 专用的 API Key(可选,不设置则使用对应 provider 的 key */
visionApiKey?: string;
/** Vision 专用的 Base URL(用于 OpenAI 兼容的 Vision 服务) */
visionBaseUrl?: string;
}
// Vision 配置接口
export interface VisionConfig {
provider: ProviderType;
apiKey: string;
model: string;
/** 自定义 Base URL(用于 OpenAI 兼容的 Vision 服务) */
baseUrl?: string;
}
// 默认模型配置
const DEFAULT_MODELS: Record<ProviderType, string> = {
anthropic: 'claude-sonnet-4-20250514',
deepseek: 'deepseek-chat',
openai: 'gpt-4o',
};
// 默认 Vision 模型(需要支持图片理解)
const DEFAULT_VISION_MODELS: Record<ProviderType, string> = {
anthropic: 'claude-sonnet-4-20250514',
deepseek: 'deepseek-chat', // DeepSeek 暂不支持 vision,占位用
openai: 'gpt-4o',
};
// 默认系统提示词
@@ -60,8 +87,10 @@ export function loadConfig(): AgentConfig {
const provider = (process.env.AI_PROVIDER as ProviderType) || 'anthropic';
const anthropicApiKey = process.env.ANTHROPIC_API_KEY;
const deepseekApiKey = process.env.DEEPSEEK_API_KEY;
const openaiApiKey = process.env.OPENAI_API_KEY;
const model = process.env.AI_MODEL;
const maxTokens = parseInt(process.env.AI_MAX_TOKENS || '4096', 10);
const baseUrl = process.env.AI_BASE_URL;
// 从配置文件读取
let storedConfig: StoredConfig = {};
@@ -83,10 +112,17 @@ export function loadConfig(): AgentConfig {
finalApiKey = anthropicApiKey || storedConfig.apiKey;
} else if (finalProvider === 'deepseek') {
finalApiKey = deepseekApiKey || storedConfig.deepseekApiKey;
} else if (finalProvider === 'openai') {
finalApiKey = openaiApiKey || storedConfig.openaiApiKey;
}
if (!finalApiKey) {
const envVar = finalProvider === 'anthropic' ? 'ANTHROPIC_API_KEY' : 'DEEPSEEK_API_KEY';
const envVarMap: Record<ProviderType, string> = {
anthropic: 'ANTHROPIC_API_KEY',
deepseek: 'DEEPSEEK_API_KEY',
openai: 'OPENAI_API_KEY',
};
const envVar = envVarMap[finalProvider];
console.error(`❌ 错误: 未设置 ${envVar}`);
console.error(`请设置环境变量: export ${envVar}=your-api-key`);
console.error('或运行: ai-assist init 进行初始化配置');
@@ -96,12 +132,71 @@ export function loadConfig(): AgentConfig {
// 确定模型
const finalModel = model || storedConfig.model || DEFAULT_MODELS[finalProvider];
// 确定 baseUrl(环境变量优先)
const finalBaseUrl = baseUrl || storedConfig.baseUrl;
return {
provider: finalProvider,
apiKey: finalApiKey,
model: finalModel,
maxTokens: storedConfig.maxTokens || maxTokens,
systemPrompt: DEFAULT_SYSTEM_PROMPT,
baseUrl: finalBaseUrl,
};
}
/**
* 加载 Vision 配置
* Vision 用于图片理解,当主模型不支持 vision 时使用
* 优先级:环境变量 > 配置文件 > 默认使用 Anthropic Claude
*/
export function loadVisionConfig(): VisionConfig | null {
// 从环境变量获取
const visionProvider = process.env.VISION_PROVIDER as ProviderType | undefined;
const visionModel = process.env.VISION_MODEL;
const visionApiKey = process.env.VISION_API_KEY;
const visionBaseUrl = process.env.VISION_BASE_URL;
const anthropicApiKey = process.env.ANTHROPIC_API_KEY;
const deepseekApiKey = process.env.DEEPSEEK_API_KEY;
const openaiApiKey = process.env.OPENAI_API_KEY;
// 从配置文件读取
const storedConfig = getConfig();
// 确定 vision provider(默认使用 anthropic,因为 Claude 支持 vision
const finalProvider = visionProvider || storedConfig.visionProvider || 'anthropic';
// 获取 Vision 专用的 API Key(优先级:环境变量 > 配置文件专用 key > provider 对应的 key
let finalApiKey: string | undefined;
finalApiKey = visionApiKey || storedConfig.visionApiKey;
// 如果没有专用 key,回退到对应 provider 的 key
if (!finalApiKey) {
if (finalProvider === 'anthropic') {
finalApiKey = anthropicApiKey || storedConfig.apiKey;
} else if (finalProvider === 'deepseek') {
finalApiKey = deepseekApiKey || storedConfig.deepseekApiKey;
} else if (finalProvider === 'openai') {
finalApiKey = openaiApiKey || storedConfig.openaiApiKey;
}
}
// 如果没有 API Key,返回 null
if (!finalApiKey) {
return null;
}
// 确定模型
const finalModel = visionModel || storedConfig.visionModel || DEFAULT_VISION_MODELS[finalProvider];
// 确定 baseUrlVision 专用)
const finalBaseUrl = visionBaseUrl || storedConfig.visionBaseUrl;
return {
provider: finalProvider,
apiKey: finalApiKey,
model: finalModel,
baseUrl: finalBaseUrl,
};
}
@@ -142,45 +237,120 @@ export async function initConfig(): Promise<void> {
message: '选择 AI 服务商:',
choices: [
{ name: 'Anthropic (Claude)', value: 'anthropic' },
{ name: 'OpenAI (GPT)', value: 'openai' },
{ name: 'OpenAI 兼容服务 (阿里云百炼、Azure 等)', value: 'openai-compatible' },
{ name: 'DeepSeek', value: 'deepseek' },
],
default: 'anthropic',
},
]);
// 是否是 OpenAI 兼容服务
const isOpenAICompatible = provider === 'openai-compatible';
const actualProvider = isOpenAICompatible ? 'openai' : provider;
// 如果是 OpenAI 兼容服务,询问 base URL
let baseUrl: string | undefined;
if (isOpenAICompatible) {
const { customBaseUrl } = await inquirer.prompt([
{
type: 'input',
name: 'customBaseUrl',
message: '请输入 API 基础 URL (如: https://dashscope.aliyuncs.com/compatible-mode/v1):',
validate: (input: string) => {
if (!input) return 'Base URL 不能为空';
try {
new URL(input);
return true;
} catch {
return '请输入有效的 URL';
}
},
},
]);
baseUrl = customBaseUrl;
}
// 根据 provider 显示不同的模型选项
const modelChoices =
provider === 'anthropic'
? [
let modelChoices: Array<{ name: string; value: string }>;
let allowCustomModel = false;
if (actualProvider === 'anthropic') {
modelChoices = [
{ name: 'Claude Sonnet 4 (推荐,平衡性能和成本)', value: 'claude-sonnet-4-20250514' },
{ name: 'Claude Opus 4 (最强,成本较高)', value: 'claude-opus-4-20250514' },
{ name: 'Claude 3.5 Haiku (快速,成本低)', value: 'claude-3-5-haiku-20241022' },
]
: [
];
} else if (actualProvider === 'openai') {
if (isOpenAICompatible) {
// OpenAI 兼容服务允许自定义模型名称
modelChoices = [
{ name: 'qwen-plus (通义千问)', value: 'qwen-plus' },
{ name: 'qwen-turbo (通义千问快速版)', value: 'qwen-turbo' },
{ name: 'qwen-max (通义千问最强版)', value: 'qwen-max' },
{ name: 'gpt-4o', value: 'gpt-4o' },
{ name: '自定义模型名称...', value: '__custom__' },
];
allowCustomModel = true;
} else {
modelChoices = [
{ name: 'GPT-4o (推荐,支持 vision)', value: 'gpt-4o' },
{ name: 'GPT-4o mini (快速,成本低)', value: 'gpt-4o-mini' },
{ name: 'GPT-4 Turbo', value: 'gpt-4-turbo' },
{ name: 'o1 (推理增强)', value: 'o1' },
{ name: 'o1-mini (推理,成本低)', value: 'o1-mini' },
];
}
} else {
modelChoices = [
{ name: 'DeepSeek Chat (推荐)', value: 'deepseek-chat' },
{ name: 'DeepSeek Reasoner (推理增强)', value: 'deepseek-reasoner' },
];
}
const apiKeyField = provider === 'anthropic' ? 'apiKey' : 'deepseekApiKey';
const apiKeyMessage =
provider === 'anthropic'
? '请输入你的 Anthropic API Key:'
: '请输入你的 DeepSeek API Key:';
const apiKeyMessageMap: Record<string, string> = {
anthropic: '请输入你的 Anthropic API Key:',
openai: '请输入你的 OpenAI API Key:',
deepseek: '请输入你的 DeepSeek API Key:',
};
const answers = await inquirer.prompt([
// 分开询问 API Key
const { apiKey } = await inquirer.prompt([
{
type: 'password',
name: apiKeyField,
message: apiKeyMessage,
name: 'apiKey',
message: isOpenAICompatible ? '请输入你的 API Key:' : apiKeyMessageMap[actualProvider],
validate: (input: string) => input.length > 0 || 'API Key 不能为空',
},
]);
// 询问模型配置
const { model: selectedModel } = await inquirer.prompt([
{
type: 'list',
name: 'model',
message: '选择默认模型:',
choices: modelChoices,
default: DEFAULT_MODELS[provider as ProviderType],
default: DEFAULT_MODELS[actualProvider as ProviderType],
},
]);
// 如果选择自定义模型,询问模型名称
let finalModel = selectedModel;
if (allowCustomModel && selectedModel === '__custom__') {
const { customModel } = await inquirer.prompt([
{
type: 'input',
name: 'customModel',
message: '请输入模型名称:',
validate: (input: string) => input.length > 0 || '模型名称不能为空',
},
]);
finalModel = customModel;
}
// 询问 token 配置
const { maxTokens } = await inquirer.prompt([
{
type: 'number',
name: 'maxTokens',
@@ -189,7 +359,28 @@ export async function initConfig(): Promise<void> {
},
]);
saveConfig({ provider, ...answers });
// 根据 provider 构建配置对象
const configToSave: Partial<StoredConfig> = {
provider: actualProvider as ProviderType,
model: finalModel,
maxTokens,
};
// 存储 API Key 到对应字段
if (actualProvider === 'anthropic') {
configToSave.apiKey = apiKey;
} else if (actualProvider === 'openai') {
configToSave.openaiApiKey = apiKey;
} else if (actualProvider === 'deepseek') {
configToSave.deepseekApiKey = apiKey;
}
// 存储 base URL
if (baseUrl) {
configToSave.baseUrl = baseUrl;
}
saveConfig(configToSave);
console.log('\n✅ 配置已保存到', CONFIG_FILE);
console.log('现在可以运行 ai-assist 开始使用了!\n');
}
+201
View File
@@ -0,0 +1,201 @@
/**
* 图片处理工具
*
* 提供图片文件读取、格式检测、base64 编码等功能
*/
import * as fs from 'fs/promises';
import * as path from 'path';
/** 支持的图片扩展名 */
export const IMAGE_EXTENSIONS = ['.png', '.jpg', '.jpeg', '.gif', '.webp'];
/** 图片 MIME 类型映射 */
const MIME_TYPES: Record<string, string> = {
'.png': 'image/png',
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.gif': 'image/gif',
'.webp': 'image/webp',
};
/** 图片信息 */
export interface ImageInfo {
/** 原始文件路径 */
path: string;
/** 文件名 */
filename: string;
/** 扩展名 */
extension: string;
/** MIME 类型 */
mimeType: string;
/** 文件大小(字节) */
size: number;
/** base64 编码的数据 */
base64: string;
/** 完整的 data URL */
dataUrl: string;
}
/** 图片加载结果 */
export interface ImageLoadResult {
success: boolean;
image?: ImageInfo;
error?: string;
}
/**
* 判断文件路径是否为图片
*/
export function isImagePath(filePath: string): boolean {
const ext = path.extname(filePath).toLowerCase();
return IMAGE_EXTENSIONS.includes(ext);
}
/**
* 从输入文本中提取图片引用
* 支持多种格式:
* 1. @path/to/image.png(不带空格的路径)
* 2. @"path/to/image with spaces.png"(带空格的路径用引号包裹)
* 3. @/path/to/image.png(绝对路径,自动匹配到图片扩展名结束)
*
* @param input 用户输入
* @returns 图片路径列表和去除图片引用后的文本
*/
export function extractImageReferences(input: string): {
imagePaths: string[];
textContent: string;
} {
const imagePaths: string[] = [];
let textContent = input;
// 模式1: 带引号的路径 @"path/to/image.png" 或 @'path/to/image.png'
const quotedMatches = [...input.matchAll(/@["']([^"']+\.(?:png|jpg|jpeg|gif|webp))["']/gi)];
for (const match of quotedMatches) {
imagePaths.push(match[1]);
textContent = textContent.replace(match[0], ' ');
}
// 模式2: 绝对路径(以 / 或 ~ 开头,匹配到图片扩展名结束)
// 支持路径中包含空格
const absoluteMatches = [...textContent.matchAll(/@([/~][^\n]*?\.(?:png|jpg|jpeg|gif|webp))(?=\s|$)/gi)];
for (const match of absoluteMatches) {
if (!imagePaths.includes(match[1])) {
imagePaths.push(match[1]);
textContent = textContent.replace(match[0], ' ');
}
}
// 模式3: 相对路径(不以 / 开头,不包含空格)
const relativeMatches = [...textContent.matchAll(/@((?:\.\/|\.\.\/)?[^\s@"'/][^\s@"']*\.(?:png|jpg|jpeg|gif|webp))/gi)];
for (const match of relativeMatches) {
if (!imagePaths.includes(match[1])) {
imagePaths.push(match[1]);
textContent = textContent.replace(match[0], ' ');
}
}
// 清理多余空格
textContent = textContent.replace(/\s+/g, ' ').trim();
return { imagePaths, textContent };
}
/**
* 加载图片文件
* @param filePath 图片路径(相对或绝对)
* @param workdir 工作目录(用于解析相对路径)
*/
export async function loadImage(
filePath: string,
workdir: string = process.cwd()
): Promise<ImageLoadResult> {
try {
// 解析路径
const absolutePath = path.isAbsolute(filePath)
? filePath
: path.resolve(workdir, filePath);
// 检查扩展名
const ext = path.extname(absolutePath).toLowerCase();
if (!IMAGE_EXTENSIONS.includes(ext)) {
return {
success: false,
error: `不支持的图片格式: ${ext}。支持的格式: ${IMAGE_EXTENSIONS.join(', ')}`,
};
}
// 读取文件
const buffer = await fs.readFile(absolutePath);
const stats = await fs.stat(absolutePath);
// 转换为 base64
const base64 = buffer.toString('base64');
const mimeType = MIME_TYPES[ext] || 'application/octet-stream';
const dataUrl = `data:${mimeType};base64,${base64}`;
return {
success: true,
image: {
path: absolutePath,
filename: path.basename(absolutePath),
extension: ext,
mimeType,
size: stats.size,
base64,
dataUrl,
},
};
} catch (error) {
if ((error as NodeJS.ErrnoException).code === 'ENOENT') {
return {
success: false,
error: `图片文件不存在: ${filePath}`,
};
}
return {
success: false,
error: `加载图片失败: ${error instanceof Error ? error.message : String(error)}`,
};
}
}
/**
* 批量加载图片
* @param filePaths 图片路径列表
* @param workdir 工作目录
*/
export async function loadImages(
filePaths: string[],
workdir: string = process.cwd()
): Promise<{
images: ImageInfo[];
errors: Array<{ path: string; error: string }>;
}> {
const images: ImageInfo[] = [];
const errors: Array<{ path: string; error: string }> = [];
for (const filePath of filePaths) {
const result = await loadImage(filePath, workdir);
if (result.success && result.image) {
images.push(result.image);
} else {
errors.push({ path: filePath, error: result.error || '未知错误' });
}
}
return { images, errors };
}
/**
* 格式化文件大小
*/
export function formatFileSize(bytes: number): string {
if (bytes < 1024) {
return `${bytes}B`;
} else if (bytes < 1024 * 1024) {
return `${(bytes / 1024).toFixed(1)}KB`;
} else {
return `${(bytes / (1024 * 1024)).toFixed(1)}MB`;
}
}
+217
View File
@@ -0,0 +1,217 @@
import { loadVisionConfig, type VisionConfig } from './config.js';
/**
* Vision 服务 - 用于图片理解
* 当主模型不支持 vision 时,使用独立的 Vision 服务分析图片
* 使用原生 fetch 调用 OpenAI 兼容接口,以确保与百炼等服务兼容
*/
export interface ImageData {
/** base64 编码的图片数据 */
data: string;
/** MIME 类型 */
mimeType: string;
/** 文件名(可选) */
filename?: string;
}
export interface VisionAnalysisResult {
success: boolean;
/** 图片描述 */
description: string;
/** 错误信息(如果失败) */
error?: string;
}
/**
* 分析单张图片
*/
export async function analyzeImage(
image: ImageData,
prompt?: string
): Promise<VisionAnalysisResult> {
const config = loadVisionConfig();
if (!config) {
return {
success: false,
description: '',
error: '未配置 Vision 服务。请在配置文件中设置 visionProvider、visionApiKey 等参数。',
};
}
try {
const description = await callVisionAPI(config, [image], prompt);
return {
success: true,
description,
};
} catch (error) {
return {
success: false,
description: '',
error: error instanceof Error ? error.message : String(error),
};
}
}
/**
* 批量分析图片
*/
export async function analyzeImages(
images: ImageData[],
prompt?: string
): Promise<VisionAnalysisResult> {
const config = loadVisionConfig();
if (!config) {
return {
success: false,
description: '',
error: '未配置 Vision 服务。请在配置文件中设置 visionProvider、visionApiKey 等参数。',
};
}
if (images.length === 0) {
return {
success: false,
description: '',
error: '没有提供图片',
};
}
try {
const description = await callVisionAPI(config, images, prompt);
return {
success: true,
description,
};
} catch (error) {
return {
success: false,
description: '',
error: error instanceof Error ? error.message : String(error),
};
}
}
/**
* 调用 Vision API
* 使用原生 fetch 调用 OpenAI 兼容接口,确保与百炼等服务兼容
*/
async function callVisionAPI(
config: VisionConfig,
images: ImageData[],
userPrompt?: string
): Promise<string> {
// 目前只支持 OpenAI 兼容的 Vision API(如百炼的 qwen-vl-plus
if (config.provider !== 'openai') {
throw new Error(`暂不支持 ${config.provider} 的 Vision 服务`);
}
// 构建消息内容(OpenAI Vision API 格式)
const content: Array<
| { type: 'text'; text: string }
| { type: 'image_url'; image_url: { url: string } }
> = [];
// 添加图片(使用 data URL 格式)
for (const img of images) {
content.push({
type: 'image_url',
image_url: {
url: `data:${img.mimeType};base64,${img.data}`,
},
});
}
// 添加提示文本
const defaultPrompt = images.length === 1
? '请详细描述这张图片的内容,包括主要元素、文字、颜色、布局等信息。'
: `请详细描述这 ${images.length} 张图片的内容,包括主要元素、文字、颜色、布局等信息。`;
content.push({
type: 'text',
text: userPrompt || defaultPrompt,
});
// 构建请求体
const requestBody = {
model: config.model,
messages: [
{
role: 'user',
content,
},
],
max_tokens: 2000,
};
// 确定 API 端点
const baseUrl = config.baseUrl || 'https://api.openai.com/v1';
const endpoint = `${baseUrl.replace(/\/$/, '')}/chat/completions`;
// 发送请求
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${config.apiKey}`,
},
body: JSON.stringify(requestBody),
});
if (!response.ok) {
const errorText = await response.text();
let errorMessage = `API 请求失败: ${response.status} ${response.statusText}`;
try {
const errorJson = JSON.parse(errorText);
if (errorJson.error?.message) {
errorMessage = errorJson.error.message;
}
} catch {
if (errorText) {
errorMessage += ` - ${errorText}`;
}
}
throw new Error(errorMessage);
}
const result = await response.json() as {
choices?: Array<{
message?: {
content?: string;
};
}>;
};
const text = result.choices?.[0]?.message?.content;
if (!text) {
throw new Error('API 返回了空响应');
}
return text;
}
/**
* 检查 Vision 服务是否可用
*/
export function isVisionAvailable(): boolean {
const config = loadVisionConfig();
return config !== null;
}
/**
* 获取 Vision 配置信息(用于显示)
*/
export function getVisionInfo(): { available: boolean; provider?: string; model?: string } {
const config = loadVisionConfig();
if (!config) {
return { available: false };
}
return {
available: true,
provider: config.provider,
model: config.model,
};
}
+155
View File
@@ -0,0 +1,155 @@
import { describe, it, expect } from 'vitest';
import {
isImagePath,
extractImageReferences,
formatFileSize,
IMAGE_EXTENSIONS,
} from '../../../src/utils/image.js';
describe('Image Utils - 图片处理工具', () => {
describe('IMAGE_EXTENSIONS', () => {
it('包含常见图片扩展名', () => {
expect(IMAGE_EXTENSIONS).toContain('.png');
expect(IMAGE_EXTENSIONS).toContain('.jpg');
expect(IMAGE_EXTENSIONS).toContain('.jpeg');
expect(IMAGE_EXTENSIONS).toContain('.gif');
expect(IMAGE_EXTENSIONS).toContain('.webp');
});
});
describe('isImagePath - 判断是否为图片路径', () => {
it('识别 PNG 文件', () => {
expect(isImagePath('screenshot.png')).toBe(true);
expect(isImagePath('path/to/image.PNG')).toBe(true);
});
it('识别 JPG/JPEG 文件', () => {
expect(isImagePath('photo.jpg')).toBe(true);
expect(isImagePath('photo.jpeg')).toBe(true);
expect(isImagePath('photo.JPG')).toBe(true);
});
it('识别 GIF 文件', () => {
expect(isImagePath('animation.gif')).toBe(true);
});
it('识别 WebP 文件', () => {
expect(isImagePath('modern.webp')).toBe(true);
});
it('不识别非图片文件', () => {
expect(isImagePath('document.txt')).toBe(false);
expect(isImagePath('script.ts')).toBe(false);
expect(isImagePath('data.json')).toBe(false);
expect(isImagePath('readme.md')).toBe(false);
});
it('处理没有扩展名的文件', () => {
expect(isImagePath('noextension')).toBe(false);
expect(isImagePath('Makefile')).toBe(false);
});
});
describe('extractImageReferences - 提取图片引用', () => {
it('提取单个 @ 引用', () => {
const result = extractImageReferences('请分析这张图片 @screenshot.png');
expect(result.imagePaths).toEqual(['screenshot.png']);
expect(result.textContent).toBe('请分析这张图片');
});
it('提取多个 @ 引用', () => {
const result = extractImageReferences('对比 @before.png 和 @after.jpg');
expect(result.imagePaths).toEqual(['before.png', 'after.jpg']);
expect(result.textContent).toBe('对比 和');
});
it('提取带路径的图片引用', () => {
const result = extractImageReferences('分析 @./images/test.png');
expect(result.imagePaths).toEqual(['./images/test.png']);
});
it('提取绝对路径图片引用', () => {
const result = extractImageReferences('查看 @/tmp/screenshot.png');
expect(result.imagePaths).toEqual(['/tmp/screenshot.png']);
});
it('提取带空格的绝对路径(自动匹配到扩展名)', () => {
const result = extractImageReferences('这张图片内容是什么?@/Users/xd/Adobe Express - file.png');
expect(result.imagePaths).toEqual(['/Users/xd/Adobe Express - file.png']);
expect(result.textContent).toBe('这张图片内容是什么?');
});
it('提取带引号的路径(支持空格)', () => {
const result = extractImageReferences('分析 @"./my images/test photo.png"');
expect(result.imagePaths).toEqual(['./my images/test photo.png']);
expect(result.textContent).toBe('分析');
});
it('提取带单引号的路径', () => {
const result = extractImageReferences("查看 @'./path with spaces/image.jpg'");
expect(result.imagePaths).toEqual(['./path with spaces/image.jpg']);
expect(result.textContent).toBe('查看');
});
it('忽略非图片的 @ 引用', () => {
const result = extractImageReferences('请查看 @readme.md 文件');
expect(result.imagePaths).toEqual([]);
expect(result.textContent).toBe('请查看 @readme.md 文件');
});
it('忽略邮箱地址', () => {
const result = extractImageReferences('联系 user@example.com 了解详情');
expect(result.imagePaths).toEqual([]);
expect(result.textContent).toBe('联系 user@example.com 了解详情');
});
it('混合图片和非图片引用', () => {
const result = extractImageReferences(
'查看 @screenshot.png 和 @config.json'
);
expect(result.imagePaths).toEqual(['screenshot.png']);
expect(result.textContent).toBe('查看 和 @config.json');
});
it('没有引用时返回原文本', () => {
const result = extractImageReferences('这是一段普通文本');
expect(result.imagePaths).toEqual([]);
expect(result.textContent).toBe('这是一段普通文本');
});
it('只有图片引用时文本内容为空', () => {
const result = extractImageReferences('@image.png');
expect(result.imagePaths).toEqual(['image.png']);
expect(result.textContent).toBe('');
});
it('处理多个空格', () => {
const result = extractImageReferences(' @test.png 描述 ');
expect(result.imagePaths).toEqual(['test.png']);
expect(result.textContent.trim()).toBe('描述');
});
});
describe('formatFileSize - 格式化文件大小', () => {
it('格式化字节', () => {
expect(formatFileSize(0)).toBe('0B');
expect(formatFileSize(100)).toBe('100B');
expect(formatFileSize(1023)).toBe('1023B');
});
it('格式化 KB', () => {
expect(formatFileSize(1024)).toBe('1.0KB');
expect(formatFileSize(1536)).toBe('1.5KB');
expect(formatFileSize(10240)).toBe('10.0KB');
});
it('格式化 MB', () => {
expect(formatFileSize(1024 * 1024)).toBe('1.0MB');
expect(formatFileSize(1024 * 1024 * 2.5)).toBe('2.5MB');
});
it('大文件以 MB 为单位', () => {
expect(formatFileSize(1024 * 1024 * 1024)).toBe('1024.0MB');
});
});
});