Add support for stablelm (#48)

Signed-off-by: mudler <mudler@mocaccino.org>
add/first-example
Ettore Di Giacinto 2 years ago committed by GitHub
parent 142bcd66ca
commit f816dfae65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      Makefile
  2. 10
      README.md
  3. 32
      api/api.go
  4. 2
      go.mod
  5. 6
      go.sum
  6. 4
      main.go
  7. 57
      pkg/model/loader.go

@ -4,7 +4,7 @@ GOVET=$(GOCMD) vet
BINARY_NAME=local-ai BINARY_NAME=local-ai
GOLLAMA_VERSION?=llama.cpp-5ecff35 GOLLAMA_VERSION?=llama.cpp-5ecff35
GOGPT4ALLJ_VERSION?=1f548782d80d48b9a0fac33aae6f129358787bc0 GOGPT4ALLJ_VERSION?=1f548782d80d48b9a0fac33aae6f129358787bc0
GOGPT2_VERSION?=f15da66b097d6dacc30140d5def78d153e529e70 GOGPT2_VERSION?=1c24f5b86ac428cd5e81dae1f1427b1463bd2b06
GREEN := $(shell tput -Txterm setaf 2) GREEN := $(shell tput -Txterm setaf 2)
YELLOW := $(shell tput -Txterm setaf 3) YELLOW := $(shell tput -Txterm setaf 3)

@ -19,6 +19,16 @@ LocalAI is a straightforward, drop-in replacement API compatible with OpenAI for
It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) supports also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all) and [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml). It is compatible with the models supported by [llama.cpp](https://github.com/ggerganov/llama.cpp) supports also [GPT4ALL-J](https://github.com/nomic-ai/gpt4all) and [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml).
Tested with:
- Vicuna
- Alpaca
- [GPT4ALL](https://github.com/nomic-ai/gpt4all)
- [GPT4ALL-J](https://gpt4all.io/models/ggml-gpt4all-j.bin)
- Koala
- [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml)
It should also be compatible with StableLM and GPTNeoX ggml models (untested)
Note: You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`. Note: You might need to convert older models to the new format, see [here](https://github.com/ggerganov/llama.cpp#using-gpt4all) for instance to run `gpt4all`.
## Usage ## Usage

@ -75,6 +75,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
var model *llama.LLama var model *llama.LLama
var gptModel *gptj.GPTJ var gptModel *gptj.GPTJ
var gpt2Model *gpt2.GPT2 var gpt2Model *gpt2.GPT2
var stableLMModel *gpt2.StableLM
input := new(OpenAIRequest) input := new(OpenAIRequest)
// Get input data from the request body // Get input data from the request body
@ -99,7 +100,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
} }
// Try to load the model with both // Try to load the model with both
var llamaerr, gpt2err, gptjerr error var llamaerr, gpt2err, gptjerr, stableerr error
llamaOpts := []llama.ModelOption{} llamaOpts := []llama.ModelOption{}
if ctx != 0 { if ctx != 0 {
llamaOpts = append(llamaOpts, llama.SetContext(ctx)) llamaOpts = append(llamaOpts, llama.SetContext(ctx))
@ -115,7 +116,10 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
if gptjerr != nil { if gptjerr != nil {
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile) gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile)
if gpt2err != nil { if gpt2err != nil {
return fmt.Errorf("llama: %s gpt: %s gpt2: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error()) // llama failed first, so we want to catch both errors stableLMModel, stableerr = loader.LoadStableLMModel(modelFile)
if stableerr != nil {
return fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors
}
} }
} }
} }
@ -182,6 +186,30 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
var predFunc func() (string, error) var predFunc func() (string, error)
switch { switch {
case stableLMModel != nil:
predFunc = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []gpt2.PredictOption{
gpt2.SetTemperature(temperature),
gpt2.SetTopP(topP),
gpt2.SetTopK(topK),
gpt2.SetTokens(tokens),
gpt2.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, gpt2.SetBatch(input.Batch))
}
if input.Seed != 0 {
predictOptions = append(predictOptions, gpt2.SetSeed(input.Seed))
}
return stableLMModel.Predict(
predInput,
predictOptions...,
)
}
case gpt2Model != nil: case gpt2Model != nil:
predFunc = func() (string, error) { predFunc = func() (string, error) {
// Generate the prediction using the language model // Generate the prediction using the language model

@ -3,6 +3,7 @@ module github.com/go-skynet/LocalAI
go 1.19 go 1.19
require ( require (
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640
github.com/gofiber/fiber/v2 v2.42.0 github.com/gofiber/fiber/v2 v2.42.0
@ -13,7 +14,6 @@ require (
require ( require (
github.com/andybalholm/brotli v1.0.4 // indirect github.com/andybalholm/brotli v1.0.4 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d // indirect
github.com/google/uuid v1.3.0 // indirect github.com/google/uuid v1.3.0 // indirect
github.com/klauspost/compress v1.15.9 // indirect github.com/klauspost/compress v1.15.9 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect

@ -4,10 +4,8 @@ github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSV
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420164106-516b5871c74d h1:8crcrVuvpRzf6wejPtIFYGmMrSTfW94CYPJZIssT8zo= github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4 h1:GkGuqnhDFKlCsT6Bo8sdY00A7rFXCzfU1nBOSS4ZnYM=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420164106-516b5871c74d/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM= github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d h1:Jabxk0NI5CLbY7PVODkRp1AQbEovS9gM6jGAOwyy5FI=
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420165404-f15da66b097d/go.mod h1:1Wj/xbkMfwQSOrhNYK178IzqQHstZbRfhx4s8p1M5VM=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrrMvlIq+g0/ltXjDdLeNtz0uc4wJ4Qs15GFU4ba4c= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrrMvlIq+g0/ltXjDdLeNtz0uc4wJ4Qs15GFU4ba4c=
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI=
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E=

@ -66,9 +66,11 @@ Some of the models compatible are:
- Koala - Koala
- GPT4ALL - GPT4ALL
- GPT4ALL-J - GPT4ALL-J
- Cerebras
- Alpaca - Alpaca
- StableLM (ggml quantized)
It uses llama.cpp and gpt4all as backend, supporting all the models supported by both. It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
`, `,
UsageText: `local-ai [options]`, UsageText: `local-ai [options]`,
Copyright: "go-skynet authors", Copyright: "go-skynet authors",

@ -24,12 +24,20 @@ type ModelLoader struct {
models map[string]*llama.LLama models map[string]*llama.LLama
gptmodels map[string]*gptj.GPTJ gptmodels map[string]*gptj.GPTJ
gpt2models map[string]*gpt2.GPT2 gpt2models map[string]*gpt2.GPT2
gptstablelmmodels map[string]*gpt2.StableLM
promptsTemplates map[string]*template.Template promptsTemplates map[string]*template.Template
} }
func NewModelLoader(modelPath string) *ModelLoader { func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{modelPath: modelPath, gpt2models: make(map[string]*gpt2.GPT2), gptmodels: make(map[string]*gptj.GPTJ), models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)} return &ModelLoader{
modelPath: modelPath,
gpt2models: make(map[string]*gpt2.GPT2),
gptmodels: make(map[string]*gptj.GPTJ),
gptstablelmmodels: make(map[string]*gpt2.StableLM),
models: make(map[string]*llama.LLama),
promptsTemplates: make(map[string]*template.Template),
}
} }
func (ml *ModelLoader) ExistsInModelPath(s string) bool { func (ml *ModelLoader) ExistsInModelPath(s string) bool {
@ -102,6 +110,38 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error {
return nil return nil
} }
func (ml *ModelLoader) LoadStableLMModel(modelName string) (*gpt2.StableLM, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}
if m, ok := ml.gptstablelmmodels[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model, err := gpt2.NewStableLM(modelFile)
if err != nil {
return nil, err
}
// If there is a prompt template, load it
if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil {
return nil, err
}
ml.gptstablelmmodels[modelName] = model
return model, err
}
func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) { func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) {
ml.mu.Lock() ml.mu.Lock()
defer ml.mu.Unlock() defer ml.mu.Unlock()
@ -116,6 +156,13 @@ func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) {
return m, nil return m, nil
} }
// TODO: This needs refactoring, it's really bad to have it in here
// Check if we have a GPTStable model loaded instead - if we do we return an error so the API tries with StableLM
if _, ok := ml.gptstablelmmodels[modelName]; ok {
log.Debug().Msgf("Model is GPTStableLM: %s", modelName)
return nil, fmt.Errorf("this model is a GPTStableLM one")
}
// Load the model and keep it in memory for later use // Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName) modelFile := filepath.Join(ml.modelPath, modelName)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile) log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
@ -154,6 +201,10 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
log.Debug().Msgf("Model is GPT2: %s", modelName) log.Debug().Msgf("Model is GPT2: %s", modelName)
return nil, fmt.Errorf("this model is a GPT2 one") return nil, fmt.Errorf("this model is a GPT2 one")
} }
if _, ok := ml.gptstablelmmodels[modelName]; ok {
log.Debug().Msgf("Model is GPTStableLM: %s", modelName)
return nil, fmt.Errorf("this model is a GPTStableLM one")
}
// Load the model and keep it in memory for later use // Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName) modelFile := filepath.Join(ml.modelPath, modelName)
@ -199,6 +250,10 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
log.Debug().Msgf("Model is GPT2: %s", modelName) log.Debug().Msgf("Model is GPT2: %s", modelName)
return nil, fmt.Errorf("this model is a GPT2 one") return nil, fmt.Errorf("this model is a GPT2 one")
} }
if _, ok := ml.gptstablelmmodels[modelName]; ok {
log.Debug().Msgf("Model is GPTStableLM: %s", modelName)
return nil, fmt.Errorf("this model is a GPTStableLM one")
}
// Load the model and keep it in memory for later use // Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.modelPath, modelName) modelFile := filepath.Join(ml.modelPath, modelName)

Loading…
Cancel
Save