Skip to content

Commit 68dd976

Browse files
authored
feat(tts): add support for streaming mode (#8291)
* feat(tts): add support for streaming mode Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Send first audio, make sure it's 16 Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
1 parent 2c44b06 commit 68dd976

File tree

13 files changed

+369
-0
lines changed

13 files changed

+369
-0
lines changed

backend/backend.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ service Backend {
1717
rpc GenerateVideo(GenerateVideoRequest) returns (Result) {}
1818
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
1919
rpc TTS(TTSRequest) returns (Result) {}
20+
rpc TTSStream(TTSRequest) returns (stream Reply) {}
2021
rpc SoundGeneration(SoundGenerationRequest) returns (Result) {}
2122
rpc TokenizeString(PredictOptions) returns (TokenizationResponse) {}
2223
rpc Status(HealthMessage) returns (StatusResponse) {}

backend/python/voxcpm/backend.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,90 @@ def TTS(self, request, context):
207207

208208
return backend_pb2.Result(success=True)
209209

210+
def TTSStream(self, request, context):
211+
try:
212+
# Get generation parameters from options with defaults
213+
cfg_value = self.options.get("cfg_value", 2.0)
214+
inference_timesteps = self.options.get("inference_timesteps", 10)
215+
normalize = self.options.get("normalize", False)
216+
denoise = self.options.get("denoise", False)
217+
retry_badcase = self.options.get("retry_badcase", True)
218+
retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3)
219+
retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0)
220+
221+
# Handle voice cloning via prompt_wav_path and prompt_text
222+
prompt_wav_path = None
223+
prompt_text = None
224+
225+
# Priority: request.voice > AudioPath > options
226+
if hasattr(request, 'voice') and request.voice:
227+
# If voice is provided, try to use it as a path
228+
if os.path.exists(request.voice):
229+
prompt_wav_path = request.voice
230+
elif hasattr(request, 'ModelFile') and request.ModelFile:
231+
model_file_base = os.path.dirname(request.ModelFile)
232+
potential_path = os.path.join(model_file_base, request.voice)
233+
if os.path.exists(potential_path):
234+
prompt_wav_path = potential_path
235+
elif hasattr(request, 'ModelPath') and request.ModelPath:
236+
potential_path = os.path.join(request.ModelPath, request.voice)
237+
if os.path.exists(potential_path):
238+
prompt_wav_path = potential_path
239+
240+
if hasattr(request, 'AudioPath') and request.AudioPath:
241+
if os.path.isabs(request.AudioPath):
242+
prompt_wav_path = request.AudioPath
243+
elif hasattr(request, 'ModelFile') and request.ModelFile:
244+
model_file_base = os.path.dirname(request.ModelFile)
245+
prompt_wav_path = os.path.join(model_file_base, request.AudioPath)
246+
elif hasattr(request, 'ModelPath') and request.ModelPath:
247+
prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath)
248+
else:
249+
prompt_wav_path = request.AudioPath
250+
251+
# Get prompt_text from options if available
252+
if "prompt_text" in self.options:
253+
prompt_text = self.options["prompt_text"]
254+
255+
# Prepare text
256+
text = request.text.strip()
257+
258+
# Get sample rate from model (needed for WAV header)
259+
sample_rate = self.model.tts_model.sample_rate
260+
261+
print(f"Streaming audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, sample_rate: {sample_rate}", file=sys.stderr)
262+
263+
# Send sample rate as first message (in message field as JSON or string)
264+
# Format: "sample_rate:16000" so we can parse it
265+
import json
266+
sample_rate_info = json.dumps({"sample_rate": int(sample_rate)})
267+
yield backend_pb2.Reply(message=bytes(sample_rate_info, 'utf-8'))
268+
269+
# Stream audio chunks
270+
for chunk in self.model.generate_streaming(
271+
text=text,
272+
prompt_wav_path=prompt_wav_path,
273+
prompt_text=prompt_text,
274+
cfg_value=cfg_value,
275+
inference_timesteps=inference_timesteps,
276+
normalize=normalize,
277+
denoise=denoise,
278+
retry_badcase=retry_badcase,
279+
retry_badcase_max_times=retry_badcase_max_times,
280+
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
281+
):
282+
# Convert numpy array to int16 PCM and then to bytes
283+
# Ensure values are in int16 range
284+
chunk_int16 = np.clip(chunk * 32767, -32768, 32767).astype(np.int16)
285+
chunk_bytes = chunk_int16.tobytes()
286+
yield backend_pb2.Reply(audio=chunk_bytes)
287+
288+
except Exception as err:
289+
print(f"Error in TTSStream: {err}", file=sys.stderr)
290+
print(traceback.format_exc(), file=sys.stderr)
291+
# Yield an error reply
292+
yield backend_pb2.Reply(message=bytes(f"Error: {err}", 'utf-8'))
293+
210294
def serve(address):
211295
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
212296
options=[

backend/python/voxcpm/test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,39 @@ def test_load_model(self):
4949
self.fail("LoadModel service failed")
5050
finally:
5151
self.tearDown()
52+
53+
def test_tts_stream(self):
54+
"""
55+
This method tests if TTS streaming works correctly
56+
"""
57+
try:
58+
self.setUp()
59+
print("Starting test_tts_stream")
60+
with grpc.insecure_channel("localhost:50051") as channel:
61+
stub = backend_pb2_grpc.BackendStub(channel)
62+
response = stub.LoadModel(backend_pb2.ModelOptions(Model="openbmb/VoxCPM1.5"))
63+
print(response)
64+
self.assertTrue(response.success)
65+
self.assertEqual(response.message, "Model loaded successfully")
66+
67+
# Test TTSStream
68+
tts_request = backend_pb2.TTSRequest(text="VoxCPM is an innovative end-to-end TTS model from ModelBest. This is a streaming test.", dst="test_stream.wav")
69+
chunks_received = 0
70+
total_audio_bytes = 0
71+
72+
for reply in stub.TTSStream(tts_request):
73+
# Verify that we receive audio chunks
74+
if reply.audio:
75+
chunks_received += 1
76+
total_audio_bytes += len(reply.audio)
77+
self.assertGreater(len(reply.audio), 0, "Audio chunk should not be empty")
78+
79+
# Verify that we received multiple chunks
80+
self.assertGreater(chunks_received, 0, "Should receive at least one audio chunk")
81+
self.assertGreater(total_audio_bytes, 0, "Total audio bytes should be greater than 0")
82+
print(f"Received {chunks_received} chunks with {total_audio_bytes} total bytes")
83+
except Exception as err:
84+
print(err)
85+
self.fail("TTSStream service failed")
86+
finally:
87+
self.tearDown()

core/backend/tts.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
package backend
22

33
import (
4+
"bytes"
45
"context"
6+
"encoding/binary"
7+
"encoding/json"
58
"fmt"
69
"os"
710
"path/filepath"
811

912
"github.com/mudler/LocalAI/core/config"
13+
laudio "github.com/mudler/LocalAI/pkg/audio"
1014

1115
"github.com/mudler/LocalAI/pkg/grpc/proto"
1216
"github.com/mudler/LocalAI/pkg/model"
@@ -74,3 +78,101 @@ func ModelTTS(
7478

7579
return filePath, res, err
7680
}
81+
82+
func ModelTTSStream(
83+
text,
84+
voice,
85+
language string,
86+
loader *model.ModelLoader,
87+
appConfig *config.ApplicationConfig,
88+
modelConfig config.ModelConfig,
89+
audioCallback func([]byte) error,
90+
) error {
91+
opts := ModelOptions(modelConfig, appConfig)
92+
ttsModel, err := loader.Load(opts...)
93+
if err != nil {
94+
return err
95+
}
96+
97+
if ttsModel == nil {
98+
return fmt.Errorf("could not load tts model %q", modelConfig.Model)
99+
}
100+
101+
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
102+
// This should be addressed in a follow up PR soon.
103+
// Copying it over nearly verbatim, as TTS backends are not functional without this.
104+
modelPath := ""
105+
// Checking first that it exists and is not outside ModelPath
106+
// TODO: we should actually first check if the modelFile is looking like
107+
// a FS path
108+
mp := filepath.Join(loader.ModelPath, modelConfig.Model)
109+
if _, err := os.Stat(mp); err == nil {
110+
if err := utils.VerifyPath(mp, appConfig.SystemState.Model.ModelsPath); err != nil {
111+
return err
112+
}
113+
modelPath = mp
114+
} else {
115+
modelPath = modelConfig.Model // skip this step if it fails?????
116+
}
117+
118+
var sampleRate uint32 = 16000 // default
119+
headerSent := false
120+
var callbackErr error
121+
122+
err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{
123+
Text: text,
124+
Model: modelPath,
125+
Voice: voice,
126+
Language: &language,
127+
}, func(reply *proto.Reply) {
128+
// First message contains sample rate info
129+
if !headerSent && len(reply.Message) > 0 {
130+
var info map[string]interface{}
131+
if json.Unmarshal(reply.Message, &info) == nil {
132+
if sr, ok := info["sample_rate"].(float64); ok {
133+
sampleRate = uint32(sr)
134+
}
135+
}
136+
// Send WAV header with placeholder size (0xFFFFFFFF for streaming)
137+
header := laudio.WAVHeader{
138+
ChunkID: [4]byte{'R', 'I', 'F', 'F'},
139+
ChunkSize: 0xFFFFFFFF, // Unknown size for streaming
140+
Format: [4]byte{'W', 'A', 'V', 'E'},
141+
Subchunk1ID: [4]byte{'f', 'm', 't', ' '},
142+
Subchunk1Size: 16,
143+
AudioFormat: 1, // PCM
144+
NumChannels: 1, // Mono
145+
SampleRate: sampleRate,
146+
ByteRate: sampleRate * 2, // SampleRate * BlockAlign
147+
BlockAlign: 2, // 16-bit = 2 bytes
148+
BitsPerSample: 16,
149+
Subchunk2ID: [4]byte{'d', 'a', 't', 'a'},
150+
Subchunk2Size: 0xFFFFFFFF, // Unknown size for streaming
151+
}
152+
153+
var buf bytes.Buffer
154+
if writeErr := binary.Write(&buf, binary.LittleEndian, header); writeErr != nil {
155+
callbackErr = writeErr
156+
return
157+
}
158+
159+
if writeErr := audioCallback(buf.Bytes()); writeErr != nil {
160+
callbackErr = writeErr
161+
return
162+
}
163+
headerSent = true
164+
}
165+
166+
// Stream audio chunks
167+
if len(reply.Audio) > 0 {
168+
if writeErr := audioCallback(reply.Audio); writeErr != nil {
169+
callbackErr = writeErr
170+
}
171+
}
172+
})
173+
174+
if callbackErr != nil {
175+
return callbackErr
176+
}
177+
return err
178+
}

core/http/endpoints/localai/tts.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,31 @@ func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig
5050
cfg.Voice = input.Voice
5151
}
5252

53+
// Handle streaming TTS
54+
if input.Stream {
55+
// Set headers for streaming audio
56+
c.Response().Header().Set("Content-Type", "audio/wav")
57+
c.Response().Header().Set("Transfer-Encoding", "chunked")
58+
c.Response().Header().Set("Cache-Control", "no-cache")
59+
c.Response().Header().Set("Connection", "keep-alive")
60+
61+
// Stream audio chunks as they're generated
62+
err := backend.ModelTTSStream(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error {
63+
_, writeErr := c.Response().Write(audioChunk)
64+
if writeErr != nil {
65+
return writeErr
66+
}
67+
c.Response().Flush()
68+
return nil
69+
})
70+
if err != nil {
71+
return err
72+
}
73+
74+
return nil
75+
}
76+
77+
// Non-streaming TTS (existing behavior)
5378
filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg)
5479
if err != nil {
5580
return err

core/schema/localai.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ type TTSRequest struct {
5353
Backend string `json:"backend" yaml:"backend"`
5454
Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model
5555
Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format
56+
Stream bool `json:"stream,omitempty" yaml:"stream,omitempty"` // (optional) enable streaming TTS
5657
}
5758

5859
// @Description VAD request body

docs/content/features/text-to-audio.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,41 @@ curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
2929

3030
Returns an `audio/wav` file.
3131

32+
## Streaming TTS
33+
34+
LocalAI supports streaming TTS generation, allowing audio to be played as it's generated. This is useful for real-time applications and reduces latency.
35+
36+
To enable streaming, add `"stream": true` to your request:
37+
38+
```bash
39+
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
40+
"input": "Hello world, this is a streaming test",
41+
"model": "voxcpm",
42+
"stream": true
43+
}' | aplay
44+
```
45+
46+
The audio will be streamed chunk-by-chunk as it's generated, allowing playback to start before generation completes. This is particularly useful for long texts or when you want to minimize perceived latency.
47+
48+
You can also pipe the streamed audio directly to audio players like `aplay` (Linux) or save it to a file:
49+
50+
```bash
51+
# Stream to aplay (Linux)
52+
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
53+
"input": "This is a longer text that will be streamed as it is generated",
54+
"model": "voxcpm",
55+
"stream": true
56+
}' | aplay
57+
58+
# Stream to a file
59+
curl http://localhost:8080/tts -H "Content-Type: application/json" -d '{
60+
"input": "Streaming audio to file",
61+
"model": "voxcpm",
62+
"stream": true
63+
}' > output.wav
64+
```
65+
66+
Note: Streaming TTS is currently supported by the `voxcpm` backend. Other backends will fall back to non-streaming mode if streaming is not supported.
3267

3368
## Backends
3469

pkg/grpc/backend.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ type Backend interface {
4141
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
4242
GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...grpc.CallOption) (*pb.Result, error)
4343
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
44+
TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...grpc.CallOption) error
4445
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
4546
Detect(ctx context.Context, in *pb.DetectOptions, opts ...grpc.CallOption) (*pb.DetectResponse, error)
4647
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)

pkg/grpc/base/base.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ func (llm *Base) TTS(*pb.TTSRequest) error {
6565
return fmt.Errorf("unimplemented")
6666
}
6767

68+
func (llm *Base) TTSStream(*pb.TTSRequest, chan []byte) error {
69+
return fmt.Errorf("unimplemented")
70+
}
71+
6872
func (llm *Base) SoundGeneration(*pb.SoundGenerationRequest) error {
6973
return fmt.Errorf("unimplemented")
7074
}

0 commit comments

Comments
 (0)