Skip to content

Commit f4850e3

Browse files
committed
refactor(genie): extract pollWaiter to simplify _handleSendMessage
Replace manual concurrency code (statusQueue, notifyGenerator, waiterDone, waiterError, IIFE promise chain) with a reusable pollWaiter async generator that bridges callback-based waiter.wait({ onProgress }) into a for-await-of loop. Signed-off-by: Jorge Calvar <jorge.calvar@databricks.com>
1 parent 89c6a56 commit f4850e3

File tree

3 files changed

+249
-81
lines changed

3 files changed

+249
-81
lines changed

packages/appkit/src/plugins/genie/genie.ts

Lines changed: 32 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { createLogger } from "../../logging";
1111
import { Plugin, toPlugin } from "../../plugin";
1212
import { genieStreamDefaults } from "./defaults";
1313
import { genieManifest } from "./manifest";
14+
import { pollWaiter } from "./poll-waiter";
1415
import type {
1516
GenieAttachmentResponse,
1617
GenieConversationHistoryResponse,
@@ -155,110 +156,60 @@ export class GeniePlugin extends Plugin {
155156
const workspaceClient = getWorkspaceClient();
156157

157158
try {
158-
// Status events queue bridging onProgress → generator
159-
const statusQueue: string[] = [];
160-
let notifyGenerator: () => void = () => {};
161-
let waiterDone = false;
162-
163-
const onProgress = async (message: GenieMessage): Promise<void> => {
164-
if (message.status) {
165-
statusQueue.push(message.status);
166-
notifyGenerator();
167-
}
168-
};
169-
170-
let resultConversationId = "";
171-
let resultMessageId = "";
172-
let completedMessage: GenieMessage =
173-
undefined as unknown as GenieMessage;
174-
let waiterError: Error | null = null;
175-
176-
// Launch Genie API call
177-
const waiterPromise = (async () => {
178-
let messageWaiter: CreateMessageWaiter;
179-
180-
if (conversationId) {
181-
messageWaiter = await workspaceClient.genie.createMessage({
159+
// Step 1: API call → get waiter + IDs
160+
let messageWaiter: CreateMessageWaiter;
161+
let resultConversationId: string;
162+
let resultMessageId: string;
163+
164+
if (conversationId) {
165+
messageWaiter = await workspaceClient.genie.createMessage({
166+
space_id: spaceId,
167+
conversation_id: conversationId,
168+
content,
169+
});
170+
resultConversationId = conversationId;
171+
resultMessageId = messageWaiter.message_id ?? "";
172+
} else {
173+
const startWaiter: StartConversationWaiter =
174+
await workspaceClient.genie.startConversation({
182175
space_id: spaceId,
183-
conversation_id: conversationId,
184176
content,
185177
});
186-
resultConversationId = conversationId;
187-
} else {
188-
const startWaiter: StartConversationWaiter =
189-
await workspaceClient.genie.startConversation({
190-
space_id: spaceId,
191-
content,
192-
});
193-
resultConversationId = startWaiter.conversation_id;
194-
resultMessageId = startWaiter.message_id;
195-
messageWaiter = startWaiter as unknown as CreateMessageWaiter;
196-
}
197-
198-
const result = await messageWaiter.wait({ onProgress });
199-
completedMessage = result;
200-
resultMessageId = result.message_id;
201-
return result;
202-
})()
203-
.catch((err: Error) => {
204-
waiterError = err;
205-
})
206-
.finally(() => {
207-
waiterDone = true;
208-
notifyGenerator();
209-
});
210-
211-
// Wait for first status or waiter completion to get IDs
212-
await new Promise<void>((resolve) => {
213-
notifyGenerator = resolve;
214-
if (waiterDone) resolve();
215-
});
216-
217-
// If the API call failed before anything started, yield error and exit
218-
if (waiterError) {
219-
throw waiterError;
178+
resultConversationId = startWaiter.conversation_id;
179+
resultMessageId = startWaiter.message_id;
180+
messageWaiter = startWaiter as unknown as CreateMessageWaiter;
220181
}
221182

222-
// Yield message_start
183+
// Step 2: Yield message_start immediately — IDs are available from API response
223184
yield {
224185
type: "message_start" as const,
225186
conversationId: resultConversationId,
226187
messageId: resultMessageId,
227188
spaceId,
228189
};
229190

230-
// Drain status events
231-
while (!waiterDone || statusQueue.length > 0) {
232-
while (statusQueue.length > 0) {
233-
const status = statusQueue.shift();
234-
if (status) {
235-
yield { type: "status" as const, status };
191+
// Step 3: Poll for status updates and completion
192+
let completedMessage!: GenieMessage;
193+
for await (const event of pollWaiter(messageWaiter)) {
194+
if (event.type === "progress") {
195+
if (event.value.status) {
196+
yield { type: "status" as const, status: event.value.status };
236197
}
198+
} else {
199+
completedMessage = event.value;
200+
resultMessageId = event.value.message_id;
237201
}
238-
239-
if (!waiterDone) {
240-
await new Promise<void>((resolve) => {
241-
notifyGenerator = resolve;
242-
if (waiterDone) resolve();
243-
});
244-
}
245-
}
246-
247-
// Check if waiter failed during polling
248-
await waiterPromise;
249-
if (waiterError) {
250-
throw waiterError;
251202
}
252203

253-
// Build cleaned message response
204+
// Step 4: Build cleaned message response
254205
const messageResponse = toMessageResponse(completedMessage);
255206

256207
yield {
257208
type: "message_result" as const,
258209
message: messageResponse,
259210
};
260211

261-
// Fetch query results for each query attachment
212+
// Step 5: Fetch query results for each query attachment
262213
const attachments = messageResponse.attachments ?? [];
263214
for (const att of attachments) {
264215
if (att.query?.statementId && att.attachmentId) {
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/**
2+
* Structural interface matching the SDK's `Waiter.wait()` shape
3+
* without importing the SDK directly.
4+
*/
5+
export interface Pollable<P> {
6+
wait(options?: {
7+
onProgress?: (p: P) => Promise<void>;
8+
timeout?: unknown;
9+
}): Promise<P>;
10+
}
11+
12+
export type PollEvent<P> =
13+
| { type: "progress"; value: P }
14+
| { type: "completed"; value: P };
15+
16+
/**
17+
* Bridges a callback-based waiter into an async generator.
18+
*
19+
* The SDK's `waiter.wait({ onProgress })` API uses a callback to report
20+
* progress and returns a promise that resolves with the final result.
21+
* This function converts that push-based model into a pull-based async
22+
* generator so callers can simply `for await (const event of pollWaiter(w))`.
23+
*
24+
* Yields `{ type: "progress", value }` for each `onProgress` callback,
25+
* then `{ type: "completed", value }` for the final result.
26+
* Throws if the waiter rejects.
27+
*/
28+
export async function* pollWaiter<P>(
29+
waiter: Pollable<P>,
30+
options?: { timeout?: unknown },
31+
): AsyncGenerator<PollEvent<P>> {
32+
// --- shared state between the onProgress callback and the generator loop ---
33+
const queue: P[] = []; // progress values waiting to be yielded
34+
let notify: () => void = () => {}; // resolves the generator's "sleep" promise
35+
let done = false; // true once waiter.wait() settles (success or error)
36+
let result!: P;
37+
let error: unknown = null;
38+
39+
// Start the waiter in the background (not awaited — runs concurrently
40+
// with the generator loop below). The onProgress callback pushes values
41+
// into the queue and wakes the generator via notify().
42+
waiter
43+
.wait({
44+
onProgress: async (p: P) => {
45+
queue.push(p);
46+
notify();
47+
},
48+
...(options?.timeout != null ? { timeout: options.timeout } : {}),
49+
})
50+
.then((r) => {
51+
result = r;
52+
done = true;
53+
notify();
54+
})
55+
.catch((err) => {
56+
error = err;
57+
done = true;
58+
notify();
59+
});
60+
61+
// Drain progress events as they arrive. The loop exits once the waiter
62+
// has settled AND the queue is empty.
63+
while (!done || queue.length > 0) {
64+
// Yield all queued progress values before sleeping.
65+
while (queue.length > 0) {
66+
const value = queue.shift() as P;
67+
yield { type: "progress", value };
68+
}
69+
70+
// Nothing in the queue yet and the waiter hasn't settled — sleep until
71+
// the next onProgress call or waiter settlement wakes us via notify().
72+
//
73+
// Race-condition guard: after setting `notify = resolve`, we re-check
74+
// `done` and `queue.length`. If either changed between the outer while
75+
// check and this point (possible via microtask), we resolve immediately
76+
// so the loop doesn't hang.
77+
if (!done) {
78+
await new Promise<void>((resolve) => {
79+
notify = resolve;
80+
if (done || queue.length > 0) resolve();
81+
});
82+
}
83+
}
84+
85+
// The waiter settled. If it rejected, propagate the error.
86+
if (error !== null) {
87+
throw error;
88+
}
89+
90+
// Final event: the completed result from waiter.wait().
91+
yield { type: "completed", value: result };
92+
}
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import { describe, expect, test, vi } from "vitest";
2+
import { type Pollable, type PollEvent, pollWaiter } from "../poll-waiter";
3+
4+
function createMockWaiter<P>(opts: {
5+
progressValues?: P[];
6+
result: P;
7+
error?: Error;
8+
delay?: number;
9+
}): Pollable<P> {
10+
return {
11+
wait: vi.fn().mockImplementation(async (options: any = {}) => {
12+
if (opts.progressValues) {
13+
for (const value of opts.progressValues) {
14+
if (opts.delay) {
15+
await new Promise((r) => setTimeout(r, opts.delay));
16+
}
17+
if (options.onProgress) {
18+
await options.onProgress(value);
19+
}
20+
}
21+
}
22+
if (opts.error) throw opts.error;
23+
return opts.result;
24+
}),
25+
};
26+
}
27+
28+
async function collect<P>(
29+
gen: AsyncGenerator<PollEvent<P>>,
30+
): Promise<PollEvent<P>[]> {
31+
const events: PollEvent<P>[] = [];
32+
for await (const event of gen) {
33+
events.push(event);
34+
}
35+
return events;
36+
}
37+
38+
describe("pollWaiter", () => {
39+
test("yields progress events then completed", async () => {
40+
const waiter = createMockWaiter({
41+
progressValues: [{ status: "A" }, { status: "B" }],
42+
result: { status: "DONE" },
43+
});
44+
45+
const events = await collect(pollWaiter(waiter));
46+
47+
expect(events).toEqual([
48+
{ type: "progress", value: { status: "A" } },
49+
{ type: "progress", value: { status: "B" } },
50+
{ type: "completed", value: { status: "DONE" } },
51+
]);
52+
});
53+
54+
test("yields only completed when no progress events", async () => {
55+
const waiter = createMockWaiter({
56+
result: { value: 42 },
57+
});
58+
59+
const events = await collect(pollWaiter(waiter));
60+
61+
expect(events).toEqual([{ type: "completed", value: { value: 42 } }]);
62+
});
63+
64+
test("throws when waiter rejects", async () => {
65+
const waiter = createMockWaiter({
66+
result: null as any,
67+
error: new Error("boom"),
68+
});
69+
70+
const events: PollEvent<any>[] = [];
71+
await expect(async () => {
72+
for await (const event of pollWaiter(waiter)) {
73+
events.push(event);
74+
}
75+
}).rejects.toThrow("boom");
76+
77+
expect(events).toEqual([]);
78+
});
79+
80+
test("throws after yielding progress if waiter fails mid-poll", async () => {
81+
const waiter = createMockWaiter({
82+
progressValues: [{ status: "A" }],
83+
result: null as any,
84+
error: new Error("mid-poll failure"),
85+
});
86+
87+
const events: PollEvent<any>[] = [];
88+
await expect(async () => {
89+
for await (const event of pollWaiter(waiter)) {
90+
events.push(event);
91+
}
92+
}).rejects.toThrow("mid-poll failure");
93+
94+
expect(events).toEqual([{ type: "progress", value: { status: "A" } }]);
95+
});
96+
97+
test("handles async delays between progress callbacks", async () => {
98+
const waiter = createMockWaiter({
99+
progressValues: [{ n: 1 }, { n: 2 }, { n: 3 }],
100+
result: { n: 99 },
101+
delay: 10,
102+
});
103+
104+
const events = await collect(pollWaiter(waiter));
105+
106+
expect(events).toHaveLength(4);
107+
expect(events[0]).toEqual({ type: "progress", value: { n: 1 } });
108+
expect(events[1]).toEqual({ type: "progress", value: { n: 2 } });
109+
expect(events[2]).toEqual({ type: "progress", value: { n: 3 } });
110+
expect(events[3]).toEqual({ type: "completed", value: { n: 99 } });
111+
});
112+
113+
test("passes timeout option through to waiter.wait()", async () => {
114+
const waiter = createMockWaiter({
115+
result: { done: true },
116+
});
117+
118+
const timeoutValue = { ms: 5000 };
119+
await collect(pollWaiter(waiter, { timeout: timeoutValue }));
120+
121+
expect(waiter.wait).toHaveBeenCalledWith(
122+
expect.objectContaining({ timeout: timeoutValue }),
123+
);
124+
});
125+
});

0 commit comments

Comments
 (0)