From d54b9788fa238ef0d0d9351a8b2cd7b97c6afcf1 Mon Sep 17 00:00:00 2001 From: kurihada Date: Mon, 15 Dec 2025 13:58:10 +0800 Subject: [PATCH] =?UTF-8?q?fix(storage):=20=E4=BF=AE=E5=A4=8D=20multi-step?= =?UTF-8?q?=20=E5=AF=B9=E8=AF=9D=E4=BA=A7=E7=94=9F=E5=A4=9A=E4=B8=AA=20ass?= =?UTF-8?q?istant=20=E6=B6=88=E6=81=AF=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 同一轮对话的后续 assistant 消息追加 Parts 到现有消息 - 修复 AI SDK tool-call 参数字段名(input 而非 args) - 修复 setToolCompleted/setToolError 中 discriminated union 类型访问 --- packages/core/src/session/manager.ts | 48 +++++++++++++++-------- packages/core/src/session/storage/part.ts | 12 +++++- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/packages/core/src/session/manager.ts b/packages/core/src/session/manager.ts index 2d5989e..faf6f7c 100644 --- a/packages/core/src/session/manager.ts +++ b/packages/core/src/session/manager.ts @@ -397,39 +397,55 @@ export class SessionManager { toolCallPartIds.clear(); } else if (message.role === 'assistant') { - // Assistant 消息 - const messageInfo = await MessageStorage.create(sessionId, 'assistant', { - parentId: currentUserMsgId ?? undefined, - }); - currentAssistantMsgId = messageInfo.id; - const partIds: string[] = []; + // Assistant 消息:如果当前轮次已有 assistant 消息,则追加 Parts + let messageId: string; + let existingPartIds: string[] = []; + + if (currentAssistantMsgId) { + // 同一轮对话的后续 assistant 消息,追加到现有消息 + messageId = currentAssistantMsgId; + const existingMsg = await MessageStorage.get(sessionId, messageId); + existingPartIds = existingMsg?.partIds ?? []; + } else { + // 新的 assistant 消息 + const messageInfo = await MessageStorage.create(sessionId, 'assistant', { + parentId: currentUserMsgId ?? undefined, + }); + messageId = messageInfo.id; + currentAssistantMsgId = messageId; + } + + const newPartIds: string[] = []; if (typeof message.content === 'string') { - const part = await PartStorage.createText(messageInfo.id, message.content); - partIds.push(part.id); + const part = await PartStorage.createText(messageId, message.content); + newPartIds.push(part.id); } else if (Array.isArray(message.content)) { for (const item of message.content) { const itemType = (item as { type: string }).type; if (itemType === 'text') { - const part = await PartStorage.createText(messageInfo.id, (item as { text: string }).text); - partIds.push(part.id); + const part = await PartStorage.createText(messageId, (item as { text: string }).text); + newPartIds.push(part.id); } else if (itemType === 'tool-call') { - const toolCall = item as unknown as { toolCallId: string; toolName: string; args: Record }; + // AI SDK 的 tool-call 使用 input 字段存储参数(不是 args) + const toolCall = item as unknown as { toolCallId: string; toolName: string; input: Record }; // 创建 running 状态的工具 Part const part = await PartStorage.createToolRunning( - messageInfo.id, + messageId, toolCall.toolCallId, toolCall.toolName, - toolCall.args ?? {} + (toolCall.input as Record) ?? {} ); - partIds.push(part.id); + newPartIds.push(part.id); toolCallPartIds.set(toolCall.toolCallId, part.id); } } } - if (partIds.length > 0) { - await MessageStorage.update(sessionId, messageInfo.id, { partIds }); + if (newPartIds.length > 0) { + // 合并已有的和新的 partIds + const allPartIds = [...existingPartIds, ...newPartIds]; + await MessageStorage.update(sessionId, messageId, { partIds: allPartIds }); } } else if (message.role === 'tool' && currentAssistantMsgId) { diff --git a/packages/core/src/session/storage/part.ts b/packages/core/src/session/storage/part.ts index 6fd02b3..7ecf57d 100644 --- a/packages/core/src/session/storage/part.ts +++ b/packages/core/src/session/storage/part.ts @@ -98,7 +98,11 @@ export async function setToolCompleted( ): Promise { // 先获取当前 part 以获取 input const part = await get(messageId, partId) as ToolPart | null; - const input = part?.state.status !== 'pending' ? part?.state.input : {}; + // 从 running 状态获取 input(需要类型断言因为 discriminated union) + const state = part?.state; + const input = state && state.status !== 'pending' + ? (state as { input: Record }).input + : {}; return updateToolState(messageId, partId, { status: 'completed', @@ -119,7 +123,11 @@ export async function setToolError( ): Promise { // 先获取当前 part 以获取 input const part = await get(messageId, partId) as ToolPart | null; - const input = part?.state.status !== 'pending' ? part?.state.input : {}; + // 从 running 状态获取 input(需要类型断言因为 discriminated union) + const state = part?.state; + const input = state && state.status !== 'pending' + ? (state as { input: Record }).input + : {}; return updateToolState(messageId, partId, { status: 'error',