From 47cc95fc9ff65bfd2643249d5a3ee19da57a2213 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 20 Jul 2023 00:40:26 +0200 Subject: [PATCH 1/6] feat: add all backends to autoload Now since gRPCs are not crashing the main thread we can just greedly attempt all the backends we have available. Signed-off-by: Ettore Di Giacinto --- pkg/model/initializers.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 53fc684..d3b4bb3 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -62,6 +62,11 @@ var AutoLoadBackends []string = []string{ MPTBackend, ReplitBackend, StarcoderBackend, + BloomzBackend, + RwkvBackend, + WhisperBackend, + StableDiffusionBackend, + PiperBackend, } func (ml *ModelLoader) StopGRPC() { From 1d2ae46ddcba32b4a3e1644931ad950a8b3be6f7 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 20 Jul 2023 01:36:34 +0200 Subject: [PATCH 2/6] tests: clean up logs Signed-off-by: Ettore Di Giacinto --- api/api_test.go | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/api/api_test.go b/api/api_test.go index 6970a8f..732076c 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -124,6 +124,8 @@ var _ = Describe("API test", func() { var c context.Context var cancel context.CancelFunc var tmpdir string + commonOpts := []options.AppOption{options.WithDebug(false), + options.WithDisableMessage(true)} Context("API with ephemeral models", func() { BeforeEach(func() { @@ -159,9 +161,10 @@ var _ = Describe("API test", func() { } app, err = App( - options.WithContext(c), - options.WithGalleries(galleries), - options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir)) + append(commonOpts, + options.WithContext(c), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir))...) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -400,13 +403,14 @@ var _ = Describe("API test", func() { } app, err = App( - options.WithContext(c), - options.WithAudioDir(tmpdir), - options.WithImageDir(tmpdir), - options.WithGalleries(galleries), - options.WithModelLoader(modelLoader), - options.WithBackendAssets(backendAssets), - options.WithBackendAssetsOutput(tmpdir), + append(commonOpts, + options.WithContext(c), + options.WithAudioDir(tmpdir), + options.WithImageDir(tmpdir), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), + options.WithBackendAssets(backendAssets), + options.WithBackendAssetsOutput(tmpdir))..., ) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -500,7 +504,9 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader)) + app, err = App( + append(commonOpts, + options.WithContext(c), options.WithModelLoader(modelLoader))...) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -674,7 +680,12 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader), options.WithConfigFile(os.Getenv("CONFIG_FILE"))) + app, err = App( + append(commonOpts, + options.WithContext(c), + options.WithModelLoader(modelLoader), + options.WithConfigFile(os.Getenv("CONFIG_FILE")))..., + ) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") From 94916749c5551eaaa828df594db917c78f0b53f1 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 20 Jul 2023 22:10:12 +0200 Subject: [PATCH 3/6] feat: add external grpc and model autoloading --- api/backend/embeddings.go | 4 + api/backend/image.go | 10 +- api/backend/llm.go | 25 ++++- api/backend/transcript.go | 42 ++++++++ api/backend/tts.go | 72 +++++++++++++ api/localai/gallery.go | 44 +++----- api/localai/localai.go | 57 +--------- api/openai/transcription.go | 24 +---- api/options/options.go | 17 +++ main.go | 30 +++++- pkg/gallery/gallery.go | 56 +++++++--- pkg/model/initializers.go | 179 ++++++++++++++++++++------------ pkg/model/options.go | 11 ++ pkg/utils/logging.go | 37 +++++++ tests/models_fixtures/grpc.yaml | 5 + 15 files changed, 425 insertions(+), 188 deletions(-) create mode 100644 api/backend/transcript.go create mode 100644 api/backend/tts.go create mode 100644 pkg/utils/logging.go create mode 100644 tests/models_fixtures/grpc.yaml diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go index 0310347..53df785 100644 --- a/api/backend/embeddings.go +++ b/api/backend/embeddings.go @@ -30,6 +30,10 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. model.WithContext(o.Context), } + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + if c.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { diff --git a/api/backend/image.go b/api/backend/image.go index a631b3b..9e32d1d 100644 --- a/api/backend/image.go +++ b/api/backend/image.go @@ -15,12 +15,20 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat return nil, fmt.Errorf("endpoint only working with stablediffusion models") } - inferenceModel, err := loader.BackendLoader( + opts := []model.Option{ model.WithBackendString(c.Backend), model.WithAssetDir(o.AssetsDestination), model.WithThreads(uint32(c.Threads)), model.WithContext(o.Context), model.WithModelFile(c.ImageGenerationAssets), + } + + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + inferenceModel, err := loader.BackendLoader( + opts..., ) if err != nil { return nil, err diff --git a/api/backend/llm.go b/api/backend/llm.go index 8fcd6da..593eea3 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -1,14 +1,17 @@ package backend import ( + "os" "regexp" "strings" "sync" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" ) func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { @@ -27,12 +30,32 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt model.WithContext(o.Context), } + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + if c.Backend != "" { + opts = append(opts, model.WithBackendString(c.Backend)) + } + + // Check if the modelFile exists, if it doesn't try to load it from the gallery + if o.AutoloadGalleries { // experimental + if _, err := os.Stat(modelFile); os.IsNotExist(err) { + utils.ResetDownloadTimers() + // if we failed to load the model, we try to download it + err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) + if err != nil { + return nil, err + } + } + } + if c.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) } else { - opts = append(opts, model.WithBackendString(c.Backend)) inferenceModel, err = loader.BackendLoader(opts...) } + if err != nil { return nil, err } diff --git a/api/backend/transcript.go b/api/backend/transcript.go new file mode 100644 index 0000000..b2f2501 --- /dev/null +++ b/api/backend/transcript.go @@ -0,0 +1,42 @@ +package backend + +import ( + "context" + "fmt" + + config "github.com/go-skynet/LocalAI/api/config" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) { + opts := []model.Option{ + model.WithBackendString(model.WhisperBackend), + model.WithModelFile(c.Model), + model.WithContext(o.Context), + model.WithThreads(uint32(c.Threads)), + model.WithAssetDir(o.AssetsDestination), + } + + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + whisperModel, err := o.Loader.BackendLoader(opts...) + if err != nil { + return nil, err + } + + if whisperModel == nil { + return nil, fmt.Errorf("could not load whisper model") + } + + return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ + Dst: audio, + Language: language, + Threads: uint32(c.Threads), + }) +} diff --git a/api/backend/tts.go b/api/backend/tts.go new file mode 100644 index 0000000..ac491e2 --- /dev/null +++ b/api/backend/tts.go @@ -0,0 +1,72 @@ +package backend + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" +) + +func generateUniqueFileName(dir, baseName, ext string) string { + counter := 1 + fileName := baseName + ext + + for { + filePath := filepath.Join(dir, fileName) + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + return fileName + } + + counter++ + fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) + } +} + +func ModelTTS(text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) { + opts := []model.Option{ + model.WithBackendString(model.PiperBackend), + model.WithModelFile(modelFile), + model.WithContext(o.Context), + model.WithAssetDir(o.AssetsDestination), + } + + for k, v := range o.ExternalGRPCBackends { + opts = append(opts, model.WithExternalBackend(k, v)) + } + + piperModel, err := o.Loader.BackendLoader(opts...) + if err != nil { + return "", nil, err + } + + if piperModel == nil { + return "", nil, fmt.Errorf("could not load piper model") + } + + if err := os.MkdirAll(o.AudioDir, 0755); err != nil { + return "", nil, fmt.Errorf("failed creating audio directory: %s", err) + } + + fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") + filePath := filepath.Join(o.AudioDir, fileName) + + modelPath := filepath.Join(o.Loader.ModelPath, modelFile) + + if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { + return "", nil, err + } + + res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ + Text: text, + Model: modelPath, + Dst: filePath, + }) + + return filePath, res, err +} diff --git a/api/localai/gallery.go b/api/localai/gallery.go index feae294..ef4be14 100644 --- a/api/localai/gallery.go +++ b/api/localai/gallery.go @@ -4,13 +4,15 @@ import ( "context" "fmt" "os" + "strings" "sync" - "time" json "github.com/json-iterator/go" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/rs/zerolog/log" @@ -80,6 +82,8 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { case <-c.Done(): return case op := <-g.C: + utils.ResetDownloadTimers() + g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) // updates the status with an error @@ -90,13 +94,17 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { // displayDownload displays the download progress progressCallback := func(fileName string, current string, total string, percentage float64) { g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) - displayDownload(fileName, current, total, percentage) + utils.DisplayDownloadFunction(fileName, current, total, percentage) } var err error // if the request contains a gallery name, we apply the gallery from the gallery list if op.galleryName != "" { - err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + if strings.Contains(op.galleryName, "@") { + err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + } else { + err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) + } } else { err = prepareModel(g.modelPath, op.req, cm, progressCallback) } @@ -119,31 +127,6 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { }() } -var lastProgress time.Time = time.Now() -var startTime time.Time = time.Now() - -func displayDownload(fileName string, current string, total string, percentage float64) { - currentTime := time.Now() - - if currentTime.Sub(lastProgress) >= 5*time.Second { - - lastProgress = currentTime - - // calculate ETA based on percentage and elapsed time - var eta time.Duration - if percentage > 0 { - elapsed := currentTime.Sub(startTime) - eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed)) - } - - if total != "" { - log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta) - } else { - log.Debug().Msgf("Downloading: %s", current) - } - } -} - type galleryModel struct { gallery.GalleryModel ID string `json:"id"` @@ -165,10 +148,11 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler } for _, r := range requests { + utils.ResetDownloadTimers() if r.ID == "" { - err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload) + err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) } else { - err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload) + err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) } } diff --git a/api/localai/localai.go b/api/localai/localai.go index 7c57c92..49f7780 100644 --- a/api/localai/localai.go +++ b/api/localai/localai.go @@ -1,17 +1,10 @@ package localai import ( - "context" - "fmt" - "os" - "path/filepath" - + "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" ) @@ -20,22 +13,6 @@ type TTSRequest struct { Input string `json:"input" yaml:"input"` } -func generateUniqueFileName(dir, baseName, ext string) string { - counter := 1 - fileName := baseName + ext - - for { - filePath := filepath.Join(dir, fileName) - _, err := os.Stat(filePath) - if os.IsNotExist(err) { - return fileName - } - - counter++ - fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) - } -} - func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { @@ -45,40 +22,10 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return err } - piperModel, err := o.Loader.BackendLoader( - model.WithBackendString(model.PiperBackend), - model.WithModelFile(input.Model), - model.WithContext(o.Context), - model.WithAssetDir(o.AssetsDestination)) + filePath, _, err := backend.ModelTTS(input.Input, input.Model, o.Loader, o) if err != nil { return err } - - if piperModel == nil { - return fmt.Errorf("could not load piper model") - } - - if err := os.MkdirAll(o.AudioDir, 0755); err != nil { - return fmt.Errorf("failed creating audio directory: %s", err) - } - - fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") - filePath := filepath.Join(o.AudioDir, fileName) - - modelPath := filepath.Join(o.Loader.ModelPath, input.Model) - - if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { - return err - } - - if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ - Text: input.Input, - Model: modelPath, - Dst: filePath, - }); err != nil { - return err - } - return c.Download(filePath) } } diff --git a/api/openai/transcription.go b/api/openai/transcription.go index 346693c..4b4a65e 100644 --- a/api/openai/transcription.go +++ b/api/openai/transcription.go @@ -1,7 +1,6 @@ package openai import ( - "context" "fmt" "io" "net/http" @@ -9,10 +8,9 @@ import ( "path" "path/filepath" + "github.com/go-skynet/LocalAI/api/backend" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" - "github.com/go-skynet/LocalAI/pkg/grpc/proto" - model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -61,25 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe log.Debug().Msgf("Audio file copied to: %+v", dst) - whisperModel, err := o.Loader.BackendLoader( - model.WithBackendString(model.WhisperBackend), - model.WithModelFile(config.Model), - model.WithContext(o.Context), - model.WithThreads(uint32(config.Threads)), - model.WithAssetDir(o.AssetsDestination)) - if err != nil { - return err - } - - if whisperModel == nil { - return fmt.Errorf("could not load whisper model") - } - - tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ - Dst: dst, - Language: input.Language, - Threads: uint32(config.Threads), - }) + tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o) if err != nil { return err } diff --git a/api/options/options.go b/api/options/options.go index 06029b0..b326947 100644 --- a/api/options/options.go +++ b/api/options/options.go @@ -28,6 +28,10 @@ type Option struct { BackendAssets embed.FS AssetsDestination string + + ExternalGRPCBackends map[string]string + + AutoloadGalleries bool } type AppOption func(*Option) @@ -53,6 +57,19 @@ func WithCors(b bool) AppOption { } } +var EnableGalleriesAutoload = func(o *Option) { + o.AutoloadGalleries = true +} + +func WithExternalBackend(name string, uri string) AppOption { + return func(o *Option) { + if o.ExternalGRPCBackends == nil { + o.ExternalGRPCBackends = make(map[string]string) + } + o.ExternalGRPCBackends[name] = uri + } +} + func WithCorsAllowOrigins(b string) AppOption { return func(o *Option) { o.CORSAllowOrigins = b diff --git a/main.go b/main.go index 3f534b0..2cb8627 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "os" "os/signal" "path/filepath" + "strings" "syscall" api "github.com/go-skynet/LocalAI/api" @@ -40,6 +41,10 @@ func main() { Name: "f16", EnvVars: []string{"F16"}, }, + &cli.BoolFlag{ + Name: "autoload-galleries", + EnvVars: []string{"AUTOLOAD_GALLERIES"}, + }, &cli.BoolFlag{ Name: "debug", EnvVars: []string{"DEBUG"}, @@ -108,6 +113,11 @@ func main() { EnvVars: []string{"BACKEND_ASSETS_PATH"}, Value: "/tmp/localai/backend_data", }, + &cli.StringSliceFlag{ + Name: "external-grpc-backends", + Usage: "A list of external grpc backends", + EnvVars: []string{"EXTERNAL_GRPC_BACKENDS"}, + }, &cli.IntFlag{ Name: "context-size", Usage: "Default context size of the model", @@ -138,7 +148,8 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit UsageText: `local-ai [options]`, Copyright: "Ettore Di Giacinto", Action: func(ctx *cli.Context) error { - app, err := api.App( + + opts := []options.AppOption{ options.WithConfigFile(ctx.String("config-file")), options.WithJSONStringPreload(ctx.String("preload-models")), options.WithYAMLConfigPreload(ctx.String("preload-models-config")), @@ -155,7 +166,22 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit options.WithThreads(ctx.Int("threads")), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), - options.WithUploadLimitMB(ctx.Int("upload-limit"))) + options.WithUploadLimitMB(ctx.Int("upload-limit")), + } + + externalgRPC := ctx.StringSlice("external-grpc-backends") + // split ":" to get backend name and the uri + for _, v := range externalgRPC { + backend := v[:strings.IndexByte(v, ':')] + uri := v[strings.IndexByte(v, ':')+1:] + opts = append(opts, options.WithExternalBackend(backend, uri)) + } + + if ctx.Bool("autoload-galleries") { + opts = append(opts, options.EnableGalleriesAutoload) + } + + app, err := api.App(opts...) if err != nil { return err } diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index 8e08592..6fe05ed 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -18,23 +18,15 @@ type Gallery struct { // Installs a model from the gallery (galleryname@modelname) func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { - - // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. - name = strings.ReplaceAll(name, string(os.PathSeparator), "__") - - models, err := AvailableGalleryModels(galleries, basePath) - if err != nil { - return err - } - applyModel := func(model *GalleryModel) error { config, err := GetGalleryConfigFromURL(model.URL) if err != nil { return err } + installName := model.Name if req.Name != "" { - model.Name = req.Name + installName = req.Name } config.Files = append(config.Files, req.AdditionalFiles...) @@ -45,20 +37,58 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, return err } - if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil { + if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus); err != nil { return err } return nil } + models, err := AvailableGalleryModels(galleries, basePath) + if err != nil { + return err + } + + model, err := FindGallery(models, name) + if err != nil { + return err + } + + return applyModel(model) +} + +func FindGallery(models []*GalleryModel, name string) (*GalleryModel, error) { + // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. + name = strings.ReplaceAll(name, string(os.PathSeparator), "__") + for _, model := range models { if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) { - return applyModel(model) + return model, nil } } + return nil, fmt.Errorf("no gallery found with name %q", name) +} + +// InstallModelFromGalleryByName loads a model from the gallery by specifying only the name (first match wins) +func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { + models, err := AvailableGalleryModels(galleries, basePath) + if err != nil { + return err + } + + name = strings.ReplaceAll(name, string(os.PathSeparator), "__") + var model *GalleryModel + for _, m := range models { + if name == m.Name { + model = m + } + } + + if model == nil { + return fmt.Errorf("no model found with name %q", name) + } - return fmt.Errorf("no model found with name %q", name) + return InstallModelFromGallery(galleries, fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name), basePath, req, downloadStatus) } // List available models diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index d3b4bb3..32c9afc 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -19,8 +19,6 @@ import ( process "github.com/mudler/go-processmanager" ) -const tokenizerSuffix = ".tokenizer.json" - const ( LlamaBackend = "llama" BloomzBackend = "bloomz" @@ -45,7 +43,6 @@ const ( StableDiffusionBackend = "stablediffusion" PiperBackend = "piper" LCHuggingFaceBackend = "langchain-huggingface" - //GGLLMFalconBackend = "falcon" ) var AutoLoadBackends []string = []string{ @@ -75,75 +72,116 @@ func (ml *ModelLoader) StopGRPC() { } } -// starts the grpcModelProcess for the backend, and returns a grpc client -// It also loads the model -func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) { - return func(s string) (*grpc.Client, error) { - log.Debug().Msgf("Loading GRPC Model", backend, *o) +func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error { + // Make sure the process is executable + if err := os.Chmod(grpcProcess, 0755); err != nil { + return err + } - grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) + log.Debug().Msgf("Loading GRPC Process", grpcProcess) - // Check if the file exists - if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { - return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) - } + log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress) - // Make sure the process is executable - if err := os.Chmod(grpcProcess, 0755); err != nil { - return nil, err - } + grpcControlProcess := process.New( + process.WithTemporaryStateDir(), + process.WithName(grpcProcess), + process.WithArgs("--addr", serverAddress)) + + ml.grpcProcesses[id] = grpcControlProcess + + if err := grpcControlProcess.Run(); err != nil { + return err + } - log.Debug().Msgf("Loading GRPC Process", grpcProcess) - port, err := freeport.GetFreePort() + log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir()) + // clean up process + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + grpcControlProcess.Stop() + }() + + go func() { + t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) if err != nil { - return nil, err + log.Debug().Msgf("Could not tail stderr") } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text) + } + }() + go func() { + t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true}) + if err != nil { + log.Debug().Msgf("Could not tail stdout") + } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text) + } + }() - serverAddress := fmt.Sprintf("localhost:%d", port) - - log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress) + return nil +} - grpcControlProcess := process.New( - process.WithTemporaryStateDir(), - process.WithName(grpcProcess), - process.WithArgs("--addr", serverAddress)) +// starts the grpcModelProcess for the backend, and returns a grpc client +// It also loads the model +func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) { + return func(s string) (*grpc.Client, error) { + log.Debug().Msgf("Loading GRPC Model", backend, *o) - ml.grpcProcesses[o.modelFile] = grpcControlProcess + var client *grpc.Client - if err := grpcControlProcess.Run(); err != nil { - return nil, err + getFreeAddress := func() (string, error) { + port, err := freeport.GetFreePort() + if err != nil { + return "", fmt.Errorf("failed allocating free ports: %s", err.Error()) + } + return fmt.Sprintf("127.0.0.1:%d", port), nil } - // clean up process - go func() { - c := make(chan os.Signal, 1) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - <-c - grpcControlProcess.Stop() - }() - - go func() { - t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) - if err != nil { - log.Debug().Msgf("Could not tail stderr") + // Check if the backend is provided as external + if uri, ok := o.externalBackends[backend]; ok { + log.Debug().Msgf("Loading external backend: %s", uri) + // check if uri is a file or a address + if _, err := os.Stat(uri); err == nil { + serverAddress, err := getFreeAddress() + if err != nil { + return nil, fmt.Errorf("failed allocating free ports: %s", err.Error()) + } + // Make sure the process is executable + if err := ml.startProcess(uri, o.modelFile, serverAddress); err != nil { + return nil, err + } + + log.Debug().Msgf("GRPC Service Started") + + client = grpc.NewClient(serverAddress) + } else { + // address + client = grpc.NewClient(uri) } - for line := range t.Lines { - log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) + } else { + grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) + // Check if the file exists + if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { + return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) } - }() - go func() { - t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true}) + + serverAddress, err := getFreeAddress() if err != nil { - log.Debug().Msgf("Could not tail stdout") + return nil, fmt.Errorf("failed allocating free ports: %s", err.Error()) } - for line := range t.Lines { - log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) + + // Make sure the process is executable + if err := ml.startProcess(grpcProcess, o.modelFile, serverAddress); err != nil { + return nil, err } - }() - log.Debug().Msgf("GRPC Service Started") + log.Debug().Msgf("GRPC Service Started") - client := grpc.NewClient(serverAddress) + client = grpc.NewClient(serverAddress) + } // Wait for the service to start up ready := false @@ -158,11 +196,6 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc if !ready { log.Debug().Msgf("GRPC Service NOT ready") - log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive()) - log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:")) - - log.Debug().Msgf(grpcControlProcess.ExitCode()) - return nil, fmt.Errorf("grpc service not ready") } @@ -189,6 +222,13 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) backend := strings.ToLower(o.backendString) + + // if an external backend is provided, use it + _, externalBackendExists := o.externalBackends[backend] + if externalBackendExists { + return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o)) + } + switch backend { case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend, MPTBackend, Gpt2Backend, FalconBackend, @@ -209,8 +249,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { o := NewOptions(opts...) - log.Debug().Msgf("Loading model '%s' greedly", o.modelFile) - // Is this really needed? BackendLoader already does this ml.mu.Lock() if m := ml.checkIsLoaded(o.modelFile); m != nil { @@ -221,16 +259,29 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { ml.mu.Unlock() var err error - for _, b := range AutoLoadBackends { - log.Debug().Msgf("[%s] Attempting to load", b) + // autoload also external backends + allBackendsToAutoLoad := []string{} + allBackendsToAutoLoad = append(allBackendsToAutoLoad, AutoLoadBackends...) + for _, b := range o.externalBackends { + allBackendsToAutoLoad = append(allBackendsToAutoLoad, b) + } + log.Debug().Msgf("Loading model '%s' greedly from all the available backends: %s", o.modelFile, strings.Join(allBackendsToAutoLoad, ", ")) - model, modelerr := ml.BackendLoader( + for _, b := range allBackendsToAutoLoad { + log.Debug().Msgf("[%s] Attempting to load", b) + options := []Option{ WithBackendString(b), WithModelFile(o.modelFile), WithLoadGRPCLLMModelOpts(o.gRPCOptions), WithThreads(o.threads), WithAssetDir(o.assetDir), - ) + } + + for k, v := range o.externalBackends { + options = append(options, WithExternalBackend(k, v)) + } + + model, modelerr := ml.BackendLoader(options...) if modelerr == nil && model != nil { log.Debug().Msgf("[%s] Loads OK", b) return model, nil diff --git a/pkg/model/options.go b/pkg/model/options.go index 298ebd4..466e9c2 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -14,10 +14,21 @@ type Options struct { context context.Context gRPCOptions *pb.ModelOptions + + externalBackends map[string]string } type Option func(*Options) +func WithExternalBackend(name string, uri string) Option { + return func(o *Options) { + if o.externalBackends == nil { + o.externalBackends = make(map[string]string) + } + o.externalBackends[name] = uri + } +} + func WithBackendString(backend string) Option { return func(o *Options) { o.backendString = backend diff --git a/pkg/utils/logging.go b/pkg/utils/logging.go new file mode 100644 index 0000000..d69cbf8 --- /dev/null +++ b/pkg/utils/logging.go @@ -0,0 +1,37 @@ +package utils + +import ( + "time" + + "github.com/rs/zerolog/log" +) + +var lastProgress time.Time = time.Now() +var startTime time.Time = time.Now() + +func ResetDownloadTimers() { + lastProgress = time.Now() + startTime = time.Now() +} + +func DisplayDownloadFunction(fileName string, current string, total string, percentage float64) { + currentTime := time.Now() + + if currentTime.Sub(lastProgress) >= 5*time.Second { + + lastProgress = currentTime + + // calculate ETA based on percentage and elapsed time + var eta time.Duration + if percentage > 0 { + elapsed := currentTime.Sub(startTime) + eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed)) + } + + if total != "" { + log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta) + } else { + log.Debug().Msgf("Downloading: %s", current) + } + } +} diff --git a/tests/models_fixtures/grpc.yaml b/tests/models_fixtures/grpc.yaml new file mode 100644 index 0000000..31c406a --- /dev/null +++ b/tests/models_fixtures/grpc.yaml @@ -0,0 +1,5 @@ +name: code-search-ada-code-001 +backend: huggingface +embeddings: true +parameters: + model: all-MiniLM-L6-v2 \ No newline at end of file From 982a7e86a80670fa37e131e916ed4e9feb380f81 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 20 Jul 2023 22:10:42 +0200 Subject: [PATCH 4/6] feat: add huggingface embeddings backend Signed-off-by: Ettore Di Giacinto --- Dockerfile | 7 +- Makefile | 13 +- api/api_test.go | 100 ++++++- extra/grpc/huggingface/backend_pb2.py | 49 ++++ extra/grpc/huggingface/backend_pb2_grpc.py | 297 +++++++++++++++++++++ extra/grpc/huggingface/huggingface.py | 67 +++++ extra/requirements.txt | 4 + 7 files changed, 529 insertions(+), 8 deletions(-) create mode 100644 extra/grpc/huggingface/backend_pb2.py create mode 100644 extra/grpc/huggingface/backend_pb2_grpc.py create mode 100755 extra/grpc/huggingface/huggingface.py create mode 100644 extra/requirements.txt diff --git a/Dockerfile b/Dockerfile index c22f5dc..5e39303 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,10 +11,15 @@ ARG TARGETARCH ARG TARGETVARIANT ENV BUILD_TYPE=${BUILD_TYPE} +ENV EXTERNAL_GRPC_BACKENDS="huggingface-embeddings:/build/extra/grpc/huggingface/huggingface.py" ARG GO_TAGS="stablediffusion tts" RUN apt-get update && \ - apt-get install -y ca-certificates cmake curl patch + apt-get install -y ca-certificates cmake curl patch pip + +# Extras requirements +COPY extra/requirements.txt /build/extra/requirements.txt +RUN pip install -r /build/extra/requirements.txt && rm -rf /build/extra/requirements.txt # CuBLAS requirements RUN if [ "${BUILD_TYPE}" = "cublas" ]; then \ diff --git a/Makefile b/Makefile index 5813ba2..afc71c4 100644 --- a/Makefile +++ b/Makefile @@ -313,7 +313,7 @@ test: prepare test-models/testmodel grpcs @echo 'Running tests' export GO_TAGS="tts stablediffusion" $(MAKE) prepare-test - TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + HUGGINGFACE_GRPC=$(abspath ./)/extra/grpc/huggingface/huggingface.py TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg $(MAKE) test-gpt4all $(MAKE) test-llama @@ -338,9 +338,7 @@ test-stablediffusion: prepare-test test-container: docker build --target requirements -t local-ai-test-container . - docker run --name localai-tests -e GO_TAGS=$(GO_TAGS) -ti -v $(abspath ./):/build local-ai-test-container make test - docker rm localai-tests - docker rmi local-ai-test-container + docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container ## Help: help: ## Show this help. @@ -354,10 +352,15 @@ help: ## Show this help. else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \ }' $(MAKEFILE_LIST) -protogen: +protogen: protogen-go protogen-python + +protogen-go: protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \ pkg/grpc/proto/backend.proto +protogen-python: + python -m grpc_tools.protoc -Ipkg/grpc/proto/ --python_out=extra/grpc/huggingface/ --grpc_python_out=extra/grpc/huggingface/ pkg/grpc/proto/backend.proto + ## GRPC backend-assets/grpc: diff --git a/api/api_test.go b/api/api_test.go index 732076c..1e53fa7 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -386,6 +386,102 @@ var _ = Describe("API test", func() { }) }) + Context("External gRPCs", func() { + BeforeEach(func() { + modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) + c, cancel = context.WithCancel(context.Background()) + + app, err := App( + append(commonOpts, + options.WithContext(c), + options.WithAudioDir(tmpdir), + options.WithImageDir(tmpdir), + options.WithModelLoader(modelLoader), + options.WithBackendAssets(backendAssets), + options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), + options.WithBackendAssetsOutput(tmpdir))..., + ) + Expect(err).ToNot(HaveOccurred()) + go app.Listen("127.0.0.1:9090") + + defaultConfig := openai.DefaultConfig("") + defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + + // Wait for API to be ready + client = openai.NewClientWithConfig(defaultConfig) + Eventually(func() error { + _, err := client.ListModels(context.TODO()) + return err + }, "2m").ShouldNot(HaveOccurred()) + }) + + AfterEach(func() { + cancel() + app.Shutdown() + os.RemoveAll(tmpdir) + }) + + Context("API query", func() { + BeforeEach(func() { + modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) + c, cancel = context.WithCancel(context.Background()) + + var err error + app, err = App( + append(commonOpts, + options.WithDebug(true), + options.WithContext(c), options.WithModelLoader(modelLoader))...) + Expect(err).ToNot(HaveOccurred()) + go app.Listen("127.0.0.1:9090") + + defaultConfig := openai.DefaultConfig("") + defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + + client2 = openaigo.NewClient("") + client2.BaseURL = defaultConfig.BaseURL + + // Wait for API to be ready + client = openai.NewClientWithConfig(defaultConfig) + Eventually(func() error { + _, err := client.ListModels(context.TODO()) + return err + }, "2m").ShouldNot(HaveOccurred()) + }) + AfterEach(func() { + cancel() + app.Shutdown() + }) + + It("calculate embeddings with huggingface", func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + resp, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaCodeSearchCode, + Input: []string{"sun", "cat"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) + Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) + + sunEmbedding := resp.Data[0].Embedding + resp2, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaCodeSearchCode, + Input: []string{"sun"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) + Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) + }) + }) + }) + Context("Model gallery", func() { BeforeEach(func() { var err error @@ -530,7 +626,7 @@ var _ = Describe("API test", func() { It("returns the models list", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(10)) + Expect(len(models.Models)).To(Equal(11)) }) It("can generate completions", func() { resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"}) @@ -707,7 +803,7 @@ var _ = Describe("API test", func() { It("can generate chat completions from config file", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(12)) + Expect(len(models.Models)).To(Equal(13)) }) It("can generate chat completions from config file", func() { resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) diff --git a/extra/grpc/huggingface/backend_pb2.py b/extra/grpc/huggingface/backend_pb2.py new file mode 100644 index 0000000..0dafdf5 --- /dev/null +++ b/extra/grpc/huggingface/backend_pb2.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: backend.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\xa4\x05\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\t\"\xac\x02\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto' + _globals['_HEALTHMESSAGE']._serialized_start=26 + _globals['_HEALTHMESSAGE']._serialized_end=41 + _globals['_PREDICTOPTIONS']._serialized_start=44 + _globals['_PREDICTOPTIONS']._serialized_end=720 + _globals['_REPLY']._serialized_start=722 + _globals['_REPLY']._serialized_end=746 + _globals['_MODELOPTIONS']._serialized_start=749 + _globals['_MODELOPTIONS']._serialized_end=1049 + _globals['_RESULT']._serialized_start=1051 + _globals['_RESULT']._serialized_end=1093 + _globals['_EMBEDDINGRESULT']._serialized_start=1095 + _globals['_EMBEDDINGRESULT']._serialized_end=1132 + _globals['_TRANSCRIPTREQUEST']._serialized_start=1134 + _globals['_TRANSCRIPTREQUEST']._serialized_end=1201 + _globals['_TRANSCRIPTRESULT']._serialized_start=1203 + _globals['_TRANSCRIPTRESULT']._serialized_end=1281 + _globals['_TRANSCRIPTSEGMENT']._serialized_start=1283 + _globals['_TRANSCRIPTSEGMENT']._serialized_end=1372 + _globals['_GENERATEIMAGEREQUEST']._serialized_start=1375 + _globals['_GENERATEIMAGEREQUEST']._serialized_end=1533 + _globals['_TTSREQUEST']._serialized_start=1535 + _globals['_TTSREQUEST']._serialized_end=1589 + _globals['_BACKEND']._serialized_start=1592 + _globals['_BACKEND']._serialized_end=2083 +# @@protoc_insertion_point(module_scope) diff --git a/extra/grpc/huggingface/backend_pb2_grpc.py b/extra/grpc/huggingface/backend_pb2_grpc.py new file mode 100644 index 0000000..301c072 --- /dev/null +++ b/extra/grpc/huggingface/backend_pb2_grpc.py @@ -0,0 +1,297 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import backend_pb2 as backend__pb2 + + +class BackendStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Health = channel.unary_unary( + '/backend.Backend/Health', + request_serializer=backend__pb2.HealthMessage.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.Predict = channel.unary_unary( + '/backend.Backend/Predict', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.LoadModel = channel.unary_unary( + '/backend.Backend/LoadModel', + request_serializer=backend__pb2.ModelOptions.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.PredictStream = channel.unary_stream( + '/backend.Backend/PredictStream', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.Reply.FromString, + ) + self.Embedding = channel.unary_unary( + '/backend.Backend/Embedding', + request_serializer=backend__pb2.PredictOptions.SerializeToString, + response_deserializer=backend__pb2.EmbeddingResult.FromString, + ) + self.GenerateImage = channel.unary_unary( + '/backend.Backend/GenerateImage', + request_serializer=backend__pb2.GenerateImageRequest.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + self.AudioTranscription = channel.unary_unary( + '/backend.Backend/AudioTranscription', + request_serializer=backend__pb2.TranscriptRequest.SerializeToString, + response_deserializer=backend__pb2.TranscriptResult.FromString, + ) + self.TTS = channel.unary_unary( + '/backend.Backend/TTS', + request_serializer=backend__pb2.TTSRequest.SerializeToString, + response_deserializer=backend__pb2.Result.FromString, + ) + + +class BackendServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Health(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Predict(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def LoadModel(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def PredictStream(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Embedding(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GenerateImage(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def AudioTranscription(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def TTS(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_BackendServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Health': grpc.unary_unary_rpc_method_handler( + servicer.Health, + request_deserializer=backend__pb2.HealthMessage.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'Predict': grpc.unary_unary_rpc_method_handler( + servicer.Predict, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'LoadModel': grpc.unary_unary_rpc_method_handler( + servicer.LoadModel, + request_deserializer=backend__pb2.ModelOptions.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'PredictStream': grpc.unary_stream_rpc_method_handler( + servicer.PredictStream, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.Reply.SerializeToString, + ), + 'Embedding': grpc.unary_unary_rpc_method_handler( + servicer.Embedding, + request_deserializer=backend__pb2.PredictOptions.FromString, + response_serializer=backend__pb2.EmbeddingResult.SerializeToString, + ), + 'GenerateImage': grpc.unary_unary_rpc_method_handler( + servicer.GenerateImage, + request_deserializer=backend__pb2.GenerateImageRequest.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + 'AudioTranscription': grpc.unary_unary_rpc_method_handler( + servicer.AudioTranscription, + request_deserializer=backend__pb2.TranscriptRequest.FromString, + response_serializer=backend__pb2.TranscriptResult.SerializeToString, + ), + 'TTS': grpc.unary_unary_rpc_method_handler( + servicer.TTS, + request_deserializer=backend__pb2.TTSRequest.FromString, + response_serializer=backend__pb2.Result.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'backend.Backend', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Backend(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Health(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health', + backend__pb2.HealthMessage.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Predict(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def LoadModel(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel', + backend__pb2.ModelOptions.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def PredictStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.Reply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Embedding(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding', + backend__pb2.PredictOptions.SerializeToString, + backend__pb2.EmbeddingResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def GenerateImage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage', + backend__pb2.GenerateImageRequest.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def AudioTranscription(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription', + backend__pb2.TranscriptRequest.SerializeToString, + backend__pb2.TranscriptResult.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def TTS(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS', + backend__pb2.TTSRequest.SerializeToString, + backend__pb2.Result.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/extra/grpc/huggingface/huggingface.py b/extra/grpc/huggingface/huggingface.py new file mode 100755 index 0000000..adf9876 --- /dev/null +++ b/extra/grpc/huggingface/huggingface.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +import grpc +from concurrent import futures +import time +import backend_pb2 +import backend_pb2_grpc +import argparse +import signal +import sys +import os +from sentence_transformers import SentenceTransformer + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 + +# Implement the BackendServicer class with the service methods +class BackendServicer(backend_pb2_grpc.BackendServicer): + def Health(self, request, context): + return backend_pb2.Reply(message="OK") + def LoadModel(self, request, context): + model_name = request.Model + model_name = os.path.basename(model_name) + try: + self.model = SentenceTransformer(model_name) + except Exception as err: + return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") + # Implement your logic here for the LoadModel service + # Replace this with your desired response + return backend_pb2.Result(message="Model loaded successfully", success=True) + def Embedding(self, request, context): + # Implement your logic here for the Embedding service + # Replace this with your desired response + print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) + sentence_embeddings = self.model.encode(request.Embeddings) + return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings) + + +def serve(address): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) + server.add_insecure_port(address) + server.start() + print("Server started. Listening on: " + address, file=sys.stderr) + + # Define the signal handler function + def signal_handler(sig, frame): + print("Received termination signal. Shutting down...") + server.stop(0) + sys.exit(0) + + # Set the signal handlers for SIGINT and SIGTERM + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run the gRPC server.") + parser.add_argument( + "--addr", default="localhost:50051", help="The address to bind the server to." + ) + args = parser.parse_args() + + serve(args.addr) \ No newline at end of file diff --git a/extra/requirements.txt b/extra/requirements.txt new file mode 100644 index 0000000..9744afb --- /dev/null +++ b/extra/requirements.txt @@ -0,0 +1,4 @@ +sentence_transformers +grpcio +google +protobuf \ No newline at end of file From e459f114cd01bb3c3155098fe231b27f8324b004 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 20 Jul 2023 23:45:29 +0200 Subject: [PATCH 5/6] fix: fix tests, small refactors Signed-off-by: Ettore Di Giacinto --- .github/workflows/test.yml | 2 +- .gitignore | 3 +- api/api_test.go | 91 +++++++++++++------------------------- 3 files changed, 34 insertions(+), 62 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5b8385c..5a0f502 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,6 +29,7 @@ jobs: sudo apt-get install -y ca-certificates cmake curl patch sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2 + sudo pip install -r extra/requirements.txt sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \ curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v1.11.0.tar.gz" | \ @@ -45,7 +46,6 @@ jobs: sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /lib64/ && \ sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \ sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/ - - name: Test run: | ESPEAK_DATA="/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data" GO_TAGS="tts stablediffusion" make test diff --git a/.gitignore b/.gitignore index 7b35ba9..a1c1c06 100644 --- a/.gitignore +++ b/.gitignore @@ -3,9 +3,10 @@ go-llama /gpt4all go-stable-diffusion go-piper +/go-bert go-ggllm /piper - +__pycache__/ *.a get-sources diff --git a/api/api_test.go b/api/api_test.go index 1e53fa7..2ffcb71 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -124,8 +124,11 @@ var _ = Describe("API test", func() { var c context.Context var cancel context.CancelFunc var tmpdir string - commonOpts := []options.AppOption{options.WithDebug(false), - options.WithDisableMessage(true)} + + commonOpts := []options.AppOption{ + options.WithDebug(true), + options.WithDisableMessage(true), + } Context("API with ephemeral models", func() { BeforeEach(func() { @@ -145,7 +148,7 @@ var _ = Describe("API test", func() { Name: "bert2", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", Overrides: map[string]interface{}{"foo": "bar"}, - AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}}, + AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}}, }, } out, err := yaml.Marshal(g) @@ -421,64 +424,32 @@ var _ = Describe("API test", func() { os.RemoveAll(tmpdir) }) - Context("API query", func() { - BeforeEach(func() { - modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - c, cancel = context.WithCancel(context.Background()) - - var err error - app, err = App( - append(commonOpts, - options.WithDebug(true), - options.WithContext(c), options.WithModelLoader(modelLoader))...) - Expect(err).ToNot(HaveOccurred()) - go app.Listen("127.0.0.1:9090") - - defaultConfig := openai.DefaultConfig("") - defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" - - client2 = openaigo.NewClient("") - client2.BaseURL = defaultConfig.BaseURL - - // Wait for API to be ready - client = openai.NewClientWithConfig(defaultConfig) - Eventually(func() error { - _, err := client.ListModels(context.TODO()) - return err - }, "2m").ShouldNot(HaveOccurred()) - }) - AfterEach(func() { - cancel() - app.Shutdown() - }) + It("calculate embeddings with huggingface", func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + resp, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaCodeSearchCode, + Input: []string{"sun", "cat"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) + Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) - It("calculate embeddings with huggingface", func() { - if runtime.GOOS != "linux" { - Skip("test supported only on linux") - } - resp, err := client.CreateEmbeddings( - context.Background(), - openai.EmbeddingRequest{ - Model: openai.AdaCodeSearchCode, - Input: []string{"sun", "cat"}, - }, - ) - Expect(err).ToNot(HaveOccurred()) - Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) - Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) - - sunEmbedding := resp.Data[0].Embedding - resp2, err := client.CreateEmbeddings( - context.Background(), - openai.EmbeddingRequest{ - Model: openai.AdaCodeSearchCode, - Input: []string{"sun"}, - }, - ) - Expect(err).ToNot(HaveOccurred()) - Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) - Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) - }) + sunEmbedding := resp.Data[0].Embedding + resp2, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaCodeSearchCode, + Input: []string{"sun"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) + Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) }) }) From c71c729bc240525af7eabfbeda711b77d8b88a12 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 21 Jul 2023 00:52:43 +0200 Subject: [PATCH 6/6] debug --- api/api_test.go | 101 +++++++++++++------------------------- api/backend/llm.go | 3 ++ pkg/grpc/tts/piper.go | 5 ++ pkg/model/initializers.go | 6 +-- 4 files changed, 46 insertions(+), 69 deletions(-) diff --git a/api/api_test.go b/api/api_test.go index 2ffcb71..a602229 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -389,70 +389,6 @@ var _ = Describe("API test", func() { }) }) - Context("External gRPCs", func() { - BeforeEach(func() { - modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - c, cancel = context.WithCancel(context.Background()) - - app, err := App( - append(commonOpts, - options.WithContext(c), - options.WithAudioDir(tmpdir), - options.WithImageDir(tmpdir), - options.WithModelLoader(modelLoader), - options.WithBackendAssets(backendAssets), - options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), - options.WithBackendAssetsOutput(tmpdir))..., - ) - Expect(err).ToNot(HaveOccurred()) - go app.Listen("127.0.0.1:9090") - - defaultConfig := openai.DefaultConfig("") - defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" - - // Wait for API to be ready - client = openai.NewClientWithConfig(defaultConfig) - Eventually(func() error { - _, err := client.ListModels(context.TODO()) - return err - }, "2m").ShouldNot(HaveOccurred()) - }) - - AfterEach(func() { - cancel() - app.Shutdown() - os.RemoveAll(tmpdir) - }) - - It("calculate embeddings with huggingface", func() { - if runtime.GOOS != "linux" { - Skip("test supported only on linux") - } - resp, err := client.CreateEmbeddings( - context.Background(), - openai.EmbeddingRequest{ - Model: openai.AdaCodeSearchCode, - Input: []string{"sun", "cat"}, - }, - ) - Expect(err).ToNot(HaveOccurred()) - Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) - Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) - - sunEmbedding := resp.Data[0].Embedding - resp2, err := client.CreateEmbeddings( - context.Background(), - openai.EmbeddingRequest{ - Model: openai.AdaCodeSearchCode, - Input: []string{"sun"}, - }, - ) - Expect(err).ToNot(HaveOccurred()) - Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) - Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) - }) - }) - Context("Model gallery", func() { BeforeEach(func() { var err error @@ -573,7 +509,10 @@ var _ = Describe("API test", func() { var err error app, err = App( append(commonOpts, - options.WithContext(c), options.WithModelLoader(modelLoader))...) + options.WithExternalBackend("huggingface", os.Getenv("HUGGINGFACE_GRPC")), + options.WithContext(c), + options.WithModelLoader(modelLoader), + )...) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -628,7 +567,7 @@ var _ = Describe("API test", func() { }) It("returns errors", func() { - backends := len(model.AutoLoadBackends) + backends := len(model.AutoLoadBackends) + 1 // +1 for huggingface _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring(fmt.Sprintf("error, status code: 500, message: could not load model - all backends returned error: %d errors occurred:", backends))) @@ -675,6 +614,36 @@ var _ = Describe("API test", func() { Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) }) + Context("External gRPC calls", func() { + It("calculate embeddings with huggingface", func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + resp, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaCodeSearchCode, + Input: []string{"sun", "cat"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) + Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) + + sunEmbedding := resp.Data[0].Embedding + resp2, err := client.CreateEmbeddings( + context.Background(), + openai.EmbeddingRequest{ + Model: openai.AdaCodeSearchCode, + Input: []string{"sun"}, + }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) + Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) + }) + }) + Context("backends", func() { It("runs rwkv completion", func() { if runtime.GOOS != "linux" { diff --git a/api/backend/llm.go b/api/backend/llm.go index 593eea3..23a5ca4 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -73,6 +73,9 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt return ss, err } else { reply, err := inferenceModel.Predict(o.Context, opts) + if err != nil { + return "", err + } return reply.Message, err } } diff --git a/pkg/grpc/tts/piper.go b/pkg/grpc/tts/piper.go index dbaa4b7..3bc85e0 100644 --- a/pkg/grpc/tts/piper.go +++ b/pkg/grpc/tts/piper.go @@ -3,7 +3,9 @@ package tts // This is a wrapper to statisfy the GRPC service interface // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( + "fmt" "os" + "path/filepath" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" @@ -16,6 +18,9 @@ type Piper struct { } func (sd *Piper) Load(opts *pb.ModelOptions) error { + if filepath.Ext(opts.Model) != ".onnx" { + return fmt.Errorf("unsupported model type %s (should end with .onnx)", opts.Model) + } var err error // Note: the Model here is a path to a directory containing the model files sd.piper, err = New(opts.LibrarySearchPath) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 32c9afc..08bf6c4 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -206,10 +206,10 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc res, err := client.LoadModel(o.context, &options) if err != nil { - return nil, err + return nil, fmt.Errorf("could not load model: %w", err) } if !res.Success { - return nil, fmt.Errorf("could not load model: %s", res.Message) + return nil, fmt.Errorf("could not load model (no success): %s", res.Message) } return client, nil @@ -289,7 +289,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { err = multierror.Append(err, modelerr) log.Debug().Msgf("[%s] Fails: %s", b, modelerr.Error()) } else if model == nil { - err = multierror.Append(err, modelerr) + err = multierror.Append(err, fmt.Errorf("backend returned no usable model")) log.Debug().Msgf("[%s] Fails: %s", b, "backend returned no usable model") } }