From 78f3c3da487f87a2d61b384c6eec4c062ac68739 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 26 Jun 2023 12:25:38 +0200 Subject: [PATCH] refactor: consolidate usage of GetURI (#674) Signed-off-by: mudler --- api/api.go | 4 +-- api/api_test.go | 29 +++--------------- api/gallery.go | 36 +++++++++------------- pkg/gallery/gallery.go | 6 ++-- pkg/gallery/models.go | 11 +++++++ pkg/gallery/request.go | 58 ----------------------------------- pkg/gallery/request_test.go | 26 +--------------- pkg/utils/uri.go | 28 +++++++++++++++-- pkg/utils/uri_test.go | 36 ++++++++++++++++++++++ pkg/utils/utils_suite_test.go | 13 ++++++++ 10 files changed, 110 insertions(+), 137 deletions(-) create mode 100644 pkg/utils/uri_test.go create mode 100644 pkg/utils/utils_suite_test.go diff --git a/api/api.go b/api/api.go index 527258b..05aec9b 100644 --- a/api/api.go +++ b/api/api.go @@ -80,13 +80,13 @@ func App(opts ...AppOption) (*fiber.App, error) { app.Use(recover.New()) if options.preloadJSONModels != "" { - if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm); err != nil { + if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm, options.galleries); err != nil { return nil, err } } if options.preloadModelsFromPath != "" { - if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm); err != nil { + if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm, options.galleries); err != nil { return nil, err } } diff --git a/api/api_test.go b/api/api_test.go index 05d5e7b..e55ddfe 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -15,6 +15,7 @@ import ( . "github.com/go-skynet/LocalAI/api" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -56,30 +57,10 @@ func getModelStatus(url string) (response map[string]interface{}) { } func getModels(url string) (response []gallery.GalleryModel) { - - //url := "http://localhost:AI/models/apply" - - // Create the request payload - - // Create the HTTP request - resp, err := http.Get(url) - if err != nil { - return nil - } - defer resp.Body.Close() - - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - fmt.Println("Error reading response body:", err) - return - } - - // Unmarshal the response into a map[string]interface{} - err = json.Unmarshal(body, &response) - if err != nil { - fmt.Println("Error unmarshaling JSON response:", err) - return - } + utils.GetURI(url, func(url string, i []byte) error { + // Unmarshal YAML data into a struct + return json.Unmarshal(i, &response) + }) return } diff --git a/api/gallery.go b/api/gallery.go index 46d92f9..1c0cec9 100644 --- a/api/gallery.go +++ b/api/gallery.go @@ -48,9 +48,8 @@ func newGalleryApplier(modelPath string) *galleryApplier { // prepareModel applies a func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { - var config gallery.Config - err := req.Get(&config) + config, err := gallery.GetGalleryConfigFromURL(req.URL) if err != nil { return err } @@ -144,40 +143,35 @@ func displayDownload(fileName string, current string, total string, percentage f } } -func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { +type galleryModel struct { + gallery.GalleryModel + ID string `json:"id"` +} + +func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error { dat, err := os.ReadFile(s) if err != nil { return err } - var requests []gallery.GalleryModel - err = json.Unmarshal(dat, &requests) - if err != nil { - return err - } - - for _, r := range requests { - if err := prepareModel(modelPath, r, cm, displayDownload); err != nil { - return err - } - } - - return nil + return ApplyGalleryFromString(modelPath, string(dat), cm, galleries) } -func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { - var requests []gallery.GalleryModel +func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error { + var requests []galleryModel err := json.Unmarshal([]byte(s), &requests) if err != nil { return err } for _, r := range requests { - if err := prepareModel(modelPath, r, cm, displayDownload); err != nil { - return err + if r.ID == "" { + err = prepareModel(modelPath, r.GalleryModel, cm, displayDownload) + } else { + err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, displayDownload) } } - return nil + return err } func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index d444034..aed5251 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -23,9 +23,7 @@ func InstallModelFromGallery(galleries []Gallery, name string, basePath string, } applyModel := func(model *GalleryModel) error { - var config Config - - err := model.Get(&config) + config, err := GetGalleryConfigFromURL(model.URL) if err != nil { return err } @@ -79,7 +77,7 @@ func AvailableGalleryModels(galleries []Gallery, basePath string) ([]*GalleryMod func getGalleryModels(gallery Gallery, basePath string) ([]*GalleryModel, error) { var models []*GalleryModel = []*GalleryModel{} - err := utils.GetURI(gallery.URL, func(d []byte) error { + err := utils.GetURI(gallery.URL, func(url string, d []byte) error { return yaml.Unmarshal(d, &models) }) if err != nil { diff --git a/pkg/gallery/models.go b/pkg/gallery/models.go index 424b424..4295a99 100644 --- a/pkg/gallery/models.go +++ b/pkg/gallery/models.go @@ -65,6 +65,17 @@ type PromptTemplate struct { Content string `yaml:"content"` } +func GetGalleryConfigFromURL(url string) (Config, error) { + var config Config + err := utils.GetURI(url, func(url string, d []byte) error { + return yaml.Unmarshal(d, &config) + }) + if err != nil { + return config, err + } + return config, nil +} + func ReadConfigFile(filePath string) (*Config, error) { // Read the YAML file yamlFile, err := os.ReadFile(filePath) diff --git a/pkg/gallery/request.go b/pkg/gallery/request.go index c8ccd5e..030ee16 100644 --- a/pkg/gallery/request.go +++ b/pkg/gallery/request.go @@ -1,14 +1,5 @@ package gallery -import ( - "fmt" - "net/url" - "strings" - - "github.com/go-skynet/LocalAI/pkg/utils" - "gopkg.in/yaml.v2" -) - // GalleryModel is the struct used to represent a model in the gallery returned by the endpoint. // It is used to install the model by resolving the URL and downloading the files. // The other fields are used to override the configuration of the model. @@ -34,52 +25,3 @@ type GalleryModel struct { const ( githubURI = "github:" ) - -func (request GalleryModel) DecodeURL() (string, error) { - input := request.URL - var rawURL string - - if strings.HasPrefix(input, githubURI) { - parts := strings.Split(input, ":") - repoParts := strings.Split(parts[1], "@") - branch := "main" - - if len(repoParts) > 1 { - branch = repoParts[1] - } - - repoPath := strings.Split(repoParts[0], "/") - org := repoPath[0] - project := repoPath[1] - projectPath := strings.Join(repoPath[2:], "/") - - rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) - } else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") { - // Handle regular URLs - u, err := url.Parse(input) - if err != nil { - return "", fmt.Errorf("invalid URL: %w", err) - } - rawURL = u.String() - // check if it's a file path - } else if strings.HasPrefix(input, "file://") { - return input, nil - } else { - - return "", fmt.Errorf("invalid URL format: %s", input) - } - - return rawURL, nil -} - -// Get fetches a model from a URL and unmarshals it into a struct -func (request GalleryModel) Get(i interface{}) error { - url, err := request.DecodeURL() - if err != nil { - return err - } - - return utils.GetURI(url, func(d []byte) error { - return yaml.Unmarshal(d, i) - }) -} diff --git a/pkg/gallery/request_test.go b/pkg/gallery/request_test.go index 12a8d06..a9d54e3 100644 --- a/pkg/gallery/request_test.go +++ b/pkg/gallery/request_test.go @@ -6,37 +6,13 @@ import ( . "github.com/onsi/gomega" ) -type example struct { - Name string `yaml:"name"` -} - var _ = Describe("Gallery API tests", func() { - Context("requests", func() { It("parses github with a branch", func() { req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} - var e example - err := req.Get(&e) + e, err := GetGalleryConfigFromURL(req.URL) Expect(err).ToNot(HaveOccurred()) Expect(e.Name).To(Equal("gpt4all-j")) }) - It("parses github without a branch", func() { - req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} - str, err := req.DecodeURL() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) - }) - It("parses github without a branch", func() { - req := GalleryModel{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"} - str, err := req.DecodeURL() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) - }) - It("parses URLS", func() { - req := GalleryModel{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"} - str, err := req.DecodeURL() - Expect(err).ToNot(HaveOccurred()) - Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) - }) }) }) diff --git a/pkg/utils/uri.go b/pkg/utils/uri.go index 753a283..9552745 100644 --- a/pkg/utils/uri.go +++ b/pkg/utils/uri.go @@ -1,12 +1,34 @@ package utils import ( + "fmt" "io/ioutil" "net/http" "strings" ) -func GetURI(url string, f func(i []byte) error) error { +const ( + githubURI = "github:" +) + +func GetURI(url string, f func(url string, i []byte) error) error { + if strings.HasPrefix(url, githubURI) { + parts := strings.Split(url, ":") + repoParts := strings.Split(parts[1], "@") + branch := "main" + + if len(repoParts) > 1 { + branch = repoParts[1] + } + + repoPath := strings.Split(repoParts[0], "/") + org := repoPath[0] + project := repoPath[1] + projectPath := strings.Join(repoPath[2:], "/") + + url = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) + } + if strings.HasPrefix(url, "file://") { rawURL := strings.TrimPrefix(url, "file://") // Read the response body @@ -16,7 +38,7 @@ func GetURI(url string, f func(i []byte) error) error { } // Unmarshal YAML data into a struct - return f(body) + return f(url, body) } // Send a GET request to the URL @@ -33,5 +55,5 @@ func GetURI(url string, f func(i []byte) error) error { } // Unmarshal YAML data into a struct - return f(body) + return f(url, body) } diff --git a/pkg/utils/uri_test.go b/pkg/utils/uri_test.go new file mode 100644 index 0000000..79a9f4a --- /dev/null +++ b/pkg/utils/uri_test.go @@ -0,0 +1,36 @@ +package utils_test + +import ( + . "github.com/go-skynet/LocalAI/pkg/utils" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Gallery API tests", func() { + Context("URI", func() { + It("parses github with a branch", func() { + Expect( + GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml", func(url string, i []byte) error { + Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) + return nil + }), + ).ToNot(HaveOccurred()) + }) + It("parses github without a branch", func() { + Expect( + GetURI("github:go-skynet/model-gallery/gpt4all-j.yaml@main", func(url string, i []byte) error { + Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) + return nil + }), + ).ToNot(HaveOccurred()) + }) + It("parses github with urls", func() { + Expect( + GetURI("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml", func(url string, i []byte) error { + Expect(url).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) + return nil + }), + ).ToNot(HaveOccurred()) + }) + }) +}) diff --git a/pkg/utils/utils_suite_test.go b/pkg/utils/utils_suite_test.go new file mode 100644 index 0000000..8260e31 --- /dev/null +++ b/pkg/utils/utils_suite_test.go @@ -0,0 +1,13 @@ +package utils_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestUtils(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Utils test suite") +}