feat(ui): 优化流式输出工具调用渲染

- 添加 tool_start/tool_end WebSocket 事件支持
- 流式消息复用 ChatMessage 组件渲染工具调用卡片
- 修复 AI SDK v5 格式兼容问题(input/output 字段)
- 修复会话恢复时 tool-result 格式错误
- 放宽 ToolState schema 中 input 字段类型为 unknown
This commit is contained in:
2025-12-15 17:35:39 +08:00
parent 865e0906b9
commit 3fd8fd98b8
12 changed files with 384 additions and 58 deletions
+74 -11
View File
@@ -23,11 +23,35 @@ import { getProviderRegistry, resolveApiKey } from '../provider/index.js';
import { getHookManager } from '../hooks/index.js';
import { getGitManager } from '../git/index.js';
/**
* 工具调用开始事件信息
*/
export interface ToolStartInfo {
id: string;
toolName: string;
args: Record<string, unknown>;
}
/**
* 工具调用结束事件信息
*/
export interface ToolEndInfo {
id: string;
status: 'completed' | 'error';
result?: unknown;
error?: string;
duration?: number;
}
/**
* Agent.chat() 选项
*/
export interface AgentChatOptions {
onStream?: (text: string) => void;
/** 工具开始执行回调 */
onToolStart?: (info: ToolStartInfo) => void;
/** 工具执行完成回调 */
onToolEnd?: (info: ToolEndInfo) => void;
abortSignal?: AbortSignal;
}
@@ -324,7 +348,10 @@ export class Agent {
async chat(userMessage: string | UserInput, options?: AgentChatOptions | ((text: string) => void)): Promise<ChatResult> {
// 兼容旧的 onStream 参数
const opts: AgentChatOptions = typeof options === 'function' ? { onStream: options } : (options || {});
const { onStream, abortSignal } = opts;
const { onStream, onToolStart, onToolEnd, abortSignal } = opts;
// 工具调用时间跟踪
const toolStartTimes = new Map<string, number>();
// 处理带图片的消息
let processedMessage = userMessage;
@@ -405,19 +432,55 @@ export class Agent {
abortSignal, // 支持取消
onChunk: ({ chunk }) => {
if (chunk.type === 'tool-call') {
onStream(`\n[调用工具: ${chunk.toolName}]\n`);
// AI SDK 中工具参数字段名为 input
const toolCallChunk = chunk as { toolCallId: string; toolName: string; input: unknown };
const toolCallId = toolCallChunk.toolCallId || `tool-${Date.now()}`;
// 记录开始时间
toolStartTimes.set(toolCallId, Date.now());
// 调用 onToolStart 回调
if (onToolStart) {
onToolStart({
id: toolCallId,
toolName: toolCallChunk.toolName,
args: (toolCallChunk.input as Record<string, unknown>) || {},
});
} else {
// 仅在没有 onToolStart 回调时输出文本(向后兼容 CLI)
onStream?.(`\n[调用工具: ${toolCallChunk.toolName}]\n`);
}
} else if (chunk.type === 'tool-result') {
const output = (chunk as { output?: ToolResult }).output;
const toolResultChunk = chunk as { toolCallId: string; output?: ToolResult };
const toolCallId = toolResultChunk.toolCallId || '';
const output = toolResultChunk.output;
// 计算执行时长
const startTime = toolStartTimes.get(toolCallId);
const duration = startTime ? Date.now() - startTime : undefined;
toolStartTimes.delete(toolCallId);
if (output && typeof output === 'object') {
if (output.success) {
// 截断过长的输出
const displayOutput =
output.output.length > 500
? output.output.substring(0, 500) + '...(截断)'
: output.output;
onStream(`[结果: ${displayOutput}]\n`);
// 调用 onToolEnd 回调
if (onToolEnd) {
onToolEnd({
id: toolCallId,
status: output.success ? 'completed' : 'error',
result: output.success ? output.output : undefined,
error: output.success ? undefined : output.error,
duration,
});
} else {
onStream(`[错误: ${output.error}]\n`);
// 仅在没有 onToolEnd 回调时输出文本(向后兼容 CLI)
if (output.success) {
const displayOutput =
output.output.length > 500
? output.output.substring(0, 500) + '...(截断)'
: output.output;
onStream?.(`[结果: ${displayOutput}]\n`);
} else {
onStream?.(`[错误: ${output.error}]\n`);
}
}
}
}
+1 -1
View File
@@ -1,5 +1,5 @@
export { Agent } from './core/agent.js';
export type { AgentChatOptions } from './core/agent.js';
export type { AgentChatOptions, ToolStartInfo, ToolEndInfo } from './core/agent.js';
export { toolRegistry, todoManager, initTaskContext, updateTaskDescription, updateSkillDescription } from './tools/index.js';
export { loadConfig, saveConfig, getConfig, loadVisionConfig, ConfigurationError } from './utils/config.js';
export type { VisionConfig } from './utils/config.js';
+15 -3
View File
@@ -45,13 +45,14 @@ function messageToModelMessages(msg: Message): ModelMessage[] {
}
// 添加工具调用部分(只有 running 或已完成的工具)
// AI SDK v5 使用 input 字段(不是 args
for (const toolPart of toolParts) {
if (toolPart.state.status !== 'pending') {
assistantContent.push({
type: 'tool-call',
toolCallId: toolPart.toolCallId,
toolName: toolPart.toolName,
args: toolPart.state.input,
input: toolPart.state.input,
});
}
}
@@ -87,12 +88,17 @@ function messageToModelMessages(msg: Message): ModelMessage[] {
const output = state.status === 'completed'
? (state as { output: unknown }).output
: (state as { error: string }).error;
// 获取 inputAI SDK v5 要求 tool-result 必须包含 input
const input = state.status !== 'pending'
? (state as { input: Record<string, unknown> }).input
: {};
return {
type: 'tool-result' as const,
toolCallId: toolPart.toolCallId,
toolName: toolPart.toolName,
result: output,
input,
output,
};
});
@@ -134,12 +140,18 @@ export function toModelMessages(messages: Message[]): ModelMessage[] {
/**
* 获取工具调用的输入参数(兼容不同状态)
* 注意:AI SDK 的 input 是 unknown 类型,这里做安全转换
*/
export function getToolInput(toolPart: ToolPart): Record<string, unknown> {
if (toolPart.state.status === 'pending') {
return {};
}
return toolPart.state.input;
const input = toolPart.state.input;
// 安全转换:如果是对象返回对象,否则返回空对象
if (input && typeof input === 'object' && !Array.isArray(input)) {
return input as Record<string, unknown>;
}
return {};
}
/**
+12 -5
View File
@@ -224,7 +224,8 @@ export class SessionManager {
} else if (role === 'assistant') {
// Assistant 消息:文本 + 工具调用
const content: unknown[] = [];
const completedTools: Array<{ toolCallId: string; toolName: string; output: unknown }> = [];
// input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等)
const completedTools: Array<{ toolCallId: string; toolName: string; input: unknown; output: unknown }> = [];
for (const part of parts) {
if (part.type === 'text') {
@@ -232,11 +233,12 @@ export class SessionManager {
} else if (part.type === 'tool') {
// 只有非 pending 状态的工具调用才添加到 AI SDK 消息
if (part.state.status !== 'pending') {
// AI SDK v5 使用 input 字段(不是 args
content.push({
type: 'tool-call',
toolCallId: part.toolCallId,
toolName: part.toolName,
args: part.state.input,
input: part.state.input,
});
// 收集已完成的工具结果
@@ -244,12 +246,14 @@ export class SessionManager {
completedTools.push({
toolCallId: part.toolCallId,
toolName: part.toolName,
input: part.state.input,
output: part.state.output,
});
} else if (part.state.status === 'error') {
completedTools.push({
toolCallId: part.toolCallId,
toolName: part.toolName,
input: part.state.input,
output: part.state.error,
});
}
@@ -273,6 +277,7 @@ export class SessionManager {
}
// 添加 tool 消息(如果有已完成的工具)
// AI SDK v5 要求 tool-result 必须包含 input 和 output 字段
if (completedTools.length > 0) {
result.push({
role: 'tool',
@@ -280,7 +285,8 @@ export class SessionManager {
type: 'tool-result',
toolCallId: t.toolCallId,
toolName: t.toolName,
result: t.output,
input: t.input,
output: t.output,
})),
} as unknown as ModelMessage);
}
@@ -454,7 +460,8 @@ export class SessionManager {
for (const item of message.content) {
const itemType = (item as { type: string }).type;
if (itemType === 'tool-result') {
const toolResult = item as unknown as { toolCallId: string; toolName: string; result: unknown };
// AI SDK v5 使用 output 字段存储结果(不是 result)
const toolResult = item as unknown as { toolCallId: string; toolName: string; output: unknown };
const partId = toolCallPartIds.get(toolResult.toolCallId);
if (partId) {
// 更新工具状态为 completed
@@ -463,7 +470,7 @@ export class SessionManager {
const startTime = part?.type === 'tool' && part.state.status === 'running'
? part.state.time.start
: Date.now();
await PartStorage.setToolCompleted(currentAssistantMsgId, partId, toolResult.result, startTime);
await PartStorage.setToolCompleted(currentAssistantMsgId, partId, toolResult.output, startTime);
}
}
}
+6 -3
View File
@@ -27,20 +27,22 @@ export type ToolStatePending = z.infer<typeof ToolStatePendingSchema>;
/**
* 工具状态机 - Running(执行中)
* 注意:input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等)
*/
export const ToolStateRunningSchema = z.object({
status: z.literal('running'),
input: z.record(z.string(), z.unknown()),
input: z.unknown(),
time: z.object({ start: z.number() }),
});
export type ToolStateRunning = z.infer<typeof ToolStateRunningSchema>;
/**
* 工具状态机 - Completed(执行完成)
* 注意:input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等)
*/
export const ToolStateCompletedSchema = z.object({
status: z.literal('completed'),
input: z.record(z.string(), z.unknown()),
input: z.unknown(),
output: z.unknown(),
time: z.object({ start: z.number(), end: z.number() }),
});
@@ -48,10 +50,11 @@ export type ToolStateCompleted = z.infer<typeof ToolStateCompletedSchema>;
/**
* 工具状态机 - Error(执行出错)
* 注意:input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等)
*/
export const ToolStateErrorSchema = z.object({
status: z.literal('error'),
input: z.record(z.string(), z.unknown()),
input: z.unknown(),
error: z.string(),
time: z.object({ start: z.number(), end: z.number() }),
});
+7 -5
View File
@@ -9,7 +9,6 @@ import { toast } from 'sonner';
import {
useChat,
ChatMessage,
StreamingMessage,
TypingIndicator,
ChatInput,
} from '@ai-assistant/ui';
@@ -46,7 +45,7 @@ export function ChatPage({
messages,
isConnected,
isLoading,
streamingContent,
streamingMessage,
sendMessage,
cancelProcessing,
} = useChat({
@@ -73,7 +72,7 @@ export function ChatPage({
// 自动滚动到底部
useEffect(() => {
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' });
}, [messages, streamingContent]);
}, [messages, streamingMessage]);
// 空状态组件
const EmptyState = () => (
@@ -270,9 +269,12 @@ export function ChatPage({
))}
</AnimatePresence>
{streamingContent && <StreamingMessage content={streamingContent} />}
{/* 流式消息 - 复用 ChatMessage 组件 */}
{streamingMessage && (
<ChatMessage message={streamingMessage} isStreaming />
)}
{isLoading && !streamingContent && <TypingIndicator />}
{isLoading && !streamingMessage && <TypingIndicator />}
<div ref={messagesEndRef} />
</div>
+55 -1
View File
@@ -67,11 +67,33 @@ interface SessionManagerConstructor {
new (): SessionManagerInstance;
}
/**
* 工具开始信息
*/
interface ToolStartInfo {
id: string;
toolName: string;
args: Record<string, unknown>;
}
/**
* 工具结束信息
*/
interface ToolEndInfo {
id: string;
status: 'completed' | 'error';
result?: unknown;
error?: string;
duration?: number;
}
/**
* Chat 选项接口
*/
interface ChatOptions {
onStream?: (chunk: string) => void;
onToolStart?: (info: ToolStartInfo) => void;
onToolEnd?: (info: ToolEndInfo) => void;
abortSignal?: AbortSignal;
}
@@ -397,7 +419,7 @@ export async function processMessage(sessionId: string, content: string): Promis
payload: { content: chunk },
});
// 检测工具调用
// 检测工具调用(向后兼容 - SSE 日志)
if (chunk.includes('[调用工具:')) {
const match = chunk.match(/\[调用工具: (.+?)\]/);
if (match) {
@@ -405,6 +427,38 @@ export async function processMessage(sessionId: string, content: string): Promis
}
}
},
onToolStart: (info) => {
// 检查是否已取消
if (abortController.signal.aborted) return;
// 推送工具开始事件
broadcastToSession(sessionId, {
type: 'tool_start',
sessionId,
payload: {
id: info.id,
toolName: info.toolName,
arguments: info.args,
},
});
},
onToolEnd: (info) => {
// 检查是否已取消
if (abortController.signal.aborted) return;
// 推送工具结束事件
broadcastToSession(sessionId, {
type: 'tool_end',
sessionId,
payload: {
id: info.id,
status: info.status,
result: info.result,
error: info.error,
duration: info.duration,
},
});
},
abortSignal: abortController.signal,
});
+18
View File
@@ -110,6 +110,8 @@ export interface ServerMessage {
| 'chunk'
| 'tool_call'
| 'tool_result'
| 'tool_start' // 工具开始执行
| 'tool_end' // 工具执行完成
| 'done'
| 'cancelled'
| 'error'
@@ -119,6 +121,22 @@ export interface ServerMessage {
payload?: unknown;
}
// 工具开始事件 Payload
export interface ToolStartPayload {
id: string;
toolName: string;
arguments: Record<string, unknown>;
}
// 工具结束事件 Payload
export interface ToolEndPayload {
id: string;
status: 'completed' | 'error';
result?: unknown;
error?: string;
duration?: number;
}
// ============ Permission 相关 ============
export type PermissionType = 'bash' | 'file' | 'git' | 'web';
+26
View File
@@ -877,3 +877,29 @@ export interface FileSearchResponse {
};
}
// ============ 流式工具调用事件 ============
/** 工具开始事件 Payload */
export interface ToolStartPayload {
/** 工具调用唯一 ID */
id: string;
/** 工具名称 */
toolName: string;
/** 调用参数 */
arguments: Record<string, unknown>;
}
/** 工具结束事件 Payload */
export interface ToolEndPayload {
/** 对应 tool_start 的 ID */
id: string;
/** 执行状态 */
status: 'completed' | 'error';
/** 执行结果 */
result?: unknown;
/** 错误信息 */
error?: string;
/** 执行时长 (ms) */
duration?: number;
}
+27 -4
View File
@@ -25,10 +25,12 @@ import type { Message, ToolCallInfo, ToolCallStatus, ToolMessagePart } from '../
interface ChatMessageProps {
message: Message;
/** 是否为流式输出中(显示打字光标) */
isStreaming?: boolean;
}
export const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>(
({ message }, ref) => {
({ message, isStreaming = false }, ref) => {
const isUser = message.role === 'user';
const [copied, setCopied] = useState(false);
@@ -42,18 +44,39 @@ export const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>(
const renderContent = () => {
// 优先使用 parts 数组(保持原始顺序)
if (message.parts && message.parts.length > 0) {
// 查找最后一个文本 part 的索引(用于显示打字光标)
let lastTextPartIndex = -1;
if (isStreaming) {
for (let i = message.parts.length - 1; i >= 0; i--) {
if (message.parts[i].type === 'text') {
lastTextPartIndex = i;
break;
}
}
}
return (
<div className="message-content text-fg-secondary space-y-3">
{message.parts.map((part) => {
{message.parts.map((part, index) => {
switch (part.type) {
case 'text':
if (!part.text) return null;
if (!part.text && index !== lastTextPartIndex) return null;
return isUser ? (
<div key={part.id}>
<FileMentionText text={part.text} />
</div>
) : (
<Markdown key={part.id} content={part.text} />
<div key={part.id}>
<Markdown content={part.text} />
{/* 流式输出时在最后一个文本末尾显示打字光标 */}
{isStreaming && index === lastTextPartIndex && (
<motion.span
animate={{ opacity: [1, 0] }}
transition={{ duration: 0.8, repeat: Infinity, repeatType: 'reverse' }}
className="inline-block w-2 h-4 bg-primary-400 ml-1 rounded-sm align-middle"
/>
)}
</div>
);
case 'tool':
return <ToolPartItem key={part.id} part={part} />;
+136 -20
View File
@@ -7,7 +7,13 @@
import { useState, useEffect, useCallback, useRef } from 'react';
import { createWebSocket, getMessages, type Message } from '../api/client.js';
import type { PermissionRequest } from '../components/PermissionDialog.js';
import type { ConfigErrorPayload } from '../api/types.js';
import type {
ConfigErrorPayload,
ToolStartPayload,
ToolEndPayload,
MessagePart,
ToolMessagePart,
} from '../api/types.js';
interface UseChatOptions {
sessionId: string;
@@ -22,7 +28,8 @@ interface ChatState {
messages: Message[];
isConnected: boolean;
isLoading: boolean;
streamingContent: string;
/** 流式消息对象,复用 Message 结构 */
streamingMessage: Message | null;
permissionRequest: PermissionRequest | null;
}
@@ -31,7 +38,7 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate
messages: [],
isConnected: false,
isLoading: false,
streamingContent: '',
streamingMessage: null,
permissionRequest: null,
});
@@ -114,27 +121,136 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate
const message = JSON.parse(event.data);
switch (message.type) {
case 'chunk':
setState((prev) => ({
...prev,
streamingContent: prev.streamingContent + (message.payload?.content || ''),
}));
case 'chunk': {
const chunkContent = message.payload?.content || '';
setState((prev) => {
// 初始化或获取当前流式消息
const streaming = prev.streamingMessage || {
id: `streaming-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'assistant' as const,
timestamp: new Date().toISOString(),
parts: [] as MessagePart[],
content: '',
};
// 复制 parts 数组以进行修改
const parts = [...streaming.parts];
const lastPart = parts[parts.length - 1];
// 如果最后一个 part 是 text,追加内容;否则创建新 text part
if (lastPart?.type === 'text') {
parts[parts.length - 1] = {
...lastPart,
text: lastPart.text + chunkContent,
};
} else if (chunkContent) {
parts.push({
type: 'text',
id: `text-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
text: chunkContent,
});
}
return {
...prev,
streamingMessage: {
...streaming,
parts,
content: (streaming.content || '') + chunkContent,
},
};
});
break;
}
case 'tool_start': {
const payload = message.payload as ToolStartPayload;
setState((prev) => {
// 初始化或获取当前流式消息
const streaming = prev.streamingMessage || {
id: `streaming-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'assistant' as const,
timestamp: new Date().toISOString(),
parts: [] as MessagePart[],
content: '',
};
// 添加工具调用 part
const toolPart: ToolMessagePart = {
type: 'tool',
id: payload.id,
toolCallId: payload.id,
toolName: payload.toolName,
status: 'running',
arguments: payload.arguments,
};
return {
...prev,
streamingMessage: {
...streaming,
parts: [...streaming.parts, toolPart],
},
};
});
break;
}
case 'tool_end': {
const payload = message.payload as ToolEndPayload;
setState((prev) => {
if (!prev.streamingMessage) return prev;
// 查找并更新对应的工具 part
const parts = prev.streamingMessage.parts.map((part) => {
if (part.type === 'tool' && part.id === payload.id) {
return {
...part,
status: payload.status,
result: payload.result,
error: payload.error,
duration: payload.duration,
} as ToolMessagePart;
}
return part;
});
return {
...prev,
streamingMessage: {
...prev.streamingMessage,
parts,
},
};
});
break;
}
case 'done':
setState((prev) => {
const content = message.payload?.content || prev.streamingContent;
const newMessage: Message = {
id: message.payload?.id || `assistant-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'assistant',
timestamp: message.payload?.timestamp || new Date().toISOString(),
parts: [{ type: 'text', id: `text-${Date.now()}`, text: content }],
content,
};
// 使用流式消息或创建新消息
const streaming = prev.streamingMessage;
const content = message.payload?.content || streaming?.content || '';
const newMessage: Message = streaming
? {
...streaming,
id: message.payload?.id || streaming.id,
timestamp: message.payload?.timestamp || streaming.timestamp,
content,
}
: {
id: message.payload?.id || `assistant-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'assistant',
timestamp: message.payload?.timestamp || new Date().toISOString(),
parts: [{ type: 'text', id: `text-${Date.now()}`, text: content }],
content,
};
return {
...prev,
messages: [...prev.messages, newMessage],
streamingContent: '',
streamingMessage: null,
isLoading: false,
};
});
@@ -165,7 +281,7 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate
} else {
onErrorRef.current?.(new Error(message.payload?.message || 'Unknown error'));
}
setState((prev) => ({ ...prev, isLoading: false, streamingContent: '' }));
setState((prev) => ({ ...prev, isLoading: false, streamingMessage: null }));
break;
case 'session_updated':
@@ -225,7 +341,7 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate
})
);
setState((prev) => ({ ...prev, isLoading: false, streamingContent: '' }));
setState((prev) => ({ ...prev, isLoading: false, streamingMessage: null }));
}, [sessionId]);
// 发送权限响应
@@ -274,7 +390,7 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate
messages: [],
isConnected: false,
isLoading: false,
streamingContent: '',
streamingMessage: null,
permissionRequest: null,
});
reconnectAttemptsRef.current = 0;
+7 -5
View File
@@ -9,7 +9,6 @@ import { toast } from 'sonner';
import {
useChat,
ChatMessage,
StreamingMessage,
TypingIndicator,
ChatInput,
PermissionDialog,
@@ -52,7 +51,7 @@ export function ChatPage({
messages,
isConnected,
isLoading,
streamingContent,
streamingMessage,
sendMessage,
cancelProcessing,
permissionRequest,
@@ -83,7 +82,7 @@ export function ChatPage({
// 自动滚动到底部
useEffect(() => {
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' });
}, [messages, streamingContent]);
}, [messages, streamingMessage]);
// 空状态组件
const EmptyState = () => (
@@ -290,9 +289,12 @@ export function ChatPage({
))}
</AnimatePresence>
{streamingContent && <StreamingMessage content={streamingContent} />}
{/* 流式消息 - 复用 ChatMessage 组件 */}
{streamingMessage && (
<ChatMessage message={streamingMessage} isStreaming />
)}
{isLoading && !streamingContent && <TypingIndicator />}
{isLoading && !streamingMessage && <TypingIndicator />}
<div ref={messagesEndRef} />
</div>