feat: add rwkv support (#158)

Signed-off-by: mudler <mudler@mocaccino.org>
swagger
Ettore Di Giacinto 2 years ago committed by GitHub
parent 1ae7150810
commit 751b7eca62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 53
      Makefile
  2. 10
      README.md
  3. 2
      api/api_test.go
  4. 35
      api/prediction.go
  5. 1
      go.mod
  6. 2
      go.sum
  7. 34
      pkg/model/loader.go

@ -9,14 +9,18 @@ GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109
# renovate: datasource=git-refs packageNameTemplate=https://github.com/go-skynet/go-gpt2.cpp currentValueTemplate=master depNameTemplate=go-gpt2.cpp # renovate: datasource=git-refs packageNameTemplate=https://github.com/go-skynet/go-gpt2.cpp currentValueTemplate=master depNameTemplate=go-gpt2.cpp
GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa
# here until https://github.com/donomii/go-rwkv.cpp/pull/1 is merged
RWKV_REPO?=https://github.com/mudler/go-rwkv.cpp
RWKV_VERSION?=6ba15255b03016b5ecce36529b500d21815399a7
GREEN := $(shell tput -Txterm setaf 2) GREEN := $(shell tput -Txterm setaf 2)
YELLOW := $(shell tput -Txterm setaf 3) YELLOW := $(shell tput -Txterm setaf 3)
WHITE := $(shell tput -Txterm setaf 7) WHITE := $(shell tput -Txterm setaf 7)
CYAN := $(shell tput -Txterm setaf 6) CYAN := $(shell tput -Txterm setaf 6)
RESET := $(shell tput -Txterm sgr0) RESET := $(shell tput -Txterm sgr0)
C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2 C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv
LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2 LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv
# Use this if you want to set the default behavior # Use this if you want to set the default behavior
ifndef BUILD_TYPE ifndef BUILD_TYPE
@ -33,16 +37,6 @@ endif
all: help all: help
## Build:
build: prepare ## Build the project
$(info ${GREEN}I local-ai build info:${RESET})
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -o $(BINARY_NAME) ./
generic-build: ## Build the project using generic
BUILD_TYPE="generic" $(MAKE) build
## GPT4ALL-J ## GPT4ALL-J
go-gpt4all-j: go-gpt4all-j:
git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j
@ -57,10 +51,18 @@ go-gpt4all-j:
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} + @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
@find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} + @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
## RWKV
go-rwkv:
git clone --recurse-submodules $(RWKV_REPO) go-rwkv
cd go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1
go-rwkv/librwkv.a: go-rwkv
cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. && cp ggml/src/libggml.a ..
go-gpt4all-j/libgptj.a: go-gpt4all-j go-gpt4all-j/libgptj.a: go-gpt4all-j
$(MAKE) -C go-gpt4all-j $(GENERIC_PREFIX)libgptj.a $(MAKE) -C go-gpt4all-j $(GENERIC_PREFIX)libgptj.a
# CEREBRAS GPT ## CEREBRAS GPT
go-gpt2: go-gpt2:
git clone --recurse-submodules https://github.com/go-skynet/go-gpt2.cpp go-gpt2 git clone --recurse-submodules https://github.com/go-skynet/go-gpt2.cpp go-gpt2
cd go-gpt2 && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1 cd go-gpt2 && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1
@ -75,7 +77,6 @@ go-gpt2:
go-gpt2/libgpt2.a: go-gpt2 go-gpt2/libgpt2.a: go-gpt2
$(MAKE) -C go-gpt2 $(GENERIC_PREFIX)libgpt2.a $(MAKE) -C go-gpt2 $(GENERIC_PREFIX)libgpt2.a
go-llama: go-llama:
git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama
@ -86,26 +87,40 @@ replace:
$(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j $(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j
$(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2 $(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2
$(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(shell pwd)/go-rwkv
prepare-sources: go-llama go-gpt2 go-gpt4all-j prepare-sources: go-llama go-gpt2 go-gpt4all-j go-rwkv
$(GOCMD) mod download $(GOCMD) mod download
rebuild: ## GENERIC
rebuild: ## Rebuilds the project
$(MAKE) -C go-llama clean $(MAKE) -C go-llama clean
$(MAKE) -C go-gpt4all-j clean $(MAKE) -C go-gpt4all-j clean
$(MAKE) -C go-gpt2 clean $(MAKE) -C go-gpt2 clean
$(MAKE) -C go-rwkv clean
$(MAKE) build $(MAKE) build
prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-gpt2/libgpt2.a replace prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-gpt2/libgpt2.a go-rwkv/librwkv.a replace ## Prepares for building
clean: ## Remove build related file clean: ## Remove build related file
rm -fr ./go-llama rm -fr ./go-llama
rm -rf ./go-gpt4all-j rm -rf ./go-gpt4all-j
rm -rf ./go-gpt2 rm -rf ./go-gpt2
rm -rf ./go-rwkv
rm -rf $(BINARY_NAME) rm -rf $(BINARY_NAME)
## Run: ## Build:
run: prepare
build: prepare ## Build the project
$(info ${GREEN}I local-ai build info:${RESET})
$(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET})
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -o $(BINARY_NAME) ./
generic-build: ## Build the project using generic
BUILD_TYPE="generic" $(MAKE) build
## Run
run: prepare ## run local-ai
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) run ./main.go C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) run ./main.go
test-models/testmodel: test-models/testmodel:

@ -15,7 +15,7 @@
- Supports multiple-models - Supports multiple-models
- Once loaded the first time, it keep models loaded in memory for faster inference - Once loaded the first time, it keep models loaded in memory for faster inference
- Support for prompt templates - Support for prompt templates
- Doesn't shell-out, but uses C bindings for a faster inference and better performance. Uses [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) and [go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp). - Doesn't shell-out, but uses C bindings for a faster inference and better performance.
LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud). LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud).
@ -39,6 +39,7 @@ Tested with:
- [GPT4ALL-J](https://gpt4all.io/models/ggml-gpt4all-j.bin) - [GPT4ALL-J](https://gpt4all.io/models/ggml-gpt4all-j.bin)
- Koala - Koala
- [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml) - [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml)
- [RWKV](https://github.com/BlinkDL/RWKV-LM) with [rwkv.cpp](https://github.com/saharNooby/rwkv.cpp)
It should also be compatible with StableLM and GPTNeoX ggml models (untested) It should also be compatible with StableLM and GPTNeoX ggml models (untested)
@ -506,6 +507,13 @@ LocalAI is a community-driven project. It was initially created by [mudler](http
MIT MIT
## Golang bindings used
- [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp)
- [go-skynet/go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp)
- [go-skynet/go-gpt2.cpp](https://github.com/go-skynet/go-gpt2.cpp)
- [donomii/go-rwkv.cpp](https://github.com/donomii/go-rwkv.cpp)
## Acknowledgements ## Acknowledgements
- [llama.cpp](https://github.com/ggerganov/llama.cpp) - [llama.cpp](https://github.com/ggerganov/llama.cpp)

@ -79,7 +79,7 @@ var _ = Describe("API test", func() {
It("returns errors", func() { It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 4 errors occurred:")) Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 5 errors occurred:"))
}) })
}) })

@ -6,6 +6,7 @@ import (
"strings" "strings"
"sync" "sync"
"github.com/donomii/go-rwkv.cpp"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
gpt2 "github.com/go-skynet/go-gpt2.cpp" gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp"
@ -13,6 +14,8 @@ import (
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
) )
const tokenizerSuffix = ".tokenizer.json"
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 // mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
var mutexMap sync.Mutex var mutexMap sync.Mutex
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
@ -20,7 +23,7 @@ var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
var loadedModels map[string]interface{} = map[string]interface{}{} var loadedModels map[string]interface{} = map[string]interface{}{}
var muModels sync.Mutex var muModels sync.Mutex
func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) { func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
switch strings.ToLower(backendString) { switch strings.ToLower(backendString) {
case "llama": case "llama":
return loader.LoadLLaMAModel(modelFile, llamaOpts...) return loader.LoadLLaMAModel(modelFile, llamaOpts...)
@ -30,12 +33,14 @@ func backendLoader(backendString string, loader *model.ModelLoader, modelFile st
return loader.LoadGPT2Model(modelFile) return loader.LoadGPT2Model(modelFile)
case "gptj": case "gptj":
return loader.LoadGPTJModel(modelFile) return loader.LoadGPTJModel(modelFile)
case "rwkv":
return loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
default: default:
return nil, fmt.Errorf("backend unsupported: %s", backendString) return nil, fmt.Errorf("backend unsupported: %s", backendString)
} }
} }
func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) { func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) {
updateModels := func(model interface{}) { updateModels := func(model interface{}) {
muModels.Lock() muModels.Lock()
defer muModels.Unlock() defer muModels.Unlock()
@ -82,6 +87,14 @@ func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama
err = multierror.Append(err, modelerr) err = multierror.Append(err, modelerr)
} }
model, modelerr = loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads)
if modelerr == nil {
updateModels(model)
return model, nil
} else {
err = multierror.Append(err, modelerr)
}
return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error()) return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error())
} }
@ -101,9 +114,9 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
var inferenceModel interface{} var inferenceModel interface{}
var err error var err error
if c.Backend == "" { if c.Backend == "" {
inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts) inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts, uint32(c.Threads))
} else { } else {
inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts) inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts, uint32(c.Threads))
} }
if err != nil { if err != nil {
return nil, err return nil, err
@ -112,6 +125,20 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
var fn func() (string, error) var fn func() (string, error)
switch model := inferenceModel.(type) { switch model := inferenceModel.(type) {
case *rwkv.RwkvState:
supportStreams = true
fn = func() (string, error) {
//model.ProcessInput("You are a chatbot that is very good at chatting. blah blah blah")
stopWord := "\n"
if len(c.StopWords) > 0 {
stopWord = c.StopWords[0]
}
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
return response, nil
}
case *gpt2.StableLM: case *gpt2.StableLM:
fn = func() (string, error) { fn = func() (string, error) {
// Generate the prediction using the language model // Generate the prediction using the language model

@ -23,6 +23,7 @@ require (
github.com/StackExchange/wmi v1.2.1 // indirect github.com/StackExchange/wmi v1.2.1 // indirect
github.com/andybalholm/brotli v1.0.5 // indirect github.com/andybalholm/brotli v1.0.5 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d // indirect
github.com/ghodss/yaml v1.0.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect
github.com/go-logr/logr v1.2.3 // indirect github.com/go-logr/logr v1.2.3 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-ole/go-ole v1.2.6 // indirect

@ -12,6 +12,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d h1:lSHwlYf1H4WAWYgf7rjEVTGen1qmigUq2Egpu8mnQiY=
github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d/go.mod h1:H6QBF7/Tz6DAEBDXQged4H1BvsmqY/K5FG9wQRGa01g=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0=

@ -12,6 +12,7 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
rwkv "github.com/donomii/go-rwkv.cpp"
gpt2 "github.com/go-skynet/go-gpt2.cpp" gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp" llama "github.com/go-skynet/go-llama.cpp"
@ -25,7 +26,7 @@ type ModelLoader struct {
gptmodels map[string]*gptj.GPTJ gptmodels map[string]*gptj.GPTJ
gpt2models map[string]*gpt2.GPT2 gpt2models map[string]*gpt2.GPT2
gptstablelmmodels map[string]*gpt2.StableLM gptstablelmmodels map[string]*gpt2.StableLM
rwkv map[string]*rwkv.RwkvState
promptsTemplates map[string]*template.Template promptsTemplates map[string]*template.Template
} }
@ -36,6 +37,7 @@ func NewModelLoader(modelPath string) *ModelLoader {
gptmodels: make(map[string]*gptj.GPTJ), gptmodels: make(map[string]*gptj.GPTJ),
gptstablelmmodels: make(map[string]*gpt2.StableLM), gptstablelmmodels: make(map[string]*gpt2.StableLM),
models: make(map[string]*llama.LLama), models: make(map[string]*llama.LLama),
rwkv: make(map[string]*rwkv.RwkvState),
promptsTemplates: make(map[string]*template.Template), promptsTemplates: make(map[string]*template.Template),
} }
} }
@ -218,6 +220,36 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
return model, err return model, err
} }
func (ml *ModelLoader) LoadRWKV(modelName, tokenFile string, threads uint32) (*rwkv.RwkvState, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
log.Debug().Msgf("Loading model name: %s", modelName)
// Check if we already have a loaded model
if !ml.ExistsInModelPath(modelName) {
return nil, fmt.Errorf("model does not exist")
}
if m, ok := ml.rwkv[modelName]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil
}
// Load the model and keep it in memory for later use
modelFile := filepath.Join(ml.ModelPath, modelName)
tokenPath := filepath.Join(ml.ModelPath, tokenFile)
log.Debug().Msgf("Loading model in memory from file: %s", modelFile)
model := rwkv.LoadFiles(modelFile, tokenPath, threads)
if model == nil {
return nil, fmt.Errorf("could not load model")
}
ml.rwkv[modelName] = model
return model, nil
}
func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) { func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) {
ml.mu.Lock() ml.mu.Lock()
defer ml.mu.Unlock() defer ml.mu.Unlock()

Loading…
Cancel
Save