Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/ripe-buses-punch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@effect/ai": patch
---

Fix the accumulation logic for response parts in the AI `Chat` module
17 changes: 10 additions & 7 deletions packages/ai/ai/src/Chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import type { PersistenceBackingError } from "@effect/experimental/Persistence"
import { BackingPersistence } from "@effect/experimental/Persistence"
import * as Channel from "effect/Channel"
import * as Chunk from "effect/Chunk"
import * as Context from "effect/Context"
import * as Duration from "effect/Duration"
import * as Effect from "effect/Effect"
Expand Down Expand Up @@ -379,24 +380,26 @@ export const empty: Effect.Effect<Service> = Effect.gen(function*() {
),
streamText: Effect.fnUntraced(
function*(options) {
let combined: Prompt.Prompt = Prompt.empty
let parts = Chunk.empty<Response.AnyPart>()
return Stream.fromChannel(Channel.acquireUseRelease(
semaphore.take(1).pipe(
Effect.zipRight(Ref.get(history)),
Effect.map((history) => Prompt.merge(history, Prompt.make(options.prompt)))
),
(prompt) =>
LanguageModel.streamText({ ...options, prompt }).pipe(
Stream.mapChunksEffect(Effect.fnUntraced(function*(chunk) {
const parts = Array.from(chunk)
combined = Prompt.merge(combined, Prompt.fromResponseParts(parts))
Stream.mapChunks((chunk) => {
parts = Chunk.appendAll(parts, chunk)
return chunk
})),
}),
Stream.toChannel
),
(parts) =>
(prompt) =>
Effect.zipRight(
Ref.set(history, Prompt.merge(parts, combined)),
Ref.set(
history,
Prompt.merge(prompt, Prompt.fromResponseParts(Array.from(parts)))
),
semaphore.release(1)
)
)).pipe(
Expand Down
165 changes: 68 additions & 97 deletions packages/ai/ai/src/Prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1511,38 +1511,6 @@ export const make = (input: RawInput): Prompt => {
*/
export const fromMessages = (messages: ReadonlyArray<Message>): Prompt => makePrompt(messages)

const VALID_RESPONSE_PART_MAP = {
"response-metadata": false,
"text": true,
"text-start": false,
"text-delta": true,
"text-end": false,
"reasoning": true,
"reasoning-start": false,
"reasoning-delta": true,
"reasoning-end": false,
"file": false,
"source": false,
"tool-params-start": false,
"tool-params-delta": false,
"tool-params-end": false,
"tool-call": true,
"tool-result": true,
"finish": false,
"error": false
} as const satisfies Record<Response.AnyPart["type"], boolean>

type ValidResponseParts = typeof VALID_RESPONSE_PART_MAP

type ValidResponsePart = {
[Type in keyof ValidResponseParts]: ValidResponseParts[Type] extends true ? Extract<Response.AnyPart, { type: Type }>
: never
}[keyof typeof VALID_RESPONSE_PART_MAP]

const isValidPart = (part: Response.AnyPart): part is ValidResponsePart => {
return VALID_RESPONSE_PART_MAP[part.type]
}

/**
* Creates a Prompt from the response parts of a previous interaction with a
* large language model.
Expand Down Expand Up @@ -1590,93 +1558,96 @@ export const fromResponseParts = (parts: ReadonlyArray<Response.AnyPart>): Promp
const assistantParts: Array<AssistantMessagePart> = []
const toolParts: Array<ToolMessagePart> = []

const textDeltas: Array<string> = []
function flushTextDeltas() {
if (textDeltas.length > 0) {
const text = textDeltas.join("")
if (text.length > 0) {
assistantParts.push(makePart("text", { text }))
}
textDeltas.length = 0
}
}
const activeTextDeltas = new Map<string, { text: string }>()
const activeReasoningDeltas = new Map<string, { text: string }>()

const reasoningDeltas: Array<string> = []
function flushReasoningDeltas() {
if (reasoningDeltas.length > 0) {
const text = reasoningDeltas.join("")
if (text.length > 0) {
assistantParts.push(makePart("reasoning", { text }))
for (const part of parts) {
switch (part.type) {
// Text Parts
case "text": {
assistantParts.push(makePart("text", { text: part.text }))
break
}
reasoningDeltas.length = 0
}
}

function flushDeltas() {
flushTextDeltas()
flushReasoningDeltas()
}

for (const part of parts) {
if (isValidPart(part)) {
switch (part.type) {
case "text": {
flushDeltas()
assistantParts.push(makePart("text", { text: part.text }))
break
}
case "text-delta": {
flushReasoningDeltas()
textDeltas.push(part.delta)
break
}
case "reasoning": {
flushDeltas()
assistantParts.push(makePart("reasoning", { text: part.text }))
break
// Text Parts (streaming)
case "text-start": {
activeTextDeltas.set(part.id, { text: "" })
break
}
case "text-delta": {
if (activeTextDeltas.has(part.id)) {
activeTextDeltas.get(part.id)!.text += part.delta
}
case "reasoning-delta": {
flushTextDeltas()
reasoningDeltas.push(part.delta)
break
break
}
case "text-end": {
if (activeTextDeltas.has(part.id)) {
assistantParts.push(makePart("text", activeTextDeltas.get(part.id)!))
}
case "tool-call": {
flushDeltas()
assistantParts.push(makePart("tool-call", {
id: part.id,
name: part.providerName ?? part.name,
params: part.params,
providerExecuted: part.providerExecuted ?? false
}))
break
break
}

// Reasoning Parts
case "reasoning": {
assistantParts.push(makePart("reasoning", { text: part.text }))
break
}

// Reasoning Parts (streaming)
case "reasoning-start": {
activeReasoningDeltas.set(part.id, { text: "" })
break
}
case "reasoning-delta": {
if (activeReasoningDeltas.has(part.id)) {
activeReasoningDeltas.get(part.id)!.text += part.delta
}
case "tool-result": {
flushDeltas()
toolParts.push(makePart("tool-result", {
id: part.id,
name: part.providerName ?? part.name,
isFailure: part.isFailure,
result: part.encodedResult
}))
break
break
}
case "reasoning-end": {
if (activeReasoningDeltas.has(part.id)) {
assistantParts.push(makePart("reasoning", activeReasoningDeltas.get(part.id)!))
}
break
}

// Tool Call Parts
case "tool-call": {
assistantParts.push(makePart("tool-call", {
id: part.id,
name: part.providerName ?? part.name,
params: part.params,
providerExecuted: part.providerExecuted ?? false
}))
break
}

// Tool Result Parts
case "tool-result": {
toolParts.push(makePart("tool-result", {
id: part.id,
name: part.providerName ?? part.name,
isFailure: part.isFailure,
result: part.encodedResult
}))
}
}
}

flushDeltas()

if (assistantParts.length === 0 && toolParts.length === 0) {
return empty
}

const messages: Array<Message> = []

if (assistantParts.length > 0) {
messages.push(makeMessage("assistant", { content: assistantParts }))
}

if (toolParts.length > 0) {
messages.push(makeMessage("tool", { content: toolParts }))
}

return makePrompt(messages)
}

Expand Down