feat: add experimental support for falcon-7b (#516)

Signed-off-by: mudler <mudler@mocaccino.org>
renovate/github.com-imdario-mergo-1.x
Ettore Di Giacinto 1 year ago committed by GitHub
parent 25e9483add
commit d62aef2016
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      Makefile
  2. 2
      api/api_test.go
  3. 24
      api/prediction.go
  4. 12
      pkg/model/initializers.go

@ -6,7 +6,7 @@ BINARY_NAME=local-ai
GOLLAMA_VERSION?=cca84ed55fb920ccdd6158958b2c9b773ce17eea GOLLAMA_VERSION?=cca84ed55fb920ccdd6158958b2c9b773ce17eea
GPT4ALL_REPO?=https://github.com/go-skynet/gpt4all GPT4ALL_REPO?=https://github.com/go-skynet/gpt4all
GPT4ALL_VERSION?=f7498c9 GPT4ALL_VERSION?=f7498c9
GOGGMLTRANSFORMERS_VERSION?=6fb862c72bc04568120e711b176defe116d3751e GOGGMLTRANSFORMERS_VERSION?=bd765bb6f3b38a63f915f3725e488aad492eedd4
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=1e18b2490e7e32f6b00e16f6a9ec0dd3a3d09266 RWKV_VERSION?=1e18b2490e7e32f6b00e16f6a9ec0dd3a3d09266
WHISPER_CPP_VERSION?=5b9e59bc07dd76320354f2af6be29f16dbcb21e7 WHISPER_CPP_VERSION?=5b9e59bc07dd76320354f2af6be29f16dbcb21e7

@ -287,7 +287,7 @@ var _ = Describe("API test", func() {
It("returns errors", func() { It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 10 errors occurred:")) Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 11 errors occurred:"))
}) })
It("transcribes audio", func() { It("transcribes audio", func() {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {

@ -368,6 +368,30 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
} }
return model.Predict(
s,
predictOptions...,
)
}
case *transformers.Falcon:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []transformers.PredictOption{
transformers.SetTemperature(c.Temperature),
transformers.SetTopP(c.TopP),
transformers.SetTopK(c.TopK),
transformers.SetTokens(c.Maxtokens),
transformers.SetThreads(c.Threads),
}
if c.Batch != 0 {
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch))
}
if c.Seed != 0 {
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed))
}
return model.Predict( return model.Predict(
s, s,
predictOptions..., predictOptions...,

@ -34,6 +34,7 @@ const (
Gpt4AllMptBackend = "gpt4all-mpt" Gpt4AllMptBackend = "gpt4all-mpt"
Gpt4AllJBackend = "gpt4all-j" Gpt4AllJBackend = "gpt4all-j"
Gpt4All = "gpt4all" Gpt4All = "gpt4all"
FalconBackend = "falcon"
BertEmbeddingsBackend = "bert-embeddings" BertEmbeddingsBackend = "bert-embeddings"
RwkvBackend = "rwkv" RwkvBackend = "rwkv"
WhisperBackend = "whisper" WhisperBackend = "whisper"
@ -41,7 +42,7 @@ const (
LCHuggingFaceBackend = "langchain-huggingface" LCHuggingFaceBackend = "langchain-huggingface"
) )
var backends []string = []string{ var autoLoadBackends []string = []string{
LlamaBackend, LlamaBackend,
Gpt4All, Gpt4All,
RwkvBackend, RwkvBackend,
@ -51,6 +52,7 @@ var backends []string = []string{
GPTJBackend, GPTJBackend,
Gpt2Backend, Gpt2Backend,
DollyBackend, DollyBackend,
FalconBackend,
MPTBackend, MPTBackend,
ReplitBackend, ReplitBackend,
StarcoderBackend, StarcoderBackend,
@ -81,6 +83,10 @@ var gptJ = func(modelFile string) (interface{}, error) {
return transformers.NewGPTJ(modelFile) return transformers.NewGPTJ(modelFile)
} }
var falcon = func(modelFile string) (interface{}, error) {
return transformers.NewFalcon(modelFile)
}
var bertEmbeddings = func(modelFile string) (interface{}, error) { var bertEmbeddings = func(modelFile string) (interface{}, error) {
return bert.New(modelFile) return bert.New(modelFile)
} }
@ -144,6 +150,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
return ml.LoadModel(modelFile, mpt) return ml.LoadModel(modelFile, mpt)
case Gpt2Backend: case Gpt2Backend:
return ml.LoadModel(modelFile, transformersLM) return ml.LoadModel(modelFile, transformersLM)
case FalconBackend:
return ml.LoadModel(modelFile, falcon)
case GPTNeoXBackend: case GPTNeoXBackend:
return ml.LoadModel(modelFile, gptNeoX) return ml.LoadModel(modelFile, gptNeoX)
case ReplitBackend: case ReplitBackend:
@ -180,7 +188,7 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt
ml.mu.Unlock() ml.mu.Unlock()
var err error var err error
for _, b := range backends { for _, b := range autoLoadBackends {
if b == BloomzBackend || b == WhisperBackend || b == RwkvBackend { // do not autoload bloomz/whisper/rwkv if b == BloomzBackend || b == WhisperBackend || b == RwkvBackend { // do not autoload bloomz/whisper/rwkv
continue continue
} }

Loading…
Cancel
Save