From 76c881043e8e427f0131eb026079e7fe917cc010 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 27 May 2023 09:26:33 +0200 Subject: [PATCH] feat: allow to preload models before startup via env var or configs (#391) --- api/api.go | 12 ++++++ api/gallery.go | 111 ++++++++++++++++++++++++++++++++----------------- api/options.go | 13 ++++++ main.go | 12 ++++++ 4 files changed, 111 insertions(+), 37 deletions(-) diff --git a/api/api.go b/api/api.go index 872d3ed..dd5f302 100644 --- a/api/api.go +++ b/api/api.go @@ -69,6 +69,18 @@ func App(opts ...AppOption) *fiber.App { // Default middleware config app.Use(recover.New()) + if options.preloadJSONModels != "" { + if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm); err != nil { + return nil + } + } + + if options.preloadModelsFromPath != "" { + if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm); err != nil { + return nil + } + } + if options.cors { if options.corsAllowOrigins == "" { app.Use(cors.New()) diff --git a/api/gallery.go b/api/gallery.go index 591b1b7..b5b74b0 100644 --- a/api/gallery.go +++ b/api/gallery.go @@ -2,10 +2,12 @@ package api import ( "context" + "encoding/json" "fmt" "io/ioutil" "net/http" "net/url" + "os" "strings" "sync" @@ -40,6 +42,43 @@ func newGalleryApplier(modelPath string) *galleryApplier { statuses: make(map[string]*galleryOpStatus), } } + +func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger) error { + url, err := req.DecodeURL() + if err != nil { + return err + } + + // Send a GET request to the URL + response, err := http.Get(url) + if err != nil { + return err + } + defer response.Body.Close() + + // Read the response body + body, err := ioutil.ReadAll(response.Body) + if err != nil { + return err + } + + // Unmarshal YAML data into a Config struct + var config gallery.Config + err = yaml.Unmarshal(body, &config) + if err != nil { + return err + } + + config.Files = append(config.Files, req.AdditionalFiles...) + + if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides); err != nil { + return err + } + + // Reload models + return cm.LoadConfigs(modelPath) +} + func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) { g.Lock() defer g.Unlock() @@ -66,52 +105,50 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) } - url, err := op.req.DecodeURL() - if err != nil { - updateError(err) - continue - } - - // Send a GET request to the URL - response, err := http.Get(url) - if err != nil { + if err := applyGallery(g.modelPath, op.req, cm); err != nil { updateError(err) continue } - defer response.Body.Close() - // Read the response body - body, err := ioutil.ReadAll(response.Body) - if err != nil { - updateError(err) - continue - } - - // Unmarshal YAML data into a Config struct - var config gallery.Config - err = yaml.Unmarshal(body, &config) - if err != nil { - updateError(fmt.Errorf("failed to unmarshal YAML: %v", err)) - continue - } + g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"}) + } + } + }() +} - config.Files = append(config.Files, op.req.AdditionalFiles...) +func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { + dat, err := os.ReadFile(s) + if err != nil { + return err + } + var requests []ApplyGalleryModelRequest + err = json.Unmarshal(dat, &requests) + if err != nil { + return err + } - if err := gallery.Apply(g.modelPath, op.req.Name, &config, op.req.Overrides); err != nil { - updateError(err) - continue - } + for _, r := range requests { + if err := applyGallery(modelPath, r, cm); err != nil { + return err + } + } - // Reload models - if err := cm.LoadConfigs(g.modelPath); err != nil { - updateError(err) - continue - } + return nil +} +func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { + var requests []ApplyGalleryModelRequest + err := json.Unmarshal([]byte(s), &requests) + if err != nil { + return err + } - g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"}) - } + for _, r := range requests { + if err := applyGallery(modelPath, r, cm); err != nil { + return err } - }() + } + + return nil } // endpoints diff --git a/api/options.go b/api/options.go index f99dda4..ea7497c 100644 --- a/api/options.go +++ b/api/options.go @@ -15,6 +15,8 @@ type Option struct { debug, disableMessage bool imageDir string cors bool + preloadJSONModels string + preloadModelsFromPath string corsAllowOrigins string } @@ -53,6 +55,17 @@ func WithContext(ctx context.Context) AppOption { } } +func WithYAMLConfigPreload(configFile string) AppOption { + return func(o *Option) { + o.preloadModelsFromPath = configFile + } +} + +func WithJSONStringPreload(configFile string) AppOption { + return func(o *Option) { + o.preloadJSONModels = configFile + } +} func WithConfigFile(configFile string) AppOption { return func(o *Option) { o.configFile = configFile diff --git a/main.go b/main.go index b5105fe..f391aff 100644 --- a/main.go +++ b/main.go @@ -53,6 +53,16 @@ func main() { EnvVars: []string{"MODELS_PATH"}, Value: filepath.Join(path, "models"), }, + &cli.StringFlag{ + Name: "preload-models", + DefaultText: "A List of models to apply in JSON at start", + EnvVars: []string{"PRELOAD_MODELS"}, + }, + &cli.StringFlag{ + Name: "preload-models-config", + DefaultText: "A List of models to apply at startup. Path to a YAML config file", + EnvVars: []string{"PRELOAD_MODELS_CONFIG"}, + }, &cli.StringFlag{ Name: "config-file", DefaultText: "Config file", @@ -103,6 +113,8 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) return api.App( api.WithConfigFile(ctx.String("config-file")), + api.WithJSONStringPreload(ctx.String("preload-models")), + api.WithYAMLConfigPreload(ctx.String("preload-models-config")), api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), api.WithContextSize(ctx.Int("context-size")), api.WithDebug(ctx.Bool("debug")),