fix(storage): 修复 multi-step 对话产生多个 assistant 消息问题

- 同一轮对话的后续 assistant 消息追加 Parts 到现有消息
- 修复 AI SDK tool-call 参数字段名(input 而非 args)
- 修复 setToolCompleted/setToolError 中 discriminated union 类型访问
This commit is contained in:
2025-12-15 13:58:10 +08:00
parent 9f456c1029
commit d54b9788fa
2 changed files with 42 additions and 18 deletions
+32 -16
View File
@@ -397,39 +397,55 @@ export class SessionManager {
toolCallPartIds.clear(); toolCallPartIds.clear();
} else if (message.role === 'assistant') { } else if (message.role === 'assistant') {
// Assistant 消息 // Assistant 消息:如果当前轮次已有 assistant 消息,则追加 Parts
const messageInfo = await MessageStorage.create(sessionId, 'assistant', { let messageId: string;
parentId: currentUserMsgId ?? undefined, let existingPartIds: string[] = [];
});
currentAssistantMsgId = messageInfo.id; if (currentAssistantMsgId) {
const partIds: string[] = []; // 同一轮对话的后续 assistant 消息,追加到现有消息
messageId = currentAssistantMsgId;
const existingMsg = await MessageStorage.get(sessionId, messageId);
existingPartIds = existingMsg?.partIds ?? [];
} else {
// 新的 assistant 消息
const messageInfo = await MessageStorage.create(sessionId, 'assistant', {
parentId: currentUserMsgId ?? undefined,
});
messageId = messageInfo.id;
currentAssistantMsgId = messageId;
}
const newPartIds: string[] = [];
if (typeof message.content === 'string') { if (typeof message.content === 'string') {
const part = await PartStorage.createText(messageInfo.id, message.content); const part = await PartStorage.createText(messageId, message.content);
partIds.push(part.id); newPartIds.push(part.id);
} else if (Array.isArray(message.content)) { } else if (Array.isArray(message.content)) {
for (const item of message.content) { for (const item of message.content) {
const itemType = (item as { type: string }).type; const itemType = (item as { type: string }).type;
if (itemType === 'text') { if (itemType === 'text') {
const part = await PartStorage.createText(messageInfo.id, (item as { text: string }).text); const part = await PartStorage.createText(messageId, (item as { text: string }).text);
partIds.push(part.id); newPartIds.push(part.id);
} else if (itemType === 'tool-call') { } else if (itemType === 'tool-call') {
const toolCall = item as unknown as { toolCallId: string; toolName: string; args: Record<string, unknown> }; // AI SDK 的 tool-call 使用 input 字段存储参数(不是 args)
const toolCall = item as unknown as { toolCallId: string; toolName: string; input: Record<string, unknown> };
// 创建 running 状态的工具 Part // 创建 running 状态的工具 Part
const part = await PartStorage.createToolRunning( const part = await PartStorage.createToolRunning(
messageInfo.id, messageId,
toolCall.toolCallId, toolCall.toolCallId,
toolCall.toolName, toolCall.toolName,
toolCall.args ?? {} (toolCall.input as Record<string, unknown>) ?? {}
); );
partIds.push(part.id); newPartIds.push(part.id);
toolCallPartIds.set(toolCall.toolCallId, part.id); toolCallPartIds.set(toolCall.toolCallId, part.id);
} }
} }
} }
if (partIds.length > 0) { if (newPartIds.length > 0) {
await MessageStorage.update(sessionId, messageInfo.id, { partIds }); // 合并已有的和新的 partIds
const allPartIds = [...existingPartIds, ...newPartIds];
await MessageStorage.update(sessionId, messageId, { partIds: allPartIds });
} }
} else if (message.role === 'tool' && currentAssistantMsgId) { } else if (message.role === 'tool' && currentAssistantMsgId) {
+10 -2
View File
@@ -98,7 +98,11 @@ export async function setToolCompleted(
): Promise<ToolPart> { ): Promise<ToolPart> {
// 先获取当前 part 以获取 input // 先获取当前 part 以获取 input
const part = await get(messageId, partId) as ToolPart | null; const part = await get(messageId, partId) as ToolPart | null;
const input = part?.state.status !== 'pending' ? part?.state.input : {}; // 从 running 状态获取 input(需要类型断言因为 discriminated union
const state = part?.state;
const input = state && state.status !== 'pending'
? (state as { input: Record<string, unknown> }).input
: {};
return updateToolState(messageId, partId, { return updateToolState(messageId, partId, {
status: 'completed', status: 'completed',
@@ -119,7 +123,11 @@ export async function setToolError(
): Promise<ToolPart> { ): Promise<ToolPart> {
// 先获取当前 part 以获取 input // 先获取当前 part 以获取 input
const part = await get(messageId, partId) as ToolPart | null; const part = await get(messageId, partId) as ToolPart | null;
const input = part?.state.status !== 'pending' ? part?.state.input : {}; // 从 running 状态获取 input(需要类型断言因为 discriminated union
const state = part?.state;
const input = state && state.status !== 'pending'
? (state as { input: Record<string, unknown> }).input
: {};
return updateToolState(messageId, partId, { return updateToolState(messageId, partId, {
status: 'error', status: 'error',