/** * Provider Registry * * 管理所有注册的提供商(内置 + 自定义) */ import type { LanguageModel } from 'ai'; import type { ProviderInfo, ProviderConfig, CustomProviderDefinition, RegisteredProvider, ConnectionTestResult, ProviderFactory, ModelInfo, ProviderListItem, ProviderDetail, } from './types.js'; import { builtinProviders, isBuiltinProvider } from './builtin/index.js'; import { loadProvidersConfig, saveProvidersConfig, resolveApiKey, } from './config.js'; import { testOpenAICompatibleConnection, createOpenAICompatibleFactory, isValidProviderId, isValidUrl, } from './utils.js'; /** * Provider Registry * 管理所有提供商的单例类 */ export class ProviderRegistry { /** 已注册的提供商 */ private providers: Map = new Map(); /** 提供商配置 */ private configs: Map = new Map(); /** 自定义提供商定义 */ private customDefinitions: Map = new Map(); /** 是否已完全初始化(包括用户配置) */ private fullyInitialized = false; constructor() { // 同步初始化内置提供商(不需要异步加载) this.initBuiltinProviders(); } /** * 同步初始化内置提供商 */ private initBuiltinProviders(): void { for (const [id, builtin] of Object.entries(builtinProviders)) { this.providers.set(id, { info: builtin.info, factory: builtin.factory, }); } } /** * 完整初始化 Registry * 加载用户配置(自定义提供商和配置) */ async init(): Promise { if (this.fullyInitialized) return; // 加载用户配置(从 ~/.ai-terminal-assistant/providers.json) const config = await loadProvidersConfig(); // 加载自定义提供商 if (config.providers) { for (const [id, definition] of Object.entries(config.providers)) { this.registerCustomInternal(definition); } } // 加载提供商配置 if (config.configs) { for (const [id, providerConfig] of Object.entries(config.configs)) { this.configs.set(id, providerConfig); } } this.fullyInitialized = true; } /** * 检查是否已完全初始化 */ isInitialized(): boolean { return this.fullyInitialized; } /** * 确保内置提供商已加载(总是true,因为构造函数中已初始化) * 保留此方法以便未来可能需要的检查 */ private ensureInitialized(): void { // 内置提供商在构造函数中同步初始化,总是可用 // 用户配置需要调用 init() 异步加载 } /** * 列出所有提供商 */ list(): ProviderInfo[] { this.ensureInitialized(); return Array.from(this.providers.values()).map((p) => p.info); } /** * 列出所有提供商(API 响应格式) */ listForApi(): ProviderListItem[] { this.ensureInitialized(); return Array.from(this.providers.entries()).map(([id, provider]) => { const config = this.configs.get(id); const apiKey = resolveApiKey(config); const customModels = config?.customModels ?? []; return { id, name: provider.info.name, description: provider.info.description, builtin: provider.info.builtin, enabled: config?.enabled ?? true, hasApiKey: !!apiKey, modelCount: provider.info.models.length + customModels.length, }; }); } /** * 获取提供商 */ get(id: string): RegisteredProvider | undefined { this.ensureInitialized(); return this.providers.get(id); } /** * 获取提供商信息 */ getInfo(id: string): ProviderInfo | undefined { return this.get(id)?.info; } /** * 获取提供商详情(API 响应格式) */ getDetail(id: string): ProviderDetail | undefined { this.ensureInitialized(); const provider = this.providers.get(id); if (!provider) return undefined; const config = this.configs.get(id); const apiKey = resolveApiKey(config); return { id, name: provider.info.name, description: provider.info.description, builtin: provider.info.builtin, baseUrl: config?.baseUrl ?? provider.info.baseUrl, models: provider.info.models, allowCustomModels: provider.info.allowCustomModels ?? false, config: { enabled: config?.enabled ?? true, hasApiKey: !!apiKey, baseUrl: config?.baseUrl, customModels: config?.customModels ?? [], }, }; } /** * 检查提供商是否存在 */ has(id: string): boolean { this.ensureInitialized(); return this.providers.has(id); } /** * 内部注册自定义提供商 */ private registerCustomInternal(definition: CustomProviderDefinition): void { const factory = createOpenAICompatibleFactory(definition.baseUrl); const info: ProviderInfo = { id: definition.id, name: definition.name, description: definition.description, builtin: false, baseUrl: definition.baseUrl, models: definition.models ?? [], allowCustomModels: definition.allowCustomModels ?? true, }; this.providers.set(definition.id, { info, factory }); this.customDefinitions.set(definition.id, definition); } /** * 注册自定义提供商 */ registerCustom(definition: CustomProviderDefinition): void { this.ensureInitialized(); // 验证 ID if (!isValidProviderId(definition.id)) { throw new Error(`Invalid provider ID: ${definition.id}`); } // 不能覆盖内置提供商 if (isBuiltinProvider(definition.id)) { throw new Error(`Cannot override builtin provider: ${definition.id}`); } // 验证 URL if (!isValidUrl(definition.baseUrl)) { throw new Error(`Invalid base URL: ${definition.baseUrl}`); } this.registerCustomInternal(definition); } /** * 移除自定义提供商 */ removeCustom(id: string): boolean { this.ensureInitialized(); // 不能删除内置提供商 if (isBuiltinProvider(id)) { throw new Error(`Cannot remove builtin provider: ${id}`); } const removed = this.providers.delete(id); this.customDefinitions.delete(id); this.configs.delete(id); return removed; } /** * 设置提供商配置 */ setConfig(id: string, config: ProviderConfig): void { this.ensureInitialized(); if (!this.providers.has(id)) { throw new Error(`Provider not found: ${id}`); } this.configs.set(id, { ...config, id }); } /** * 获取提供商配置 */ getConfig(id: string): ProviderConfig | undefined { this.ensureInitialized(); return this.configs.get(id); } /** * 获取所有配置 */ getAllConfigs(): Record { this.ensureInitialized(); return Object.fromEntries(this.configs); } /** * 获取提供商的模型列表 */ getModels(providerId: string): ModelInfo[] { this.ensureInitialized(); const provider = this.providers.get(providerId); if (!provider) return []; const config = this.configs.get(providerId); const customModels = config?.customModels ?? []; return [...provider.info.models, ...customModels]; } /** * 获取指定模型的详细信息 * @param providerId 提供商 ID * @param modelId 模型 ID * @returns 模型信息,未找到时返回 undefined */ getModelInfo(providerId: string, modelId: string): ModelInfo | undefined { const models = this.getModels(providerId); return models.find((m) => m.id === modelId); } /** * 添加自定义模型 */ addCustomModel(providerId: string, model: ModelInfo): void { this.ensureInitialized(); if (!this.providers.has(providerId)) { throw new Error(`Provider not found: ${providerId}`); } const config = this.configs.get(providerId) ?? { id: providerId }; const customModels = config.customModels ?? []; // 检查是否已存在 if (customModels.some((m) => m.id === model.id)) { throw new Error(`Model already exists: ${model.id}`); } customModels.push(model); this.configs.set(providerId, { ...config, customModels }); } /** * 移除自定义模型 */ removeCustomModel(providerId: string, modelId: string): boolean { this.ensureInitialized(); const config = this.configs.get(providerId); if (!config?.customModels) return false; const index = config.customModels.findIndex((m) => m.id === modelId); if (index === -1) return false; config.customModels.splice(index, 1); this.configs.set(providerId, config); return true; } /** * 测试提供商连接 */ async testConnection( providerId: string, apiKey?: string ): Promise { this.ensureInitialized(); const provider = this.providers.get(providerId); if (!provider) { return { success: false, error: `Provider not found: ${providerId}` }; } const config = this.configs.get(providerId); const resolvedApiKey = apiKey ?? resolveApiKey(config); if (!resolvedApiKey && !provider.info.baseUrl?.includes('localhost')) { return { success: false, error: 'API key not configured' }; } const baseUrl = config?.baseUrl ?? provider.info.baseUrl; // 对于自定义提供商,使用 OpenAI 兼容测试 if (!provider.info.builtin && baseUrl) { return testOpenAICompatibleConnection(resolvedApiKey ?? '', baseUrl); } // 对于内置提供商,简单验证 API key 存在 return { success: !!resolvedApiKey, error: resolvedApiKey ? undefined : 'API key not configured', }; } /** * 获取模型工厂函数 */ getModelFactory( providerId: string, options?: { apiKey?: string; baseUrl?: string } ): (model: string) => LanguageModel { this.ensureInitialized(); const provider = this.providers.get(providerId); if (!provider) { throw new Error(`Provider not found: ${providerId}`); } const config = this.configs.get(providerId); const apiKey = options?.apiKey ?? resolveApiKey(config); const baseUrl = options?.baseUrl ?? config?.baseUrl ?? provider.info.baseUrl; if (!apiKey) { throw new Error(`API key not configured for provider: ${providerId}`); } return provider.factory({ apiKey, baseUrl }); } /** * 保存配置到文件 */ async saveConfig(): Promise { this.ensureInitialized(); const config = { providers: Object.fromEntries(this.customDefinitions), configs: Object.fromEntries(this.configs), }; await saveProvidersConfig(config); } /** * 重新加载配置 */ async reloadConfig(): Promise { this.fullyInitialized = false; this.configs.clear(); this.customDefinitions.clear(); // 清除自定义提供商,保留内置提供商 for (const [id, provider] of this.providers.entries()) { if (!provider.info.builtin) { this.providers.delete(id); } } await this.init(); } } /** 单例实例 */ export const providerRegistry = new ProviderRegistry(); /** 获取 ProviderRegistry 实例 */ export function getProviderRegistry(): ProviderRegistry { return providerRegistry; }