From 8c46635dc7bcb6f7c22fdc446f2354133f477f17 Mon Sep 17 00:00:00 2001 From: kurihada Date: Wed, 17 Dec 2025 00:44:25 +0800 Subject: [PATCH] =?UTF-8?q?feat(core):=20=E5=AE=9E=E7=8E=B0=20ask=5Fuser?= =?UTF-8?q?=5Fquestion=20=E5=B7=A5=E5=85=B7=E7=9A=84=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E8=BE=93=E5=85=A5=E7=AD=89=E5=BE=85=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 创建 UserInputWaiter 管理用户输入等待状态 - 修改 agent-tool-executor 在 requiresUserInput 时等待用户回答 - 添加 onWaitingForInput 回调通知前端显示问题 - Server 端处理 waiting_for_input 广播和 user_input_response 消息 - 前端处理问题显示和用户回答提交 - 修复问题选项在流式输出时被禁用的问题 --- .../core/src/core/agent-message-handler.ts | 3 +- packages/core/src/core/agent-tool-executor.ts | 75 +++++++++-- packages/core/src/core/agent.ts | 9 +- packages/core/src/core/index.ts | 8 ++ packages/core/src/core/user-input-waiter.ts | 117 ++++++++++++++++++ packages/core/src/index.ts | 6 +- packages/server/src/agent/adapter.ts | 29 +++++ packages/server/src/agent/index.ts | 2 + packages/server/src/types.ts | 17 ++- packages/server/src/ws.ts | 29 ++++- packages/ui/src/api/types.ts | 12 ++ packages/ui/src/components/ChatMessage.tsx | 9 +- packages/ui/src/hooks/useChat.ts | 88 ++++++++----- 13 files changed, 351 insertions(+), 53 deletions(-) create mode 100644 packages/core/src/core/user-input-waiter.ts diff --git a/packages/core/src/core/agent-message-handler.ts b/packages/core/src/core/agent-message-handler.ts index b8d04f5..ecd5179 100644 --- a/packages/core/src/core/agent-message-handler.ts +++ b/packages/core/src/core/agent-message-handler.ts @@ -278,7 +278,8 @@ export class AgentMessageHandler { onToolEnd({ id: toolCallId, status: output.success ? 'completed' : 'error', - result: output.success ? output.output : undefined, + // 传递完整的结果对象(包含 output 和 metadata),以支持 ask_user_question 等需要 metadata 的工具 + result: output.success ? { output: output.output, metadata: output.metadata } : undefined, error: output.success ? undefined : output.error, duration, }); diff --git a/packages/core/src/core/agent-tool-executor.ts b/packages/core/src/core/agent-tool-executor.ts index e7ff991..3814384 100644 --- a/packages/core/src/core/agent-tool-executor.ts +++ b/packages/core/src/core/agent-tool-executor.ts @@ -16,6 +16,7 @@ import { } from '../agent/index.js'; import { getHookManager } from '../hooks/index.js'; import { getGitManager } from '../git/index.js'; +import { getUserInputWaiter } from './user-input-waiter.js'; /** * 工具调用开始事件信息 @@ -37,6 +38,16 @@ export interface ToolEndInfo { duration?: number; } +/** + * 等待用户输入事件信息 + */ +export interface WaitingForInputInfo { + id: string; + toolName: string; + questions: unknown[]; + args: Record; +} + /** * 工具执行上下文 */ @@ -45,6 +56,8 @@ export interface ToolExecutionContext { agentMode: AgentInfo | null; onToolStart?: (info: ToolStartInfo) => void; onToolEnd?: (info: ToolEndInfo) => void; + /** 当工具需要用户输入时调用(如 ask_user_question) */ + onWaitingForInput?: (info: WaitingForInputInfo) => void; } /** @@ -175,8 +188,8 @@ export class AgentToolExecutor { context: ToolExecutionContext, hookManager: ReturnType ): Promise { - const callId = `${tool.name}-${Date.now()}`; - const { sessionId, onToolStart, onToolEnd } = context; + const callId = `${tool.name}-${Date.now()}-${Math.random().toString(36).slice(2, 7)}`; + const { sessionId, onToolStart, onToolEnd, onWaitingForInput } = context; // 触发工具执行前 hook let finalArgs = args; @@ -221,7 +234,7 @@ export class AgentToolExecutor { // 执行工具 const startTime = Date.now(); let result = await tool.execute(finalArgs); - const duration = Date.now() - startTime; + let duration = Date.now() - startTime; // 触发工具执行后 hook if (hookManager) { @@ -243,14 +256,54 @@ export class AgentToolExecutor { } } - // 通知工具结束 - onToolEnd?.({ - id: callId, - status: result.success ? 'completed' : 'error', - result: result.success ? result.output : undefined, - error: result.success ? undefined : result.error, - duration, - }); + // 检查是否需要用户输入(如 ask_user_question 工具) + if (result.success && result.metadata?.requiresUserInput) { + // 通知前端等待用户输入 + const questions = (finalArgs.questions as unknown[]) || []; + onWaitingForInput?.({ + id: callId, + toolName: tool.name, + questions, + args: finalArgs, + }); + + // 等待用户输入 + const userInputWaiter = getUserInputWaiter(); + try { + const userAnswer = await userInputWaiter.waitForInput(callId, tool.name); + // 用户回答后,更新结果 + result = { + success: true, + output: `用户回答:\n${userAnswer}`, + metadata: { + ...result.metadata, + userAnswer, + requiresUserInput: false, // 已获得输入 + }, + }; + // 更新持续时间(包含等待用户输入的时间) + duration = Date.now() - startTime; + } catch (error) { + // 用户取消或超时 + result = { + success: false, + output: '', + error: error instanceof Error ? error.message : '等待用户输入失败', + }; + } + } + + // 通知工具结束(注意:这里不再调用 onToolEnd,因为 agent-message-handler 会在 tool-result chunk 中处理) + // 但对于需要用户输入的工具,我们需要在这里调用,因为结果已经更新 + if (result.metadata?.userAnswer !== undefined) { + onToolEnd?.({ + id: callId, + status: result.success ? 'completed' : 'error', + result: result.success ? { output: result.output, metadata: result.metadata } : undefined, + error: result.success ? undefined : result.error, + duration, + }); + } // 如果是 tool_search 调用,解析结果并注入发现的工具 if (tool.name === 'tool_search' && result.success) { diff --git a/packages/core/src/core/agent.ts b/packages/core/src/core/agent.ts index 3c8b1ae..f0cec59 100644 --- a/packages/core/src/core/agent.ts +++ b/packages/core/src/core/agent.ts @@ -19,13 +19,13 @@ import { todoManager } from '../tools/todo/todo-manager.js'; import { initTaskContext } from '../tools/task/index.js'; // 子模块 -import { AgentToolExecutor, type ToolStartInfo, type ToolEndInfo } from './agent-tool-executor.js'; +import { AgentToolExecutor, type ToolStartInfo, type ToolEndInfo, type WaitingForInputInfo } from './agent-tool-executor.js'; import { AgentMessageHandler, type DoomLoopInfo } from './agent-message-handler.js'; import { AgentModeManager } from './agent-mode-manager.js'; import { AgentVisionHandler } from './agent-vision-handler.js'; // 重新导出类型 -export type { ToolStartInfo, ToolEndInfo, DoomLoopInfo }; +export type { ToolStartInfo, ToolEndInfo, DoomLoopInfo, WaitingForInputInfo }; /** * Agent.chat() 选项 @@ -35,6 +35,8 @@ export interface AgentChatOptions { onToolStart?: (info: ToolStartInfo) => void; onToolEnd?: (info: ToolEndInfo) => void; onDoomLoop?: (info: DoomLoopInfo) => void; + /** 当工具需要用户输入时调用(如 ask_user_question) */ + onWaitingForInput?: (info: WaitingForInputInfo) => void; abortSignal?: AbortSignal; } @@ -155,7 +157,7 @@ export class Agent { async chat(userMessage: string | UserInput, options?: AgentChatOptions | ((text: string) => void)): Promise { // 兼容旧的 onStream 参数 const opts: AgentChatOptions = typeof options === 'function' ? { onStream: options } : (options || {}); - const { onStream, onToolStart, onToolEnd, onDoomLoop, abortSignal } = opts; + const { onStream, onToolStart, onToolEnd, onDoomLoop, onWaitingForInput, abortSignal } = opts; if (!this.toolExecutor) { throw new Error('工具注册表未初始化,请先调用 setRegistry()'); @@ -199,6 +201,7 @@ export class Agent { sessionId: this.sessionManager?.getSession()?.id || 'default', agentMode: this.modeManager.getCurrentMode(), onToolEnd, + onWaitingForInput, }); // 配置消息处理 diff --git a/packages/core/src/core/index.ts b/packages/core/src/core/index.ts index 65dd668..f11abb2 100644 --- a/packages/core/src/core/index.ts +++ b/packages/core/src/core/index.ts @@ -11,8 +11,16 @@ export { type ToolStartInfo, type ToolEndInfo, type ToolExecutionContext, + type WaitingForInputInfo, } from './agent-tool-executor.js'; +// 用户输入等待器 +export { + getUserInputWaiter, + UserInputWaiter, + type PendingInput, +} from './user-input-waiter.js'; + export { AgentMessageHandler, type DoomLoopInfo, diff --git a/packages/core/src/core/user-input-waiter.ts b/packages/core/src/core/user-input-waiter.ts new file mode 100644 index 0000000..c26bf77 --- /dev/null +++ b/packages/core/src/core/user-input-waiter.ts @@ -0,0 +1,117 @@ +/** + * 用户输入等待管理器 + * + * 用于处理需要用户输入的工具(如 ask_user_question)。 + * 当工具需要用户输入时,会创建一个等待器,阻塞工具执行直到用户提交回答。 + */ + +export interface PendingInput { + toolCallId: string; + toolName: string; + resolve: (answer: string) => void; + reject: (error: Error) => void; + createdAt: number; +} + +/** + * 用户输入等待管理器 + */ +class UserInputWaiter { + private pendingInputs: Map = new Map(); + // 超时时间:10 分钟 + private readonly timeout = 10 * 60 * 1000; + + /** + * 等待用户输入 + * @param toolCallId 工具调用 ID + * @param toolName 工具名称 + * @returns 用户输入的答案 + */ + async waitForInput(toolCallId: string, toolName: string): Promise { + return new Promise((resolve, reject) => { + const pending: PendingInput = { + toolCallId, + toolName, + resolve, + reject, + createdAt: Date.now(), + }; + + this.pendingInputs.set(toolCallId, pending); + + // 设置超时 + setTimeout(() => { + if (this.pendingInputs.has(toolCallId)) { + this.pendingInputs.delete(toolCallId); + reject(new Error(`等待用户输入超时 (${toolName})`)); + } + }, this.timeout); + }); + } + + /** + * 提交用户输入 + * @param toolCallId 工具调用 ID + * @param answer 用户的回答 + * @returns 是否成功提交 + */ + submitInput(toolCallId: string, answer: string): boolean { + const pending = this.pendingInputs.get(toolCallId); + if (!pending) { + return false; + } + + this.pendingInputs.delete(toolCallId); + pending.resolve(answer); + return true; + } + + /** + * 取消等待 + * @param toolCallId 工具调用 ID + * @param reason 取消原因 + */ + cancelInput(toolCallId: string, reason?: string): boolean { + const pending = this.pendingInputs.get(toolCallId); + if (!pending) { + return false; + } + + this.pendingInputs.delete(toolCallId); + pending.reject(new Error(reason || '用户取消了输入')); + return true; + } + + /** + * 检查是否有等待中的输入 + */ + hasPendingInput(toolCallId: string): boolean { + return this.pendingInputs.has(toolCallId); + } + + /** + * 获取所有等待中的输入 + */ + getPendingInputs(): PendingInput[] { + return Array.from(this.pendingInputs.values()); + } + + /** + * 清除所有等待(用于会话结束时) + */ + clearAll(reason?: string): void { + for (const pending of this.pendingInputs.values()) { + pending.reject(new Error(reason || '会话已结束')); + } + this.pendingInputs.clear(); + } +} + +// 全局单例 +const userInputWaiter = new UserInputWaiter(); + +export function getUserInputWaiter(): UserInputWaiter { + return userInputWaiter; +} + +export { UserInputWaiter }; diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index e77634e..1a5348d 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,5 +1,9 @@ export { Agent } from './core/agent.js'; -export type { AgentChatOptions, ToolStartInfo, ToolEndInfo, DoomLoopInfo } from './core/agent.js'; +export type { AgentChatOptions, ToolStartInfo, ToolEndInfo, DoomLoopInfo, WaitingForInputInfo } from './core/agent.js'; + +// User Input Waiter (用于 ask_user_question 等工具) +export { getUserInputWaiter, UserInputWaiter } from './core/user-input-waiter.js'; +export type { PendingInput } from './core/user-input-waiter.js'; // Doom Loop Detection export { diff --git a/packages/server/src/agent/adapter.ts b/packages/server/src/agent/adapter.ts index 641ddf0..9839130 100644 --- a/packages/server/src/agent/adapter.ts +++ b/packages/server/src/agent/adapter.ts @@ -21,10 +21,12 @@ import { getProviderRegistry, agentRegistry, agentEventEmitter, + getUserInputWaiter, ConfigurationError, type AgentChatOptions, type ToolStartInfo, type ToolEndInfo, + type WaitingForInputInfo, type TokenUsage, type DetailedCompressionResult, } from '@ai-assistant/core'; @@ -380,6 +382,22 @@ export async function processMessage( }, }); }, + onWaitingForInput: (info: WaitingForInputInfo) => { + // 检查是否已取消 + if (abortController.signal.aborted) return; + + // 推送等待用户输入事件 + broadcastToSession(sessionId, { + type: 'waiting_for_input', + sessionId, + payload: { + id: info.id, + toolName: info.toolName, + questions: info.questions, + arguments: info.args, + }, + }); + }, abortSignal: abortController.signal, }; @@ -614,3 +632,14 @@ export async function compressContext( // Re-export TokenUsage for external use export type { TokenUsage }; + +/** + * 提交用户输入响应(用于 ask_user_question 等工具) + * @param toolCallId 工具调用 ID + * @param answer 用户的回答 + * @returns 是否成功提交 + */ +export function submitUserInput(toolCallId: string, answer: string): boolean { + const userInputWaiter = getUserInputWaiter(); + return userInputWaiter.submitInput(toolCallId, answer); +} diff --git a/packages/server/src/agent/index.ts b/packages/server/src/agent/index.ts index 64faba2..bceff34 100644 --- a/packages/server/src/agent/index.ts +++ b/packages/server/src/agent/index.ts @@ -15,6 +15,8 @@ export { // 上下文压缩相关 getContextUsage, compressContext, + // 用户输入响应 + submitUserInput, // 类型导出 type TokenUsage, type CompressionResult, diff --git a/packages/server/src/types.ts b/packages/server/src/types.ts index 6c551e2..282fdbf 100644 --- a/packages/server/src/types.ts +++ b/packages/server/src/types.ts @@ -82,7 +82,7 @@ export type AgentModeType = 'build' | 'plan'; // 客户端发送的消息 export interface ClientMessage { - type: 'message' | 'cancel' | 'tool_response' | 'permission_response' | 'config_update' | 'mode_switch'; + type: 'message' | 'cancel' | 'tool_response' | 'permission_response' | 'config_update' | 'mode_switch' | 'user_input_response'; sessionId: string; payload?: { content?: string; @@ -95,6 +95,8 @@ export interface ClientMessage { // Agent mode fields agentMode?: AgentModeType; autoApprove?: boolean; + // User input response fields (for ask_user_question) + answer?: string; }; } @@ -108,6 +110,7 @@ export interface ServerMessage { | 'tool_result' | 'tool_start' // 工具开始执行 | 'tool_end' // 工具执行完成 + | 'waiting_for_input' // 等待用户输入(如 ask_user_question) | 'done' | 'cancelled' | 'error' @@ -140,6 +143,18 @@ export interface ToolEndPayload { duration?: number; } +// 等待用户输入事件 Payload +export interface WaitingForInputPayload { + /** 工具调用 ID,用于匹配用户回答 */ + id: string; + /** 工具名称 */ + toolName: string; + /** 问题列表 */ + questions: unknown[]; + /** 工具参数 */ + arguments: Record; +} + // ============ 子 Agent 事件 Payload ============ /** 子 Agent 开始事件 Payload */ diff --git a/packages/server/src/ws.ts b/packages/server/src/ws.ts index 2cc72b2..481f173 100644 --- a/packages/server/src/ws.ts +++ b/packages/server/src/ws.ts @@ -6,7 +6,7 @@ import type { WSContext } from 'hono/ws'; import { getSessionManager } from './session/manager.js'; -import { processMessage, cancelProcessing, getOrCreateAgent } from './agent/index.js'; +import { processMessage, cancelProcessing, getOrCreateAgent, submitUserInput } from './agent/index.js'; import { handlePermissionResponse, setSessionAutoApprove } from './permission/handler.js'; import type { ClientMessage, ServerMessage } from './types.js'; @@ -201,6 +201,33 @@ export async function handleWebSocketMessage( break; } + case 'user_input_response': { + // 处理用户输入响应(用于 ask_user_question 等工具) + const toolCallId = message.payload?.toolCallId; + const answer = message.payload?.answer; + + if (toolCallId && answer !== undefined) { + const handled = submitUserInput(toolCallId, answer); + if (!handled) { + console.warn(`[WS] User input response for unknown tool call: ${toolCallId}`); + broadcastToSession(sessionId, { + type: 'error', + sessionId, + payload: { message: `No pending input request for tool call: ${toolCallId}` }, + }); + } else { + console.log(`[WS] User input submitted for tool call: ${toolCallId}`); + } + } else { + broadcastToSession(sessionId, { + type: 'error', + sessionId, + payload: { message: 'Missing toolCallId or answer in user_input_response' }, + }); + } + break; + } + default: ws.send( JSON.stringify({ diff --git a/packages/ui/src/api/types.ts b/packages/ui/src/api/types.ts index 9deebf1..89e5659 100644 --- a/packages/ui/src/api/types.ts +++ b/packages/ui/src/api/types.ts @@ -954,6 +954,18 @@ export interface ToolEndPayload { duration?: number; } +/** 等待用户输入事件 Payload */ +export interface WaitingForInputPayload { + /** 工具调用 ID,用于匹配用户回答 */ + id: string; + /** 工具名称 */ + toolName: string; + /** 问题列表 */ + questions: Question[]; + /** 工具参数 */ + arguments: Record; +} + // ============ 子 Agent 事件 Payload ============ /** 子 Agent 开始事件 Payload */ diff --git a/packages/ui/src/components/ChatMessage.tsx b/packages/ui/src/components/ChatMessage.tsx index e842e1d..b5dfc2a 100644 --- a/packages/ui/src/components/ChatMessage.tsx +++ b/packages/ui/src/components/ChatMessage.tsx @@ -84,19 +84,22 @@ export const ChatMessage = forwardRef( ); case 'tool': return ; - case 'question': + case 'question': { + // 问题组件:即使在流式输出时也允许用户回答(除非已回答) + const questionPart = part as QuestionMessagePart; return ( onAnswerQuestion(part.id, answers) : undefined } - disabled={isStreaming} + disabled={questionPart.answered} /> ); + } case 'reasoning': return (
diff --git a/packages/ui/src/hooks/useChat.ts b/packages/ui/src/hooks/useChat.ts index a6f7c4a..d312a10 100644 --- a/packages/ui/src/hooks/useChat.ts +++ b/packages/ui/src/hooks/useChat.ts @@ -11,6 +11,7 @@ import type { ConfigErrorPayload, ToolStartPayload, ToolEndPayload, + WaitingForInputPayload, MessagePart, ToolMessagePart, QuestionMessagePart, @@ -320,6 +321,46 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate break; } + case 'waiting_for_input': { + // 工具需要用户输入(如 ask_user_question) + const payload = message.payload as WaitingForInputPayload; + setState((prev) => { + if (!prev.streamingMessage) return prev; + + // 创建 QuestionMessagePart + const questionPart: QuestionMessagePart = { + type: 'question', + id: payload.id, + questions: payload.questions, + answered: false, + }; + + // 查找并替换对应的工具 part(如果存在),或添加新的 + let found = false; + const parts = prev.streamingMessage.parts.map((part) => { + if (part.type === 'tool' && part.id === payload.id) { + found = true; + return questionPart; + } + return part; + }); + + // 如果没找到对应的工具 part,添加到末尾 + if (!found) { + parts.push(questionPart); + } + + return { + ...prev, + streamingMessage: { + ...prev.streamingMessage, + parts, + }, + }; + }); + break; + } + case 'done': setState((prev) => { // 使用流式消息或创建新消息 @@ -655,6 +696,21 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate // 回答问题(ask_user_question 工具) const answerQuestion = useCallback( (questionPartId: string, answers: string[]) => { + // 发送用户输入响应 + const answerText = answers.filter((a) => a).join('\n'); + if (answerText && wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { + wsRef.current.send( + JSON.stringify({ + type: 'user_input_response', + sessionId, + payload: { + toolCallId: questionPartId, + answer: answerText, + }, + }) + ); + } + // 更新问题状态为已回答 setState((prev) => { // 更新流式消息中的问题 @@ -670,22 +726,6 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate return part; }); - // 发送回答作为用户消息 - const answerText = answers.filter((a) => a).join('\n'); - if (answerText && wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { - wsRef.current.send( - JSON.stringify({ - type: 'message', - sessionId, - payload: { - content: answerText, - agentMode: state.agentMode, - autoApprove: state.autoApprove, - }, - }) - ); - } - return { ...prev, streamingMessage: { @@ -713,22 +753,6 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate return msg; }); - // 发送回答作为用户消息 - const answerText = answers.filter((a) => a).join('\n'); - if (answerText && wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { - wsRef.current.send( - JSON.stringify({ - type: 'message', - sessionId, - payload: { - content: answerText, - agentMode: state.agentMode, - autoApprove: state.autoApprove, - }, - }) - ); - } - return { ...prev, messages }; }); },