feat(core): 实现 ask_user_question 工具的用户输入等待机制

- 创建 UserInputWaiter 管理用户输入等待状态
- 修改 agent-tool-executor 在 requiresUserInput 时等待用户回答
- 添加 onWaitingForInput 回调通知前端显示问题
- Server 端处理 waiting_for_input 广播和 user_input_response 消息
- 前端处理问题显示和用户回答提交
- 修复问题选项在流式输出时被禁用的问题
This commit is contained in:
2025-12-17 00:44:25 +08:00
parent a4e8037108
commit 8c46635dc7
13 changed files with 351 additions and 53 deletions
@@ -278,7 +278,8 @@ export class AgentMessageHandler {
onToolEnd({ onToolEnd({
id: toolCallId, id: toolCallId,
status: output.success ? 'completed' : 'error', status: output.success ? 'completed' : 'error',
result: output.success ? output.output : undefined, // 传递完整的结果对象(包含 output 和 metadata),以支持 ask_user_question 等需要 metadata 的工具
result: output.success ? { output: output.output, metadata: output.metadata } : undefined,
error: output.success ? undefined : output.error, error: output.success ? undefined : output.error,
duration, duration,
}); });
+64 -11
View File
@@ -16,6 +16,7 @@ import {
} from '../agent/index.js'; } from '../agent/index.js';
import { getHookManager } from '../hooks/index.js'; import { getHookManager } from '../hooks/index.js';
import { getGitManager } from '../git/index.js'; import { getGitManager } from '../git/index.js';
import { getUserInputWaiter } from './user-input-waiter.js';
/** /**
* 工具调用开始事件信息 * 工具调用开始事件信息
@@ -37,6 +38,16 @@ export interface ToolEndInfo {
duration?: number; duration?: number;
} }
/**
* 等待用户输入事件信息
*/
export interface WaitingForInputInfo {
id: string;
toolName: string;
questions: unknown[];
args: Record<string, unknown>;
}
/** /**
* 工具执行上下文 * 工具执行上下文
*/ */
@@ -45,6 +56,8 @@ export interface ToolExecutionContext {
agentMode: AgentInfo | null; agentMode: AgentInfo | null;
onToolStart?: (info: ToolStartInfo) => void; onToolStart?: (info: ToolStartInfo) => void;
onToolEnd?: (info: ToolEndInfo) => void; onToolEnd?: (info: ToolEndInfo) => void;
/** 当工具需要用户输入时调用(如 ask_user_question */
onWaitingForInput?: (info: WaitingForInputInfo) => void;
} }
/** /**
@@ -175,8 +188,8 @@ export class AgentToolExecutor {
context: ToolExecutionContext, context: ToolExecutionContext,
hookManager: ReturnType<typeof getHookManager> hookManager: ReturnType<typeof getHookManager>
): Promise<ToolResult> { ): Promise<ToolResult> {
const callId = `${tool.name}-${Date.now()}`; const callId = `${tool.name}-${Date.now()}-${Math.random().toString(36).slice(2, 7)}`;
const { sessionId, onToolStart, onToolEnd } = context; const { sessionId, onToolStart, onToolEnd, onWaitingForInput } = context;
// 触发工具执行前 hook // 触发工具执行前 hook
let finalArgs = args; let finalArgs = args;
@@ -221,7 +234,7 @@ export class AgentToolExecutor {
// 执行工具 // 执行工具
const startTime = Date.now(); const startTime = Date.now();
let result = await tool.execute(finalArgs); let result = await tool.execute(finalArgs);
const duration = Date.now() - startTime; let duration = Date.now() - startTime;
// 触发工具执行后 hook // 触发工具执行后 hook
if (hookManager) { if (hookManager) {
@@ -243,14 +256,54 @@ export class AgentToolExecutor {
} }
} }
// 通知工具结束 // 检查是否需要用户输入(如 ask_user_question 工具)
onToolEnd?.({ if (result.success && result.metadata?.requiresUserInput) {
id: callId, // 通知前端等待用户输入
status: result.success ? 'completed' : 'error', const questions = (finalArgs.questions as unknown[]) || [];
result: result.success ? result.output : undefined, onWaitingForInput?.({
error: result.success ? undefined : result.error, id: callId,
duration, toolName: tool.name,
}); questions,
args: finalArgs,
});
// 等待用户输入
const userInputWaiter = getUserInputWaiter();
try {
const userAnswer = await userInputWaiter.waitForInput(callId, tool.name);
// 用户回答后,更新结果
result = {
success: true,
output: `用户回答:\n${userAnswer}`,
metadata: {
...result.metadata,
userAnswer,
requiresUserInput: false, // 已获得输入
},
};
// 更新持续时间(包含等待用户输入的时间)
duration = Date.now() - startTime;
} catch (error) {
// 用户取消或超时
result = {
success: false,
output: '',
error: error instanceof Error ? error.message : '等待用户输入失败',
};
}
}
// 通知工具结束(注意:这里不再调用 onToolEnd,因为 agent-message-handler 会在 tool-result chunk 中处理)
// 但对于需要用户输入的工具,我们需要在这里调用,因为结果已经更新
if (result.metadata?.userAnswer !== undefined) {
onToolEnd?.({
id: callId,
status: result.success ? 'completed' : 'error',
result: result.success ? { output: result.output, metadata: result.metadata } : undefined,
error: result.success ? undefined : result.error,
duration,
});
}
// 如果是 tool_search 调用,解析结果并注入发现的工具 // 如果是 tool_search 调用,解析结果并注入发现的工具
if (tool.name === 'tool_search' && result.success) { if (tool.name === 'tool_search' && result.success) {
+6 -3
View File
@@ -19,13 +19,13 @@ import { todoManager } from '../tools/todo/todo-manager.js';
import { initTaskContext } from '../tools/task/index.js'; import { initTaskContext } from '../tools/task/index.js';
// 子模块 // 子模块
import { AgentToolExecutor, type ToolStartInfo, type ToolEndInfo } from './agent-tool-executor.js'; import { AgentToolExecutor, type ToolStartInfo, type ToolEndInfo, type WaitingForInputInfo } from './agent-tool-executor.js';
import { AgentMessageHandler, type DoomLoopInfo } from './agent-message-handler.js'; import { AgentMessageHandler, type DoomLoopInfo } from './agent-message-handler.js';
import { AgentModeManager } from './agent-mode-manager.js'; import { AgentModeManager } from './agent-mode-manager.js';
import { AgentVisionHandler } from './agent-vision-handler.js'; import { AgentVisionHandler } from './agent-vision-handler.js';
// 重新导出类型 // 重新导出类型
export type { ToolStartInfo, ToolEndInfo, DoomLoopInfo }; export type { ToolStartInfo, ToolEndInfo, DoomLoopInfo, WaitingForInputInfo };
/** /**
* Agent.chat() 选项 * Agent.chat() 选项
@@ -35,6 +35,8 @@ export interface AgentChatOptions {
onToolStart?: (info: ToolStartInfo) => void; onToolStart?: (info: ToolStartInfo) => void;
onToolEnd?: (info: ToolEndInfo) => void; onToolEnd?: (info: ToolEndInfo) => void;
onDoomLoop?: (info: DoomLoopInfo) => void; onDoomLoop?: (info: DoomLoopInfo) => void;
/** 当工具需要用户输入时调用(如 ask_user_question */
onWaitingForInput?: (info: WaitingForInputInfo) => void;
abortSignal?: AbortSignal; abortSignal?: AbortSignal;
} }
@@ -155,7 +157,7 @@ export class Agent {
async chat(userMessage: string | UserInput, options?: AgentChatOptions | ((text: string) => void)): Promise<ChatResult> { async chat(userMessage: string | UserInput, options?: AgentChatOptions | ((text: string) => void)): Promise<ChatResult> {
// 兼容旧的 onStream 参数 // 兼容旧的 onStream 参数
const opts: AgentChatOptions = typeof options === 'function' ? { onStream: options } : (options || {}); const opts: AgentChatOptions = typeof options === 'function' ? { onStream: options } : (options || {});
const { onStream, onToolStart, onToolEnd, onDoomLoop, abortSignal } = opts; const { onStream, onToolStart, onToolEnd, onDoomLoop, onWaitingForInput, abortSignal } = opts;
if (!this.toolExecutor) { if (!this.toolExecutor) {
throw new Error('工具注册表未初始化,请先调用 setRegistry()'); throw new Error('工具注册表未初始化,请先调用 setRegistry()');
@@ -199,6 +201,7 @@ export class Agent {
sessionId: this.sessionManager?.getSession()?.id || 'default', sessionId: this.sessionManager?.getSession()?.id || 'default',
agentMode: this.modeManager.getCurrentMode(), agentMode: this.modeManager.getCurrentMode(),
onToolEnd, onToolEnd,
onWaitingForInput,
}); });
// 配置消息处理 // 配置消息处理
+8
View File
@@ -11,8 +11,16 @@ export {
type ToolStartInfo, type ToolStartInfo,
type ToolEndInfo, type ToolEndInfo,
type ToolExecutionContext, type ToolExecutionContext,
type WaitingForInputInfo,
} from './agent-tool-executor.js'; } from './agent-tool-executor.js';
// 用户输入等待器
export {
getUserInputWaiter,
UserInputWaiter,
type PendingInput,
} from './user-input-waiter.js';
export { export {
AgentMessageHandler, AgentMessageHandler,
type DoomLoopInfo, type DoomLoopInfo,
+117
View File
@@ -0,0 +1,117 @@
/**
* 用户输入等待管理器
*
* 用于处理需要用户输入的工具(如 ask_user_question)。
* 当工具需要用户输入时,会创建一个等待器,阻塞工具执行直到用户提交回答。
*/
export interface PendingInput {
toolCallId: string;
toolName: string;
resolve: (answer: string) => void;
reject: (error: Error) => void;
createdAt: number;
}
/**
* 用户输入等待管理器
*/
class UserInputWaiter {
private pendingInputs: Map<string, PendingInput> = new Map();
// 超时时间:10 分钟
private readonly timeout = 10 * 60 * 1000;
/**
* 等待用户输入
* @param toolCallId 工具调用 ID
* @param toolName 工具名称
* @returns 用户输入的答案
*/
async waitForInput(toolCallId: string, toolName: string): Promise<string> {
return new Promise<string>((resolve, reject) => {
const pending: PendingInput = {
toolCallId,
toolName,
resolve,
reject,
createdAt: Date.now(),
};
this.pendingInputs.set(toolCallId, pending);
// 设置超时
setTimeout(() => {
if (this.pendingInputs.has(toolCallId)) {
this.pendingInputs.delete(toolCallId);
reject(new Error(`等待用户输入超时 (${toolName})`));
}
}, this.timeout);
});
}
/**
* 提交用户输入
* @param toolCallId 工具调用 ID
* @param answer 用户的回答
* @returns 是否成功提交
*/
submitInput(toolCallId: string, answer: string): boolean {
const pending = this.pendingInputs.get(toolCallId);
if (!pending) {
return false;
}
this.pendingInputs.delete(toolCallId);
pending.resolve(answer);
return true;
}
/**
* 取消等待
* @param toolCallId 工具调用 ID
* @param reason 取消原因
*/
cancelInput(toolCallId: string, reason?: string): boolean {
const pending = this.pendingInputs.get(toolCallId);
if (!pending) {
return false;
}
this.pendingInputs.delete(toolCallId);
pending.reject(new Error(reason || '用户取消了输入'));
return true;
}
/**
* 检查是否有等待中的输入
*/
hasPendingInput(toolCallId: string): boolean {
return this.pendingInputs.has(toolCallId);
}
/**
* 获取所有等待中的输入
*/
getPendingInputs(): PendingInput[] {
return Array.from(this.pendingInputs.values());
}
/**
* 清除所有等待(用于会话结束时)
*/
clearAll(reason?: string): void {
for (const pending of this.pendingInputs.values()) {
pending.reject(new Error(reason || '会话已结束'));
}
this.pendingInputs.clear();
}
}
// 全局单例
const userInputWaiter = new UserInputWaiter();
export function getUserInputWaiter(): UserInputWaiter {
return userInputWaiter;
}
export { UserInputWaiter };
+5 -1
View File
@@ -1,5 +1,9 @@
export { Agent } from './core/agent.js'; export { Agent } from './core/agent.js';
export type { AgentChatOptions, ToolStartInfo, ToolEndInfo, DoomLoopInfo } from './core/agent.js'; export type { AgentChatOptions, ToolStartInfo, ToolEndInfo, DoomLoopInfo, WaitingForInputInfo } from './core/agent.js';
// User Input Waiter (用于 ask_user_question 等工具)
export { getUserInputWaiter, UserInputWaiter } from './core/user-input-waiter.js';
export type { PendingInput } from './core/user-input-waiter.js';
// Doom Loop Detection // Doom Loop Detection
export { export {
+29
View File
@@ -21,10 +21,12 @@ import {
getProviderRegistry, getProviderRegistry,
agentRegistry, agentRegistry,
agentEventEmitter, agentEventEmitter,
getUserInputWaiter,
ConfigurationError, ConfigurationError,
type AgentChatOptions, type AgentChatOptions,
type ToolStartInfo, type ToolStartInfo,
type ToolEndInfo, type ToolEndInfo,
type WaitingForInputInfo,
type TokenUsage, type TokenUsage,
type DetailedCompressionResult, type DetailedCompressionResult,
} from '@ai-assistant/core'; } from '@ai-assistant/core';
@@ -380,6 +382,22 @@ export async function processMessage(
}, },
}); });
}, },
onWaitingForInput: (info: WaitingForInputInfo) => {
// 检查是否已取消
if (abortController.signal.aborted) return;
// 推送等待用户输入事件
broadcastToSession(sessionId, {
type: 'waiting_for_input',
sessionId,
payload: {
id: info.id,
toolName: info.toolName,
questions: info.questions,
arguments: info.args,
},
});
},
abortSignal: abortController.signal, abortSignal: abortController.signal,
}; };
@@ -614,3 +632,14 @@ export async function compressContext(
// Re-export TokenUsage for external use // Re-export TokenUsage for external use
export type { TokenUsage }; export type { TokenUsage };
/**
* 提交用户输入响应(用于 ask_user_question 等工具)
* @param toolCallId 工具调用 ID
* @param answer 用户的回答
* @returns 是否成功提交
*/
export function submitUserInput(toolCallId: string, answer: string): boolean {
const userInputWaiter = getUserInputWaiter();
return userInputWaiter.submitInput(toolCallId, answer);
}
+2
View File
@@ -15,6 +15,8 @@ export {
// 上下文压缩相关 // 上下文压缩相关
getContextUsage, getContextUsage,
compressContext, compressContext,
// 用户输入响应
submitUserInput,
// 类型导出 // 类型导出
type TokenUsage, type TokenUsage,
type CompressionResult, type CompressionResult,
+16 -1
View File
@@ -82,7 +82,7 @@ export type AgentModeType = 'build' | 'plan';
// 客户端发送的消息 // 客户端发送的消息
export interface ClientMessage { export interface ClientMessage {
type: 'message' | 'cancel' | 'tool_response' | 'permission_response' | 'config_update' | 'mode_switch'; type: 'message' | 'cancel' | 'tool_response' | 'permission_response' | 'config_update' | 'mode_switch' | 'user_input_response';
sessionId: string; sessionId: string;
payload?: { payload?: {
content?: string; content?: string;
@@ -95,6 +95,8 @@ export interface ClientMessage {
// Agent mode fields // Agent mode fields
agentMode?: AgentModeType; agentMode?: AgentModeType;
autoApprove?: boolean; autoApprove?: boolean;
// User input response fields (for ask_user_question)
answer?: string;
}; };
} }
@@ -108,6 +110,7 @@ export interface ServerMessage {
| 'tool_result' | 'tool_result'
| 'tool_start' // 工具开始执行 | 'tool_start' // 工具开始执行
| 'tool_end' // 工具执行完成 | 'tool_end' // 工具执行完成
| 'waiting_for_input' // 等待用户输入(如 ask_user_question
| 'done' | 'done'
| 'cancelled' | 'cancelled'
| 'error' | 'error'
@@ -140,6 +143,18 @@ export interface ToolEndPayload {
duration?: number; duration?: number;
} }
// 等待用户输入事件 Payload
export interface WaitingForInputPayload {
/** 工具调用 ID,用于匹配用户回答 */
id: string;
/** 工具名称 */
toolName: string;
/** 问题列表 */
questions: unknown[];
/** 工具参数 */
arguments: Record<string, unknown>;
}
// ============ 子 Agent 事件 Payload ============ // ============ 子 Agent 事件 Payload ============
/** 子 Agent 开始事件 Payload */ /** 子 Agent 开始事件 Payload */
+28 -1
View File
@@ -6,7 +6,7 @@
import type { WSContext } from 'hono/ws'; import type { WSContext } from 'hono/ws';
import { getSessionManager } from './session/manager.js'; import { getSessionManager } from './session/manager.js';
import { processMessage, cancelProcessing, getOrCreateAgent } from './agent/index.js'; import { processMessage, cancelProcessing, getOrCreateAgent, submitUserInput } from './agent/index.js';
import { handlePermissionResponse, setSessionAutoApprove } from './permission/handler.js'; import { handlePermissionResponse, setSessionAutoApprove } from './permission/handler.js';
import type { ClientMessage, ServerMessage } from './types.js'; import type { ClientMessage, ServerMessage } from './types.js';
@@ -201,6 +201,33 @@ export async function handleWebSocketMessage(
break; break;
} }
case 'user_input_response': {
// 处理用户输入响应(用于 ask_user_question 等工具)
const toolCallId = message.payload?.toolCallId;
const answer = message.payload?.answer;
if (toolCallId && answer !== undefined) {
const handled = submitUserInput(toolCallId, answer);
if (!handled) {
console.warn(`[WS] User input response for unknown tool call: ${toolCallId}`);
broadcastToSession(sessionId, {
type: 'error',
sessionId,
payload: { message: `No pending input request for tool call: ${toolCallId}` },
});
} else {
console.log(`[WS] User input submitted for tool call: ${toolCallId}`);
}
} else {
broadcastToSession(sessionId, {
type: 'error',
sessionId,
payload: { message: 'Missing toolCallId or answer in user_input_response' },
});
}
break;
}
default: default:
ws.send( ws.send(
JSON.stringify({ JSON.stringify({
+12
View File
@@ -954,6 +954,18 @@ export interface ToolEndPayload {
duration?: number; duration?: number;
} }
/** 等待用户输入事件 Payload */
export interface WaitingForInputPayload {
/** 工具调用 ID,用于匹配用户回答 */
id: string;
/** 工具名称 */
toolName: string;
/** 问题列表 */
questions: Question[];
/** 工具参数 */
arguments: Record<string, unknown>;
}
// ============ 子 Agent 事件 Payload ============ // ============ 子 Agent 事件 Payload ============
/** 子 Agent 开始事件 Payload */ /** 子 Agent 开始事件 Payload */
+6 -3
View File
@@ -84,19 +84,22 @@ export const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>(
); );
case 'tool': case 'tool':
return <ToolPartItem key={part.id} part={part} />; return <ToolPartItem key={part.id} part={part} />;
case 'question': case 'question': {
// 问题组件:即使在流式输出时也允许用户回答(除非已回答)
const questionPart = part as QuestionMessagePart;
return ( return (
<AskUserQuestion <AskUserQuestion
key={part.id} key={part.id}
part={part as QuestionMessagePart} part={questionPart}
onAnswer={ onAnswer={
onAnswerQuestion onAnswerQuestion
? (answers) => onAnswerQuestion(part.id, answers) ? (answers) => onAnswerQuestion(part.id, answers)
: undefined : undefined
} }
disabled={isStreaming} disabled={questionPart.answered}
/> />
); );
}
case 'reasoning': case 'reasoning':
return ( return (
<div key={part.id} className="text-fg-muted italic border-l-2 border-line pl-3"> <div key={part.id} className="text-fg-muted italic border-l-2 border-line pl-3">
+56 -32
View File
@@ -11,6 +11,7 @@ import type {
ConfigErrorPayload, ConfigErrorPayload,
ToolStartPayload, ToolStartPayload,
ToolEndPayload, ToolEndPayload,
WaitingForInputPayload,
MessagePart, MessagePart,
ToolMessagePart, ToolMessagePart,
QuestionMessagePart, QuestionMessagePart,
@@ -320,6 +321,46 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate
break; break;
} }
case 'waiting_for_input': {
// 工具需要用户输入(如 ask_user_question
const payload = message.payload as WaitingForInputPayload;
setState((prev) => {
if (!prev.streamingMessage) return prev;
// 创建 QuestionMessagePart
const questionPart: QuestionMessagePart = {
type: 'question',
id: payload.id,
questions: payload.questions,
answered: false,
};
// 查找并替换对应的工具 part(如果存在),或添加新的
let found = false;
const parts = prev.streamingMessage.parts.map((part) => {
if (part.type === 'tool' && part.id === payload.id) {
found = true;
return questionPart;
}
return part;
});
// 如果没找到对应的工具 part,添加到末尾
if (!found) {
parts.push(questionPart);
}
return {
...prev,
streamingMessage: {
...prev.streamingMessage,
parts,
},
};
});
break;
}
case 'done': case 'done':
setState((prev) => { setState((prev) => {
// 使用流式消息或创建新消息 // 使用流式消息或创建新消息
@@ -655,6 +696,21 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate
// 回答问题(ask_user_question 工具) // 回答问题(ask_user_question 工具)
const answerQuestion = useCallback( const answerQuestion = useCallback(
(questionPartId: string, answers: string[]) => { (questionPartId: string, answers: string[]) => {
// 发送用户输入响应
const answerText = answers.filter((a) => a).join('\n');
if (answerText && wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
wsRef.current.send(
JSON.stringify({
type: 'user_input_response',
sessionId,
payload: {
toolCallId: questionPartId,
answer: answerText,
},
})
);
}
// 更新问题状态为已回答 // 更新问题状态为已回答
setState((prev) => { setState((prev) => {
// 更新流式消息中的问题 // 更新流式消息中的问题
@@ -670,22 +726,6 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate
return part; return part;
}); });
// 发送回答作为用户消息
const answerText = answers.filter((a) => a).join('\n');
if (answerText && wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
wsRef.current.send(
JSON.stringify({
type: 'message',
sessionId,
payload: {
content: answerText,
agentMode: state.agentMode,
autoApprove: state.autoApprove,
},
})
);
}
return { return {
...prev, ...prev,
streamingMessage: { streamingMessage: {
@@ -713,22 +753,6 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate
return msg; return msg;
}); });
// 发送回答作为用户消息
const answerText = answers.filter((a) => a).join('\n');
if (answerText && wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
wsRef.current.send(
JSON.stringify({
type: 'message',
sessionId,
payload: {
content: answerText,
agentMode: state.agentMode,
autoApprove: state.autoApprove,
},
})
);
}
return { ...prev, messages }; return { ...prev, messages };
}); });
}, },