diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5b8385c..5a0f502 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,6 +29,7 @@ jobs: sudo apt-get install -y ca-certificates cmake curl patch sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2 + sudo pip install -r extra/requirements.txt sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \ curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v1.11.0.tar.gz" | \ @@ -45,7 +46,6 @@ jobs: sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /lib64/ && \ sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \ sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/ - - name: Test run: | ESPEAK_DATA="/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data" GO_TAGS="tts stablediffusion" make test diff --git a/.gitignore b/.gitignore index 7b35ba9..a1c1c06 100644 --- a/.gitignore +++ b/.gitignore @@ -3,9 +3,10 @@ go-llama /gpt4all go-stable-diffusion go-piper +/go-bert go-ggllm /piper - +__pycache__/ *.a get-sources diff --git a/Dockerfile b/Dockerfile index c22f5dc..5e39303 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,10 +11,15 @@ ARG TARGETARCH ARG TARGETVARIANT ENV BUILD_TYPE=${BUILD_TYPE} +ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/extra/grpc/huggingface/huggingface.py" ARG GO_TAGS="stablediffusion tts" RUN apt-get update && \ - apt-get install -y ca-certificates cmake curl patch + apt-get install -y ca-certificates cmake curl patch pip + +# Extras requirements +COPY extra/requirements.txt /build/extra/requirements.txt +RUN pip install -r /build/extra/requirements.txt && rm -rf /build/extra/requirements.txt # CuBLAS requirements RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \ diff --git a/Makefile b/Makefile index 1145bcf..bb9df23 100644 --- a/Makefile +++ b/Makefile @@ -310,7 +310,7 @@ test: prepare test-models/testmodel grpcs @echo 'Running tests' export GO_TAGS="tts stablediffusion" $(MAKE) prepare-test - TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + HUGGINGFACE_GRPC=$(abspath ./)/extra/grpc/huggingface/huggingface.py TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg $(MAKE) test-gpt4all $(MAKE) test-llama @@ -335,9 +335,7 @@ test-stablediffusion: prepare-test test-container: docker build --target requirements -t local-ai-test-container . - docker run --name localai-tests -e GO_TAGS=$(GO_TAGS) -ti -v $(abspath ./):/build local-ai-test-container make test - docker rm localai-tests - docker rmi local-ai-test-container + docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container ## Help: help: ## Show this help. @@ -351,10 +349,15 @@ help: ## Show this help. else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \ }' $(MAKEFILE_LIST) -protogen: +protogen: protogen-go protogen-python + +protogen-go: protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \ pkg/grpc/proto/backend.proto +protogen-python: + python -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/huggingface/ --grpc_python_out=extra/grpc/huggingface/ pkg/grpc/proto/backend.proto + ## GRPC backend-assets/grpc: diff --git a/api/api_test.go b/api/api_test.go index 6970a8f..a602229 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -125,6 +125,11 @@ var _ = Describe("API test", func() { var cancel context.CancelFunc var tmpdir string + commonOpts := []options.AppOption{ + options.WithDebug(true), + options.WithDisableMessage(true), + } + Context("API with ephemeral models", func() { BeforeEach(func() { var err error @@ -143,7 +148,7 @@ var _ = Describe("API test", func() { Name: "bert2", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", Overrides: map[string]interface{}{"foo": "bar"}, - AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}}, + AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}}, }, } out, err := yaml.Marshal(g) @@ -159,9 +164,10 @@ var _ = Describe("API test", func() { } app, err = App( - options.WithContext(c), - options.WithGalleries(galleries), - options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir)) + append(commonOpts, + options.WithContext(c), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -400,13 +406,14 @@ var _ = Describe("API test", func() { } app, err = App( - options.WithContext(c), - options.WithAudioDir(tmpdir), - options.WithImageDir(tmpdir), - options.WithGalleries(galleries), - options.WithModelLoader(modelLoader), - options.WithBackendAssets(backendAssets), - options.WithBackendAssetsOutput(tmpdir), + append(commonOpts, + options.WithContext(c), + options.WithAudioDir(tmpdir), + options.WithImageDir(tmpdir), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), + options.WithBackendAssets(backendAssets), + options.WithBackendAssetsOutput(tmpdir))..., ) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -500,7 +507,12 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader)) + app, err = App( + append(commonOpts, + options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), + options.WithContext(c), + options.WithModelLoader(modelLoader), + )...) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -524,7 +536,7 @@ var _ = Describe("API test", func() { It("returns the models list", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(10)) + Expect(len(models.Models)).To(Equal(11)) }) It("can generate completions", func() { resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"}) @@ -555,7 +567,7 @@ var _ = Describe("API test", func() { }) It("returns errors", func() { - backends := len(model.AutoLoadBackends) + backends := len(model.AutoLoadBackends) + 1 // +1 for huggingface _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends))) @@ -602,6 +614,36 @@ var _ = Describe("API test", func() { Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) }) + Context("External gRPC calls", func() { + It("calculate embeddings with huggingface", func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + resp, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaCodeSearchCode, + Input: []string{"sun", "cat"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) + Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) + + sunEmbedding := resp.Data[0].Embedding + resp2, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaCodeSearchCode, + Input: []string{"sun"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) + Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) + }) + }) + Context("backends", func() { It("runs rwkv completion", func() { if runtime.GOOS != "linux" { @@ -674,7 +716,12 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader), options.WithConfigFile(os.Getenv("CONFIG_FILE"))) + app, err = App( + append(commonOpts, + options.WithContext(c), + options.WithModelLoader(modelLoader), + options.WithConfigFile(os.Getenv("CONFIG_FILE")))..., + ) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -696,7 +743,7 @@ var _ = Describe("API test", func() { It("can generate chat completions from config file", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(12)) + Expect(len(models.Models)).To(Equal(13)) }) It("can generate chat completions from config file", func() { resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go index 0310347..53df785 100644 --- a/api/backend/embeddings.go +++ b/api/backend/embeddings.go @@ -30,6 +30,10 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. model.WithContext(o.Context), } + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + if c.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { diff --git a/api/backend/image.go b/api/backend/image.go index a631b3b..9e32d1d 100644 --- a/api/backend/image.go +++ b/api/backend/image.go @@ -15,12 +15,20 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat return nil, fmt.Errorf("endpoint only working with stablediffusion models") } - inferenceModel, err := loader.BackendLoader( + opts := []model.Option{ model.WithBackendString(c.Backend), model.WithAssetDir(o.AssetsDestination), model.WithThreads(uint32(c.Threads)), model.WithContext(o.Context), model.WithModelFile(c.ImageGenerationAssets), + } + + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + inferenceModel, err := loader.BackendLoader( + opts..., ) if err != nil { return nil, err diff --git a/api/backend/llm.go b/api/backend/llm.go index 8fcd6da..23a5ca4 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -1,14 +1,17 @@ package backend import ( + "os" "regexp" "strings" "sync" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" ) func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { @@ -27,12 +30,32 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt model.WithContext(o.Context), } + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + if c.Backend != "" { + opts = append(opts, model.WithBackendString(c.Backend)) + } + + // Check if the modelFile exists, if it doesn't try to load it from the gallery + if o.AutoloadGalleries { // experimental + if _, err := os.Stat(modelFile); os.IsNotExist(err) { + utils.ResetDownloadTimers() + // if we failed to load the model, we try to download it + err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) + if err != nil { + return nil, err + } + } + } + if c.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { - opts = append(opts, model.WithBackendString(c.Backend)) inferenceModel, err = loader.BackendLoader(opts...) } + if err != nil { return nil, err } @@ -50,6 +73,9 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt return ss, err } else { reply, err := inferenceModel.Predict(o.Context, opts) + if err != nil { + return "", err + } return reply.Message, err } } diff --git a/api/backend/transcript.go b/api/backend/transcript.go new file mode 100644 index 0000000..b2f2501 --- /dev/null +++ b/api/backend/transcript.go @@ -0,0 +1,42 @@ +package backend + +import ( + "context" + "fmt" + + config "github.com/go-skynet/LocalAI/api/config" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) { + opts := []model.Option{ + model.WithBackendString(model.WhisperBackend), + model.WithModelFile(c.Model), + model.WithContext(o.Context), + model.WithThreads(uint32(c.Threads)), + model.WithAssetDir(o.AssetsDestination), + } + + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + whisperModel, err := o.Loader.BackendLoader(opts...) + if err != nil { + return nil, err + } + + if whisperModel == nil { + return nil, fmt.Errorf("could not load whisper model") + } + + return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ + Dst: audio, + Language: language, + Threads: uint32(c.Threads), + }) +} diff --git a/api/backend/tts.go b/api/backend/tts.go new file mode 100644 index 0000000..ac491e2 --- /dev/null +++ b/api/backend/tts.go @@ -0,0 +1,72 @@ +package backend + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" +) + +func generateUniqueFileName(dir, baseName, ext string) string { + counter := 1 + fileName := baseName + ext + + for { + filePath := filepath.Join(dir, fileName) + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + return fileName + } + + counter++ + fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) + } +} + +func ModelTTS(text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) { + opts := []model.Option{ + model.WithBackendString(model.PiperBackend), + model.WithModelFile(modelFile), + model.WithContext(o.Context), + model.WithAssetDir(o.AssetsDestination), + } + + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + piperModel, err := o.Loader.BackendLoader(opts...) + if err != nil { + return "", nil, err + } + + if piperModel == nil { + return "", nil, fmt.Errorf("could not load piper model") + } + + if err := os.MkdirAll(o.AudioDir, 0755); err != nil { + return "", nil, fmt.Errorf("failed creating audio directory: %s", err) + } + + fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") + filePath := filepath.Join(o.AudioDir, fileName) + + modelPath := filepath.Join(o.Loader.ModelPath, modelFile) + + if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { + return "", nil, err + } + + res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ + Text: text, + Model: modelPath, + Dst: filePath, + }) + + return filePath, res, err +} diff --git a/api/localai/gallery.go b/api/localai/gallery.go index feae294..ef4be14 100644 --- a/api/localai/gallery.go +++ b/api/localai/gallery.go @@ -4,13 +4,15 @@ import ( "context" "fmt" "os" + "strings" "sync" - "time" json "github.com/json-iterator/go" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/rs/zerolog/log" @@ -80,6 +82,8 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { case <-c.Done(): return case op := <-g.C: + utils.ResetDownloadTimers() + g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) // updates the status with an error @@ -90,13 +94,17 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { // displayDownload displays the download progress progressCallback := func(fileName string, current string, total string, percentage float64) { g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) - displayDownload(fileName, current, total, percentage) + utils.DisplayDownloadFunction(fileName, current, total, percentage) } var err error // if the request contains a gallery name, we apply the gallery from the gallery list if op.galleryName != "" { - err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + if strings.Contains(op.galleryName, "@") { + err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + } else { + err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + } } else { err = prepareModel(g.modelPath, op.req, cm, progressCallback) } @@ -119,31 +127,6 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { }() } -var lastProgress time.Time = time.Now() -var startTime time.Time = time.Now() - -func displayDownload(fileName string, current string, total string, percentage float64) { - currentTime := time.Now() - - if currentTime.Sub(lastProgress) >= 5*time.Second { - - lastProgress = currentTime - - // calculate ETA based on percentage and elapsed time - var eta time.Duration - if percentage > 0 { - elapsed := currentTime.Sub(startTime) - eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed)) - } - - if total != "" { - log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta) - } else { - log.Debug().Msgf("Downloading: %s", current) - } - } -} - type galleryModel struct { gallery.GalleryModel ID string `json:"id"` @@ -165,10 +148,11 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler } for _, r := range requests { + utils.ResetDownloadTimers() if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload) + err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) } else { - err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload) + err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) } } diff --git a/api/localai/localai.go b/api/localai/localai.go index 7c57c92..49f7780 100644 --- a/api/localai/localai.go +++ b/api/localai/localai.go @@ -1,17 +1,10 @@ package localai import ( - "context" - "fmt" - "os" - "path/filepath" - + "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" ) @@ -20,22 +13,6 @@ type TTSRequest struct { Input string `json:"input" yaml:"input"` } -func generateUniqueFileName(dir, baseName, ext string) string { - counter := 1 - fileName := baseName + ext - - for { - filePath := filepath.Join(dir, fileName) - _, err := os.Stat(filePath) - if os.IsNotExist(err) { - return fileName - } - - counter++ - fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) - } -} - func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { @@ -45,40 +22,10 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return err } - piperModel, err := o.Loader.BackendLoader( - model.WithBackendString(model.PiperBackend), - model.WithModelFile(input.Model), - model.WithContext(o.Context), - model.WithAssetDir(o.AssetsDestination)) + filePath, _, err := backend.ModelTTS(input.Input, input.Model, o.Loader, o) if err != nil { return err } - - if piperModel == nil { - return fmt.Errorf("could not load piper model") - } - - if err := os.MkdirAll(o.AudioDir, 0755); err != nil { - return fmt.Errorf("failed creating audio directory: %s", err) - } - - fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") - filePath := filepath.Join(o.AudioDir, fileName) - - modelPath := filepath.Join(o.Loader.ModelPath, input.Model) - - if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { - return err - } - - if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ - Text: input.Input, - Model: modelPath, - Dst: filePath, - }); err != nil { - return err - } - return c.Download(filePath) } } diff --git a/api/openai/transcription.go b/api/openai/transcription.go index 346693c..4b4a65e 100644 --- a/api/openai/transcription.go +++ b/api/openai/transcription.go @@ -1,7 +1,6 @@ package openai import ( - "context" "fmt" "io" "net/http" @@ -9,10 +8,9 @@ import ( "path" "path/filepath" + "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -61,25 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe log.Debug().Msgf("Audio file copied to: %+v", dst) - whisperModel, err := o.Loader.BackendLoader( - model.WithBackendString(model.WhisperBackend), - model.WithModelFile(config.Model), - model.WithContext(o.Context), - model.WithThreads(uint32(config.Threads)), - model.WithAssetDir(o.AssetsDestination)) - if err != nil { - return err - } - - if whisperModel == nil { - return fmt.Errorf("could not load whisper model") - } - - tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ - Dst: dst, - Language: input.Language, - Threads: uint32(config.Threads), - }) + tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o) if err != nil { return err } diff --git a/api/options/options.go b/api/options/options.go index 06029b0..b326947 100644 --- a/api/options/options.go +++ b/api/options/options.go @@ -28,6 +28,10 @@ type Option struct { BackendAssets embed.FS AssetsDestination string + + ExternalGRPCBackends map[string]string + + AutoloadGalleries bool } type AppOption func(*Option) @@ -53,6 +57,19 @@ func WithCors(b bool) AppOption { } } +var EnableGalleriesAutoload = func(o *Option) { + o.AutoloadGalleries = true +} + +func WithExternalBackend(name string, uri string) AppOption { + return func(o *Option) { + if o.ExternalGRPCBackends == nil { + o.ExternalGRPCBackends = make(map[string]string) + } + o.ExternalGRPCBackends[name] = uri + } +} + func WithCorsAllowOrigins(b string) AppOption { return func(o *Option) { o.CORSAllowOrigins = b diff --git a/extra/grpc/huggingface/backend_pb2.py b/extra/grpc/huggingface/backend_pb2.py new file mode 100644 index 0000000..0dafdf5 --- /dev/null +++ b/extra/grpc/huggingface/backend_pb2.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: backend.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\xa4\x05\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\t\"\xac\x02\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto' + _globals['_HEALTHMESSAGE']._serialized_start=26 + _globals['_HEALTHMESSAGE']._serialized_end=41 + _globals['_PREDICTOPTIONS']._serialized_start=44 + _globals['_PREDICTOPTIONS']._serialized_end=720 + _globals['_REPLY']._serialized_start=722 + _globals['_REPLY']._serialized_end=746 + _globals['_MODELOPTIONS']._serialized_start=749 + _globals['_MODELOPTIONS']._serialized_end=1049 + _globals['_RESULT']._serialized_start=1051 + _globals['_RESULT']._serialized_end=1093 + _globals['_EMBEDDINGRESULT']._serialized_start=1095 + _globals['_EMBEDDINGRESULT']._serialized_end=1132 + _globals['_TRANSCRIPTREQUEST']._serialized_start=1134 + _globals['_TRANSCRIPTREQUEST']._serialized_end=1201 + _globals['_TRANSCRIPTRESULT']._serialized_start=1203 + _globals['_TRANSCRIPTRESULT']._serialized_end=1281 + _globals['_TRANSCRIPTSEGMENT']._serialized_start=1283 + _globals['_TRANSCRIPTSEGMENT']._serialized_end=1372 + _globals['_GENERATEIMAGEREQUEST']._serialized_start=1375 + _globals['_GENERATEIMAGEREQUEST']._serialized_end=1533 + _globals['_TTSREQUEST']._serialized_start=1535 + _globals['_TTSREQUEST']._serialized_end=1589 + _globals['_BACKEND']._serialized_start=1592 + _globals['_BACKEND']._serialized_end=2083 +# @@protoc_insertion_point(module_scope) diff --git a/extra/grpc/huggingface/backend_pb2_grpc.py b/extra/grpc/huggingface/backend_pb2_grpc.py new file mode 100644 index 0000000..301c072 --- /dev/null +++ b/extra/grpc/huggingface/backend_pb2_grpc.py @@ -0,0 +1,297 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import backend_pb2 as backend__pb2 + + +class BackendStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Health = channel.unary_unary( + '/backend.Backend/Health', + request_serializer=backend__pb2.HealthMessage.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.Predict = channel.unary_unary( + '/backend.Backend/Predict', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.LoadModel = channel.unary_unary( + '/backend.Backend/LoadModel', + request_serializer=backend__pb2.ModelOptions.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.PredictStream = channel.unary_stream( + '/backend.Backend/PredictStream', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.Embedding = channel.unary_unary( + '/backend.Backend/Embedding', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.EmbeddingResult.FromString, + ) + self.GenerateImage = channel.unary_unary( + '/backend.Backend/GenerateImage', + request_serializer=backend__pb2.GenerateImageRequest.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.AudioTranscription = channel.unary_unary( + '/backend.Backend/AudioTranscription', + request_serializer=backend__pb2.TranscriptRequest.SerializeToString, + response_deserializer=backend__pb2.TranscriptResult.FromString, + ) + self.TTS = channel.unary_unary( + '/backend.Backend/TTS', + request_serializer=backend__pb2.TTSRequest.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + + +class BackendServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Health(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Predict(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def LoadModel(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PredictStream(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Embedding(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GenerateImage(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def AudioTranscription(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TTS(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_BackendServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Health': grpc.unary_unary_rpc_method_handler( + servicer.Health, + request_deserializer=backend__pb2.HealthMessage.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'Predict': grpc.unary_unary_rpc_method_handler( + servicer.Predict, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'LoadModel': grpc.unary_unary_rpc_method_handler( + servicer.LoadModel, + request_deserializer=backend__pb2.ModelOptions.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'PredictStream': grpc.unary_stream_rpc_method_handler( + servicer.PredictStream, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'Embedding': grpc.unary_unary_rpc_method_handler( + servicer.Embedding, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.EmbeddingResult.SerializeToString, + ), + 'GenerateImage': grpc.unary_unary_rpc_method_handler( + servicer.GenerateImage, + request_deserializer=backend__pb2.GenerateImageRequest.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'AudioTranscription': grpc.unary_unary_rpc_method_handler( + servicer.AudioTranscription, + request_deserializer=backend__pb2.TranscriptRequest.FromString, + response_serializer=backend__pb2.TranscriptResult.SerializeToString, + ), + 'TTS': grpc.unary_unary_rpc_method_handler( + servicer.TTS, + request_deserializer=backend__pb2.TTSRequest.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'backend.Backend', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Backend(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Health(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health', + backend__pb2.HealthMessage.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Predict(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def LoadModel(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel', + backend__pb2.ModelOptions.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PredictStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Embedding(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.EmbeddingResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GenerateImage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage', + backend__pb2.GenerateImageRequest.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def AudioTranscription(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription', + backend__pb2.TranscriptRequest.SerializeToString, + backend__pb2.TranscriptResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TTS(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS', + backend__pb2.TTSRequest.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/extra/grpc/huggingface/huggingface.py b/extra/grpc/huggingface/huggingface.py new file mode 100755 index 0000000..adf9876 --- /dev/null +++ b/extra/grpc/huggingface/huggingface.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +import grpc +from concurrent import futures +import time +import backend_pb2 +import backend_pb2_grpc +import argparse +import signal +import sys +import os +from sentence_transformers import SentenceTransformer + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + def Health(self, request, context): + return backend_pb2.Reply(message="OK") + def LoadModel(self, request, context): + model_name = request.Model + model_name = os.path.basename(model_name) + try: + self.model = SentenceTransformer(model_name) + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + # Implement your logic here for the LoadModel service + # Replace this with your desired response + return backend_pb2.Result(message="Model loaded successfully", success=True) + def Embedding(self, request, context): + # Implement your logic here for the Embedding service + # Replace this with your desired response + print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) + sentence_embeddings = self.model.encode(request.Embeddings) + return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings) + + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + + # Define the signal handler function + def signal_handler(sig, frame): + print("Received termination signal. Shutting down...") + server.stop(0) + sys.exit(0) + + # Set the signal handlers for SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + serve(args.addr) \ No newline at end of file diff --git a/extra/requirements.txt b/extra/requirements.txt new file mode 100644 index 0000000..9744afb --- /dev/null +++ b/extra/requirements.txt @@ -0,0 +1,4 @@ +sentence_transformers +grpcio +google +protobuf \ No newline at end of file diff --git a/main.go b/main.go index 3f534b0..2cb8627 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "syscall" api "github.com/go-skynet/LocalAI/api" @@ -40,6 +41,10 @@ func main() { Name: "f16", EnvVars: []string{"F16"}, }, + &cli.BoolFlag{ + Name: "autoload-galleries", + EnvVars: []string{"AUTOLOAD_GALLERIES"}, + }, &cli.BoolFlag{ Name: "debug", EnvVars: []string{"DEBUG"}, @@ -108,6 +113,11 @@ func main() { EnvVars: []string{"BACKEND_ASSETS_PATH"}, Value: "/tmp/localai/backend_data", }, + &cli.StringSliceFlag{ + Name: "external-grpc-backends", + Usage: "A list of external grpc backends", + EnvVars: []string{"EXTERNAL_GRPC_BACKENDS"}, + }, &cli.IntFlag{ Name: "context-size", Usage: "Default context size of the model", @@ -138,7 +148,8 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit UsageText: `local-ai [options]`, Copyright: "Ettore Di Giacinto", Action: func(ctx *cli.Context) error { - app, err := api.App( + + opts := []options.AppOption{ options.WithConfigFile(ctx.String("config-file")), options.WithJSONStringPreload(ctx.String("preload-models")), options.WithYAMLConfigPreload(ctx.String("preload-models-config")), @@ -155,7 +166,22 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit options.WithThreads(ctx.Int("threads")), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), - options.WithUploadLimitMB(ctx.Int("upload-limit"))) + options.WithUploadLimitMB(ctx.Int("upload-limit")), + } + + externalgRPC := ctx.StringSlice("external-grpc-backends") + // split ":" to get backend name and the uri + for _, v := range externalgRPC { + backend := v[:strings.IndexByte(v, ':')] + uri := v[strings.IndexByte(v, ':')+1:] + opts = append(opts, options.WithExternalBackend(backend, uri)) + } + + if ctx.Bool("autoload-galleries") { + opts = append(opts, options.EnableGalleriesAutoload) + } + + app, err := api.App(opts...) if err != nil { return err } diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index 8e08592..6fe05ed 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -18,23 +18,15 @@ type Gallery struct { // Installs a model from the gallery (galleryname@modelname) func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { - - // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. - name = strings.ReplaceAll(name, string(os.PathSeparator), "__") - - models, err := AvailableGalleryModels(galleries, basePath) - if err != nil { - return err - } - applyModel := func(model *GalleryModel) error { config, err := GetGalleryConfigFromURL(model.URL) if err != nil { return err } + installName := model.Name if req.Name != "" { - model.Name = req.Name + installName = req.Name } config.Files = append(config.Files, req.AdditionalFiles...) @@ -45,20 +37,58 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, return err } - if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil { + if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus); err != nil { return err } return nil } + models, err := AvailableGalleryModels(galleries, basePath) + if err != nil { + return err + } + + model, err := FindGallery(models, name) + if err != nil { + return err + } + + return applyModel(model) +} + +func FindGallery(models []*GalleryModel, name string) (*GalleryModel, error) { + // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. + name = strings.ReplaceAll(name, string(os.PathSeparator), "__") + for _, model := range models { if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) { - return applyModel(model) + return model, nil } } + return nil, fmt.Errorf("no gallery found with name %q", name) +} + +// InstallModelFromGalleryByName loads a model from the gallery by specifying only the name (first match wins) +func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { + models, err := AvailableGalleryModels(galleries, basePath) + if err != nil { + return err + } + + name = strings.ReplaceAll(name, string(os.PathSeparator), "__") + var model *GalleryModel + for _, m := range models { + if name == m.Name { + model = m + } + } + + if model == nil { + return fmt.Errorf("no model found with name %q", name) + } - return fmt.Errorf("no model found with name %q", name) + return InstallModelFromGallery(galleries, fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name), basePath, req, downloadStatus) } // List available models diff --git a/pkg/grpc/tts/piper.go b/pkg/grpc/tts/piper.go index dbaa4b7..3bc85e0 100644 --- a/pkg/grpc/tts/piper.go +++ b/pkg/grpc/tts/piper.go @@ -3,7 +3,9 @@ package tts // This is a wrapper to statisfy the GRPC service interface // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( + "fmt" "os" + "path/filepath" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" @@ -16,6 +18,9 @@ type Piper struct { } func (sd *Piper) Load(opts *pb.ModelOptions) error { + if filepath.Ext(opts.Model) != ".onnx" { + return fmt.Errorf("unsupported model type %s (should end with .onnx)", opts.Model) + } var err error // Note: the Model here is a path to a directory containing the model files sd.piper, err = New(opts.LibrarySearchPath) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 53fc684..08bf6c4 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -19,8 +19,6 @@ import ( process "github.com/mudler/go-processmanager" ) -const tokenizerSuffix = ".tokenizer.json" - const ( LlamaBackend = "llama" BloomzBackend = "bloomz" @@ -45,7 +43,6 @@ const ( StableDiffusionBackend = "stablediffusion" PiperBackend = "piper" LCHuggingFaceBackend = "langchain-huggingface" - //GGLLMFalconBackend = "falcon" ) var AutoLoadBackends []string = []string{ @@ -62,6 +59,11 @@ var AutoLoadBackends []string = []string{ MPTBackend, ReplitBackend, StarcoderBackend, + BloomzBackend, + RwkvBackend, + WhisperBackend, + StableDiffusionBackend, + PiperBackend, } func (ml *ModelLoader) StopGRPC() { @@ -70,75 +72,116 @@ func (ml *ModelLoader) StopGRPC() { } } -// starts the grpcModelProcess for the backend, and returns a grpc client -// It also loads the model -func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) { - return func(s string) (*grpc.Client, error) { - log.Debug().Msgf("Loading GRPC Model", backend, *o) +func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error { + // Make sure the process is executable + if err := os.Chmod(grpcProcess, 0755); err != nil { + return err + } - grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) + log.Debug().Msgf("Loading GRPC Process", grpcProcess) - // Check if the file exists - if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { - return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) - } + log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress) - // Make sure the process is executable - if err := os.Chmod(grpcProcess, 0755); err != nil { - return nil, err - } + grpcControlProcess := process.New( + process.WithTemporaryStateDir(), + process.WithName(grpcProcess), + process.WithArgs("--addr", serverAddress)) + + ml.grpcProcesses[id] = grpcControlProcess + + if err := grpcControlProcess.Run(); err != nil { + return err + } - log.Debug().Msgf("Loading GRPC Process", grpcProcess) - port, err := freeport.GetFreePort() + log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir()) + // clean up process + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + grpcControlProcess.Stop() + }() + + go func() { + t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) if err != nil { - return nil, err + log.Debug().Msgf("Could not tail stderr") } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text) + } + }() + go func() { + t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true}) + if err != nil { + log.Debug().Msgf("Could not tail stdout") + } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text) + } + }() - serverAddress := fmt.Sprintf("localhost:%d", port) - - log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress) + return nil +} - grpcControlProcess := process.New( - process.WithTemporaryStateDir(), - process.WithName(grpcProcess), - process.WithArgs("--addr", serverAddress)) +// starts the grpcModelProcess for the backend, and returns a grpc client +// It also loads the model +func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) { + return func(s string) (*grpc.Client, error) { + log.Debug().Msgf("Loading GRPC Model", backend, *o) - ml.grpcProcesses[o.modelFile] = grpcControlProcess + var client *grpc.Client - if err := grpcControlProcess.Run(); err != nil { - return nil, err + getFreeAddress := func() (string, error) { + port, err := freeport.GetFreePort() + if err != nil { + return "", fmt.Errorf("failed allocating free ports: %s", err.Error()) + } + return fmt.Sprintf("127.0.0.1:%d", port), nil } - // clean up process - go func() { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c - grpcControlProcess.Stop() - }() - - go func() { - t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) - if err != nil { - log.Debug().Msgf("Could not tail stderr") + // Check if the backend is provided as external + if uri, ok := o.externalBackends[backend]; ok { + log.Debug().Msgf("Loading external backend: %s", uri) + // check if uri is a file or a address + if _, err := os.Stat(uri); err == nil { + serverAddress, err := getFreeAddress() + if err != nil { + return nil, fmt.Errorf("failed allocating free ports: %s", err.Error()) + } + // Make sure the process is executable + if err := ml.startProcess(uri, o.modelFile, serverAddress); err != nil { + return nil, err + } + + log.Debug().Msgf("GRPC Service Started") + + client = grpc.NewClient(serverAddress) + } else { + // address + client = grpc.NewClient(uri) } - for line := range t.Lines { - log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) + } else { + grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) + // Check if the file exists + if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { + return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) } - }() - go func() { - t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true}) + + serverAddress, err := getFreeAddress() if err != nil { - log.Debug().Msgf("Could not tail stdout") + return nil, fmt.Errorf("failed allocating free ports: %s", err.Error()) } - for line := range t.Lines { - log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) + + // Make sure the process is executable + if err := ml.startProcess(grpcProcess, o.modelFile, serverAddress); err != nil { + return nil, err } - }() - log.Debug().Msgf("GRPC Service Started") + log.Debug().Msgf("GRPC Service Started") - client := grpc.NewClient(serverAddress) + client = grpc.NewClient(serverAddress) + } // Wait for the service to start up ready := false @@ -153,11 +196,6 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc if !ready { log.Debug().Msgf("GRPC Service NOT ready") - log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive()) - log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:")) - - log.Debug().Msgf(grpcControlProcess.ExitCode()) - return nil, fmt.Errorf("grpc service not ready") } @@ -168,10 +206,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc res, err := client.LoadModel(o.context, &options) if err != nil { - return nil, err + return nil, fmt.Errorf("could not load model: %w", err) } if !res.Success { - return nil, fmt.Errorf("could not load model: %s", res.Message) + return nil, fmt.Errorf("could not load model (no success): %s", res.Message) } return client, nil @@ -184,6 +222,13 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) backend := strings.ToLower(o.backendString) + + // if an external backend is provided, use it + _, externalBackendExists := o.externalBackends[backend] + if externalBackendExists { + return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o)) + } + switch backend { case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend, MPTBackend, Gpt2Backend, FalconBackend, @@ -204,8 +249,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { o := NewOptions(opts...) - log.Debug().Msgf("Loading model '%s' greedly", o.modelFile) - // Is this really needed? BackendLoader already does this ml.mu.Lock() if m := ml.checkIsLoaded(o.modelFile); m != nil { @@ -216,16 +259,29 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { ml.mu.Unlock() var err error - for _, b := range AutoLoadBackends { - log.Debug().Msgf("[%s] Attempting to load", b) + // autoload also external backends + allBackendsToAutoLoad := []string{} + allBackendsToAutoLoad = append(allBackendsToAutoLoad, AutoLoadBackends...) + for _, b := range o.externalBackends { + allBackendsToAutoLoad = append(allBackendsToAutoLoad, b) + } + log.Debug().Msgf("Loading model '%s' greedly from all the available backends: %s", o.modelFile, strings.Join(allBackendsToAutoLoad, ", ")) - model, modelerr := ml.BackendLoader( + for _, b := range allBackendsToAutoLoad { + log.Debug().Msgf("[%s] Attempting to load", b) + options := []Option{ WithBackendString(b), WithModelFile(o.modelFile), WithLoadGRPCLLMModelOpts(o.gRPCOptions), WithThreads(o.threads), WithAssetDir(o.assetDir), - ) + } + + for k, v := range o.externalBackends { + options = append(options, WithExternalBackend(k, v)) + } + + model, modelerr := ml.BackendLoader(options...) if modelerr == nil && model != nil { log.Debug().Msgf("[%s] Loads OK", b) return model, nil @@ -233,7 +289,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { err = multierror.Append(err, modelerr) log.Debug().Msgf("[%s] Fails: %s", b, modelerr.Error()) } else if model == nil { - err = multierror.Append(err, modelerr) + err = multierror.Append(err, fmt.Errorf("backend returned no usable model")) log.Debug().Msgf("[%s] Fails: %s", b, "backend returned no usable model") } } diff --git a/pkg/model/options.go b/pkg/model/options.go index 298ebd4..466e9c2 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -14,10 +14,21 @@ type Options struct { context context.Context gRPCOptions *pb.ModelOptions + + externalBackends map[string]string } type Option func(*Options) +func WithExternalBackend(name string, uri string) Option { + return func(o *Options) { + if o.externalBackends == nil { + o.externalBackends = make(map[string]string) + } + o.externalBackends[name] = uri + } +} + func WithBackendString(backend string) Option { return func(o *Options) { o.backendString = backend diff --git a/pkg/utils/logging.go b/pkg/utils/logging.go new file mode 100644 index 0000000..d69cbf8 --- /dev/null +++ b/pkg/utils/logging.go @@ -0,0 +1,37 @@ +package utils + +import ( + "time" + + "github.com/rs/zerolog/log" +) + +var lastProgress time.Time = time.Now() +var startTime time.Time = time.Now() + +func ResetDownloadTimers() { + lastProgress = time.Now() + startTime = time.Now() +} + +func DisplayDownloadFunction(fileName string, current string, total string, percentage float64) { + currentTime := time.Now() + + if currentTime.Sub(lastProgress) >= 5*time.Second { + + lastProgress = currentTime + + // calculate ETA based on percentage and elapsed time + var eta time.Duration + if percentage > 0 { + elapsed := currentTime.Sub(startTime) + eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed)) + } + + if total != "" { + log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta) + } else { + log.Debug().Msgf("Downloading: %s", current) + } + } +} diff --git a/tests/models_fixtures/grpc.yaml b/tests/models_fixtures/grpc.yaml new file mode 100644 index 0000000..31c406a --- /dev/null +++ b/tests/models_fixtures/grpc.yaml @@ -0,0 +1,5 @@ +name: code-search-ada-code-001 +backend: huggingface +embeddings: true +parameters: + model: all-MiniLM-L6-v2 \ No newline at end of file