Files
ai-terminal-assistant/packages/ui/src/hooks/useChat.ts
T
kurihada 791c4a4616 fix(ui): 修复工具调用重复显示问题
在 useChat hook 中添加去重逻辑:
- tool_start: 检查是否已存在相同 ID 的工具调用,存在则跳过
- subagent:tool_start: 同样添加去重检查

问题原因:服务端可能因 AI SDK 触发两次 tool-call chunk
导致发送重复的 tool_start 事件,前端之前没有去重逻辑
2025-12-16 23:02:34 +08:00

664 lines
22 KiB
TypeScript

/**
* Chat Hook
*
* 管理 WebSocket 连接和消息状态
*/
import { useState, useEffect, useCallback, useRef } from 'react';
import { createWebSocket, getMessages, type Message } from '../api/client.js';
import type { PermissionRequest } from '../components/PermissionDialog.js';
import type {
ConfigErrorPayload,
ToolStartPayload,
ToolEndPayload,
MessagePart,
ToolMessagePart,
AgentModeType,
SubagentStartPayload,
SubagentEndPayload,
SubagentStreamPayload,
SubagentToolStartPayload,
SubagentToolEndPayload,
SubagentState,
SubagentToolInfo,
} from '../api/types.js';
interface UseChatOptions {
sessionId: string;
onError?: (error: Error) => void;
onSessionNotFound?: () => void;
onSessionUpdated?: (sessionId: string, name: string) => void;
/** 配置错误回调(如 API Key 未配置) */
onConfigError?: (error: ConfigErrorPayload) => void;
}
interface ChatState {
messages: Message[];
isConnected: boolean;
isLoading: boolean;
/** 流式消息对象,复用 Message 结构 */
streamingMessage: Message | null;
permissionRequest: PermissionRequest | null;
/** Agent 模式 (会话级别) */
agentMode: AgentModeType;
/** 是否自动授权文件写入/编辑 (会话级别) */
autoApprove: boolean;
/** 当前正在执行的 Agent 名称 */
currentAgent: string;
/** 当前正在执行的子 Agent 状态 */
currentSubagent: SubagentState | null;
}
export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdated, onConfigError }: UseChatOptions) {
const [state, setState] = useState<ChatState>({
messages: [],
isConnected: false,
isLoading: false,
streamingMessage: null,
permissionRequest: null,
agentMode: 'build',
autoApprove: false,
currentAgent: 'build',
currentSubagent: null,
});
const wsRef = useRef<WebSocket | null>(null);
const reconnectTimeoutRef = useRef<ReturnType<typeof setTimeout>>();
const reconnectAttemptsRef = useRef(0);
const maxReconnectAttempts = 5;
// 标记是否正在主动关闭连接(切换 session 时)
const isClosingRef = useRef(false);
// 用 ref 存储回调,避免依赖变化导致无限循环
const onErrorRef = useRef(onError);
const onSessionNotFoundRef = useRef(onSessionNotFound);
const onSessionUpdatedRef = useRef(onSessionUpdated);
const onConfigErrorRef = useRef(onConfigError);
onErrorRef.current = onError;
onSessionNotFoundRef.current = onSessionNotFound;
onSessionUpdatedRef.current = onSessionUpdated;
onConfigErrorRef.current = onConfigError;
// 加载历史消息
const loadMessages = useCallback(async () => {
try {
const { data } = await getMessages(sessionId);
setState((prev) => ({ ...prev, messages: data }));
} catch (error) {
// 会话不存在(404 或 "Session not found"),通知上层重新创建
const msg = error instanceof Error ? error.message : '';
if (msg.includes('404') || msg.toLowerCase().includes('not found')) {
onSessionNotFoundRef.current?.();
return;
}
onErrorRef.current?.(error instanceof Error ? error : new Error('Failed to load messages'));
}
}, [sessionId]);
// 连接 WebSocket
const connect = useCallback(() => {
// 如果正在关闭,不要连接
if (isClosingRef.current) return;
// 如果已经连接,不要重复连接
if (wsRef.current?.readyState === WebSocket.OPEN) return;
// 如果正在连接中,不要重复连接
if (wsRef.current?.readyState === WebSocket.CONNECTING) return;
const ws = createWebSocket(sessionId);
ws.onopen = () => {
// 如果在连接过程中组件已卸载,立即关闭
if (isClosingRef.current) {
ws.close();
return;
}
reconnectAttemptsRef.current = 0; // 连接成功,重置重连次数
setState((prev) => ({ ...prev, isConnected: true }));
};
ws.onclose = () => {
setState((prev) => ({ ...prev, isConnected: false }));
// 主动关闭时不重连
if (isClosingRef.current) {
isClosingRef.current = false;
return;
}
// 限制重连次数
if (reconnectAttemptsRef.current < maxReconnectAttempts) {
reconnectAttemptsRef.current++;
reconnectTimeoutRef.current = setTimeout(connect, 3000);
}
};
ws.onerror = () => {
// 主动关闭时不报错
if (isClosingRef.current) return;
onErrorRef.current?.(new Error('WebSocket connection error'));
};
ws.onmessage = (event) => {
try {
const message = JSON.parse(event.data);
switch (message.type) {
case 'chunk': {
const chunkContent = message.payload?.content || '';
setState((prev) => {
// 初始化或获取当前流式消息
const streaming = prev.streamingMessage || {
id: `streaming-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'assistant' as const,
timestamp: new Date().toISOString(),
parts: [] as MessagePart[],
content: '',
metadata: { agentName: prev.currentAgent },
};
// 复制 parts 数组以进行修改
const parts = [...streaming.parts];
const lastPart = parts[parts.length - 1];
// 如果最后一个 part 是 text,追加内容;否则创建新 text part
if (lastPart?.type === 'text') {
parts[parts.length - 1] = {
...lastPart,
text: lastPart.text + chunkContent,
};
} else if (chunkContent) {
parts.push({
type: 'text',
id: `text-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
text: chunkContent,
});
}
return {
...prev,
streamingMessage: {
...streaming,
parts,
content: (streaming.content || '') + chunkContent,
},
};
});
break;
}
case 'tool_start': {
const payload = message.payload as ToolStartPayload;
setState((prev) => {
// 初始化或获取当前流式消息
const streaming = prev.streamingMessage || {
id: `streaming-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'assistant' as const,
timestamp: new Date().toISOString(),
parts: [] as MessagePart[],
content: '',
metadata: { agentName: prev.currentAgent },
};
// 检查是否已存在相同 ID 的工具调用(去重)
const existingTool = streaming.parts.find(
(part) => part.type === 'tool' && part.id === payload.id
);
if (existingTool) {
// 已存在相同 ID 的工具调用,跳过
return prev;
}
// 添加工具调用 part
const toolPart: ToolMessagePart = {
type: 'tool',
id: payload.id,
toolCallId: payload.id,
toolName: payload.toolName,
status: 'running',
arguments: payload.arguments,
};
// 如果是 task 工具,切换到子 agent
const newAgent =
payload.toolName === 'task' && payload.arguments?.subagent_type
? (payload.arguments.subagent_type as string)
: prev.currentAgent;
return {
...prev,
currentAgent: newAgent,
streamingMessage: {
...streaming,
parts: [...streaming.parts, toolPart],
metadata: { ...streaming.metadata, agentName: newAgent },
},
};
});
break;
}
case 'tool_end': {
const payload = message.payload as ToolEndPayload;
setState((prev) => {
if (!prev.streamingMessage) return prev;
// 查找并更新对应的工具 part
const parts = prev.streamingMessage.parts.map((part) => {
if (part.type === 'tool' && part.id === payload.id) {
return {
...part,
status: payload.status,
result: payload.result,
error: payload.error,
duration: payload.duration,
} as ToolMessagePart;
}
return part;
});
// 查找完成的工具是否为 task,如果是则恢复主 agent
const completedTool = prev.streamingMessage.parts.find(
(part) => part.type === 'tool' && part.id === payload.id
);
const isTaskTool = completedTool?.type === 'tool' && completedTool.toolName === 'task';
const newAgent = isTaskTool ? prev.agentMode : prev.currentAgent;
return {
...prev,
currentAgent: newAgent,
streamingMessage: {
...prev.streamingMessage,
parts,
metadata: { ...prev.streamingMessage.metadata, agentName: newAgent },
},
};
});
break;
}
case 'done':
setState((prev) => {
// 使用流式消息或创建新消息
const streaming = prev.streamingMessage;
const content = message.payload?.content || streaming?.content || '';
// 从服务器 payload 获取 agentName,或使用当前 agentMode
const agentName = message.payload?.agentName || prev.agentMode;
const newMessage: Message = streaming
? {
...streaming,
id: message.payload?.id || streaming.id,
timestamp: message.payload?.timestamp || streaming.timestamp,
content,
metadata: { ...streaming.metadata, agentName },
}
: {
id: message.payload?.id || `assistant-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'assistant',
timestamp: message.payload?.timestamp || new Date().toISOString(),
parts: [{ type: 'text', id: `text-${Date.now()}`, text: content }],
content,
metadata: { agentName },
};
return {
...prev,
messages: [...prev.messages, newMessage],
streamingMessage: null,
isLoading: false,
currentAgent: prev.agentMode, // 恢复为主 agent
};
});
break;
case 'message_received':
// 用户消息已确认 - 构建完整的消息对象
setState((prev) => {
const content = message.payload?.content || '';
const userMessage: Message = {
id: `user-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'user',
timestamp: new Date().toISOString(),
parts: [{ type: 'text', id: `text-${Date.now()}`, text: content }],
content,
};
return {
...prev,
messages: [...prev.messages, userMessage],
};
});
break;
case 'error':
// 检查是否为配置错误
if (message.payload?.type === 'config_error') {
onConfigErrorRef.current?.(message.payload as ConfigErrorPayload);
} else {
onErrorRef.current?.(new Error(message.payload?.message || 'Unknown error'));
}
setState((prev) => ({ ...prev, isLoading: false, streamingMessage: null }));
break;
case 'session_updated':
// 会话信息更新(如标题)
if (message.payload?.id && message.payload?.name) {
onSessionUpdatedRef.current?.(message.payload.id, message.payload.name);
}
break;
case 'permission_request':
// 权限请求
if (message.payload) {
setState((prev) => ({
...prev,
permissionRequest: message.payload as PermissionRequest,
}));
}
break;
// ============ 子 Agent 事件处理 ============
case 'subagent:start': {
const payload = message.payload as SubagentStartPayload;
setState((prev) => ({
...prev,
currentAgent: payload.agentName,
currentSubagent: {
id: payload.agentId,
name: payload.agentName,
description: payload.description,
status: 'running',
tools: [],
streamContent: '',
},
}));
break;
}
case 'subagent:tool_start': {
const payload = message.payload as SubagentToolStartPayload;
setState((prev) => {
if (!prev.currentSubagent || prev.currentSubagent.id !== payload.agentId) {
return prev;
}
// 检查是否已存在相同 ID 的工具调用(去重)
const existingTool = prev.currentSubagent.tools.find(
(tool) => tool.id === payload.toolCallId
);
if (existingTool) {
// 已存在相同 ID 的工具调用,跳过
return prev;
}
const newTool: SubagentToolInfo = {
id: payload.toolCallId,
toolName: payload.toolName,
status: 'running',
args: payload.args,
};
return {
...prev,
currentSubagent: {
...prev.currentSubagent,
tools: [...prev.currentSubagent.tools, newTool],
},
};
});
break;
}
case 'subagent:stream': {
const payload = message.payload as SubagentStreamPayload;
setState((prev) => {
if (!prev.currentSubagent || prev.currentSubagent.id !== payload.agentId) {
return prev;
}
return {
...prev,
currentSubagent: {
...prev.currentSubagent,
streamContent: prev.currentSubagent.streamContent + payload.content,
},
};
});
break;
}
case 'subagent:tool_end': {
const payload = message.payload as SubagentToolEndPayload;
setState((prev) => {
if (!prev.currentSubagent || prev.currentSubagent.id !== payload.agentId) {
return prev;
}
const updatedTools = prev.currentSubagent.tools.map((tool) => {
if (tool.id === payload.toolCallId) {
return {
...tool,
status: payload.status === 'completed' ? 'completed' : 'error',
result: payload.result,
error: payload.error,
duration: payload.duration,
} as SubagentToolInfo;
}
return tool;
});
return {
...prev,
currentSubagent: {
...prev.currentSubagent,
tools: updatedTools,
},
};
});
break;
}
case 'subagent:end': {
const payload = message.payload as SubagentEndPayload;
setState((prev) => {
// 只有当 agentId 匹配时才处理
if (!prev.currentSubagent || prev.currentSubagent.id !== payload.agentId) {
return prev;
}
return {
...prev,
currentAgent: prev.agentMode, // 恢复为主 Agent
currentSubagent: {
...prev.currentSubagent,
status: payload.success ? 'completed' : 'error',
duration: payload.duration,
error: payload.error,
},
};
});
// 完成后短暂延迟再清除,让 UI 能显示最终状态
setTimeout(() => {
setState((prev) => ({
...prev,
currentSubagent: null,
}));
}, 1000);
break;
}
}
} catch {
// 忽略解析错误
}
};
wsRef.current = ws;
}, [sessionId]);
// 发送消息
const sendMessage = useCallback(
(content: string) => {
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
onErrorRef.current?.(new Error('WebSocket not connected'));
return;
}
setState((prev) => ({ ...prev, isLoading: true }));
wsRef.current.send(
JSON.stringify({
type: 'message',
sessionId,
payload: {
content,
agentMode: state.agentMode,
autoApprove: state.autoApprove,
},
})
);
},
[sessionId, state.agentMode, state.autoApprove]
);
// 取消处理
const cancelProcessing = useCallback(() => {
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) return;
wsRef.current.send(
JSON.stringify({
type: 'cancel',
sessionId,
})
);
// 保留已流式输出的内容作为消息
setState((prev) => {
const streaming = prev.streamingMessage;
// 如果有流式消息且有内容,保存到 messages 中
if (streaming && (streaming.content || streaming.parts.length > 0)) {
const cancelledMessage: Message = {
...streaming,
id: streaming.id || `cancelled-${Date.now()}`,
// 标记为已取消(可选:在 content 末尾添加提示)
content: streaming.content || '',
};
return {
...prev,
messages: [...prev.messages, cancelledMessage],
isLoading: false,
streamingMessage: null,
};
}
return { ...prev, isLoading: false, streamingMessage: null };
});
}, [sessionId]);
// 发送权限响应
const respondToPermission = useCallback(
(requestId: string, allow: boolean, remember?: boolean) => {
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
onErrorRef.current?.(new Error('WebSocket not connected'));
return;
}
wsRef.current.send(
JSON.stringify({
type: 'permission_response',
sessionId,
payload: { requestId, allow, remember },
})
);
// 清除权限请求状态
setState((prev) => ({ ...prev, permissionRequest: null }));
},
[sessionId]
);
// 允许权限请求
const allowPermission = useCallback(
(requestId: string, remember?: boolean) => {
respondToPermission(requestId, true, remember);
},
[respondToPermission]
);
// 拒绝权限请求
const denyPermission = useCallback(
(requestId: string, remember?: boolean) => {
respondToPermission(requestId, false, remember);
},
[respondToPermission]
);
// 设置 Agent 模式 (会话级别)
const setAgentMode = useCallback((mode: AgentModeType) => {
setState((prev) => ({ ...prev, agentMode: mode }));
}, []);
// 设置自动授权 (会话级别,立即生效)
const setAutoApprove = useCallback(
(enabled: boolean) => {
setState((prev) => ({ ...prev, autoApprove: enabled }));
// 立即通过 WebSocket 发送配置更新,使其对当前执行生效
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
wsRef.current.send(
JSON.stringify({
type: 'config_update',
sessionId,
payload: { autoApprove: enabled },
})
);
}
},
[sessionId]
);
// 初始化
useEffect(() => {
// 重置状态
isClosingRef.current = false;
setState({
messages: [],
isConnected: false,
isLoading: false,
streamingMessage: null,
permissionRequest: null,
agentMode: 'build',
autoApprove: false,
currentAgent: 'build',
currentSubagent: null,
});
reconnectAttemptsRef.current = 0;
loadMessages();
connect();
return () => {
clearTimeout(reconnectTimeoutRef.current);
// 标记为主动关闭,避免触发错误回调和重连
isClosingRef.current = true;
// 只关闭已建立的连接
if (wsRef.current) {
const ws = wsRef.current;
// 清除引用,防止后续操作
wsRef.current = null;
// 清除事件处理器,避免关闭时触发错误
ws.onclose = null;
ws.onerror = null;
ws.onmessage = null;
ws.onopen = null;
// 只关闭已经建立的连接,避免 "closed before established" 警告
if (ws.readyState === WebSocket.OPEN) {
ws.close();
}
// 对于 CONNECTING 状态,等待连接建立后再关闭
// 这样可以避免浏览器警告
}
};
}, [loadMessages, connect]);
return {
...state,
sendMessage,
cancelProcessing,
reload: loadMessages,
allowPermission,
denyPermission,
setAgentMode,
setAutoApprove,
};
}