import { describe, it, expect, vi, beforeEach } from 'vitest'; // --------------------------------------------------------------------------- // Mocks -- must be declared before importing the module under test. // --------------------------------------------------------------------------- // The middleware module reads `config.port` at module scope to build the // allowedHosts set, so we need the mock in place before the import. vi.mock('../src/config/index.js', () => ({ config: { port: 3000, }, })); vi.mock('../src/utils/logger.js', () => ({ logger: { error: vi.fn(), warn: vi.fn(), info: vi.fn(), debug: vi.fn(), child: vi.fn(() => ({ error: vi.fn(), warn: vi.fn(), info: vi.fn(), debug: vi.fn(), })), }, })); vi.mock('../src/utils/errors.js', () => ({ sanitizeErrorMessage: vi.fn((msg: string) => msg), })); import { dnsRebindingGuard } from '../src/server/middleware.js'; // --------------------------------------------------------------------------- // Helpers -- lightweight Express req/res/next fakes // --------------------------------------------------------------------------- interface FakeRequest { headers: Record; ip?: string; method?: string; originalUrl?: string; } interface FakeResponse { statusCode: number; body: unknown; status: (code: number) => FakeResponse; json: (data: unknown) => FakeResponse; } function createReq(host?: string): FakeRequest { return { headers: host !== undefined ? { host } : {}, ip: '127.0.0.1', method: 'GET', originalUrl: '/test', }; } function createRes(): FakeResponse { const res: FakeResponse = { statusCode: 200, body: undefined, status(code: number) { res.statusCode = code; return res; }, json(data: unknown) { res.body = data; return res; }, }; return res; } // --------------------------------------------------------------------------- // dnsRebindingGuard // --------------------------------------------------------------------------- describe('dnsRebindingGuard', () => { let next: ReturnType; beforeEach(() => { next = vi.fn(); }); it('allows requests with Host: 127.0.0.1', () => { const req = createReq('127.0.0.1'); const res = createRes(); dnsRebindingGuard(req as any, res as any, next); expect(next).toHaveBeenCalledTimes(1); expect(res.statusCode).toBe(200); }); it('allows requests with Host: localhost', () => { const req = createReq('localhost'); const res = createRes(); dnsRebindingGuard(req as any, res as any, next); expect(next).toHaveBeenCalledTimes(1); }); it('allows requests with Host: localhost:', () => { const req = createReq('localhost:3000'); const res = createRes(); dnsRebindingGuard(req as any, res as any, next); expect(next).toHaveBeenCalledTimes(1); }); it('allows requests with Host: 127.0.0.1:', () => { const req = createReq('127.0.0.1:3000'); const res = createRes(); dnsRebindingGuard(req as any, res as any, next); expect(next).toHaveBeenCalledTimes(1); }); it('blocks requests with Host: evil.com', () => { const req = createReq('evil.com'); const res = createRes(); dnsRebindingGuard(req as any, res as any, next); expect(next).not.toHaveBeenCalled(); expect(res.statusCode).toBe(403); expect(res.body).toEqual({ error: 'Forbidden' }); }); it('blocks requests with no Host header', () => { const req: FakeRequest = { headers: {}, ip: '127.0.0.1', method: 'GET', originalUrl: '/test', }; const res = createRes(); dnsRebindingGuard(req as any, res as any, next); expect(next).not.toHaveBeenCalled(); expect(res.statusCode).toBe(403); expect(res.body).toEqual({ error: 'Forbidden' }); }); });