Files
ai-terminal-assistant/packages/ui/src/hooks/useChat.ts
T
kurihada f238368f87 feat(ui): 支持 Auto Edit 运行时切换立即生效
- 添加 config_update WebSocket 消息类型
- setAutoApprove 切换时立即发送配置更新到服务端
- 服务端实时更新会话的 auto-approve 配置
2025-12-15 19:52:06 +08:00

492 lines
15 KiB
TypeScript

/**
* Chat Hook
*
* 管理 WebSocket 连接和消息状态
*/
import { useState, useEffect, useCallback, useRef } from 'react';
import { createWebSocket, getMessages, type Message } from '../api/client.js';
import type { PermissionRequest } from '../components/PermissionDialog.js';
import type {
ConfigErrorPayload,
ToolStartPayload,
ToolEndPayload,
MessagePart,
ToolMessagePart,
AgentModeType,
} from '../api/types.js';
interface UseChatOptions {
sessionId: string;
onError?: (error: Error) => void;
onSessionNotFound?: () => void;
onSessionUpdated?: (sessionId: string, name: string) => void;
/** 配置错误回调(如 API Key 未配置) */
onConfigError?: (error: ConfigErrorPayload) => void;
}
interface ChatState {
messages: Message[];
isConnected: boolean;
isLoading: boolean;
/** 流式消息对象,复用 Message 结构 */
streamingMessage: Message | null;
permissionRequest: PermissionRequest | null;
/** Agent 模式 (会话级别) */
agentMode: AgentModeType;
/** 是否自动授权文件写入/编辑 (会话级别) */
autoApprove: boolean;
}
export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdated, onConfigError }: UseChatOptions) {
const [state, setState] = useState<ChatState>({
messages: [],
isConnected: false,
isLoading: false,
streamingMessage: null,
permissionRequest: null,
agentMode: 'build',
autoApprove: false,
});
const wsRef = useRef<WebSocket | null>(null);
const reconnectTimeoutRef = useRef<ReturnType<typeof setTimeout>>();
const reconnectAttemptsRef = useRef(0);
const maxReconnectAttempts = 5;
// 标记是否正在主动关闭连接(切换 session 时)
const isClosingRef = useRef(false);
// 用 ref 存储回调,避免依赖变化导致无限循环
const onErrorRef = useRef(onError);
const onSessionNotFoundRef = useRef(onSessionNotFound);
const onSessionUpdatedRef = useRef(onSessionUpdated);
const onConfigErrorRef = useRef(onConfigError);
onErrorRef.current = onError;
onSessionNotFoundRef.current = onSessionNotFound;
onSessionUpdatedRef.current = onSessionUpdated;
onConfigErrorRef.current = onConfigError;
// 加载历史消息
const loadMessages = useCallback(async () => {
try {
const { data } = await getMessages(sessionId);
setState((prev) => ({ ...prev, messages: data }));
} catch (error) {
// 会话不存在(404 或 "Session not found"),通知上层重新创建
const msg = error instanceof Error ? error.message : '';
if (msg.includes('404') || msg.toLowerCase().includes('not found')) {
onSessionNotFoundRef.current?.();
return;
}
onErrorRef.current?.(error instanceof Error ? error : new Error('Failed to load messages'));
}
}, [sessionId]);
// 连接 WebSocket
const connect = useCallback(() => {
// 如果正在关闭,不要连接
if (isClosingRef.current) return;
// 如果已经连接,不要重复连接
if (wsRef.current?.readyState === WebSocket.OPEN) return;
// 如果正在连接中,不要重复连接
if (wsRef.current?.readyState === WebSocket.CONNECTING) return;
const ws = createWebSocket(sessionId);
ws.onopen = () => {
// 如果在连接过程中组件已卸载,立即关闭
if (isClosingRef.current) {
ws.close();
return;
}
reconnectAttemptsRef.current = 0; // 连接成功,重置重连次数
setState((prev) => ({ ...prev, isConnected: true }));
};
ws.onclose = () => {
setState((prev) => ({ ...prev, isConnected: false }));
// 主动关闭时不重连
if (isClosingRef.current) {
isClosingRef.current = false;
return;
}
// 限制重连次数
if (reconnectAttemptsRef.current < maxReconnectAttempts) {
reconnectAttemptsRef.current++;
reconnectTimeoutRef.current = setTimeout(connect, 3000);
}
};
ws.onerror = () => {
// 主动关闭时不报错
if (isClosingRef.current) return;
onErrorRef.current?.(new Error('WebSocket connection error'));
};
ws.onmessage = (event) => {
try {
const message = JSON.parse(event.data);
switch (message.type) {
case 'chunk': {
const chunkContent = message.payload?.content || '';
setState((prev) => {
// 初始化或获取当前流式消息
const streaming = prev.streamingMessage || {
id: `streaming-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'assistant' as const,
timestamp: new Date().toISOString(),
parts: [] as MessagePart[],
content: '',
};
// 复制 parts 数组以进行修改
const parts = [...streaming.parts];
const lastPart = parts[parts.length - 1];
// 如果最后一个 part 是 text,追加内容;否则创建新 text part
if (lastPart?.type === 'text') {
parts[parts.length - 1] = {
...lastPart,
text: lastPart.text + chunkContent,
};
} else if (chunkContent) {
parts.push({
type: 'text',
id: `text-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
text: chunkContent,
});
}
return {
...prev,
streamingMessage: {
...streaming,
parts,
content: (streaming.content || '') + chunkContent,
},
};
});
break;
}
case 'tool_start': {
const payload = message.payload as ToolStartPayload;
setState((prev) => {
// 初始化或获取当前流式消息
const streaming = prev.streamingMessage || {
id: `streaming-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'assistant' as const,
timestamp: new Date().toISOString(),
parts: [] as MessagePart[],
content: '',
};
// 添加工具调用 part
const toolPart: ToolMessagePart = {
type: 'tool',
id: payload.id,
toolCallId: payload.id,
toolName: payload.toolName,
status: 'running',
arguments: payload.arguments,
};
return {
...prev,
streamingMessage: {
...streaming,
parts: [...streaming.parts, toolPart],
},
};
});
break;
}
case 'tool_end': {
const payload = message.payload as ToolEndPayload;
setState((prev) => {
if (!prev.streamingMessage) return prev;
// 查找并更新对应的工具 part
const parts = prev.streamingMessage.parts.map((part) => {
if (part.type === 'tool' && part.id === payload.id) {
return {
...part,
status: payload.status,
result: payload.result,
error: payload.error,
duration: payload.duration,
} as ToolMessagePart;
}
return part;
});
return {
...prev,
streamingMessage: {
...prev.streamingMessage,
parts,
},
};
});
break;
}
case 'done':
setState((prev) => {
// 使用流式消息或创建新消息
const streaming = prev.streamingMessage;
const content = message.payload?.content || streaming?.content || '';
const newMessage: Message = streaming
? {
...streaming,
id: message.payload?.id || streaming.id,
timestamp: message.payload?.timestamp || streaming.timestamp,
content,
}
: {
id: message.payload?.id || `assistant-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'assistant',
timestamp: message.payload?.timestamp || new Date().toISOString(),
parts: [{ type: 'text', id: `text-${Date.now()}`, text: content }],
content,
};
return {
...prev,
messages: [...prev.messages, newMessage],
streamingMessage: null,
isLoading: false,
};
});
break;
case 'message_received':
// 用户消息已确认 - 构建完整的消息对象
setState((prev) => {
const content = message.payload?.content || '';
const userMessage: Message = {
id: `user-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`,
role: 'user',
timestamp: new Date().toISOString(),
parts: [{ type: 'text', id: `text-${Date.now()}`, text: content }],
content,
};
return {
...prev,
messages: [...prev.messages, userMessage],
};
});
break;
case 'error':
// 检查是否为配置错误
if (message.payload?.type === 'config_error') {
onConfigErrorRef.current?.(message.payload as ConfigErrorPayload);
} else {
onErrorRef.current?.(new Error(message.payload?.message || 'Unknown error'));
}
setState((prev) => ({ ...prev, isLoading: false, streamingMessage: null }));
break;
case 'session_updated':
// 会话信息更新(如标题)
if (message.payload?.id && message.payload?.name) {
onSessionUpdatedRef.current?.(message.payload.id, message.payload.name);
}
break;
case 'permission_request':
// 权限请求
if (message.payload) {
setState((prev) => ({
...prev,
permissionRequest: message.payload as PermissionRequest,
}));
}
break;
}
} catch {
// 忽略解析错误
}
};
wsRef.current = ws;
}, [sessionId]);
// 发送消息
const sendMessage = useCallback(
(content: string) => {
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
onErrorRef.current?.(new Error('WebSocket not connected'));
return;
}
setState((prev) => ({ ...prev, isLoading: true }));
wsRef.current.send(
JSON.stringify({
type: 'message',
sessionId,
payload: {
content,
agentMode: state.agentMode,
autoApprove: state.autoApprove,
},
})
);
},
[sessionId, state.agentMode, state.autoApprove]
);
// 取消处理
const cancelProcessing = useCallback(() => {
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) return;
wsRef.current.send(
JSON.stringify({
type: 'cancel',
sessionId,
})
);
// 保留已流式输出的内容作为消息
setState((prev) => {
const streaming = prev.streamingMessage;
// 如果有流式消息且有内容,保存到 messages 中
if (streaming && (streaming.content || streaming.parts.length > 0)) {
const cancelledMessage: Message = {
...streaming,
id: streaming.id || `cancelled-${Date.now()}`,
// 标记为已取消(可选:在 content 末尾添加提示)
content: streaming.content || '',
};
return {
...prev,
messages: [...prev.messages, cancelledMessage],
isLoading: false,
streamingMessage: null,
};
}
return { ...prev, isLoading: false, streamingMessage: null };
});
}, [sessionId]);
// 发送权限响应
const respondToPermission = useCallback(
(requestId: string, allow: boolean, remember?: boolean) => {
if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) {
onErrorRef.current?.(new Error('WebSocket not connected'));
return;
}
wsRef.current.send(
JSON.stringify({
type: 'permission_response',
sessionId,
payload: { requestId, allow, remember },
})
);
// 清除权限请求状态
setState((prev) => ({ ...prev, permissionRequest: null }));
},
[sessionId]
);
// 允许权限请求
const allowPermission = useCallback(
(requestId: string, remember?: boolean) => {
respondToPermission(requestId, true, remember);
},
[respondToPermission]
);
// 拒绝权限请求
const denyPermission = useCallback(
(requestId: string, remember?: boolean) => {
respondToPermission(requestId, false, remember);
},
[respondToPermission]
);
// 设置 Agent 模式 (会话级别)
const setAgentMode = useCallback((mode: AgentModeType) => {
setState((prev) => ({ ...prev, agentMode: mode }));
}, []);
// 设置自动授权 (会话级别,立即生效)
const setAutoApprove = useCallback(
(enabled: boolean) => {
setState((prev) => ({ ...prev, autoApprove: enabled }));
// 立即通过 WebSocket 发送配置更新,使其对当前执行生效
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
wsRef.current.send(
JSON.stringify({
type: 'config_update',
sessionId,
payload: { autoApprove: enabled },
})
);
}
},
[sessionId]
);
// 初始化
useEffect(() => {
// 重置状态
isClosingRef.current = false;
setState({
messages: [],
isConnected: false,
isLoading: false,
streamingMessage: null,
permissionRequest: null,
agentMode: 'build',
autoApprove: false,
});
reconnectAttemptsRef.current = 0;
loadMessages();
connect();
return () => {
clearTimeout(reconnectTimeoutRef.current);
// 标记为主动关闭,避免触发错误回调和重连
isClosingRef.current = true;
// 只关闭已建立的连接
if (wsRef.current) {
const ws = wsRef.current;
// 清除引用,防止后续操作
wsRef.current = null;
// 清除事件处理器,避免关闭时触发错误
ws.onclose = null;
ws.onerror = null;
ws.onmessage = null;
ws.onopen = null;
// 只关闭已经建立的连接,避免 "closed before established" 警告
if (ws.readyState === WebSocket.OPEN) {
ws.close();
}
// 对于 CONNECTING 状态,等待连接建立后再关闭
// 这样可以避免浏览器警告
}
};
}, [loadMessages, connect]);
return {
...state,
sendMessage,
cancelProcessing,
reload: loadMessages,
allowPermission,
denyPermission,
setAgentMode,
setAutoApprove,
};
}