Skip to content

Commit 50518df

Browse files
committed
feat(rime): add WebSocket streaming support
Adds opt-in WS streaming to the Rime TTS plugin via use_websocket=True. Pattern mirrors the Cartesia plugin: single-context JSON+base64 WS, ConnectionPool with mark_refreshed_on_get=True, blingfire sentence tokenizer, weakref.WeakSet for stream cleanup. - New SynthesizeStream class with input/send/recv task split - _connect_ws / _close_ws (eos shutdown, mirrors Deepgram) - _model_params helper consolidates the arcana/mist option-walking shared between the WS query string and the HTTP body - update_options invalidates the pool when the WS URL changes, computed via before/after _ws_url() diff - Capabilities flips streaming and aligned_transcript on with the flag - Routes to /ws3 only (mistv2 stays HTTP-only)
1 parent 8283a5a commit 50518df

2 files changed

Lines changed: 229 additions & 29 deletions

File tree

livekit-plugins/livekit-plugins-rime/livekit/plugins/rime/tts.py

Lines changed: 222 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,22 @@
1515
from __future__ import annotations
1616

1717
import asyncio
18+
import base64
19+
import json
1820
import os
21+
import weakref
1922
from dataclasses import dataclass, replace
23+
from urllib.parse import urlencode
2024

2125
import aiohttp
2226

2327
from livekit.agents import (
2428
APIConnectionError,
2529
APIConnectOptions,
30+
APIError,
2631
APIStatusError,
2732
APITimeoutError,
33+
tokenize,
2834
tts,
2935
utils,
3036
)
@@ -34,6 +40,7 @@
3440
NotGivenOr,
3541
)
3642
from livekit.agents.utils import is_given
43+
from livekit.agents.voice.io import TimedString
3744

3845
from .langs import TTSLangs
3946
from .log import logger
@@ -43,6 +50,8 @@
4350
ARCANA_MODEL_TIMEOUT = 60 * 4
4451
MIST_MODEL_TIMEOUT = 30
4552
RIME_BASE_URL = "https://users.rime.ai/v1/rime-tts"
53+
RIME_WS_BASE_URL = "wss://users-ws.rime.ai"
54+
NUM_CHANNELS = 1
4655

4756

4857
@dataclass
@@ -73,9 +82,6 @@ class _MistOptions:
7382
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN
7483

7584

76-
NUM_CHANNELS = 1
77-
78-
7985
def _is_mist_model(model: TTSModels | str) -> bool:
8086
return "mist" in model
8187

@@ -86,11 +92,40 @@ def _timeout_for_model(model: TTSModels | str) -> int:
8692
return MIST_MODEL_TIMEOUT
8793

8894

95+
def _model_params(opts: _TTSOptions) -> dict[str, object]:
96+
"""Per-model option fields shared between the HTTP body and the WS query string."""
97+
params: dict[str, object] = {}
98+
if opts.model == "arcana" and opts.arcana_options is not None:
99+
ao = opts.arcana_options
100+
if is_given(ao.lang):
101+
params["lang"] = ao.lang
102+
if is_given(ao.repetition_penalty):
103+
params["repetition_penalty"] = ao.repetition_penalty
104+
if is_given(ao.temperature):
105+
params["temperature"] = ao.temperature
106+
if is_given(ao.top_p):
107+
params["top_p"] = ao.top_p
108+
if is_given(ao.max_tokens):
109+
params["max_tokens"] = ao.max_tokens
110+
elif _is_mist_model(opts.model) and opts.mist_options is not None:
111+
mo = opts.mist_options
112+
if is_given(mo.lang):
113+
params["lang"] = mo.lang
114+
if is_given(mo.speed_alpha):
115+
params["speedAlpha"] = mo.speed_alpha
116+
if is_given(mo.pause_between_brackets):
117+
params["pauseBetweenBrackets"] = mo.pause_between_brackets
118+
if is_given(mo.phonemize_between_brackets):
119+
params["phonemizeBetweenBrackets"] = mo.phonemize_between_brackets
120+
return params
121+
122+
89123
class TTS(tts.TTS):
90124
def __init__(
91125
self,
92126
*,
93127
base_url: str = RIME_BASE_URL,
128+
ws_base_url: str = RIME_WS_BASE_URL,
94129
model: TTSModels | str = "arcana",
95130
speaker: NotGivenOr[ArcanaVoices | str] = NOT_GIVEN,
96131
lang: TTSLangs | str = "eng",
@@ -107,10 +142,14 @@ def __init__(
107142
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
108143
api_key: NotGivenOr[str] = NOT_GIVEN,
109144
http_session: aiohttp.ClientSession | None = None,
145+
use_websocket: bool = False,
146+
segment: NotGivenOr[str] = NOT_GIVEN,
147+
tokenizer: NotGivenOr[tokenize.SentenceTokenizer] = NOT_GIVEN,
110148
) -> None:
111149
super().__init__(
112150
capabilities=tts.TTSCapabilities(
113-
streaming=False,
151+
streaming=use_websocket,
152+
aligned_transcript=use_websocket,
114153
),
115154
sample_rate=sample_rate,
116155
num_channels=NUM_CHANNELS,
@@ -148,9 +187,23 @@ def __init__(
148187
)
149188
self._session = http_session
150189
self._base_url = base_url
190+
self._ws_base_url = ws_base_url
191+
self._use_websocket = use_websocket
192+
self._segment = segment if is_given(segment) else "bySentence"
151193

152194
self._total_timeout = _timeout_for_model(model)
153195

196+
self._streams: weakref.WeakSet[SynthesizeStream] = weakref.WeakSet()
197+
self._sentence_tokenizer = (
198+
tokenizer if is_given(tokenizer) else tokenize.blingfire.SentenceTokenizer()
199+
)
200+
self._pool = utils.ConnectionPool[aiohttp.ClientWebSocketResponse](
201+
connect_cb=self._connect_ws,
202+
close_cb=self._close_ws,
203+
max_session_duration=300,
204+
mark_refreshed_on_get=True,
205+
)
206+
154207
@property
155208
def model(self) -> str:
156209
return self._opts.model
@@ -165,6 +218,58 @@ def _ensure_session(self) -> aiohttp.ClientSession:
165218

166219
return self._session
167220

221+
def _ws_url(self) -> str:
222+
params: dict[str, object] = {
223+
"speaker": self._opts.speaker,
224+
"modelId": self._opts.model,
225+
"audioFormat": "pcm",
226+
"samplingRate": self._sample_rate,
227+
"segment": self._segment,
228+
**_model_params(self._opts),
229+
}
230+
return f"{self._ws_base_url}/ws3?{urlencode(params)}"
231+
232+
async def _connect_ws(self, timeout: float) -> aiohttp.ClientWebSocketResponse:
233+
session = self._ensure_session()
234+
return await asyncio.wait_for(
235+
session.ws_connect(
236+
self._ws_url(), headers={"Authorization": f"Bearer {self._api_key}"}
237+
),
238+
timeout,
239+
)
240+
241+
async def _close_ws(self, ws: aiohttp.ClientWebSocketResponse) -> None:
242+
try:
243+
await ws.send_str(json.dumps({"operation": "eos"}))
244+
try:
245+
await asyncio.wait_for(ws.receive(), timeout=1.0)
246+
except asyncio.TimeoutError:
247+
pass
248+
except Exception as e:
249+
logger.warning(f"Error during Rime WS close sequence: {e}")
250+
finally:
251+
await ws.close()
252+
253+
def prewarm(self) -> None:
254+
self._pool.prewarm()
255+
256+
def stream(
257+
self, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
258+
) -> SynthesizeStream:
259+
if not self._use_websocket:
260+
raise RuntimeError(
261+
"Rime TTS streaming requires use_websocket=True at construction time"
262+
)
263+
s = SynthesizeStream(tts=self, conn_options=conn_options)
264+
self._streams.add(s)
265+
return s
266+
267+
async def aclose(self) -> None:
268+
for s in list(self._streams):
269+
await s.aclose()
270+
self._streams.clear()
271+
await self._pool.aclose()
272+
168273
def synthesize(
169274
self, text: str, *, conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
170275
) -> ChunkedStream:
@@ -189,6 +294,8 @@ def update_options(
189294
phonemize_between_brackets: NotGivenOr[bool] = NOT_GIVEN,
190295
base_url: NotGivenOr[str] = NOT_GIVEN,
191296
) -> None:
297+
# WS URL is bound at pool connect; invalidate if any URL-affecting param changed.
298+
prev_ws_url = self._ws_url() if self._use_websocket else None
192299
if is_given(base_url):
193300
self._base_url = base_url
194301
if is_given(model):
@@ -231,6 +338,9 @@ def update_options(
231338
if is_given(phonemize_between_brackets):
232339
self._opts.mist_options.phonemize_between_brackets = phonemize_between_brackets
233340

341+
if prev_ws_url is not None and self._ws_url() != prev_ws_url:
342+
self._pool.invalidate()
343+
234344

235345
class ChunkedStream(tts.ChunkedStream):
236346
"""Synthesize using the chunked api endpoint"""
@@ -245,38 +355,18 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
245355
"speaker": self._opts.speaker,
246356
"text": self._input_text,
247357
"modelId": self._opts.model,
358+
**_model_params(self._opts),
248359
}
249360
format = "audio/pcm"
250-
if self._opts.model == "arcana":
251-
arcana_opts = self._opts.arcana_options
252-
assert arcana_opts is not None
253-
if is_given(arcana_opts.repetition_penalty):
254-
payload["repetition_penalty"] = arcana_opts.repetition_penalty
255-
if is_given(arcana_opts.temperature):
256-
payload["temperature"] = arcana_opts.temperature
257-
if is_given(arcana_opts.top_p):
258-
payload["top_p"] = arcana_opts.top_p
259-
if is_given(arcana_opts.max_tokens):
260-
payload["max_tokens"] = arcana_opts.max_tokens
261-
if is_given(arcana_opts.lang):
262-
payload["lang"] = arcana_opts.lang
263-
if is_given(arcana_opts.sample_rate):
264-
payload["samplingRate"] = arcana_opts.sample_rate
265-
elif _is_mist_model(self._opts.model):
361+
if self._opts.model == "arcana" and self._opts.arcana_options is not None:
362+
if is_given(self._opts.arcana_options.sample_rate):
363+
payload["samplingRate"] = self._opts.arcana_options.sample_rate
364+
elif _is_mist_model(self._opts.model) and self._opts.mist_options is not None:
266365
mist_opts = self._opts.mist_options
267-
assert mist_opts is not None
268-
if is_given(mist_opts.lang):
269-
payload["lang"] = mist_opts.lang
270366
if is_given(mist_opts.sample_rate):
271367
payload["samplingRate"] = mist_opts.sample_rate
272-
if is_given(mist_opts.speed_alpha):
273-
payload["speedAlpha"] = mist_opts.speed_alpha
274368
if self._opts.model == "mistv2" and is_given(mist_opts.reduce_latency):
275369
payload["reduceLatency"] = mist_opts.reduce_latency
276-
if is_given(mist_opts.pause_between_brackets):
277-
payload["pauseBetweenBrackets"] = mist_opts.pause_between_brackets
278-
if is_given(mist_opts.phonemize_between_brackets):
279-
payload["phonemizeBetweenBrackets"] = mist_opts.phonemize_between_brackets
280370

281371
try:
282372
async with self._tts._ensure_session().post(
@@ -316,3 +406,106 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
316406
) from None
317407
except Exception as e:
318408
raise APIConnectionError() from e
409+
410+
411+
class SynthesizeStream(tts.SynthesizeStream):
412+
"""One stream = one utterance. Server-side bySentence segmentation by default;
413+
pass segment="immediate" on the TTS to disable server buffering when the agent
414+
is already feeding sentence-tokenized text."""
415+
416+
def __init__(self, *, tts: TTS, conn_options: APIConnectOptions) -> None:
417+
super().__init__(tts=tts, conn_options=conn_options)
418+
self._tts: TTS = tts
419+
420+
async def _run(self, output_emitter: tts.AudioEmitter) -> None:
421+
request_id = utils.shortuuid()
422+
context_id = utils.shortuuid()
423+
output_emitter.initialize(
424+
request_id=request_id,
425+
sample_rate=self._tts.sample_rate,
426+
num_channels=NUM_CHANNELS,
427+
mime_type="audio/pcm",
428+
stream=True,
429+
)
430+
output_emitter.start_segment(segment_id=context_id)
431+
432+
sent_stream = self._tts._sentence_tokenizer.stream()
433+
434+
async def _input_task() -> None:
435+
async for data in self._input_ch:
436+
if isinstance(data, self._FlushSentinel):
437+
sent_stream.flush()
438+
continue
439+
sent_stream.push_text(data)
440+
sent_stream.end_input()
441+
442+
async def _send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
443+
sent_count = 0
444+
async for ev in sent_stream:
445+
pkt = {"text": ev.token + " ", "contextId": context_id}
446+
self._mark_started()
447+
await ws.send_str(json.dumps(pkt))
448+
sent_count += 1
449+
# Empty input: server returns notReady, never emits done — fail fast.
450+
if sent_count == 0:
451+
raise APIError("Rime WS: no text was provided to synthesize")
452+
# Per-utterance flush — eos would close the pooled WS.
453+
await ws.send_str(json.dumps({"operation": "flush", "contextId": context_id}))
454+
455+
async def _recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
456+
while True:
457+
msg = await ws.receive(timeout=self._conn_options.timeout)
458+
if msg.type in (
459+
aiohttp.WSMsgType.CLOSE,
460+
aiohttp.WSMsgType.CLOSED,
461+
aiohttp.WSMsgType.CLOSING,
462+
):
463+
raise APIStatusError(
464+
"Rime ws closed unexpectedly",
465+
request_id=request_id,
466+
)
467+
if msg.type != aiohttp.WSMsgType.TEXT:
468+
logger.warning("unexpected Rime ws message type %s", msg.type)
469+
continue
470+
data = json.loads(msg.data)
471+
t = data.get("type")
472+
if t == "chunk":
473+
output_emitter.push(base64.b64decode(data["data"]))
474+
elif t == "timestamps":
475+
wt = data.get("word_timestamps") or {}
476+
words = wt.get("words") or []
477+
starts = wt.get("start") or []
478+
ends = wt.get("end") or []
479+
for w, s, e in zip(words, starts, ends, strict=False):
480+
output_emitter.push_timed_transcript(
481+
TimedString(text=w + " ", start_time=s, end_time=e)
482+
)
483+
elif t == "done":
484+
output_emitter.end_input()
485+
break
486+
elif t == "error":
487+
msg_text = data.get("message", "(no message)")
488+
raise APIError(f"Rime ws error: {msg_text}")
489+
490+
try:
491+
async with self._tts._pool.connection(timeout=self._conn_options.timeout) as ws:
492+
tasks = [
493+
asyncio.create_task(_input_task()),
494+
asyncio.create_task(_send_task(ws)),
495+
asyncio.create_task(_recv_task(ws)),
496+
]
497+
try:
498+
await asyncio.gather(*tasks)
499+
finally:
500+
await sent_stream.aclose()
501+
await utils.aio.gracefully_cancel(*tasks)
502+
except asyncio.TimeoutError:
503+
raise APITimeoutError() from None
504+
except aiohttp.ClientResponseError as e:
505+
raise APIStatusError(
506+
message=e.message, status_code=e.status, request_id=None, body=None
507+
) from None
508+
except APIError:
509+
raise
510+
except Exception as e:
511+
raise APIConnectionError(f"Rime WS error: {e}") from e

tests/test_tts.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,13 @@ async def test_tts_synthesize_error_propagation():
431431
},
432432
id="google",
433433
),
434+
pytest.param(
435+
lambda: {
436+
"tts": rime.TTS(use_websocket=True),
437+
"proxy-upstream": "users-ws.rime.ai:443",
438+
},
439+
id="rime",
440+
),
434441
pytest.param(
435442
lambda: {
436443
"tts": tts.StreamAdapter(tts=inworld.TTS()),

0 commit comments

Comments
 (0)