From 630ce9fd4ba300b985d72b14a630b69aca75c7f6 Mon Sep 17 00:00:00 2001 From: kurihada Date: Thu, 11 Dec 2025 23:12:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=20Hook=20=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 参考 open-code 的实现,添加工具执行前后的 hook 功能: - 添加 Hook 类型定义 (tool.execute.before/after, file.edited/created/deleted 等) - 实现 HookManager 管理器,支持插件注册和事件触发 - 实现配置文件加载器,支持 .ai-assistant.json/jsonc 格式 - 支持 glob 模式匹配文件触发 shell 命令 - 集成到 Agent 工具执行流程 - 添加 minimatch 依赖用于 glob 匹配 - 编写完整测试用例 (27 个测试) 配置示例: ```json { "hooks": { "file_edited": { "*.ts": [{ "command": ["npx", "tsc", "--noEmit"] }] } } } ``` --- package-lock.json | 37 +++ package.json | 1 + src/core/agent.ts | 70 ++++- src/hooks/config-loader.ts | 232 +++++++++++++++ src/hooks/index.ts | 48 ++++ src/hooks/manager.ts | 495 ++++++++++++++++++++++++++++++++ src/hooks/types.ts | 215 ++++++++++++++ tests/hooks/hooks.test.ts | 570 +++++++++++++++++++++++++++++++++++++ 8 files changed, 1667 insertions(+), 1 deletion(-) create mode 100644 src/hooks/config-loader.ts create mode 100644 src/hooks/index.ts create mode 100644 src/hooks/manager.ts create mode 100644 src/hooks/types.ts create mode 100644 tests/hooks/hooks.test.ts diff --git a/package-lock.json b/package-lock.json index 139a009..470687a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -18,6 +18,7 @@ "commander": "^12.1.0", "inquirer": "^12.0.0", "js-yaml": "^4.1.1", + "minimatch": "^10.1.1", "ora": "^8.1.0", "qwen-ai-provider-v5": "^1.0.2", "tree-sitter-bash": "^0.25.1", @@ -972,6 +973,27 @@ } } }, + "node_modules/@isaacs/balanced-match": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/@isaacs/balanced-match/-/balanced-match-4.0.1.tgz", + "integrity": "sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==", + "license": "MIT", + "engines": { + "node": "20 || >=22" + } + }, + "node_modules/@isaacs/brace-expansion": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/@isaacs/brace-expansion/-/brace-expansion-5.0.0.tgz", + "integrity": "sha512-ZT55BDLV0yv0RBm2czMiZ+SqCGO7AvmOM3G/w2xhVPH+te0aKgFjmBvGlL1dH+ql2tgGO3MVrbb3jCKyvpgnxA==", + "license": "MIT", + "dependencies": { + "@isaacs/balanced-match": "^4.0.1" + }, + "engines": { + "node": "20 || >=22" + } + }, "node_modules/@jridgewell/resolve-uri": { "version": "3.1.2", "resolved": "https://registry.npmjs.org/@jridgewell/resolve-uri/-/resolve-uri-3.1.2.tgz", @@ -2438,6 +2460,21 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/minimatch": { + "version": "10.1.1", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-10.1.1.tgz", + "integrity": "sha512-enIvLvRAFZYXJzkCYG5RKmPfrFArdLv+R+lbQ53BmIMLIry74bjKzX6iHAm8WYamJkhSSEabrWN5D97XnKObjQ==", + "license": "BlueOak-1.0.0", + "dependencies": { + "@isaacs/brace-expansion": "^5.0.0" + }, + "engines": { + "node": "20 || >=22" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/ms": { "version": "2.1.3", "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", diff --git a/package.json b/package.json index 746d179..993ab05 100644 --- a/package.json +++ b/package.json @@ -35,6 +35,7 @@ "commander": "^12.1.0", "inquirer": "^12.0.0", "js-yaml": "^4.1.1", + "minimatch": "^10.1.1", "ora": "^8.1.0", "qwen-ai-provider-v5": "^1.0.2", "tree-sitter-bash": "^0.25.1", diff --git a/src/core/agent.ts b/src/core/agent.ts index 2931526..735871c 100644 --- a/src/core/agent.ts +++ b/src/core/agent.ts @@ -19,6 +19,7 @@ import type { AgentInfo, ImageData } from '../agent/types.js'; import { agentRegistry, AgentExecutor } from '../agent/index.js'; import { loadVisionConfig } from '../utils/config.js'; import { getModelFactory } from './providers.js'; +import { getHookManager } from '../hooks/index.js'; export class Agent { private getModel: (model: string) => LanguageModel; @@ -142,6 +143,7 @@ export class Agent { private getVercelTools(): Record { const vercelTools: Record = {}; const availableTools = this.getAvailableTools(); + const hookManager = getHookManager(); for (const tool of availableTools) { const schema = buildZodSchema(tool.parameters); @@ -150,7 +152,73 @@ export class Agent { description: tool.description, inputSchema: schema, execute: async (params) => { - const result = await tool.execute(params as Record); + const args = params as Record; + const callId = `${tool.name}-${Date.now()}`; + const sessionId = this.sessionManager?.getSession()?.id || 'default'; + + // 触发工具执行前 hook + let finalArgs = args; + if (hookManager) { + const beforeOutput = await hookManager.triggerToolExecuteBefore({ + tool: tool.name, + sessionId, + callId, + args, + }); + + // 如果 hook 指定跳过,直接返回 + if (beforeOutput.skip && beforeOutput.skipResult) { + return beforeOutput.skipResult; + } + + finalArgs = beforeOutput.args; + } + + // 执行工具 + const startTime = Date.now(); + let result = await tool.execute(finalArgs); + const duration = Date.now() - startTime; + + // 触发工具执行后 hook + if (hookManager) { + const afterOutput = await hookManager.triggerToolExecuteAfter( + { + tool: tool.name, + sessionId, + callId, + args: finalArgs, + duration, + }, + result + ); + result = afterOutput.result; + + // 对于文件操作工具,触发相应的文件 hook + if (result.success) { + const filePath = finalArgs.path as string | undefined; + if (filePath) { + if (tool.name === 'write_file') { + await hookManager.triggerFileCreated({ + path: filePath, + tool: tool.name, + sessionId, + }); + } else if (tool.name === 'edit_file') { + await hookManager.triggerFileEdited({ + path: filePath, + tool: tool.name, + sessionId, + }); + } else if (tool.name === 'delete_file') { + await hookManager.triggerFileDeleted({ + path: filePath, + tool: tool.name, + sessionId, + }); + } + } + } + } // 如果是 tool_search 调用,解析结果并注入发现的工具 if (tool.name === 'tool_search' && result.success) { diff --git a/src/hooks/config-loader.ts b/src/hooks/config-loader.ts new file mode 100644 index 0000000..49758a7 --- /dev/null +++ b/src/hooks/config-loader.ts @@ -0,0 +1,232 @@ +/** + * Hook 配置加载器 + * + * 从项目配置文件加载 hook 配置 + * 支持 .ai-assistant.json, .ai-assistant.jsonc, ai-assistant.config.json 等格式 + */ + +import * as fs from 'fs/promises'; +import * as path from 'path'; +import type { HookConfig, ShellCommandConfig, FileHookConfig } from './types.js'; + +// 支持的配置文件名 +const CONFIG_FILE_NAMES = [ + '.ai-assistant.json', + '.ai-assistant.jsonc', + 'ai-assistant.config.json', + '.ai-assistantrc', + '.ai-assistantrc.json', +]; + +/** + * 完整的配置文件结构 + */ +export interface ProjectConfig { + /** Hook 配置 */ + hooks?: HookConfig; + /** 插件列表 */ + plugins?: string[]; + /** 其他配置... */ + [key: string]: unknown; +} + +/** + * 移除 JSON 中的注释(支持 JSONC 格式) + */ +function stripJsonComments(jsonString: string): string { + // 移除单行注释 // ... + let result = jsonString.replace(/\/\/.*$/gm, ''); + // 移除多行注释 /* ... */ + result = result.replace(/\/\*[\s\S]*?\*\//g, ''); + return result; +} + +/** + * 解析 JSON 文件(支持 JSONC) + */ +async function parseJsonFile(filePath: string): Promise { + try { + const content = await fs.readFile(filePath, 'utf-8'); + const cleanContent = stripJsonComments(content); + return JSON.parse(cleanContent); + } catch { + return null; + } +} + +/** + * 在目录中查找配置文件 + */ +async function findConfigFile(directory: string): Promise { + for (const fileName of CONFIG_FILE_NAMES) { + const filePath = path.join(directory, fileName); + try { + await fs.access(filePath); + return filePath; + } catch { + // 文件不存在,继续查找 + } + } + return null; +} + +/** + * 验证 ShellCommandConfig + */ +function validateShellCommandConfig(config: unknown): config is ShellCommandConfig { + if (typeof config !== 'object' || config === null) return false; + const obj = config as Record; + + // command 必须是非空字符串数组 + if (!Array.isArray(obj.command) || obj.command.length === 0) return false; + if (!obj.command.every((c) => typeof c === 'string')) return false; + + // environment 如果存在,必须是对象 + if (obj.environment !== undefined) { + if (typeof obj.environment !== 'object' || obj.environment === null) return false; + const env = obj.environment as Record; + if (!Object.values(env).every((v) => typeof v === 'string')) return false; + } + + // timeout 如果存在,必须是正数 + if (obj.timeout !== undefined) { + if (typeof obj.timeout !== 'number' || obj.timeout <= 0) return false; + } + + // cwd 如果存在,必须是字符串 + if (obj.cwd !== undefined) { + if (typeof obj.cwd !== 'string') return false; + } + + return true; +} + +/** + * 验证 FileHookConfig + */ +function validateFileHookConfig(config: unknown): config is FileHookConfig { + if (typeof config !== 'object' || config === null) return false; + const obj = config as Record; + + for (const [pattern, commands] of Object.entries(obj)) { + // pattern 必须是非空字符串 + if (typeof pattern !== 'string' || pattern.length === 0) return false; + + // commands 必须是 ShellCommandConfig 数组 + if (!Array.isArray(commands)) return false; + if (!commands.every(validateShellCommandConfig)) return false; + } + + return true; +} + +/** + * 验证 HookConfig + */ +function validateHookConfig(config: unknown): config is HookConfig { + if (typeof config !== 'object' || config === null) return false; + const obj = config as Record; + + // file_edited + if (obj.file_edited !== undefined) { + if (!validateFileHookConfig(obj.file_edited)) return false; + } + + // file_created + if (obj.file_created !== undefined) { + if (!validateFileHookConfig(obj.file_created)) return false; + } + + // file_deleted + if (obj.file_deleted !== undefined) { + if (!validateFileHookConfig(obj.file_deleted)) return false; + } + + // session_completed + if (obj.session_completed !== undefined) { + if (!Array.isArray(obj.session_completed)) return false; + if (!obj.session_completed.every(validateShellCommandConfig)) return false; + } + + return true; +} + +/** + * 加载项目配置 + */ +export async function loadProjectConfig(directory: string): Promise { + const configPath = await findConfigFile(directory); + if (!configPath) return null; + + const config = await parseJsonFile(configPath); + return config; +} + +/** + * 加载 Hook 配置 + */ +export async function loadHookConfig(directory: string): Promise { + const projectConfig = await loadProjectConfig(directory); + if (!projectConfig?.hooks) return null; + + // 验证配置 + if (!validateHookConfig(projectConfig.hooks)) { + console.warn('Invalid hook configuration in project config file'); + return null; + } + + return projectConfig.hooks; +} + +/** + * 加载插件列表 + */ +export async function loadPluginList(directory: string): Promise { + const projectConfig = await loadProjectConfig(directory); + if (!projectConfig?.plugins) return []; + + // 验证插件列表 + if (!Array.isArray(projectConfig.plugins)) return []; + if (!projectConfig.plugins.every((p) => typeof p === 'string')) return []; + + return projectConfig.plugins; +} + +/** + * 创建默认配置文件 + */ +export async function createDefaultConfig(directory: string): Promise { + const configPath = path.join(directory, '.ai-assistant.json'); + + const defaultConfig: ProjectConfig = { + hooks: { + file_edited: { + '*.ts': [ + { + command: ['npx', 'tsc', '--noEmit'], + timeout: 30000, + }, + ], + '*.{js,jsx,ts,tsx}': [ + { + command: ['npx', 'eslint', '--fix'], + timeout: 30000, + }, + ], + }, + file_created: {}, + file_deleted: {}, + session_completed: [], + }, + plugins: [], + }; + + await fs.writeFile(configPath, JSON.stringify(defaultConfig, null, 2)); +} + +/** + * 获取配置文件路径(如果存在) + */ +export async function getConfigFilePath(directory: string): Promise { + return findConfigFile(directory); +} diff --git a/src/hooks/index.ts b/src/hooks/index.ts new file mode 100644 index 0000000..d38ccd3 --- /dev/null +++ b/src/hooks/index.ts @@ -0,0 +1,48 @@ +/** + * Hook 系统模块 + * + * 提供工具执行前后的 hook 功能,支持自定义命令执行 + * 参考 open-code 的实现 + */ + +// Hook 管理器 +export { + HookManager, + getHookManager, + initHookManager, + resetHookManager, +} from './manager.js'; + +// 配置加载 +export { + loadProjectConfig, + loadHookConfig, + loadPluginList, + createDefaultConfig, + getConfigFilePath, + type ProjectConfig, +} from './config-loader.js'; + +// 类型导出 +export type { + HookType, + HookConfig, + HookEvent, + HookEventListener, + ShellCommandConfig, + FileHookConfig, + Hooks, + Plugin, + PluginInput, + ToolExecuteBeforeInput, + ToolExecuteBeforeOutput, + ToolExecuteAfterInput, + ToolExecuteAfterOutput, + SessionStartInput, + SessionEndInput, + MessageBeforeInput, + MessageBeforeOutput, + MessageAfterInput, + FileChangeInput, + FileChangeOutput, +} from './types.js'; diff --git a/src/hooks/manager.ts b/src/hooks/manager.ts new file mode 100644 index 0000000..5dc9785 --- /dev/null +++ b/src/hooks/manager.ts @@ -0,0 +1,495 @@ +/** + * Hook 管理器 + * + * 负责 hook 的注册、触发和管理 + */ + +import { spawn } from 'child_process'; +import { minimatch } from 'minimatch'; +import type { + Hooks, + HookType, + HookConfig, + HookEvent, + HookEventListener, + ShellCommandConfig, + FileHookConfig, + ToolExecuteBeforeInput, + ToolExecuteBeforeOutput, + ToolExecuteAfterInput, + ToolExecuteAfterOutput, + SessionStartInput, + SessionEndInput, + MessageBeforeInput, + MessageBeforeOutput, + MessageAfterInput, + FileChangeInput, + FileChangeOutput, + Plugin, + PluginInput, +} from './types.js'; + +/** + * Hook 管理器 + */ +export class HookManager { + /** 已注册的 hooks */ + private hooks: Hooks[] = []; + + /** 配置型 hooks(从配置文件加载) */ + private configHooks: HookConfig | null = null; + + /** 事件监听器 */ + private eventListeners: HookEventListener[] = []; + + /** 当前工作目录 */ + private workdir: string; + + /** 会话 ID */ + private sessionId: string; + + constructor(workdir: string, sessionId?: string) { + this.workdir = workdir; + this.sessionId = sessionId || 'default'; + } + + /** + * 注册插件 + */ + async registerPlugin(plugin: Plugin): Promise { + const input: PluginInput = { + workdir: this.workdir, + sessionId: this.sessionId, + }; + + try { + const hooks = await plugin(input); + this.hooks.push(hooks); + } catch (error) { + console.error('Failed to register plugin:', error); + } + } + + /** + * 注册 hooks 对象 + */ + registerHooks(hooks: Hooks): void { + this.hooks.push(hooks); + } + + /** + * 设置配置型 hooks + */ + setConfigHooks(config: HookConfig): void { + this.configHooks = config; + } + + /** + * 添加事件监听器 + */ + addEventListener(listener: HookEventListener): void { + this.eventListeners.push(listener); + } + + /** + * 移除事件监听器 + */ + removeEventListener(listener: HookEventListener): void { + const index = this.eventListeners.indexOf(listener); + if (index !== -1) { + this.eventListeners.splice(index, 1); + } + } + + /** + * 发送事件 + */ + private emitEvent(type: HookType, data: unknown): void { + const event: HookEvent = { + type, + timestamp: Date.now(), + data, + }; + + for (const listener of this.eventListeners) { + try { + listener(event); + } catch (error) { + console.error('Event listener error:', error); + } + } + } + + /** + * 触发工具执行前 hook + */ + async triggerToolExecuteBefore( + input: ToolExecuteBeforeInput + ): Promise { + const output: ToolExecuteBeforeOutput = { + args: { ...input.args }, + }; + + for (const hook of this.hooks) { + if (hook['tool.execute.before']) { + try { + await hook['tool.execute.before'](input, output); + } catch (error) { + console.error('Hook tool.execute.before error:', error); + } + } + } + + this.emitEvent('tool.execute.before', { input, output }); + return output; + } + + /** + * 触发工具执行后 hook + */ + async triggerToolExecuteAfter( + input: ToolExecuteAfterInput, + result: ToolExecuteAfterOutput['result'] + ): Promise { + const output: ToolExecuteAfterOutput = { + result: { ...result }, + }; + + for (const hook of this.hooks) { + if (hook['tool.execute.after']) { + try { + await hook['tool.execute.after'](input, output); + } catch (error) { + console.error('Hook tool.execute.after error:', error); + } + } + } + + this.emitEvent('tool.execute.after', { input, output }); + return output; + } + + /** + * 触发会话开始 hook + */ + async triggerSessionStart(input: SessionStartInput): Promise { + for (const hook of this.hooks) { + if (hook['session.start']) { + try { + await hook['session.start'](input); + } catch (error) { + console.error('Hook session.start error:', error); + } + } + } + + this.emitEvent('session.start', input); + } + + /** + * 触发会话结束 hook + */ + async triggerSessionEnd(input: SessionEndInput): Promise { + for (const hook of this.hooks) { + if (hook['session.end']) { + try { + await hook['session.end'](input); + } catch (error) { + console.error('Hook session.end error:', error); + } + } + } + + // 执行配置型 session_completed hooks + if (this.configHooks?.session_completed) { + await this.executeShellCommands(this.configHooks.session_completed); + } + + this.emitEvent('session.end', input); + } + + /** + * 触发消息前 hook + */ + async triggerMessageBefore( + input: MessageBeforeInput + ): Promise { + const output: MessageBeforeOutput = { + content: input.content, + }; + + for (const hook of this.hooks) { + if (hook['message.before']) { + try { + await hook['message.before'](input, output); + } catch (error) { + console.error('Hook message.before error:', error); + } + } + } + + this.emitEvent('message.before', { input, output }); + return output; + } + + /** + * 触发消息后 hook + */ + async triggerMessageAfter(input: MessageAfterInput): Promise { + for (const hook of this.hooks) { + if (hook['message.after']) { + try { + await hook['message.after'](input); + } catch (error) { + console.error('Hook message.after error:', error); + } + } + } + + this.emitEvent('message.after', input); + } + + /** + * 触发文件编辑 hook + */ + async triggerFileEdited(input: FileChangeInput): Promise { + const output: FileChangeOutput = {}; + + // 执行插件 hooks + for (const hook of this.hooks) { + if (hook['file.edited']) { + try { + await hook['file.edited'](input, output); + } catch (error) { + console.error('Hook file.edited error:', error); + } + } + } + + // 执行配置型 hooks + if (this.configHooks?.file_edited) { + const results = await this.executeFileHooks( + input.path, + this.configHooks.file_edited + ); + output.commandResults = results; + } + + this.emitEvent('file.edited', { input, output }); + return output; + } + + /** + * 触发文件创建 hook + */ + async triggerFileCreated(input: FileChangeInput): Promise { + const output: FileChangeOutput = {}; + + // 执行插件 hooks + for (const hook of this.hooks) { + if (hook['file.created']) { + try { + await hook['file.created'](input, output); + } catch (error) { + console.error('Hook file.created error:', error); + } + } + } + + // 执行配置型 hooks + if (this.configHooks?.file_created) { + const results = await this.executeFileHooks( + input.path, + this.configHooks.file_created + ); + output.commandResults = results; + } + + this.emitEvent('file.created', { input, output }); + return output; + } + + /** + * 触发文件删除 hook + */ + async triggerFileDeleted(input: FileChangeInput): Promise { + const output: FileChangeOutput = {}; + + // 执行插件 hooks + for (const hook of this.hooks) { + if (hook['file.deleted']) { + try { + await hook['file.deleted'](input, output); + } catch (error) { + console.error('Hook file.deleted error:', error); + } + } + } + + // 执行配置型 hooks + if (this.configHooks?.file_deleted) { + const results = await this.executeFileHooks( + input.path, + this.configHooks.file_deleted + ); + output.commandResults = results; + } + + this.emitEvent('file.deleted', { input, output }); + return output; + } + + /** + * 执行文件 hooks + * 根据文件路径匹配 glob 模式并执行对应命令 + */ + private async executeFileHooks( + filePath: string, + config: FileHookConfig + ): Promise { + const results: FileChangeOutput['commandResults'] = []; + + for (const [pattern, commands] of Object.entries(config)) { + // 使用 minimatch 进行 glob 匹配 + if (minimatch(filePath, pattern, { matchBase: true })) { + const commandResults = await this.executeShellCommands(commands, { + FILE_PATH: filePath, + }); + results.push(...commandResults); + } + } + + return results; + } + + /** + * 执行 shell 命令列表 + */ + private async executeShellCommands( + commands: ShellCommandConfig[], + extraEnv?: Record + ): Promise> { + const results: Array<{ + command: string[]; + success: boolean; + output?: string; + error?: string; + }> = []; + + for (const cmdConfig of commands) { + const result = await this.executeShellCommand(cmdConfig, extraEnv); + results.push(result); + } + + return results; + } + + /** + * 执行单个 shell 命令 + */ + private executeShellCommand( + config: ShellCommandConfig, + extraEnv?: Record + ): Promise<{ command: string[]; success: boolean; output?: string; error?: string }> { + return new Promise((resolve) => { + const [cmd, ...args] = config.command; + const timeout = config.timeout || 30000; + const cwd = config.cwd || this.workdir; + + const env = { + ...process.env, + ...config.environment, + ...extraEnv, + }; + + let stdout = ''; + let stderr = ''; + + const child = spawn(cmd, args, { + cwd, + env, + shell: true, + }); + + const timer = setTimeout(() => { + child.kill('SIGTERM'); + resolve({ + command: config.command, + success: false, + error: `Command timed out after ${timeout}ms`, + }); + }, timeout); + + child.stdout?.on('data', (data) => { + stdout += data.toString(); + }); + + child.stderr?.on('data', (data) => { + stderr += data.toString(); + }); + + child.on('close', (code) => { + clearTimeout(timer); + resolve({ + command: config.command, + success: code === 0, + output: stdout.trim() || undefined, + error: stderr.trim() || undefined, + }); + }); + + child.on('error', (error) => { + clearTimeout(timer); + resolve({ + command: config.command, + success: false, + error: error.message, + }); + }); + }); + } + + /** + * 获取所有已注册的 hooks 数量 + */ + getHookCount(): number { + return this.hooks.length; + } + + /** + * 清空所有 hooks + */ + clear(): void { + this.hooks = []; + this.configHooks = null; + this.eventListeners = []; + } +} + +// 全局 Hook 管理器实例 +let globalHookManager: HookManager | null = null; + +/** + * 获取全局 Hook 管理器 + */ +export function getHookManager(): HookManager | null { + return globalHookManager; +} + +/** + * 初始化全局 Hook 管理器 + */ +export function initHookManager(workdir: string, sessionId?: string): HookManager { + globalHookManager = new HookManager(workdir, sessionId); + return globalHookManager; +} + +/** + * 重置全局 Hook 管理器 + */ +export function resetHookManager(): void { + if (globalHookManager) { + globalHookManager.clear(); + globalHookManager = null; + } +} diff --git a/src/hooks/types.ts b/src/hooks/types.ts new file mode 100644 index 0000000..b081812 --- /dev/null +++ b/src/hooks/types.ts @@ -0,0 +1,215 @@ +/** + * Hook 系统类型定义 + * + * 参考 open-code 的 hook 实现 + */ + +import type { Tool, ToolResult } from '../types/index.js'; + +/** + * Hook 类型枚举 + */ +export type HookType = + | 'tool.execute.before' // 工具执行前 + | 'tool.execute.after' // 工具执行后 + | 'session.start' // 会话开始 + | 'session.end' // 会话结束 + | 'message.before' // 消息发送前 + | 'message.after' // 消息接收后 + | 'file.edited' // 文件被编辑后 + | 'file.created' // 文件被创建后 + | 'file.deleted'; // 文件被删除后 + +/** + * 工具执行前 Hook 的输入 + */ +export interface ToolExecuteBeforeInput { + tool: string; + sessionId: string; + callId: string; + args: Record; +} + +/** + * 工具执行前 Hook 的输出(可修改) + */ +export interface ToolExecuteBeforeOutput { + args: Record; + /** 设为 true 可阻止工具执行 */ + skip?: boolean; + /** 跳过时返回的结果 */ + skipResult?: ToolResult; +} + +/** + * 工具执行后 Hook 的输入 + */ +export interface ToolExecuteAfterInput { + tool: string; + sessionId: string; + callId: string; + args: Record; + duration: number; // 执行时长(毫秒) +} + +/** + * 工具执行后 Hook 的输出(可修改) + */ +export interface ToolExecuteAfterOutput { + result: ToolResult; +} + +/** + * 会话开始 Hook 的输入 + */ +export interface SessionStartInput { + sessionId: string; + workdir: string; +} + +/** + * 会话结束 Hook 的输入 + */ +export interface SessionEndInput { + sessionId: string; + messageCount: number; + duration: number; // 会话时长(毫秒) +} + +/** + * 消息前 Hook 的输入 + */ +export interface MessageBeforeInput { + sessionId: string; + content: string; +} + +/** + * 消息前 Hook 的输出(可修改) + */ +export interface MessageBeforeOutput { + content: string; + /** 设为 true 可阻止消息发送 */ + skip?: boolean; +} + +/** + * 消息后 Hook 的输入 + */ +export interface MessageAfterInput { + sessionId: string; + content: string; + toolCalls: number; +} + +/** + * 文件变更 Hook 的输入 + */ +export interface FileChangeInput { + path: string; + tool: string; + sessionId: string; +} + +/** + * 文件变更 Hook 的输出 + */ +export interface FileChangeOutput { + /** 执行的命令结果 */ + commandResults?: Array<{ + command: string[]; + success: boolean; + output?: string; + error?: string; + }>; +} + +/** + * Shell 命令配置 + */ +export interface ShellCommandConfig { + /** 命令数组,第一个元素是命令,后面是参数 */ + command: string[]; + /** 环境变量 */ + environment?: Record; + /** 超时时间(毫秒),默认 30000 */ + timeout?: number; + /** 工作目录,默认使用当前目录 */ + cwd?: string; +} + +/** + * 文件 Hook 配置 + * 支持 glob 模式匹配文件 + */ +export interface FileHookConfig { + /** glob 模式 -> 命令配置列表 */ + [pattern: string]: ShellCommandConfig[]; +} + +/** + * Hook 配置 + */ +export interface HookConfig { + /** 文件编辑后执行的 hook */ + file_edited?: FileHookConfig; + /** 文件创建后执行的 hook */ + file_created?: FileHookConfig; + /** 文件删除后执行的 hook */ + file_deleted?: FileHookConfig; + /** 会话完成后执行的命令 */ + session_completed?: ShellCommandConfig[]; +} + +/** + * Hook 函数类型 + */ +export type HookFunction = ( + input: Input, + output: Output +) => Promise; + +/** + * Hook 定义接口 + */ +export interface Hooks { + 'tool.execute.before'?: HookFunction; + 'tool.execute.after'?: HookFunction; + 'session.start'?: (input: SessionStartInput) => Promise; + 'session.end'?: (input: SessionEndInput) => Promise; + 'message.before'?: HookFunction; + 'message.after'?: (input: MessageAfterInput) => Promise; + 'file.edited'?: HookFunction; + 'file.created'?: HookFunction; + 'file.deleted'?: HookFunction; +} + +/** + * 插件输入 + */ +export interface PluginInput { + /** 当前工作目录 */ + workdir: string; + /** 会话 ID */ + sessionId?: string; +} + +/** + * 插件定义 + * 一个插件是一个函数,接收 PluginInput 返回 Hooks + */ +export type Plugin = (input: PluginInput) => Promise; + +/** + * Hook 事件 + */ +export interface HookEvent { + type: HookType; + timestamp: number; + data: unknown; +} + +/** + * Hook 事件监听器 + */ +export type HookEventListener = (event: HookEvent) => void; diff --git a/tests/hooks/hooks.test.ts b/tests/hooks/hooks.test.ts new file mode 100644 index 0000000..fd9e47d --- /dev/null +++ b/tests/hooks/hooks.test.ts @@ -0,0 +1,570 @@ +/** + * Hook 系统测试 + */ + +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import * as path from 'path'; +import * as fs from 'fs/promises'; +import * as os from 'os'; +import { + HookManager, + initHookManager, + getHookManager, + resetHookManager, + loadHookConfig, + loadProjectConfig, + type Hooks, + type HookConfig, +} from '../../src/hooks/index.js'; + +describe('HookManager', () => { + let tempDir: string; + let manager: HookManager; + + beforeEach(async () => { + tempDir = path.join(os.tmpdir(), `hooks-test-${Date.now()}`); + await fs.mkdir(tempDir, { recursive: true }); + manager = new HookManager(tempDir, 'test-session'); + }); + + afterEach(async () => { + resetHookManager(); + try { + await fs.rm(tempDir, { recursive: true, force: true }); + } catch { + // ignore cleanup errors + } + }); + + describe('Plugin Registration', () => { + it('should register hooks from plugin', async () => { + const hooks: Hooks = { + 'tool.execute.before': async (input, output) => { + output.args = { ...output.args, injected: true }; + }, + }; + + manager.registerHooks(hooks); + expect(manager.getHookCount()).toBe(1); + }); + + it('should register multiple plugins', async () => { + const hooks1: Hooks = { + 'tool.execute.before': async () => {}, + }; + const hooks2: Hooks = { + 'tool.execute.after': async () => {}, + }; + + manager.registerHooks(hooks1); + manager.registerHooks(hooks2); + expect(manager.getHookCount()).toBe(2); + }); + }); + + describe('Tool Execute Before Hook', () => { + it('should trigger tool.execute.before hook', async () => { + let triggered = false; + const hooks: Hooks = { + 'tool.execute.before': async (input, output) => { + triggered = true; + expect(input.tool).toBe('test_tool'); + expect(input.sessionId).toBe('test-session'); + expect(input.args).toEqual({ foo: 'bar' }); + }, + }; + + manager.registerHooks(hooks); + + await manager.triggerToolExecuteBefore({ + tool: 'test_tool', + sessionId: 'test-session', + callId: 'call-1', + args: { foo: 'bar' }, + }); + + expect(triggered).toBe(true); + }); + + it('should allow hook to modify args', async () => { + const hooks: Hooks = { + 'tool.execute.before': async (input, output) => { + output.args = { ...output.args, modified: true }; + }, + }; + + manager.registerHooks(hooks); + + const result = await manager.triggerToolExecuteBefore({ + tool: 'test_tool', + sessionId: 'test-session', + callId: 'call-1', + args: { original: true }, + }); + + expect(result.args).toEqual({ original: true, modified: true }); + }); + + it('should allow hook to skip execution', async () => { + const hooks: Hooks = { + 'tool.execute.before': async (input, output) => { + output.skip = true; + output.skipResult = { + success: false, + output: '', + error: 'Blocked by hook', + }; + }, + }; + + manager.registerHooks(hooks); + + const result = await manager.triggerToolExecuteBefore({ + tool: 'test_tool', + sessionId: 'test-session', + callId: 'call-1', + args: {}, + }); + + expect(result.skip).toBe(true); + expect(result.skipResult?.error).toBe('Blocked by hook'); + }); + }); + + describe('Tool Execute After Hook', () => { + it('should trigger tool.execute.after hook', async () => { + let triggered = false; + const hooks: Hooks = { + 'tool.execute.after': async (input, output) => { + triggered = true; + expect(input.tool).toBe('test_tool'); + expect(input.duration).toBeGreaterThanOrEqual(0); + }, + }; + + manager.registerHooks(hooks); + + await manager.triggerToolExecuteAfter( + { + tool: 'test_tool', + sessionId: 'test-session', + callId: 'call-1', + args: {}, + duration: 100, + }, + { success: true, output: 'test output' } + ); + + expect(triggered).toBe(true); + }); + + it('should allow hook to modify result', async () => { + const hooks: Hooks = { + 'tool.execute.after': async (input, output) => { + output.result = { + ...output.result, + output: output.result.output + ' (modified)', + }; + }, + }; + + manager.registerHooks(hooks); + + const result = await manager.triggerToolExecuteAfter( + { + tool: 'test_tool', + sessionId: 'test-session', + callId: 'call-1', + args: {}, + duration: 100, + }, + { success: true, output: 'original' } + ); + + expect(result.result.output).toBe('original (modified)'); + }); + }); + + describe('Session Hooks', () => { + it('should trigger session.start hook', async () => { + let triggered = false; + const hooks: Hooks = { + 'session.start': async (input) => { + triggered = true; + expect(input.sessionId).toBe('session-123'); + expect(input.workdir).toBe('/test/dir'); + }, + }; + + manager.registerHooks(hooks); + + await manager.triggerSessionStart({ + sessionId: 'session-123', + workdir: '/test/dir', + }); + + expect(triggered).toBe(true); + }); + + it('should trigger session.end hook', async () => { + let triggered = false; + const hooks: Hooks = { + 'session.end': async (input) => { + triggered = true; + expect(input.messageCount).toBe(10); + expect(input.duration).toBe(5000); + }, + }; + + manager.registerHooks(hooks); + + await manager.triggerSessionEnd({ + sessionId: 'session-123', + messageCount: 10, + duration: 5000, + }); + + expect(triggered).toBe(true); + }); + }); + + describe('Message Hooks', () => { + it('should trigger message.before hook', async () => { + const hooks: Hooks = { + 'message.before': async (input, output) => { + output.content = input.content.toUpperCase(); + }, + }; + + manager.registerHooks(hooks); + + const result = await manager.triggerMessageBefore({ + sessionId: 'test-session', + content: 'hello world', + }); + + expect(result.content).toBe('HELLO WORLD'); + }); + + it('should allow message.before to skip', async () => { + const hooks: Hooks = { + 'message.before': async (input, output) => { + if (input.content.includes('forbidden')) { + output.skip = true; + } + }, + }; + + manager.registerHooks(hooks); + + const result = await manager.triggerMessageBefore({ + sessionId: 'test-session', + content: 'this is forbidden', + }); + + expect(result.skip).toBe(true); + }); + + it('should trigger message.after hook', async () => { + let triggered = false; + const hooks: Hooks = { + 'message.after': async (input) => { + triggered = true; + expect(input.toolCalls).toBe(3); + }, + }; + + manager.registerHooks(hooks); + + await manager.triggerMessageAfter({ + sessionId: 'test-session', + content: 'response', + toolCalls: 3, + }); + + expect(triggered).toBe(true); + }); + }); + + describe('File Change Hooks', () => { + it('should trigger file.edited hook', async () => { + let triggered = false; + const hooks: Hooks = { + 'file.edited': async (input, output) => { + triggered = true; + expect(input.path).toBe('/test/file.ts'); + expect(input.tool).toBe('edit_file'); + }, + }; + + manager.registerHooks(hooks); + + await manager.triggerFileEdited({ + path: '/test/file.ts', + tool: 'edit_file', + sessionId: 'test-session', + }); + + expect(triggered).toBe(true); + }); + + it('should trigger file.created hook', async () => { + let triggered = false; + const hooks: Hooks = { + 'file.created': async (input, output) => { + triggered = true; + expect(input.path).toBe('/test/new-file.ts'); + }, + }; + + manager.registerHooks(hooks); + + await manager.triggerFileCreated({ + path: '/test/new-file.ts', + tool: 'write_file', + sessionId: 'test-session', + }); + + expect(triggered).toBe(true); + }); + + it('should trigger file.deleted hook', async () => { + let triggered = false; + const hooks: Hooks = { + 'file.deleted': async (input, output) => { + triggered = true; + expect(input.path).toBe('/test/old-file.ts'); + }, + }; + + manager.registerHooks(hooks); + + await manager.triggerFileDeleted({ + path: '/test/old-file.ts', + tool: 'delete_file', + sessionId: 'test-session', + }); + + expect(triggered).toBe(true); + }); + }); + + describe('Config Hooks', () => { + it('should execute file hooks matching glob pattern', async () => { + const config: HookConfig = { + file_edited: { + '*.ts': [ + { + command: ['echo', 'TypeScript file edited'], + timeout: 5000, + }, + ], + }, + }; + + manager.setConfigHooks(config); + + const result = await manager.triggerFileEdited({ + path: 'test.ts', + tool: 'edit_file', + sessionId: 'test-session', + }); + + expect(result.commandResults).toBeDefined(); + expect(result.commandResults?.length).toBe(1); + expect(result.commandResults?.[0].success).toBe(true); + expect(result.commandResults?.[0].output).toContain('TypeScript file edited'); + }); + + it('should not execute hooks for non-matching patterns', async () => { + const config: HookConfig = { + file_edited: { + '*.ts': [ + { + command: ['echo', 'TypeScript'], + }, + ], + }, + }; + + manager.setConfigHooks(config); + + const result = await manager.triggerFileEdited({ + path: 'test.js', + tool: 'edit_file', + sessionId: 'test-session', + }); + + expect(result.commandResults).toBeDefined(); + expect(result.commandResults?.length).toBe(0); + }); + }); + + describe('Event Listeners', () => { + it('should emit events to listeners', async () => { + const events: any[] = []; + manager.addEventListener((event) => { + events.push(event); + }); + + await manager.triggerToolExecuteBefore({ + tool: 'test_tool', + sessionId: 'test-session', + callId: 'call-1', + args: {}, + }); + + expect(events.length).toBe(1); + expect(events[0].type).toBe('tool.execute.before'); + expect(events[0].timestamp).toBeGreaterThan(0); + }); + + it('should remove event listener', async () => { + const events: any[] = []; + const listener = (event: any) => events.push(event); + + manager.addEventListener(listener); + manager.removeEventListener(listener); + + await manager.triggerToolExecuteBefore({ + tool: 'test_tool', + sessionId: 'test-session', + callId: 'call-1', + args: {}, + }); + + expect(events.length).toBe(0); + }); + }); + + describe('Error Handling', () => { + it('should continue execution when hook throws error', async () => { + const hooks: Hooks = { + 'tool.execute.before': async () => { + throw new Error('Hook error'); + }, + }; + + manager.registerHooks(hooks); + + // Should not throw + const result = await manager.triggerToolExecuteBefore({ + tool: 'test_tool', + sessionId: 'test-session', + callId: 'call-1', + args: { foo: 'bar' }, + }); + + // Args should remain unchanged + expect(result.args).toEqual({ foo: 'bar' }); + }); + }); +}); + +describe('Global Hook Manager', () => { + afterEach(() => { + resetHookManager(); + }); + + it('should initialize global hook manager', () => { + const manager = initHookManager('/test/dir', 'session-1'); + expect(manager).toBeInstanceOf(HookManager); + expect(getHookManager()).toBe(manager); + }); + + it('should return null before initialization', () => { + expect(getHookManager()).toBeNull(); + }); + + it('should reset global hook manager', () => { + initHookManager('/test/dir'); + resetHookManager(); + expect(getHookManager()).toBeNull(); + }); +}); + +describe('Config Loader', () => { + let tempDir: string; + + beforeEach(async () => { + tempDir = path.join(os.tmpdir(), `config-test-${Date.now()}`); + await fs.mkdir(tempDir, { recursive: true }); + }); + + afterEach(async () => { + try { + await fs.rm(tempDir, { recursive: true, force: true }); + } catch { + // ignore cleanup errors + } + }); + + it('should load project config from .ai-assistant.json', async () => { + const config = { + hooks: { + file_edited: { + '*.ts': [{ command: ['echo', 'test'] }], + }, + }, + plugins: ['plugin-a', 'plugin-b'], + }; + + await fs.writeFile( + path.join(tempDir, '.ai-assistant.json'), + JSON.stringify(config) + ); + + const loaded = await loadProjectConfig(tempDir); + expect(loaded).not.toBeNull(); + expect(loaded?.hooks?.file_edited).toBeDefined(); + expect(loaded?.plugins).toEqual(['plugin-a', 'plugin-b']); + }); + + it('should load hook config', async () => { + const config = { + hooks: { + file_edited: { + '*.ts': [{ command: ['npm', 'run', 'lint'] }], + }, + session_completed: [{ command: ['echo', 'done'] }], + }, + }; + + await fs.writeFile( + path.join(tempDir, '.ai-assistant.json'), + JSON.stringify(config) + ); + + const hookConfig = await loadHookConfig(tempDir); + expect(hookConfig).not.toBeNull(); + expect(hookConfig?.file_edited?.['*.ts']).toHaveLength(1); + expect(hookConfig?.session_completed).toHaveLength(1); + }); + + it('should return null for missing config', async () => { + const config = await loadProjectConfig(tempDir); + expect(config).toBeNull(); + }); + + it('should support JSONC format with comments', async () => { + const configContent = `{ + // This is a comment + "hooks": { + /* Multi-line + comment */ + "file_edited": { + "*.ts": [{ "command": ["echo", "test"] }] + } + } + }`; + + await fs.writeFile( + path.join(tempDir, '.ai-assistant.jsonc'), + configContent + ); + + const loaded = await loadProjectConfig(tempDir); + expect(loaded).not.toBeNull(); + expect(loaded?.hooks?.file_edited?.['*.ts']).toBeDefined(); + }); +});