Files
ai-terminal-assistant/packages/server/src/ws.ts
T
kurihada 48a11ff077 feat: 添加 :new 系统命令创建新会话
- Core: 新增 :new/:n 命令返回 new_session action
- Server: 处理 new_session action 创建新会话
- UI: useChat 添加 onSessionSwitch 回调
- Web/Desktop: ChatPage 和 App 实现会话切换逻辑
2025-12-17 19:36:47 +08:00

353 lines
10 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* WebSocket Handler
*
* 处理实时双向通信,主要用于 AI 对话流
*/
import type { WSContext } from 'hono/ws';
import { getSessionManager } from './session/manager.js';
import { processMessage, cancelProcessing, getOrCreateAgent, submitUserInput, clearAgentHistory } from './agent/index.js';
import { handlePermissionResponse, setSessionAutoApprove } from './permission/handler.js';
import type { ClientMessage, ServerMessage } from './types.js';
import {
isSystemCommand,
executeSystemCommand,
initializeSystemCommands,
MessageStorage,
} from '@ai-assistant/core';
// 初始化系统命令
initializeSystemCommands();
// 存储活跃的 WebSocket 连接
const connections: Map<string, Set<WSContext>> = new Map();
/**
* 获取会话的所有连接
*/
export function getSessionConnections(sessionId: string): Set<WSContext> {
return connections.get(sessionId) || new Set();
}
/**
* 向会话的所有连接发送消息
*/
export function broadcastToSession(sessionId: string, message: ServerMessage): void {
const conns = connections.get(sessionId);
if (!conns) return;
const data = JSON.stringify(message);
for (const ws of conns) {
try {
ws.send(data);
} catch (error) {
console.error('Failed to send message:', error);
}
}
}
/**
* WebSocket 连接处理器
*/
export function handleWebSocket(ws: WSContext, sessionId: string): void {
const sessionManager = getSessionManager();
// 验证会话
if (!sessionManager.exists(sessionId)) {
ws.send(
JSON.stringify({
type: 'error',
sessionId,
payload: { message: 'Session not found' },
} as ServerMessage)
);
ws.close(4004, 'Session not found');
return;
}
// 注册连接
if (!connections.has(sessionId)) {
connections.set(sessionId, new Set());
}
connections.get(sessionId)!.add(ws);
// 更新会话状态
sessionManager.updateStatus(sessionId, 'active');
// 发送连接成功消息
ws.send(
JSON.stringify({
type: 'connected',
sessionId,
payload: { message: 'Connected to session' },
} as ServerMessage)
);
console.log(`[WS] Client connected to session: ${sessionId}`);
}
/**
* WebSocket 消息处理器
*/
export async function handleWebSocketMessage(
ws: WSContext,
sessionId: string,
data: unknown
): Promise<void> {
const sessionManager = getSessionManager();
try {
// 处理不同类型的数据
let text: string;
if (typeof data === 'string') {
text = data;
} else if (data instanceof ArrayBuffer || data instanceof SharedArrayBuffer) {
text = new TextDecoder().decode(data as ArrayBuffer);
} else if (data instanceof Blob) {
text = await data.text();
} else {
text = String(data);
}
const message: ClientMessage = JSON.parse(text);
switch (message.type) {
case 'message': {
// 用户发送消息
let content = message.payload?.content || '';
const agentMode = message.payload?.agentMode as 'build' | 'plan' | undefined;
const autoApprove = message.payload?.autoApprove as boolean | undefined;
// 检测系统命令(: 前缀)
if (isSystemCommand(content)) {
await handleSystemCommand(sessionId, content);
break;
}
// 将 @filepath 转换为 ./filepath 格式(方便 AI 识别为文件路径)
content = content.replace(/@([\w./-]+)/g, './$1');
// 广播确认收到消息
broadcastToSession(sessionId, {
type: 'message_received',
sessionId,
payload: { content: message.payload?.content || '' }, // 广播原始内容
});
// 调用 Agent 处理消息(异步,不阻塞)
// 消息存储由 Core Agent 负责
processMessage(sessionId, content, { agentMode, autoApprove }).catch((error) => {
console.error('[WS] Agent processing error:', error);
});
break;
}
case 'cancel': {
// 取消当前操作
cancelProcessing(sessionId);
broadcastToSession(sessionId, {
type: 'cancelled',
sessionId,
payload: { message: 'Operation cancelled' },
});
break;
}
case 'tool_response': {
// 工具执行结果 (用于人工确认场景)
// TODO: 处理工具响应
break;
}
case 'permission_response': {
// 处理权限确认响应
const { requestId, allow, remember } = message.payload || {};
if (requestId) {
const handled = handlePermissionResponse(requestId, allow ?? false, remember);
if (!handled) {
console.warn(`[WS] Permission response for unknown request: ${requestId}`);
}
}
break;
}
case 'config_update': {
// 实时配置更新(如 Auto Edit 开关)
const autoApprove = message.payload?.autoApprove;
if (typeof autoApprove === 'boolean') {
if (autoApprove) {
setSessionAutoApprove(sessionId, {
file: { write: 'allow', edit: 'allow' },
});
} else {
setSessionAutoApprove(sessionId, null);
}
console.log(`[WS] Config updated for session ${sessionId}: autoApprove=${autoApprove}`);
}
break;
}
case 'mode_switch': {
// 动态模式切换(Build ↔ Plan
const mode = message.payload?.agentMode as 'build' | 'plan' | undefined;
if (mode === 'build' || mode === 'plan') {
try {
const agent = await getOrCreateAgent(sessionId);
if (agent && typeof agent.switchMode === 'function') {
agent.switchMode(mode, true); // 保留对话历史
broadcastToSession(sessionId, {
type: 'mode_switched',
sessionId,
payload: { mode },
});
console.log(`[WS] Mode switched for session ${sessionId}: ${mode}`);
} else {
console.warn(`[WS] Agent does not support switchMode for session ${sessionId}`);
}
} catch (error) {
console.error(`[WS] Failed to switch mode for session ${sessionId}:`, error);
broadcastToSession(sessionId, {
type: 'error',
sessionId,
payload: { message: 'Failed to switch mode' },
});
}
}
break;
}
case 'user_input_response': {
// 处理用户输入响应(用于 ask_user_question 等工具)
const toolCallId = message.payload?.toolCallId;
const answer = message.payload?.answer;
if (toolCallId && answer !== undefined) {
const handled = submitUserInput(toolCallId, answer);
if (!handled) {
console.warn(`[WS] User input response for unknown tool call: ${toolCallId}`);
broadcastToSession(sessionId, {
type: 'error',
sessionId,
payload: { message: `No pending input request for tool call: ${toolCallId}` },
});
} else {
console.log(`[WS] User input submitted for tool call: ${toolCallId}`);
}
} else {
broadcastToSession(sessionId, {
type: 'error',
sessionId,
payload: { message: 'Missing toolCallId or answer in user_input_response' },
});
}
break;
}
default:
ws.send(
JSON.stringify({
type: 'error',
sessionId,
payload: { message: `Unknown message type: ${(message as any).type}` },
} as ServerMessage)
);
}
} catch (error) {
ws.send(
JSON.stringify({
type: 'error',
sessionId,
payload: {
message: error instanceof Error ? error.message : 'Failed to process message',
},
} as ServerMessage)
);
}
}
/**
* WebSocket 关闭处理器
*/
export function handleWebSocketClose(ws: WSContext, sessionId: string): void {
const conns = connections.get(sessionId);
if (conns) {
conns.delete(ws);
if (conns.size === 0) {
connections.delete(sessionId);
// 当没有连接时,更新会话状态
const sessionManager = getSessionManager();
sessionManager.updateStatus(sessionId, 'idle');
}
}
console.log(`[WS] Client disconnected from session: ${sessionId}`);
}
/**
* 获取连接统计
*/
export function getConnectionStats(): { sessions: number; connections: number } {
let totalConnections = 0;
for (const conns of connections.values()) {
totalConnections += conns.size;
}
return {
sessions: connections.size,
connections: totalConnections,
};
}
/**
* 处理系统命令(: 前缀)
*/
async function handleSystemCommand(sessionId: string, content: string): Promise<void> {
const sessionManager = getSessionManager();
console.log(`[WS] System command: ${content}`);
// 执行系统命令
const result = await executeSystemCommand(content, { sessionId });
// 处理特殊操作
if (result.success && result.action) {
switch (result.action.type) {
case 'clear_messages': {
// 清空存储的消息
await MessageStorage.removeBySession(sessionId);
// 清空 Agent 内存中的对话历史
clearAgentHistory(sessionId);
// 重置会话状态
sessionManager.updateStatus(sessionId, 'idle');
break;
}
case 'new_session': {
// 获取当前会话信息以复用配置
const currentSession = sessionManager.get(sessionId);
const workdir = currentSession?.workdir || process.cwd();
// 创建新会话
const newSession = await sessionManager.create({ workdir });
// 更新 action 中的 sessionId
result.action = { type: 'new_session', sessionId: newSession.id };
break;
}
}
}
// 发送结果给客户端
broadcastToSession(sessionId, {
type: 'system_command_result',
sessionId,
payload: {
success: result.success,
message: result.message,
error: result.error,
action: result.action,
},
});
}