From 8250391e495c71bc7dcc4dcd7d24a0df52497aa4 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 12 May 2023 11:36:35 +0200 Subject: [PATCH] Add support for gptneox/replit (#238) --- Makefile | 2 +- README.md | 4 +++- api/api_test.go | 2 +- api/prediction.go | 48 +++++++++++++++++++++++++++++++++++++++ pkg/model/initializers.go | 16 +++++++++++++ 5 files changed, 69 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index d1cf783..b6471f3 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ BINARY_NAME=local-ai GOLLAMA_VERSION?=70593fccbe4b01dedaab805b0f25cb58192c7b38 GPT4ALL_REPO?=https://github.com/go-skynet/gpt4all GPT4ALL_VERSION?=3657f9417e17edf378c27d0a9274a1bf41caa914 -GOGPT2_VERSION?=6a10572a23328e18a62cfdb45e4a3c8ddbe75f25 +GOGPT2_VERSION?=92421a8cf61ed6e03babd9067af292b094cb1307 RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47 WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993 diff --git a/README.md b/README.md index 04f43d3..dbd45ed 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ It should also be compatible with StableLM and GPTNeoX ggml models (untested). Depending on the model you are attempting to run might need more RAM or CPU resources. Check out also [here](https://github.com/ggerganov/llama.cpp#memorydisk-requirements) for `ggml` based backends. `rwkv` is less expensive on resources. -### Feature support matrix +### Model compatibility table
@@ -106,6 +106,8 @@ Depending on the model you are attempting to run might need more RAM or CPU reso | dolly | Dolly | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | | redpajama | RedPajama | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | | stableLM | StableLM GPT/NeoX | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | +| replit | Replit | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | +| gptneox | GPT NeoX | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | | starcoder | Starcoder | yes | no | no | no | https://github.com/ggerganov/ggml | https://github.com/go-skynet/go-gpt2.cpp | | bloomz | Bloom | yes | no | no | no | https://github.com/NouamaneTazi/bloomz.cpp | https://github.com/go-skynet/bloomz.cpp | | rwkv | RWKV | yes | no | no | yes | https://github.com/saharNooby/rwkv.cpp | https://github.com/donomii/go-rwkv.cpp | diff --git a/api/api_test.go b/api/api_test.go index 5829bd1..1189cdb 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -80,7 +80,7 @@ var _ = Describe("API test", func() { It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 10 errors occurred:")) + Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:")) }) PIt("transcribes audio", func() { resp, err := client.CreateTranscription( diff --git a/api/prediction.go b/api/prediction.go index b128e7e..f31ffd5 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -199,6 +199,54 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback return response, nil } + case *gpt2.GPTNeoX: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []gpt2.PredictOption{ + gpt2.SetTemperature(c.Temperature), + gpt2.SetTopP(c.TopP), + gpt2.SetTopK(c.TopK), + gpt2.SetTokens(c.Maxtokens), + gpt2.SetThreads(c.Threads), + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch)) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed)) + } + + return model.Predict( + s, + predictOptions..., + ) + } + case *gpt2.Replit: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []gpt2.PredictOption{ + gpt2.SetTemperature(c.Temperature), + gpt2.SetTopP(c.TopP), + gpt2.SetTopK(c.TopK), + gpt2.SetTokens(c.Maxtokens), + gpt2.SetThreads(c.Threads), + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch)) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed)) + } + + return model.Predict( + s, + predictOptions..., + ) + } case *gpt2.Starcoder: fn = func() (string, error) { // Generate the prediction using the language model diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 686dc38..07dbc0c 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -24,6 +24,8 @@ const ( StableLMBackend = "stablelm" DollyBackend = "dolly" RedPajamaBackend = "redpajama" + GPTNeoXBackend = "gptneox" + ReplitBackend = "replit" Gpt2Backend = "gpt2" Gpt4AllLlamaBackend = "gpt4all-llama" Gpt4AllMptBackend = "gpt4all-mpt" @@ -45,6 +47,8 @@ var backends []string = []string{ StableLMBackend, DollyBackend, RedPajamaBackend, + GPTNeoXBackend, + ReplitBackend, BertEmbeddingsBackend, StarcoderBackend, } @@ -61,6 +65,14 @@ var dolly = func(modelFile string) (interface{}, error) { return gpt2.NewDolly(modelFile) } +var gptNeoX = func(modelFile string) (interface{}, error) { + return gpt2.NewGPTNeoX(modelFile) +} + +var replit = func(modelFile string) (interface{}, error) { + return gpt2.NewReplit(modelFile) +} + var stableLM = func(modelFile string) (interface{}, error) { return gpt2.NewStableLM(modelFile) } @@ -116,6 +128,10 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla return ml.LoadModel(modelFile, redPajama) case Gpt2Backend: return ml.LoadModel(modelFile, gpt2LM) + case GPTNeoXBackend: + return ml.LoadModel(modelFile, gptNeoX) + case ReplitBackend: + return ml.LoadModel(modelFile, replit) case StarcoderBackend: return ml.LoadModel(modelFile, starCoder) case Gpt4AllLlamaBackend: