From 032dee256f1db6ea17a6c2eb68b195af41bbd5fc Mon Sep 17 00:00:00 2001 From: Matthew Campbell Date: Thu, 11 May 2023 19:05:07 +0700 Subject: [PATCH] Keep whisper models in memory (#233) --- api/openai.go | 7 ++++++- pkg/model/loader.go | 45 +++++++++++++++++++++++++++++++++++++++--- pkg/whisper/whisper.go | 9 +-------- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/api/openai.go b/api/openai.go index 171aa68..dcd2110 100644 --- a/api/openai.go +++ b/api/openai.go @@ -436,7 +436,12 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, log.Debug().Msgf("Audio file copied to: %+v", dst) - tr, err := whisper.Transcript(filepath.Join(loader.ModelPath, config.Model), dst, input.Language) + whisperModel, err := loader.WhisperLoader("whisper", config.Model) + if err != nil { + return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) + } + + tr, err := whisper.Transcript(whisperModel, dst, input.Language) if err != nil { return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 3679a46..2542248 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -11,6 +11,7 @@ import ( "text/template" rwkv "github.com/donomii/go-rwkv.cpp" + whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" bloomz "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" gpt2 "github.com/go-skynet/go-gpt2.cpp" @@ -32,9 +33,9 @@ type ModelLoader struct { redpajama map[string]*gpt2.RedPajama rwkv map[string]*rwkv.RwkvState bloomz map[string]*bloomz.Bloomz - - bert map[string]*bert.Bert - promptsTemplates map[string]*template.Template + bert map[string]*bert.Bert + promptsTemplates map[string]*template.Template + whisperModels map[string]whisper.Model } func NewModelLoader(modelPath string) *ModelLoader { @@ -50,6 +51,7 @@ func NewModelLoader(modelPath string) *ModelLoader { bloomz: make(map[string]*bloomz.Bloomz), bert: make(map[string]*bert.Bert), promptsTemplates: make(map[string]*template.Template), + whisperModels: make(map[string]whisper.Model), } } @@ -422,6 +424,33 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio return model, err } +func (ml *ModelLoader) LoadWhisperModel(modelName string) (whisper.Model, error) { + ml.mu.Lock() + defer ml.mu.Unlock() + + // Check if we already have a loaded model + if !ml.ExistsInModelPath(modelName) { + return nil, fmt.Errorf("model does not exist -- %s", modelName) + } + + if m, ok := ml.whisperModels[modelName]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", modelName) + return m, nil + } + + // Load the model and keep it in memory for later use + modelFile := filepath.Join(ml.ModelPath, modelName) + log.Debug().Msgf("Loading model in memory from file: %s", modelFile) + + model, err := whisper.New(modelFile) + if err != nil { + return nil, err + } + + ml.whisperModels[modelName] = model + return model, err +} + const tokenizerSuffix = ".tokenizer.json" var loadedModels map[string]interface{} = map[string]interface{}{} @@ -452,6 +481,16 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla } } +func (ml *ModelLoader) WhisperLoader(backendString string, modelFile string) (model whisper.Model, err error) { + //TODO expose more whisper options in next PR + switch strings.ToLower(backendString) { + case "whisper": + return ml.LoadWhisperModel(modelFile) + default: + return nil, fmt.Errorf("whisper backend unsupported: %s", backendString) + } +} + func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) { updateModels := func(model interface{}) { muModels.Lock() diff --git a/pkg/whisper/whisper.go b/pkg/whisper/whisper.go index 4077f86..ae84742 100644 --- a/pkg/whisper/whisper.go +++ b/pkg/whisper/whisper.go @@ -28,7 +28,7 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(modelpath, audiopath, language string) (string, error) { +func Transcript(model whisper.Model, audiopath, language string) (string, error) { dir, err := os.MkdirTemp("", "whisper") if err != nil { @@ -58,13 +58,6 @@ func Transcript(modelpath, audiopath, language string) (string, error) { data := buf.AsFloat32Buffer().Data - // Load the model - model, err := whisper.New(modelpath) - if err != nil { - return "", err - } - defer model.Close() - // Process samples context, err := model.NewContext() if err != nil {