feat: Update gpt4all, support multiple implementations in runtime (#472)

Signed-off-by: mudler <mudler@mocaccino.org>
renovate/github.com-imdario-mergo-1.x
Ettore Di Giacinto 2 years ago committed by GitHub
parent 42d753846e
commit 78ad4813df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      .gitignore
  2. 34
      Makefile
  3. 7
      api/api.go
  4. 2
      api/api_test.go
  5. 27
      api/backend_assets.go
  6. 16
      api/options.go
  7. 6
      assets.go
  8. 2
      go.mod
  9. 2
      go.sum
  10. 8
      main.go
  11. 51
      pkg/assets/extract.go
  12. 13
      pkg/model/initializers.go

3
.gitignore vendored

@ -25,3 +25,6 @@ release/
# just in case # just in case
.DS_Store .DS_Store
.idea .idea
# Generated during build
backend-assets/

@ -5,7 +5,7 @@ BINARY_NAME=local-ai
GOLLAMA_VERSION?=10caf37d8b73386708b4373975b8917e6b212c0e GOLLAMA_VERSION?=10caf37d8b73386708b4373975b8917e6b212c0e
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
GPT4ALL_VERSION?=337c7fecacfa4ae6779046513ab090687a5b0ef6 GPT4ALL_VERSION?=022f1cabe7dd2c911936b37510582f279069ba1e
GOGGMLTRANSFORMERS_VERSION?=13ccc22621bb21afecd38675a2b043498e2e756c GOGGMLTRANSFORMERS_VERSION?=13ccc22621bb21afecd38675a2b043498e2e756c
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=ccb05c3e1c6efd098017d114dcb58ab3262b40b2 RWKV_VERSION?=ccb05c3e1c6efd098017d114dcb58ab3262b40b2
@ -63,22 +63,13 @@ gpt4all:
git clone --recurse-submodules $(GPT4ALL_REPO) gpt4all git clone --recurse-submodules $(GPT4ALL_REPO) gpt4all
cd gpt4all && git checkout -b build $(GPT4ALL_VERSION) && git submodule update --init --recursive --depth 1 cd gpt4all && git checkout -b build $(GPT4ALL_VERSION) && git submodule update --init --recursive --depth 1
# This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml.. # This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml..
@find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} + @find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} + @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gptj_/g' {} + @find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gptj_/g' {} + @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.cpp" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gptj_/g' {} + @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.go" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/set_console_color/set_gptj_console_color/g' {} + @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/set_console_color/set_gptj_console_color/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.go" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.txt" -exec sed -i'' -e 's/llama_/gptjllama_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gptj_/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_gptj_replace/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_gptj_replace/g' {} +
@find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/regex_escape/gpt4allregex_escape/g' {} +
mv ./gpt4all/gpt4all-backend/llama.cpp/llama_util.h ./gpt4all/gpt4all-backend/llama.cpp/gptjllama_util.h
## BERT embeddings ## BERT embeddings
go-bert: go-bert:
@ -124,6 +115,12 @@ bloomz/libbloomz.a: bloomz
go-bert/libgobert.a: go-bert go-bert/libgobert.a: go-bert
$(MAKE) -C go-bert libgobert.a $(MAKE) -C go-bert libgobert.a
backend-assets/gpt4all: gpt4all/gpt4all-bindings/golang/libgpt4all.a
mkdir -p backend-assets/gpt4all
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.so backend-assets/gpt4all/ || true
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.dylib backend-assets/gpt4all/ || true
@cp gpt4all/gpt4all-bindings/golang/buildllm/*.dll backend-assets/gpt4all/ || true
gpt4all/gpt4all-bindings/golang/libgpt4all.a: gpt4all gpt4all/gpt4all-bindings/golang/libgpt4all.a: gpt4all
$(MAKE) -C gpt4all/gpt4all-bindings/golang/ libgpt4all.a $(MAKE) -C gpt4all/gpt4all-bindings/golang/ libgpt4all.a
@ -188,7 +185,7 @@ rebuild: ## Rebuilds the project
$(MAKE) -C bloomz clean $(MAKE) -C bloomz clean
$(MAKE) build $(MAKE) build
prepare: prepare-sources gpt4all/gpt4all-bindings/golang/libgpt4all.a $(OPTIONAL_TARGETS) go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building prepare: prepare-sources backend-assets/gpt4all $(OPTIONAL_TARGETS) go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building
clean: ## Remove build related file clean: ## Remove build related file
rm -fr ./go-llama rm -fr ./go-llama
@ -196,6 +193,7 @@ clean: ## Remove build related file
rm -rf ./go-gpt2 rm -rf ./go-gpt2
rm -rf ./go-stable-diffusion rm -rf ./go-stable-diffusion
rm -rf ./go-ggml-transformers rm -rf ./go-ggml-transformers
rm -rf ./backend-assets
rm -rf ./go-rwkv rm -rf ./go-rwkv
rm -rf ./go-bert rm -rf ./go-bert
rm -rf ./bloomz rm -rf ./bloomz

@ -66,6 +66,13 @@ func App(opts ...AppOption) (*fiber.App, error) {
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
} }
} }
if options.assetsDestination != "" {
if err := PrepareBackendAssets(options.backendAssets, options.assetsDestination); err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}
// Default middleware config // Default middleware config
app.Use(recover.New()) app.Use(recover.New())

@ -257,7 +257,7 @@ var _ = Describe("API test", func() {
It("returns errors", func() { It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:")) Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 10 errors occurred:"))
}) })
It("transcribes audio", func() { It("transcribes audio", func() {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {

@ -0,0 +1,27 @@
package api
import (
"embed"
"os"
"path/filepath"
"github.com/go-skynet/LocalAI/pkg/assets"
"github.com/rs/zerolog/log"
)
func PrepareBackendAssets(backendAssets embed.FS, dst string) error {
// Extract files from the embedded FS
err := assets.ExtractFiles(backendAssets, dst)
if err != nil {
return err
}
// Set GPT4ALL libs where we extracted the files
// https://github.com/nomic-ai/gpt4all/commit/27e80e1d10985490c9fd4214e4bf458cfcf70896
gpt4alldir := filepath.Join(dst, "backend-assets", "gpt4all")
os.Setenv("GPT4ALL_IMPLEMENTATIONS_PATH", gpt4alldir)
log.Debug().Msgf("GPT4ALL_IMPLEMENTATIONS_PATH: %s", gpt4alldir)
return nil
}

@ -2,6 +2,7 @@ package api
import ( import (
"context" "context"
"embed"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
) )
@ -18,6 +19,9 @@ type Option struct {
preloadJSONModels string preloadJSONModels string
preloadModelsFromPath string preloadModelsFromPath string
corsAllowOrigins string corsAllowOrigins string
backendAssets embed.FS
assetsDestination string
} }
type AppOption func(*Option) type AppOption func(*Option)
@ -49,6 +53,18 @@ func WithCorsAllowOrigins(b string) AppOption {
} }
} }
func WithBackendAssetsOutput(out string) AppOption {
return func(o *Option) {
o.assetsDestination = out
}
}
func WithBackendAssets(f embed.FS) AppOption {
return func(o *Option) {
o.backendAssets = f
}
}
func WithContext(ctx context.Context) AppOption { func WithContext(ctx context.Context) AppOption {
return func(o *Option) { return func(o *Option) {
o.context = ctx o.context = ctx

@ -0,0 +1,6 @@
package main
import "embed"
//go:embed backend-assets/*
var backendAssets embed.FS

@ -15,7 +15,7 @@ require (
github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-multierror v1.1.1
github.com/imdario/mergo v0.3.16 github.com/imdario/mergo v0.3.16
github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642 github.com/mudler/go-stable-diffusion v0.0.0-20230516152536-c0748eca3642
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5 github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230601151908-5175db27813c
github.com/onsi/ginkgo/v2 v2.9.7 github.com/onsi/ginkgo/v2 v2.9.7
github.com/onsi/gomega v1.27.7 github.com/onsi/gomega v1.27.7
github.com/otiai10/openaigo v1.1.0 github.com/otiai10/openaigo v1.1.0

@ -155,6 +155,8 @@ github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230528235700-9eb81c
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230528235700-9eb81cb54922/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230528235700-9eb81cb54922/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5 h1:99cF+V5wk7IInDAEM9HAlSHdLf/xoJR529Wr8lAG5KQ= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5 h1:99cF+V5wk7IInDAEM9HAlSHdLf/xoJR529Wr8lAG5KQ=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230531011104-5f940208e4f5/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230601151908-5175db27813c h1:KXYqUH6bdYbxnF67l8wayctaCZ4BQJQOsUyNke7HC0A=
github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230601151908-5175db27813c/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss= github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=

@ -80,6 +80,12 @@ func main() {
EnvVars: []string{"IMAGE_PATH"}, EnvVars: []string{"IMAGE_PATH"},
Value: "", Value: "",
}, },
&cli.StringFlag{
Name: "backend-assets-path",
DefaultText: "Path used to extract libraries that are required by some of the backends in runtime.",
EnvVars: []string{"BACKEND_ASSETS_PATH"},
Value: "/tmp/localai/backend_data",
},
&cli.IntFlag{ &cli.IntFlag{
Name: "context-size", Name: "context-size",
DefaultText: "Default context size of the model", DefaultText: "Default context size of the model",
@ -124,6 +130,8 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
api.WithCors(ctx.Bool("cors")), api.WithCors(ctx.Bool("cors")),
api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
api.WithThreads(ctx.Int("threads")), api.WithThreads(ctx.Int("threads")),
api.WithBackendAssets(backendAssets),
api.WithBackendAssetsOutput(ctx.String("backend-assets-path")),
api.WithUploadLimitMB(ctx.Int("upload-limit"))) api.WithUploadLimitMB(ctx.Int("upload-limit")))
if err != nil { if err != nil {
return err return err

@ -0,0 +1,51 @@
package assets
import (
"embed"
"fmt"
"io/fs"
"os"
"path/filepath"
)
func ExtractFiles(content embed.FS, extractDir string) error {
// Create the target directory if it doesn't exist
err := os.MkdirAll(extractDir, 0755)
if err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
// Walk through the embedded FS and extract files
err = fs.WalkDir(content, ".", func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
// Reconstruct the directory structure in the target directory
targetFile := filepath.Join(extractDir, path)
if d.IsDir() {
// Create the directory in the target directory
err := os.MkdirAll(targetFile, 0755)
if err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
return nil
}
// Read the file from the embedded FS
fileData, err := content.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read file: %v", err)
}
// Create the file in the target directory
err = os.WriteFile(targetFile, fileData, 0644)
if err != nil {
return fmt.Errorf("failed to write file: %v", err)
}
return nil
})
return err
}

@ -33,6 +33,7 @@ const (
Gpt4AllLlamaBackend = "gpt4all-llama" Gpt4AllLlamaBackend = "gpt4all-llama"
Gpt4AllMptBackend = "gpt4all-mpt" Gpt4AllMptBackend = "gpt4all-mpt"
Gpt4AllJBackend = "gpt4all-j" Gpt4AllJBackend = "gpt4all-j"
Gpt4All = "gpt4all"
BertEmbeddingsBackend = "bert-embeddings" BertEmbeddingsBackend = "bert-embeddings"
RwkvBackend = "rwkv" RwkvBackend = "rwkv"
WhisperBackend = "whisper" WhisperBackend = "whisper"
@ -42,9 +43,7 @@ const (
var backends []string = []string{ var backends []string = []string{
LlamaBackend, LlamaBackend,
Gpt4AllLlamaBackend, Gpt4All,
Gpt4AllMptBackend,
Gpt4AllJBackend,
RwkvBackend, RwkvBackend,
GPTNeoXBackend, GPTNeoXBackend,
WhisperBackend, WhisperBackend,
@ -153,12 +152,8 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla
return ml.LoadModel(modelFile, stableDiffusion) return ml.LoadModel(modelFile, stableDiffusion)
case StarcoderBackend: case StarcoderBackend:
return ml.LoadModel(modelFile, starCoder) return ml.LoadModel(modelFile, starCoder)
case Gpt4AllLlamaBackend: case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.LLaMAType))) return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads))))
case Gpt4AllMptBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.MPTType)))
case Gpt4AllJBackend:
return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetModelType(gpt4all.GPTJType)))
case BertEmbeddingsBackend: case BertEmbeddingsBackend:
return ml.LoadModel(modelFile, bertEmbeddings) return ml.LoadModel(modelFile, bertEmbeddings)
case RwkvBackend: case RwkvBackend:

Loading…
Cancel
Save