From fd1df4e9717b60a404a5ac5f15b887d4737acca1 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 12 May 2023 10:04:20 +0200 Subject: [PATCH] whisper: add tests and allow to set upload size (#237) --- .github/workflows/test.yml | 4 ++-- Makefile | 5 ++++- api/api.go | 3 ++- api/api_test.go | 22 +++++++++++++++++----- api/openai.go | 26 ++++++++++++++++---------- main.go | 8 +++++++- pkg/whisper/whisper.go | 6 +++++- 7 files changed, 53 insertions(+), 21 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 54ad8b8..17e3c80 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,7 +21,7 @@ jobs: - name: Dependencies run: | sudo apt-get update - sudo apt-get install build-essential + sudo apt-get install build-essential ffmpeg - name: Test run: | make test @@ -38,7 +38,7 @@ jobs: - name: Dependencies run: | brew update - brew install sdl2 + brew install sdl2 ffmpeg - name: Test run: | make test \ No newline at end of file diff --git a/Makefile b/Makefile index f3b28f2..d1cf783 100644 --- a/Makefile +++ b/Makefile @@ -179,12 +179,15 @@ run: prepare ## run local-ai test-models/testmodel: mkdir test-models + mkdir test-dir wget https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerberas-111m-q4_0.bin -O test-models/testmodel + wget https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en + wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav cp tests/fixtures/* test-models test: prepare test-models/testmodel cp tests/fixtures/* test-models - @C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api + @C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api ## Help: help: ## Show this help. diff --git a/api/api.go b/api/api.go index a855534..59489f7 100644 --- a/api/api.go +++ b/api/api.go @@ -12,7 +12,7 @@ import ( "github.com/rs/zerolog/log" ) -func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App { +func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App { zerolog.SetGlobalLevel(zerolog.InfoLevel) if debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) @@ -20,6 +20,7 @@ func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16 // Return errors as JSON responses app := fiber.New(fiber.Config{ + BodyLimit: uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB DisableStartupMessage: disableMessage, // Override default error handler ErrorHandler: func(ctx *fiber.Ctx, err error) error { diff --git a/api/api_test.go b/api/api_test.go index de9fc34..5829bd1 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -3,6 +3,7 @@ package api_test import ( "context" "os" + "path/filepath" . "github.com/go-skynet/LocalAI/api" "github.com/go-skynet/LocalAI/pkg/model" @@ -23,7 +24,7 @@ var _ = Describe("API test", func() { Context("API query", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - app = App("", modelLoader, 1, 512, false, true, true) + app = App("", modelLoader, 15, 1, 512, false, true, true) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -45,7 +46,7 @@ var _ = Describe("API test", func() { It("returns the models list", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(3)) + Expect(len(models.Models)).To(Equal(4)) Expect(models.Models[0].ID).To(Equal("testmodel")) }) It("can generate completions", func() { @@ -81,13 +82,23 @@ var _ = Describe("API test", func() { Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 10 errors occurred:")) }) - + PIt("transcribes audio", func() { + resp, err := client.CreateTranscription( + context.Background(), + openai.AudioRequest{ + Model: openai.Whisper1, + FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"), + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting")) + }) }) Context("Config file", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - app = App(os.Getenv("CONFIG_FILE"), modelLoader, 1, 512, false, true, true) + app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -108,7 +119,7 @@ var _ = Describe("API test", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(5)) + Expect(len(models.Models)).To(Equal(6)) Expect(models.Models[0].ID).To(Equal("testmodel")) }) It("can generate chat completions from config file", func() { @@ -134,5 +145,6 @@ var _ = Describe("API test", func() { Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Text).ToNot(BeEmpty()) }) + }) }) diff --git a/api/openai.go b/api/openai.go index 1045507..7b65135 100644 --- a/api/openai.go +++ b/api/openai.go @@ -409,14 +409,13 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, // retrieve the file data from the request file, err := c.FormFile("file") if err != nil { - return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) + return err } f, err := file.Open() if err != nil { - return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) + return err } defer f.Close() - log.Debug().Msgf("Audio file: %+v", file) dir, err := os.MkdirTemp("", "whisper") @@ -428,26 +427,33 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, dst := filepath.Join(dir, path.Base(file.Filename)) dstFile, err := os.Create(dst) if err != nil { - return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) + return err } if _, err := io.Copy(dstFile, f); err != nil { - log.Debug().Msgf("Audio file %+v - %+v - err %+v", file.Filename, dst, err) + log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) return err } log.Debug().Msgf("Audio file copied to: %+v", dst) - whisperModel, err := loader.BackendLoader("whisper", config.Model, []llama.ModelOption{}, uint32(config.Threads)) + whisperModel, err := loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads)) if err != nil { - return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) + return err } - w := whisperModel.(whisper.Model) + if whisperModel == nil { + return fmt.Errorf("could not load whisper model") + } + + w, ok := whisperModel.(whisper.Model) + if !ok { + return fmt.Errorf("loader returned non-whisper object") + } - tr, err := whisperutil.Transcript(w, dst, input.Language) + tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) if err != nil { - return c.Status(http.StatusBadRequest).JSON(fiber.Map{"error": err.Error()}) + return err } log.Debug().Msgf("Trascribed: %+v", tr) diff --git a/main.go b/main.go index a57ac72..275fb31 100644 --- a/main.go +++ b/main.go @@ -62,6 +62,12 @@ func main() { EnvVars: []string{"CONTEXT_SIZE"}, Value: 512, }, + &cli.IntFlag{ + Name: "upload-limit", + DefaultText: "Default upload-limit. MB", + EnvVars: []string{"UPLOAD_LIMIT"}, + Value: 15, + }, }, Description: ` LocalAI is a drop-in replacement OpenAI API which runs inference locally. @@ -81,7 +87,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. Copyright: "go-skynet authors", Action: func(ctx *cli.Context) error { fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) - return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address")) + return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address")) }, } diff --git a/pkg/whisper/whisper.go b/pkg/whisper/whisper.go index ae84742..41c4587 100644 --- a/pkg/whisper/whisper.go +++ b/pkg/whisper/whisper.go @@ -28,7 +28,7 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string) (string, error) { +func Transcript(model whisper.Model, audiopath, language string, threads uint) (string, error) { dir, err := os.MkdirTemp("", "whisper") if err != nil { @@ -65,8 +65,12 @@ func Transcript(model whisper.Model, audiopath, language string) (string, error) } + context.SetThreads(threads) + if language != "" { context.SetLanguage(language) + } else { + context.SetLanguage("auto") } if err := context.Process(data, nil); err != nil {