feat: backends improvements (#778)

renovate/github.com-nomic-ai-gpt4all-gpt4all-bindings-golang-digest
Ettore Di Giacinto 1 year ago committed by GitHub
commit 0eac0402e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      .github/workflows/test.yml
  2. 3
      .gitignore
  3. 7
      Dockerfile
  4. 13
      Makefile
  5. 79
      api/api_test.go
  6. 4
      api/backend/embeddings.go
  7. 10
      api/backend/image.go
  8. 28
      api/backend/llm.go
  9. 42
      api/backend/transcript.go
  10. 72
      api/backend/tts.go
  11. 44
      api/localai/gallery.go
  12. 57
      api/localai/localai.go
  13. 24
      api/openai/transcription.go
  14. 17
      api/options/options.go
  15. 49
      extra/grpc/huggingface/backend_pb2.py
  16. 297
      extra/grpc/huggingface/backend_pb2_grpc.py
  17. 67
      extra/grpc/huggingface/huggingface.py
  18. 4
      extra/requirements.txt
  19. 30
      main.go
  20. 56
      pkg/gallery/gallery.go
  21. 5
      pkg/grpc/tts/piper.go
  22. 190
      pkg/model/initializers.go
  23. 11
      pkg/model/options.go
  24. 37
      pkg/utils/logging.go
  25. 5
      tests/models_fixtures/grpc.yaml

@ -29,6 +29,7 @@ jobs:
sudo apt-get install -y ca-certificates cmake curl patch sudo apt-get install -y ca-certificates cmake curl patch
sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2 sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2
sudo pip install -r extra/requirements.txt
sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \ sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \
curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v1.11.0.tar.gz" | \ curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v1.11.0.tar.gz" | \
@ -45,7 +46,6 @@ jobs:
sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /lib64/ && \ sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /lib64/ && \
sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \ sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \
sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/ sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/
- name: Test - name: Test
run: | run: |
ESPEAK_DATA="/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data" GO_TAGS="tts stablediffusion" make test ESPEAK_DATA="/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data" GO_TAGS="tts stablediffusion" make test

3
.gitignore vendored

@ -3,9 +3,10 @@ go-llama
/gpt4all /gpt4all
go-stable-diffusion go-stable-diffusion
go-piper go-piper
/go-bert
go-ggllm go-ggllm
/piper /piper
__pycache__/
*.a *.a
get-sources get-sources

@ -11,10 +11,15 @@ ARG TARGETARCH
ARG TARGETVARIANT ARG TARGETVARIANT
ENV BUILD_TYPE=${BUILD_TYPE} ENV BUILD_TYPE=${BUILD_TYPE}
ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/extra/grpc/huggingface/huggingface.py"
ARG GO_TAGS="stablediffusion tts" ARG GO_TAGS="stablediffusion tts"
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y ca-certificates cmake curl patch apt-get install -y ca-certificates cmake curl patch pip
# Extras requirements
COPY extra/requirements.txt /build/extra/requirements.txt
RUN pip install -r /build/extra/requirements.txt && rm -rf /build/extra/requirements.txt
# CuBLAS requirements # CuBLAS requirements
RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \ RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \

@ -310,7 +310,7 @@ test: prepare test-models/testmodel grpcs
@echo 'Running tests' @echo 'Running tests'
export GO_TAGS="tts stablediffusion" export GO_TAGS="tts stablediffusion"
$(MAKE) prepare-test $(MAKE) prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ HUGGINGFACE_GRPC=$(abspath ./)/extra/grpc/huggingface/huggingface.py TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg
$(MAKE) test-gpt4all $(MAKE) test-gpt4all
$(MAKE) test-llama $(MAKE) test-llama
@ -335,9 +335,7 @@ test-stablediffusion: prepare-test
test-container: test-container:
docker build --target requirements -t local-ai-test-container . docker build --target requirements -t local-ai-test-container .
docker run --name localai-tests -e GO_TAGS=$(GO_TAGS) -ti -v $(abspath ./):/build local-ai-test-container make test docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container
docker rm localai-tests
docker rmi local-ai-test-container
## Help: ## Help:
help: ## Show this help. help: ## Show this help.
@ -351,10 +349,15 @@ help: ## Show this help.
else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \ else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \
}' $(MAKEFILE_LIST) }' $(MAKEFILE_LIST)
protogen: protogen: protogen-go protogen-python
protogen-go:
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \ protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \
pkg/grpc/proto/backend.proto pkg/grpc/proto/backend.proto
protogen-python:
python -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/huggingface/ --grpc_python_out=extra/grpc/huggingface/ pkg/grpc/proto/backend.proto
## GRPC ## GRPC
backend-assets/grpc: backend-assets/grpc:

@ -125,6 +125,11 @@ var _ = Describe("API test", func() {
var cancel context.CancelFunc var cancel context.CancelFunc
var tmpdir string var tmpdir string
commonOpts := []options.AppOption{
options.WithDebug(true),
options.WithDisableMessage(true),
}
Context("API with ephemeral models", func() { Context("API with ephemeral models", func() {
BeforeEach(func() { BeforeEach(func() {
var err error var err error
@ -143,7 +148,7 @@ var _ = Describe("API test", func() {
Name: "bert2", Name: "bert2",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml",
Overrides: map[string]interface{}{"foo": "bar"}, Overrides: map[string]interface{}{"foo": "bar"},
AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}}, AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}},
}, },
} }
out, err := yaml.Marshal(g) out, err := yaml.Marshal(g)
@ -159,9 +164,10 @@ var _ = Describe("API test", func() {
} }
app, err = App( app, err = App(
options.WithContext(c), append(commonOpts,
options.WithGalleries(galleries), options.WithContext(c),
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir)) options.WithGalleries(galleries),
options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
@ -400,13 +406,14 @@ var _ = Describe("API test", func() {
} }
app, err = App( app, err = App(
options.WithContext(c), append(commonOpts,
options.WithAudioDir(tmpdir), options.WithContext(c),
options.WithImageDir(tmpdir), options.WithAudioDir(tmpdir),
options.WithGalleries(galleries), options.WithImageDir(tmpdir),
options.WithModelLoader(modelLoader), options.WithGalleries(galleries),
options.WithBackendAssets(backendAssets), options.WithModelLoader(modelLoader),
options.WithBackendAssetsOutput(tmpdir), options.WithBackendAssets(backendAssets),
options.WithBackendAssetsOutput(tmpdir))...,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
@ -500,7 +507,12 @@ var _ = Describe("API test", func() {
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
var err error var err error
app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader)) app, err = App(
append(commonOpts,
options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")),
options.WithContext(c),
options.WithModelLoader(modelLoader),
)...)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
@ -524,7 +536,7 @@ var _ = Describe("API test", func() {
It("returns the models list", func() { It("returns the models list", func() {
models, err := client.ListModels(context.TODO()) models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(10)) Expect(len(models.Models)).To(Equal(11))
}) })
It("can generate completions", func() { It("can generate completions", func() {
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"}) resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"})
@ -555,7 +567,7 @@ var _ = Describe("API test", func() {
}) })
It("returns errors", func() { It("returns errors", func() {
backends := len(model.AutoLoadBackends) backends := len(model.AutoLoadBackends) + 1 // +1 for huggingface
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends))) Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends)))
@ -602,6 +614,36 @@ var _ = Describe("API test", func() {
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
}) })
Context("External gRPC calls", func() {
It("calculate embeddings with huggingface", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaCodeSearchCode,
Input: []string{"sun", "cat"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384))
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384))
sunEmbedding := resp.Data[0].Embedding
resp2, err := client.CreateEmbeddings(
context.Background(),
openai.EmbeddingRequest{
Model: openai.AdaCodeSearchCode,
Input: []string{"sun"},
},
)
Expect(err).ToNot(HaveOccurred())
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding))
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding))
})
})
Context("backends", func() { Context("backends", func() {
It("runs rwkv completion", func() { It("runs rwkv completion", func() {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
@ -674,7 +716,12 @@ var _ = Describe("API test", func() {
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
var err error var err error
app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader), options.WithConfigFile(os.Getenv("CONFIG_FILE"))) app, err = App(
append(commonOpts,
options.WithContext(c),
options.WithModelLoader(modelLoader),
options.WithConfigFile(os.Getenv("CONFIG_FILE")))...,
)
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
@ -696,7 +743,7 @@ var _ = Describe("API test", func() {
It("can generate chat completions from config file", func() { It("can generate chat completions from config file", func() {
models, err := client.ListModels(context.TODO()) models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(12)) Expect(len(models.Models)).To(Equal(13))
}) })
It("can generate chat completions from config file", func() { It("can generate chat completions from config file", func() {
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}})

@ -30,6 +30,10 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
model.WithContext(o.Context), model.WithContext(o.Context),
} }
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
if c.Backend == "" { if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...) inferenceModel, err = loader.GreedyLoader(opts...)
} else { } else {

@ -15,12 +15,20 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
return nil, fmt.Errorf("endpoint only working with stablediffusion models") return nil, fmt.Errorf("endpoint only working with stablediffusion models")
} }
inferenceModel, err := loader.BackendLoader( opts := []model.Option{
model.WithBackendString(c.Backend), model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)), model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context), model.WithContext(o.Context),
model.WithModelFile(c.ImageGenerationAssets), model.WithModelFile(c.ImageGenerationAssets),
}
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
inferenceModel, err := loader.BackendLoader(
opts...,
) )
if err != nil { if err != nil {
return nil, err return nil, err

@ -1,14 +1,17 @@
package backend package backend
import ( import (
"os"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
) )
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
@ -27,12 +30,32 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
model.WithContext(o.Context), model.WithContext(o.Context),
} }
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
}
// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
if err != nil {
return nil, err
}
}
}
if c.Backend == "" { if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...) inferenceModel, err = loader.GreedyLoader(opts...)
} else { } else {
opts = append(opts, model.WithBackendString(c.Backend))
inferenceModel, err = loader.BackendLoader(opts...) inferenceModel, err = loader.BackendLoader(opts...)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -50,6 +73,9 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
return ss, err return ss, err
} else { } else {
reply, err := inferenceModel.Predict(o.Context, opts) reply, err := inferenceModel.Predict(o.Context, opts)
if err != nil {
return "", err
}
return reply.Message, err return reply.Message, err
} }
} }

@ -0,0 +1,42 @@
package backend
import (
"context"
"fmt"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) {
opts := []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModelFile(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
}
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
whisperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
return nil, err
}
if whisperModel == nil {
return nil, fmt.Errorf("could not load whisper model")
}
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(c.Threads),
})
}

@ -0,0 +1,72 @@
package backend
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}
counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}
func ModelTTS(text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) {
opts := []model.Option{
model.WithBackendString(model.PiperBackend),
model.WithModelFile(modelFile),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination),
}
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
piperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}
if piperModel == nil {
return "", nil, fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
filePath := filepath.Join(o.AudioDir, fileName)
modelPath := filepath.Join(o.Loader.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
return "", nil, err
}
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Dst: filePath,
})
return filePath, res, err
}

@ -4,13 +4,15 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"strings"
"sync" "sync"
"time"
json "github.com/json-iterator/go" json "github.com/json-iterator/go"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -80,6 +82,8 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
case <-c.Done(): case <-c.Done():
return return
case op := <-g.C: case op := <-g.C:
utils.ResetDownloadTimers()
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error // updates the status with an error
@ -90,13 +94,17 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
// displayDownload displays the download progress // displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) { progressCallback := func(fileName string, current string, total string, percentage float64) {
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage) utils.DisplayDownloadFunction(fileName, current, total, percentage)
} }
var err error var err error
// if the request contains a gallery name, we apply the gallery from the gallery list // if the request contains a gallery name, we apply the gallery from the gallery list
if op.galleryName != "" { if op.galleryName != "" {
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) if strings.Contains(op.galleryName, "@") {
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
}
} else { } else {
err = prepareModel(g.modelPath, op.req, cm, progressCallback) err = prepareModel(g.modelPath, op.req, cm, progressCallback)
} }
@ -119,31 +127,6 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
}() }()
} }
var lastProgress time.Time = time.Now()
var startTime time.Time = time.Now()
func displayDownload(fileName string, current string, total string, percentage float64) {
currentTime := time.Now()
if currentTime.Sub(lastProgress) >= 5*time.Second {
lastProgress = currentTime
// calculate ETA based on percentage and elapsed time
var eta time.Duration
if percentage > 0 {
elapsed := currentTime.Sub(startTime)
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
}
if total != "" {
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
} else {
log.Debug().Msgf("Downloading: %s", current)
}
}
}
type galleryModel struct { type galleryModel struct {
gallery.GalleryModel gallery.GalleryModel
ID string `json:"id"` ID string `json:"id"`
@ -165,10 +148,11 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler
} }
for _, r := range requests { for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" { if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload) err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else { } else {
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload) err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} }
} }

@ -1,17 +1,10 @@
package localai package localai
import ( import (
"context" "github.com/go-skynet/LocalAI/api/backend"
"fmt"
"os"
"path/filepath"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
@ -20,22 +13,6 @@ type TTSRequest struct {
Input string `json:"input" yaml:"input"` Input string `json:"input" yaml:"input"`
} }
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}
counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
@ -45,40 +22,10 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return err return err
} }
piperModel, err := o.Loader.BackendLoader( filePath, _, err := backend.ModelTTS(input.Input, input.Model, o.Loader, o)
model.WithBackendString(model.PiperBackend),
model.WithModelFile(input.Model),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination))
if err != nil { if err != nil {
return err return err
} }
if piperModel == nil {
return fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
return fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
filePath := filepath.Join(o.AudioDir, fileName)
modelPath := filepath.Join(o.Loader.ModelPath, input.Model)
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
return err
}
if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
Text: input.Input,
Model: modelPath,
Dst: filePath,
}); err != nil {
return err
}
return c.Download(filePath) return c.Download(filePath)
} }
} }

@ -1,7 +1,6 @@
package openai package openai
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -9,10 +8,9 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -61,25 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
log.Debug().Msgf("Audio file copied to: %+v", dst) log.Debug().Msgf("Audio file copied to: %+v", dst)
whisperModel, err := o.Loader.BackendLoader( tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
model.WithBackendString(model.WhisperBackend),
model.WithModelFile(config.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(config.Threads)),
model.WithAssetDir(o.AssetsDestination))
if err != nil {
return err
}
if whisperModel == nil {
return fmt.Errorf("could not load whisper model")
}
tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: dst,
Language: input.Language,
Threads: uint32(config.Threads),
})
if err != nil { if err != nil {
return err return err
} }

@ -28,6 +28,10 @@ type Option struct {
BackendAssets embed.FS BackendAssets embed.FS
AssetsDestination string AssetsDestination string
ExternalGRPCBackends map[string]string
AutoloadGalleries bool
} }
type AppOption func(*Option) type AppOption func(*Option)
@ -53,6 +57,19 @@ func WithCors(b bool) AppOption {
} }
} }
var EnableGalleriesAutoload = func(o *Option) {
o.AutoloadGalleries = true
}
func WithExternalBackend(name string, uri string) AppOption {
return func(o *Option) {
if o.ExternalGRPCBackends == nil {
o.ExternalGRPCBackends = make(map[string]string)
}
o.ExternalGRPCBackends[name] = uri
}
}
func WithCorsAllowOrigins(b string) AppOption { func WithCorsAllowOrigins(b string) AppOption {
return func(o *Option) { return func(o *Option) {
o.CORSAllowOrigins = b o.CORSAllowOrigins = b

@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: backend.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\xa4\x05\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\t\"\xac\x02\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals)
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto'
_globals['_HEALTHMESSAGE']._serialized_start=26
_globals['_HEALTHMESSAGE']._serialized_end=41
_globals['_PREDICTOPTIONS']._serialized_start=44
_globals['_PREDICTOPTIONS']._serialized_end=720
_globals['_REPLY']._serialized_start=722
_globals['_REPLY']._serialized_end=746
_globals['_MODELOPTIONS']._serialized_start=749
_globals['_MODELOPTIONS']._serialized_end=1049
_globals['_RESULT']._serialized_start=1051
_globals['_RESULT']._serialized_end=1093
_globals['_EMBEDDINGRESULT']._serialized_start=1095
_globals['_EMBEDDINGRESULT']._serialized_end=1132
_globals['_TRANSCRIPTREQUEST']._serialized_start=1134
_globals['_TRANSCRIPTREQUEST']._serialized_end=1201
_globals['_TRANSCRIPTRESULT']._serialized_start=1203
_globals['_TRANSCRIPTRESULT']._serialized_end=1281
_globals['_TRANSCRIPTSEGMENT']._serialized_start=1283
_globals['_TRANSCRIPTSEGMENT']._serialized_end=1372
_globals['_GENERATEIMAGEREQUEST']._serialized_start=1375
_globals['_GENERATEIMAGEREQUEST']._serialized_end=1533
_globals['_TTSREQUEST']._serialized_start=1535
_globals['_TTSREQUEST']._serialized_end=1589
_globals['_BACKEND']._serialized_start=1592
_globals['_BACKEND']._serialized_end=2083
# @@protoc_insertion_point(module_scope)

@ -0,0 +1,297 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import backend_pb2 as backend__pb2
class BackendStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.Health = channel.unary_unary(
'/backend.Backend/Health',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.Reply.FromString,
)
self.Predict = channel.unary_unary(
'/backend.Backend/Predict',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.Reply.FromString,
)
self.LoadModel = channel.unary_unary(
'/backend.Backend/LoadModel',
request_serializer=backend__pb2.ModelOptions.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
self.PredictStream = channel.unary_stream(
'/backend.Backend/PredictStream',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.Reply.FromString,
)
self.Embedding = channel.unary_unary(
'/backend.Backend/Embedding',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.EmbeddingResult.FromString,
)
self.GenerateImage = channel.unary_unary(
'/backend.Backend/GenerateImage',
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
self.AudioTranscription = channel.unary_unary(
'/backend.Backend/AudioTranscription',
request_serializer=backend__pb2.TranscriptRequest.SerializeToString,
response_deserializer=backend__pb2.TranscriptResult.FromString,
)
self.TTS = channel.unary_unary(
'/backend.Backend/TTS',
request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
class BackendServicer(object):
"""Missing associated documentation comment in .proto file."""
def Health(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Predict(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def LoadModel(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def PredictStream(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def Embedding(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def GenerateImage(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def AudioTranscription(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def TTS(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = {
'Health': grpc.unary_unary_rpc_method_handler(
servicer.Health,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.Reply.SerializeToString,
),
'Predict': grpc.unary_unary_rpc_method_handler(
servicer.Predict,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.Reply.SerializeToString,
),
'LoadModel': grpc.unary_unary_rpc_method_handler(
servicer.LoadModel,
request_deserializer=backend__pb2.ModelOptions.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
'PredictStream': grpc.unary_stream_rpc_method_handler(
servicer.PredictStream,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.Reply.SerializeToString,
),
'Embedding': grpc.unary_unary_rpc_method_handler(
servicer.Embedding,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.EmbeddingResult.SerializeToString,
),
'GenerateImage': grpc.unary_unary_rpc_method_handler(
servicer.GenerateImage,
request_deserializer=backend__pb2.GenerateImageRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
'AudioTranscription': grpc.unary_unary_rpc_method_handler(
servicer.AudioTranscription,
request_deserializer=backend__pb2.TranscriptRequest.FromString,
response_serializer=backend__pb2.TranscriptResult.SerializeToString,
),
'TTS': grpc.unary_unary_rpc_method_handler(
servicer.TTS,
request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class Backend(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def Health(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.Reply.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Predict(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.Reply.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def LoadModel(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel',
backend__pb2.ModelOptions.SerializeToString,
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def PredictStream(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.Reply.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def Embedding(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.EmbeddingResult.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def GenerateImage(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage',
backend__pb2.GenerateImageRequest.SerializeToString,
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def AudioTranscription(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription',
backend__pb2.TranscriptRequest.SerializeToString,
backend__pb2.TranscriptResult.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def TTS(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS',
backend__pb2.TTSRequest.SerializeToString,
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@ -0,0 +1,67 @@
#!/usr/bin/env python3
import grpc
from concurrent import futures
import time
import backend_pb2
import backend_pb2_grpc
import argparse
import signal
import sys
import os
from sentence_transformers import SentenceTransformer
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context):
return backend_pb2.Reply(message="OK")
def LoadModel(self, request, context):
model_name = request.Model
model_name = os.path.basename(model_name)
try:
self.model = SentenceTransformer(model_name)
except Exception as err:
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
# Implement your logic here for the LoadModel service
# Replace this with your desired response
return backend_pb2.Result(message="Model loaded successfully", success=True)
def Embedding(self, request, context):
# Implement your logic here for the Embedding service
# Replace this with your desired response
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr)
sentence_embeddings = self.model.encode(request.Embeddings)
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings)
def serve(address):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
# Define the signal handler function
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)
# Set the signal handlers for SIGINT and SIGTERM
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument(
"--addr", default="localhost:50051", help="The address to bind the server to."
)
args = parser.parse_args()
serve(args.addr)

@ -0,0 +1,4 @@
sentence_transformers
grpcio
google
protobuf

@ -4,6 +4,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"strings"
"syscall" "syscall"
api "github.com/go-skynet/LocalAI/api" api "github.com/go-skynet/LocalAI/api"
@ -40,6 +41,10 @@ func main() {
Name: "f16", Name: "f16",
EnvVars: []string{"F16"}, EnvVars: []string{"F16"},
}, },
&cli.BoolFlag{
Name: "autoload-galleries",
EnvVars: []string{"AUTOLOAD_GALLERIES"},
},
&cli.BoolFlag{ &cli.BoolFlag{
Name: "debug", Name: "debug",
EnvVars: []string{"DEBUG"}, EnvVars: []string{"DEBUG"},
@ -108,6 +113,11 @@ func main() {
EnvVars: []string{"BACKEND_ASSETS_PATH"}, EnvVars: []string{"BACKEND_ASSETS_PATH"},
Value: "/tmp/localai/backend_data", Value: "/tmp/localai/backend_data",
}, },
&cli.StringSliceFlag{
Name: "external-grpc-backends",
Usage: "A list of external grpc backends",
EnvVars: []string{"EXTERNAL_GRPC_BACKENDS"},
},
&cli.IntFlag{ &cli.IntFlag{
Name: "context-size", Name: "context-size",
Usage: "Default context size of the model", Usage: "Default context size of the model",
@ -138,7 +148,8 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
UsageText: `local-ai [options]`, UsageText: `local-ai [options]`,
Copyright: "Ettore Di Giacinto", Copyright: "Ettore Di Giacinto",
Action: func(ctx *cli.Context) error { Action: func(ctx *cli.Context) error {
app, err := api.App(
opts := []options.AppOption{
options.WithConfigFile(ctx.String("config-file")), options.WithConfigFile(ctx.String("config-file")),
options.WithJSONStringPreload(ctx.String("preload-models")), options.WithJSONStringPreload(ctx.String("preload-models")),
options.WithYAMLConfigPreload(ctx.String("preload-models-config")), options.WithYAMLConfigPreload(ctx.String("preload-models-config")),
@ -155,7 +166,22 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
options.WithThreads(ctx.Int("threads")), options.WithThreads(ctx.Int("threads")),
options.WithBackendAssets(backendAssets), options.WithBackendAssets(backendAssets),
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
options.WithUploadLimitMB(ctx.Int("upload-limit"))) options.WithUploadLimitMB(ctx.Int("upload-limit")),
}
externalgRPC := ctx.StringSlice("external-grpc-backends")
// split ":" to get backend name and the uri
for _, v := range externalgRPC {
backend := v[:strings.IndexByte(v, ':')]
uri := v[strings.IndexByte(v, ':')+1:]
opts = append(opts, options.WithExternalBackend(backend, uri))
}
if ctx.Bool("autoload-galleries") {
opts = append(opts, options.EnableGalleriesAutoload)
}
app, err := api.App(opts...)
if err != nil { if err != nil {
return err return err
} }

@ -18,23 +18,15 @@ type Gallery struct {
// Installs a model from the gallery (galleryname@modelname) // Installs a model from the gallery (galleryname@modelname)
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
models, err := AvailableGalleryModels(galleries, basePath)
if err != nil {
return err
}
applyModel := func(model *GalleryModel) error { applyModel := func(model *GalleryModel) error {
config, err := GetGalleryConfigFromURL(model.URL) config, err := GetGalleryConfigFromURL(model.URL)
if err != nil { if err != nil {
return err return err
} }
installName := model.Name
if req.Name != "" { if req.Name != "" {
model.Name = req.Name installName = req.Name
} }
config.Files = append(config.Files, req.AdditionalFiles...) config.Files = append(config.Files, req.AdditionalFiles...)
@ -45,20 +37,58 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
return err return err
} }
if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil { if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus); err != nil {
return err return err
} }
return nil return nil
} }
models, err := AvailableGalleryModels(galleries, basePath)
if err != nil {
return err
}
model, err := FindGallery(models, name)
if err != nil {
return err
}
return applyModel(model)
}
func FindGallery(models []*GalleryModel, name string) (*GalleryModel, error) {
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
for _, model := range models { for _, model := range models {
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) { if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
return applyModel(model) return model, nil
} }
} }
return nil, fmt.Errorf("no gallery found with name %q", name)
}
// InstallModelFromGalleryByName loads a model from the gallery by specifying only the name (first match wins)
func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
models, err := AvailableGalleryModels(galleries, basePath)
if err != nil {
return err
}
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
var model *GalleryModel
for _, m := range models {
if name == m.Name {
model = m
}
}
if model == nil {
return fmt.Errorf("no model found with name %q", name)
}
return fmt.Errorf("no model found with name %q", name) return InstallModelFromGallery(galleries, fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name), basePath, req, downloadStatus)
} }
// List available models // List available models

@ -3,7 +3,9 @@ package tts
// This is a wrapper to statisfy the GRPC service interface // This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import ( import (
"fmt"
"os" "os"
"path/filepath"
"github.com/go-skynet/LocalAI/pkg/grpc/base" "github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
@ -16,6 +18,9 @@ type Piper struct {
} }
func (sd *Piper) Load(opts *pb.ModelOptions) error { func (sd *Piper) Load(opts *pb.ModelOptions) error {
if filepath.Ext(opts.Model) != ".onnx" {
return fmt.Errorf("unsupported model type %s (should end with .onnx)", opts.Model)
}
var err error var err error
// Note: the Model here is a path to a directory containing the model files // Note: the Model here is a path to a directory containing the model files
sd.piper, err = New(opts.LibrarySearchPath) sd.piper, err = New(opts.LibrarySearchPath)

@ -19,8 +19,6 @@ import (
process "github.com/mudler/go-processmanager" process "github.com/mudler/go-processmanager"
) )
const tokenizerSuffix = ".tokenizer.json"
const ( const (
LlamaBackend = "llama" LlamaBackend = "llama"
BloomzBackend = "bloomz" BloomzBackend = "bloomz"
@ -45,7 +43,6 @@ const (
StableDiffusionBackend = "stablediffusion" StableDiffusionBackend = "stablediffusion"
PiperBackend = "piper" PiperBackend = "piper"
LCHuggingFaceBackend = "langchain-huggingface" LCHuggingFaceBackend = "langchain-huggingface"
//GGLLMFalconBackend = "falcon"
) )
var AutoLoadBackends []string = []string{ var AutoLoadBackends []string = []string{
@ -62,6 +59,11 @@ var AutoLoadBackends []string = []string{
MPTBackend, MPTBackend,
ReplitBackend, ReplitBackend,
StarcoderBackend, StarcoderBackend,
BloomzBackend,
RwkvBackend,
WhisperBackend,
StableDiffusionBackend,
PiperBackend,
} }
func (ml *ModelLoader) StopGRPC() { func (ml *ModelLoader) StopGRPC() {
@ -70,75 +72,116 @@ func (ml *ModelLoader) StopGRPC() {
} }
} }
// starts the grpcModelProcess for the backend, and returns a grpc client func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
// It also loads the model // Make sure the process is executable
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) { if err := os.Chmod(grpcProcess, 0755); err != nil {
return func(s string) (*grpc.Client, error) { return err
log.Debug().Msgf("Loading GRPC Model", backend, *o) }
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) log.Debug().Msgf("Loading GRPC Process", grpcProcess)
// Check if the file exists log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
}
// Make sure the process is executable grpcControlProcess := process.New(
if err := os.Chmod(grpcProcess, 0755); err != nil { process.WithTemporaryStateDir(),
return nil, err process.WithName(grpcProcess),
} process.WithArgs("--addr", serverAddress))
ml.grpcProcesses[id] = grpcControlProcess
if err := grpcControlProcess.Run(); err != nil {
return err
}
log.Debug().Msgf("Loading GRPC Process", grpcProcess) log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
port, err := freeport.GetFreePort() // clean up process
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
grpcControlProcess.Stop()
}()
go func() {
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
if err != nil { if err != nil {
return nil, err log.Debug().Msgf("Could not tail stderr")
} }
for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
}
}()
go func() {
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
if err != nil {
log.Debug().Msgf("Could not tail stdout")
}
for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
}
}()
serverAddress := fmt.Sprintf("localhost:%d", port) return nil
}
log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress)
grpcControlProcess := process.New( // starts the grpcModelProcess for the backend, and returns a grpc client
process.WithTemporaryStateDir(), // It also loads the model
process.WithName(grpcProcess), func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
process.WithArgs("--addr", serverAddress)) return func(s string) (*grpc.Client, error) {
log.Debug().Msgf("Loading GRPC Model", backend, *o)
ml.grpcProcesses[o.modelFile] = grpcControlProcess var client *grpc.Client
if err := grpcControlProcess.Run(); err != nil { getFreeAddress := func() (string, error) {
return nil, err port, err := freeport.GetFreePort()
if err != nil {
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
}
return fmt.Sprintf("127.0.0.1:%d", port), nil
} }
// clean up process // Check if the backend is provided as external
go func() { if uri, ok := o.externalBackends[backend]; ok {
c := make(chan os.Signal, 1) log.Debug().Msgf("Loading external backend: %s", uri)
signal.Notify(c, os.Interrupt, syscall.SIGTERM) // check if uri is a file or a address
<-c if _, err := os.Stat(uri); err == nil {
grpcControlProcess.Stop() serverAddress, err := getFreeAddress()
}() if err != nil {
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
go func() { }
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) // Make sure the process is executable
if err != nil { if err := ml.startProcess(uri, o.modelFile, serverAddress); err != nil {
log.Debug().Msgf("Could not tail stderr") return nil, err
}
log.Debug().Msgf("GRPC Service Started")
client = grpc.NewClient(serverAddress)
} else {
// address
client = grpc.NewClient(uri)
} }
for line := range t.Lines { } else {
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
// Check if the file exists
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
} }
}()
go func() { serverAddress, err := getFreeAddress()
t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true})
if err != nil { if err != nil {
log.Debug().Msgf("Could not tail stdout") return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
} }
for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) // Make sure the process is executable
if err := ml.startProcess(grpcProcess, o.modelFile, serverAddress); err != nil {
return nil, err
} }
}()
log.Debug().Msgf("GRPC Service Started") log.Debug().Msgf("GRPC Service Started")
client := grpc.NewClient(serverAddress) client = grpc.NewClient(serverAddress)
}
// Wait for the service to start up // Wait for the service to start up
ready := false ready := false
@ -153,11 +196,6 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
if !ready { if !ready {
log.Debug().Msgf("GRPC Service NOT ready") log.Debug().Msgf("GRPC Service NOT ready")
log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive())
log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:"))
log.Debug().Msgf(grpcControlProcess.ExitCode())
return nil, fmt.Errorf("grpc service not ready") return nil, fmt.Errorf("grpc service not ready")
} }
@ -168,10 +206,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
res, err := client.LoadModel(o.context, &options) res, err := client.LoadModel(o.context, &options)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("could not load model: %w", err)
} }
if !res.Success { if !res.Success {
return nil, fmt.Errorf("could not load model: %s", res.Message) return nil, fmt.Errorf("could not load model (no success): %s", res.Message)
} }
return client, nil return client, nil
@ -184,6 +222,13 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile)
backend := strings.ToLower(o.backendString) backend := strings.ToLower(o.backendString)
// if an external backend is provided, use it
_, externalBackendExists := o.externalBackends[backend]
if externalBackendExists {
return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o))
}
switch backend { switch backend {
case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend, case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend,
MPTBackend, Gpt2Backend, FalconBackend, MPTBackend, Gpt2Backend, FalconBackend,
@ -204,8 +249,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
o := NewOptions(opts...) o := NewOptions(opts...)
log.Debug().Msgf("Loading model '%s' greedly", o.modelFile)
// Is this really needed? BackendLoader already does this // Is this really needed? BackendLoader already does this
ml.mu.Lock() ml.mu.Lock()
if m := ml.checkIsLoaded(o.modelFile); m != nil { if m := ml.checkIsLoaded(o.modelFile); m != nil {
@ -216,16 +259,29 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
ml.mu.Unlock() ml.mu.Unlock()
var err error var err error
for _, b := range AutoLoadBackends { // autoload also external backends
log.Debug().Msgf("[%s] Attempting to load", b) allBackendsToAutoLoad := []string{}
allBackendsToAutoLoad = append(allBackendsToAutoLoad, AutoLoadBackends...)
for _, b := range o.externalBackends {
allBackendsToAutoLoad = append(allBackendsToAutoLoad, b)
}
log.Debug().Msgf("Loading model '%s' greedly from all the available backends: %s", o.modelFile, strings.Join(allBackendsToAutoLoad, ", "))
model, modelerr := ml.BackendLoader( for _, b := range allBackendsToAutoLoad {
log.Debug().Msgf("[%s] Attempting to load", b)
options := []Option{
WithBackendString(b), WithBackendString(b),
WithModelFile(o.modelFile), WithModelFile(o.modelFile),
WithLoadGRPCLLMModelOpts(o.gRPCOptions), WithLoadGRPCLLMModelOpts(o.gRPCOptions),
WithThreads(o.threads), WithThreads(o.threads),
WithAssetDir(o.assetDir), WithAssetDir(o.assetDir),
) }
for k, v := range o.externalBackends {
options = append(options, WithExternalBackend(k, v))
}
model, modelerr := ml.BackendLoader(options...)
if modelerr == nil && model != nil { if modelerr == nil && model != nil {
log.Debug().Msgf("[%s] Loads OK", b) log.Debug().Msgf("[%s] Loads OK", b)
return model, nil return model, nil
@ -233,7 +289,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
err = multierror.Append(err, modelerr) err = multierror.Append(err, modelerr)
log.Debug().Msgf("[%s] Fails: %s", b, modelerr.Error()) log.Debug().Msgf("[%s] Fails: %s", b, modelerr.Error())
} else if model == nil { } else if model == nil {
err = multierror.Append(err, modelerr) err = multierror.Append(err, fmt.Errorf("backend returned no usable model"))
log.Debug().Msgf("[%s] Fails: %s", b, "backend returned no usable model") log.Debug().Msgf("[%s] Fails: %s", b, "backend returned no usable model")
} }
} }

@ -14,10 +14,21 @@ type Options struct {
context context.Context context context.Context
gRPCOptions *pb.ModelOptions gRPCOptions *pb.ModelOptions
externalBackends map[string]string
} }
type Option func(*Options) type Option func(*Options)
func WithExternalBackend(name string, uri string) Option {
return func(o *Options) {
if o.externalBackends == nil {
o.externalBackends = make(map[string]string)
}
o.externalBackends[name] = uri
}
}
func WithBackendString(backend string) Option { func WithBackendString(backend string) Option {
return func(o *Options) { return func(o *Options) {
o.backendString = backend o.backendString = backend

@ -0,0 +1,37 @@
package utils
import (
"time"
"github.com/rs/zerolog/log"
)
var lastProgress time.Time = time.Now()
var startTime time.Time = time.Now()
func ResetDownloadTimers() {
lastProgress = time.Now()
startTime = time.Now()
}
func DisplayDownloadFunction(fileName string, current string, total string, percentage float64) {
currentTime := time.Now()
if currentTime.Sub(lastProgress) >= 5*time.Second {
lastProgress = currentTime
// calculate ETA based on percentage and elapsed time
var eta time.Duration
if percentage > 0 {
elapsed := currentTime.Sub(startTime)
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
}
if total != "" {
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
} else {
log.Debug().Msgf("Downloading: %s", current)
}
}
}

@ -0,0 +1,5 @@
name: code-search-ada-code-001
backend: huggingface
embeddings: true
parameters:
model: all-MiniLM-L6-v2
Loading…
Cancel
Save