From f816dfae6598a4fada7ea6b85d2758848fcb9c9a Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 21 Apr 2023 00:06:55 +0200 Subject: [PATCH] Add support for stablelm (#48) Signed-off-by: mudler --- Makefile | 2 +- README.md | 10 +++++++ api/api.go | 32 +++++++++++++++++++++-- go.mod | 2 +- go.sum | 6 ++--- main.go | 4 ++- pkg/model/loader.go | 63 ++++++++++++++++++++++++++++++++++++++++++--- 7 files changed, 106 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index 59a62f6..3035be8 100644 --- a/Makefile +++ b/Makefile @@ -4,7 +4,7 @@ GOVET=$(GOCMD) vet BINARY_NAME=local-ai GOLLAMA_VERSION?=llama.cpp-5ecff35 GOGPT4ALLJ_VERSION?=1f548782d80d48b9a0fac33aae6f129358787bc0 -GOGPT2_VERSION?=f15da66b097d6dacc30140d5def78d153e529e70 +GOGPT2_VERSION?=1c24f5b86ac428cd5e81dae1f1427b1463bd2b06 GREEN := $(shell tput -Txterm setaf 2) YELLOW := $(shell tput -Txterm setaf 3) diff --git a/README.md b/README.md index c74b0d5..afbec65 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,16 @@ LocalAI is a straightforward, drop-in replacement API compatible with OpenAI for It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) supports also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all) and [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml). +Tested with: +- Vicuna +- Alpaca +- [GPT4ALL](https://github.com/nomic-ai/gpt4all) +- [GPT4ALL-J](https://gpt4all.io/models/ggml-gpt4all-j.bin) +- Koala +- [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml) + +It should also be compatible with StableLM and GPTNeoX ggml models (untested) + Note: You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`. ## Usage diff --git a/api/api.go b/api/api.go index 3ec1d81..60e750a 100644 --- a/api/api.go +++ b/api/api.go @@ -75,6 +75,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 var model *llama.LLama var gptModel *gptj.GPTJ var gpt2Model *gpt2.GPT2 + var stableLMModel *gpt2.StableLM input := new(OpenAIRequest) // Get input data from the request body @@ -99,7 +100,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 } // Try to load the model with both - var llamaerr, gpt2err, gptjerr error + var llamaerr, gpt2err, gptjerr, stableerr error llamaOpts := []llama.ModelOption{} if ctx != 0 { llamaOpts = append(llamaOpts, llama.SetContext(ctx)) @@ -115,7 +116,10 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 if gptjerr != nil { gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile) if gpt2err != nil { - return fmt.Errorf("llama: %s gpt: %s gpt2: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error()) // llama failed first, so we want to catch both errors + stableLMModel, stableerr = loader.LoadStableLMModel(modelFile) + if stableerr != nil { + return fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors + } } } } @@ -182,6 +186,30 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 var predFunc func() (string, error) switch { + case stableLMModel != nil: + predFunc = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []gpt2.PredictOption{ + gpt2.SetTemperature(temperature), + gpt2.SetTopP(topP), + gpt2.SetTopK(topK), + gpt2.SetTokens(tokens), + gpt2.SetThreads(threads), + } + + if input.Batch != 0 { + predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch)) + } + + if input.Seed != 0 { + predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed)) + } + + return stableLMModel.Predict( + predInput, + predictOptions..., + ) + } case gpt2Model != nil: predFunc = func() (string, error) { // Generate the prediction using the language model diff --git a/go.mod b/go.mod index 7707615..d90137b 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/go-skynet/LocalAI go 1.19 require ( + github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4 github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 github.com/gofiber/fiber/v2 v2.42.0 @@ -13,7 +14,6 @@ require ( require ( github.com/andybalholm/brotli v1.0.4 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect - github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d // indirect github.com/google/uuid v1.3.0 // indirect github.com/klauspost/compress v1.15.9 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index b98a7d8..880540b 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,8 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420164106-516b5871c74d h1:8crcrVuvpRzf6wejPtIFYGmMrSTfW94CYPJZIssT8zo= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420164106-516b5871c74d/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d h1:Jabxk0NI5CLbY7PVODkRp1AQbEovS9gM6jGAOwyy5FI= -github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= +github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4 h1:GkGuqnhDFKlCsT6Bo8sdY00A7rFXCzfU1nBOSS4ZnYM= +github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrrMvlIq+g0/ltXjDdLeNtz0uc4wJ4Qs15GFU4ba4c= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E= diff --git a/main.go b/main.go index 6a2051b..4c29c42 100644 --- a/main.go +++ b/main.go @@ -66,9 +66,11 @@ Some of the models compatible are: - Koala - GPT4ALL - GPT4ALL-J +- Cerebras - Alpaca +- StableLM (ggml quantized) -It uses llama.cpp and gpt4all as backend, supporting all the models supported by both. +It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. `, UsageText: `local-ai [options]`, Copyright: "go-skynet authors", diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 1db1713..b3cce43 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -21,15 +21,23 @@ type ModelLoader struct { modelPath string mu sync.Mutex - models map[string]*llama.LLama - gptmodels map[string]*gptj.GPTJ - gpt2models map[string]*gpt2.GPT2 + models map[string]*llama.LLama + gptmodels map[string]*gptj.GPTJ + gpt2models map[string]*gpt2.GPT2 + gptstablelmmodels map[string]*gpt2.StableLM promptsTemplates map[string]*template.Template } func NewModelLoader(modelPath string) *ModelLoader { - return &ModelLoader{modelPath: modelPath, gpt2models: make(map[string]*gpt2.GPT2), gptmodels: make(map[string]*gptj.GPTJ), models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)} + return &ModelLoader{ + modelPath: modelPath, + gpt2models: make(map[string]*gpt2.GPT2), + gptmodels: make(map[string]*gptj.GPTJ), + gptstablelmmodels: make(map[string]*gpt2.StableLM), + models: make(map[string]*llama.LLama), + promptsTemplates: make(map[string]*template.Template), + } } func (ml *ModelLoader) ExistsInModelPath(s string) bool { @@ -102,6 +110,38 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error { return nil } +func (ml *ModelLoader) LoadStableLMModel(modelName string) (*gpt2.StableLM, error) { + ml.mu.Lock() + defer ml.mu.Unlock() + + // Check if we already have a loaded model + if !ml.ExistsInModelPath(modelName) { + return nil, fmt.Errorf("model does not exist") + } + + if m, ok := ml.gptstablelmmodels[modelName]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", modelName) + return m, nil + } + + // Load the model and keep it in memory for later use + modelFile := filepath.Join(ml.modelPath, modelName) + log.Debug().Msgf("Loading model in memory from file: %s", modelFile) + + model, err := gpt2.NewStableLM(modelFile) + if err != nil { + return nil, err + } + + // If there is a prompt template, load it + if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil { + return nil, err + } + + ml.gptstablelmmodels[modelName] = model + return model, err +} + func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) { ml.mu.Lock() defer ml.mu.Unlock() @@ -116,6 +156,13 @@ func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) { return m, nil } + // TODO: This needs refactoring, it's really bad to have it in here + // Check if we have a GPTStable model loaded instead - if we do we return an error so the API tries with StableLM + if _, ok := ml.gptstablelmmodels[modelName]; ok { + log.Debug().Msgf("Model is GPTStableLM: %s", modelName) + return nil, fmt.Errorf("this model is a GPTStableLM one") + } + // Load the model and keep it in memory for later use modelFile := filepath.Join(ml.modelPath, modelName) log.Debug().Msgf("Loading model in memory from file: %s", modelFile) @@ -154,6 +201,10 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) { log.Debug().Msgf("Model is GPT2: %s", modelName) return nil, fmt.Errorf("this model is a GPT2 one") } + if _, ok := ml.gptstablelmmodels[modelName]; ok { + log.Debug().Msgf("Model is GPTStableLM: %s", modelName) + return nil, fmt.Errorf("this model is a GPTStableLM one") + } // Load the model and keep it in memory for later use modelFile := filepath.Join(ml.modelPath, modelName) @@ -199,6 +250,10 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio log.Debug().Msgf("Model is GPT2: %s", modelName) return nil, fmt.Errorf("this model is a GPT2 one") } + if _, ok := ml.gptstablelmmodels[modelName]; ok { + log.Debug().Msgf("Model is GPTStableLM: %s", modelName) + return nil, fmt.Errorf("this model is a GPTStableLM one") + } // Load the model and keep it in memory for later use modelFile := filepath.Join(ml.modelPath, modelName)