a6c1e792fa
Provider 和 Agent 配置统一从全局目录加载,无需传递 workdir: - ProviderRegistry.init() 不再需要 workdir 参数 - AgentRegistry.init() 不再需要 workdir 参数 - 配置文件路径统一使用 ~/.ai-terminal-assistant/
449 lines
11 KiB
TypeScript
449 lines
11 KiB
TypeScript
/**
|
||
* Provider Registry
|
||
*
|
||
* 管理所有注册的提供商(内置 + 自定义)
|
||
*/
|
||
|
||
import type { LanguageModel } from 'ai';
|
||
import type {
|
||
ProviderInfo,
|
||
ProviderConfig,
|
||
CustomProviderDefinition,
|
||
RegisteredProvider,
|
||
ConnectionTestResult,
|
||
ProviderFactory,
|
||
ModelInfo,
|
||
ProviderListItem,
|
||
ProviderDetail,
|
||
} from './types.js';
|
||
import { builtinProviders, isBuiltinProvider } from './builtin/index.js';
|
||
import {
|
||
loadProvidersConfig,
|
||
saveProvidersConfig,
|
||
resolveApiKey,
|
||
} from './config.js';
|
||
import {
|
||
testOpenAICompatibleConnection,
|
||
createOpenAICompatibleFactory,
|
||
isValidProviderId,
|
||
isValidUrl,
|
||
} from './utils.js';
|
||
|
||
/**
|
||
* Provider Registry
|
||
* 管理所有提供商的单例类
|
||
*/
|
||
export class ProviderRegistry {
|
||
/** 已注册的提供商 */
|
||
private providers: Map<string, RegisteredProvider> = new Map();
|
||
|
||
/** 提供商配置 */
|
||
private configs: Map<string, ProviderConfig> = new Map();
|
||
|
||
/** 自定义提供商定义 */
|
||
private customDefinitions: Map<string, CustomProviderDefinition> = new Map();
|
||
|
||
/** 是否已完全初始化(包括用户配置) */
|
||
private fullyInitialized = false;
|
||
|
||
constructor() {
|
||
// 同步初始化内置提供商(不需要异步加载)
|
||
this.initBuiltinProviders();
|
||
}
|
||
|
||
/**
|
||
* 同步初始化内置提供商
|
||
*/
|
||
private initBuiltinProviders(): void {
|
||
for (const [id, builtin] of Object.entries(builtinProviders)) {
|
||
this.providers.set(id, {
|
||
info: builtin.info,
|
||
factory: builtin.factory,
|
||
});
|
||
}
|
||
}
|
||
|
||
/**
|
||
* 完整初始化 Registry
|
||
* 加载用户配置(自定义提供商和配置)
|
||
*/
|
||
async init(): Promise<void> {
|
||
if (this.fullyInitialized) return;
|
||
|
||
// 加载用户配置(从 ~/.ai-terminal-assistant/providers.json)
|
||
const config = await loadProvidersConfig();
|
||
|
||
// 加载自定义提供商
|
||
if (config.providers) {
|
||
for (const [id, definition] of Object.entries(config.providers)) {
|
||
this.registerCustomInternal(definition);
|
||
}
|
||
}
|
||
|
||
// 加载提供商配置
|
||
if (config.configs) {
|
||
for (const [id, providerConfig] of Object.entries(config.configs)) {
|
||
this.configs.set(id, providerConfig);
|
||
}
|
||
}
|
||
|
||
this.fullyInitialized = true;
|
||
}
|
||
|
||
/**
|
||
* 检查是否已完全初始化
|
||
*/
|
||
isInitialized(): boolean {
|
||
return this.fullyInitialized;
|
||
}
|
||
|
||
/**
|
||
* 确保内置提供商已加载(总是true,因为构造函数中已初始化)
|
||
* 保留此方法以便未来可能需要的检查
|
||
*/
|
||
private ensureInitialized(): void {
|
||
// 内置提供商在构造函数中同步初始化,总是可用
|
||
// 用户配置需要调用 init() 异步加载
|
||
}
|
||
|
||
/**
|
||
* 列出所有提供商
|
||
*/
|
||
list(): ProviderInfo[] {
|
||
this.ensureInitialized();
|
||
return Array.from(this.providers.values()).map((p) => p.info);
|
||
}
|
||
|
||
/**
|
||
* 列出所有提供商(API 响应格式)
|
||
*/
|
||
listForApi(): ProviderListItem[] {
|
||
this.ensureInitialized();
|
||
return Array.from(this.providers.entries()).map(([id, provider]) => {
|
||
const config = this.configs.get(id);
|
||
const apiKey = resolveApiKey(config);
|
||
const customModels = config?.customModels ?? [];
|
||
|
||
return {
|
||
id,
|
||
name: provider.info.name,
|
||
description: provider.info.description,
|
||
builtin: provider.info.builtin,
|
||
enabled: config?.enabled ?? true,
|
||
hasApiKey: !!apiKey,
|
||
modelCount: provider.info.models.length + customModels.length,
|
||
};
|
||
});
|
||
}
|
||
|
||
/**
|
||
* 获取提供商
|
||
*/
|
||
get(id: string): RegisteredProvider | undefined {
|
||
this.ensureInitialized();
|
||
return this.providers.get(id);
|
||
}
|
||
|
||
/**
|
||
* 获取提供商信息
|
||
*/
|
||
getInfo(id: string): ProviderInfo | undefined {
|
||
return this.get(id)?.info;
|
||
}
|
||
|
||
/**
|
||
* 获取提供商详情(API 响应格式)
|
||
*/
|
||
getDetail(id: string): ProviderDetail | undefined {
|
||
this.ensureInitialized();
|
||
const provider = this.providers.get(id);
|
||
if (!provider) return undefined;
|
||
|
||
const config = this.configs.get(id);
|
||
const apiKey = resolveApiKey(config);
|
||
|
||
return {
|
||
id,
|
||
name: provider.info.name,
|
||
description: provider.info.description,
|
||
builtin: provider.info.builtin,
|
||
baseUrl: config?.baseUrl ?? provider.info.baseUrl,
|
||
apiKeyEnvVar: provider.info.apiKeyEnvVar,
|
||
models: provider.info.models,
|
||
allowCustomModels: provider.info.allowCustomModels ?? false,
|
||
config: {
|
||
enabled: config?.enabled ?? true,
|
||
hasApiKey: !!apiKey,
|
||
baseUrl: config?.baseUrl,
|
||
customModels: config?.customModels ?? [],
|
||
},
|
||
};
|
||
}
|
||
|
||
/**
|
||
* 检查提供商是否存在
|
||
*/
|
||
has(id: string): boolean {
|
||
this.ensureInitialized();
|
||
return this.providers.has(id);
|
||
}
|
||
|
||
/**
|
||
* 内部注册自定义提供商
|
||
*/
|
||
private registerCustomInternal(definition: CustomProviderDefinition): void {
|
||
const factory = createOpenAICompatibleFactory(definition.baseUrl);
|
||
|
||
const info: ProviderInfo = {
|
||
id: definition.id,
|
||
name: definition.name,
|
||
description: definition.description,
|
||
builtin: false,
|
||
baseUrl: definition.baseUrl,
|
||
apiKeyEnvVar: definition.apiKeyEnvVar,
|
||
models: definition.models ?? [],
|
||
allowCustomModels: definition.allowCustomModels ?? true,
|
||
};
|
||
|
||
this.providers.set(definition.id, { info, factory });
|
||
this.customDefinitions.set(definition.id, definition);
|
||
}
|
||
|
||
/**
|
||
* 注册自定义提供商
|
||
*/
|
||
registerCustom(definition: CustomProviderDefinition): void {
|
||
this.ensureInitialized();
|
||
|
||
// 验证 ID
|
||
if (!isValidProviderId(definition.id)) {
|
||
throw new Error(`Invalid provider ID: ${definition.id}`);
|
||
}
|
||
|
||
// 不能覆盖内置提供商
|
||
if (isBuiltinProvider(definition.id)) {
|
||
throw new Error(`Cannot override builtin provider: ${definition.id}`);
|
||
}
|
||
|
||
// 验证 URL
|
||
if (!isValidUrl(definition.baseUrl)) {
|
||
throw new Error(`Invalid base URL: ${definition.baseUrl}`);
|
||
}
|
||
|
||
this.registerCustomInternal(definition);
|
||
}
|
||
|
||
/**
|
||
* 移除自定义提供商
|
||
*/
|
||
removeCustom(id: string): boolean {
|
||
this.ensureInitialized();
|
||
|
||
// 不能删除内置提供商
|
||
if (isBuiltinProvider(id)) {
|
||
throw new Error(`Cannot remove builtin provider: ${id}`);
|
||
}
|
||
|
||
const removed = this.providers.delete(id);
|
||
this.customDefinitions.delete(id);
|
||
this.configs.delete(id);
|
||
|
||
return removed;
|
||
}
|
||
|
||
/**
|
||
* 设置提供商配置
|
||
*/
|
||
setConfig(id: string, config: ProviderConfig): void {
|
||
this.ensureInitialized();
|
||
|
||
if (!this.providers.has(id)) {
|
||
throw new Error(`Provider not found: ${id}`);
|
||
}
|
||
|
||
this.configs.set(id, { ...config, id });
|
||
}
|
||
|
||
/**
|
||
* 获取提供商配置
|
||
*/
|
||
getConfig(id: string): ProviderConfig | undefined {
|
||
this.ensureInitialized();
|
||
return this.configs.get(id);
|
||
}
|
||
|
||
/**
|
||
* 获取所有配置
|
||
*/
|
||
getAllConfigs(): Record<string, ProviderConfig> {
|
||
this.ensureInitialized();
|
||
return Object.fromEntries(this.configs);
|
||
}
|
||
|
||
/**
|
||
* 获取提供商的模型列表
|
||
*/
|
||
getModels(providerId: string): ModelInfo[] {
|
||
this.ensureInitialized();
|
||
|
||
const provider = this.providers.get(providerId);
|
||
if (!provider) return [];
|
||
|
||
const config = this.configs.get(providerId);
|
||
const customModels = config?.customModels ?? [];
|
||
|
||
return [...provider.info.models, ...customModels];
|
||
}
|
||
|
||
/**
|
||
* 获取指定模型的详细信息
|
||
* @param providerId 提供商 ID
|
||
* @param modelId 模型 ID
|
||
* @returns 模型信息,未找到时返回 undefined
|
||
*/
|
||
getModelInfo(providerId: string, modelId: string): ModelInfo | undefined {
|
||
const models = this.getModels(providerId);
|
||
return models.find((m) => m.id === modelId);
|
||
}
|
||
|
||
/**
|
||
* 添加自定义模型
|
||
*/
|
||
addCustomModel(providerId: string, model: ModelInfo): void {
|
||
this.ensureInitialized();
|
||
|
||
if (!this.providers.has(providerId)) {
|
||
throw new Error(`Provider not found: ${providerId}`);
|
||
}
|
||
|
||
const config = this.configs.get(providerId) ?? { id: providerId };
|
||
const customModels = config.customModels ?? [];
|
||
|
||
// 检查是否已存在
|
||
if (customModels.some((m) => m.id === model.id)) {
|
||
throw new Error(`Model already exists: ${model.id}`);
|
||
}
|
||
|
||
customModels.push(model);
|
||
this.configs.set(providerId, { ...config, customModels });
|
||
}
|
||
|
||
/**
|
||
* 移除自定义模型
|
||
*/
|
||
removeCustomModel(providerId: string, modelId: string): boolean {
|
||
this.ensureInitialized();
|
||
|
||
const config = this.configs.get(providerId);
|
||
if (!config?.customModels) return false;
|
||
|
||
const index = config.customModels.findIndex((m) => m.id === modelId);
|
||
if (index === -1) return false;
|
||
|
||
config.customModels.splice(index, 1);
|
||
this.configs.set(providerId, config);
|
||
|
||
return true;
|
||
}
|
||
|
||
/**
|
||
* 测试提供商连接
|
||
*/
|
||
async testConnection(
|
||
providerId: string,
|
||
apiKey?: string
|
||
): Promise<ConnectionTestResult> {
|
||
this.ensureInitialized();
|
||
|
||
const provider = this.providers.get(providerId);
|
||
if (!provider) {
|
||
return { success: false, error: `Provider not found: ${providerId}` };
|
||
}
|
||
|
||
const config = this.configs.get(providerId);
|
||
const resolvedApiKey = apiKey ?? resolveApiKey(config);
|
||
|
||
if (!resolvedApiKey && !provider.info.baseUrl?.includes('localhost')) {
|
||
return { success: false, error: 'API key not configured' };
|
||
}
|
||
|
||
const baseUrl = config?.baseUrl ?? provider.info.baseUrl;
|
||
|
||
// 对于自定义提供商,使用 OpenAI 兼容测试
|
||
if (!provider.info.builtin && baseUrl) {
|
||
return testOpenAICompatibleConnection(resolvedApiKey ?? '', baseUrl);
|
||
}
|
||
|
||
// 对于内置提供商,简单验证 API key 存在
|
||
return {
|
||
success: !!resolvedApiKey,
|
||
error: resolvedApiKey ? undefined : 'API key not configured',
|
||
};
|
||
}
|
||
|
||
/**
|
||
* 获取模型工厂函数
|
||
*/
|
||
getModelFactory(
|
||
providerId: string,
|
||
options?: { apiKey?: string; baseUrl?: string }
|
||
): (model: string) => LanguageModel {
|
||
this.ensureInitialized();
|
||
|
||
const provider = this.providers.get(providerId);
|
||
if (!provider) {
|
||
throw new Error(`Provider not found: ${providerId}`);
|
||
}
|
||
|
||
const config = this.configs.get(providerId);
|
||
const apiKey = options?.apiKey ?? resolveApiKey(config);
|
||
const baseUrl = options?.baseUrl ?? config?.baseUrl ?? provider.info.baseUrl;
|
||
|
||
if (!apiKey) {
|
||
throw new Error(`API key not configured for provider: ${providerId}`);
|
||
}
|
||
|
||
return provider.factory({ apiKey, baseUrl });
|
||
}
|
||
|
||
/**
|
||
* 保存配置到文件
|
||
*/
|
||
async saveConfig(): Promise<void> {
|
||
this.ensureInitialized();
|
||
|
||
const config = {
|
||
providers: Object.fromEntries(this.customDefinitions),
|
||
configs: Object.fromEntries(this.configs),
|
||
};
|
||
|
||
await saveProvidersConfig(config);
|
||
}
|
||
|
||
/**
|
||
* 重新加载配置
|
||
*/
|
||
async reloadConfig(): Promise<void> {
|
||
this.fullyInitialized = false;
|
||
this.configs.clear();
|
||
this.customDefinitions.clear();
|
||
|
||
// 清除自定义提供商,保留内置提供商
|
||
for (const [id, provider] of this.providers.entries()) {
|
||
if (!provider.info.builtin) {
|
||
this.providers.delete(id);
|
||
}
|
||
}
|
||
|
||
await this.init();
|
||
}
|
||
}
|
||
|
||
/** 单例实例 */
|
||
export const providerRegistry = new ProviderRegistry();
|
||
|
||
/** 获取 ProviderRegistry 实例 */
|
||
export function getProviderRegistry(): ProviderRegistry {
|
||
return providerRegistry;
|
||
}
|