From c806eae0de5bdd177d5e0efc8e97a63100c266a9 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 27 Apr 2023 06:18:18 +0200 Subject: [PATCH] feat: config files and SSE (#83) Signed-off-by: mudler Signed-off-by: Tyler Gillson Co-authored-by: Tyler Gillson --- .dockerignore | 1 + .gitignore | 3 +- Makefile | 4 +- README.md | 48 +- api/api.go | 409 +----------------- api/api_test.go | 67 ++- api/config.go | 100 +++++ api/openai.go | 396 +++++++++++++++++ api/prediction.go | 188 ++++++++ examples/README.md | 11 + examples/chatbot-ui/README.md | 26 ++ examples/chatbot-ui/docker-compose.yaml | 24 + examples/chatbot-ui/models/completion.tmpl | 1 + examples/chatbot-ui/models/gpt-3.5-turbo.yaml | 17 + examples/chatbot-ui/models/gpt4all.tmpl | 4 + main.go | 7 +- pkg/model/loader.go | 34 +- tests/fixtures/completion.tmpl | 1 + tests/fixtures/config.yaml | 28 ++ tests/fixtures/ggml-gpt4all-j.tmpl | 4 + tests/fixtures/gpt4.yaml | 14 + tests/fixtures/gpt4_2.yaml | 14 + 22 files changed, 983 insertions(+), 418 deletions(-) create mode 100644 api/config.go create mode 100644 api/openai.go create mode 100644 api/prediction.go create mode 100644 examples/README.md create mode 100644 examples/chatbot-ui/README.md create mode 100644 examples/chatbot-ui/docker-compose.yaml create mode 100644 examples/chatbot-ui/models/completion.tmpl create mode 100644 examples/chatbot-ui/models/gpt-3.5-turbo.yaml create mode 100644 examples/chatbot-ui/models/gpt4all.tmpl create mode 100644 tests/fixtures/completion.tmpl create mode 100644 tests/fixtures/config.yaml create mode 100644 tests/fixtures/ggml-gpt4all-j.tmpl create mode 100644 tests/fixtures/gpt4.yaml create mode 100644 tests/fixtures/gpt4_2.yaml diff --git a/.dockerignore b/.dockerignore index 604f0f2..dff08bd 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,2 @@ models +examples/chatbot-ui/models \ No newline at end of file diff --git a/.gitignore b/.gitignore index 4b4ab00..4edbf4c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,5 @@ local-ai !charts/* # Ignore models -models/*.bin -models/ggml-* +models/* test-models/ \ No newline at end of file diff --git a/Makefile b/Makefile index c7ad7d5..abbb96f 100644 --- a/Makefile +++ b/Makefile @@ -102,9 +102,11 @@ run: prepare test-models/testmodel: mkdir test-models wget https://huggingface.co/concedo/cerebras-111M-ggml/resolve/main/cerberas-111m-q4_0.bin -O test-models/testmodel + cp tests/fixtures/* test-models test: prepare test-models/testmodel - @C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} MODELS_PATH=$(abspath ./)/test-models $(GOCMD) test -v ./... + cp tests/fixtures/* test-models + @C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) test -v -timeout 20m ./... ## Help: help: ## Show this help. diff --git a/README.md b/README.md index df532d5..414a6c7 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,9 @@ git clone https://github.com/go-skynet/LocalAI cd LocalAI +# (optional) Checkout a specific LocalAI tag +# git checkout -b build + # copy your models to models/ cp your-model.bin models/ @@ -80,6 +83,9 @@ git clone https://github.com/go-skynet/LocalAI cd LocalAI +# (optional) Checkout a specific LocalAI tag +# git checkout -b build + # Download gpt4all-j to models/ wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j @@ -106,6 +112,12 @@ curl http://localhost:8080/v1/chat/completions -H "Content-Type: application/jso ``` +To build locally, run `make build` (see below). + +## Other examples + +To see other examples on how to integrate with other projects, see: [examples](https://github.com/go-skynet/LocalAI/tree/master/examples/). + ## Prompt templates The API doesn't inject a default prompt for talking to the model. You have to use a prompt similar to what's described in the standford-alpaca docs: https://github.com/tatsu-lab/stanford_alpaca#data-release. @@ -169,6 +181,9 @@ Once the server is running, you can start making requests to it using HTTP, usin +## Advanced configuration + + ### Supported OpenAI API endpoints You can check out the [OpenAI API reference](https://platform.openai.com/docs/api-reference/chat/create). @@ -223,22 +238,11 @@ curl http://localhost:8080/v1/models -## Using other models - -gpt4all (https://github.com/nomic-ai/gpt4all) works as well, however the original model needs to be converted (same applies for old alpaca models, too): - -```bash -wget -O tokenizer.model https://huggingface.co/decapoda-research/llama-30b-hf/resolve/main/tokenizer.model -mkdir models -cp gpt4all.. models/ -git clone https://gist.github.com/eiz/828bddec6162a023114ce19146cb2b82 -pip install sentencepiece -python 828bddec6162a023114ce19146cb2b82/gistfile1.txt models tokenizer.model -# There will be a new model with the ".tmp" extension, you have to use that one! -``` +## Helm Chart Installation (run LocalAI in Kubernetes) +LocalAI can be installed inside Kubernetes with helm. -## Helm Chart Installation (run LocalAI in Kubernetes) +
The local-ai Helm chart supports two options for the LocalAI server's models directory: 1. Basic deployment with no persistent volume. You must manually update the Deployment to configure your own models directory. @@ -258,6 +262,12 @@ The local-ai Helm chart supports two options for the LocalAI server's models dir ``` This will update the local-ai Deployment to mount the PV that was provisioned by the DataVolume. +
+ +## Blog posts + +- https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65 + ## Windows compatibility It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/LocalAI/issues/2 @@ -335,17 +345,25 @@ AutoGPT currently doesn't allow to set a different API URL, but there is a PR op +## Projects already using LocalAI to run local models + +Feel free to open up a PR to get your project listed! + +- [Kairos](https://github.com/kairos-io/kairos) +- [k8sgpt](https://github.com/k8sgpt-ai/k8sgpt#running-local-models) ## Short-term roadmap - [x] Mimic OpenAI API (https://github.com/go-skynet/LocalAI/issues/10) - [ ] Binary releases (https://github.com/go-skynet/LocalAI/issues/6) -- [ ] Upstream our golang bindings to llama.cpp (https://github.com/ggerganov/llama.cpp/issues/351) +- [ ] Upstream our golang bindings to llama.cpp (https://github.com/ggerganov/llama.cpp/issues/351) and gpt4all - [x] Multi-model support - [ ] Have a webUI! - [ ] Allow configuration of defaults for models. - [ ] Enable automatic downloading of models from a curated gallery, with only free-licensed models. +[![LocalAI Star history Chart](https://api.star-history.com/svg?repos=go-skynet/LocalAI&type=Date)](https://star-history.com/#go-skynet/LocalAI&Date) + ## License MIT diff --git a/api/api.go b/api/api.go index 5c401ad..85cbef2 100644 --- a/api/api.go +++ b/api/api.go @@ -1,16 +1,9 @@ package api import ( - "encoding/json" "errors" - "fmt" - "strings" - "sync" model "github.com/go-skynet/LocalAI/pkg/model" - gpt2 "github.com/go-skynet/go-gpt2.cpp" - gptj "github.com/go-skynet/go-gpt4all-j.cpp" - llama "github.com/go-skynet/go-llama.cpp" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/recover" @@ -18,375 +11,7 @@ import ( "github.com/rs/zerolog/log" ) -// APIError provides error information returned by the OpenAI API. -type APIError struct { - Code any `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` -} - -type ErrorResponse struct { - Error *APIError `json:"error,omitempty"` -} - -type OpenAIResponse struct { - Created int `json:"created,omitempty"` - Object string `json:"chat.completion,omitempty"` - ID string `json:"id,omitempty"` - Model string `json:"model,omitempty"` - Choices []Choice `json:"choices,omitempty"` -} - -type Choice struct { - Index int `json:"index,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - Message *Message `json:"message,omitempty"` - Text string `json:"text,omitempty"` -} - -type Message struct { - Role string `json:"role,omitempty"` - Content string `json:"content,omitempty"` -} - -type OpenAIModel struct { - ID string `json:"id"` - Object string `json:"object"` -} - -type OpenAIRequest struct { - Model string `json:"model"` - - // Prompt is read only by completion API calls - Prompt string `json:"prompt"` - - Stop string `json:"stop"` - - // Messages is read only by chat/completion API calls - Messages []Message `json:"messages"` - - Echo bool `json:"echo"` - // Common options between all the API calls - TopP float64 `json:"top_p"` - TopK int `json:"top_k"` - Temperature float64 `json:"temperature"` - Maxtokens int `json:"max_tokens"` - - N int `json:"n"` - - // Custom parameters - not present in the OpenAI API - Batch int `json:"batch"` - F16 bool `json:"f16kv"` - IgnoreEOS bool `json:"ignore_eos"` - RepeatPenalty float64 `json:"repeat_penalty"` - Keep int `json:"n_keep"` - - Seed int `json:"seed"` -} - -// https://platform.openai.com/docs/api-reference/completions -func openAIEndpoint(chat, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - var err error - var model *llama.LLama - var gptModel *gptj.GPTJ - var gpt2Model *gpt2.GPT2 - var stableLMModel *gpt2.StableLM - - input := new(OpenAIRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - modelFile := input.Model - received, _ := json.Marshal(input) - - log.Debug().Msgf("Request received: %s", string(received)) - - // Set model from bearer token, if available - bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) - - // If no model was specified, take the first available - if modelFile == "" { - models, _ := loader.ListModels() - if len(models) > 0 { - modelFile = models[0] - log.Debug().Msgf("No model specified, using: %s", modelFile) - } - } - - // If no model is found or specified, we bail out - if modelFile == "" && !bearerExists { - return fmt.Errorf("no model specified") - } - - // If a model is found in bearer token takes precedence - if bearerExists { - log.Debug().Msgf("Using model from bearer token: %s", bearer) - modelFile = bearer - } - - // Try to load the model - var llamaerr, gpt2err, gptjerr, stableerr error - llamaOpts := []llama.ModelOption{} - if ctx != 0 { - llamaOpts = append(llamaOpts, llama.SetContext(ctx)) - } - if f16 { - llamaOpts = append(llamaOpts, llama.EnableF16Memory) - } - - // TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation.. - model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...) - if llamaerr != nil { - gptModel, gptjerr = loader.LoadGPTJModel(modelFile) - if gptjerr != nil { - gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile) - if gpt2err != nil { - stableLMModel, stableerr = loader.LoadStableLMModel(modelFile) - if stableerr != nil { - return fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors - } - } - } - } - - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[modelFile] - if !ok { - m := &sync.Mutex{} - mutexes[modelFile] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - // Set the parameters for the language model prediction - topP := input.TopP - if topP == 0 { - topP = 0.7 - } - topK := input.TopK - if topK == 0 { - topK = 80 - } - - temperature := input.Temperature - if temperature == 0 { - temperature = 0.9 - } - - tokens := input.Maxtokens - if tokens == 0 { - tokens = 512 - } - - predInput := input.Prompt - if chat { - mess := []string{} - // TODO: encode roles - for _, i := range input.Messages { - mess = append(mess, i.Content) - } - - predInput = strings.Join(mess, "\n") - } - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(modelFile, struct { - Input string - }{Input: predInput}) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } - - result := []Choice{} - - n := input.N - - if input.N == 0 { - n = 1 - } - - var predFunc func() (string, error) - switch { - case stableLMModel != nil: - predFunc = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []gpt2.PredictOption{ - gpt2.SetTemperature(temperature), - gpt2.SetTopP(topP), - gpt2.SetTopK(topK), - gpt2.SetTokens(tokens), - gpt2.SetThreads(threads), - } - - if input.Batch != 0 { - predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch)) - } - - if input.Seed != 0 { - predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed)) - } - - return stableLMModel.Predict( - predInput, - predictOptions..., - ) - } - case gpt2Model != nil: - predFunc = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []gpt2.PredictOption{ - gpt2.SetTemperature(temperature), - gpt2.SetTopP(topP), - gpt2.SetTopK(topK), - gpt2.SetTokens(tokens), - gpt2.SetThreads(threads), - } - - if input.Batch != 0 { - predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch)) - } - - if input.Seed != 0 { - predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed)) - } - - return gpt2Model.Predict( - predInput, - predictOptions..., - ) - } - case gptModel != nil: - predFunc = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []gptj.PredictOption{ - gptj.SetTemperature(temperature), - gptj.SetTopP(topP), - gptj.SetTopK(topK), - gptj.SetTokens(tokens), - gptj.SetThreads(threads), - } - - if input.Batch != 0 { - predictOptions = append(predictOptions, gptj.SetBatch(input.Batch)) - } - - if input.Seed != 0 { - predictOptions = append(predictOptions, gptj.SetSeed(input.Seed)) - } - - return gptModel.Predict( - predInput, - predictOptions..., - ) - } - case model != nil: - predFunc = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []llama.PredictOption{ - llama.SetTemperature(temperature), - llama.SetTopP(topP), - llama.SetTopK(topK), - llama.SetTokens(tokens), - llama.SetThreads(threads), - } - - if debug { - predictOptions = append(predictOptions, llama.Debug) - } - - if input.Stop != "" { - predictOptions = append(predictOptions, llama.SetStopWords(input.Stop)) - } - - if input.RepeatPenalty != 0 { - predictOptions = append(predictOptions, llama.SetPenalty(input.RepeatPenalty)) - } - - if input.Keep != 0 { - predictOptions = append(predictOptions, llama.SetNKeep(input.Keep)) - } - - if input.Batch != 0 { - predictOptions = append(predictOptions, llama.SetBatch(input.Batch)) - } - - if input.F16 { - predictOptions = append(predictOptions, llama.EnableF16KV) - } - - if input.IgnoreEOS { - predictOptions = append(predictOptions, llama.IgnoreEOS) - } - - if input.Seed != 0 { - predictOptions = append(predictOptions, llama.SetSeed(input.Seed)) - } - - return model.Predict( - predInput, - predictOptions..., - ) - } - } - - for i := 0; i < n; i++ { - prediction, err := predFunc() - if err != nil { - return err - } - - if input.Echo { - prediction = predInput + prediction - } - - if chat { - result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}}) - } else { - result = append(result, Choice{Text: prediction}) - } - } - - jsonResult, _ := json.Marshal(result) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - }) - } -} - -func listModels(loader *model.ModelLoader) func(ctx *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - models, err := loader.ListModels() - if err != nil { - return err - } - - dataModels := []OpenAIModel{} - for _, m := range models { - dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) - } - return c.JSON(struct { - Object string `json:"object"` - Data []OpenAIModel `json:"data"` - }{ - Object: "list", - Data: dataModels, - }) - } -} - -func App(loader *model.ModelLoader, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App { +func App(configFile string, loader *model.ModelLoader, threads, ctxSize int, f16 bool, debug, disableMessage bool) *fiber.App { zerolog.SetGlobalLevel(zerolog.InfoLevel) if debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) @@ -415,23 +40,35 @@ func App(loader *model.ModelLoader, threads, ctxSize int, f16 bool, debug, disab }, }) + cm := make(ConfigMerger) + if err := cm.LoadConfigs(loader.ModelPath); err != nil { + log.Error().Msgf("error loading config files: %s", err.Error()) + } + + if configFile != "" { + if err := cm.LoadConfigFile(configFile); err != nil { + log.Error().Msgf("error loading config file: %s", err.Error()) + } + } + + if debug { + for k, v := range cm { + log.Debug().Msgf("Model: %s (config: %+v)", k, v) + } + } // Default middleware config app.Use(recover.New()) app.Use(cors.New()) - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mu := map[string]*sync.Mutex{} - var mumutex = &sync.Mutex{} - // openAI compatible API endpoint - app.Post("/v1/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu)) - app.Post("/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu)) + app.Post("/v1/chat/completions", openAIEndpoint(cm, true, debug, loader, threads, ctxSize, f16)) + app.Post("/chat/completions", openAIEndpoint(cm, true, debug, loader, threads, ctxSize, f16)) - app.Post("/v1/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu)) - app.Post("/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu)) + app.Post("/v1/completions", openAIEndpoint(cm, false, debug, loader, threads, ctxSize, f16)) + app.Post("/completions", openAIEndpoint(cm, false, debug, loader, threads, ctxSize, f16)) - app.Get("/v1/models", listModels(loader)) - app.Get("/models", listModels(loader)) + app.Get("/v1/models", listModels(loader, cm)) + app.Get("/models", listModels(loader, cm)) return app } diff --git a/api/api_test.go b/api/api_test.go index 53d1516..6cd90e5 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -21,7 +21,7 @@ var _ = Describe("API test", func() { Context("API query", func() { BeforeEach(func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) - app = App(modelLoader, 1, 512, false, false, true) + app = App("", modelLoader, 1, 512, false, true, true) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -40,7 +40,7 @@ var _ = Describe("API test", func() { It("returns the models list", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) - Expect(len(models.Models)).To(Equal(1)) + Expect(len(models.Models)).To(Equal(3)) Expect(models.Models[0].ID).To(Equal("testmodel")) }) It("can generate completions", func() { @@ -49,10 +49,73 @@ var _ = Describe("API test", func() { Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Text).ToNot(BeEmpty()) }) + + It("can generate chat completions ", func() { + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Choices)).To(Equal(1)) + Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) + }) + + It("can generate completions from model configs", func() { + resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "gpt4all", Prompt: "abcdedfghikl"}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Choices)).To(Equal(1)) + Expect(resp.Choices[0].Text).ToNot(BeEmpty()) + }) + + It("can generate chat completions from model configs", func() { + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Choices)).To(Equal(1)) + Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) + }) + It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: llama: model does not exist")) }) + + }) + + Context("Config file", func() { + BeforeEach(func() { + modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) + app = App(os.Getenv("CONFIG_FILE"), modelLoader, 1, 512, false, true, true) + go app.Listen("127.0.0.1:9090") + + defaultConfig := openai.DefaultConfig("") + defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + + // Wait for API to be ready + client = openai.NewClientWithConfig(defaultConfig) + Eventually(func() error { + _, err := client.ListModels(context.TODO()) + return err + }, "2m").ShouldNot(HaveOccurred()) + }) + AfterEach(func() { + app.Shutdown() + }) + It("can generate chat completions from config file", func() { + + models, err := client.ListModels(context.TODO()) + Expect(err).ToNot(HaveOccurred()) + Expect(len(models.Models)).To(Equal(5)) + Expect(models.Models[0].ID).To(Equal("testmodel")) + }) + It("can generate chat completions from config file", func() { + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Choices)).To(Equal(1)) + Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) + }) + It("can generate chat completions from config file", func() { + resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "abcdedfghikl"}}}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Choices)).To(Equal(1)) + Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) + }) }) }) diff --git a/api/config.go b/api/config.go new file mode 100644 index 0000000..848f25c --- /dev/null +++ b/api/config.go @@ -0,0 +1,100 @@ +package api + +import ( + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strings" + + "gopkg.in/yaml.v3" +) + +type Config struct { + OpenAIRequest `yaml:"parameters"` + Name string `yaml:"name"` + StopWords []string `yaml:"stopwords"` + Cutstrings []string `yaml:"cutstrings"` + TrimSpace []string `yaml:"trimspace"` + ContextSize int `yaml:"context_size"` + F16 bool `yaml:"f16"` + Threads int `yaml:"threads"` + Debug bool `yaml:"debug"` + Roles map[string]string `yaml:"roles"` + TemplateConfig TemplateConfig `yaml:"template"` +} + +type TemplateConfig struct { + Completion string `yaml:"completion"` + Chat string `yaml:"chat"` +} + +type ConfigMerger map[string]Config + +func ReadConfigFile(file string) ([]*Config, error) { + c := &[]*Config{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + return *c, nil +} + +func ReadConfig(file string) (*Config, error) { + c := &Config{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + return c, nil +} + +func (cm ConfigMerger) LoadConfigFile(file string) error { + c, err := ReadConfigFile(file) + if err != nil { + return fmt.Errorf("cannot load config file: %w", err) + } + + for _, cc := range c { + cm[cc.Name] = *cc + } + return nil +} + +func (cm ConfigMerger) LoadConfig(file string) error { + c, err := ReadConfig(file) + if err != nil { + return fmt.Errorf("cannot read config file: %w", err) + } + + cm[c.Name] = *c + return nil +} + +func (cm ConfigMerger) LoadConfigs(path string) error { + files, err := ioutil.ReadDir(path) + if err != nil { + return err + } + + for _, file := range files { + // Skip templates, YAML and .keep files + if !strings.Contains(file.Name(), ".yaml") { + continue + } + c, err := ReadConfig(filepath.Join(path, file.Name())) + if err == nil { + cm[c.Name] = *c + } + } + + return nil +} diff --git a/api/openai.go b/api/openai.go new file mode 100644 index 0000000..3cb9b59 --- /dev/null +++ b/api/openai.go @@ -0,0 +1,396 @@ +package api + +import ( + "bufio" + "encoding/json" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + "sync" + + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +// APIError provides error information returned by the OpenAI API. +type APIError struct { + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` +} + +type ErrorResponse struct { + Error *APIError `json:"error,omitempty"` +} + +type OpenAIResponse struct { + Created int `json:"created,omitempty"` + Object string `json:"object,omitempty"` + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Choices []Choice `json:"choices,omitempty"` +} + +type Choice struct { + Index int `json:"index,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Message *Message `json:"message,omitempty"` + Delta *Message `json:"delta,omitempty"` + Text string `json:"text,omitempty"` +} + +type Message struct { + Role string `json:"role,omitempty" yaml:"role"` + Content string `json:"content,omitempty" yaml:"content"` +} + +type OpenAIModel struct { + ID string `json:"id"` + Object string `json:"object"` +} + +type OpenAIRequest struct { + Model string `json:"model" yaml:"model"` + + // Prompt is read only by completion API calls + Prompt string `json:"prompt" yaml:"prompt"` + + Stop string `json:"stop" yaml:"stop"` + + // Messages is read only by chat/completion API calls + Messages []Message `json:"messages" yaml:"messages"` + + Stream bool `json:"stream"` + Echo bool `json:"echo"` + // Common options between all the API calls + TopP float64 `json:"top_p" yaml:"top_p"` + TopK int `json:"top_k" yaml:"top_k"` + Temperature float64 `json:"temperature" yaml:"temperature"` + Maxtokens int `json:"max_tokens" yaml:"max_tokens"` + + N int `json:"n"` + + // Custom parameters - not present in the OpenAI API + Batch int `json:"batch" yaml:"batch"` + F16 bool `json:"f16" yaml:"f16"` + IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` + RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` + Keep int `json:"n_keep" yaml:"n_keep"` + + Seed int `json:"seed" yaml:"seed"` +} + +func defaultRequest(modelFile string) OpenAIRequest { + return OpenAIRequest{ + TopP: 0.7, + TopK: 80, + Maxtokens: 512, + Temperature: 0.9, + Model: modelFile, + } +} + +func updateConfig(config *Config, input *OpenAIRequest) { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != 0 { + config.TopK = input.TopK + } + if input.TopP != 0 { + config.TopP = input.TopP + } + + if input.Temperature != 0 { + config.Temperature = input.Temperature + } + + if input.Maxtokens != 0 { + config.Maxtokens = input.Maxtokens + } + + if input.Stop != "" { + config.StopWords = append(config.StopWords, input.Stop) + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.F16 { + config.F16 = input.F16 + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != 0 { + config.Seed = input.Seed + } +} + +var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) +var mu sync.Mutex = sync.Mutex{} + +// https://platform.openai.com/docs/api-reference/completions +func openAIEndpoint(cm ConfigMerger, chat, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + input := new(OpenAIRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + if input.Stream { + log.Debug().Msgf("Stream request received") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + c.Set("Content-Type", "text/event-stream; charset=utf-8") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + + modelFile := input.Model + received, _ := json.Marshal(input) + + log.Debug().Msgf("Request received: %s", string(received)) + + // Set model from bearer token, if available + bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) + + // If no model was specified, take the first available + if modelFile == "" && !bearerExists { + models, _ := loader.ListModels() + if len(models) > 0 { + modelFile = models[0] + log.Debug().Msgf("No model specified, using: %s", modelFile) + } else { + log.Debug().Msgf("No model specified, returning error") + return fmt.Errorf("no model specified") + } + } + + // If a model is found in bearer token takes precedence + if bearerExists { + log.Debug().Msgf("Using model from bearer token: %s", bearer) + modelFile = bearer + } + + // Load a config file if present after the model name + modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") + if _, err := os.Stat(modelConfig); err == nil { + if err := cm.LoadConfig(modelConfig); err != nil { + return fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + } + + var config *Config + cfg, exists := cm[modelFile] + if !exists { + config = &Config{ + OpenAIRequest: defaultRequest(modelFile), + } + } else { + config = &cfg + } + + // Set the parameters for the language model prediction + updateConfig(config, input) + + if threads != 0 { + config.Threads = threads + } + if ctx != 0 { + config.ContextSize = ctx + } + if f16 { + config.F16 = true + } + + if debug { + config.Debug = true + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + predInput := input.Prompt + if chat { + mess := []string{} + for _, i := range input.Messages { + r := config.Roles[i.Role] + if r == "" { + r = i.Role + } + + content := fmt.Sprint(r, " ", i.Content) + mess = append(mess, content) + } + + predInput = strings.Join(mess, "\n") + } + + templateFile := config.Model + if config.TemplateConfig.Chat != "" && chat { + templateFile = config.TemplateConfig.Chat + } + + if config.TemplateConfig.Completion != "" && !chat { + templateFile = config.TemplateConfig.Completion + } + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(templateFile, struct { + Input string + }{Input: predInput}) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + + result := []Choice{} + + n := input.N + + if input.N == 0 { + n = 1 + } + + // get the model function to call for the result + predFunc, err := ModelInference(predInput, loader, *config) + if err != nil { + return err + } + + finetunePrediction := func(prediction string) string { + if config.Echo { + prediction = predInput + prediction + } + + for _, c := range config.Cutstrings { + mu.Lock() + reg, ok := cutstrings[c] + if !ok { + cutstrings[c] = regexp.MustCompile(c) + reg = cutstrings[c] + } + mu.Unlock() + prediction = reg.ReplaceAllString(prediction, "") + } + + for _, c := range config.TrimSpace { + prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) + } + return prediction + } + + for i := 0; i < n; i++ { + prediction, err := predFunc() + if err != nil { + return err + } + + prediction = finetunePrediction(prediction) + + if chat { + if input.Stream { + result = append(result, Choice{Delta: &Message{Role: "assistant", Content: prediction}}) + } else { + result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}}) + } + } else { + result = append(result, Choice{Text: prediction}) + } + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + } + if input.Stream && chat { + resp.Object = "chat.completion.chunk" + } else if chat { + resp.Object = "chat.completion" + } else { + resp.Object = "text_completion" + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + if input.Stream { + log.Debug().Msgf("Handling stream request") + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + fmt.Fprintf(w, "event: data\n") + w.Flush() + + fmt.Fprintf(w, "data: %s\n\n", jsonResult) + w.Flush() + + fmt.Fprintf(w, "event: data\n") + w.Flush() + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{Choice{FinishReason: "stop"}}, + } + respData, _ := json.Marshal(resp) + + fmt.Fprintf(w, "data: %s\n\n", respData) + w.Flush() + + // fmt.Fprintf(w, "data: [DONE]\n\n") + // w.Flush() + })) + return nil + } else { + // Return the prediction in the response body + return c.JSON(resp) + } + } +} + +func listModels(loader *model.ModelLoader, cm ConfigMerger) func(ctx *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + models, err := loader.ListModels() + if err != nil { + return err + } + var mm map[string]interface{} = map[string]interface{}{} + + dataModels := []OpenAIModel{} + for _, m := range models { + mm[m] = nil + dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) + } + + for k := range cm { + if _, exists := mm[k]; !exists { + dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) + } + } + + return c.JSON(struct { + Object string `json:"object"` + Data []OpenAIModel `json:"data"` + }{ + Object: "list", + Data: dataModels, + }) + } +} diff --git a/api/prediction.go b/api/prediction.go new file mode 100644 index 0000000..dfa8b60 --- /dev/null +++ b/api/prediction.go @@ -0,0 +1,188 @@ +package api + +import ( + "fmt" + "sync" + + model "github.com/go-skynet/LocalAI/pkg/model" + gpt2 "github.com/go-skynet/go-gpt2.cpp" + gptj "github.com/go-skynet/go-gpt4all-j.cpp" + llama "github.com/go-skynet/go-llama.cpp" +) + +// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 +var mutexMap sync.Mutex +var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) + +func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (string, error), error) { + var model *llama.LLama + var gptModel *gptj.GPTJ + var gpt2Model *gpt2.GPT2 + var stableLMModel *gpt2.StableLM + + modelFile := c.Model + + // Try to load the model + var llamaerr, gpt2err, gptjerr, stableerr error + llamaOpts := []llama.ModelOption{} + if c.ContextSize != 0 { + llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize)) + } + if c.F16 { + llamaOpts = append(llamaOpts, llama.EnableF16Memory) + } + + // TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation.. + model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...) + if llamaerr != nil { + gptModel, gptjerr = loader.LoadGPTJModel(modelFile) + if gptjerr != nil { + gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile) + if gpt2err != nil { + stableLMModel, stableerr = loader.LoadStableLMModel(modelFile) + if stableerr != nil { + return nil, fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors + } + } + } + } + + var fn func() (string, error) + + switch { + case stableLMModel != nil: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []gpt2.PredictOption{ + gpt2.SetTemperature(c.Temperature), + gpt2.SetTopP(c.TopP), + gpt2.SetTopK(c.TopK), + gpt2.SetTokens(c.Maxtokens), + gpt2.SetThreads(c.Threads), + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch)) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed)) + } + + return stableLMModel.Predict( + s, + predictOptions..., + ) + } + case gpt2Model != nil: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []gpt2.PredictOption{ + gpt2.SetTemperature(c.Temperature), + gpt2.SetTopP(c.TopP), + gpt2.SetTopK(c.TopK), + gpt2.SetTokens(c.Maxtokens), + gpt2.SetThreads(c.Threads), + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch)) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed)) + } + + return gpt2Model.Predict( + s, + predictOptions..., + ) + } + case gptModel != nil: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []gptj.PredictOption{ + gptj.SetTemperature(c.Temperature), + gptj.SetTopP(c.TopP), + gptj.SetTopK(c.TopK), + gptj.SetTokens(c.Maxtokens), + gptj.SetThreads(c.Threads), + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, gptj.SetBatch(c.Batch)) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, gptj.SetSeed(c.Seed)) + } + + return gptModel.Predict( + s, + predictOptions..., + ) + } + case model != nil: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []llama.PredictOption{ + llama.SetTemperature(c.Temperature), + llama.SetTopP(c.TopP), + llama.SetTopK(c.TopK), + llama.SetTokens(c.Maxtokens), + llama.SetThreads(c.Threads), + } + + if c.Debug { + predictOptions = append(predictOptions, llama.Debug) + } + + predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) + + if c.RepeatPenalty != 0 { + predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) + } + + if c.Keep != 0 { + predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) + } + + if c.F16 { + predictOptions = append(predictOptions, llama.EnableF16KV) + } + + if c.IgnoreEOS { + predictOptions = append(predictOptions, llama.IgnoreEOS) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) + } + + return model.Predict( + s, + predictOptions..., + ) + } + } + + return func() (string, error) { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mutexMap.Lock() + l, ok := mutexes[modelFile] + if !ok { + m := &sync.Mutex{} + mutexes[modelFile] = m + l = m + } + mutexMap.Unlock() + l.Lock() + defer l.Unlock() + + return fn() + }, nil +} diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..0c64623 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,11 @@ +# Examples + +Here is a list of projects that can easily be integrated with the LocalAI backend. + +## Projects + +- [chatbot-ui](https://github.com/go-skynet/LocalAI/tree/master/examples/chatbot-ui/) (by [@mudler](https://github.com/mudler)) + +## Want to contribute? + +Create an issue, and put `Example: ` in the title! We will post your examples here. \ No newline at end of file diff --git a/examples/chatbot-ui/README.md b/examples/chatbot-ui/README.md new file mode 100644 index 0000000..ff181cb --- /dev/null +++ b/examples/chatbot-ui/README.md @@ -0,0 +1,26 @@ +# chatbot-ui + +Example of integration with [mckaywrigley/chatbot-ui](https://github.com/mckaywrigley/chatbot-ui). + +![Screenshot from 2023-04-26 23-59-55](https://user-images.githubusercontent.com/2420543/234715439-98d12e03-d3ce-4f94-ab54-2b256808e05e.png) + +## Setup + +```bash +# Clone LocalAI +git clone https://github.com/go-skynet/LocalAI + +cd LocalAI/examples/chatbot-ui + +# (optional) Checkout a specific LocalAI tag +# git checkout -b build + +# Download gpt4all-j to models/ +wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j + +# start with docker-compose +docker compose up -d --build +``` + +Open http://localhost:3000 for the Web UI. + diff --git a/examples/chatbot-ui/docker-compose.yaml b/examples/chatbot-ui/docker-compose.yaml new file mode 100644 index 0000000..c7782c3 --- /dev/null +++ b/examples/chatbot-ui/docker-compose.yaml @@ -0,0 +1,24 @@ +version: '3.6' + +services: + api: + image: quay.io/go-skynet/local-ai:latest + build: + context: ../../ + dockerfile: Dockerfile + ports: + - 8080:8080 + environment: + - DEBUG=true + - MODELS_PATH=/models + volumes: + - ./models:/models:cached + command: ["/usr/bin/local-ai" ] + + chatgpt: + image: ghcr.io/mckaywrigley/chatbot-ui:main + ports: + - 3000:3000 + environment: + - 'OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXXX' + - 'OPENAI_API_HOST=http://api:8080' \ No newline at end of file diff --git a/examples/chatbot-ui/models/completion.tmpl b/examples/chatbot-ui/models/completion.tmpl new file mode 100644 index 0000000..9867cfc --- /dev/null +++ b/examples/chatbot-ui/models/completion.tmpl @@ -0,0 +1 @@ +{{.Input}} \ No newline at end of file diff --git a/examples/chatbot-ui/models/gpt-3.5-turbo.yaml b/examples/chatbot-ui/models/gpt-3.5-turbo.yaml new file mode 100644 index 0000000..6df1dbf --- /dev/null +++ b/examples/chatbot-ui/models/gpt-3.5-turbo.yaml @@ -0,0 +1,17 @@ +name: gpt-3.5-turbo +parameters: + model: ggml-gpt4all-j + top_k: 80 + temperature: 0.2 + top_p: 0.7 +context_size: 1024 +threads: 14 +stopwords: +- "HUMAN:" +- "GPT:" +roles: + user: " " + system: " " +template: + completion: completion + chat: gpt4all \ No newline at end of file diff --git a/examples/chatbot-ui/models/gpt4all.tmpl b/examples/chatbot-ui/models/gpt4all.tmpl new file mode 100644 index 0000000..f76b080 --- /dev/null +++ b/examples/chatbot-ui/models/gpt4all.tmpl @@ -0,0 +1,4 @@ +The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. +### Prompt: +{{.Input}} +### Response: diff --git a/main.go b/main.go index 4cb4741..61472af 100644 --- a/main.go +++ b/main.go @@ -50,6 +50,11 @@ func main() { EnvVars: []string{"MODELS_PATH"}, Value: path, }, + &cli.StringFlag{ + Name: "config-file", + DefaultText: "Config file", + EnvVars: []string{"CONFIG_FILE"}, + }, &cli.StringFlag{ Name: "address", DefaultText: "Bind address for the API server.", @@ -80,7 +85,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. UsageText: `local-ai [options]`, Copyright: "go-skynet authors", Action: func(ctx *cli.Context) error { - return api.App(model.NewModelLoader(ctx.String("models-path")), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address")) + return api.App(ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false).Listen(ctx.String("address")) }, } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index b3cce43..6b1539c 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -18,7 +18,7 @@ import ( ) type ModelLoader struct { - modelPath string + ModelPath string mu sync.Mutex models map[string]*llama.LLama @@ -31,7 +31,7 @@ type ModelLoader struct { func NewModelLoader(modelPath string) *ModelLoader { return &ModelLoader{ - modelPath: modelPath, + ModelPath: modelPath, gpt2models: make(map[string]*gpt2.GPT2), gptmodels: make(map[string]*gptj.GPTJ), gptstablelmmodels: make(map[string]*gpt2.StableLM), @@ -41,12 +41,12 @@ func NewModelLoader(modelPath string) *ModelLoader { } func (ml *ModelLoader) ExistsInModelPath(s string) bool { - _, err := os.Stat(filepath.Join(ml.modelPath, s)) + _, err := os.Stat(filepath.Join(ml.ModelPath, s)) return err == nil } func (ml *ModelLoader) ListModels() ([]string, error) { - files, err := ioutil.ReadDir(ml.modelPath) + files, err := ioutil.ReadDir(ml.ModelPath) if err != nil { return []string{}, err } @@ -70,7 +70,19 @@ func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, m, ok := ml.promptsTemplates[modelName] if !ok { - return "", fmt.Errorf("no prompt template available") + modelFile := filepath.Join(ml.ModelPath, modelName) + if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil { + return "", err + } + + t, exists := ml.promptsTemplates[modelName] + if exists { + m = t + } + + } + if m == nil { + return "", nil } var buf bytes.Buffer @@ -88,14 +100,14 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error { } // Check if the model path exists - // skip any error here - we run anyway if a template is not exist + // skip any error here - we run anyway if a template does not exist modelTemplateFile := fmt.Sprintf("%s.tmpl", modelName) if !ml.ExistsInModelPath(modelTemplateFile) { return nil } - dat, err := os.ReadFile(filepath.Join(ml.modelPath, modelTemplateFile)) + dat, err := os.ReadFile(filepath.Join(ml.ModelPath, modelTemplateFile)) if err != nil { return err } @@ -125,7 +137,7 @@ func (ml *ModelLoader) LoadStableLMModel(modelName string) (*gpt2.StableLM, erro } // Load the model and keep it in memory for later use - modelFile := filepath.Join(ml.modelPath, modelName) + modelFile := filepath.Join(ml.ModelPath, modelName) log.Debug().Msgf("Loading model in memory from file: %s", modelFile) model, err := gpt2.NewStableLM(modelFile) @@ -164,7 +176,7 @@ func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) { } // Load the model and keep it in memory for later use - modelFile := filepath.Join(ml.modelPath, modelName) + modelFile := filepath.Join(ml.ModelPath, modelName) log.Debug().Msgf("Loading model in memory from file: %s", modelFile) model, err := gpt2.New(modelFile) @@ -207,7 +219,7 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) { } // Load the model and keep it in memory for later use - modelFile := filepath.Join(ml.modelPath, modelName) + modelFile := filepath.Join(ml.ModelPath, modelName) log.Debug().Msgf("Loading model in memory from file: %s", modelFile) model, err := gptj.New(modelFile) @@ -256,7 +268,7 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio } // Load the model and keep it in memory for later use - modelFile := filepath.Join(ml.modelPath, modelName) + modelFile := filepath.Join(ml.ModelPath, modelName) log.Debug().Msgf("Loading model in memory from file: %s", modelFile) model, err := llama.New(modelFile, opts...) diff --git a/tests/fixtures/completion.tmpl b/tests/fixtures/completion.tmpl new file mode 100644 index 0000000..9867cfc --- /dev/null +++ b/tests/fixtures/completion.tmpl @@ -0,0 +1 @@ +{{.Input}} \ No newline at end of file diff --git a/tests/fixtures/config.yaml b/tests/fixtures/config.yaml new file mode 100644 index 0000000..866b74b --- /dev/null +++ b/tests/fixtures/config.yaml @@ -0,0 +1,28 @@ +- name: list1 + parameters: + model: testmodel + context_size: 512 + threads: 10 + stopwords: + - "HUMAN:" + - "### Response:" + roles: + user: "HUMAN:" + system: "GPT:" + template: + completion: completion + chat: ggml-gpt4all-j +- name: list2 + parameters: + model: testmodel + context_size: 512 + threads: 10 + stopwords: + - "HUMAN:" + - "### Response:" + roles: + user: "HUMAN:" + system: "GPT:" + template: + completion: completion + chat: ggml-gpt4all-j \ No newline at end of file diff --git a/tests/fixtures/ggml-gpt4all-j.tmpl b/tests/fixtures/ggml-gpt4all-j.tmpl new file mode 100644 index 0000000..f76b080 --- /dev/null +++ b/tests/fixtures/ggml-gpt4all-j.tmpl @@ -0,0 +1,4 @@ +The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. +### Prompt: +{{.Input}} +### Response: diff --git a/tests/fixtures/gpt4.yaml b/tests/fixtures/gpt4.yaml new file mode 100644 index 0000000..c2f9bec --- /dev/null +++ b/tests/fixtures/gpt4.yaml @@ -0,0 +1,14 @@ +name: gpt4all +parameters: + model: testmodel +context_size: 512 +threads: 10 +stopwords: +- "HUMAN:" +- "### Response:" +roles: + user: "HUMAN:" + system: "GPT:" +template: + completion: completion + chat: ggml-gpt4all-j \ No newline at end of file diff --git a/tests/fixtures/gpt4_2.yaml b/tests/fixtures/gpt4_2.yaml new file mode 100644 index 0000000..60722f4 --- /dev/null +++ b/tests/fixtures/gpt4_2.yaml @@ -0,0 +1,14 @@ +name: gpt4all-2 +parameters: + model: testmodel +context_size: 1024 +threads: 5 +stopwords: +- "HUMAN:" +- "### Response:" +roles: + user: "HUMAN:" + system: "GPT:" +template: + completion: completion + chat: ggml-gpt4all-j \ No newline at end of file