Skip to content

Commit b6459dd

Browse files
nanoandrew4nanoandrew4mudler
authored
feat(api): Add transcribe response format request parameter & adjust STT backends (#8318)
* WIP response format implementation for audio transcriptions (cherry picked from commit e271dd7) Signed-off-by: Andres Smith <andressmithdev@pm.me> * Rework transcript response_format and add more formats (cherry picked from commit 6a93a8f) Signed-off-by: Andres Smith <andressmithdev@pm.me> * Add test and replace go-openai package with official openai go client (cherry picked from commit f25d1a0) Signed-off-by: Andres Smith <andressmithdev@pm.me> * Fix faster-whisper backend and refactor transcription formatting to also work on CLI Signed-off-by: Andres Smith <andressmithdev@pm.me> (cherry picked from commit 69a9397) Signed-off-by: Andres Smith <andressmithdev@pm.me> --------- Signed-off-by: Andres Smith <andressmithdev@pm.me> Co-authored-by: nanoandrew4 <nanoandrew4@gmail.com> Co-authored-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
1 parent 397f7f0 commit b6459dd

File tree

18 files changed

+353
-184
lines changed

18 files changed

+353
-184
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ LocalAI
3636
models/*
3737
test-models/
3838
test-dir/
39+
tests/e2e-aio/backends
40+
tests/e2e-aio/models
3941

4042
release/
4143

backend/go/whisper/gowhisper.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
130130
segments := []*pb.TranscriptSegment{}
131131
text := ""
132132
for i := range int(segsLen) {
133-
s := CppGetSegmentStart(i)
134-
t := CppGetSegmentEnd(i)
133+
// segment start/end conversion factor taken from https://github.com/ggml-org/whisper.cpp/blob/master/examples/cli/cli.cpp#L895
134+
s := CppGetSegmentStart(i) * (10000000)
135+
t := CppGetSegmentEnd(i) * (10000000)
135136
txt := strings.Clone(CppGetSegmentText(i))
136137
tokens := make([]int32, CppNTokens(i))
137138

backend/python/faster-whisper/backend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def LoadModel(self, request, context):
4040
device = "mps"
4141
try:
4242
print("Preparing models, please wait", file=sys.stderr)
43-
self.model = WhisperModel(request.Model, device=device, compute_type="float16")
43+
self.model = WhisperModel(request.Model, device=device, compute_type="default")
4444
except Exception as err:
4545
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
4646
# Implement your logic here for the LoadModel service
@@ -55,11 +55,12 @@ def AudioTranscription(self, request, context):
5555
id = 0
5656
for segment in segments:
5757
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
58-
resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=segment.start, end=segment.end, text=segment.text))
58+
resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=int(segment.start)*1e9, end=int(segment.end)*1e9, text=segment.text))
5959
text += segment.text
60-
id += 1
60+
id += 1
6161
except Exception as err:
6262
print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
63+
raise err
6364

6465
return backend_pb2.TranscriptResult(segments=resultSegments, text=text)
6566

core/backend/transcript.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ import (
1212
"github.com/mudler/LocalAI/pkg/model"
1313
)
1414

15-
func ModelTranscription(audio, language string, translate bool, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
16-
15+
func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
1716
if modelConfig.Backend == "" {
1817
modelConfig.Backend = model.WhisperBackend
1918
}

core/cli/transcript.go

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,42 @@ package cli
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"fmt"
8+
"strings"
79

810
"github.com/mudler/LocalAI/core/backend"
911
cliContext "github.com/mudler/LocalAI/core/cli/context"
1012
"github.com/mudler/LocalAI/core/config"
13+
"github.com/mudler/LocalAI/core/gallery"
14+
"github.com/mudler/LocalAI/core/schema"
15+
"github.com/mudler/LocalAI/pkg/format"
1116
"github.com/mudler/LocalAI/pkg/model"
1217
"github.com/mudler/LocalAI/pkg/system"
1318
"github.com/mudler/xlog"
1419
)
1520

1621
type TranscriptCMD struct {
17-
Filename string `arg:""`
22+
Filename string `arg:"" name:"file" help:"Audio file to transcribe" type:"path"`
1823

19-
Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
20-
Model string `short:"m" required:"" help:"Model name to run the TTS"`
21-
Language string `short:"l" help:"Language of the audio file"`
22-
Translate bool `short:"c" help:"Translate the transcription to english"`
23-
Diarize bool `short:"d" help:"Mark speaker turns"`
24-
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
25-
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
26-
Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"`
24+
Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
25+
Model string `short:"m" required:"" help:"Model name to run the TTS"`
26+
Language string `short:"l" help:"Language of the audio file"`
27+
Translate bool `short:"c" help:"Translate the transcription to English"`
28+
Diarize bool `short:"d" help:"Mark speaker turns"`
29+
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
30+
BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"`
31+
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
32+
BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"`
33+
Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"`
34+
ResponseFormat schema.TranscriptionResponseFormatType `short:"f" default:"" help:"Response format for Whisper models, can be one of (txt, lrc, srt, vtt, json, json_verbose)"`
35+
PrettyPrint bool `help:"Used with response_format json or json_verbose for pretty printing"`
2736
}
2837

2938
func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
3039
systemState, err := system.GetSystemState(
40+
system.WithBackendPath(t.BackendsPath),
3141
system.WithModelPath(t.ModelsPath),
3242
)
3343
if err != nil {
@@ -40,6 +50,11 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
4050

4151
cl := config.NewModelConfigLoader(t.ModelsPath)
4252
ml := model.NewModelLoader(systemState)
53+
54+
if err := gallery.RegisterBackends(systemState, ml); err != nil {
55+
xlog.Error("error registering external backends", "error", err)
56+
}
57+
4358
if err := cl.LoadModelConfigsFromPath(t.ModelsPath); err != nil {
4459
return err
4560
}
@@ -62,8 +77,29 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
6277
if err != nil {
6378
return err
6479
}
65-
for _, segment := range tr.Segments {
66-
fmt.Println(segment.Start.String(), "-", segment.Text)
80+
81+
switch t.ResponseFormat {
82+
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText:
83+
fmt.Println(format.TranscriptionResponse(tr, t.ResponseFormat))
84+
case schema.TranscriptionResponseFormatJson:
85+
tr.Segments = nil
86+
fallthrough
87+
case schema.TranscriptionResponseFormatJsonVerbose:
88+
var mtr []byte
89+
var err error
90+
if t.PrettyPrint {
91+
mtr, err = json.MarshalIndent(tr, "", " ")
92+
} else {
93+
mtr, err = json.Marshal(tr)
94+
}
95+
if err != nil {
96+
return err
97+
}
98+
fmt.Println(string(mtr))
99+
default:
100+
for _, segment := range tr.Segments {
101+
fmt.Println(segment.Start.String(), "-", strings.TrimSpace(segment.Text))
102+
}
67103
}
68104
return nil
69105
}

core/http/endpoints/openai/transcription.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package openai
22

33
import (
4+
"errors"
45
"io"
56
"net/http"
67
"os"
@@ -12,6 +13,7 @@ import (
1213
"github.com/mudler/LocalAI/core/config"
1314
"github.com/mudler/LocalAI/core/http/middleware"
1415
"github.com/mudler/LocalAI/core/schema"
16+
"github.com/mudler/LocalAI/pkg/format"
1517
model "github.com/mudler/LocalAI/pkg/model"
1618

1719
"github.com/mudler/xlog"
@@ -38,6 +40,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
3840

3941
diarize := c.FormValue("diarize") != "false"
4042
prompt := c.FormValue("prompt")
43+
responseFormat := schema.TranscriptionResponseFormatType(c.FormValue("response_format"))
4144

4245
// retrieve the file data from the request
4346
file, err := c.FormFile("file")
@@ -76,7 +79,17 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
7679
}
7780

7881
xlog.Debug("Transcribed", "transcription", tr)
79-
// TODO: handle different outputs here
80-
return c.JSON(http.StatusOK, tr)
82+
83+
switch responseFormat {
84+
case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatText, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt:
85+
return c.String(http.StatusOK, format.TranscriptionResponse(tr, responseFormat))
86+
case schema.TranscriptionResponseFormatJson:
87+
tr.Segments = nil
88+
fallthrough
89+
case schema.TranscriptionResponseFormatJsonVerbose, "": // maintain backwards compatibility
90+
return c.JSON(http.StatusOK, tr)
91+
default:
92+
return errors.New("invalid response_format")
93+
}
8194
}
8295
}

core/schema/openai.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,17 @@ type ImageGenerationResponseFormat string
107107

108108
type ChatCompletionResponseFormatType string
109109

110+
type TranscriptionResponseFormatType string
111+
112+
const (
113+
TranscriptionResponseFormatText = TranscriptionResponseFormatType("txt")
114+
TranscriptionResponseFormatSrt = TranscriptionResponseFormatType("srt")
115+
TranscriptionResponseFormatVtt = TranscriptionResponseFormatType("vtt")
116+
TranscriptionResponseFormatLrc = TranscriptionResponseFormatType("lrc")
117+
TranscriptionResponseFormatJson = TranscriptionResponseFormatType("json")
118+
TranscriptionResponseFormatJsonVerbose = TranscriptionResponseFormatType("json_verbose")
119+
)
120+
110121
type ChatCompletionResponseFormat struct {
111122
Type ChatCompletionResponseFormatType `json:"type,omitempty"`
112123
}

core/schema/transcription.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ type TranscriptionSegment struct {
1111
}
1212

1313
type TranscriptionResult struct {
14-
Segments []TranscriptionSegment `json:"segments"`
14+
Segments []TranscriptionSegment `json:"segments,omitempty"`
1515
Text string `json:"text"`
1616
}

core/startup/model_preload.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@ import (
1818
"github.com/mudler/xlog"
1919
)
2020

21-
const (
22-
YAML_EXTENSION = ".yaml"
23-
)
24-
2521
// InstallModels will preload models from the given list of URLs and galleries
2622
// It will download the model if it is not already present in the model path
2723
// It will also try to resolve if the model is an embedded model YAML configuration

0 commit comments

Comments
 (0)