1515from __future__ import annotations
1616
1717import asyncio
18+ import base64
19+ import json
1820import os
21+ import weakref
1922from dataclasses import dataclass , replace
23+ from urllib .parse import urlencode
2024
2125import aiohttp
2226
2327from livekit .agents import (
2428 APIConnectionError ,
2529 APIConnectOptions ,
30+ APIError ,
2631 APIStatusError ,
2732 APITimeoutError ,
33+ tokenize ,
2834 tts ,
2935 utils ,
3036)
3440 NotGivenOr ,
3541)
3642from livekit .agents .utils import is_given
43+ from livekit .agents .voice .io import TimedString
3744
3845from .langs import TTSLangs
3946from .log import logger
4350ARCANA_MODEL_TIMEOUT = 60 * 4
4451MIST_MODEL_TIMEOUT = 30
4552RIME_BASE_URL = "https://users.rime.ai/v1/rime-tts"
53+ RIME_WS_BASE_URL = "wss://users-ws.rime.ai"
54+ NUM_CHANNELS = 1
4655
4756
4857@dataclass
@@ -73,9 +82,6 @@ class _MistOptions:
7382 phonemize_between_brackets : NotGivenOr [bool ] = NOT_GIVEN
7483
7584
76- NUM_CHANNELS = 1
77-
78-
7985def _is_mist_model (model : TTSModels | str ) -> bool :
8086 return "mist" in model
8187
@@ -86,11 +92,40 @@ def _timeout_for_model(model: TTSModels | str) -> int:
8692 return MIST_MODEL_TIMEOUT
8793
8894
95+ def _model_params (opts : _TTSOptions ) -> dict [str , object ]:
96+ """Per-model option fields shared between the HTTP body and the WS query string."""
97+ params : dict [str , object ] = {}
98+ if opts .model == "arcana" and opts .arcana_options is not None :
99+ ao = opts .arcana_options
100+ if is_given (ao .lang ):
101+ params ["lang" ] = ao .lang
102+ if is_given (ao .repetition_penalty ):
103+ params ["repetition_penalty" ] = ao .repetition_penalty
104+ if is_given (ao .temperature ):
105+ params ["temperature" ] = ao .temperature
106+ if is_given (ao .top_p ):
107+ params ["top_p" ] = ao .top_p
108+ if is_given (ao .max_tokens ):
109+ params ["max_tokens" ] = ao .max_tokens
110+ elif _is_mist_model (opts .model ) and opts .mist_options is not None :
111+ mo = opts .mist_options
112+ if is_given (mo .lang ):
113+ params ["lang" ] = mo .lang
114+ if is_given (mo .speed_alpha ):
115+ params ["speedAlpha" ] = mo .speed_alpha
116+ if is_given (mo .pause_between_brackets ):
117+ params ["pauseBetweenBrackets" ] = mo .pause_between_brackets
118+ if is_given (mo .phonemize_between_brackets ):
119+ params ["phonemizeBetweenBrackets" ] = mo .phonemize_between_brackets
120+ return params
121+
122+
89123class TTS (tts .TTS ):
90124 def __init__ (
91125 self ,
92126 * ,
93127 base_url : str = RIME_BASE_URL ,
128+ ws_base_url : str = RIME_WS_BASE_URL ,
94129 model : TTSModels | str = "arcana" ,
95130 speaker : NotGivenOr [ArcanaVoices | str ] = NOT_GIVEN ,
96131 lang : TTSLangs | str = "eng" ,
@@ -107,10 +142,14 @@ def __init__(
107142 phonemize_between_brackets : NotGivenOr [bool ] = NOT_GIVEN ,
108143 api_key : NotGivenOr [str ] = NOT_GIVEN ,
109144 http_session : aiohttp .ClientSession | None = None ,
145+ use_websocket : bool = False ,
146+ segment : NotGivenOr [str ] = NOT_GIVEN ,
147+ tokenizer : NotGivenOr [tokenize .SentenceTokenizer ] = NOT_GIVEN ,
110148 ) -> None :
111149 super ().__init__ (
112150 capabilities = tts .TTSCapabilities (
113- streaming = False ,
151+ streaming = use_websocket ,
152+ aligned_transcript = use_websocket ,
114153 ),
115154 sample_rate = sample_rate ,
116155 num_channels = NUM_CHANNELS ,
@@ -148,9 +187,23 @@ def __init__(
148187 )
149188 self ._session = http_session
150189 self ._base_url = base_url
190+ self ._ws_base_url = ws_base_url
191+ self ._use_websocket = use_websocket
192+ self ._segment = segment if is_given (segment ) else "bySentence"
151193
152194 self ._total_timeout = _timeout_for_model (model )
153195
196+ self ._streams : weakref .WeakSet [SynthesizeStream ] = weakref .WeakSet ()
197+ self ._sentence_tokenizer = (
198+ tokenizer if is_given (tokenizer ) else tokenize .blingfire .SentenceTokenizer ()
199+ )
200+ self ._pool = utils .ConnectionPool [aiohttp .ClientWebSocketResponse ](
201+ connect_cb = self ._connect_ws ,
202+ close_cb = self ._close_ws ,
203+ max_session_duration = 300 ,
204+ mark_refreshed_on_get = True ,
205+ )
206+
154207 @property
155208 def model (self ) -> str :
156209 return self ._opts .model
@@ -165,6 +218,58 @@ def _ensure_session(self) -> aiohttp.ClientSession:
165218
166219 return self ._session
167220
221+ def _ws_url (self ) -> str :
222+ params : dict [str , object ] = {
223+ "speaker" : self ._opts .speaker ,
224+ "modelId" : self ._opts .model ,
225+ "audioFormat" : "pcm" ,
226+ "samplingRate" : self ._sample_rate ,
227+ "segment" : self ._segment ,
228+ ** _model_params (self ._opts ),
229+ }
230+ return f"{ self ._ws_base_url } /ws3?{ urlencode (params )} "
231+
232+ async def _connect_ws (self , timeout : float ) -> aiohttp .ClientWebSocketResponse :
233+ session = self ._ensure_session ()
234+ return await asyncio .wait_for (
235+ session .ws_connect (
236+ self ._ws_url (), headers = {"Authorization" : f"Bearer { self ._api_key } " }
237+ ),
238+ timeout ,
239+ )
240+
241+ async def _close_ws (self , ws : aiohttp .ClientWebSocketResponse ) -> None :
242+ try :
243+ await ws .send_str (json .dumps ({"operation" : "eos" }))
244+ try :
245+ await asyncio .wait_for (ws .receive (), timeout = 1.0 )
246+ except asyncio .TimeoutError :
247+ pass
248+ except Exception as e :
249+ logger .warning (f"Error during Rime WS close sequence: { e } " )
250+ finally :
251+ await ws .close ()
252+
253+ def prewarm (self ) -> None :
254+ self ._pool .prewarm ()
255+
256+ def stream (
257+ self , * , conn_options : APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
258+ ) -> SynthesizeStream :
259+ if not self ._use_websocket :
260+ raise RuntimeError (
261+ "Rime TTS streaming requires use_websocket=True at construction time"
262+ )
263+ s = SynthesizeStream (tts = self , conn_options = conn_options )
264+ self ._streams .add (s )
265+ return s
266+
267+ async def aclose (self ) -> None :
268+ for s in list (self ._streams ):
269+ await s .aclose ()
270+ self ._streams .clear ()
271+ await self ._pool .aclose ()
272+
168273 def synthesize (
169274 self , text : str , * , conn_options : APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS
170275 ) -> ChunkedStream :
@@ -189,6 +294,8 @@ def update_options(
189294 phonemize_between_brackets : NotGivenOr [bool ] = NOT_GIVEN ,
190295 base_url : NotGivenOr [str ] = NOT_GIVEN ,
191296 ) -> None :
297+ # WS URL is bound at pool connect; invalidate if any URL-affecting param changed.
298+ prev_ws_url = self ._ws_url () if self ._use_websocket else None
192299 if is_given (base_url ):
193300 self ._base_url = base_url
194301 if is_given (model ):
@@ -231,6 +338,9 @@ def update_options(
231338 if is_given (phonemize_between_brackets ):
232339 self ._opts .mist_options .phonemize_between_brackets = phonemize_between_brackets
233340
341+ if prev_ws_url is not None and self ._ws_url () != prev_ws_url :
342+ self ._pool .invalidate ()
343+
234344
235345class ChunkedStream (tts .ChunkedStream ):
236346 """Synthesize using the chunked api endpoint"""
@@ -245,38 +355,18 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
245355 "speaker" : self ._opts .speaker ,
246356 "text" : self ._input_text ,
247357 "modelId" : self ._opts .model ,
358+ ** _model_params (self ._opts ),
248359 }
249360 format = "audio/pcm"
250- if self ._opts .model == "arcana" :
251- arcana_opts = self ._opts .arcana_options
252- assert arcana_opts is not None
253- if is_given (arcana_opts .repetition_penalty ):
254- payload ["repetition_penalty" ] = arcana_opts .repetition_penalty
255- if is_given (arcana_opts .temperature ):
256- payload ["temperature" ] = arcana_opts .temperature
257- if is_given (arcana_opts .top_p ):
258- payload ["top_p" ] = arcana_opts .top_p
259- if is_given (arcana_opts .max_tokens ):
260- payload ["max_tokens" ] = arcana_opts .max_tokens
261- if is_given (arcana_opts .lang ):
262- payload ["lang" ] = arcana_opts .lang
263- if is_given (arcana_opts .sample_rate ):
264- payload ["samplingRate" ] = arcana_opts .sample_rate
265- elif _is_mist_model (self ._opts .model ):
361+ if self ._opts .model == "arcana" and self ._opts .arcana_options is not None :
362+ if is_given (self ._opts .arcana_options .sample_rate ):
363+ payload ["samplingRate" ] = self ._opts .arcana_options .sample_rate
364+ elif _is_mist_model (self ._opts .model ) and self ._opts .mist_options is not None :
266365 mist_opts = self ._opts .mist_options
267- assert mist_opts is not None
268- if is_given (mist_opts .lang ):
269- payload ["lang" ] = mist_opts .lang
270366 if is_given (mist_opts .sample_rate ):
271367 payload ["samplingRate" ] = mist_opts .sample_rate
272- if is_given (mist_opts .speed_alpha ):
273- payload ["speedAlpha" ] = mist_opts .speed_alpha
274368 if self ._opts .model == "mistv2" and is_given (mist_opts .reduce_latency ):
275369 payload ["reduceLatency" ] = mist_opts .reduce_latency
276- if is_given (mist_opts .pause_between_brackets ):
277- payload ["pauseBetweenBrackets" ] = mist_opts .pause_between_brackets
278- if is_given (mist_opts .phonemize_between_brackets ):
279- payload ["phonemizeBetweenBrackets" ] = mist_opts .phonemize_between_brackets
280370
281371 try :
282372 async with self ._tts ._ensure_session ().post (
@@ -316,3 +406,106 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
316406 ) from None
317407 except Exception as e :
318408 raise APIConnectionError () from e
409+
410+
411+ class SynthesizeStream (tts .SynthesizeStream ):
412+ """One stream = one utterance. Server-side bySentence segmentation by default;
413+ pass segment="immediate" on the TTS to disable server buffering when the agent
414+ is already feeding sentence-tokenized text."""
415+
416+ def __init__ (self , * , tts : TTS , conn_options : APIConnectOptions ) -> None :
417+ super ().__init__ (tts = tts , conn_options = conn_options )
418+ self ._tts : TTS = tts
419+
420+ async def _run (self , output_emitter : tts .AudioEmitter ) -> None :
421+ request_id = utils .shortuuid ()
422+ context_id = utils .shortuuid ()
423+ output_emitter .initialize (
424+ request_id = request_id ,
425+ sample_rate = self ._tts .sample_rate ,
426+ num_channels = NUM_CHANNELS ,
427+ mime_type = "audio/pcm" ,
428+ stream = True ,
429+ )
430+ output_emitter .start_segment (segment_id = context_id )
431+
432+ sent_stream = self ._tts ._sentence_tokenizer .stream ()
433+
434+ async def _input_task () -> None :
435+ async for data in self ._input_ch :
436+ if isinstance (data , self ._FlushSentinel ):
437+ sent_stream .flush ()
438+ continue
439+ sent_stream .push_text (data )
440+ sent_stream .end_input ()
441+
442+ async def _send_task (ws : aiohttp .ClientWebSocketResponse ) -> None :
443+ sent_count = 0
444+ async for ev in sent_stream :
445+ pkt = {"text" : ev .token + " " , "contextId" : context_id }
446+ self ._mark_started ()
447+ await ws .send_str (json .dumps (pkt ))
448+ sent_count += 1
449+ # Empty input: server returns notReady, never emits done — fail fast.
450+ if sent_count == 0 :
451+ raise APIError ("Rime WS: no text was provided to synthesize" )
452+ # Per-utterance flush — eos would close the pooled WS.
453+ await ws .send_str (json .dumps ({"operation" : "flush" , "contextId" : context_id }))
454+
455+ async def _recv_task (ws : aiohttp .ClientWebSocketResponse ) -> None :
456+ while True :
457+ msg = await ws .receive (timeout = self ._conn_options .timeout )
458+ if msg .type in (
459+ aiohttp .WSMsgType .CLOSE ,
460+ aiohttp .WSMsgType .CLOSED ,
461+ aiohttp .WSMsgType .CLOSING ,
462+ ):
463+ raise APIStatusError (
464+ "Rime ws closed unexpectedly" ,
465+ request_id = request_id ,
466+ )
467+ if msg .type != aiohttp .WSMsgType .TEXT :
468+ logger .warning ("unexpected Rime ws message type %s" , msg .type )
469+ continue
470+ data = json .loads (msg .data )
471+ t = data .get ("type" )
472+ if t == "chunk" :
473+ output_emitter .push (base64 .b64decode (data ["data" ]))
474+ elif t == "timestamps" :
475+ wt = data .get ("word_timestamps" ) or {}
476+ words = wt .get ("words" ) or []
477+ starts = wt .get ("start" ) or []
478+ ends = wt .get ("end" ) or []
479+ for w , s , e in zip (words , starts , ends , strict = False ):
480+ output_emitter .push_timed_transcript (
481+ TimedString (text = w + " " , start_time = s , end_time = e )
482+ )
483+ elif t == "done" :
484+ output_emitter .end_input ()
485+ break
486+ elif t == "error" :
487+ msg_text = data .get ("message" , "(no message)" )
488+ raise APIError (f"Rime ws error: { msg_text } " )
489+
490+ try :
491+ async with self ._tts ._pool .connection (timeout = self ._conn_options .timeout ) as ws :
492+ tasks = [
493+ asyncio .create_task (_input_task ()),
494+ asyncio .create_task (_send_task (ws )),
495+ asyncio .create_task (_recv_task (ws )),
496+ ]
497+ try :
498+ await asyncio .gather (* tasks )
499+ finally :
500+ await sent_stream .aclose ()
501+ await utils .aio .gracefully_cancel (* tasks )
502+ except asyncio .TimeoutError :
503+ raise APITimeoutError () from None
504+ except aiohttp .ClientResponseError as e :
505+ raise APIStatusError (
506+ message = e .message , status_code = e .status , request_id = None , body = None
507+ ) from None
508+ except APIError :
509+ raise
510+ except Exception as e :
511+ raise APIConnectionError (f"Rime WS error: { e } " ) from e
0 commit comments