From cc9aa9eb3ff6bc511a4988379dbde1cc853d1239 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 18 May 2023 15:59:03 +0200 Subject: [PATCH] feat: add /models/apply endpoint to prepare models (#286) --- Makefile | 6 +- README.md | 7 +- api/api.go | 27 +- api/api_test.go | 13 +- api/config.go | 44 +++- api/gallery.go | 146 +++++++++++ api/openai.go | 16 +- main.go | 7 +- pkg/gallery/gallery_suite_test.go | 13 + pkg/gallery/models.go | 237 ++++++++++++++++++ pkg/gallery/models_test.go | 30 +++ pkg/model/initializers.go | 3 +- tests/fixtures/gallery_simple.yaml | 40 +++ .../completion.tmpl | 0 .../{fixtures => models_fixtures}/config.yaml | 0 .../embeddings.yaml | 0 .../ggml-gpt4all-j.tmpl | 0 tests/{fixtures => models_fixtures}/gpt4.yaml | 0 .../{fixtures => models_fixtures}/gpt4_2.yaml | 0 tests/{fixtures => models_fixtures}/rwkv.yaml | 0 .../rwkv_chat.tmpl | 0 .../rwkv_completion.tmpl | 0 .../whisper.yaml | 0 23 files changed, 556 insertions(+), 33 deletions(-) create mode 100644 api/gallery.go create mode 100644 pkg/gallery/gallery_suite_test.go create mode 100644 pkg/gallery/models.go create mode 100644 pkg/gallery/models_test.go create mode 100644 tests/fixtures/gallery_simple.yaml rename tests/{fixtures => models_fixtures}/completion.tmpl (100%) rename tests/{fixtures => models_fixtures}/config.yaml (100%) rename tests/{fixtures => models_fixtures}/embeddings.yaml (100%) rename tests/{fixtures => models_fixtures}/ggml-gpt4all-j.tmpl (100%) rename tests/{fixtures => models_fixtures}/gpt4.yaml (100%) rename tests/{fixtures => models_fixtures}/gpt4_2.yaml (100%) rename tests/{fixtures => models_fixtures}/rwkv.yaml (100%) rename tests/{fixtures => models_fixtures}/rwkv_chat.tmpl (100%) rename tests/{fixtures => models_fixtures}/rwkv_completion.tmpl (100%) rename tests/{fixtures => models_fixtures}/whisper.yaml (100%) diff --git a/Makefile b/Makefile index 523f152..ea75023 100644 --- a/Makefile +++ b/Makefile @@ -211,11 +211,11 @@ test-models/testmodel: wget https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav wget https://huggingface.co/imxcstar/rwkv-4-raven-ggml/resolve/main/RWKV-4-Raven-1B5-v11-Eng99%25-Other1%25-20230425-ctx4096-16_Q4_2.bin -O test-models/rwkv wget https://raw.githubusercontent.com/saharNooby/rwkv.cpp/5eb8f09c146ea8124633ab041d9ea0b1f1db4459/rwkv/20B_tokenizer.json -O test-models/rwkv.tokenizer.json - cp tests/fixtures/* test-models + cp tests/models_fixtures/* test-models test: prepare test-models/testmodel - cp tests/fixtures/* test-models - @C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api + cp tests/models_fixtures/* test-models + C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo -v -r ./api ./pkg ## Help: help: ## Show this help. diff --git a/README.md b/README.md index 1cbb414..3a4cbb9 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,12 @@ **LocalAI** is a drop-in replacement REST API compatible with OpenAI API specifications for local inferencing. It allows to run models locally or on-prem with consumer grade hardware, supporting multiple models families compatible with the `ggml` format. For a list of the supported model families, see [the model compatibility table below](https://github.com/go-skynet/LocalAI#model-compatibility-table). -- OpenAI drop-in alternative REST API +- Local, OpenAI drop-in alternative REST API. You own your data. - Supports multiple models, Audio transcription, Text generation with GPTs, Image generation with stable diffusion (experimental) - Once loaded the first time, it keep models loaded in memory for faster inference - Support for prompt templates - Doesn't shell-out, but uses C++ bindings for a faster inference and better performance. +- NO GPU required. NO Internet access is required either. Optional, GPU Acceleration is available in `llama.cpp`-compatible LLMs. [See building instructions](https://github.com/go-skynet/LocalAI#cublas). LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud). @@ -434,7 +435,7 @@ local-ai --models-path [--address
] [--threads @@ -567,6 +568,8 @@ Note: CuBLAS support is experimental, and has not been tested on real HW. please make BUILD_TYPE=cublas build ``` +More informations available in the upstream PR: https://github.com/ggerganov/llama.cpp/pull/1412 + ### Windows compatibility diff --git a/api/api.go b/api/api.go index ecf56b0..ec7c981 100644 --- a/api/api.go +++ b/api/api.go @@ -1,6 +1,7 @@ package api import ( + "context" "errors" model "github.com/go-skynet/LocalAI/pkg/model" @@ -12,7 +13,7 @@ import ( "github.com/rs/zerolog/log" ) -func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App { +func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App { zerolog.SetGlobalLevel(zerolog.InfoLevel) if debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) @@ -48,7 +49,7 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c })) } - cm := make(ConfigMerger) + cm := NewConfigMerger() if err := cm.LoadConfigs(loader.ModelPath); err != nil { log.Error().Msgf("error loading config files: %s", err.Error()) } @@ -60,39 +61,51 @@ func App(configFile string, loader *model.ModelLoader, uploadLimitMB, threads, c } if debug { - for k, v := range cm { - log.Debug().Msgf("Model: %s (config: %+v)", k, v) + for _, v := range cm.ListConfigs() { + cfg, _ := cm.GetConfig(v) + log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) } } // Default middleware config app.Use(recover.New()) app.Use(cors.New()) + // LocalAI API endpoints + applier := newGalleryApplier(loader.ModelPath) + applier.start(c, cm) + app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C)) + app.Get("/models/jobs/:uid", getOpStatus(applier)) + // openAI compatible API endpoint + + // chat app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16)) + // edit app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16)) + // completion app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16)) + // embeddings app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) - - // /v1/engines/{engine_id}/embeddings - app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) + // audio app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16)) + // images app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir)) if imageDir != "" { app.Static("/generated-images", imageDir) } + // models app.Get("/v1/models", listModels(loader, cm)) app.Get("/models", listModels(loader, cm)) diff --git a/api/api_test.go b/api/api_test.go index f2af038..1a5d7d4 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -22,10 +22,14 @@ var _ = Describe("API test", func() { var modelLoader *model.ModelLoader var client *openai.Client var client2 *openaigo.Client + var c context.Context + var cancel context.CancelFunc Context("API query", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - app = App("", modelLoader, 15, 1, 512, false, true, true, "") + c, cancel = context.WithCancel(context.Background()) + + app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "") go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -42,6 +46,7 @@ var _ = Describe("API test", func() { }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func() { + cancel() app.Shutdown() }) It("returns the models list", func() { @@ -140,7 +145,9 @@ var _ = Describe("API test", func() { Context("Config file", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - app = App(os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "") + c, cancel = context.WithCancel(context.Background()) + + app = App(c, os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "") go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -155,10 +162,10 @@ var _ = Describe("API test", func() { }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func() { + cancel() app.Shutdown() }) It("can generate chat completions from config file", func() { - models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) Expect(len(models.Models)).To(Equal(12)) diff --git a/api/config.go b/api/config.go index 7379978..7e0d826 100644 --- a/api/config.go +++ b/api/config.go @@ -7,6 +7,7 @@ import ( "os" "path/filepath" "strings" + "sync" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" @@ -43,8 +44,16 @@ type TemplateConfig struct { Edit string `yaml:"edit"` } -type ConfigMerger map[string]Config +type ConfigMerger struct { + configs map[string]Config + sync.Mutex +} +func NewConfigMerger() *ConfigMerger { + return &ConfigMerger{ + configs: make(map[string]Config), + } +} func ReadConfigFile(file string) ([]*Config, error) { c := &[]*Config{} f, err := os.ReadFile(file) @@ -72,28 +81,51 @@ func ReadConfig(file string) (*Config, error) { } func (cm ConfigMerger) LoadConfigFile(file string) error { + cm.Lock() + defer cm.Unlock() c, err := ReadConfigFile(file) if err != nil { return fmt.Errorf("cannot load config file: %w", err) } for _, cc := range c { - cm[cc.Name] = *cc + cm.configs[cc.Name] = *cc } return nil } func (cm ConfigMerger) LoadConfig(file string) error { + cm.Lock() + defer cm.Unlock() c, err := ReadConfig(file) if err != nil { return fmt.Errorf("cannot read config file: %w", err) } - cm[c.Name] = *c + cm.configs[c.Name] = *c return nil } +func (cm ConfigMerger) GetConfig(m string) (Config, bool) { + cm.Lock() + defer cm.Unlock() + v, exists := cm.configs[m] + return v, exists +} + +func (cm ConfigMerger) ListConfigs() []string { + cm.Lock() + defer cm.Unlock() + var res []string + for k := range cm.configs { + res = append(res, k) + } + return res +} + func (cm ConfigMerger) LoadConfigs(path string) error { + cm.Lock() + defer cm.Unlock() files, err := ioutil.ReadDir(path) if err != nil { return err @@ -106,7 +138,7 @@ func (cm ConfigMerger) LoadConfigs(path string) error { } c, err := ReadConfig(filepath.Join(path, file.Name())) if err == nil { - cm[c.Name] = *c + cm.configs[c.Name] = *c } } @@ -253,7 +285,7 @@ func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (strin return modelFile, input, nil } -func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { +func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { // Load a config file if present after the model name modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") if _, err := os.Stat(modelConfig); err == nil { @@ -263,7 +295,7 @@ func readConfig(modelFile string, input *OpenAIRequest, cm ConfigMerger, loader } var config *Config - cfg, exists := cm[modelFile] + cfg, exists := cm.GetConfig(modelFile) if !exists { config = &Config{ OpenAIRequest: defaultRequest(modelFile), diff --git a/api/gallery.go b/api/gallery.go new file mode 100644 index 0000000..5378c7b --- /dev/null +++ b/api/gallery.go @@ -0,0 +1,146 @@ +package api + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "sync" + + "github.com/go-skynet/LocalAI/pkg/gallery" + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "gopkg.in/yaml.v3" +) + +type galleryOp struct { + req ApplyGalleryModelRequest + id string +} + +type galleryOpStatus struct { + Error error `json:"error"` + Processed bool `json:"processed"` + Message string `json:"message"` +} + +type galleryApplier struct { + modelPath string + sync.Mutex + C chan galleryOp + statuses map[string]*galleryOpStatus +} + +func newGalleryApplier(modelPath string) *galleryApplier { + return &galleryApplier{ + modelPath: modelPath, + C: make(chan galleryOp), + statuses: make(map[string]*galleryOpStatus), + } +} +func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) { + g.Lock() + defer g.Unlock() + g.statuses[s] = op +} + +func (g *galleryApplier) getstatus(s string) *galleryOpStatus { + g.Lock() + defer g.Unlock() + + return g.statuses[s] +} + +func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { + go func() { + for { + select { + case <-c.Done(): + return + case op := <-g.C: + g.updatestatus(op.id, &galleryOpStatus{Message: "processing"}) + + updateError := func(e error) { + g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) + } + // Send a GET request to the URL + response, err := http.Get(op.req.URL) + if err != nil { + updateError(err) + continue + } + defer response.Body.Close() + + // Read the response body + body, err := ioutil.ReadAll(response.Body) + if err != nil { + updateError(err) + continue + } + + // Unmarshal YAML data into a Config struct + var config gallery.Config + err = yaml.Unmarshal(body, &config) + if err != nil { + updateError(fmt.Errorf("failed to unmarshal YAML: %v", err)) + continue + } + + if err := gallery.Apply(g.modelPath, op.req.Name, &config); err != nil { + updateError(err) + continue + } + + // Reload models + if err := cm.LoadConfigs(g.modelPath); err != nil { + updateError(err) + continue + } + + g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"}) + } + } + }() +} + +// endpoints + +type ApplyGalleryModelRequest struct { + URL string `json:"url"` + Name string `json:"name"` +} + +func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + status := g.getstatus(c.Params("uid")) + if status == nil { + return fmt.Errorf("could not find any status for ID") + } + + return c.JSON(status) + } +} + +func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + input := new(ApplyGalleryModelRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + uuid, err := uuid.NewUUID() + if err != nil { + return err + } + g <- galleryOp{ + req: *input, + id: uuid.String(), + } + return c.JSON(struct { + ID string `json:"uid"` + StatusURL string `json:"status"` + }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) + } +} diff --git a/api/openai.go b/api/openai.go index 52d6597..0a85349 100644 --- a/api/openai.go +++ b/api/openai.go @@ -142,7 +142,7 @@ func defaultRequest(modelFile string) OpenAIRequest { } // https://platform.openai.com/docs/api-reference/completions -func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { model, input, err := readInput(c, loader, true) @@ -199,7 +199,7 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, } // https://platform.openai.com/docs/api-reference/embeddings -func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { model, input, err := readInput(c, loader, true) if err != nil { @@ -256,7 +256,7 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, } } -func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { @@ -378,7 +378,7 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread } } -func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { model, input, err := readInput(c, loader, true) if err != nil { @@ -449,7 +449,7 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread * */ -func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error { +func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { m, input, err := readInput(c, loader, false) if err != nil { @@ -574,7 +574,7 @@ func imageEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, image } // https://platform.openai.com/docs/api-reference/audio/create -func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { m, input, err := readInput(c, loader, false) if err != nil { @@ -641,7 +641,7 @@ func transcriptEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, } } -func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) error { +func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error { return func(c *fiber.Ctx) error { models, err := loader.ListModels() if err != nil { @@ -655,7 +655,7 @@ func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) } - for k := range cm { + for _, k := range cm.ListConfigs() { if _, exists := mm[k]; !exists { dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) } diff --git a/main.go b/main.go index 2490e19..f3ffc03 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "path/filepath" @@ -57,9 +58,9 @@ func main() { Value: ":8080", }, &cli.StringFlag{ - Name: "image-dir", + Name: "image-path", DefaultText: "Image directory", - EnvVars: []string{"IMAGE_DIR"}, + EnvVars: []string{"IMAGE_PATH"}, Value: "", }, &cli.IntFlag{ @@ -93,7 +94,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. Copyright: "go-skynet authors", Action: func(ctx *cli.Context) error { fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) - return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-dir")).Listen(ctx.String("address")) + return api.App(context.Background(), ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-path")).Listen(ctx.String("address")) }, } diff --git a/pkg/gallery/gallery_suite_test.go b/pkg/gallery/gallery_suite_test.go new file mode 100644 index 0000000..44256bc --- /dev/null +++ b/pkg/gallery/gallery_suite_test.go @@ -0,0 +1,13 @@ +package gallery_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestGallery(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Gallery test suite") +} diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go new file mode 100644 index 0000000..bd9e137 --- /dev/null +++ b/pkg/gallery/models.go @@ -0,0 +1,237 @@ +package gallery + +import ( + "crypto/sha256" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + + "github.com/rs/zerolog/log" + "gopkg.in/yaml.v2" +) + +/* + +description: | + foo +license: "" + +urls: +- +- + +name: "bar" + +config_file: | + # Note, name will be injected. or generated by the alias wanted by the user + threads: 14 + +files: + - filename: "" + sha: "" + uri: "" + +prompt_templates: + - name: "" + content: "" + +*/ + +type Config struct { + Description string `yaml:"description"` + License string `yaml:"license"` + URLs []string `yaml:"urls"` + Name string `yaml:"name"` + ConfigFile string `yaml:"config_file"` + Files []File `yaml:"files"` + PromptTemplates []PromptTemplate `yaml:"prompt_templates"` +} + +type File struct { + Filename string `yaml:"filename"` + SHA256 string `yaml:"sha256"` + URI string `yaml:"uri"` +} + +type PromptTemplate struct { + Name string `yaml:"name"` + Content string `yaml:"content"` +} + +func ReadConfigFile(filePath string) (*Config, error) { + // Read the YAML file + yamlFile, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read YAML file: %v", err) + } + + // Unmarshal YAML data into a Config struct + var config Config + err = yaml.Unmarshal(yamlFile, &config) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal YAML: %v", err) + } + + return &config, nil +} + +func Apply(basePath, nameOverride string, config *Config) error { + // Create base path if it doesn't exist + err := os.MkdirAll(basePath, 0755) + if err != nil { + return fmt.Errorf("failed to create base path: %v", err) + } + + // Download files and verify their SHA + for _, file := range config.Files { + log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) + + // Create file path + filePath := filepath.Join(basePath, file.Filename) + + // Check if the file already exists + _, err := os.Stat(filePath) + if err == nil { + // File exists, check SHA + if file.SHA256 != "" { + // Verify SHA + calculatedSHA, err := calculateSHA(filePath) + if err != nil { + return fmt.Errorf("failed to calculate SHA for file %q: %v", file.Filename, err) + } + if calculatedSHA == file.SHA256 { + // SHA matches, skip downloading + log.Debug().Msgf("File %q already exists and matches the SHA. Skipping download", file.Filename) + continue + } + // SHA doesn't match, delete the file and download again + err = os.Remove(filePath) + if err != nil { + return fmt.Errorf("failed to remove existing file %q: %v", file.Filename, err) + } + log.Debug().Msgf("Removed %q (SHA doesn't match)", filePath) + + } else { + // SHA is missing, skip downloading + log.Debug().Msgf("File %q already exists. Skipping download", file.Filename) + continue + } + } else if !os.IsNotExist(err) { + // Error occurred while checking file existence + return fmt.Errorf("failed to check file %q existence: %v", file.Filename, err) + } + + log.Debug().Msgf("Downloading %q", file.URI) + + // Download file + resp, err := http.Get(file.URI) + if err != nil { + return fmt.Errorf("failed to download file %q: %v", file.Filename, err) + } + defer resp.Body.Close() + + // Create parent directory + err = os.MkdirAll(filepath.Dir(filePath), 0755) + if err != nil { + return fmt.Errorf("failed to create parent directory for file %q: %v", file.Filename, err) + } + + // Create and write file content + outFile, err := os.Create(filePath) + if err != nil { + return fmt.Errorf("failed to create file %q: %v", file.Filename, err) + } + defer outFile.Close() + + if file.SHA256 != "" { + log.Debug().Msgf("Download and verifying %q", file.Filename) + + // Write file content and calculate SHA + hash := sha256.New() + _, err = io.Copy(io.MultiWriter(outFile, hash), resp.Body) + if err != nil { + return fmt.Errorf("failed to write file %q: %v", file.Filename, err) + } + + // Verify SHA + calculatedSHA := fmt.Sprintf("%x", hash.Sum(nil)) + if calculatedSHA != file.SHA256 { + return fmt.Errorf("SHA mismatch for file %q ( calculated: %s != metadata: %s )", file.Filename, calculatedSHA, file.SHA256) + } + } else { + log.Debug().Msgf("SHA missing for %q. Skipping validation", file.Filename) + _, err = io.Copy(outFile, resp.Body) + if err != nil { + return fmt.Errorf("failed to write file %q: %v", file.Filename, err) + } + } + + log.Debug().Msgf("File %q downloaded and verified", file.Filename) + } + + // Write prompt template contents to separate files + for _, template := range config.PromptTemplates { + // Create file path + filePath := filepath.Join(basePath, template.Name+".tmpl") + + // Create parent directory + err := os.MkdirAll(filepath.Dir(filePath), 0755) + if err != nil { + return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err) + } + // Create and write file content + err = os.WriteFile(filePath, []byte(template.Content), 0644) + if err != nil { + return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err) + } + + log.Debug().Msgf("Prompt template %q written", template.Name) + } + + name := config.Name + if nameOverride != "" { + name = nameOverride + } + + configFilePath := filepath.Join(basePath, name+".yaml") + + // Read and update config file as map[string]interface{} + configMap := make(map[string]interface{}) + err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap) + if err != nil { + return fmt.Errorf("failed to unmarshal config YAML: %v", err) + } + + configMap["name"] = name + + // Write updated config file + updatedConfigYAML, err := yaml.Marshal(configMap) + if err != nil { + return fmt.Errorf("failed to marshal updated config YAML: %v", err) + } + + err = os.WriteFile(configFilePath, updatedConfigYAML, 0644) + if err != nil { + return fmt.Errorf("failed to write updated config file: %v", err) + } + + log.Debug().Msgf("Written config file %s", configFilePath) + return nil +} + +func calculateSHA(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return "", err + } + + return fmt.Sprintf("%x", hash.Sum(nil)), nil +} diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go new file mode 100644 index 0000000..123948a --- /dev/null +++ b/pkg/gallery/models_test.go @@ -0,0 +1,30 @@ +package gallery_test + +import ( + "os" + "path/filepath" + + . "github.com/go-skynet/LocalAI/pkg/gallery" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Model test", func() { + Context("Downloading", func() { + It("applies model correctly", func() { + tempdir, err := os.MkdirTemp("", "test") + Expect(err).ToNot(HaveOccurred()) + defer os.RemoveAll(tempdir) + c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + Expect(err).ToNot(HaveOccurred()) + + err = Apply(tempdir, "", c) + Expect(err).ToNot(HaveOccurred()) + + for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { + _, err = os.Stat(filepath.Join(tempdir, f)) + Expect(err).ToNot(HaveOccurred()) + } + }) + }) +}) diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 74c05f2..b5e43a3 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -164,11 +164,12 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla } func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (interface{}, error) { - log.Debug().Msgf("Loading models greedly") + log.Debug().Msgf("Loading model '%s' greedly", modelFile) ml.mu.Lock() m, exists := ml.models[modelFile] if exists { + log.Debug().Msgf("Model '%s' already loaded", modelFile) ml.mu.Unlock() return m, nil } diff --git a/tests/fixtures/gallery_simple.yaml b/tests/fixtures/gallery_simple.yaml new file mode 100644 index 0000000..058733f --- /dev/null +++ b/tests/fixtures/gallery_simple.yaml @@ -0,0 +1,40 @@ +name: "cerebras" +description: | + cerebras +license: "Apache 2.0" + +config_file: | + parameters: + model: cerebras + top_k: 80 + temperature: 0.2 + top_p: 0.7 + context_size: 1024 + stopwords: + - "HUMAN:" + - "GPT:" + roles: + user: "" + system: "" + template: + completion: "cerebras-completion" + chat: cerebras-chat + +files: + - filename: "cerebras" + sha256: "c947051ae4dba9530ca55d923a7a484acd65664c8633462c8ccd4bb7848f2c65" + uri: "https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerebras-111m-q4_2.bin" + +prompt_templates: + - name: "cerebras-completion" + content: | + Complete the prompt + ### Prompt: + {{.Input}} + ### Response: + - name: "cerebras-chat" + content: | + The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. + ### Prompt: + {{.Input}} + ### Response: \ No newline at end of file diff --git a/tests/fixtures/completion.tmpl b/tests/models_fixtures/completion.tmpl similarity index 100% rename from tests/fixtures/completion.tmpl rename to tests/models_fixtures/completion.tmpl diff --git a/tests/fixtures/config.yaml b/tests/models_fixtures/config.yaml similarity index 100% rename from tests/fixtures/config.yaml rename to tests/models_fixtures/config.yaml diff --git a/tests/fixtures/embeddings.yaml b/tests/models_fixtures/embeddings.yaml similarity index 100% rename from tests/fixtures/embeddings.yaml rename to tests/models_fixtures/embeddings.yaml diff --git a/tests/fixtures/ggml-gpt4all-j.tmpl b/tests/models_fixtures/ggml-gpt4all-j.tmpl similarity index 100% rename from tests/fixtures/ggml-gpt4all-j.tmpl rename to tests/models_fixtures/ggml-gpt4all-j.tmpl diff --git a/tests/fixtures/gpt4.yaml b/tests/models_fixtures/gpt4.yaml similarity index 100% rename from tests/fixtures/gpt4.yaml rename to tests/models_fixtures/gpt4.yaml diff --git a/tests/fixtures/gpt4_2.yaml b/tests/models_fixtures/gpt4_2.yaml similarity index 100% rename from tests/fixtures/gpt4_2.yaml rename to tests/models_fixtures/gpt4_2.yaml diff --git a/tests/fixtures/rwkv.yaml b/tests/models_fixtures/rwkv.yaml similarity index 100% rename from tests/fixtures/rwkv.yaml rename to tests/models_fixtures/rwkv.yaml diff --git a/tests/fixtures/rwkv_chat.tmpl b/tests/models_fixtures/rwkv_chat.tmpl similarity index 100% rename from tests/fixtures/rwkv_chat.tmpl rename to tests/models_fixtures/rwkv_chat.tmpl diff --git a/tests/fixtures/rwkv_completion.tmpl b/tests/models_fixtures/rwkv_completion.tmpl similarity index 100% rename from tests/fixtures/rwkv_completion.tmpl rename to tests/models_fixtures/rwkv_completion.tmpl diff --git a/tests/fixtures/whisper.yaml b/tests/models_fixtures/whisper.yaml similarity index 100% rename from tests/fixtures/whisper.yaml rename to tests/models_fixtures/whisper.yaml