feat(core): 实现 ask_user_question 工具的用户输入等待机制
- 创建 UserInputWaiter 管理用户输入等待状态 - 修改 agent-tool-executor 在 requiresUserInput 时等待用户回答 - 添加 onWaitingForInput 回调通知前端显示问题 - Server 端处理 waiting_for_input 广播和 user_input_response 消息 - 前端处理问题显示和用户回答提交 - 修复问题选项在流式输出时被禁用的问题
This commit is contained in:
@@ -21,10 +21,12 @@ import {
|
||||
getProviderRegistry,
|
||||
agentRegistry,
|
||||
agentEventEmitter,
|
||||
getUserInputWaiter,
|
||||
ConfigurationError,
|
||||
type AgentChatOptions,
|
||||
type ToolStartInfo,
|
||||
type ToolEndInfo,
|
||||
type WaitingForInputInfo,
|
||||
type TokenUsage,
|
||||
type DetailedCompressionResult,
|
||||
} from '@ai-assistant/core';
|
||||
@@ -380,6 +382,22 @@ export async function processMessage(
|
||||
},
|
||||
});
|
||||
},
|
||||
onWaitingForInput: (info: WaitingForInputInfo) => {
|
||||
// 检查是否已取消
|
||||
if (abortController.signal.aborted) return;
|
||||
|
||||
// 推送等待用户输入事件
|
||||
broadcastToSession(sessionId, {
|
||||
type: 'waiting_for_input',
|
||||
sessionId,
|
||||
payload: {
|
||||
id: info.id,
|
||||
toolName: info.toolName,
|
||||
questions: info.questions,
|
||||
arguments: info.args,
|
||||
},
|
||||
});
|
||||
},
|
||||
abortSignal: abortController.signal,
|
||||
};
|
||||
|
||||
@@ -614,3 +632,14 @@ export async function compressContext(
|
||||
|
||||
// Re-export TokenUsage for external use
|
||||
export type { TokenUsage };
|
||||
|
||||
/**
|
||||
* 提交用户输入响应(用于 ask_user_question 等工具)
|
||||
* @param toolCallId 工具调用 ID
|
||||
* @param answer 用户的回答
|
||||
* @returns 是否成功提交
|
||||
*/
|
||||
export function submitUserInput(toolCallId: string, answer: string): boolean {
|
||||
const userInputWaiter = getUserInputWaiter();
|
||||
return userInputWaiter.submitInput(toolCallId, answer);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,8 @@ export {
|
||||
// 上下文压缩相关
|
||||
getContextUsage,
|
||||
compressContext,
|
||||
// 用户输入响应
|
||||
submitUserInput,
|
||||
// 类型导出
|
||||
type TokenUsage,
|
||||
type CompressionResult,
|
||||
|
||||
@@ -82,7 +82,7 @@ export type AgentModeType = 'build' | 'plan';
|
||||
|
||||
// 客户端发送的消息
|
||||
export interface ClientMessage {
|
||||
type: 'message' | 'cancel' | 'tool_response' | 'permission_response' | 'config_update' | 'mode_switch';
|
||||
type: 'message' | 'cancel' | 'tool_response' | 'permission_response' | 'config_update' | 'mode_switch' | 'user_input_response';
|
||||
sessionId: string;
|
||||
payload?: {
|
||||
content?: string;
|
||||
@@ -95,6 +95,8 @@ export interface ClientMessage {
|
||||
// Agent mode fields
|
||||
agentMode?: AgentModeType;
|
||||
autoApprove?: boolean;
|
||||
// User input response fields (for ask_user_question)
|
||||
answer?: string;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -108,6 +110,7 @@ export interface ServerMessage {
|
||||
| 'tool_result'
|
||||
| 'tool_start' // 工具开始执行
|
||||
| 'tool_end' // 工具执行完成
|
||||
| 'waiting_for_input' // 等待用户输入(如 ask_user_question)
|
||||
| 'done'
|
||||
| 'cancelled'
|
||||
| 'error'
|
||||
@@ -140,6 +143,18 @@ export interface ToolEndPayload {
|
||||
duration?: number;
|
||||
}
|
||||
|
||||
// 等待用户输入事件 Payload
|
||||
export interface WaitingForInputPayload {
|
||||
/** 工具调用 ID,用于匹配用户回答 */
|
||||
id: string;
|
||||
/** 工具名称 */
|
||||
toolName: string;
|
||||
/** 问题列表 */
|
||||
questions: unknown[];
|
||||
/** 工具参数 */
|
||||
arguments: Record<string, unknown>;
|
||||
}
|
||||
|
||||
// ============ 子 Agent 事件 Payload ============
|
||||
|
||||
/** 子 Agent 开始事件 Payload */
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
import type { WSContext } from 'hono/ws';
|
||||
import { getSessionManager } from './session/manager.js';
|
||||
import { processMessage, cancelProcessing, getOrCreateAgent } from './agent/index.js';
|
||||
import { processMessage, cancelProcessing, getOrCreateAgent, submitUserInput } from './agent/index.js';
|
||||
import { handlePermissionResponse, setSessionAutoApprove } from './permission/handler.js';
|
||||
import type { ClientMessage, ServerMessage } from './types.js';
|
||||
|
||||
@@ -201,6 +201,33 @@ export async function handleWebSocketMessage(
|
||||
break;
|
||||
}
|
||||
|
||||
case 'user_input_response': {
|
||||
// 处理用户输入响应(用于 ask_user_question 等工具)
|
||||
const toolCallId = message.payload?.toolCallId;
|
||||
const answer = message.payload?.answer;
|
||||
|
||||
if (toolCallId && answer !== undefined) {
|
||||
const handled = submitUserInput(toolCallId, answer);
|
||||
if (!handled) {
|
||||
console.warn(`[WS] User input response for unknown tool call: ${toolCallId}`);
|
||||
broadcastToSession(sessionId, {
|
||||
type: 'error',
|
||||
sessionId,
|
||||
payload: { message: `No pending input request for tool call: ${toolCallId}` },
|
||||
});
|
||||
} else {
|
||||
console.log(`[WS] User input submitted for tool call: ${toolCallId}`);
|
||||
}
|
||||
} else {
|
||||
broadcastToSession(sessionId, {
|
||||
type: 'error',
|
||||
sessionId,
|
||||
payload: { message: 'Missing toolCallId or answer in user_input_response' },
|
||||
});
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
ws.send(
|
||||
JSON.stringify({
|
||||
|
||||
Reference in New Issue
Block a user