/** * WebSocket Handler 测试 * * 测试 WebSocket 连接处理、消息路由、广播功能等 * * 注意:消息存储已移至 Core 层,Server 的 WebSocket 只负责消息广播和路由 */ import { describe, it, expect, beforeEach, vi } from 'vitest'; // Create mock functions const mockExists = vi.fn(); const mockUpdateStatus = vi.fn(); const mockGet = vi.fn(); // Mock dependencies before imports vi.mock('../../src/session/manager.js', () => ({ getSessionManager: vi.fn(() => ({ exists: mockExists, updateStatus: mockUpdateStatus, get: mockGet, })), })); vi.mock('../../src/agent/index.js', () => ({ processMessage: vi.fn().mockResolvedValue(undefined), cancelProcessing: vi.fn(), })); vi.mock('../../src/permission/handler.js', () => ({ handlePermissionResponse: vi.fn().mockReturnValue(true), })); import { handleWebSocket, handleWebSocketMessage, handleWebSocketClose, broadcastToSession, getSessionConnections, getConnectionStats, } from '../../src/ws.js'; import { processMessage, cancelProcessing } from '../../src/agent/index.js'; import { handlePermissionResponse } from '../../src/permission/handler.js'; // Mock WSContext function createMockWSContext() { return { send: vi.fn(), close: vi.fn(), }; } // Counter to generate unique session IDs for test isolation let testCounter = 0; function getUniqueSessionId(prefix = 'session') { return `${prefix}-${Date.now()}-${testCounter++}`; } describe('WebSocket Handler', () => { beforeEach(() => { vi.clearAllMocks(); mockExists.mockReturnValue(false); }); describe('handleWebSocket - 连接处理', () => { it('无效 session 发送错误并关闭连接', () => { const ws = createMockWSContext(); mockExists.mockReturnValue(false); handleWebSocket(ws as any, 'invalid-session'); expect(ws.send).toHaveBeenCalledWith( expect.stringContaining('"type":"error"') ); expect(ws.send).toHaveBeenCalledWith( expect.stringContaining('Session not found') ); expect(ws.close).toHaveBeenCalledWith(4004, 'Session not found'); }); it('有效 session 注册连接并发送 connected 消息', () => { const ws = createMockWSContext(); mockExists.mockReturnValue(true); handleWebSocket(ws as any, 'session-1'); expect(ws.close).not.toHaveBeenCalled(); expect(ws.send).toHaveBeenCalledWith( expect.stringContaining('"type":"connected"') ); expect(mockUpdateStatus).toHaveBeenCalledWith('session-1', 'active'); }); it('连接被添加到 connections map', () => { const ws = createMockWSContext(); mockExists.mockReturnValue(true); handleWebSocket(ws as any, 'session-1'); const connections = getSessionConnections('session-1'); expect(connections.has(ws as any)).toBe(true); }); }); describe('handleWebSocketMessage - 消息处理', () => { beforeEach(() => { mockExists.mockReturnValue(true); }); it('处理 message 类型消息', async () => { const ws = createMockWSContext(); handleWebSocket(ws as any, 'session-1'); const message = JSON.stringify({ type: 'message', payload: { content: 'Hello AI' }, }); await handleWebSocketMessage(ws as any, 'session-1', message); // 应该广播 message_received 确认 expect(ws.send).toHaveBeenCalledWith( expect.stringContaining('"type":"message_received"') ); // 应该调用 processMessage expect(processMessage).toHaveBeenCalledWith('session-1', 'Hello AI', expect.any(Object)); }); it('处理 cancel 类型消息', async () => { const ws = createMockWSContext(); handleWebSocket(ws as any, 'session-1'); const message = JSON.stringify({ type: 'cancel' }); await handleWebSocketMessage(ws as any, 'session-1', message); expect(cancelProcessing).toHaveBeenCalledWith('session-1'); // 应该广播 cancelled 消息 expect(ws.send).toHaveBeenCalledWith( expect.stringContaining('"type":"cancelled"') ); }); it('处理 permission_response 类型消息', async () => { const ws = createMockWSContext(); const message = JSON.stringify({ type: 'permission_response', payload: { requestId: 'req-123', allow: true, remember: false }, }); await handleWebSocketMessage(ws as any, 'session-1', message); expect(handlePermissionResponse).toHaveBeenCalledWith('req-123', true, false); }); it('未知消息类型返回错误', async () => { const ws = createMockWSContext(); const message = JSON.stringify({ type: 'unknown_type' }); await handleWebSocketMessage(ws as any, 'session-1', message); expect(ws.send).toHaveBeenCalledWith( expect.stringContaining('"type":"error"') ); expect(ws.send).toHaveBeenCalledWith( expect.stringContaining('Unknown message type') ); }); it('无效 JSON 返回错误', async () => { const ws = createMockWSContext(); await handleWebSocketMessage(ws as any, 'session-1', 'invalid json'); expect(ws.send).toHaveBeenCalledWith( expect.stringContaining('"type":"error"') ); }); it('处理 ArrayBuffer 数据', async () => { const ws = createMockWSContext(); handleWebSocket(ws as any, 'session-1'); const message = { type: 'message', payload: { content: 'ArrayBuffer test' } }; const encoder = new TextEncoder(); const buffer = encoder.encode(JSON.stringify(message)).buffer; await handleWebSocketMessage(ws as any, 'session-1', buffer); expect(processMessage).toHaveBeenCalledWith('session-1', 'ArrayBuffer test', expect.any(Object)); }); it('空 content 处理正确', async () => { const ws = createMockWSContext(); handleWebSocket(ws as any, 'session-1'); const message = JSON.stringify({ type: 'message', payload: {}, }); await handleWebSocketMessage(ws as any, 'session-1', message); expect(processMessage).toHaveBeenCalledWith('session-1', '', expect.any(Object)); }); it('处理 Blob 数据', async () => { const ws = createMockWSContext(); handleWebSocket(ws as any, 'session-1'); const message = { type: 'message', payload: { content: 'Blob test' } }; const blob = new Blob([JSON.stringify(message)]); await handleWebSocketMessage(ws as any, 'session-1', blob); expect(processMessage).toHaveBeenCalledWith('session-1', 'Blob test', expect.any(Object)); }); it('处理非标准数据类型(转为字符串)', async () => { const ws = createMockWSContext(); // 使用一个对象作为数据,它会被 String() 转换 const message = { type: 'cancel' }; const objData = { toString: () => JSON.stringify(message) }; await handleWebSocketMessage(ws as any, 'session-1', objData); expect(cancelProcessing).toHaveBeenCalledWith('session-1'); }); it('处理 tool_response 类型消息(TODO 场景)', async () => { const ws = createMockWSContext(); const message = JSON.stringify({ type: 'tool_response', payload: { toolId: 'tool-123', result: 'success' }, }); // 当前实现是 TODO,所以不会有任何处理 await handleWebSocketMessage(ws as any, 'session-1', message); // 不应该发送错误消息 expect(ws.send).not.toHaveBeenCalledWith( expect.stringContaining('"type":"error"') ); }); it('permission_response 无 requestId 时不调用 handler', async () => { const ws = createMockWSContext(); const message = JSON.stringify({ type: 'permission_response', payload: { allow: true }, }); await handleWebSocketMessage(ws as any, 'session-1', message); expect(handlePermissionResponse).not.toHaveBeenCalled(); }); it('permission_response handler 返回 false 时打印警告', async () => { const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); const { handlePermissionResponse: mockHandler } = await import('../../src/permission/handler.js'); vi.mocked(mockHandler).mockReturnValue(false); const ws = createMockWSContext(); const message = JSON.stringify({ type: 'permission_response', payload: { requestId: 'unknown-req', allow: true }, }); await handleWebSocketMessage(ws as any, 'session-1', message); expect(consoleWarnSpy).toHaveBeenCalledWith( expect.stringContaining('unknown-req') ); consoleWarnSpy.mockRestore(); }); }); describe('handleWebSocketClose - 关闭处理', () => { it('从 connections 中移除连接', () => { const sessionId = getUniqueSessionId('close'); const ws = createMockWSContext(); mockExists.mockReturnValue(true); handleWebSocket(ws as any, sessionId); expect(getSessionConnections(sessionId).size).toBe(1); handleWebSocketClose(ws as any, sessionId); expect(getSessionConnections(sessionId).size).toBe(0); }); it('最后一个连接关闭时更新 session 状态为 idle', () => { const sessionId = getUniqueSessionId('idle'); const ws = createMockWSContext(); mockExists.mockReturnValue(true); handleWebSocket(ws as any, sessionId); handleWebSocketClose(ws as any, sessionId); expect(mockUpdateStatus).toHaveBeenLastCalledWith(sessionId, 'idle'); }); it('多个连接时关闭一个不影响其他', () => { const sessionId = getUniqueSessionId('multi'); const ws1 = createMockWSContext(); const ws2 = createMockWSContext(); mockExists.mockReturnValue(true); handleWebSocket(ws1 as any, sessionId); handleWebSocket(ws2 as any, sessionId); expect(getSessionConnections(sessionId).size).toBe(2); handleWebSocketClose(ws1 as any, sessionId); expect(getSessionConnections(sessionId).size).toBe(1); expect(getSessionConnections(sessionId).has(ws2 as any)).toBe(true); }); }); describe('broadcastToSession - 会话广播', () => { it('向所有连接发送消息', () => { const ws1 = createMockWSContext(); const ws2 = createMockWSContext(); mockExists.mockReturnValue(true); handleWebSocket(ws1 as any, 'session-1'); handleWebSocket(ws2 as any, 'session-1'); vi.clearAllMocks(); broadcastToSession('session-1', { type: 'chunk', sessionId: 'session-1', payload: { content: 'test' }, }); expect(ws1.send).toHaveBeenCalledTimes(1); expect(ws2.send).toHaveBeenCalledTimes(1); expect(ws1.send).toHaveBeenCalledWith(expect.stringContaining('"type":"chunk"')); }); it('无连接时不抛出错误', () => { expect(() => { broadcastToSession('non-existent', { type: 'chunk', sessionId: 'non-existent', payload: {}, }); }).not.toThrow(); }); it('发送失败时继续发送给其他连接', () => { mockExists.mockReturnValue(true); const ws1 = createMockWSContext(); const ws2 = createMockWSContext(); handleWebSocket(ws1 as any, 'session-1'); handleWebSocket(ws2 as any, 'session-1'); // Now set up ws1.send to throw ws1.send.mockImplementation(() => { throw new Error('Connection closed'); }); expect(() => { broadcastToSession('session-1', { type: 'chunk', sessionId: 'session-1', payload: {}, }); }).not.toThrow(); expect(ws2.send).toHaveBeenCalled(); }); }); describe('getSessionConnections - 获取会话连接', () => { it('返回会话的所有连接', () => { const sessionId = getUniqueSessionId('get-conns'); const ws = createMockWSContext(); mockExists.mockReturnValue(true); handleWebSocket(ws as any, sessionId); const connections = getSessionConnections(sessionId); expect(connections.size).toBe(1); }); it('不存在的会话返回空 Set', () => { const connections = getSessionConnections('non-existent-unique-12345'); expect(connections.size).toBe(0); }); }); describe('getConnectionStats - 连接统计', () => { it('返回正确的统计信息', () => { // Get initial stats to account for connections from previous tests const initialStats = getConnectionStats(); const sessionId1 = getUniqueSessionId('stats-1'); const sessionId2 = getUniqueSessionId('stats-2'); const ws1 = createMockWSContext(); const ws2 = createMockWSContext(); const ws3 = createMockWSContext(); mockExists.mockReturnValue(true); handleWebSocket(ws1 as any, sessionId1); handleWebSocket(ws2 as any, sessionId1); handleWebSocket(ws3 as any, sessionId2); const stats = getConnectionStats(); // We added 2 sessions and 3 connections expect(stats.sessions).toBe(initialStats.sessions + 2); expect(stats.connections).toBe(initialStats.connections + 3); }); }); });