package api_test import ( "bytes" "context" "embed" "encoding/json" "fmt" "io/ioutil" "net/http" "os" "path/filepath" "runtime" . "github.com/go-skynet/LocalAI/api" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gopkg.in/yaml.v3" openaigo "github.com/otiai10/openaigo" "github.com/sashabaranov/go-openai" ) type modelApplyRequest struct { ID string `json:"id"` URL string `json:"url"` Name string `json:"name"` Overrides map[string]string `json:"overrides"` } func getModelStatus(url string) (response map[string]interface{}) { // Create the HTTP request resp, err := http.Get(url) if err != nil { fmt.Println("Error creating request:", err) return } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { fmt.Println("Error reading response body:", err) return } // Unmarshal the response into a map[string]interface{} err = json.Unmarshal(body, &response) if err != nil { fmt.Println("Error unmarshaling JSON response:", err) return } return } func getModels(url string) (response []gallery.GalleryModel) { utils.GetURI(url, func(url string, i []byte) error { // Unmarshal YAML data into a struct return json.Unmarshal(i, &response) }) return } func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { //url := "http://localhost:AI/models/apply" // Create the request payload payload, err := json.Marshal(request) if err != nil { fmt.Println("Error marshaling JSON:", err) return } // Create the HTTP request req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) if err != nil { fmt.Println("Error creating request:", err) return } req.Header.Set("Content-Type", "application/json") // Make the request client := &http.Client{} resp, err := client.Do(req) if err != nil { fmt.Println("Error making request:", err) return } defer resp.Body.Close() body, err := ioutil.ReadAll(resp.Body) if err != nil { fmt.Println("Error reading response body:", err) return } // Unmarshal the response into a map[string]interface{} err = json.Unmarshal(body, &response) if err != nil { fmt.Println("Error unmarshaling JSON response:", err) return } return } //go:embed backend-assets/* var backendAssets embed.FS var _ = Describe("API test", func() { var app *fiber.App var modelLoader *model.ModelLoader var client *openai.Client var client2 *openaigo.Client var c context.Context var cancel context.CancelFunc var tmpdir string Context("API with ephemeral models", func() { BeforeEach(func() { var err error tmpdir, err = os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) modelLoader = model.NewModelLoader(tmpdir) c, cancel = context.WithCancel(context.Background()) g := []gallery.GalleryModel{ { Name: "bert", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", }, { Name: "bert2", URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", Overrides: map[string]interface{}{"foo": "bar"}, AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}}, }, } out, err := yaml.Marshal(g) Expect(err).ToNot(HaveOccurred()) err = ioutil.WriteFile(filepath.Join(tmpdir, "gallery_simple.yaml"), out, 0644) Expect(err).ToNot(HaveOccurred()) galleries := []gallery.Gallery{ { Name: "test", URL: "file://" + filepath.Join(tmpdir, "gallery_simple.yaml"), }, } app, err = App(WithContext(c), WithGalleries(galleries), WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir)) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" client2 = openaigo.NewClient("") client2.BaseURL = defaultConfig.BaseURL // Wait for API to be ready client = openai.NewClientWithConfig(defaultConfig) Eventually(func() error { _, err := client.ListModels(context.TODO()) return err }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func() { cancel() app.Shutdown() os.RemoveAll(tmpdir) }) Context("Applying models", func() { It("applies models from a gallery", func() { models := getModels("http://127.0.0.1:9090/models/available") Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models)) response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ ID: "test@bert2", }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) resp := map[string]interface{}{} Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) fmt.Println(response) resp = response return response["processed"].(bool) }, "360s").Should(Equal(true)) Expect(resp["message"]).ToNot(ContainSubstring("error")) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml")) Expect(err).ToNot(HaveOccurred()) _, err = os.ReadFile(filepath.Join(tmpdir, "foo.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["backend"]).To(Equal("bert-embeddings")) Expect(content["foo"]).To(Equal("bar")) models = getModels("http://127.0.0.1:9090/models/available") Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2"))) Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2"))) for _, m := range models { if m.Name == "bert2" { Expect(m.Installed).To(BeTrue()) } else { Expect(m.Installed).To(BeFalse()) } } }) It("overrides models", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", Name: "bert", Overrides: map[string]string{ "backend": "llama", }, }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) fmt.Println(response) return response["processed"].(bool) }, "360s").Should(Equal(true)) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["backend"]).To(Equal("llama")) }) It("apply models without overrides", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", Name: "bert", Overrides: map[string]string{}, }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) fmt.Println(response) return response["processed"].(bool) }, "360s").Should(Equal(true)) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["backend"]).To(Equal("bert-embeddings")) }) It("runs openllama", Label("llama"), func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "github:go-skynet/model-gallery/openllama_3b.yaml", Name: "openllama_3b", Overrides: map[string]string{}, }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) fmt.Println(response) return response["processed"].(bool) }, "360s").Should(Equal(true)) resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "openllama_3b", Prompt: "Count up to five: one, two, three, four, "}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Text).To(ContainSubstring("five")) }) It("runs gpt4all", Label("gpt4all"), func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: "github:go-skynet/model-gallery/gpt4all-j.yaml", Name: "gpt4all-j", Overrides: map[string]string{}, }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) fmt.Println(response) return response["processed"].(bool) }, "360s").Should(Equal(true)) resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-j", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "How are you?"}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).To(ContainSubstring("well")) }) }) }) Context("API query", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) c, cancel = context.WithCancel(context.Background()) var err error app, err = App(WithContext(c), WithModelLoader(modelLoader)) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" client2 = openaigo.NewClient("") client2.BaseURL = defaultConfig.BaseURL // Wait for API to be ready client = openai.NewClientWithConfig(defaultConfig) Eventually(func() error { _, err := client.ListModels(context.TODO()) return err }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func() { cancel() app.Shutdown() }) It("returns the models list", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) Expect(len(models.Models)).To(Equal(10)) }) It("can generate completions", func() { resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel", Prompt: "abcdedfghikl"}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Text).ToNot(BeEmpty()) }) It("can generate chat completions ", func() { resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate completions from model configs", func() { resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "gpt4all", Prompt: "abcdedfghikl"}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Text).ToNot(BeEmpty()) }) It("can generate chat completions from model configs", func() { resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 11 errors occurred:")) }) It("transcribes audio", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } resp, err := client.CreateTranscription( context.Background(), openai.AudioRequest{ Model: openai.Whisper1, FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"), }, ) Expect(err).ToNot(HaveOccurred()) Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting")) }) It("calculate embeddings", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } resp, err := client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ Model: openai.AdaEmbeddingV2, Input: []string{"sun", "cat"}, }, ) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) sunEmbedding := resp.Data[0].Embedding resp2, err := client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ Model: openai.AdaEmbeddingV2, Input: []string{"sun"}, }, ) Expect(err).ToNot(HaveOccurred()) Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) }) Context("backends", func() { It("runs rwkv", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices) > 0).To(BeTrue()) Expect(resp.Choices[0].Text).To(Equal(" five.")) }) }) }) Context("Config file", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) c, cancel = context.WithCancel(context.Background()) var err error app, err = App(WithContext(c), WithModelLoader(modelLoader), WithConfigFile(os.Getenv("CONFIG_FILE"))) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" client2 = openaigo.NewClient("") client2.BaseURL = defaultConfig.BaseURL // Wait for API to be ready client = openai.NewClientWithConfig(defaultConfig) Eventually(func() error { _, err := client.ListModels(context.TODO()) return err }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func() { cancel() app.Shutdown() }) It("can generate chat completions from config file", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) Expect(len(models.Models)).To(Equal(12)) }) It("can generate chat completions from config file", func() { resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate chat completions from config file", func() { resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate edit completions from config file", func() { request := openaigo.EditCreateRequestBody{ Model: "list2", Instruction: "foo", Input: "bar", } resp, err := client2.CreateEdit(context.Background(), request) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Text).ToNot(BeEmpty()) }) }) })