feat: add external grpc and model autoloading

renovate/github.com-nomic-ai-gpt4all-gpt4all-bindings-golang-digest
Ettore Di Giacinto 1 year ago
parent 1d2ae46ddc
commit 94916749c5
  1. 4
      api/backend/embeddings.go
  2. 10
      api/backend/image.go
  3. 25
      api/backend/llm.go
  4. 42
      api/backend/transcript.go
  5. 72
      api/backend/tts.go
  6. 42
      api/localai/gallery.go
  7. 57
      api/localai/localai.go
  8. 24
      api/openai/transcription.go
  9. 17
      api/options/options.go
  10. 30
      main.go
  11. 54
      pkg/gallery/gallery.go
  12. 131
      pkg/model/initializers.go
  13. 11
      pkg/model/options.go
  14. 37
      pkg/utils/logging.go
  15. 5
      tests/models_fixtures/grpc.yaml

@ -30,6 +30,10 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
model.WithContext(o.Context), model.WithContext(o.Context),
} }
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
if c.Backend == "" { if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...) inferenceModel, err = loader.GreedyLoader(opts...)
} else { } else {

@ -15,12 +15,20 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
return nil, fmt.Errorf("endpoint only working with stablediffusion models") return nil, fmt.Errorf("endpoint only working with stablediffusion models")
} }
inferenceModel, err := loader.BackendLoader( opts := []model.Option{
model.WithBackendString(c.Backend), model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)), model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context), model.WithContext(o.Context),
model.WithModelFile(c.ImageGenerationAssets), model.WithModelFile(c.ImageGenerationAssets),
}
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
inferenceModel, err := loader.BackendLoader(
opts...,
) )
if err != nil { if err != nil {
return nil, err return nil, err

@ -1,14 +1,17 @@
package backend package backend
import ( import (
"os"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
) )
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
@ -27,12 +30,32 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
model.WithContext(o.Context), model.WithContext(o.Context),
} }
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
}
// Check if the modelFile exists, if it doesn't try to load it from the gallery
if o.AutoloadGalleries { // experimental
if _, err := os.Stat(modelFile); os.IsNotExist(err) {
utils.ResetDownloadTimers()
// if we failed to load the model, we try to download it
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction)
if err != nil {
return nil, err
}
}
}
if c.Backend == "" { if c.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...) inferenceModel, err = loader.GreedyLoader(opts...)
} else { } else {
opts = append(opts, model.WithBackendString(c.Backend))
inferenceModel, err = loader.BackendLoader(opts...) inferenceModel, err = loader.BackendLoader(opts...)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }

@ -0,0 +1,42 @@
package backend
import (
"context"
"fmt"
config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
model "github.com/go-skynet/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) {
opts := []model.Option{
model.WithBackendString(model.WhisperBackend),
model.WithModelFile(c.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination),
}
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
whisperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
return nil, err
}
if whisperModel == nil {
return nil, fmt.Errorf("could not load whisper model")
}
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Threads: uint32(c.Threads),
})
}

@ -0,0 +1,72 @@
package backend
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
)
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}
counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}
func ModelTTS(text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) {
opts := []model.Option{
model.WithBackendString(model.PiperBackend),
model.WithModelFile(modelFile),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination),
}
for k, v := range o.ExternalGRPCBackends {
opts = append(opts, model.WithExternalBackend(k, v))
}
piperModel, err := o.Loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}
if piperModel == nil {
return "", nil, fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
return "", nil, fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
filePath := filepath.Join(o.AudioDir, fileName)
modelPath := filepath.Join(o.Loader.ModelPath, modelFile)
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
return "", nil, err
}
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
Text: text,
Model: modelPath,
Dst: filePath,
})
return filePath, res, err
}

@ -4,13 +4,15 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"strings"
"sync" "sync"
"time"
json "github.com/json-iterator/go" json "github.com/json-iterator/go"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/gallery"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -80,6 +82,8 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
case <-c.Done(): case <-c.Done():
return return
case op := <-g.C: case op := <-g.C:
utils.ResetDownloadTimers()
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0})
// updates the status with an error // updates the status with an error
@ -90,13 +94,17 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
// displayDownload displays the download progress // displayDownload displays the download progress
progressCallback := func(fileName string, current string, total string, percentage float64) { progressCallback := func(fileName string, current string, total string, percentage float64) {
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current})
displayDownload(fileName, current, total, percentage) utils.DisplayDownloadFunction(fileName, current, total, percentage)
} }
var err error var err error
// if the request contains a gallery name, we apply the gallery from the gallery list // if the request contains a gallery name, we apply the gallery from the gallery list
if op.galleryName != "" { if op.galleryName != "" {
if strings.Contains(op.galleryName, "@") {
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
} else {
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback)
}
} else { } else {
err = prepareModel(g.modelPath, op.req, cm, progressCallback) err = prepareModel(g.modelPath, op.req, cm, progressCallback)
} }
@ -119,31 +127,6 @@ func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) {
}() }()
} }
var lastProgress time.Time = time.Now()
var startTime time.Time = time.Now()
func displayDownload(fileName string, current string, total string, percentage float64) {
currentTime := time.Now()
if currentTime.Sub(lastProgress) >= 5*time.Second {
lastProgress = currentTime
// calculate ETA based on percentage and elapsed time
var eta time.Duration
if percentage > 0 {
elapsed := currentTime.Sub(startTime)
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
}
if total != "" {
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
} else {
log.Debug().Msgf("Downloading: %s", current)
}
}
}
type galleryModel struct { type galleryModel struct {
gallery.GalleryModel gallery.GalleryModel
ID string `json:"id"` ID string `json:"id"`
@ -165,10 +148,11 @@ func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galler
} }
for _, r := range requests { for _, r := range requests {
utils.ResetDownloadTimers()
if r.ID == "" { if r.ID == "" {
err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload) err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction)
} else { } else {
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload) err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction)
} }
} }

@ -1,17 +1,10 @@
package localai package localai
import ( import (
"context" "github.com/go-skynet/LocalAI/api/backend"
"fmt"
"os"
"path/filepath"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
@ -20,22 +13,6 @@ type TTSRequest struct {
Input string `json:"input" yaml:"input"` Input string `json:"input" yaml:"input"`
} }
func generateUniqueFileName(dir, baseName, ext string) string {
counter := 1
fileName := baseName + ext
for {
filePath := filepath.Join(dir, fileName)
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
return fileName
}
counter++
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext)
}
}
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
@ -45,40 +22,10 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return err return err
} }
piperModel, err := o.Loader.BackendLoader( filePath, _, err := backend.ModelTTS(input.Input, input.Model, o.Loader, o)
model.WithBackendString(model.PiperBackend),
model.WithModelFile(input.Model),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination))
if err != nil { if err != nil {
return err return err
} }
if piperModel == nil {
return fmt.Errorf("could not load piper model")
}
if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
return fmt.Errorf("failed creating audio directory: %s", err)
}
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
filePath := filepath.Join(o.AudioDir, fileName)
modelPath := filepath.Join(o.Loader.ModelPath, input.Model)
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil {
return err
}
if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
Text: input.Input,
Model: modelPath,
Dst: filePath,
}); err != nil {
return err
}
return c.Download(filePath) return c.Download(filePath)
} }
} }

@ -1,7 +1,6 @@
package openai package openai
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -9,10 +8,9 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"github.com/go-skynet/LocalAI/api/backend"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -61,25 +59,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
log.Debug().Msgf("Audio file copied to: %+v", dst) log.Debug().Msgf("Audio file copied to: %+v", dst)
whisperModel, err := o.Loader.BackendLoader( tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o)
model.WithBackendString(model.WhisperBackend),
model.WithModelFile(config.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(config.Threads)),
model.WithAssetDir(o.AssetsDestination))
if err != nil {
return err
}
if whisperModel == nil {
return fmt.Errorf("could not load whisper model")
}
tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: dst,
Language: input.Language,
Threads: uint32(config.Threads),
})
if err != nil { if err != nil {
return err return err
} }

@ -28,6 +28,10 @@ type Option struct {
BackendAssets embed.FS BackendAssets embed.FS
AssetsDestination string AssetsDestination string
ExternalGRPCBackends map[string]string
AutoloadGalleries bool
} }
type AppOption func(*Option) type AppOption func(*Option)
@ -53,6 +57,19 @@ func WithCors(b bool) AppOption {
} }
} }
var EnableGalleriesAutoload = func(o *Option) {
o.AutoloadGalleries = true
}
func WithExternalBackend(name string, uri string) AppOption {
return func(o *Option) {
if o.ExternalGRPCBackends == nil {
o.ExternalGRPCBackends = make(map[string]string)
}
o.ExternalGRPCBackends[name] = uri
}
}
func WithCorsAllowOrigins(b string) AppOption { func WithCorsAllowOrigins(b string) AppOption {
return func(o *Option) { return func(o *Option) {
o.CORSAllowOrigins = b o.CORSAllowOrigins = b

@ -4,6 +4,7 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"strings"
"syscall" "syscall"
api "github.com/go-skynet/LocalAI/api" api "github.com/go-skynet/LocalAI/api"
@ -40,6 +41,10 @@ func main() {
Name: "f16", Name: "f16",
EnvVars: []string{"F16"}, EnvVars: []string{"F16"},
}, },
&cli.BoolFlag{
Name: "autoload-galleries",
EnvVars: []string{"AUTOLOAD_GALLERIES"},
},
&cli.BoolFlag{ &cli.BoolFlag{
Name: "debug", Name: "debug",
EnvVars: []string{"DEBUG"}, EnvVars: []string{"DEBUG"},
@ -108,6 +113,11 @@ func main() {
EnvVars: []string{"BACKEND_ASSETS_PATH"}, EnvVars: []string{"BACKEND_ASSETS_PATH"},
Value: "/tmp/localai/backend_data", Value: "/tmp/localai/backend_data",
}, },
&cli.StringSliceFlag{
Name: "external-grpc-backends",
Usage: "A list of external grpc backends",
EnvVars: []string{"EXTERNAL_GRPC_BACKENDS"},
},
&cli.IntFlag{ &cli.IntFlag{
Name: "context-size", Name: "context-size",
Usage: "Default context size of the model", Usage: "Default context size of the model",
@ -138,7 +148,8 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
UsageText: `local-ai [options]`, UsageText: `local-ai [options]`,
Copyright: "Ettore Di Giacinto", Copyright: "Ettore Di Giacinto",
Action: func(ctx *cli.Context) error { Action: func(ctx *cli.Context) error {
app, err := api.App(
opts := []options.AppOption{
options.WithConfigFile(ctx.String("config-file")), options.WithConfigFile(ctx.String("config-file")),
options.WithJSONStringPreload(ctx.String("preload-models")), options.WithJSONStringPreload(ctx.String("preload-models")),
options.WithYAMLConfigPreload(ctx.String("preload-models-config")), options.WithYAMLConfigPreload(ctx.String("preload-models-config")),
@ -155,7 +166,22 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
options.WithThreads(ctx.Int("threads")), options.WithThreads(ctx.Int("threads")),
options.WithBackendAssets(backendAssets), options.WithBackendAssets(backendAssets),
options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), options.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
options.WithUploadLimitMB(ctx.Int("upload-limit"))) options.WithUploadLimitMB(ctx.Int("upload-limit")),
}
externalgRPC := ctx.StringSlice("external-grpc-backends")
// split ":" to get backend name and the uri
for _, v := range externalgRPC {
backend := v[:strings.IndexByte(v, ':')]
uri := v[strings.IndexByte(v, ':')+1:]
opts = append(opts, options.WithExternalBackend(backend, uri))
}
if ctx.Bool("autoload-galleries") {
opts = append(opts, options.EnableGalleriesAutoload)
}
app, err := api.App(opts...)
if err != nil { if err != nil {
return err return err
} }

@ -18,23 +18,15 @@ type Gallery struct {
// Installs a model from the gallery (galleryname@modelname) // Installs a model from the gallery (galleryname@modelname)
func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
models, err := AvailableGalleryModels(galleries, basePath)
if err != nil {
return err
}
applyModel := func(model *GalleryModel) error { applyModel := func(model *GalleryModel) error {
config, err := GetGalleryConfigFromURL(model.URL) config, err := GetGalleryConfigFromURL(model.URL)
if err != nil { if err != nil {
return err return err
} }
installName := model.Name
if req.Name != "" { if req.Name != "" {
model.Name = req.Name installName = req.Name
} }
config.Files = append(config.Files, req.AdditionalFiles...) config.Files = append(config.Files, req.AdditionalFiles...)
@ -45,22 +37,60 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string,
return err return err
} }
if err := InstallModel(basePath, model.Name, &config, model.Overrides, downloadStatus); err != nil { if err := InstallModel(basePath, installName, &config, model.Overrides, downloadStatus); err != nil {
return err return err
} }
return nil return nil
} }
models, err := AvailableGalleryModels(galleries, basePath)
if err != nil {
return err
}
model, err := FindGallery(models, name)
if err != nil {
return err
}
return applyModel(model)
}
func FindGallery(models []*GalleryModel, name string) (*GalleryModel, error) {
// os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths.
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
for _, model := range models { for _, model := range models {
if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) { if name == fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name) {
return applyModel(model) return model, nil
}
}
return nil, fmt.Errorf("no gallery found with name %q", name)
}
// InstallModelFromGalleryByName loads a model from the gallery by specifying only the name (first match wins)
func InstallModelFromGalleryByName(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error {
models, err := AvailableGalleryModels(galleries, basePath)
if err != nil {
return err
}
name = strings.ReplaceAll(name, string(os.PathSeparator), "__")
var model *GalleryModel
for _, m := range models {
if name == m.Name {
model = m
} }
} }
if model == nil {
return fmt.Errorf("no model found with name %q", name) return fmt.Errorf("no model found with name %q", name)
} }
return InstallModelFromGallery(galleries, fmt.Sprintf("%s@%s", model.Gallery.Name, model.Name), basePath, req, downloadStatus)
}
// List available models // List available models
// Models galleries are a list of json files that are hosted on a remote server (for example github). // Models galleries are a list of json files that are hosted on a remote server (for example github).
// Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting. // Each json file contains a list of models that can be downloaded and optionally overrides to define a new model setting.

@ -19,8 +19,6 @@ import (
process "github.com/mudler/go-processmanager" process "github.com/mudler/go-processmanager"
) )
const tokenizerSuffix = ".tokenizer.json"
const ( const (
LlamaBackend = "llama" LlamaBackend = "llama"
BloomzBackend = "bloomz" BloomzBackend = "bloomz"
@ -45,7 +43,6 @@ const (
StableDiffusionBackend = "stablediffusion" StableDiffusionBackend = "stablediffusion"
PiperBackend = "piper" PiperBackend = "piper"
LCHuggingFaceBackend = "langchain-huggingface" LCHuggingFaceBackend = "langchain-huggingface"
//GGLLMFalconBackend = "falcon"
) )
var AutoLoadBackends []string = []string{ var AutoLoadBackends []string = []string{
@ -75,45 +72,28 @@ func (ml *ModelLoader) StopGRPC() {
} }
} }
// starts the grpcModelProcess for the backend, and returns a grpc client func (ml *ModelLoader) startProcess(grpcProcess, id string, serverAddress string) error {
// It also loads the model
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
return func(s string) (*grpc.Client, error) {
log.Debug().Msgf("Loading GRPC Model", backend, *o)
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
// Check if the file exists
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
}
// Make sure the process is executable // Make sure the process is executable
if err := os.Chmod(grpcProcess, 0755); err != nil { if err := os.Chmod(grpcProcess, 0755); err != nil {
return nil, err return err
} }
log.Debug().Msgf("Loading GRPC Process", grpcProcess) log.Debug().Msgf("Loading GRPC Process", grpcProcess)
port, err := freeport.GetFreePort()
if err != nil {
return nil, err
}
serverAddress := fmt.Sprintf("localhost:%d", port) log.Debug().Msgf("GRPC Service for %s will be running at: '%s'", id, serverAddress)
log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress)
grpcControlProcess := process.New( grpcControlProcess := process.New(
process.WithTemporaryStateDir(), process.WithTemporaryStateDir(),
process.WithName(grpcProcess), process.WithName(grpcProcess),
process.WithArgs("--addr", serverAddress)) process.WithArgs("--addr", serverAddress))
ml.grpcProcesses[o.modelFile] = grpcControlProcess ml.grpcProcesses[id] = grpcControlProcess
if err := grpcControlProcess.Run(); err != nil { if err := grpcControlProcess.Run(); err != nil {
return nil, err return err
} }
log.Debug().Msgf("GRPC Service state dir: %s", grpcControlProcess.StateDir())
// clean up process // clean up process
go func() { go func() {
c := make(chan os.Signal, 1) c := make(chan os.Signal, 1)
@ -128,7 +108,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
log.Debug().Msgf("Could not tail stderr") log.Debug().Msgf("Could not tail stderr")
} }
for line := range t.Lines { for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
} }
}() }()
go func() { go func() {
@ -137,13 +117,71 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
log.Debug().Msgf("Could not tail stdout") log.Debug().Msgf("Could not tail stdout")
} }
for line := range t.Lines { for line := range t.Lines {
log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) log.Debug().Msgf("GRPC(%s): stdout %s", strings.Join([]string{id, serverAddress}, "-"), line.Text)
} }
}() }()
return nil
}
// starts the grpcModelProcess for the backend, and returns a grpc client
// It also loads the model
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
return func(s string) (*grpc.Client, error) {
log.Debug().Msgf("Loading GRPC Model", backend, *o)
var client *grpc.Client
getFreeAddress := func() (string, error) {
port, err := freeport.GetFreePort()
if err != nil {
return "", fmt.Errorf("failed allocating free ports: %s", err.Error())
}
return fmt.Sprintf("127.0.0.1:%d", port), nil
}
// Check if the backend is provided as external
if uri, ok := o.externalBackends[backend]; ok {
log.Debug().Msgf("Loading external backend: %s", uri)
// check if uri is a file or a address
if _, err := os.Stat(uri); err == nil {
serverAddress, err := getFreeAddress()
if err != nil {
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
}
// Make sure the process is executable
if err := ml.startProcess(uri, o.modelFile, serverAddress); err != nil {
return nil, err
}
log.Debug().Msgf("GRPC Service Started")
client = grpc.NewClient(serverAddress)
} else {
// address
client = grpc.NewClient(uri)
}
} else {
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
// Check if the file exists
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
}
serverAddress, err := getFreeAddress()
if err != nil {
return nil, fmt.Errorf("failed allocating free ports: %s", err.Error())
}
// Make sure the process is executable
if err := ml.startProcess(grpcProcess, o.modelFile, serverAddress); err != nil {
return nil, err
}
log.Debug().Msgf("GRPC Service Started") log.Debug().Msgf("GRPC Service Started")
client := grpc.NewClient(serverAddress) client = grpc.NewClient(serverAddress)
}
// Wait for the service to start up // Wait for the service to start up
ready := false ready := false
@ -158,11 +196,6 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc
if !ready { if !ready {
log.Debug().Msgf("GRPC Service NOT ready") log.Debug().Msgf("GRPC Service NOT ready")
log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive())
log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:"))
log.Debug().Msgf(grpcControlProcess.ExitCode())
return nil, fmt.Errorf("grpc service not ready") return nil, fmt.Errorf("grpc service not ready")
} }
@ -189,6 +222,13 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile)
backend := strings.ToLower(o.backendString) backend := strings.ToLower(o.backendString)
// if an external backend is provided, use it
_, externalBackendExists := o.externalBackends[backend]
if externalBackendExists {
return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o))
}
switch backend { switch backend {
case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend, case LlamaBackend, LlamaGrammarBackend, GPTJBackend, DollyBackend,
MPTBackend, Gpt2Backend, FalconBackend, MPTBackend, Gpt2Backend, FalconBackend,
@ -209,8 +249,6 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err er
func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
o := NewOptions(opts...) o := NewOptions(opts...)
log.Debug().Msgf("Loading model '%s' greedly", o.modelFile)
// Is this really needed? BackendLoader already does this // Is this really needed? BackendLoader already does this
ml.mu.Lock() ml.mu.Lock()
if m := ml.checkIsLoaded(o.modelFile); m != nil { if m := ml.checkIsLoaded(o.modelFile); m != nil {
@ -221,16 +259,29 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
ml.mu.Unlock() ml.mu.Unlock()
var err error var err error
for _, b := range AutoLoadBackends { // autoload also external backends
log.Debug().Msgf("[%s] Attempting to load", b) allBackendsToAutoLoad := []string{}
allBackendsToAutoLoad = append(allBackendsToAutoLoad, AutoLoadBackends...)
for _, b := range o.externalBackends {
allBackendsToAutoLoad = append(allBackendsToAutoLoad, b)
}
log.Debug().Msgf("Loading model '%s' greedly from all the available backends: %s", o.modelFile, strings.Join(allBackendsToAutoLoad, ", "))
model, modelerr := ml.BackendLoader( for _, b := range allBackendsToAutoLoad {
log.Debug().Msgf("[%s] Attempting to load", b)
options := []Option{
WithBackendString(b), WithBackendString(b),
WithModelFile(o.modelFile), WithModelFile(o.modelFile),
WithLoadGRPCLLMModelOpts(o.gRPCOptions), WithLoadGRPCLLMModelOpts(o.gRPCOptions),
WithThreads(o.threads), WithThreads(o.threads),
WithAssetDir(o.assetDir), WithAssetDir(o.assetDir),
) }
for k, v := range o.externalBackends {
options = append(options, WithExternalBackend(k, v))
}
model, modelerr := ml.BackendLoader(options...)
if modelerr == nil && model != nil { if modelerr == nil && model != nil {
log.Debug().Msgf("[%s] Loads OK", b) log.Debug().Msgf("[%s] Loads OK", b)
return model, nil return model, nil

@ -14,10 +14,21 @@ type Options struct {
context context.Context context context.Context
gRPCOptions *pb.ModelOptions gRPCOptions *pb.ModelOptions
externalBackends map[string]string
} }
type Option func(*Options) type Option func(*Options)
func WithExternalBackend(name string, uri string) Option {
return func(o *Options) {
if o.externalBackends == nil {
o.externalBackends = make(map[string]string)
}
o.externalBackends[name] = uri
}
}
func WithBackendString(backend string) Option { func WithBackendString(backend string) Option {
return func(o *Options) { return func(o *Options) {
o.backendString = backend o.backendString = backend

@ -0,0 +1,37 @@
package utils
import (
"time"
"github.com/rs/zerolog/log"
)
var lastProgress time.Time = time.Now()
var startTime time.Time = time.Now()
func ResetDownloadTimers() {
lastProgress = time.Now()
startTime = time.Now()
}
func DisplayDownloadFunction(fileName string, current string, total string, percentage float64) {
currentTime := time.Now()
if currentTime.Sub(lastProgress) >= 5*time.Second {
lastProgress = currentTime
// calculate ETA based on percentage and elapsed time
var eta time.Duration
if percentage > 0 {
elapsed := currentTime.Sub(startTime)
eta = time.Duration(float64(elapsed)*(100/percentage) - float64(elapsed))
}
if total != "" {
log.Debug().Msgf("Downloading %s: %s/%s (%.2f%%) ETA: %s", fileName, current, total, percentage, eta)
} else {
log.Debug().Msgf("Downloading: %s", current)
}
}
}

@ -0,0 +1,5 @@
name: code-search-ada-code-001
backend: huggingface
embeddings: true
parameters:
model: all-MiniLM-L6-v2
Loading…
Cancel
Save