feat(ui): 优化流式输出工具调用渲染
- 添加 tool_start/tool_end WebSocket 事件支持 - 流式消息复用 ChatMessage 组件渲染工具调用卡片 - 修复 AI SDK v5 格式兼容问题(input/output 字段) - 修复会话恢复时 tool-result 格式错误 - 放宽 ToolState schema 中 input 字段类型为 unknown
This commit is contained in:
@@ -23,11 +23,35 @@ import { getProviderRegistry, resolveApiKey } from '../provider/index.js';
|
||||
import { getHookManager } from '../hooks/index.js';
|
||||
import { getGitManager } from '../git/index.js';
|
||||
|
||||
/**
|
||||
* 工具调用开始事件信息
|
||||
*/
|
||||
export interface ToolStartInfo {
|
||||
id: string;
|
||||
toolName: string;
|
||||
args: Record<string, unknown>;
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具调用结束事件信息
|
||||
*/
|
||||
export interface ToolEndInfo {
|
||||
id: string;
|
||||
status: 'completed' | 'error';
|
||||
result?: unknown;
|
||||
error?: string;
|
||||
duration?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Agent.chat() 选项
|
||||
*/
|
||||
export interface AgentChatOptions {
|
||||
onStream?: (text: string) => void;
|
||||
/** 工具开始执行回调 */
|
||||
onToolStart?: (info: ToolStartInfo) => void;
|
||||
/** 工具执行完成回调 */
|
||||
onToolEnd?: (info: ToolEndInfo) => void;
|
||||
abortSignal?: AbortSignal;
|
||||
}
|
||||
|
||||
@@ -324,7 +348,10 @@ export class Agent {
|
||||
async chat(userMessage: string | UserInput, options?: AgentChatOptions | ((text: string) => void)): Promise<ChatResult> {
|
||||
// 兼容旧的 onStream 参数
|
||||
const opts: AgentChatOptions = typeof options === 'function' ? { onStream: options } : (options || {});
|
||||
const { onStream, abortSignal } = opts;
|
||||
const { onStream, onToolStart, onToolEnd, abortSignal } = opts;
|
||||
|
||||
// 工具调用时间跟踪
|
||||
const toolStartTimes = new Map<string, number>();
|
||||
// 处理带图片的消息
|
||||
let processedMessage = userMessage;
|
||||
|
||||
@@ -405,19 +432,55 @@ export class Agent {
|
||||
abortSignal, // 支持取消
|
||||
onChunk: ({ chunk }) => {
|
||||
if (chunk.type === 'tool-call') {
|
||||
onStream(`\n[调用工具: ${chunk.toolName}]\n`);
|
||||
// AI SDK 中工具参数字段名为 input
|
||||
const toolCallChunk = chunk as { toolCallId: string; toolName: string; input: unknown };
|
||||
const toolCallId = toolCallChunk.toolCallId || `tool-${Date.now()}`;
|
||||
|
||||
// 记录开始时间
|
||||
toolStartTimes.set(toolCallId, Date.now());
|
||||
|
||||
// 调用 onToolStart 回调
|
||||
if (onToolStart) {
|
||||
onToolStart({
|
||||
id: toolCallId,
|
||||
toolName: toolCallChunk.toolName,
|
||||
args: (toolCallChunk.input as Record<string, unknown>) || {},
|
||||
});
|
||||
} else {
|
||||
// 仅在没有 onToolStart 回调时输出文本(向后兼容 CLI)
|
||||
onStream?.(`\n[调用工具: ${toolCallChunk.toolName}]\n`);
|
||||
}
|
||||
} else if (chunk.type === 'tool-result') {
|
||||
const output = (chunk as { output?: ToolResult }).output;
|
||||
const toolResultChunk = chunk as { toolCallId: string; output?: ToolResult };
|
||||
const toolCallId = toolResultChunk.toolCallId || '';
|
||||
const output = toolResultChunk.output;
|
||||
|
||||
// 计算执行时长
|
||||
const startTime = toolStartTimes.get(toolCallId);
|
||||
const duration = startTime ? Date.now() - startTime : undefined;
|
||||
toolStartTimes.delete(toolCallId);
|
||||
|
||||
if (output && typeof output === 'object') {
|
||||
if (output.success) {
|
||||
// 截断过长的输出
|
||||
const displayOutput =
|
||||
output.output.length > 500
|
||||
? output.output.substring(0, 500) + '...(截断)'
|
||||
: output.output;
|
||||
onStream(`[结果: ${displayOutput}]\n`);
|
||||
// 调用 onToolEnd 回调
|
||||
if (onToolEnd) {
|
||||
onToolEnd({
|
||||
id: toolCallId,
|
||||
status: output.success ? 'completed' : 'error',
|
||||
result: output.success ? output.output : undefined,
|
||||
error: output.success ? undefined : output.error,
|
||||
duration,
|
||||
});
|
||||
} else {
|
||||
onStream(`[错误: ${output.error}]\n`);
|
||||
// 仅在没有 onToolEnd 回调时输出文本(向后兼容 CLI)
|
||||
if (output.success) {
|
||||
const displayOutput =
|
||||
output.output.length > 500
|
||||
? output.output.substring(0, 500) + '...(截断)'
|
||||
: output.output;
|
||||
onStream?.(`[结果: ${displayOutput}]\n`);
|
||||
} else {
|
||||
onStream?.(`[错误: ${output.error}]\n`);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
export { Agent } from './core/agent.js';
|
||||
export type { AgentChatOptions } from './core/agent.js';
|
||||
export type { AgentChatOptions, ToolStartInfo, ToolEndInfo } from './core/agent.js';
|
||||
export { toolRegistry, todoManager, initTaskContext, updateTaskDescription, updateSkillDescription } from './tools/index.js';
|
||||
export { loadConfig, saveConfig, getConfig, loadVisionConfig, ConfigurationError } from './utils/config.js';
|
||||
export type { VisionConfig } from './utils/config.js';
|
||||
|
||||
@@ -45,13 +45,14 @@ function messageToModelMessages(msg: Message): ModelMessage[] {
|
||||
}
|
||||
|
||||
// 添加工具调用部分(只有 running 或已完成的工具)
|
||||
// AI SDK v5 使用 input 字段(不是 args)
|
||||
for (const toolPart of toolParts) {
|
||||
if (toolPart.state.status !== 'pending') {
|
||||
assistantContent.push({
|
||||
type: 'tool-call',
|
||||
toolCallId: toolPart.toolCallId,
|
||||
toolName: toolPart.toolName,
|
||||
args: toolPart.state.input,
|
||||
input: toolPart.state.input,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -87,12 +88,17 @@ function messageToModelMessages(msg: Message): ModelMessage[] {
|
||||
const output = state.status === 'completed'
|
||||
? (state as { output: unknown }).output
|
||||
: (state as { error: string }).error;
|
||||
// 获取 input(AI SDK v5 要求 tool-result 必须包含 input)
|
||||
const input = state.status !== 'pending'
|
||||
? (state as { input: Record<string, unknown> }).input
|
||||
: {};
|
||||
|
||||
return {
|
||||
type: 'tool-result' as const,
|
||||
toolCallId: toolPart.toolCallId,
|
||||
toolName: toolPart.toolName,
|
||||
result: output,
|
||||
input,
|
||||
output,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -134,12 +140,18 @@ export function toModelMessages(messages: Message[]): ModelMessage[] {
|
||||
|
||||
/**
|
||||
* 获取工具调用的输入参数(兼容不同状态)
|
||||
* 注意:AI SDK 的 input 是 unknown 类型,这里做安全转换
|
||||
*/
|
||||
export function getToolInput(toolPart: ToolPart): Record<string, unknown> {
|
||||
if (toolPart.state.status === 'pending') {
|
||||
return {};
|
||||
}
|
||||
return toolPart.state.input;
|
||||
const input = toolPart.state.input;
|
||||
// 安全转换:如果是对象返回对象,否则返回空对象
|
||||
if (input && typeof input === 'object' && !Array.isArray(input)) {
|
||||
return input as Record<string, unknown>;
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -224,7 +224,8 @@ export class SessionManager {
|
||||
} else if (role === 'assistant') {
|
||||
// Assistant 消息:文本 + 工具调用
|
||||
const content: unknown[] = [];
|
||||
const completedTools: Array<{ toolCallId: string; toolName: string; output: unknown }> = [];
|
||||
// input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等)
|
||||
const completedTools: Array<{ toolCallId: string; toolName: string; input: unknown; output: unknown }> = [];
|
||||
|
||||
for (const part of parts) {
|
||||
if (part.type === 'text') {
|
||||
@@ -232,11 +233,12 @@ export class SessionManager {
|
||||
} else if (part.type === 'tool') {
|
||||
// 只有非 pending 状态的工具调用才添加到 AI SDK 消息
|
||||
if (part.state.status !== 'pending') {
|
||||
// AI SDK v5 使用 input 字段(不是 args)
|
||||
content.push({
|
||||
type: 'tool-call',
|
||||
toolCallId: part.toolCallId,
|
||||
toolName: part.toolName,
|
||||
args: part.state.input,
|
||||
input: part.state.input,
|
||||
});
|
||||
|
||||
// 收集已完成的工具结果
|
||||
@@ -244,12 +246,14 @@ export class SessionManager {
|
||||
completedTools.push({
|
||||
toolCallId: part.toolCallId,
|
||||
toolName: part.toolName,
|
||||
input: part.state.input,
|
||||
output: part.state.output,
|
||||
});
|
||||
} else if (part.state.status === 'error') {
|
||||
completedTools.push({
|
||||
toolCallId: part.toolCallId,
|
||||
toolName: part.toolName,
|
||||
input: part.state.input,
|
||||
output: part.state.error,
|
||||
});
|
||||
}
|
||||
@@ -273,6 +277,7 @@ export class SessionManager {
|
||||
}
|
||||
|
||||
// 添加 tool 消息(如果有已完成的工具)
|
||||
// AI SDK v5 要求 tool-result 必须包含 input 和 output 字段
|
||||
if (completedTools.length > 0) {
|
||||
result.push({
|
||||
role: 'tool',
|
||||
@@ -280,7 +285,8 @@ export class SessionManager {
|
||||
type: 'tool-result',
|
||||
toolCallId: t.toolCallId,
|
||||
toolName: t.toolName,
|
||||
result: t.output,
|
||||
input: t.input,
|
||||
output: t.output,
|
||||
})),
|
||||
} as unknown as ModelMessage);
|
||||
}
|
||||
@@ -454,7 +460,8 @@ export class SessionManager {
|
||||
for (const item of message.content) {
|
||||
const itemType = (item as { type: string }).type;
|
||||
if (itemType === 'tool-result') {
|
||||
const toolResult = item as unknown as { toolCallId: string; toolName: string; result: unknown };
|
||||
// AI SDK v5 使用 output 字段存储结果(不是 result)
|
||||
const toolResult = item as unknown as { toolCallId: string; toolName: string; output: unknown };
|
||||
const partId = toolCallPartIds.get(toolResult.toolCallId);
|
||||
if (partId) {
|
||||
// 更新工具状态为 completed
|
||||
@@ -463,7 +470,7 @@ export class SessionManager {
|
||||
const startTime = part?.type === 'tool' && part.state.status === 'running'
|
||||
? part.state.time.start
|
||||
: Date.now();
|
||||
await PartStorage.setToolCompleted(currentAssistantMsgId, partId, toolResult.result, startTime);
|
||||
await PartStorage.setToolCompleted(currentAssistantMsgId, partId, toolResult.output, startTime);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,20 +27,22 @@ export type ToolStatePending = z.infer<typeof ToolStatePendingSchema>;
|
||||
|
||||
/**
|
||||
* 工具状态机 - Running(执行中)
|
||||
* 注意:input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等)
|
||||
*/
|
||||
export const ToolStateRunningSchema = z.object({
|
||||
status: z.literal('running'),
|
||||
input: z.record(z.string(), z.unknown()),
|
||||
input: z.unknown(),
|
||||
time: z.object({ start: z.number() }),
|
||||
});
|
||||
export type ToolStateRunning = z.infer<typeof ToolStateRunningSchema>;
|
||||
|
||||
/**
|
||||
* 工具状态机 - Completed(执行完成)
|
||||
* 注意:input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等)
|
||||
*/
|
||||
export const ToolStateCompletedSchema = z.object({
|
||||
status: z.literal('completed'),
|
||||
input: z.record(z.string(), z.unknown()),
|
||||
input: z.unknown(),
|
||||
output: z.unknown(),
|
||||
time: z.object({ start: z.number(), end: z.number() }),
|
||||
});
|
||||
@@ -48,10 +50,11 @@ export type ToolStateCompleted = z.infer<typeof ToolStateCompletedSchema>;
|
||||
|
||||
/**
|
||||
* 工具状态机 - Error(执行出错)
|
||||
* 注意:input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等)
|
||||
*/
|
||||
export const ToolStateErrorSchema = z.object({
|
||||
status: z.literal('error'),
|
||||
input: z.record(z.string(), z.unknown()),
|
||||
input: z.unknown(),
|
||||
error: z.string(),
|
||||
time: z.object({ start: z.number(), end: z.number() }),
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user