5b20420ccd
- auth/token.ts: 50% → 100% - 新增 authMiddleware 中间件完整测试 - 覆盖本地 IP 检测、远程认证、跳过路径等场景 - 新增 getAuthContext 测试 - ws.ts: 90% → 98% - 新增 Blob/非标准数据类型处理测试 - 新增 addMessage 返回 null 场景测试 - 新增 tool_response 和 permission_response 边界测试 - sse.ts: 新增事件格式化和统计测试 测试数量: 344 → 369 (+25) 总体覆盖率: 80.82% → 82.98%
435 lines
13 KiB
TypeScript
435 lines
13 KiB
TypeScript
/**
|
|
* WebSocket Handler 测试
|
|
*
|
|
* 测试 WebSocket 连接处理、消息路由、广播功能等
|
|
*/
|
|
|
|
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
|
|
|
// Create mock functions
|
|
const mockExists = vi.fn();
|
|
const mockUpdateStatus = vi.fn();
|
|
const mockAddMessage = vi.fn();
|
|
const mockGet = vi.fn();
|
|
|
|
// Mock dependencies before imports
|
|
vi.mock('../../src/session/manager.js', () => ({
|
|
getSessionManager: vi.fn(() => ({
|
|
exists: mockExists,
|
|
updateStatus: mockUpdateStatus,
|
|
addMessage: mockAddMessage,
|
|
get: mockGet,
|
|
})),
|
|
}));
|
|
|
|
vi.mock('../../src/agent/index.js', () => ({
|
|
processMessage: vi.fn().mockResolvedValue(undefined),
|
|
cancelProcessing: vi.fn(),
|
|
}));
|
|
|
|
vi.mock('../../src/permission/handler.js', () => ({
|
|
handlePermissionResponse: vi.fn().mockReturnValue(true),
|
|
}));
|
|
|
|
import {
|
|
handleWebSocket,
|
|
handleWebSocketMessage,
|
|
handleWebSocketClose,
|
|
broadcastToSession,
|
|
getSessionConnections,
|
|
getConnectionStats,
|
|
} from '../../src/ws.js';
|
|
import { processMessage, cancelProcessing } from '../../src/agent/index.js';
|
|
import { handlePermissionResponse } from '../../src/permission/handler.js';
|
|
|
|
// Mock WSContext
|
|
function createMockWSContext() {
|
|
return {
|
|
send: vi.fn(),
|
|
close: vi.fn(),
|
|
};
|
|
}
|
|
|
|
// Counter to generate unique session IDs for test isolation
|
|
let testCounter = 0;
|
|
function getUniqueSessionId(prefix = 'session') {
|
|
return `${prefix}-${Date.now()}-${testCounter++}`;
|
|
}
|
|
|
|
describe('WebSocket Handler', () => {
|
|
beforeEach(() => {
|
|
vi.clearAllMocks();
|
|
mockExists.mockReturnValue(false);
|
|
mockAddMessage.mockReturnValue({ id: 'msg-1', role: 'user', content: '', timestamp: Date.now() });
|
|
});
|
|
|
|
describe('handleWebSocket - 连接处理', () => {
|
|
it('无效 session 发送错误并关闭连接', () => {
|
|
const ws = createMockWSContext();
|
|
mockExists.mockReturnValue(false);
|
|
|
|
handleWebSocket(ws as any, 'invalid-session');
|
|
|
|
expect(ws.send).toHaveBeenCalledWith(
|
|
expect.stringContaining('"type":"error"')
|
|
);
|
|
expect(ws.send).toHaveBeenCalledWith(
|
|
expect.stringContaining('Session not found')
|
|
);
|
|
expect(ws.close).toHaveBeenCalledWith(4004, 'Session not found');
|
|
});
|
|
|
|
it('有效 session 注册连接并发送 connected 消息', () => {
|
|
const ws = createMockWSContext();
|
|
mockExists.mockReturnValue(true);
|
|
|
|
handleWebSocket(ws as any, 'session-1');
|
|
|
|
expect(ws.close).not.toHaveBeenCalled();
|
|
expect(ws.send).toHaveBeenCalledWith(
|
|
expect.stringContaining('"type":"connected"')
|
|
);
|
|
expect(mockUpdateStatus).toHaveBeenCalledWith('session-1', 'active');
|
|
});
|
|
|
|
it('连接被添加到 connections map', () => {
|
|
const ws = createMockWSContext();
|
|
mockExists.mockReturnValue(true);
|
|
|
|
handleWebSocket(ws as any, 'session-1');
|
|
|
|
const connections = getSessionConnections('session-1');
|
|
expect(connections.has(ws as any)).toBe(true);
|
|
});
|
|
});
|
|
|
|
describe('handleWebSocketMessage - 消息处理', () => {
|
|
beforeEach(() => {
|
|
mockExists.mockReturnValue(true);
|
|
});
|
|
|
|
it('处理 message 类型消息', async () => {
|
|
const ws = createMockWSContext();
|
|
const message = JSON.stringify({
|
|
type: 'message',
|
|
payload: { content: 'Hello AI' },
|
|
});
|
|
|
|
await handleWebSocketMessage(ws as any, 'session-1', message);
|
|
|
|
expect(mockAddMessage).toHaveBeenCalledWith('session-1', {
|
|
role: 'user',
|
|
content: 'Hello AI',
|
|
});
|
|
expect(processMessage).toHaveBeenCalledWith('session-1', 'Hello AI');
|
|
});
|
|
|
|
it('处理 cancel 类型消息', async () => {
|
|
const ws = createMockWSContext();
|
|
handleWebSocket(ws as any, 'session-1');
|
|
|
|
const message = JSON.stringify({ type: 'cancel' });
|
|
await handleWebSocketMessage(ws as any, 'session-1', message);
|
|
|
|
expect(cancelProcessing).toHaveBeenCalledWith('session-1');
|
|
// 应该广播 cancelled 消息
|
|
expect(ws.send).toHaveBeenCalledWith(
|
|
expect.stringContaining('"type":"cancelled"')
|
|
);
|
|
});
|
|
|
|
it('处理 permission_response 类型消息', async () => {
|
|
const ws = createMockWSContext();
|
|
const message = JSON.stringify({
|
|
type: 'permission_response',
|
|
payload: { requestId: 'req-123', allow: true, remember: false },
|
|
});
|
|
|
|
await handleWebSocketMessage(ws as any, 'session-1', message);
|
|
|
|
expect(handlePermissionResponse).toHaveBeenCalledWith('req-123', true, false);
|
|
});
|
|
|
|
it('未知消息类型返回错误', async () => {
|
|
const ws = createMockWSContext();
|
|
const message = JSON.stringify({ type: 'unknown_type' });
|
|
|
|
await handleWebSocketMessage(ws as any, 'session-1', message);
|
|
|
|
expect(ws.send).toHaveBeenCalledWith(
|
|
expect.stringContaining('"type":"error"')
|
|
);
|
|
expect(ws.send).toHaveBeenCalledWith(
|
|
expect.stringContaining('Unknown message type')
|
|
);
|
|
});
|
|
|
|
it('无效 JSON 返回错误', async () => {
|
|
const ws = createMockWSContext();
|
|
await handleWebSocketMessage(ws as any, 'session-1', 'invalid json');
|
|
|
|
expect(ws.send).toHaveBeenCalledWith(
|
|
expect.stringContaining('"type":"error"')
|
|
);
|
|
});
|
|
|
|
it('处理 ArrayBuffer 数据', async () => {
|
|
const ws = createMockWSContext();
|
|
|
|
const message = { type: 'message', payload: { content: 'ArrayBuffer test' } };
|
|
const encoder = new TextEncoder();
|
|
const buffer = encoder.encode(JSON.stringify(message)).buffer;
|
|
|
|
await handleWebSocketMessage(ws as any, 'session-1', buffer);
|
|
|
|
expect(mockAddMessage).toHaveBeenCalledWith('session-1', {
|
|
role: 'user',
|
|
content: 'ArrayBuffer test',
|
|
});
|
|
});
|
|
|
|
it('空 content 处理正确', async () => {
|
|
const ws = createMockWSContext();
|
|
const message = JSON.stringify({
|
|
type: 'message',
|
|
payload: {},
|
|
});
|
|
|
|
await handleWebSocketMessage(ws as any, 'session-1', message);
|
|
|
|
expect(mockAddMessage).toHaveBeenCalledWith('session-1', {
|
|
role: 'user',
|
|
content: '',
|
|
});
|
|
});
|
|
|
|
it('处理 Blob 数据', async () => {
|
|
const ws = createMockWSContext();
|
|
|
|
const message = { type: 'message', payload: { content: 'Blob test' } };
|
|
const blob = new Blob([JSON.stringify(message)]);
|
|
|
|
await handleWebSocketMessage(ws as any, 'session-1', blob);
|
|
|
|
expect(mockAddMessage).toHaveBeenCalledWith('session-1', {
|
|
role: 'user',
|
|
content: 'Blob test',
|
|
});
|
|
});
|
|
|
|
it('处理非标准数据类型(转为字符串)', async () => {
|
|
const ws = createMockWSContext();
|
|
|
|
// 使用一个对象作为数据,它会被 String() 转换
|
|
const message = { type: 'cancel' };
|
|
const objData = { toString: () => JSON.stringify(message) };
|
|
|
|
await handleWebSocketMessage(ws as any, 'session-1', objData);
|
|
|
|
expect(cancelProcessing).toHaveBeenCalledWith('session-1');
|
|
});
|
|
|
|
it('addMessage 返回 null 时不调用 processMessage', async () => {
|
|
const ws = createMockWSContext();
|
|
mockAddMessage.mockReturnValue(null);
|
|
|
|
const message = JSON.stringify({
|
|
type: 'message',
|
|
payload: { content: 'Test' },
|
|
});
|
|
|
|
await handleWebSocketMessage(ws as any, 'session-1', message);
|
|
|
|
expect(mockAddMessage).toHaveBeenCalled();
|
|
expect(processMessage).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('处理 tool_response 类型消息(TODO 场景)', async () => {
|
|
const ws = createMockWSContext();
|
|
const message = JSON.stringify({
|
|
type: 'tool_response',
|
|
payload: { toolId: 'tool-123', result: 'success' },
|
|
});
|
|
|
|
// 当前实现是 TODO,所以不会有任何处理
|
|
await handleWebSocketMessage(ws as any, 'session-1', message);
|
|
|
|
// 不应该发送错误消息
|
|
expect(ws.send).not.toHaveBeenCalledWith(
|
|
expect.stringContaining('"type":"error"')
|
|
);
|
|
});
|
|
|
|
it('permission_response 无 requestId 时不调用 handler', async () => {
|
|
const ws = createMockWSContext();
|
|
const message = JSON.stringify({
|
|
type: 'permission_response',
|
|
payload: { allow: true },
|
|
});
|
|
|
|
await handleWebSocketMessage(ws as any, 'session-1', message);
|
|
|
|
expect(handlePermissionResponse).not.toHaveBeenCalled();
|
|
});
|
|
|
|
it('permission_response handler 返回 false 时打印警告', async () => {
|
|
const consoleWarnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {});
|
|
const { handlePermissionResponse: mockHandler } = await import('../../src/permission/handler.js');
|
|
vi.mocked(mockHandler).mockReturnValue(false);
|
|
|
|
const ws = createMockWSContext();
|
|
const message = JSON.stringify({
|
|
type: 'permission_response',
|
|
payload: { requestId: 'unknown-req', allow: true },
|
|
});
|
|
|
|
await handleWebSocketMessage(ws as any, 'session-1', message);
|
|
|
|
expect(consoleWarnSpy).toHaveBeenCalledWith(
|
|
expect.stringContaining('unknown-req')
|
|
);
|
|
consoleWarnSpy.mockRestore();
|
|
});
|
|
});
|
|
|
|
describe('handleWebSocketClose - 关闭处理', () => {
|
|
it('从 connections 中移除连接', () => {
|
|
const sessionId = getUniqueSessionId('close');
|
|
const ws = createMockWSContext();
|
|
mockExists.mockReturnValue(true);
|
|
|
|
handleWebSocket(ws as any, sessionId);
|
|
expect(getSessionConnections(sessionId).size).toBe(1);
|
|
|
|
handleWebSocketClose(ws as any, sessionId);
|
|
expect(getSessionConnections(sessionId).size).toBe(0);
|
|
});
|
|
|
|
it('最后一个连接关闭时更新 session 状态为 idle', () => {
|
|
const sessionId = getUniqueSessionId('idle');
|
|
const ws = createMockWSContext();
|
|
mockExists.mockReturnValue(true);
|
|
|
|
handleWebSocket(ws as any, sessionId);
|
|
handleWebSocketClose(ws as any, sessionId);
|
|
|
|
expect(mockUpdateStatus).toHaveBeenLastCalledWith(sessionId, 'idle');
|
|
});
|
|
|
|
it('多个连接时关闭一个不影响其他', () => {
|
|
const sessionId = getUniqueSessionId('multi');
|
|
const ws1 = createMockWSContext();
|
|
const ws2 = createMockWSContext();
|
|
mockExists.mockReturnValue(true);
|
|
|
|
handleWebSocket(ws1 as any, sessionId);
|
|
handleWebSocket(ws2 as any, sessionId);
|
|
expect(getSessionConnections(sessionId).size).toBe(2);
|
|
|
|
handleWebSocketClose(ws1 as any, sessionId);
|
|
expect(getSessionConnections(sessionId).size).toBe(1);
|
|
expect(getSessionConnections(sessionId).has(ws2 as any)).toBe(true);
|
|
});
|
|
});
|
|
|
|
describe('broadcastToSession - 会话广播', () => {
|
|
it('向所有连接发送消息', () => {
|
|
const ws1 = createMockWSContext();
|
|
const ws2 = createMockWSContext();
|
|
mockExists.mockReturnValue(true);
|
|
|
|
handleWebSocket(ws1 as any, 'session-1');
|
|
handleWebSocket(ws2 as any, 'session-1');
|
|
|
|
vi.clearAllMocks();
|
|
|
|
broadcastToSession('session-1', {
|
|
type: 'chunk',
|
|
sessionId: 'session-1',
|
|
payload: { content: 'test' },
|
|
});
|
|
|
|
expect(ws1.send).toHaveBeenCalledTimes(1);
|
|
expect(ws2.send).toHaveBeenCalledTimes(1);
|
|
expect(ws1.send).toHaveBeenCalledWith(expect.stringContaining('"type":"chunk"'));
|
|
});
|
|
|
|
it('无连接时不抛出错误', () => {
|
|
expect(() => {
|
|
broadcastToSession('non-existent', {
|
|
type: 'chunk',
|
|
sessionId: 'non-existent',
|
|
payload: {},
|
|
});
|
|
}).not.toThrow();
|
|
});
|
|
|
|
it('发送失败时继续发送给其他连接', () => {
|
|
mockExists.mockReturnValue(true);
|
|
|
|
const ws1 = createMockWSContext();
|
|
const ws2 = createMockWSContext();
|
|
|
|
handleWebSocket(ws1 as any, 'session-1');
|
|
handleWebSocket(ws2 as any, 'session-1');
|
|
|
|
// Now set up ws1.send to throw
|
|
ws1.send.mockImplementation(() => {
|
|
throw new Error('Connection closed');
|
|
});
|
|
|
|
expect(() => {
|
|
broadcastToSession('session-1', {
|
|
type: 'chunk',
|
|
sessionId: 'session-1',
|
|
payload: {},
|
|
});
|
|
}).not.toThrow();
|
|
|
|
expect(ws2.send).toHaveBeenCalled();
|
|
});
|
|
});
|
|
|
|
describe('getSessionConnections - 获取会话连接', () => {
|
|
it('返回会话的所有连接', () => {
|
|
const sessionId = getUniqueSessionId('get-conns');
|
|
const ws = createMockWSContext();
|
|
mockExists.mockReturnValue(true);
|
|
|
|
handleWebSocket(ws as any, sessionId);
|
|
|
|
const connections = getSessionConnections(sessionId);
|
|
expect(connections.size).toBe(1);
|
|
});
|
|
|
|
it('不存在的会话返回空 Set', () => {
|
|
const connections = getSessionConnections('non-existent-unique-12345');
|
|
expect(connections.size).toBe(0);
|
|
});
|
|
});
|
|
|
|
describe('getConnectionStats - 连接统计', () => {
|
|
it('返回正确的统计信息', () => {
|
|
// Get initial stats to account for connections from previous tests
|
|
const initialStats = getConnectionStats();
|
|
|
|
const sessionId1 = getUniqueSessionId('stats-1');
|
|
const sessionId2 = getUniqueSessionId('stats-2');
|
|
const ws1 = createMockWSContext();
|
|
const ws2 = createMockWSContext();
|
|
const ws3 = createMockWSContext();
|
|
|
|
mockExists.mockReturnValue(true);
|
|
|
|
handleWebSocket(ws1 as any, sessionId1);
|
|
handleWebSocket(ws2 as any, sessionId1);
|
|
handleWebSocket(ws3 as any, sessionId2);
|
|
|
|
const stats = getConnectionStats();
|
|
// We added 2 sessions and 3 connections
|
|
expect(stats.sessions).toBe(initialStats.sessions + 2);
|
|
expect(stats.connections).toBe(initialStats.connections + 3);
|
|
});
|
|
});
|
|
});
|