From d62aef20166e3a7b2169c0ae5c0752013d4caa83 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 6 Jun 2023 17:23:19 +0200 Subject: [PATCH] feat: add experimental support for falcon-7b (#516) Signed-off-by: mudler --- Makefile | 2 +- api/api_test.go | 2 +- api/prediction.go | 24 ++++++++++++++++++++++++ pkg/model/initializers.go | 12 ++++++++++-- 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 7a68a1d..e131685 100644 --- a/Makefile +++ b/Makefile @@ -6,7 +6,7 @@ BINARY_NAME=local-ai GOLLAMA_VERSION?=cca84ed55fb920ccdd6158958b2c9b773ce17eea GPT4ALL_REPO?=https://github.com/go-skynet/gpt4all GPT4ALL_VERSION?=f7498c9 -GOGGMLTRANSFORMERS_VERSION?=6fb862c72bc04568120e711b176defe116d3751e +GOGGMLTRANSFORMERS_VERSION?=bd765bb6f3b38a63f915f3725e488aad492eedd4 RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=1e18b2490e7e32f6b00e16f6a9ec0dd3a3d09266 WHISPER_CPP_VERSION?=5b9e59bc07dd76320354f2af6be29f16dbcb21e7 diff --git a/api/api_test.go b/api/api_test.go index 54118b8..6dc697a 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -287,7 +287,7 @@ var _ = Describe("API test", func() { It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 10 errors occurred:")) + Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 11 errors occurred:")) }) It("transcribes audio", func() { if runtime.GOOS != "linux" { diff --git a/api/prediction.go b/api/prediction.go index 8aad422..04a5b95 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -368,6 +368,30 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) } + return model.Predict( + s, + predictOptions..., + ) + } + case *transformers.Falcon: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []transformers.PredictOption{ + transformers.SetTemperature(c.Temperature), + transformers.SetTopP(c.TopP), + transformers.SetTopK(c.TopK), + transformers.SetTokens(c.Maxtokens), + transformers.SetThreads(c.Threads), + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) + } + return model.Predict( s, predictOptions..., diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index b2c23b7..7de487d 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -34,6 +34,7 @@ const ( Gpt4AllMptBackend = "gpt4all-mpt" Gpt4AllJBackend = "gpt4all-j" Gpt4All = "gpt4all" + FalconBackend = "falcon" BertEmbeddingsBackend = "bert-embeddings" RwkvBackend = "rwkv" WhisperBackend = "whisper" @@ -41,7 +42,7 @@ const ( LCHuggingFaceBackend = "langchain-huggingface" ) -var backends []string = []string{ +var autoLoadBackends []string = []string{ LlamaBackend, Gpt4All, RwkvBackend, @@ -51,6 +52,7 @@ var backends []string = []string{ GPTJBackend, Gpt2Backend, DollyBackend, + FalconBackend, MPTBackend, ReplitBackend, StarcoderBackend, @@ -81,6 +83,10 @@ var gptJ = func(modelFile string) (interface{}, error) { return transformers.NewGPTJ(modelFile) } +var falcon = func(modelFile string) (interface{}, error) { + return transformers.NewFalcon(modelFile) +} + var bertEmbeddings = func(modelFile string) (interface{}, error) { return bert.New(modelFile) } @@ -144,6 +150,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla return ml.LoadModel(modelFile, mpt) case Gpt2Backend: return ml.LoadModel(modelFile, transformersLM) + case FalconBackend: + return ml.LoadModel(modelFile, falcon) case GPTNeoXBackend: return ml.LoadModel(modelFile, gptNeoX) case ReplitBackend: @@ -180,7 +188,7 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt ml.mu.Unlock() var err error - for _, b := range backends { + for _, b := range autoLoadBackends { if b == BloomzBackend || b == WhisperBackend || b == RwkvBackend { // do not autoload bloomz/whisper/rwkv continue }