From 751b7eca62e8a66a50506b151d60251c7f375ea3 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 3 May 2023 11:45:22 +0200 Subject: [PATCH] feat: add rwkv support (#158) Signed-off-by: mudler --- Makefile | 55 ++++++++++++++++++++++++++++----------------- README.md | 10 ++++++++- api/api_test.go | 2 +- api/prediction.go | 35 +++++++++++++++++++++++++---- go.mod | 1 + go.sum | 2 ++ pkg/model/loader.go | 36 +++++++++++++++++++++++++++-- 7 files changed, 113 insertions(+), 28 deletions(-) diff --git a/Makefile b/Makefile index e8b7929..8afe03e 100644 --- a/Makefile +++ b/Makefile @@ -9,14 +9,18 @@ GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109 # renovate: datasource=git-refs packageNameTemplate=https://github.com/go-skynet/go-gpt2.cpp currentValueTemplate=master depNameTemplate=go-gpt2.cpp GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa +# here until https://github.com/donomii/go-rwkv.cpp/pull/1 is merged +RWKV_REPO?=https://github.com/mudler/go-rwkv.cpp +RWKV_VERSION?=6ba15255b03016b5ecce36529b500d21815399a7 + GREEN := $(shell tput -Txterm setaf 2) YELLOW := $(shell tput -Txterm setaf 3) WHITE := $(shell tput -Txterm setaf 7) CYAN := $(shell tput -Txterm setaf 6) RESET := $(shell tput -Txterm sgr0) -C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2 -LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2 +C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv +LIBRARY_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-gpt4all-j:$(shell pwd)/go-gpt2:$(shell pwd)/go-rwkv # Use this if you want to set the default behavior ifndef BUILD_TYPE @@ -33,16 +37,6 @@ endif all: help -## Build: - -build: prepare ## Build the project - $(info ${GREEN}I local-ai build info:${RESET}) - $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) - C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -o $(BINARY_NAME) ./ - -generic-build: ## Build the project using generic - BUILD_TYPE="generic" $(MAKE) build - ## GPT4ALL-J go-gpt4all-j: git clone --recurse-submodules https://github.com/go-skynet/go-gpt4all-j.cpp go-gpt4all-j @@ -57,11 +51,19 @@ go-gpt4all-j: @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} + @find ./go-gpt4all-j -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} + +## RWKV +go-rwkv: + git clone --recurse-submodules $(RWKV_REPO) go-rwkv + cd go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1 + +go-rwkv/librwkv.a: go-rwkv + cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. && cp ggml/src/libggml.a .. + go-gpt4all-j/libgptj.a: go-gpt4all-j $(MAKE) -C go-gpt4all-j $(GENERIC_PREFIX)libgptj.a -# CEREBRAS GPT -go-gpt2: +## CEREBRAS GPT +go-gpt2: git clone --recurse-submodules https://github.com/go-skynet/go-gpt2.cpp go-gpt2 cd go-gpt2 && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1 # This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml.. @@ -74,7 +76,6 @@ go-gpt2: go-gpt2/libgpt2.a: go-gpt2 $(MAKE) -C go-gpt2 $(GENERIC_PREFIX)libgpt2.a - go-llama: git clone -b $(GOLLAMA_VERSION) --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama @@ -86,26 +87,40 @@ replace: $(GOCMD) mod edit -replace github.com/go-skynet/go-llama.cpp=$(shell pwd)/go-llama $(GOCMD) mod edit -replace github.com/go-skynet/go-gpt4all-j.cpp=$(shell pwd)/go-gpt4all-j $(GOCMD) mod edit -replace github.com/go-skynet/go-gpt2.cpp=$(shell pwd)/go-gpt2 + $(GOCMD) mod edit -replace github.com/donomii/go-rwkv.cpp=$(shell pwd)/go-rwkv -prepare-sources: go-llama go-gpt2 go-gpt4all-j +prepare-sources: go-llama go-gpt2 go-gpt4all-j go-rwkv $(GOCMD) mod download -rebuild: +## GENERIC +rebuild: ## Rebuilds the project $(MAKE) -C go-llama clean $(MAKE) -C go-gpt4all-j clean $(MAKE) -C go-gpt2 clean + $(MAKE) -C go-rwkv clean $(MAKE) build -prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-gpt2/libgpt2.a replace +prepare: prepare-sources go-llama/libbinding.a go-gpt4all-j/libgptj.a go-gpt2/libgpt2.a go-rwkv/librwkv.a replace ## Prepares for building clean: ## Remove build related file rm -fr ./go-llama rm -rf ./go-gpt4all-j rm -rf ./go-gpt2 + rm -rf ./go-rwkv rm -rf $(BINARY_NAME) -## Run: -run: prepare +## Build: + +build: prepare ## Build the project + $(info ${GREEN}I local-ai build info:${RESET}) + $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) + C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -o $(BINARY_NAME) ./ + +generic-build: ## Build the project using generic + BUILD_TYPE="generic" $(MAKE) build + +## Run +run: prepare ## run local-ai C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) run ./main.go test-models/testmodel: diff --git a/README.md b/README.md index a06763a..03d6bb6 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,7 @@ - Supports multiple-models - Once loaded the first time, it keep models loaded in memory for faster inference - Support for prompt templates -- Doesn't shell-out, but uses C bindings for a faster inference and better performance. Uses [go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) and [go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp). +- Doesn't shell-out, but uses C bindings for a faster inference and better performance. LocalAI is a community-driven project, focused on making the AI accessible to anyone. Any contribution, feedback and PR is welcome! It was initially created by [mudler](https://github.com/mudler/) at the [SpectroCloud OSS Office](https://github.com/spectrocloud). @@ -39,6 +39,7 @@ Tested with: - [GPT4ALL-J](https://gpt4all.io/models/ggml-gpt4all-j.bin) - Koala - [cerebras-GPT with ggml](https://huggingface.co/lxe/Cerebras-GPT-2.7B-Alpaca-SP-ggml) +- [RWKV](https://github.com/BlinkDL/RWKV-LM) with [rwkv.cpp](https://github.com/saharNooby/rwkv.cpp) It should also be compatible with StableLM and GPTNeoX ggml models (untested) @@ -506,6 +507,13 @@ LocalAI is a community-driven project. It was initially created by [mudler](http MIT +## Golang bindings used + +- [go-skynet/go-llama.cpp](https://github.com/go-skynet/go-llama.cpp) +- [go-skynet/go-gpt4all-j.cpp](https://github.com/go-skynet/go-gpt4all-j.cpp) +- [go-skynet/go-gpt2.cpp](https://github.com/go-skynet/go-gpt2.cpp) +- [donomii/go-rwkv.cpp](https://github.com/donomii/go-rwkv.cpp) + ## Acknowledgements - [llama.cpp](https://github.com/ggerganov/llama.cpp) diff --git a/api/api_test.go b/api/api_test.go index 199ef14..9682a21 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -79,7 +79,7 @@ var _ = Describe("API test", func() { It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 4 errors occurred:")) + Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 5 errors occurred:")) }) }) diff --git a/api/prediction.go b/api/prediction.go index 4d2f77c..127a957 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -6,6 +6,7 @@ import ( "strings" "sync" + "github.com/donomii/go-rwkv.cpp" model "github.com/go-skynet/LocalAI/pkg/model" gpt2 "github.com/go-skynet/go-gpt2.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp" @@ -13,6 +14,8 @@ import ( "github.com/hashicorp/go-multierror" ) +const tokenizerSuffix = ".tokenizer.json" + // mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 var mutexMap sync.Mutex var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) @@ -20,7 +23,7 @@ var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) var loadedModels map[string]interface{} = map[string]interface{}{} var muModels sync.Mutex -func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) { +func backendLoader(backendString string, loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) { switch strings.ToLower(backendString) { case "llama": return loader.LoadLLaMAModel(modelFile, llamaOpts...) @@ -30,12 +33,14 @@ func backendLoader(backendString string, loader *model.ModelLoader, modelFile st return loader.LoadGPT2Model(modelFile) case "gptj": return loader.LoadGPTJModel(modelFile) + case "rwkv": + return loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads) default: return nil, fmt.Errorf("backend unsupported: %s", backendString) } } -func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption) (model interface{}, err error) { +func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) { updateModels := func(model interface{}) { muModels.Lock() defer muModels.Unlock() @@ -82,6 +87,14 @@ func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama err = multierror.Append(err, modelerr) } + model, modelerr = loader.LoadRWKV(modelFile, modelFile+tokenizerSuffix, threads) + if modelerr == nil { + updateModels(model) + return model, nil + } else { + err = multierror.Append(err, modelerr) + } + return nil, fmt.Errorf("could not load model - all backends returned error: %s", err.Error()) } @@ -101,9 +114,9 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback var inferenceModel interface{} var err error if c.Backend == "" { - inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts) + inferenceModel, err = greedyLoader(loader, modelFile, llamaOpts, uint32(c.Threads)) } else { - inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts) + inferenceModel, err = backendLoader(c.Backend, loader, modelFile, llamaOpts, uint32(c.Threads)) } if err != nil { return nil, err @@ -112,6 +125,20 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback var fn func() (string, error) switch model := inferenceModel.(type) { + case *rwkv.RwkvState: + supportStreams = true + + fn = func() (string, error) { + //model.ProcessInput("You are a chatbot that is very good at chatting. blah blah blah") + stopWord := "\n" + if len(c.StopWords) > 0 { + stopWord = c.StopWords[0] + } + + response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) + + return response, nil + } case *gpt2.StableLM: fn = func() (string, error) { // Generate the prediction using the language model diff --git a/go.mod b/go.mod index 2f01790..b6b9dff 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/StackExchange/wmi v1.2.1 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect + github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d // indirect github.com/ghodss/yaml v1.0.0 // indirect github.com/go-logr/logr v1.2.3 // indirect github.com/go-ole/go-ole v1.2.6 // indirect diff --git a/go.sum b/go.sum index 528e445..cb88cf4 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d h1:lSHwlYf1H4WAWYgf7rjEVTGen1qmigUq2Egpu8mnQiY= +github.com/donomii/go-rwkv.cpp v0.0.0-20230502223004-0a3db3d72e7d/go.mod h1:H6QBF7/Tz6DAEBDXQged4H1BvsmqY/K5FG9wQRGa01g= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 34826d1..7037e86 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -12,6 +12,7 @@ import ( "github.com/rs/zerolog/log" + rwkv "github.com/donomii/go-rwkv.cpp" gpt2 "github.com/go-skynet/go-gpt2.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp" llama "github.com/go-skynet/go-llama.cpp" @@ -25,8 +26,8 @@ type ModelLoader struct { gptmodels map[string]*gptj.GPTJ gpt2models map[string]*gpt2.GPT2 gptstablelmmodels map[string]*gpt2.StableLM - - promptsTemplates map[string]*template.Template + rwkv map[string]*rwkv.RwkvState + promptsTemplates map[string]*template.Template } func NewModelLoader(modelPath string) *ModelLoader { @@ -36,6 +37,7 @@ func NewModelLoader(modelPath string) *ModelLoader { gptmodels: make(map[string]*gptj.GPTJ), gptstablelmmodels: make(map[string]*gpt2.StableLM), models: make(map[string]*llama.LLama), + rwkv: make(map[string]*rwkv.RwkvState), promptsTemplates: make(map[string]*template.Template), } } @@ -218,6 +220,36 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) { return model, err } +func (ml *ModelLoader) LoadRWKV(modelName, tokenFile string, threads uint32) (*rwkv.RwkvState, error) { + ml.mu.Lock() + defer ml.mu.Unlock() + + log.Debug().Msgf("Loading model name: %s", modelName) + + // Check if we already have a loaded model + if !ml.ExistsInModelPath(modelName) { + return nil, fmt.Errorf("model does not exist") + } + + if m, ok := ml.rwkv[modelName]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", modelName) + return m, nil + } + + // Load the model and keep it in memory for later use + modelFile := filepath.Join(ml.ModelPath, modelName) + tokenPath := filepath.Join(ml.ModelPath, tokenFile) + log.Debug().Msgf("Loading model in memory from file: %s", modelFile) + + model := rwkv.LoadFiles(modelFile, tokenPath, threads) + if model == nil { + return nil, fmt.Errorf("could not load model") + } + + ml.rwkv[modelName] = model + return model, nil +} + func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) { ml.mu.Lock() defer ml.mu.Unlock()