diff --git a/.env b/.env index c524859..73e3174 100644 --- a/.env +++ b/.env @@ -1,6 +1,30 @@ +## Set number of threads. +## Note: prefer the number of physical cores. Overbooking the CPU degrades performance notably. # THREADS=14 + +## Specify a different bind address (defaults to ":8080") +# ADDRESS=127.0.0.1:8080 + +## Default models context size # CONTEXT_SIZE=512 + +## Default path for models MODELS_PATH=/models + +## Enable debug mode # DEBUG=true -# BUILD_TYPE=generic -# REBUILD=true + +## Specify a build type. Available: cublas, openblas. +# BUILD_TYPE=openblas + +## Uncomment and set to false to disable rebuilding from source +# REBUILD=false + +## Enable image generation with stablediffusion (requires REBUILD=true) +# GO_TAGS=stablediffusion + +## Path where to store generated images +# IMAGE_PATH=/tmp + +## Specify a default upload limit in MB (whisper) +# UPLOAD_LIMIT \ No newline at end of file diff --git a/README.md b/README.md index 007802a..e71a56d 100644 --- a/README.md +++ b/README.md @@ -869,17 +869,18 @@ curl http://localhost:8080/models/apply -H "Content-Type: application/json" -d ' "uri": "", "sha256": "", "name": "" - } + }, + "overrides": { "backend": "...", "f16": true } ] } ``` -An optional, list of additional files can be specified to be downloaded. The `name` allows to override the model name. +An optional, list of additional files can be specified to be downloaded within `files`. The `name` allows to override the model name. Finally it is possible to override the model config file with `override`. Returns an `uuid` and an `url` to follow up the state of the process: ```json -{ "uid":"251475c9-f666-11ed-95e0-9a8a4480ac58", "status":"http://localhost:8080/models/jobs/251475c9-f666-11ed-95e0-9a8a4480ac58"} +{ "uuid":"251475c9-f666-11ed-95e0-9a8a4480ac58", "status":"http://localhost:8080/models/jobs/251475c9-f666-11ed-95e0-9a8a4480ac58"} ``` To see a collection example of curated models definition files, see the [model-gallery](https://github.com/go-skynet/model-gallery). diff --git a/api/api.go b/api/api.go index ec7c981..b81a89f 100644 --- a/api/api.go +++ b/api/api.go @@ -74,7 +74,7 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload applier := newGalleryApplier(loader.ModelPath) applier.start(c, cm) app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C)) - app.Get("/models/jobs/:uid", getOpStatus(applier)) + app.Get("/models/jobs/:uuid", getOpStatus(applier)) // openAI compatible API endpoint diff --git a/api/api_test.go b/api/api_test.go index 1a5d7d4..f061527 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -1,7 +1,12 @@ package api_test import ( + "bytes" "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" "os" "path/filepath" "runtime" @@ -11,11 +16,85 @@ import ( "github.com/gofiber/fiber/v2" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "gopkg.in/yaml.v3" openaigo "github.com/otiai10/openaigo" "github.com/sashabaranov/go-openai" ) +type modelApplyRequest struct { + URL string `json:"url"` + Name string `json:"name"` + Overrides map[string]string `json:"overrides"` +} + +func getModelStatus(url string) (response map[string]interface{}) { + // Create the HTTP request + resp, err := http.Get(url) + if err != nil { + fmt.Println("Error creating request:", err) + return + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Println("Error reading response body:", err) + return + } + + // Unmarshal the response into a map[string]interface{} + err = json.Unmarshal(body, &response) + if err != nil { + fmt.Println("Error unmarshaling JSON response:", err) + return + } + return +} +func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { + + //url := "http://localhost:AI/models/apply" + + // Create the request payload + + payload, err := json.Marshal(request) + if err != nil { + fmt.Println("Error marshaling JSON:", err) + return + } + + // Create the HTTP request + req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) + if err != nil { + fmt.Println("Error creating request:", err) + return + } + req.Header.Set("Content-Type", "application/json") + + // Make the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + fmt.Println("Error making request:", err) + return + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Println("Error reading response body:", err) + return + } + + // Unmarshal the response into a map[string]interface{} + err = json.Unmarshal(body, &response) + if err != nil { + fmt.Println("Error unmarshaling JSON response:", err) + return + } + return +} + var _ = Describe("API test", func() { var app *fiber.App @@ -24,6 +103,96 @@ var _ = Describe("API test", func() { var client2 *openaigo.Client var c context.Context var cancel context.CancelFunc + var tmpdir string + + Context("API with ephemeral models", func() { + BeforeEach(func() { + var err error + tmpdir, err = os.MkdirTemp("", "") + Expect(err).ToNot(HaveOccurred()) + + modelLoader = model.NewModelLoader(tmpdir) + c, cancel = context.WithCancel(context.Background()) + + app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "") + go app.Listen("127.0.0.1:9090") + + defaultConfig := openai.DefaultConfig("") + defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + + client2 = openaigo.NewClient("") + client2.BaseURL = defaultConfig.BaseURL + + // Wait for API to be ready + client = openai.NewClientWithConfig(defaultConfig) + Eventually(func() error { + _, err := client.ListModels(context.TODO()) + return err + }, "2m").ShouldNot(HaveOccurred()) + }) + + AfterEach(func() { + cancel() + app.Shutdown() + os.RemoveAll(tmpdir) + }) + + Context("Applying models", func() { + It("overrides models", func() { + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", + Name: "bert", + Overrides: map[string]string{ + "backend": "llama", + }, + }) + + Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) + + uuid := response["uuid"].(string) + + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + fmt.Println(response) + return response["processed"].(bool) + }, "360s").Should(Equal(true)) + + dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) + Expect(err).ToNot(HaveOccurred()) + + content := map[string]interface{}{} + err = yaml.Unmarshal(dat, &content) + Expect(err).ToNot(HaveOccurred()) + Expect(content["backend"]).To(Equal("llama")) + }) + It("apply models without overrides", func() { + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", + Name: "bert", + Overrides: map[string]string{}, + }) + + Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) + + uuid := response["uuid"].(string) + + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + fmt.Println(response) + return response["processed"].(bool) + }, "360s").Should(Equal(true)) + + dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) + Expect(err).ToNot(HaveOccurred()) + + content := map[string]interface{}{} + err = yaml.Unmarshal(dat, &content) + Expect(err).ToNot(HaveOccurred()) + Expect(content["backend"]).To(Equal("bert-embeddings")) + }) + }) + }) + Context("API query", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) diff --git a/api/gallery.go b/api/gallery.go index fac4862..591b1b7 100644 --- a/api/gallery.go +++ b/api/gallery.go @@ -97,7 +97,7 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { config.Files = append(config.Files, op.req.AdditionalFiles...) - if err := gallery.Apply(g.modelPath, op.req.Name, &config); err != nil { + if err := gallery.Apply(g.modelPath, op.req.Name, &config, op.req.Overrides); err != nil { updateError(err) continue } @@ -117,9 +117,10 @@ func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { // endpoints type ApplyGalleryModelRequest struct { - URL string `json:"url"` - Name string `json:"name"` - AdditionalFiles []gallery.File `json:"files"` + URL string `json:"url"` + Name string `json:"name"` + Overrides map[string]interface{} `json:"overrides"` + AdditionalFiles []gallery.File `json:"files"` } const ( @@ -162,7 +163,7 @@ func (request ApplyGalleryModelRequest) DecodeURL() (string, error) { func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - status := g.getstatus(c.Params("uid")) + status := g.getstatus(c.Params("uuid")) if status == nil { return fmt.Errorf("could not find any status for ID") } @@ -188,7 +189,7 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) fun id: uuid.String(), } return c.JSON(struct { - ID string `json:"uid"` + ID string `json:"uuid"` StatusURL string `json:"status"` }{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) } diff --git a/examples/langchain-chroma/README.md b/examples/langchain-chroma/README.md index 17207a0..9fd9e31 100644 --- a/examples/langchain-chroma/README.md +++ b/examples/langchain-chroma/README.md @@ -36,6 +36,8 @@ pip install -r requirements.txt In this step we will create a local vector database from our document set, so later we can ask questions on it with the LLM. +Note: **OPENAI_API_KEY** is not required. However the library might fail if no API_KEY is passed by, so an arbitrary string can be used. + ```bash export OPENAI_API_BASE=http://localhost:8080/v1 export OPENAI_API_KEY=sk- diff --git a/examples/langchain-python/README.md b/examples/langchain-python/README.md index a98c48f..2472aab 100644 --- a/examples/langchain-python/README.md +++ b/examples/langchain-python/README.md @@ -26,6 +26,7 @@ pip install langchain pip install openai export OPENAI_API_BASE=http://localhost:8080 +# Note: **OPENAI_API_KEY** is not required. However the library might fail if no API_KEY is passed by, so an arbitrary string can be used. export OPENAI_API_KEY=sk- python test.py diff --git a/examples/query_data/README.md b/examples/query_data/README.md index f7a4e1f..c4e384c 100644 --- a/examples/query_data/README.md +++ b/examples/query_data/README.md @@ -35,6 +35,8 @@ docker-compose up -d --build In this step we will create a local vector database from our document set, so later we can ask questions on it with the LLM. +Note: **OPENAI_API_KEY** is not required. However the library might fail if no API_KEY is passed by, so an arbitrary string can be used. + ```bash export OPENAI_API_BASE=http://localhost:8080/v1 export OPENAI_API_KEY=sk- diff --git a/go.mod b/go.mod index e26e067..1682e50 100644 --- a/go.mod +++ b/go.mod @@ -9,16 +9,15 @@ require ( github.com/go-skynet/bloomz.cpp v0.0.0-20230510223001-e9366e82abdf github.com/go-skynet/go-bert.cpp v0.0.0-20230516063724-cea1ed76a7f4 github.com/go-skynet/go-gpt2.cpp v0.0.0-20230512145559-7bff56f02245 - github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c github.com/go-skynet/go-llama.cpp v0.0.0-20230520082618-a298043ef5f1 github.com/gofiber/fiber/v2 v2.46.0 github.com/google/uuid v1.3.0 github.com/hashicorp/go-multierror v1.1.1 + github.com/imdario/mergo v0.3.15 github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642 github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230519014017-914519e772fd github.com/onsi/ginkgo/v2 v2.9.5 github.com/onsi/gomega v1.27.7 - github.com/otiai10/copy v1.11.0 github.com/otiai10/openaigo v1.1.0 github.com/rs/zerolog v1.29.1 github.com/sashabaranov/go-openai v1.9.4 @@ -52,6 +51,7 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.18 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect + github.com/otiai10/mint v1.5.1 // indirect github.com/philhofer/fwd v1.1.2 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index f70fe7d..c1ff1ac 100644 --- a/go.sum +++ b/go.sum @@ -16,16 +16,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be h1:3Hic97PY6hcw/SY44RuR7kyONkxd744RFeRrqckzwNQ= -github.com/donomii/go-rwkv.cpp v0.0.0-20230503112711-af62fcc432be/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= -github.com/donomii/go-rwkv.cpp v0.0.0-20230510174014-07166da10cb2 h1:YNbUAyIRtaLODitigJU1EM5ubmMu5FmHtYAayJD6Vbg= -github.com/donomii/go-rwkv.cpp v0.0.0-20230510174014-07166da10cb2/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= github.com/donomii/go-rwkv.cpp v0.0.0-20230515123100-6fdd0c338e56 h1:s8/MZdicstKi5fn9D9mKGIQ/q6IWCYCk/BM68i8v51w= github.com/donomii/go-rwkv.cpp v0.0.0-20230515123100-6fdd0c338e56/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35 h1:sMg/SgnMPS/HNUO/2kGm72vl8R9TmNIwgLFr2TNwR3g= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230508180809-bf2449dfae35/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230509153812-1d17cd5bb37a h1:MlyiDLNCM/wjbv8U5Elj18NvaAgl61SGiRUpqQz5dfs= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230509153812-1d17cd5bb37a/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230515153606-95b02d76b04d h1:uxKTbiRnplE2SubchneSf4NChtxLJtOy9VdHnQMT0d0= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230515153606-95b02d76b04d/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= @@ -46,30 +38,12 @@ github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7 github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= -github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f h1:GW8RQa1RVeDF1dOuAP/y6xWVC+BRtf9tJOuEza6Asbg= -github.com/go-skynet/bloomz.cpp v0.0.0-20230510195113-ad7e89a0885f/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA= github.com/go-skynet/bloomz.cpp v0.0.0-20230510223001-e9366e82abdf h1:VJfSn8hIDE+K5+h38M3iAyFXrxpRExMKRdTk33UDxsw= github.com/go-skynet/bloomz.cpp v0.0.0-20230510223001-e9366e82abdf/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA= -github.com/go-skynet/go-bert.cpp v0.0.0-20230510101404-7bb183b147ea h1:8Isk9D+Auth5OuXVAQPC3MO+5zF/2S7mvs2JZLw6a+8= -github.com/go-skynet/go-bert.cpp v0.0.0-20230510101404-7bb183b147ea/go.mod h1:NHwIVvsg7Jh6p0M4uBLVmSMEaPUia6O6yjXUpLWVJmQ= -github.com/go-skynet/go-bert.cpp v0.0.0-20230510124618-ec771ec71557 h1:LD66fKtvP2lmyuuKL8pBat/pVTKUbLs3L5fM/5lyi4w= -github.com/go-skynet/go-bert.cpp v0.0.0-20230510124618-ec771ec71557/go.mod h1:NHwIVvsg7Jh6p0M4uBLVmSMEaPUia6O6yjXUpLWVJmQ= github.com/go-skynet/go-bert.cpp v0.0.0-20230516063724-cea1ed76a7f4 h1:+3KPDf4Wv1VHOkzAfZnlj9qakLSYggTpm80AswhD/FU= github.com/go-skynet/go-bert.cpp v0.0.0-20230516063724-cea1ed76a7f4/go.mod h1:VY0s5KoAI2jRCvQXKuDeEEe8KG7VaWifSNJSk+E1KtY= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6 h1:XshpypO6ekU09CI19vuzke2a1Es1lV5ZaxA7CUehu0E= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230509180201-d49823284cc6/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= github.com/go-skynet/go-gpt2.cpp v0.0.0-20230512145559-7bff56f02245 h1:IcfYY5uH0DdDXEJKJ8bq0WZCd9guPPd3xllaWNy8LOk= github.com/go-skynet/go-gpt2.cpp v0.0.0-20230512145559-7bff56f02245/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= -github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c h1:48I7jpLNGiQeBmF0SFVVbREh8vlG0zN13v9LH5ctXis= -github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230422090028-1f7bff57f66c/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= -github.com/go-skynet/go-llama.cpp v0.0.0-20230510072905-70593fccbe4b h1:qqxrjY8fYDXQahmCMTCACahm1tbiqHLPUHALkFLyBfo= -github.com/go-skynet/go-llama.cpp v0.0.0-20230510072905-70593fccbe4b/go.mod h1:DLfsPD7tYYnpksERH83HSf7qVNW3FIwmz7/zfYO0/6I= -github.com/go-skynet/go-llama.cpp v0.0.0-20230516230554-b7bbefbe0b84 h1:f5iYF75bAr73Tl8AdtFD5Urs/2bsHKPh52K++jLbsfk= -github.com/go-skynet/go-llama.cpp v0.0.0-20230516230554-b7bbefbe0b84/go.mod h1:jxyQ26t1aKC5Gn782w9WWh5n1133PxCOfkuc01xM4RQ= -github.com/go-skynet/go-llama.cpp v0.0.0-20230518171914-33f8c2db53bf h1:D9CLQwr1eqSnV0DM7YGOKhSfNajj2qOA7XAD6+p1/HI= -github.com/go-skynet/go-llama.cpp v0.0.0-20230518171914-33f8c2db53bf/go.mod h1:oA0r4BW8ndyjTMGi1tulsNd7sdg3Ql8MaVFuT1zF6ws= -github.com/go-skynet/go-llama.cpp v0.0.0-20230519203945-3ee537e8cb52 h1:VU64ntI6BuWRMFMzx8OTH26k5BJqY8vgW0Igyn9MYKc= -github.com/go-skynet/go-llama.cpp v0.0.0-20230519203945-3ee537e8cb52/go.mod h1:oA0r4BW8ndyjTMGi1tulsNd7sdg3Ql8MaVFuT1zF6ws= github.com/go-skynet/go-llama.cpp v0.0.0-20230520082618-a298043ef5f1 h1:i0oM2MERUgMIRmjOcv22TDQULxbmY8o9rZKLKKyWXLo= github.com/go-skynet/go-llama.cpp v0.0.0-20230520082618-a298043ef5f1/go.mod h1:oA0r4BW8ndyjTMGi1tulsNd7sdg3Ql8MaVFuT1zF6ws= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= @@ -91,6 +65,8 @@ github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brv github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= +github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM= +github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/klauspost/compress v1.16.3 h1:XuJt9zzcnaz6a16/OU53ZjWp/v7/42WcR5t2a0PcNQY= @@ -113,31 +89,18 @@ github.com/mattn/go-isatty v0.0.18 h1:DOKFKCQ7FNG2L1rbrmstDN4QVRdS89Nkh85u68Uwp9 github.com/mattn/go-isatty v0.0.18/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mudler/go-stable-diffusion v0.0.0-20230516104333-2f32a16b5b24 h1:XfRD/bZom6u4zji7aB0urIVOsPe43KlkzSRrVhlzaOM= -github.com/mudler/go-stable-diffusion v0.0.0-20230516104333-2f32a16b5b24/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642 h1:KTkh3lOUsGqQyP4v+oa38sPFdrZtNnM4HaxTb3epdYs= github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc h1:OPavP/SUsVWVYPhSUZKZeX8yDSQzf4G+BmUmwzrLTyI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230516143155-79d6243fe1bc/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230518171731-546600fb6878 h1:3MFUW2a1Aqm2nMF5f+PNGq55cbxIzkRHQX/o7JVysAo= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230518171731-546600fb6878/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230518183047-94f401889042 h1:bMU75tTgyw6mYDa5NbOHlZ1KYqQgz7heQFfTwuvCGww= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230518183047-94f401889042/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230519014017-914519e772fd h1:kMnZASxCNc8GsPuAV94tltEsfT6T+esuB+rgzdjwFVM= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230519014017-914519e772fd/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/onsi/ginkgo/v2 v2.9.4 h1:xR7vG4IXt5RWx6FfIjyAtsoMAtnc3C/rFXBBd2AjZwE= -github.com/onsi/ginkgo/v2 v2.9.4/go.mod h1:gCQYp2Q+kSoIj7ykSVb9nskRSsR6PUj4AiLywzIhbKM= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= -github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= -github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= github.com/onsi/gomega v1.27.7/go.mod h1:1p8OOlwo2iUUDsHnOrjE5UKYJ+e3W8eQ3qSlRahPmr4= -github.com/otiai10/copy v1.11.0 h1:OKBD80J/mLBrwnzXqGtFCzprFSGioo30JcmR4APsNwc= -github.com/otiai10/copy v1.11.0/go.mod h1:rSaLseMUsZFFbsFGc7wCJnnkTAvdc5L6VWxPE4308Ww= github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks= +github.com/otiai10/mint v1.5.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= github.com/otiai10/openaigo v1.1.0 h1:zRvGBqZUW5PCMgdkJNsPVTBd8tOLCMTipXE5wD2pdTg= github.com/otiai10/openaigo v1.1.0/go.mod h1:792bx6AWTS61weDi2EzKpHHnTF4eDMAlJ5GvAk/mgPg= github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= @@ -153,8 +116,6 @@ github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sashabaranov/go-openai v1.9.3 h1:uNak3Rn5pPsKRs9bdT7RqRZEyej/zdZOEI2/8wvrFtM= -github.com/sashabaranov/go-openai v1.9.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.9.4 h1:KanoCEoowAI45jVXlenMCckutSRr39qOmSi9MyPBfZM= github.com/sashabaranov/go-openai v1.9.4/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4= @@ -198,8 +159,6 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= -golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= -golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -237,8 +196,6 @@ golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= -golang.org/x/tools v0.8.0 h1:vSDcovVPld282ceKgDimkRSC8kpaH1dgyc9UMzlt84Y= -golang.org/x/tools v0.8.0/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index 8c61380..f4f86ae 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" + "github.com/imdario/mergo" "github.com/rs/zerolog/log" "gopkg.in/yaml.v2" ) @@ -92,13 +93,17 @@ func verifyPath(path, basePath string) error { return inTrustedRoot(c, basePath) } -func Apply(basePath, nameOverride string, config *Config) error { +func Apply(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}) error { // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0755) if err != nil { return fmt.Errorf("failed to create base path: %v", err) } + if len(configOverrides) > 0 { + log.Debug().Msgf("Config overrides %+v", configOverrides) + } + // Download files and verify their SHA for _, file := range config.Files { log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) @@ -231,6 +236,10 @@ func Apply(basePath, nameOverride string, config *Config) error { configMap["name"] = name + if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil { + return err + } + // Write updated config file updatedConfigYAML, err := yaml.Marshal(configMap) if err != nil { diff --git a/pkg/gallery/models_test.go b/pkg/gallery/models_test.go index 980b3a9..f0e580e 100644 --- a/pkg/gallery/models_test.go +++ b/pkg/gallery/models_test.go @@ -7,6 +7,7 @@ import ( . "github.com/go-skynet/LocalAI/pkg/gallery" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + "gopkg.in/yaml.v3" ) var _ = Describe("Model test", func() { @@ -18,14 +19,25 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "", c) + err = Apply(tempdir, "", c, map[string]interface{}{}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { _, err = os.Stat(filepath.Join(tempdir, f)) Expect(err).ToNot(HaveOccurred()) } + + content := map[string]interface{}{} + + dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml")) + Expect(err).ToNot(HaveOccurred()) + + err = yaml.Unmarshal(dat, content) + Expect(err).ToNot(HaveOccurred()) + + Expect(content["context_size"]).To(Equal(1024)) }) + It("renames model correctly", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) @@ -33,14 +45,41 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "foo", c) + err = Apply(tempdir, "foo", c, map[string]interface{}{}) + Expect(err).ToNot(HaveOccurred()) + + for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { + _, err = os.Stat(filepath.Join(tempdir, f)) + Expect(err).ToNot(HaveOccurred()) + } + }) + + It("overrides parameters", func() { + tempdir, err := os.MkdirTemp("", "test") + Expect(err).ToNot(HaveOccurred()) + defer os.RemoveAll(tempdir) + c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) + Expect(err).ToNot(HaveOccurred()) + + err = Apply(tempdir, "foo", c, map[string]interface{}{"backend": "foo"}) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { _, err = os.Stat(filepath.Join(tempdir, f)) Expect(err).ToNot(HaveOccurred()) } + + content := map[string]interface{}{} + + dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml")) + Expect(err).ToNot(HaveOccurred()) + + err = yaml.Unmarshal(dat, content) + Expect(err).ToNot(HaveOccurred()) + + Expect(content["backend"]).To(Equal("foo")) }) + It("catches path traversals", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) @@ -48,7 +87,7 @@ var _ = Describe("Model test", func() { c, err := ReadConfigFile(filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) - err = Apply(tempdir, "../../../foo", c) + err = Apply(tempdir, "../../../foo", c, map[string]interface{}{}) Expect(err).To(HaveOccurred()) }) })