feat(context): 优化对话压缩系统

- 添加独立摘要模型配置支持(SUMMARY_PROVIDER/MODEL/API_KEY/BASE_URL)
- 添加 CompressionStatus 枚举和 DetailedCompressionResult 详细返回类型
- 实现压缩失败检测(空摘要、token膨胀)
- 添加首条 user-assistant 对保护,确保上下文连贯性
- CompressionManager 支持独立摘要模型(优先使用小模型降低成本)
- Agent 自动压缩时显示详细状态信息
- 更新相关测试用例
This commit is contained in:
2025-12-13 11:13:20 +08:00
parent 9ff2934089
commit f54f24b079
10 changed files with 495 additions and 102 deletions
+180 -27
View File
@@ -2,7 +2,9 @@ import { generateText, type ModelMessage, type LanguageModel } from 'ai';
import { TokenCounter } from './token-counter.js'; import { TokenCounter } from './token-counter.js';
import { import {
SUMMARY_MARKER, SUMMARY_MARKER,
CompressionStatus,
type CompressionConfig, type CompressionConfig,
type DetailedCompressionResult,
DEFAULT_COMPRESSION_CONFIG, DEFAULT_COMPRESSION_CONFIG,
} from './types.js'; } from './types.js';
@@ -58,6 +60,90 @@ function createSummaryMessage(summary: string): ModelMessage {
}; };
} }
/**
* 验证摘要结果
* @returns 验证结果,包含状态和 token 数
*/
function validateSummary(
summary: string,
originalTokens: number
): { valid: boolean; status: CompressionStatus; summaryTokens: number } {
// 1. 检测空摘要
if (!summary || summary.trim().length === 0) {
return {
valid: false,
status: CompressionStatus.FAILED_EMPTY_SUMMARY,
summaryTokens: 0,
};
}
// 2. 检测 token 膨胀(摘要比原消息还大)
const summaryTokens = TokenCounter.estimateText(summary);
if (summaryTokens >= originalTokens) {
return {
valid: false,
status: CompressionStatus.FAILED_TOKEN_INFLATED,
summaryTokens,
};
}
return {
valid: true,
status: CompressionStatus.SUCCESS,
summaryTokens,
};
}
/**
* 查找首个 user-assistant 对
* 用于首条消息保护,确保上下文连贯性
*/
function findFirstUserAssistantPair(messages: ModelMessage[]): {
firstPair: ModelMessage[];
rest: ModelMessage[];
} | null {
// 找到第一个 user 消息
let userIndex = -1;
for (let i = 0; i < messages.length; i++) {
if (messages[i].role === 'user') {
userIndex = i;
break;
}
}
if (userIndex === -1) return null;
// 找到紧随其后的 assistant 消息
let assistantIndex = -1;
for (let i = userIndex + 1; i < messages.length; i++) {
if (messages[i].role === 'assistant') {
assistantIndex = i;
break;
}
}
if (assistantIndex === -1) {
// 只有 user 消息,没有 assistant 响应
return {
firstPair: messages.slice(0, userIndex + 1),
rest: messages.slice(userIndex + 1),
};
}
return {
firstPair: messages.slice(0, assistantIndex + 1),
rest: messages.slice(assistantIndex + 1),
};
}
/**
* Compaction 选项
*/
export interface CompactOptions {
/** 是否保护首个 user-assistant 对 */
protectFirstPair?: boolean;
}
/** /**
* Compaction 策略:使用 AI 生成对话摘要 * Compaction 策略:使用 AI 生成对话摘要
* *
@@ -70,21 +156,36 @@ function createSummaryMessage(summary: string): ModelMessage {
* @param messages 消息数组 * @param messages 消息数组
* @param model 语言模型 * @param model 语言模型
* @param config 压缩配置 * @param config 压缩配置
* @returns 压缩后的消息数组和释放的 tokens * @param options 压缩选项
* @returns 详细压缩结果
*/ */
export async function compact( export async function compact(
messages: ModelMessage[], messages: ModelMessage[],
model: LanguageModel, model: LanguageModel,
config: CompressionConfig = DEFAULT_COMPRESSION_CONFIG config: CompressionConfig = DEFAULT_COMPRESSION_CONFIG,
): Promise<{ messages: ModelMessage[]; freedTokens: number }> { options: CompactOptions = {}
): Promise<DetailedCompressionResult> {
const { pruneProtect } = config; const { pruneProtect } = config;
const { protectFirstPair = true } = options;
// 计算需要保护的消息数量 // 首条消息保护:分离首个 user-assistant 对
let protectedFirst: ModelMessage[] = [];
let compressibleMessages = messages;
if (protectFirstPair) {
const firstPairResult = findFirstUserAssistantPair(messages);
if (firstPairResult && firstPairResult.rest.length > 0) {
protectedFirst = firstPairResult.firstPair;
compressibleMessages = firstPairResult.rest;
}
}
// 计算需要保护的消息数量(从可压缩部分的末尾算起)
let protectedTokens = 0; let protectedTokens = 0;
let protectedCount = 0; let protectedCount = 0;
for (let i = messages.length - 1; i >= 0; i--) { for (let i = compressibleMessages.length - 1; i >= 0; i--) {
const tokens = TokenCounter.estimateMessage(messages[i]); const tokens = TokenCounter.estimateMessage(compressibleMessages[i]);
if (protectedTokens + tokens > pruneProtect) { if (protectedTokens + tokens > pruneProtect) {
break; break;
} }
@@ -101,12 +202,17 @@ export async function compact(
} }
// 分割消息:需要压缩的部分 vs 保护的部分 // 分割消息:需要压缩的部分 vs 保护的部分
const toCompact = messages.slice(0, messages.length - protectedCount); const toCompact = compressibleMessages.slice(0, compressibleMessages.length - protectedCount);
const toKeep = messages.slice(messages.length - protectedCount); const toKeep = compressibleMessages.slice(compressibleMessages.length - protectedCount);
// 如果没有需要压缩的消息,直接返回 // 如果没有需要压缩的消息,直接返回
if (toCompact.length === 0) { if (toCompact.length === 0) {
return { messages, freedTokens: 0 }; return {
messages,
freedTokens: 0,
type: 'none',
status: CompressionStatus.NOOP,
};
} }
// 检查是否已有摘要消息 // 检查是否已有摘要消息
@@ -115,7 +221,7 @@ export async function compact(
existingSummaryIndex >= 0 ? toCompact.slice(existingSummaryIndex) : toCompact; existingSummaryIndex >= 0 ? toCompact.slice(existingSummaryIndex) : toCompact;
// 计算压缩前的 tokens // 计算压缩前的 tokens
const beforeTokens = TokenCounter.estimateMessages(toCompact); const originalTokens = TokenCounter.estimateMessages(toCompact);
try { try {
// 调用 AI 生成摘要 // 调用 AI 生成摘要
@@ -132,18 +238,44 @@ export async function compact(
maxOutputTokens: 2000, maxOutputTokens: 2000,
}); });
const summaryMessage = createSummaryMessage(result.text); // 验证摘要结果
const afterTokens = TokenCounter.estimateMessage(summaryMessage); const validation = validateSummary(result.text, originalTokens);
// 返回:摘要 + 保护的消息 if (!validation.valid) {
console.warn(`摘要验证失败: ${validation.status}`);
return { return {
messages: [summaryMessage, ...toKeep], messages,
freedTokens: beforeTokens - afterTokens, freedTokens: 0,
type: 'none',
status: validation.status,
originalTokens,
summaryTokens: validation.summaryTokens,
};
}
const summaryMessage = createSummaryMessage(result.text);
const summaryTokens = TokenCounter.estimateMessage(summaryMessage);
const freedTokens = originalTokens - summaryTokens;
// 返回:首条保护 + 摘要 + 末尾保护的消息
return {
messages: [...protectedFirst, summaryMessage, ...toKeep],
freedTokens,
type: 'compaction',
status: CompressionStatus.SUCCESS,
originalTokens,
summaryTokens,
}; };
} catch (error) { } catch (error) {
console.error('生成摘要失败:', error); console.error('生成摘要失败:', error);
// 失败时返回原消息 return {
return { messages, freedTokens: 0 }; messages,
freedTokens: 0,
type: 'none',
status: CompressionStatus.FAILED_ERROR,
error: error instanceof Error ? error.message : String(error),
originalTokens,
};
} }
} }
@@ -153,16 +285,30 @@ export async function compact(
*/ */
export function simpleCompact( export function simpleCompact(
messages: ModelMessage[], messages: ModelMessage[],
config: CompressionConfig = DEFAULT_COMPRESSION_CONFIG config: CompressionConfig = DEFAULT_COMPRESSION_CONFIG,
): { messages: ModelMessage[]; freedTokens: number } { options: CompactOptions = {}
): DetailedCompressionResult {
const { pruneProtect } = config; const { pruneProtect } = config;
const { protectFirstPair = true } = options;
// 首条消息保护:分离首个 user-assistant 对
let protectedFirst: ModelMessage[] = [];
let compressibleMessages = messages;
if (protectFirstPair) {
const firstPairResult = findFirstUserAssistantPair(messages);
if (firstPairResult && firstPairResult.rest.length > 0) {
protectedFirst = firstPairResult.firstPair;
compressibleMessages = firstPairResult.rest;
}
}
// 计算需要保留的消息 // 计算需要保留的消息
let keptTokens = 0; let keptTokens = 0;
let keepFromIndex = messages.length; let keepFromIndex = compressibleMessages.length;
for (let i = messages.length - 1; i >= 0; i--) { for (let i = compressibleMessages.length - 1; i >= 0; i--) {
const tokens = TokenCounter.estimateMessage(messages[i]); const tokens = TokenCounter.estimateMessage(compressibleMessages[i]);
if (keptTokens + tokens > pruneProtect) { if (keptTokens + tokens > pruneProtect) {
break; break;
} }
@@ -172,13 +318,18 @@ export function simpleCompact(
// 确保至少保留最后 N 条消息(强制模式下保留 1 条,否则保留 2 条) // 确保至少保留最后 N 条消息(强制模式下保留 1 条,否则保留 2 条)
const minKeep = pruneProtect > 0 ? 2 : 1; const minKeep = pruneProtect > 0 ? 2 : 1;
keepFromIndex = Math.min(keepFromIndex, messages.length - minKeep); keepFromIndex = Math.min(keepFromIndex, compressibleMessages.length - minKeep);
const removed = messages.slice(0, keepFromIndex); const removed = compressibleMessages.slice(0, keepFromIndex);
const kept = messages.slice(keepFromIndex); const kept = compressibleMessages.slice(keepFromIndex);
if (removed.length === 0) { if (removed.length === 0) {
return { messages, freedTokens: 0 }; return {
messages,
freedTokens: 0,
type: 'none',
status: CompressionStatus.NOOP,
};
} }
// 创建简单摘要 // 创建简单摘要
@@ -190,7 +341,9 @@ export function simpleCompact(
const freedTokens = TokenCounter.estimateMessages(removed); const freedTokens = TokenCounter.estimateMessages(removed);
return { return {
messages: [simpleSummary, ...kept], messages: [...protectedFirst, simpleSummary, ...kept],
freedTokens, freedTokens,
type: 'compaction',
status: CompressionStatus.SUCCESS,
}; };
} }
+3 -1
View File
@@ -4,6 +4,7 @@ export type {
CompressionConfig, CompressionConfig,
CompressionContext, CompressionContext,
CompressionResult, CompressionResult,
DetailedCompressionResult,
} from './types.js'; } from './types.js';
export { export {
@@ -11,6 +12,7 @@ export {
COMPACTED_PLACEHOLDER, COMPACTED_PLACEHOLDER,
SUMMARY_MARKER, SUMMARY_MARKER,
COMPACTED_MARKER, COMPACTED_MARKER,
CompressionStatus,
} from './types.js'; } from './types.js';
// Token 计数器 // Token 计数器
@@ -20,7 +22,7 @@ export { TokenCounter } from './token-counter.js';
export { prune, filterCompacted } from './prune.js'; export { prune, filterCompacted } from './prune.js';
// Compaction 策略 // Compaction 策略
export { compact, simpleCompact, isSummaryMessage } from './compaction.js'; export { compact, simpleCompact, isSummaryMessage, type CompactOptions } from './compaction.js';
// 压缩管理器 // 压缩管理器
export { CompressionManager, compressionManager } from './manager.js'; export { CompressionManager, compressionManager } from './manager.js';
+122 -45
View File
@@ -1,11 +1,12 @@
import type { ModelMessage, LanguageModel } from 'ai'; import type { ModelMessage, LanguageModel } from 'ai';
import { TokenCounter } from './token-counter.js'; import { TokenCounter } from './token-counter.js';
import { prune, filterCompacted } from './prune.js'; import { prune, filterCompacted } from './prune.js';
import { compact, simpleCompact, isSummaryMessage } from './compaction.js'; import { compact, simpleCompact, isSummaryMessage, type CompactOptions } from './compaction.js';
import { import {
type TokenUsage, type TokenUsage,
type CompressionConfig, type CompressionConfig,
type CompressionResult, type DetailedCompressionResult,
CompressionStatus,
DEFAULT_COMPRESSION_CONFIG, DEFAULT_COMPRESSION_CONFIG,
} from './types.js'; } from './types.js';
@@ -15,19 +16,55 @@ import {
*/ */
export class CompressionManager { export class CompressionManager {
private config: CompressionConfig; private config: CompressionConfig;
/** 主模型(摘要模型的后备) */
private model: LanguageModel | null = null; private model: LanguageModel | null = null;
/** 专用摘要模型(推荐使用小模型以降低成本) */
private summaryModel: LanguageModel | null = null;
/** 是否保护首条 user-assistant 对 */
private protectFirstPair: boolean = true;
constructor(config: Partial<CompressionConfig> = {}) { constructor(config: Partial<CompressionConfig> = {}) {
this.config = { ...DEFAULT_COMPRESSION_CONFIG, ...config }; this.config = { ...DEFAULT_COMPRESSION_CONFIG, ...config };
} }
/** /**
* 设置用于生成摘要的模型 * 设置用于生成摘要的模型(后备)
*/ */
setModel(model: LanguageModel): void { setModel(model: LanguageModel): void {
this.model = model; this.model = model;
} }
/**
* 设置专用摘要模型(优先使用)
*/
setSummaryModel(model: LanguageModel): void {
this.summaryModel = model;
}
/**
* 获取用于摘要生成的模型
* 优先使用专用摘要模型,无则使用主模型
*/
private getSummaryModel(): LanguageModel | null {
return this.summaryModel ?? this.model;
}
/**
* 设置是否保护首条 user-assistant 对
*/
setProtectFirstPair(protect: boolean): void {
this.protectFirstPair = protect;
}
/**
* 获取压缩选项
*/
private getCompactOptions(): CompactOptions {
return {
protectFirstPair: this.protectFirstPair,
};
}
/** /**
* 获取当前配置 * 获取当前配置
*/ */
@@ -85,37 +122,70 @@ export class CompressionManager {
/** /**
* 执行 compaction 策略 * 执行 compaction 策略
*/ */
async compact(messages: ModelMessage[]): Promise<{ messages: ModelMessage[]; freedTokens: number }> { async compact(messages: ModelMessage[]): Promise<DetailedCompressionResult> {
if (this.model) { const summaryModel = this.getSummaryModel();
return compact(messages, this.model, this.config); if (summaryModel) {
return compact(messages, summaryModel, this.config, this.getCompactOptions());
} }
// 没有模型时使用简单压缩 // 没有模型时使用简单压缩
return simpleCompact(messages, this.config); return simpleCompact(messages, this.config, this.getCompactOptions());
} }
/** /**
* 自动压缩:先 prune,不够再 compact * 自动压缩:先 prune,不够再 compact
*/ */
async compress(messages: ModelMessage[]): Promise<CompressionResult> { async compress(messages: ModelMessage[]): Promise<DetailedCompressionResult> {
// 检查是否需要压缩
if (!this.shouldCompress(messages)) {
return {
messages,
freedTokens: 0,
type: 'none',
status: CompressionStatus.NOOP,
};
}
let result = [...messages]; let result = [...messages];
let totalFreed = 0; let totalFreed = 0;
let type: CompressionResult['type'] = 'prune'; let type: DetailedCompressionResult['type'] = 'none';
// 第一步:尝试 prune // 第一步:尝试 prune
const pruneResult = this.prune(result); const pruneResult = this.prune(result);
if (pruneResult.freedTokens > 0) { if (pruneResult.freedTokens > 0) {
result = pruneResult.messages; result = pruneResult.messages;
totalFreed += pruneResult.freedTokens; totalFreed += pruneResult.freedTokens;
type = 'prune';
} }
// 检查是否还需要进一步压缩 // 检查是否还需要进一步压缩
if (this.shouldCompress(result)) { if (this.shouldCompress(result)) {
// 第二步:执行 compaction // 第二步:执行 compaction
const compactResult = await this.compact(result); const compactResult = await this.compact(result);
if (compactResult.freedTokens > 0) {
if (compactResult.status === CompressionStatus.SUCCESS && compactResult.freedTokens > 0) {
result = compactResult.messages; result = compactResult.messages;
totalFreed += compactResult.freedTokens; totalFreed += compactResult.freedTokens;
type = pruneResult.freedTokens > 0 ? 'both' : 'compaction'; type = pruneResult.freedTokens > 0 ? 'both' : 'compaction';
return {
messages: result,
freedTokens: totalFreed,
type,
status: CompressionStatus.SUCCESS,
originalTokens: compactResult.originalTokens,
summaryTokens: compactResult.summaryTokens,
};
}
// compaction 失败,返回失败状态
if (compactResult.status !== CompressionStatus.NOOP) {
return {
messages, // 返回原消息
freedTokens: 0,
type: 'none',
status: compactResult.status,
error: compactResult.error,
};
} }
} }
@@ -123,6 +193,7 @@ export class CompressionManager {
messages: result, messages: result,
freedTokens: totalFreed, freedTokens: totalFreed,
type, type,
status: totalFreed > 0 ? CompressionStatus.SUCCESS : CompressionStatus.NOOP,
}; };
} }
@@ -130,19 +201,20 @@ export class CompressionManager {
* 强制压缩(用于 /compact 命令) * 强制压缩(用于 /compact 命令)
* 无论是否达到阈值都执行压缩 * 无论是否达到阈值都执行压缩
*/ */
async forceCompress(messages: ModelMessage[]): Promise<CompressionResult> { async forceCompress(messages: ModelMessage[]): Promise<DetailedCompressionResult> {
// 消息数量太少时不压缩(至少需要 4 条消息) // 消息数量太少时不压缩(至少需要 4 条消息)
if (messages.length <= 4) { if (messages.length <= 4) {
return { return {
messages, messages,
freedTokens: 0, freedTokens: 0,
type: 'prune', type: 'none',
status: CompressionStatus.NOOP,
}; };
} }
let result = [...messages]; let result = [...messages];
let totalFreed = 0; let totalFreed = 0;
let type: CompressionResult['type'] = 'prune'; let type: DetailedCompressionResult['type'] = 'none';
// 先尝试 prune(使用强制配置) // 先尝试 prune(使用强制配置)
const pruneConfig: CompressionConfig = { const pruneConfig: CompressionConfig = {
@@ -156,48 +228,52 @@ export class CompressionManager {
if (pruneResult.freedTokens > 0) { if (pruneResult.freedTokens > 0) {
result = pruneResult.messages; result = pruneResult.messages;
totalFreed += pruneResult.freedTokens; totalFreed += pruneResult.freedTokens;
type = 'prune';
} }
// 强制 compaction只保留最后 2 条消息 // 强制 compaction使用强制配置
// 计算保留消息的 tokens const summaryModel = this.getSummaryModel();
const keepCount = Math.min(2, result.length - 1); const forceConfig: CompressionConfig = {
const toKeep = result.slice(-keepCount);
const toCompact = result.slice(0, result.length - keepCount);
if (toCompact.length > 0) {
if (this.model) {
try {
const compactResult = await compact(result, this.model, {
...this.config, ...this.config,
pruneProtect: 0, // 强制模式:不保护任何 tokens pruneProtect: 0, // 强制模式:不保护任何 tokens
}); };
if (compactResult.freedTokens > 0) { // 强制模式不保护首条消息对
const forceOptions = { protectFirstPair: false };
if (summaryModel) {
const compactResult = await compact(result, summaryModel, forceConfig, forceOptions);
if (compactResult.status === CompressionStatus.SUCCESS && compactResult.freedTokens > 0) {
result = compactResult.messages; result = compactResult.messages;
totalFreed += compactResult.freedTokens; totalFreed += compactResult.freedTokens;
type = pruneResult.freedTokens > 0 ? 'both' : 'compaction'; type = type === 'prune' ? 'both' : 'compaction';
return {
messages: result,
freedTokens: totalFreed,
type,
status: CompressionStatus.SUCCESS,
originalTokens: compactResult.originalTokens,
summaryTokens: compactResult.summaryTokens,
};
} }
} catch {
// AI 压缩失败,使用简单压缩 // AI 压缩失败,回退到简单压缩
const compactResult = simpleCompact(result, { if (compactResult.status !== CompressionStatus.NOOP) {
...this.config, const simpleResult = simpleCompact(result, forceConfig, forceOptions);
pruneProtect: 0, if (simpleResult.freedTokens > 0) {
}); result = simpleResult.messages;
if (compactResult.freedTokens > 0) { totalFreed += simpleResult.freedTokens;
result = compactResult.messages; type = type === 'prune' ? 'both' : 'compaction';
totalFreed += compactResult.freedTokens;
type = pruneResult.freedTokens > 0 ? 'both' : 'compaction';
} }
} }
} else { } else {
const compactResult = simpleCompact(result, { // 没有模型,使用简单压缩
...this.config, const simpleResult = simpleCompact(result, forceConfig, forceOptions);
pruneProtect: 0, if (simpleResult.freedTokens > 0) {
}); result = simpleResult.messages;
if (compactResult.freedTokens > 0) { totalFreed += simpleResult.freedTokens;
result = compactResult.messages; type = type === 'prune' ? 'both' : 'compaction';
totalFreed += compactResult.freedTokens;
type = pruneResult.freedTokens > 0 ? 'both' : 'compaction';
}
} }
} }
@@ -205,6 +281,7 @@ export class CompressionManager {
messages: result, messages: result,
freedTokens: totalFreed, freedTokens: totalFreed,
type, type,
status: totalFreed > 0 ? CompressionStatus.SUCCESS : CompressionStatus.NOOP,
}; };
} }
+37 -1
View File
@@ -65,7 +65,7 @@ export interface CompressionContext {
} }
/** /**
* 压缩结果 * 压缩结果(基础)
*/ */
export interface CompressionResult { export interface CompressionResult {
/** 压缩后的消息 */ /** 压缩后的消息 */
@@ -75,3 +75,39 @@ export interface CompressionResult {
/** 压缩类型 */ /** 压缩类型 */
type: 'prune' | 'compaction' | 'both'; type: 'prune' | 'compaction' | 'both';
} }
/**
* 压缩状态枚举
*/
export enum CompressionStatus {
/** 成功压缩 */
SUCCESS = 'success',
/** 未达阈值,无需压缩 */
NOOP = 'noop',
/** 失败:空摘要 */
FAILED_EMPTY_SUMMARY = 'failed_empty_summary',
/** 失败:token 膨胀(摘要反而增加 token */
FAILED_TOKEN_INFLATED = 'failed_token_inflated',
/** 失败:其他错误 */
FAILED_ERROR = 'failed_error',
}
/**
* 详细压缩结果
*/
export interface DetailedCompressionResult {
/** 压缩后的消息 */
messages: import('ai').ModelMessage[];
/** 释放的 tokens(正数=成功,0或负数=失败) */
freedTokens: number;
/** 压缩类型 */
type: 'prune' | 'compaction' | 'both' | 'none';
/** 详细状态 */
status: CompressionStatus;
/** 错误信息(失败时) */
error?: string;
/** 原始 token 数(压缩前) */
originalTokens?: number;
/** 摘要 token 数(压缩后) */
summaryTokens?: number;
}
+31 -6
View File
@@ -12,12 +12,13 @@ import { ToolRegistry } from '../tools/registry.js';
import { SessionManager } from '../session/index.js'; import { SessionManager } from '../session/index.js';
import { import {
CompressionManager, CompressionManager,
CompressionStatus,
type TokenUsage, type TokenUsage,
type CompressionConfig, type CompressionConfig,
} from '../context/index.js'; } from '../context/index.js';
import type { AgentInfo, ImageData } from '../agent/types.js'; import type { AgentInfo, ImageData } from '../agent/types.js';
import { agentRegistry, AgentExecutor } from '../agent/index.js'; import { agentRegistry, AgentExecutor } from '../agent/index.js';
import { loadVisionConfig } from '../utils/config.js'; import { loadVisionConfig, loadSummaryConfig } from '../utils/config.js';
import { getProviderRegistry } from '../provider/index.js'; import { getProviderRegistry } from '../provider/index.js';
import { getHookManager } from '../hooks/index.js'; import { getHookManager } from '../hooks/index.js';
import { getGitManager } from '../git/index.js'; import { getGitManager } from '../git/index.js';
@@ -50,16 +51,26 @@ export class Agent {
this.originalSystemPrompt = config.systemPrompt; this.originalSystemPrompt = config.systemPrompt;
// 使用 ProviderRegistry 获取模型工厂 // 使用 ProviderRegistry 获取模型工厂
const registry = getProviderRegistry(); const providerRegistry = getProviderRegistry();
this.getModel = registry.getModelFactory(config.provider, { this.getModel = providerRegistry.getModelFactory(config.provider, {
apiKey: config.apiKey, apiKey: config.apiKey,
baseUrl: config.baseUrl, baseUrl: config.baseUrl,
}); });
// 初始化压缩管理器 // 初始化压缩管理器
this.compressionManager = new CompressionManager(compressionConfig); this.compressionManager = new CompressionManager(compressionConfig);
// 设置模型用于生成摘要 // 设置模型(作为摘要模型的后备)
this.compressionManager.setModel(this.getModel(config.model)); this.compressionManager.setModel(this.getModel(config.model));
// 加载摘要模型配置(可选,用于降低压缩成本)
const summaryConfig = loadSummaryConfig();
if (summaryConfig) {
const summaryModelFactory = providerRegistry.getModelFactory(summaryConfig.provider, {
apiKey: summaryConfig.apiKey,
baseUrl: summaryConfig.baseUrl,
});
this.compressionManager.setSummaryModel(summaryModelFactory(summaryConfig.model));
}
} }
/** /**
@@ -398,10 +409,24 @@ export class Agent {
// 检查是否需要自动压缩 // 检查是否需要自动压缩
if (this.compressionManager.shouldCompress(this.conversationHistory)) { if (this.compressionManager.shouldCompress(this.conversationHistory)) {
const result = await this.compressionManager.compress(this.conversationHistory); const result = await this.compressionManager.compress(this.conversationHistory);
if (result.freedTokens > 0) {
if (result.status === CompressionStatus.SUCCESS && result.freedTokens > 0) {
this.conversationHistory = result.messages; this.conversationHistory = result.messages;
if (onStream) { if (onStream) {
onStream(`\n[自动压缩: 释放了 ${(result.freedTokens / 1000).toFixed(1)}k tokens]\n`); const typeLabel = result.type === 'both' ? 'prune+摘要' : result.type === 'compaction' ? '摘要' : 'prune';
onStream(`\n[自动压缩(${typeLabel}): 释放了 ${(result.freedTokens / 1000).toFixed(1)}k tokens]\n`);
}
} else if (result.status === CompressionStatus.FAILED_EMPTY_SUMMARY) {
if (onStream) {
onStream('\n[压缩失败: 摘要生成为空,已跳过]\n');
}
} else if (result.status === CompressionStatus.FAILED_TOKEN_INFLATED) {
if (onStream) {
onStream('\n[压缩失败: 摘要反而增加了 token,已跳过]\n');
}
} else if (result.status === CompressionStatus.FAILED_ERROR) {
if (onStream) {
onStream(`\n[压缩失败: ${result.error || '未知错误'}]\n`);
} }
} }
} }
+92
View File
@@ -23,6 +23,13 @@ interface StoredConfig {
visionApiKey?: string; visionApiKey?: string;
/** Vision 专用的 Base URL(用于 OpenAI 兼容的 Vision 服务) */ /** Vision 专用的 Base URL(用于 OpenAI 兼容的 Vision 服务) */
visionBaseUrl?: string; visionBaseUrl?: string;
// Summary 配置(用于对话压缩摘要生成)
summaryProvider?: ProviderType;
summaryModel?: string;
/** Summary 专用的 API Key(可选,不设置则使用对应 provider 的 key */
summaryApiKey?: string;
/** Summary 专用的 Base URL(用于 OpenAI 兼容的 Summary 服务) */
summaryBaseUrl?: string;
} }
// Vision 配置接口 // Vision 配置接口
@@ -34,6 +41,15 @@ export interface VisionConfig {
baseUrl?: string; baseUrl?: string;
} }
// Summary 配置接口(用于对话压缩摘要生成)
export interface SummaryConfig {
provider: ProviderType;
apiKey: string;
model: string;
/** 自定义 Base URL(用于 OpenAI 兼容的 Summary 服务) */
baseUrl?: string;
}
// 默认模型配置 // 默认模型配置
const DEFAULT_MODELS: Record<ProviderType, string> = { const DEFAULT_MODELS: Record<ProviderType, string> = {
anthropic: 'claude-sonnet-4-20250514', anthropic: 'claude-sonnet-4-20250514',
@@ -48,6 +64,13 @@ const DEFAULT_VISION_MODELS: Record<ProviderType, string> = {
openai: 'gpt-4o', openai: 'gpt-4o',
}; };
// 默认 Summary 模型(推荐使用成本较低的模型)
const DEFAULT_SUMMARY_MODELS: Record<ProviderType, string> = {
anthropic: 'claude-3-5-haiku-20241022',
deepseek: 'deepseek-chat',
openai: 'gpt-4o-mini',
};
// 默认系统提示词 // 默认系统提示词
const DEFAULT_SYSTEM_PROMPT = `你是一个运行在终端中的 AI 编程助手。你可以帮助用户: const DEFAULT_SYSTEM_PROMPT = `你是一个运行在终端中的 AI 编程助手。你可以帮助用户:
- 读取和写入文件 - 读取和写入文件
@@ -200,6 +223,75 @@ export function loadVisionConfig(): VisionConfig | null {
}; };
} }
/**
* 加载 Summary 配置
* Summary 用于对话压缩时生成摘要,推荐使用成本较低的小模型
* 优先级:环境变量 > 配置文件 > null(使用主模型)
*/
export function loadSummaryConfig(): SummaryConfig | null {
// 从环境变量获取
const summaryProvider = process.env.SUMMARY_PROVIDER as ProviderType | undefined;
const summaryModel = process.env.SUMMARY_MODEL;
const summaryApiKey = process.env.SUMMARY_API_KEY;
const summaryBaseUrl = process.env.SUMMARY_BASE_URL;
const anthropicApiKey = process.env.ANTHROPIC_API_KEY;
const deepseekApiKey = process.env.DEEPSEEK_API_KEY;
const openaiApiKey = process.env.OPENAI_API_KEY;
// 从配置文件读取
const storedConfig = getConfig();
// 如果没有任何 summary 相关配置,返回 null(使用主模型)
const hasSummaryConfig =
summaryProvider ||
summaryModel ||
summaryApiKey ||
storedConfig.summaryProvider ||
storedConfig.summaryModel ||
storedConfig.summaryApiKey;
if (!hasSummaryConfig) {
return null;
}
// 确定 summary provider(默认使用主配置的 provider
const mainProvider = (process.env.AI_PROVIDER as ProviderType) || storedConfig.provider || 'anthropic';
const finalProvider = summaryProvider || storedConfig.summaryProvider || mainProvider;
// 获取 Summary 专用的 API Key(优先级:环境变量 > 配置文件专用 key > provider 对应的 key
let finalApiKey: string | undefined;
finalApiKey = summaryApiKey || storedConfig.summaryApiKey;
// 如果没有专用 key,回退到对应 provider 的 key
if (!finalApiKey) {
if (finalProvider === 'anthropic') {
finalApiKey = anthropicApiKey || storedConfig.apiKey;
} else if (finalProvider === 'deepseek') {
finalApiKey = deepseekApiKey || storedConfig.deepseekApiKey;
} else if (finalProvider === 'openai') {
finalApiKey = openaiApiKey || storedConfig.openaiApiKey;
}
}
// 如果没有 API Key,返回 null
if (!finalApiKey) {
return null;
}
// 确定模型
const finalModel = summaryModel || storedConfig.summaryModel || DEFAULT_SUMMARY_MODELS[finalProvider];
// 确定 baseUrlSummary 专用)
const finalBaseUrl = summaryBaseUrl || storedConfig.summaryBaseUrl;
return {
provider: finalProvider,
apiKey: finalApiKey,
model: finalModel,
baseUrl: finalBaseUrl,
};
}
// 保存配置 // 保存配置
export function saveConfig(config: Partial<StoredConfig>): void { export function saveConfig(config: Partial<StoredConfig>): void {
// 确保目录存在 // 确保目录存在
@@ -178,7 +178,8 @@ describe('simpleCompact - 简单压缩', () => {
messages.push(createUserMessage(`Message ${i}: ${'a'.repeat(100)}`)); messages.push(createUserMessage(`Message ${i}: ${'a'.repeat(100)}`));
} }
const result = simpleCompact(messages, testConfig); // 禁用首条保护以便测试摘要消息在第一位
const result = simpleCompact(messages, testConfig, { protectFirstPair: false });
if (result.freedTokens > 0) { if (result.freedTokens > 0) {
// 第一条消息应该是摘要 // 第一条消息应该是摘要
@@ -192,7 +193,8 @@ describe('simpleCompact - 简单压缩', () => {
messages.push(createUserMessage(`Message ${i}: ${'a'.repeat(100)}`)); messages.push(createUserMessage(`Message ${i}: ${'a'.repeat(100)}`));
} }
const result = simpleCompact(messages, testConfig); // 禁用首条保护以便测试摘要消息
const result = simpleCompact(messages, testConfig, { protectFirstPair: false });
if (result.freedTokens > 0) { if (result.freedTokens > 0) {
const summaryContent = result.messages[0].content as string; const summaryContent = result.messages[0].content as string;
@@ -1,5 +1,6 @@
import { describe, it, expect, beforeEach, vi } from 'vitest'; import { describe, it, expect, beforeEach, vi } from 'vitest';
import type { ModelMessage, LanguageModel } from 'ai'; import type { ModelMessage, LanguageModel } from 'ai';
import { CompressionStatus } from '../../../src/context/types.js';
// Mock prune module // Mock prune module
const mockPrune = vi.fn(); const mockPrune = vi.fn();
@@ -45,13 +46,13 @@ describe('CompressionManager - 压缩管理器扩展测试', () => {
vi.clearAllMocks(); vi.clearAllMocks();
manager = new CompressionManager(); manager = new CompressionManager();
// 默认 mock 返回值 // 默认 mock 返回值 - 使用 DetailedCompressionResult
mockEstimateMessages.mockReturnValue(1000); mockEstimateMessages.mockReturnValue(1000);
mockFormat.mockReturnValue('1K'); mockFormat.mockReturnValue('1K');
mockPrune.mockReturnValue({ messages: [], freedTokens: 0 }); mockPrune.mockReturnValue({ messages: [], freedTokens: 0 });
mockFilterCompacted.mockImplementation((msgs) => msgs); mockFilterCompacted.mockImplementation((msgs) => msgs);
mockCompact.mockResolvedValue({ messages: [], freedTokens: 0 }); mockCompact.mockResolvedValue({ messages: [], freedTokens: 0, type: 'none', status: CompressionStatus.NOOP });
mockSimpleCompact.mockReturnValue({ messages: [], freedTokens: 0 }); mockSimpleCompact.mockReturnValue({ messages: [], freedTokens: 0, type: 'none', status: CompressionStatus.NOOP });
mockIsSummaryMessage.mockReturnValue(false); mockIsSummaryMessage.mockReturnValue(false);
}); });
@@ -174,7 +175,7 @@ describe('CompressionManager - 压缩管理器扩展测试', () => {
it('有模型时使用 AI 压缩', async () => { it('有模型时使用 AI 压缩', async () => {
const mockModel = {} as LanguageModel; const mockModel = {} as LanguageModel;
manager.setModel(mockModel); manager.setModel(mockModel);
mockCompact.mockResolvedValue({ messages: [], freedTokens: 2000 }); mockCompact.mockResolvedValue({ messages: [], freedTokens: 2000, type: 'compaction', status: CompressionStatus.SUCCESS });
const messages = createMessages(5); const messages = createMessages(5);
const result = await manager.compact(messages); const result = await manager.compact(messages);
@@ -184,7 +185,7 @@ describe('CompressionManager - 压缩管理器扩展测试', () => {
}); });
it('无模型时使用简单压缩', async () => { it('无模型时使用简单压缩', async () => {
mockSimpleCompact.mockReturnValue({ messages: [], freedTokens: 500 }); mockSimpleCompact.mockReturnValue({ messages: [], freedTokens: 500, type: 'compaction', status: CompressionStatus.SUCCESS });
const messages = createMessages(5); const messages = createMessages(5);
const result = await manager.compact(messages); const result = await manager.compact(messages);
@@ -196,7 +197,9 @@ describe('CompressionManager - 压缩管理器扩展测试', () => {
describe('compress - 自动压缩', () => { describe('compress - 自动压缩', () => {
it('先 prune 后不需要 compact', async () => { it('先 prune 后不需要 compact', async () => {
mockEstimateMessages.mockReturnValue(10000); // 低于阈值 mockEstimateMessages
.mockReturnValueOnce(150000) // shouldCompress check - 高于阈值
.mockReturnValueOnce(10000); // 第二次 shouldCompress check - prune 后低于阈值
mockPrune.mockReturnValue({ messages: [], freedTokens: 500 }); mockPrune.mockReturnValue({ messages: [], freedTokens: 500 });
const messages = createMessages(5); const messages = createMessages(5);
@@ -210,7 +213,7 @@ describe('CompressionManager - 压缩管理器扩展测试', () => {
it('prune 后仍需 compact', async () => { it('prune 后仍需 compact', async () => {
mockEstimateMessages.mockReturnValue(150000); // 高于阈值 mockEstimateMessages.mockReturnValue(150000); // 高于阈值
mockPrune.mockReturnValue({ messages: createMessages(3), freedTokens: 500 }); mockPrune.mockReturnValue({ messages: createMessages(3), freedTokens: 500 });
mockSimpleCompact.mockReturnValue({ messages: [], freedTokens: 1000 }); mockSimpleCompact.mockReturnValue({ messages: [], freedTokens: 1000, type: 'compaction', status: CompressionStatus.SUCCESS });
const messages = createMessages(5); const messages = createMessages(5);
const result = await manager.compress(messages); const result = await manager.compress(messages);
@@ -222,7 +225,7 @@ describe('CompressionManager - 压缩管理器扩展测试', () => {
it('只 compact 时类型为 compaction', async () => { it('只 compact 时类型为 compaction', async () => {
mockEstimateMessages.mockReturnValue(150000); mockEstimateMessages.mockReturnValue(150000);
mockPrune.mockReturnValue({ messages: createMessages(5), freedTokens: 0 }); mockPrune.mockReturnValue({ messages: createMessages(5), freedTokens: 0 });
mockSimpleCompact.mockReturnValue({ messages: [], freedTokens: 1000 }); mockSimpleCompact.mockReturnValue({ messages: [], freedTokens: 1000, type: 'compaction', status: CompressionStatus.SUCCESS });
const messages = createMessages(5); const messages = createMessages(5);
const result = await manager.compress(messages); const result = await manager.compress(messages);
@@ -243,7 +246,7 @@ describe('CompressionManager - 压缩管理器扩展测试', () => {
it('足够消息时执行压缩', async () => { it('足够消息时执行压缩', async () => {
mockEstimateMessages.mockReturnValue(10000); mockEstimateMessages.mockReturnValue(10000);
mockPrune.mockReturnValue({ messages: createMessages(3), freedTokens: 500 }); mockPrune.mockReturnValue({ messages: createMessages(3), freedTokens: 500 });
mockSimpleCompact.mockReturnValue({ messages: createMessages(2), freedTokens: 300 }); mockSimpleCompact.mockReturnValue({ messages: createMessages(2), freedTokens: 300, type: 'compaction', status: CompressionStatus.SUCCESS });
const messages = createMessages(10); const messages = createMessages(10);
const result = await manager.forceCompress(messages); const result = await manager.forceCompress(messages);
@@ -256,7 +259,7 @@ describe('CompressionManager - 压缩管理器扩展测试', () => {
manager.setModel(mockModel); manager.setModel(mockModel);
mockEstimateMessages.mockReturnValue(10000); mockEstimateMessages.mockReturnValue(10000);
mockPrune.mockReturnValue({ messages: createMessages(5), freedTokens: 500 }); mockPrune.mockReturnValue({ messages: createMessages(5), freedTokens: 500 });
mockCompact.mockResolvedValue({ messages: createMessages(2), freedTokens: 1000 }); mockCompact.mockResolvedValue({ messages: createMessages(2), freedTokens: 1000, type: 'compaction', status: CompressionStatus.SUCCESS });
const messages = createMessages(10); const messages = createMessages(10);
await manager.forceCompress(messages); await manager.forceCompress(messages);
@@ -269,8 +272,9 @@ describe('CompressionManager - 压缩管理器扩展测试', () => {
manager.setModel(mockModel); manager.setModel(mockModel);
mockEstimateMessages.mockReturnValue(10000); mockEstimateMessages.mockReturnValue(10000);
mockPrune.mockReturnValue({ messages: createMessages(5), freedTokens: 500 }); mockPrune.mockReturnValue({ messages: createMessages(5), freedTokens: 500 });
mockCompact.mockRejectedValue(new Error('AI error')); // 使用 FAILED 状态而不是 reject
mockSimpleCompact.mockReturnValue({ messages: createMessages(2), freedTokens: 800 }); mockCompact.mockResolvedValue({ messages: createMessages(5), freedTokens: 0, type: 'none', status: CompressionStatus.FAILED_ERROR });
mockSimpleCompact.mockReturnValue({ messages: createMessages(2), freedTokens: 800, type: 'compaction', status: CompressionStatus.SUCCESS });
const messages = createMessages(10); const messages = createMessages(10);
const result = await manager.forceCompress(messages); const result = await manager.forceCompress(messages);
@@ -213,7 +213,8 @@ describe('CompressionManager - 压缩管理器', () => {
const result = await manager.compress(messages); const result = await manager.compress(messages);
expect(['prune', 'compaction', 'both']).toContain(result.type); // 小对话不压缩时返回 'none'
expect(['prune', 'compaction', 'both', 'none']).toContain(result.type);
}); });
}); });
+2 -1
View File
@@ -49,9 +49,10 @@ vi.mock('../../../src/agent/index.js', () => ({
})), })),
})); }));
// Mock vision config // Mock vision and summary config
vi.mock('../../../src/utils/config.js', () => ({ vi.mock('../../../src/utils/config.js', () => ({
loadVisionConfig: vi.fn(() => null), loadVisionConfig: vi.fn(() => null),
loadSummaryConfig: vi.fn(() => null),
})); }));
// Create mock tool // Create mock tool