Files
ai-terminal-assistant/packages/core/src/provider/registry.ts
T
kurihada 5f38753f6d refactor: 清理未使用的类型定义和接口字段
- 移除 Provider 相关的 apiKeyEnvVar 字段(未实现的功能)
- 清理 Server routes 中未使用的 Core 类型导入
- 清理 UI Message 接口中未使用的 metadata 字段
2025-12-30 10:41:38 +08:00

447 lines
11 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Provider Registry
*
* 管理所有注册的提供商(内置 + 自定义)
*/
import type { LanguageModel } from 'ai';
import type {
ProviderInfo,
ProviderConfig,
CustomProviderDefinition,
RegisteredProvider,
ConnectionTestResult,
ProviderFactory,
ModelInfo,
ProviderListItem,
ProviderDetail,
} from './types.js';
import { builtinProviders, isBuiltinProvider } from './builtin/index.js';
import {
loadProvidersConfig,
saveProvidersConfig,
resolveApiKey,
} from './config.js';
import {
testOpenAICompatibleConnection,
createOpenAICompatibleFactory,
isValidProviderId,
isValidUrl,
} from './utils.js';
/**
* Provider Registry
* 管理所有提供商的单例类
*/
export class ProviderRegistry {
/** 已注册的提供商 */
private providers: Map<string, RegisteredProvider> = new Map();
/** 提供商配置 */
private configs: Map<string, ProviderConfig> = new Map();
/** 自定义提供商定义 */
private customDefinitions: Map<string, CustomProviderDefinition> = new Map();
/** 是否已完全初始化(包括用户配置) */
private fullyInitialized = false;
constructor() {
// 同步初始化内置提供商(不需要异步加载)
this.initBuiltinProviders();
}
/**
* 同步初始化内置提供商
*/
private initBuiltinProviders(): void {
for (const [id, builtin] of Object.entries(builtinProviders)) {
this.providers.set(id, {
info: builtin.info,
factory: builtin.factory,
});
}
}
/**
* 完整初始化 Registry
* 加载用户配置(自定义提供商和配置)
*/
async init(): Promise<void> {
if (this.fullyInitialized) return;
// 加载用户配置(从 ~/.ai-terminal-assistant/providers.json
const config = await loadProvidersConfig();
// 加载自定义提供商
if (config.providers) {
for (const [id, definition] of Object.entries(config.providers)) {
this.registerCustomInternal(definition);
}
}
// 加载提供商配置
if (config.configs) {
for (const [id, providerConfig] of Object.entries(config.configs)) {
this.configs.set(id, providerConfig);
}
}
this.fullyInitialized = true;
}
/**
* 检查是否已完全初始化
*/
isInitialized(): boolean {
return this.fullyInitialized;
}
/**
* 确保内置提供商已加载(总是true,因为构造函数中已初始化)
* 保留此方法以便未来可能需要的检查
*/
private ensureInitialized(): void {
// 内置提供商在构造函数中同步初始化,总是可用
// 用户配置需要调用 init() 异步加载
}
/**
* 列出所有提供商
*/
list(): ProviderInfo[] {
this.ensureInitialized();
return Array.from(this.providers.values()).map((p) => p.info);
}
/**
* 列出所有提供商(API 响应格式)
*/
listForApi(): ProviderListItem[] {
this.ensureInitialized();
return Array.from(this.providers.entries()).map(([id, provider]) => {
const config = this.configs.get(id);
const apiKey = resolveApiKey(config);
const customModels = config?.customModels ?? [];
return {
id,
name: provider.info.name,
description: provider.info.description,
builtin: provider.info.builtin,
enabled: config?.enabled ?? true,
hasApiKey: !!apiKey,
modelCount: provider.info.models.length + customModels.length,
};
});
}
/**
* 获取提供商
*/
get(id: string): RegisteredProvider | undefined {
this.ensureInitialized();
return this.providers.get(id);
}
/**
* 获取提供商信息
*/
getInfo(id: string): ProviderInfo | undefined {
return this.get(id)?.info;
}
/**
* 获取提供商详情(API 响应格式)
*/
getDetail(id: string): ProviderDetail | undefined {
this.ensureInitialized();
const provider = this.providers.get(id);
if (!provider) return undefined;
const config = this.configs.get(id);
const apiKey = resolveApiKey(config);
return {
id,
name: provider.info.name,
description: provider.info.description,
builtin: provider.info.builtin,
baseUrl: config?.baseUrl ?? provider.info.baseUrl,
models: provider.info.models,
allowCustomModels: provider.info.allowCustomModels ?? false,
config: {
enabled: config?.enabled ?? true,
hasApiKey: !!apiKey,
baseUrl: config?.baseUrl,
customModels: config?.customModels ?? [],
},
};
}
/**
* 检查提供商是否存在
*/
has(id: string): boolean {
this.ensureInitialized();
return this.providers.has(id);
}
/**
* 内部注册自定义提供商
*/
private registerCustomInternal(definition: CustomProviderDefinition): void {
const factory = createOpenAICompatibleFactory(definition.baseUrl);
const info: ProviderInfo = {
id: definition.id,
name: definition.name,
description: definition.description,
builtin: false,
baseUrl: definition.baseUrl,
models: definition.models ?? [],
allowCustomModels: definition.allowCustomModels ?? true,
};
this.providers.set(definition.id, { info, factory });
this.customDefinitions.set(definition.id, definition);
}
/**
* 注册自定义提供商
*/
registerCustom(definition: CustomProviderDefinition): void {
this.ensureInitialized();
// 验证 ID
if (!isValidProviderId(definition.id)) {
throw new Error(`Invalid provider ID: ${definition.id}`);
}
// 不能覆盖内置提供商
if (isBuiltinProvider(definition.id)) {
throw new Error(`Cannot override builtin provider: ${definition.id}`);
}
// 验证 URL
if (!isValidUrl(definition.baseUrl)) {
throw new Error(`Invalid base URL: ${definition.baseUrl}`);
}
this.registerCustomInternal(definition);
}
/**
* 移除自定义提供商
*/
removeCustom(id: string): boolean {
this.ensureInitialized();
// 不能删除内置提供商
if (isBuiltinProvider(id)) {
throw new Error(`Cannot remove builtin provider: ${id}`);
}
const removed = this.providers.delete(id);
this.customDefinitions.delete(id);
this.configs.delete(id);
return removed;
}
/**
* 设置提供商配置
*/
setConfig(id: string, config: ProviderConfig): void {
this.ensureInitialized();
if (!this.providers.has(id)) {
throw new Error(`Provider not found: ${id}`);
}
this.configs.set(id, { ...config, id });
}
/**
* 获取提供商配置
*/
getConfig(id: string): ProviderConfig | undefined {
this.ensureInitialized();
return this.configs.get(id);
}
/**
* 获取所有配置
*/
getAllConfigs(): Record<string, ProviderConfig> {
this.ensureInitialized();
return Object.fromEntries(this.configs);
}
/**
* 获取提供商的模型列表
*/
getModels(providerId: string): ModelInfo[] {
this.ensureInitialized();
const provider = this.providers.get(providerId);
if (!provider) return [];
const config = this.configs.get(providerId);
const customModels = config?.customModels ?? [];
return [...provider.info.models, ...customModels];
}
/**
* 获取指定模型的详细信息
* @param providerId 提供商 ID
* @param modelId 模型 ID
* @returns 模型信息,未找到时返回 undefined
*/
getModelInfo(providerId: string, modelId: string): ModelInfo | undefined {
const models = this.getModels(providerId);
return models.find((m) => m.id === modelId);
}
/**
* 添加自定义模型
*/
addCustomModel(providerId: string, model: ModelInfo): void {
this.ensureInitialized();
if (!this.providers.has(providerId)) {
throw new Error(`Provider not found: ${providerId}`);
}
const config = this.configs.get(providerId) ?? { id: providerId };
const customModels = config.customModels ?? [];
// 检查是否已存在
if (customModels.some((m) => m.id === model.id)) {
throw new Error(`Model already exists: ${model.id}`);
}
customModels.push(model);
this.configs.set(providerId, { ...config, customModels });
}
/**
* 移除自定义模型
*/
removeCustomModel(providerId: string, modelId: string): boolean {
this.ensureInitialized();
const config = this.configs.get(providerId);
if (!config?.customModels) return false;
const index = config.customModels.findIndex((m) => m.id === modelId);
if (index === -1) return false;
config.customModels.splice(index, 1);
this.configs.set(providerId, config);
return true;
}
/**
* 测试提供商连接
*/
async testConnection(
providerId: string,
apiKey?: string
): Promise<ConnectionTestResult> {
this.ensureInitialized();
const provider = this.providers.get(providerId);
if (!provider) {
return { success: false, error: `Provider not found: ${providerId}` };
}
const config = this.configs.get(providerId);
const resolvedApiKey = apiKey ?? resolveApiKey(config);
if (!resolvedApiKey && !provider.info.baseUrl?.includes('localhost')) {
return { success: false, error: 'API key not configured' };
}
const baseUrl = config?.baseUrl ?? provider.info.baseUrl;
// 对于自定义提供商,使用 OpenAI 兼容测试
if (!provider.info.builtin && baseUrl) {
return testOpenAICompatibleConnection(resolvedApiKey ?? '', baseUrl);
}
// 对于内置提供商,简单验证 API key 存在
return {
success: !!resolvedApiKey,
error: resolvedApiKey ? undefined : 'API key not configured',
};
}
/**
* 获取模型工厂函数
*/
getModelFactory(
providerId: string,
options?: { apiKey?: string; baseUrl?: string }
): (model: string) => LanguageModel {
this.ensureInitialized();
const provider = this.providers.get(providerId);
if (!provider) {
throw new Error(`Provider not found: ${providerId}`);
}
const config = this.configs.get(providerId);
const apiKey = options?.apiKey ?? resolveApiKey(config);
const baseUrl = options?.baseUrl ?? config?.baseUrl ?? provider.info.baseUrl;
if (!apiKey) {
throw new Error(`API key not configured for provider: ${providerId}`);
}
return provider.factory({ apiKey, baseUrl });
}
/**
* 保存配置到文件
*/
async saveConfig(): Promise<void> {
this.ensureInitialized();
const config = {
providers: Object.fromEntries(this.customDefinitions),
configs: Object.fromEntries(this.configs),
};
await saveProvidersConfig(config);
}
/**
* 重新加载配置
*/
async reloadConfig(): Promise<void> {
this.fullyInitialized = false;
this.configs.clear();
this.customDefinitions.clear();
// 清除自定义提供商,保留内置提供商
for (const [id, provider] of this.providers.entries()) {
if (!provider.info.builtin) {
this.providers.delete(id);
}
}
await this.init();
}
}
/** 单例实例 */
export const providerRegistry = new ProviderRegistry();
/** 获取 ProviderRegistry 实例 */
export function getProviderRegistry(): ProviderRegistry {
return providerRegistry;
}