diff --git a/packages/core/src/core/agent.ts b/packages/core/src/core/agent.ts index 3341eb9..69f8038 100644 --- a/packages/core/src/core/agent.ts +++ b/packages/core/src/core/agent.ts @@ -23,11 +23,35 @@ import { getProviderRegistry, resolveApiKey } from '../provider/index.js'; import { getHookManager } from '../hooks/index.js'; import { getGitManager } from '../git/index.js'; +/** + * 工具调用开始事件信息 + */ +export interface ToolStartInfo { + id: string; + toolName: string; + args: Record; +} + +/** + * 工具调用结束事件信息 + */ +export interface ToolEndInfo { + id: string; + status: 'completed' | 'error'; + result?: unknown; + error?: string; + duration?: number; +} + /** * Agent.chat() 选项 */ export interface AgentChatOptions { onStream?: (text: string) => void; + /** 工具开始执行回调 */ + onToolStart?: (info: ToolStartInfo) => void; + /** 工具执行完成回调 */ + onToolEnd?: (info: ToolEndInfo) => void; abortSignal?: AbortSignal; } @@ -324,7 +348,10 @@ export class Agent { async chat(userMessage: string | UserInput, options?: AgentChatOptions | ((text: string) => void)): Promise { // 兼容旧的 onStream 参数 const opts: AgentChatOptions = typeof options === 'function' ? { onStream: options } : (options || {}); - const { onStream, abortSignal } = opts; + const { onStream, onToolStart, onToolEnd, abortSignal } = opts; + + // 工具调用时间跟踪 + const toolStartTimes = new Map(); // 处理带图片的消息 let processedMessage = userMessage; @@ -405,19 +432,55 @@ export class Agent { abortSignal, // 支持取消 onChunk: ({ chunk }) => { if (chunk.type === 'tool-call') { - onStream(`\n[调用工具: ${chunk.toolName}]\n`); + // AI SDK 中工具参数字段名为 input + const toolCallChunk = chunk as { toolCallId: string; toolName: string; input: unknown }; + const toolCallId = toolCallChunk.toolCallId || `tool-${Date.now()}`; + + // 记录开始时间 + toolStartTimes.set(toolCallId, Date.now()); + + // 调用 onToolStart 回调 + if (onToolStart) { + onToolStart({ + id: toolCallId, + toolName: toolCallChunk.toolName, + args: (toolCallChunk.input as Record) || {}, + }); + } else { + // 仅在没有 onToolStart 回调时输出文本(向后兼容 CLI) + onStream?.(`\n[调用工具: ${toolCallChunk.toolName}]\n`); + } } else if (chunk.type === 'tool-result') { - const output = (chunk as { output?: ToolResult }).output; + const toolResultChunk = chunk as { toolCallId: string; output?: ToolResult }; + const toolCallId = toolResultChunk.toolCallId || ''; + const output = toolResultChunk.output; + + // 计算执行时长 + const startTime = toolStartTimes.get(toolCallId); + const duration = startTime ? Date.now() - startTime : undefined; + toolStartTimes.delete(toolCallId); + if (output && typeof output === 'object') { - if (output.success) { - // 截断过长的输出 - const displayOutput = - output.output.length > 500 - ? output.output.substring(0, 500) + '...(截断)' - : output.output; - onStream(`[结果: ${displayOutput}]\n`); + // 调用 onToolEnd 回调 + if (onToolEnd) { + onToolEnd({ + id: toolCallId, + status: output.success ? 'completed' : 'error', + result: output.success ? output.output : undefined, + error: output.success ? undefined : output.error, + duration, + }); } else { - onStream(`[错误: ${output.error}]\n`); + // 仅在没有 onToolEnd 回调时输出文本(向后兼容 CLI) + if (output.success) { + const displayOutput = + output.output.length > 500 + ? output.output.substring(0, 500) + '...(截断)' + : output.output; + onStream?.(`[结果: ${displayOutput}]\n`); + } else { + onStream?.(`[错误: ${output.error}]\n`); + } } } } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index d534234..604c792 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -1,5 +1,5 @@ export { Agent } from './core/agent.js'; -export type { AgentChatOptions } from './core/agent.js'; +export type { AgentChatOptions, ToolStartInfo, ToolEndInfo } from './core/agent.js'; export { toolRegistry, todoManager, initTaskContext, updateTaskDescription, updateSkillDescription } from './tools/index.js'; export { loadConfig, saveConfig, getConfig, loadVisionConfig, ConfigurationError } from './utils/config.js'; export type { VisionConfig } from './utils/config.js'; diff --git a/packages/core/src/session/converter.ts b/packages/core/src/session/converter.ts index 09dcc41..c45e1d0 100644 --- a/packages/core/src/session/converter.ts +++ b/packages/core/src/session/converter.ts @@ -45,13 +45,14 @@ function messageToModelMessages(msg: Message): ModelMessage[] { } // 添加工具调用部分(只有 running 或已完成的工具) + // AI SDK v5 使用 input 字段(不是 args) for (const toolPart of toolParts) { if (toolPart.state.status !== 'pending') { assistantContent.push({ type: 'tool-call', toolCallId: toolPart.toolCallId, toolName: toolPart.toolName, - args: toolPart.state.input, + input: toolPart.state.input, }); } } @@ -87,12 +88,17 @@ function messageToModelMessages(msg: Message): ModelMessage[] { const output = state.status === 'completed' ? (state as { output: unknown }).output : (state as { error: string }).error; + // 获取 input(AI SDK v5 要求 tool-result 必须包含 input) + const input = state.status !== 'pending' + ? (state as { input: Record }).input + : {}; return { type: 'tool-result' as const, toolCallId: toolPart.toolCallId, toolName: toolPart.toolName, - result: output, + input, + output, }; }); @@ -134,12 +140,18 @@ export function toModelMessages(messages: Message[]): ModelMessage[] { /** * 获取工具调用的输入参数(兼容不同状态) + * 注意:AI SDK 的 input 是 unknown 类型,这里做安全转换 */ export function getToolInput(toolPart: ToolPart): Record { if (toolPart.state.status === 'pending') { return {}; } - return toolPart.state.input; + const input = toolPart.state.input; + // 安全转换:如果是对象返回对象,否则返回空对象 + if (input && typeof input === 'object' && !Array.isArray(input)) { + return input as Record; + } + return {}; } /** diff --git a/packages/core/src/session/manager.ts b/packages/core/src/session/manager.ts index faf6f7c..41850ca 100644 --- a/packages/core/src/session/manager.ts +++ b/packages/core/src/session/manager.ts @@ -224,7 +224,8 @@ export class SessionManager { } else if (role === 'assistant') { // Assistant 消息:文本 + 工具调用 const content: unknown[] = []; - const completedTools: Array<{ toolCallId: string; toolName: string; output: unknown }> = []; + // input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等) + const completedTools: Array<{ toolCallId: string; toolName: string; input: unknown; output: unknown }> = []; for (const part of parts) { if (part.type === 'text') { @@ -232,11 +233,12 @@ export class SessionManager { } else if (part.type === 'tool') { // 只有非 pending 状态的工具调用才添加到 AI SDK 消息 if (part.state.status !== 'pending') { + // AI SDK v5 使用 input 字段(不是 args) content.push({ type: 'tool-call', toolCallId: part.toolCallId, toolName: part.toolName, - args: part.state.input, + input: part.state.input, }); // 收集已完成的工具结果 @@ -244,12 +246,14 @@ export class SessionManager { completedTools.push({ toolCallId: part.toolCallId, toolName: part.toolName, + input: part.state.input, output: part.state.output, }); } else if (part.state.status === 'error') { completedTools.push({ toolCallId: part.toolCallId, toolName: part.toolName, + input: part.state.input, output: part.state.error, }); } @@ -273,6 +277,7 @@ export class SessionManager { } // 添加 tool 消息(如果有已完成的工具) + // AI SDK v5 要求 tool-result 必须包含 input 和 output 字段 if (completedTools.length > 0) { result.push({ role: 'tool', @@ -280,7 +285,8 @@ export class SessionManager { type: 'tool-result', toolCallId: t.toolCallId, toolName: t.toolName, - result: t.output, + input: t.input, + output: t.output, })), } as unknown as ModelMessage); } @@ -454,7 +460,8 @@ export class SessionManager { for (const item of message.content) { const itemType = (item as { type: string }).type; if (itemType === 'tool-result') { - const toolResult = item as unknown as { toolCallId: string; toolName: string; result: unknown }; + // AI SDK v5 使用 output 字段存储结果(不是 result) + const toolResult = item as unknown as { toolCallId: string; toolName: string; output: unknown }; const partId = toolCallPartIds.get(toolResult.toolCallId); if (partId) { // 更新工具状态为 completed @@ -463,7 +470,7 @@ export class SessionManager { const startTime = part?.type === 'tool' && part.state.status === 'running' ? part.state.time.start : Date.now(); - await PartStorage.setToolCompleted(currentAssistantMsgId, partId, toolResult.result, startTime); + await PartStorage.setToolCompleted(currentAssistantMsgId, partId, toolResult.output, startTime); } } } diff --git a/packages/core/src/session/parts.ts b/packages/core/src/session/parts.ts index 6fd8256..c212f3b 100644 --- a/packages/core/src/session/parts.ts +++ b/packages/core/src/session/parts.ts @@ -27,20 +27,22 @@ export type ToolStatePending = z.infer; /** * 工具状态机 - Running(执行中) + * 注意:input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等) */ export const ToolStateRunningSchema = z.object({ status: z.literal('running'), - input: z.record(z.string(), z.unknown()), + input: z.unknown(), time: z.object({ start: z.number() }), }); export type ToolStateRunning = z.infer; /** * 工具状态机 - Completed(执行完成) + * 注意:input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等) */ export const ToolStateCompletedSchema = z.object({ status: z.literal('completed'), - input: z.record(z.string(), z.unknown()), + input: z.unknown(), output: z.unknown(), time: z.object({ start: z.number(), end: z.number() }), }); @@ -48,10 +50,11 @@ export type ToolStateCompleted = z.infer; /** * 工具状态机 - Error(执行出错) + * 注意:input 使用 unknown 类型以兼容 AI SDK(可能是对象、字符串等) */ export const ToolStateErrorSchema = z.object({ status: z.literal('error'), - input: z.record(z.string(), z.unknown()), + input: z.unknown(), error: z.string(), time: z.object({ start: z.number(), end: z.number() }), }); diff --git a/packages/desktop/src/pages/Chat.tsx b/packages/desktop/src/pages/Chat.tsx index 8ab5bb1..64be9d7 100644 --- a/packages/desktop/src/pages/Chat.tsx +++ b/packages/desktop/src/pages/Chat.tsx @@ -9,7 +9,6 @@ import { toast } from 'sonner'; import { useChat, ChatMessage, - StreamingMessage, TypingIndicator, ChatInput, } from '@ai-assistant/ui'; @@ -46,7 +45,7 @@ export function ChatPage({ messages, isConnected, isLoading, - streamingContent, + streamingMessage, sendMessage, cancelProcessing, } = useChat({ @@ -73,7 +72,7 @@ export function ChatPage({ // 自动滚动到底部 useEffect(() => { messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); - }, [messages, streamingContent]); + }, [messages, streamingMessage]); // 空状态组件 const EmptyState = () => ( @@ -270,9 +269,12 @@ export function ChatPage({ ))} - {streamingContent && } + {/* 流式消息 - 复用 ChatMessage 组件 */} + {streamingMessage && ( + + )} - {isLoading && !streamingContent && } + {isLoading && !streamingMessage && }
diff --git a/packages/server/src/agent/adapter.ts b/packages/server/src/agent/adapter.ts index 031a2e1..86aa905 100644 --- a/packages/server/src/agent/adapter.ts +++ b/packages/server/src/agent/adapter.ts @@ -67,11 +67,33 @@ interface SessionManagerConstructor { new (): SessionManagerInstance; } +/** + * 工具开始信息 + */ +interface ToolStartInfo { + id: string; + toolName: string; + args: Record; +} + +/** + * 工具结束信息 + */ +interface ToolEndInfo { + id: string; + status: 'completed' | 'error'; + result?: unknown; + error?: string; + duration?: number; +} + /** * Chat 选项接口 */ interface ChatOptions { onStream?: (chunk: string) => void; + onToolStart?: (info: ToolStartInfo) => void; + onToolEnd?: (info: ToolEndInfo) => void; abortSignal?: AbortSignal; } @@ -397,7 +419,7 @@ export async function processMessage(sessionId: string, content: string): Promis payload: { content: chunk }, }); - // 检测工具调用 + // 检测工具调用(向后兼容 - SSE 日志) if (chunk.includes('[调用工具:')) { const match = chunk.match(/\[调用工具: (.+?)\]/); if (match) { @@ -405,6 +427,38 @@ export async function processMessage(sessionId: string, content: string): Promis } } }, + onToolStart: (info) => { + // 检查是否已取消 + if (abortController.signal.aborted) return; + + // 推送工具开始事件 + broadcastToSession(sessionId, { + type: 'tool_start', + sessionId, + payload: { + id: info.id, + toolName: info.toolName, + arguments: info.args, + }, + }); + }, + onToolEnd: (info) => { + // 检查是否已取消 + if (abortController.signal.aborted) return; + + // 推送工具结束事件 + broadcastToSession(sessionId, { + type: 'tool_end', + sessionId, + payload: { + id: info.id, + status: info.status, + result: info.result, + error: info.error, + duration: info.duration, + }, + }); + }, abortSignal: abortController.signal, }); diff --git a/packages/server/src/types.ts b/packages/server/src/types.ts index 7563bd9..776d254 100644 --- a/packages/server/src/types.ts +++ b/packages/server/src/types.ts @@ -110,6 +110,8 @@ export interface ServerMessage { | 'chunk' | 'tool_call' | 'tool_result' + | 'tool_start' // 工具开始执行 + | 'tool_end' // 工具执行完成 | 'done' | 'cancelled' | 'error' @@ -119,6 +121,22 @@ export interface ServerMessage { payload?: unknown; } +// 工具开始事件 Payload +export interface ToolStartPayload { + id: string; + toolName: string; + arguments: Record; +} + +// 工具结束事件 Payload +export interface ToolEndPayload { + id: string; + status: 'completed' | 'error'; + result?: unknown; + error?: string; + duration?: number; +} + // ============ Permission 相关 ============ export type PermissionType = 'bash' | 'file' | 'git' | 'web'; diff --git a/packages/ui/src/api/types.ts b/packages/ui/src/api/types.ts index c551728..14e2033 100644 --- a/packages/ui/src/api/types.ts +++ b/packages/ui/src/api/types.ts @@ -877,3 +877,29 @@ export interface FileSearchResponse { }; } +// ============ 流式工具调用事件 ============ + +/** 工具开始事件 Payload */ +export interface ToolStartPayload { + /** 工具调用唯一 ID */ + id: string; + /** 工具名称 */ + toolName: string; + /** 调用参数 */ + arguments: Record; +} + +/** 工具结束事件 Payload */ +export interface ToolEndPayload { + /** 对应 tool_start 的 ID */ + id: string; + /** 执行状态 */ + status: 'completed' | 'error'; + /** 执行结果 */ + result?: unknown; + /** 错误信息 */ + error?: string; + /** 执行时长 (ms) */ + duration?: number; +} + diff --git a/packages/ui/src/components/ChatMessage.tsx b/packages/ui/src/components/ChatMessage.tsx index ba27761..8760025 100644 --- a/packages/ui/src/components/ChatMessage.tsx +++ b/packages/ui/src/components/ChatMessage.tsx @@ -25,10 +25,12 @@ import type { Message, ToolCallInfo, ToolCallStatus, ToolMessagePart } from '../ interface ChatMessageProps { message: Message; + /** 是否为流式输出中(显示打字光标) */ + isStreaming?: boolean; } export const ChatMessage = forwardRef( - ({ message }, ref) => { + ({ message, isStreaming = false }, ref) => { const isUser = message.role === 'user'; const [copied, setCopied] = useState(false); @@ -42,18 +44,39 @@ export const ChatMessage = forwardRef( const renderContent = () => { // 优先使用 parts 数组(保持原始顺序) if (message.parts && message.parts.length > 0) { + // 查找最后一个文本 part 的索引(用于显示打字光标) + let lastTextPartIndex = -1; + if (isStreaming) { + for (let i = message.parts.length - 1; i >= 0; i--) { + if (message.parts[i].type === 'text') { + lastTextPartIndex = i; + break; + } + } + } + return (
- {message.parts.map((part) => { + {message.parts.map((part, index) => { switch (part.type) { case 'text': - if (!part.text) return null; + if (!part.text && index !== lastTextPartIndex) return null; return isUser ? (
) : ( - +
+ + {/* 流式输出时在最后一个文本末尾显示打字光标 */} + {isStreaming && index === lastTextPartIndex && ( + + )} +
); case 'tool': return ; diff --git a/packages/ui/src/hooks/useChat.ts b/packages/ui/src/hooks/useChat.ts index 4d9a94b..7fab05b 100644 --- a/packages/ui/src/hooks/useChat.ts +++ b/packages/ui/src/hooks/useChat.ts @@ -7,7 +7,13 @@ import { useState, useEffect, useCallback, useRef } from 'react'; import { createWebSocket, getMessages, type Message } from '../api/client.js'; import type { PermissionRequest } from '../components/PermissionDialog.js'; -import type { ConfigErrorPayload } from '../api/types.js'; +import type { + ConfigErrorPayload, + ToolStartPayload, + ToolEndPayload, + MessagePart, + ToolMessagePart, +} from '../api/types.js'; interface UseChatOptions { sessionId: string; @@ -22,7 +28,8 @@ interface ChatState { messages: Message[]; isConnected: boolean; isLoading: boolean; - streamingContent: string; + /** 流式消息对象,复用 Message 结构 */ + streamingMessage: Message | null; permissionRequest: PermissionRequest | null; } @@ -31,7 +38,7 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate messages: [], isConnected: false, isLoading: false, - streamingContent: '', + streamingMessage: null, permissionRequest: null, }); @@ -114,27 +121,136 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate const message = JSON.parse(event.data); switch (message.type) { - case 'chunk': - setState((prev) => ({ - ...prev, - streamingContent: prev.streamingContent + (message.payload?.content || ''), - })); + case 'chunk': { + const chunkContent = message.payload?.content || ''; + setState((prev) => { + // 初始化或获取当前流式消息 + const streaming = prev.streamingMessage || { + id: `streaming-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`, + role: 'assistant' as const, + timestamp: new Date().toISOString(), + parts: [] as MessagePart[], + content: '', + }; + + // 复制 parts 数组以进行修改 + const parts = [...streaming.parts]; + const lastPart = parts[parts.length - 1]; + + // 如果最后一个 part 是 text,追加内容;否则创建新 text part + if (lastPart?.type === 'text') { + parts[parts.length - 1] = { + ...lastPart, + text: lastPart.text + chunkContent, + }; + } else if (chunkContent) { + parts.push({ + type: 'text', + id: `text-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`, + text: chunkContent, + }); + } + + return { + ...prev, + streamingMessage: { + ...streaming, + parts, + content: (streaming.content || '') + chunkContent, + }, + }; + }); break; + } + + case 'tool_start': { + const payload = message.payload as ToolStartPayload; + setState((prev) => { + // 初始化或获取当前流式消息 + const streaming = prev.streamingMessage || { + id: `streaming-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`, + role: 'assistant' as const, + timestamp: new Date().toISOString(), + parts: [] as MessagePart[], + content: '', + }; + + // 添加工具调用 part + const toolPart: ToolMessagePart = { + type: 'tool', + id: payload.id, + toolCallId: payload.id, + toolName: payload.toolName, + status: 'running', + arguments: payload.arguments, + }; + + return { + ...prev, + streamingMessage: { + ...streaming, + parts: [...streaming.parts, toolPart], + }, + }; + }); + break; + } + + case 'tool_end': { + const payload = message.payload as ToolEndPayload; + setState((prev) => { + if (!prev.streamingMessage) return prev; + + // 查找并更新对应的工具 part + const parts = prev.streamingMessage.parts.map((part) => { + if (part.type === 'tool' && part.id === payload.id) { + return { + ...part, + status: payload.status, + result: payload.result, + error: payload.error, + duration: payload.duration, + } as ToolMessagePart; + } + return part; + }); + + return { + ...prev, + streamingMessage: { + ...prev.streamingMessage, + parts, + }, + }; + }); + break; + } case 'done': setState((prev) => { - const content = message.payload?.content || prev.streamingContent; - const newMessage: Message = { - id: message.payload?.id || `assistant-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`, - role: 'assistant', - timestamp: message.payload?.timestamp || new Date().toISOString(), - parts: [{ type: 'text', id: `text-${Date.now()}`, text: content }], - content, - }; + // 使用流式消息或创建新消息 + const streaming = prev.streamingMessage; + const content = message.payload?.content || streaming?.content || ''; + + const newMessage: Message = streaming + ? { + ...streaming, + id: message.payload?.id || streaming.id, + timestamp: message.payload?.timestamp || streaming.timestamp, + content, + } + : { + id: message.payload?.id || `assistant-${Date.now()}-${Math.random().toString(36).slice(2, 9)}`, + role: 'assistant', + timestamp: message.payload?.timestamp || new Date().toISOString(), + parts: [{ type: 'text', id: `text-${Date.now()}`, text: content }], + content, + }; + return { ...prev, messages: [...prev.messages, newMessage], - streamingContent: '', + streamingMessage: null, isLoading: false, }; }); @@ -165,7 +281,7 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate } else { onErrorRef.current?.(new Error(message.payload?.message || 'Unknown error')); } - setState((prev) => ({ ...prev, isLoading: false, streamingContent: '' })); + setState((prev) => ({ ...prev, isLoading: false, streamingMessage: null })); break; case 'session_updated': @@ -225,7 +341,7 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate }) ); - setState((prev) => ({ ...prev, isLoading: false, streamingContent: '' })); + setState((prev) => ({ ...prev, isLoading: false, streamingMessage: null })); }, [sessionId]); // 发送权限响应 @@ -274,7 +390,7 @@ export function useChat({ sessionId, onError, onSessionNotFound, onSessionUpdate messages: [], isConnected: false, isLoading: false, - streamingContent: '', + streamingMessage: null, permissionRequest: null, }); reconnectAttemptsRef.current = 0; diff --git a/packages/web/src/pages/Chat.tsx b/packages/web/src/pages/Chat.tsx index dfccab8..7369248 100644 --- a/packages/web/src/pages/Chat.tsx +++ b/packages/web/src/pages/Chat.tsx @@ -9,7 +9,6 @@ import { toast } from 'sonner'; import { useChat, ChatMessage, - StreamingMessage, TypingIndicator, ChatInput, PermissionDialog, @@ -52,7 +51,7 @@ export function ChatPage({ messages, isConnected, isLoading, - streamingContent, + streamingMessage, sendMessage, cancelProcessing, permissionRequest, @@ -83,7 +82,7 @@ export function ChatPage({ // 自动滚动到底部 useEffect(() => { messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' }); - }, [messages, streamingContent]); + }, [messages, streamingMessage]); // 空状态组件 const EmptyState = () => ( @@ -290,9 +289,12 @@ export function ChatPage({ ))} - {streamingContent && } + {/* 流式消息 - 复用 ChatMessage 组件 */} + {streamingMessage && ( + + )} - {isLoading && !streamingContent && } + {isLoading && !streamingMessage && }