From b816009db0e43d3bd979c598f56e9431b76a9157 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 01/12] feat: add falcon ggllm via grpc client Signed-off-by: Ettore Di Giacinto --- .gitignore | 9 +- Makefile | 36 +- api/api.go | 1 + api/localai.go | 6 +- api/openai.go | 18 +- api/prediction.go | 379 +++++++++--- cmd/grpc/falcon/main.go | 25 + go.mod | 23 +- go.sum | 162 +++--- pkg/grpc/client.go | 98 ++++ pkg/grpc/interface.go | 11 + pkg/grpc/llm/falcon/falcon.go | 136 +++++ pkg/grpc/llm/ggml/starcoder.go | 0 pkg/grpc/proto/llmserver.pb.go | 870 ++++++++++++++++++++++++++++ pkg/grpc/proto/llmserver.proto | 82 +++ pkg/grpc/proto/llmserver_grpc.pb.go | 241 ++++++++ pkg/grpc/server.go | 76 +++ pkg/model/initializers.go | 182 +++++- pkg/model/loader.go | 3 + pkg/model/options.go | 62 ++ 20 files changed, 2194 insertions(+), 226 deletions(-) create mode 100644 cmd/grpc/falcon/main.go create mode 100644 pkg/grpc/client.go create mode 100644 pkg/grpc/interface.go create mode 100644 pkg/grpc/llm/falcon/falcon.go create mode 100644 pkg/grpc/llm/ggml/starcoder.go create mode 100644 pkg/grpc/proto/llmserver.pb.go create mode 100644 pkg/grpc/proto/llmserver.proto create mode 100644 pkg/grpc/proto/llmserver_grpc.pb.go create mode 100644 pkg/grpc/server.go create mode 100644 pkg/model/options.go diff --git a/.gitignore b/.gitignore index 8ad9f22..8819ad7 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,13 @@ go-llama gpt4all go-stable-diffusion +go-piper +go-ggllm +piper + +*.a +get-sources + go-ggml-transformers go-gpt2 go-rwkv @@ -29,4 +36,4 @@ release/ # Generated during build backend-assets/ -/ggml-metal.metal \ No newline at end of file +/ggml-metal.metal diff --git a/Makefile b/Makefile index d885b94..abac2b4 100644 --- a/Makefile +++ b/Makefile @@ -41,6 +41,9 @@ BLOOMZ_VERSION?=1834e77b83faafe912ad4092ccf7f77937349e2f # stablediffusion version STABLEDIFFUSION_VERSION?=d89260f598afb809279bc72aa0107b4292587632 +# Go-ggllm +GOGGLLM_VERSION?=862477d16eefb0805261c19c9b0d053e3b2b684b + export BUILD_TYPE?= CGO_LDFLAGS?= CUDA_LIBPATH?=/usr/local/cuda/lib64/ @@ -126,6 +129,14 @@ gpt4all: @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.c" -exec sed -i'' -e 's/clear_numa_thread_affinity/gpt4all__clear_numa_thread_affinity/g' {} + @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/clear_numa_thread_affinity/gpt4all__clear_numa_thread_affinity/g' {} + +## go-ggllm +go-ggllm: + git clone --recurse-submodules https://github.com/mudler/go-ggllm.cpp go-ggllm + cd go-ggllm && git checkout -b build $(GOGGLLM_VERSION) && git submodule update --init --recursive --depth 1 + +go-ggllm/libggllm.a: go-ggllm + $(MAKE) -C go-ggllm BUILD_TYPE=$(BUILD_TYPE) libggllm.a + ## go-piper go-piper: git clone --recurse-submodules https://github.com/mudler/go-piper go-piper @@ -238,7 +249,7 @@ go-llama/libbinding.a: go-llama go-piper/libpiper_binding.a: $(MAKE) -C go-piper libpiper_binding.a example/main -get-sources: go-llama go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion +get-sources: go-llama go-ggllm go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion touch $@ replace: @@ -251,6 +262,7 @@ replace: $(GOCMD) mod edit -replace github.com/go-skynet/bloomz.cpp=$(shell pwd)/bloomz $(GOCMD) mod edit -replace github.com/mudler/go-stable-diffusion=$(shell pwd)/go-stable-diffusion $(GOCMD) mod edit -replace github.com/mudler/go-piper=$(shell pwd)/go-piper + $(GOCMD) mod edit -replace github.com/mudler/go-ggllm.cpp=$(shell pwd)/go-ggllm prepare-sources: get-sources replace $(GOCMD) mod download @@ -267,9 +279,10 @@ rebuild: ## Rebuilds the project $(MAKE) -C go-bert clean $(MAKE) -C bloomz clean $(MAKE) -C go-piper clean + $(MAKE) -C go-ggllm clean $(MAKE) build -prepare: prepare-sources backend-assets/gpt4all $(OPTIONAL_TARGETS) go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building +prepare: prepare-sources backend-assets/gpt4all grpcs $(OPTIONAL_TARGETS) go-ggllm/libggllm.a go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building touch $@ clean: ## Remove build related file @@ -285,6 +298,7 @@ clean: ## Remove build related file rm -rf ./bloomz rm -rf ./whisper.cpp rm -rf ./go-piper + rm -rf ./go-ggllm rm -rf $(BINARY_NAME) rm -rf release/ @@ -296,7 +310,7 @@ build: prepare ## Build the project $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) $(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET}) - CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./ + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./ ifeq ($(BUILD_TYPE),metal) cp go-llama/build/bin/ggml-metal.metal . endif @@ -341,3 +355,19 @@ help: ## Show this help. if (/^[a-zA-Z_-]+:.*?##.*$$/) {printf " ${YELLOW}%-20s${GREEN}%s${RESET}\n", $$1, $$2} \ else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \ }' $(MAKEFILE_LIST) + +protogen: + protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \ + pkg/grpc/proto/llmserver.proto + +## GRPC + +backend-assets/grpc: + mkdir -p backend-assets/grpc + +falcon-grpc: backend-assets/grpc + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggllm LIBRARY_PATH=$(shell pwd)/go-ggllm \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/ + + +grpcs: falcon-grpc \ No newline at end of file diff --git a/api/api.go b/api/api.go index 543e756..1438f1f 100644 --- a/api/api.go +++ b/api/api.go @@ -75,6 +75,7 @@ func App(opts ...AppOption) (*fiber.App, error) { if options.assetsDestination != "" { // Extract files from the embedded FS err := assets.ExtractFiles(options.backendAssets, options.assetsDestination) + log.Debug().Msgf("Extracting backend assets files to %s", options.assetsDestination) if err != nil { log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) } diff --git a/api/localai.go b/api/localai.go index b719689..66eda5a 100644 --- a/api/localai.go +++ b/api/localai.go @@ -8,7 +8,6 @@ import ( model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/tts" "github.com/go-skynet/LocalAI/pkg/utils" - llama "github.com/go-skynet/go-llama.cpp" "github.com/gofiber/fiber/v2" ) @@ -42,7 +41,10 @@ func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return err } - piperModel, err := o.loader.BackendLoader(model.PiperBackend, input.Model, []llama.ModelOption{}, uint32(0), o.assetsDestination) + piperModel, err := o.loader.BackendLoader( + model.WithBackendString(model.PiperBackend), + model.WithModelFile(input.Model), + model.WithAssetDir(o.assetsDestination)) if err != nil { return err } diff --git a/api/openai.go b/api/openai.go index 77d2c8e..c39b1cc 100644 --- a/api/openai.go +++ b/api/openai.go @@ -20,7 +20,6 @@ import ( "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" - llama "github.com/go-skynet/go-llama.cpp" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" @@ -362,6 +361,13 @@ func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { } } +func isEOS(s string) bool { + if s == "<|endoftext|>" { + return true + } + + return false +} func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { @@ -380,7 +386,9 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { } log.Debug().Msgf("Sending goroutine: %s", s) - responses <- resp + if s != "" && !isEOS(s) { + responses <- resp + } return true }) close(responses) @@ -905,7 +913,11 @@ func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { log.Debug().Msgf("Audio file copied to: %+v", dst) - whisperModel, err := o.loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads), o.assetsDestination) + whisperModel, err := o.loader.BackendLoader( + model.WithBackendString(model.WhisperBackend), + model.WithModelFile(config.Model), + model.WithThreads(uint32(config.Threads)), + model.WithAssetDir(o.assetsDestination)) if err != nil { return err } diff --git a/api/prediction.go b/api/prediction.go index 7daa730..b9b5710 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -1,6 +1,7 @@ package api import ( + "context" "fmt" "os" "path/filepath" @@ -9,6 +10,8 @@ import ( "sync" "github.com/donomii/go-rwkv.cpp" + "github.com/go-skynet/LocalAI/pkg/grpc" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/LocalAI/pkg/langchain" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/stablediffusion" @@ -16,6 +19,7 @@ import ( bert "github.com/go-skynet/go-bert.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp" llama "github.com/go-skynet/go-llama.cpp" + gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" ) @@ -23,6 +27,160 @@ import ( var mutexMap sync.Mutex var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) +func gRPCModelOpts(c Config) *pb.ModelOptions { + b := 512 + if c.Batch != 0 { + b = c.Batch + } + return &pb.ModelOptions{ + ContextSize: int32(c.ContextSize), + Seed: int32(c.Seed), + NBatch: int32(b), + NGPULayers: int32(c.NGPULayers), + MMap: c.MMap, + MainGPU: c.MainGPU, + TensorSplit: c.TensorSplit, + } +} + +// func defaultGGLLMOpts(c Config) []ggllm.ModelOption { +// ggllmOpts := []ggllm.ModelOption{} +// if c.ContextSize != 0 { +// ggllmOpts = append(ggllmOpts, ggllm.SetContext(c.ContextSize)) +// } +// // F16 doesn't seem to produce good output at all! +// //if c.F16 { +// // llamaOpts = append(llamaOpts, llama.EnableF16Memory) +// //} + +// if c.NGPULayers != 0 { +// ggllmOpts = append(ggllmOpts, ggllm.SetGPULayers(c.NGPULayers)) +// } + +// ggllmOpts = append(ggllmOpts, ggllm.SetMMap(c.MMap)) +// ggllmOpts = append(ggllmOpts, ggllm.SetMainGPU(c.MainGPU)) +// ggllmOpts = append(ggllmOpts, ggllm.SetTensorSplit(c.TensorSplit)) +// if c.Batch != 0 { +// ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(c.Batch)) +// } else { +// ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(512)) +// } + +// return ggllmOpts +// } + +func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { + promptCachePath := "" + if c.PromptCachePath != "" { + p := filepath.Join(modelPath, c.PromptCachePath) + os.MkdirAll(filepath.Dir(p), 0755) + promptCachePath = p + } + return &pb.PredictOptions{ + Temperature: float32(c.Temperature), + TopP: float32(c.TopP), + TopK: int32(c.TopK), + Tokens: int32(c.Maxtokens), + Threads: int32(c.Threads), + PromptCacheAll: c.PromptCacheAll, + PromptCacheRO: c.PromptCacheRO, + PromptCachePath: promptCachePath, + Mirostat: int32(c.Mirostat), + MirostatETA: float32(c.MirostatETA), + MirostatTAU: float32(c.MirostatTAU), + Debug: c.Debug, + StopPrompts: c.StopWords, + Repeat: int32(c.RepeatPenalty), + NKeep: int32(c.Keep), + Batch: int32(c.Batch), + IgnoreEOS: c.IgnoreEOS, + Seed: int32(c.Seed), + FrequencyPenalty: float32(c.FrequencyPenalty), + MLock: c.MMlock, + MMap: c.MMap, + MainGPU: c.MainGPU, + TensorSplit: c.TensorSplit, + TailFreeSamplingZ: float32(c.TFZ), + TypicalP: float32(c.TypicalP), + } +} + +// func buildGGLLMPredictOptions(c Config, modelPath string) []ggllm.PredictOption { +// // Generate the prediction using the language model +// predictOptions := []ggllm.PredictOption{ +// ggllm.SetTemperature(c.Temperature), +// ggllm.SetTopP(c.TopP), +// ggllm.SetTopK(c.TopK), +// ggllm.SetTokens(c.Maxtokens), +// ggllm.SetThreads(c.Threads), +// } + +// if c.PromptCacheAll { +// predictOptions = append(predictOptions, ggllm.EnablePromptCacheAll) +// } + +// if c.PromptCacheRO { +// predictOptions = append(predictOptions, ggllm.EnablePromptCacheRO) +// } + +// if c.PromptCachePath != "" { +// // Create parent directory +// p := filepath.Join(modelPath, c.PromptCachePath) +// os.MkdirAll(filepath.Dir(p), 0755) +// predictOptions = append(predictOptions, ggllm.SetPathPromptCache(p)) +// } + +// if c.Mirostat != 0 { +// predictOptions = append(predictOptions, ggllm.SetMirostat(c.Mirostat)) +// } + +// if c.MirostatETA != 0 { +// predictOptions = append(predictOptions, ggllm.SetMirostatETA(c.MirostatETA)) +// } + +// if c.MirostatTAU != 0 { +// predictOptions = append(predictOptions, ggllm.SetMirostatTAU(c.MirostatTAU)) +// } + +// if c.Debug { +// predictOptions = append(predictOptions, ggllm.Debug) +// } + +// predictOptions = append(predictOptions, ggllm.SetStopWords(c.StopWords...)) + +// if c.RepeatPenalty != 0 { +// predictOptions = append(predictOptions, ggllm.SetPenalty(c.RepeatPenalty)) +// } + +// if c.Keep != 0 { +// predictOptions = append(predictOptions, ggllm.SetNKeep(c.Keep)) +// } + +// if c.Batch != 0 { +// predictOptions = append(predictOptions, ggllm.SetBatch(c.Batch)) +// } + +// if c.IgnoreEOS { +// predictOptions = append(predictOptions, ggllm.IgnoreEOS) +// } + +// if c.Seed != 0 { +// predictOptions = append(predictOptions, ggllm.SetSeed(c.Seed)) +// } + +// //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) + +// predictOptions = append(predictOptions, ggllm.SetFrequencyPenalty(c.FrequencyPenalty)) +// predictOptions = append(predictOptions, ggllm.SetMlock(c.MMlock)) +// predictOptions = append(predictOptions, ggllm.SetMemoryMap(c.MMap)) +// predictOptions = append(predictOptions, ggllm.SetPredictionMainGPU(c.MainGPU)) +// predictOptions = append(predictOptions, ggllm.SetPredictionTensorSplit(c.TensorSplit)) +// predictOptions = append(predictOptions, ggllm.SetTailFreeSamplingZ(c.TFZ)) +// predictOptions = append(predictOptions, ggllm.SetTypicalP(c.TypicalP)) + +// return predictOptions +// } + func defaultLLamaOpts(c Config) []llama.ModelOption { llamaOpts := []llama.ModelOption{} if c.ContextSize != 0 { @@ -59,11 +217,99 @@ func defaultLLamaOpts(c Config) []llama.ModelOption { return llamaOpts } +func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption { + // Generate the prediction using the language model + predictOptions := []llama.PredictOption{ + llama.SetTemperature(c.Temperature), + llama.SetTopP(c.TopP), + llama.SetTopK(c.TopK), + llama.SetTokens(c.Maxtokens), + llama.SetThreads(c.Threads), + } + + if c.PromptCacheAll { + predictOptions = append(predictOptions, llama.EnablePromptCacheAll) + } + + if c.PromptCacheRO { + predictOptions = append(predictOptions, llama.EnablePromptCacheRO) + } + + predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar)) + + if c.PromptCachePath != "" { + // Create parent directory + p := filepath.Join(modelPath, c.PromptCachePath) + os.MkdirAll(filepath.Dir(p), 0755) + predictOptions = append(predictOptions, llama.SetPathPromptCache(p)) + } + + if c.Mirostat != 0 { + predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) + } + + if c.MirostatETA != 0 { + predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) + } + + if c.MirostatTAU != 0 { + predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) + } + + if c.Debug { + predictOptions = append(predictOptions, llama.Debug) + } + + predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) + + if c.RepeatPenalty != 0 { + predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) + } + + if c.Keep != 0 { + predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) + } + + if c.F16 { + predictOptions = append(predictOptions, llama.EnableF16KV) + } + + if c.IgnoreEOS { + predictOptions = append(predictOptions, llama.IgnoreEOS) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) + } + + //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) + + predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty)) + predictOptions = append(predictOptions, llama.SetMlock(c.MMlock)) + predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap)) + predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU)) + predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit)) + predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ)) + predictOptions = append(predictOptions, llama.SetTypicalP(c.TypicalP)) + + return predictOptions +} + func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) { if c.Backend != model.StableDiffusionBackend { return nil, fmt.Errorf("endpoint only working with stablediffusion models") } - inferenceModel, err := loader.BackendLoader(c.Backend, c.ImageGenerationAssets, []llama.ModelOption{}, uint32(c.Threads), o.assetsDestination) + + inferenceModel, err := loader.BackendLoader( + model.WithBackendString(c.Backend), + model.WithAssetDir(o.assetsDestination), + model.WithThreads(uint32(c.Threads)), + model.WithModelFile(c.ImageGenerationAssets), + ) if err != nil { return nil, err } @@ -106,13 +352,24 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, modelFile := c.Model llamaOpts := defaultLLamaOpts(c) + grpcOpts := gRPCModelOpts(c) var inferenceModel interface{} var err error + + opts := []model.Option{ + model.WithLlamaOpts(llamaOpts...), + model.WithLoadGRPCOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), + model.WithAssetDir(o.assetsDestination), + model.WithModelFile(modelFile), + } + if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) + inferenceModel, err = loader.GreedyLoader(opts...) } else { - inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) + opts = append(opts, model.WithBackendString(c.Backend)) + inferenceModel, err = loader.BackendLoader(opts...) } if err != nil { return nil, err @@ -171,100 +428,29 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, }, nil } -func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption { - // Generate the prediction using the language model - predictOptions := []llama.PredictOption{ - llama.SetTemperature(c.Temperature), - llama.SetTopP(c.TopP), - llama.SetTopK(c.TopK), - llama.SetTokens(c.Maxtokens), - llama.SetThreads(c.Threads), - } - - if c.PromptCacheAll { - predictOptions = append(predictOptions, llama.EnablePromptCacheAll) - } - - if c.PromptCacheRO { - predictOptions = append(predictOptions, llama.EnablePromptCacheRO) - } - - predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar)) - - if c.PromptCachePath != "" { - // Create parent directory - p := filepath.Join(modelPath, c.PromptCachePath) - os.MkdirAll(filepath.Dir(p), 0755) - predictOptions = append(predictOptions, llama.SetPathPromptCache(p)) - } - - if c.Mirostat != 0 { - predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) - } - - if c.MirostatETA != 0 { - predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) - } - - if c.MirostatTAU != 0 { - predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) - } - - if c.Debug { - predictOptions = append(predictOptions, llama.Debug) - } - - predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) - - if c.RepeatPenalty != 0 { - predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) - } - - if c.Keep != 0 { - predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) - } - - if c.F16 { - predictOptions = append(predictOptions, llama.EnableF16KV) - } - - if c.IgnoreEOS { - predictOptions = append(predictOptions, llama.IgnoreEOS) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) - } - - //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) - - predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty)) - predictOptions = append(predictOptions, llama.SetMlock(c.MMlock)) - predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap)) - predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU)) - predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit)) - predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ)) - predictOptions = append(predictOptions, llama.SetTypicalP(c.TypicalP)) - - return predictOptions -} - func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, tokenCallback func(string) bool) (func() (string, error), error) { supportStreams := false modelFile := c.Model llamaOpts := defaultLLamaOpts(c) + grpcOpts := gRPCModelOpts(c) var inferenceModel interface{} var err error + + opts := []model.Option{ + model.WithLlamaOpts(llamaOpts...), + model.WithLoadGRPCOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), + model.WithAssetDir(o.assetsDestination), + model.WithModelFile(modelFile), + } + if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) + inferenceModel, err = loader.GreedyLoader(opts...) } else { - inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) + opts = append(opts, model.WithBackendString(c.Backend)) + inferenceModel, err = loader.BackendLoader(opts...) } if err != nil { return nil, err @@ -552,6 +738,25 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to model.SetTokenCallback(nil) return str, er } + case *grpc.Client: + // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported + supportStreams = true + fn = func() (string, error) { + + opts := gRPCPredictOpts(c, loader.ModelPath) + opts.Prompt = s + if tokenCallback != nil { + ss := "" + err := model.PredictStream(context.TODO(), opts, func(s string) { + tokenCallback(s) + ss += s + }) + return ss, err + } else { + reply, err := model.Predict(context.TODO(), opts) + return reply.Message, err + } + } case *langchain.HuggingFace: fn = func() (string, error) { diff --git a/cmd/grpc/falcon/main.go b/cmd/grpc/falcon/main.go new file mode 100644 index 0000000..9ccead4 --- /dev/null +++ b/cmd/grpc/falcon/main.go @@ -0,0 +1,25 @@ +package main + +// GRPC Falcon server + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + falcon "github.com/go-skynet/LocalAI/pkg/grpc/llm/falcon" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &falcon.LLM{}); err != nil { + panic(err) + } +} diff --git a/go.mod b/go.mod index 0f65978..1d6268c 100644 --- a/go.mod +++ b/go.mod @@ -13,20 +13,25 @@ require ( github.com/gofiber/fiber/v2 v2.47.0 github.com/google/uuid v1.3.0 github.com/hashicorp/go-multierror v1.1.1 + github.com/hpcloud/tail v1.0.0 github.com/imdario/mergo v0.3.16 github.com/json-iterator/go v1.1.12 github.com/mholt/archiver/v3 v3.5.1 + github.com/mudler/go-ggllm.cpp v0.0.0-20230708215552-a6504d5bc137 + github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230708212935-d611d107479f github.com/onsi/ginkgo/v2 v2.11.0 github.com/onsi/gomega v1.27.8 github.com/otiai10/openaigo v1.5.2 + github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/rs/zerolog v1.29.1 github.com/sashabaranov/go-openai v1.13.0 - github.com/swaggo/swag v1.16.1 github.com/tmc/langchaingo v0.0.0-20230709010448-a875e6bc0c54 github.com/urfave/cli/v2 v2.25.7 github.com/valyala/fasthttp v1.48.0 + google.golang.org/grpc v1.56.2 + google.golang.org/protobuf v1.30.0 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -34,8 +39,10 @@ require ( require ( github.com/dlclark/regexp2 v1.8.1 // indirect github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.2 // indirect github.com/klauspost/pgzip v1.2.5 // indirect + github.com/kr/text v0.2.0 // indirect github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/nwaples/rardecode v1.1.0 // indirect @@ -43,33 +50,27 @@ require ( github.com/pkoukk/tiktoken-go v0.1.2 // indirect github.com/ulikunitz/xz v0.5.9 // indirect github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect + google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/fsnotify.v1 v1.4.7 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) require ( - github.com/KyleBanks/depth v1.2.1 // indirect - github.com/PuerkitoBio/purell v1.1.1 // indirect - github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/go-audio/audio v1.0.0 // indirect github.com/go-audio/riff v1.0.0 // indirect github.com/go-logr/logr v1.2.4 // indirect - github.com/go-openapi/jsonpointer v0.19.5 // indirect - github.com/go-openapi/jsonreference v0.19.6 // indirect - github.com/go-openapi/spec v0.20.4 // indirect - github.com/go-openapi/swag v0.22.3 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect - github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.16.3 // indirect - github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 - github.com/otiai10/mint v1.6.1 // indirect github.com/philhofer/fwd v1.1.2 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 81f81e7..2906f50 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,3 @@ -github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= -github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= -github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI= -github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= -github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M= -github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= @@ -19,13 +13,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674 h1:G70Yf/QOCEL1v24idWnGd6rJsbqiGkJAJnMaWaolzEg= -github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L7HYpRu/0lE3e0BaElwnNO1qkNQxBY= github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s= github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27 h1:boeMTUUBtnLU8JElZJHXrsUzROJar9/t6vGOFjkrhhI= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= @@ -36,47 +29,28 @@ github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g= github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= -github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= -github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= -github.com/go-openapi/jsonreference v0.19.6 h1:UBIxjkht+AWIgYzCDSv2GN+E/togfwXUJFRTWhl2Jjs= -github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= -github.com/go-openapi/spec v0.20.4 h1:O8hJrt0UMnhHcluhIdUgCLRWyM2x7QkBXRvOs7m+O1M= -github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= -github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= -github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= -github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= -github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= -github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa h1:gxr68r/6EWroay4iI81jxqGCDbKotY4+CiwdUkBz2NQ= -github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA= -github.com/go-skynet/go-bert.cpp v0.0.0-20230607105116-6069103f54b9 h1:wRGbDwNwPmSzoXVw/HLzXY4blpRvPWg7QW2OA0WKezA= -github.com/go-skynet/go-bert.cpp v0.0.0-20230607105116-6069103f54b9/go.mod h1:pXKCpYYXujMeAvgJHU6WoMfvYbr84563+J8+Ebkyr5U= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230617123349-32b9223ccdb1 h1:jVGgzDSfpjD/0jl/ChpGI+O4EHSAeeU6DK7IyhH8PK8= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230617123349-32b9223ccdb1/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230620192816-a459d2726792 h1:rozZ9gWGzq0ZhBsNCWqfLTRCebaxwTsxLMnflwe6rDU= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230620192816-a459d2726792/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230626202628-8e31841dcddc h1:SrNxH4U8W6cqurbxpXxm9rzifeDsCgecRT73kT0BRq0= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230626202628-8e31841dcddc/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230630204211-3fec197a1dc4 h1:LScGc8yWTS9wbS2RTOq6s+waeHElLIQDJg2SUCwrO3E= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230630204211-3fec197a1dc4/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= -github.com/go-skynet/go-llama.cpp v0.0.0-20230616223721-7ad833b67070 h1:T771FjB1yQw8j4P5x4ayFrUPNTglzxRIqDjaNkMVIME= -github.com/go-skynet/go-llama.cpp v0.0.0-20230616223721-7ad833b67070/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= -github.com/go-skynet/go-llama.cpp v0.0.0-20230626215901-f104111358e8 h1:Knh5QUvI/68erb/yWtrVa/3hvoQdENF2dH0hL2HNPrI= -github.com/go-skynet/go-llama.cpp v0.0.0-20230626215901-f104111358e8/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= -github.com/go-skynet/go-llama.cpp v0.0.0-20230627195533-582753605210 h1:9bm+vsiR3UI7xlU0G0cMU2Swq78RysoFVkSONvrujF8= -github.com/go-skynet/go-llama.cpp v0.0.0-20230627195533-582753605210/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= -github.com/go-skynet/go-llama.cpp v0.0.0-20230628194133-42ba44838369 h1:lSX1NWzRvRS2MlACvyvVVUnqXhKiuMAoN3DO5TbCe8M= -github.com/go-skynet/go-llama.cpp v0.0.0-20230628194133-42ba44838369/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= -github.com/go-skynet/go-llama.cpp v0.0.0-20230703203849-ffa57fbc3a12 h1:cfGZiZana0gPD0i8nmyOGTUQGb4N8PYqaBqhhukREPc= -github.com/go-skynet/go-llama.cpp v0.0.0-20230703203849-ffa57fbc3a12/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofiber/fiber/v2 v2.47.0 h1:EN5lHVCc+Pyqh5OEsk8fzRiifgwpbrP0rulQ4iNf3fs= github.com/gofiber/fiber/v2 v2.47.0/go.mod h1:mbFMVN1lQuzziTkkakgtKKdjfsXSw9BKR5lmcNksUoU= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw= github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -89,11 +63,11 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= -github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= @@ -103,16 +77,12 @@ github.com/klauspost/compress v1.16.3/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQs github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/pgzip v1.2.5 h1:qnWYvvKqedOF2ulHpMG72XQol4ILEJ8k2wwRl/Km8oE= github.com/klauspost/pgzip v1.2.5/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA= -github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= @@ -128,33 +98,29 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU= -github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= -github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks= -github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230620230702-09ae04cee90c h1:axNtjd5k6Xs4Ck7B7VRRQu6q5lQzTsjdWmaJkDADopU= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230620230702-09ae04cee90c/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230628182915-a67f8132e165 h1:zcnIdoSeLueTDxUD2A1qnyaSp8uh0Ay7OgHeBwpxSeg= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230628182915-a67f8132e165/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230708212935-d611d107479f h1:FtXRIjsBvoBQ5xmA26QbzyG4RjV2U5lOpUgP4npITOM= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230708212935-d611d107479f/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= +github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d h1:/lAg9vPAAU+s35cDMCx1IyeMn+4OYfCBPqi08Q8vXDg= +github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d/go.mod h1:HGGAOJhipApckwNV8ZTliRJqxctUv3xRY+zbQEwuytc= github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ= github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU= github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc= github.com/onsi/gomega v1.27.8/go.mod h1:2J8vzI/s+2shY9XHRApDkdgPo1TKT7P2u6fXeJKFnNQ= -github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks= -github.com/otiai10/mint v1.5.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= -github.com/otiai10/mint v1.6.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= -github.com/otiai10/openaigo v1.2.0 h1:Whq+uvgqw8NdIsVdixtBKCAI6OdfCJiGPlhUnYJQ6Ag= -github.com/otiai10/openaigo v1.2.0/go.mod h1:792bx6AWTS61weDi2EzKpHHnTF4eDMAlJ5GvAk/mgPg= -github.com/otiai10/openaigo v1.4.0 h1:BeacKb2Q5bVejjOKHFJxL2WFYal3QxwkrKtKuoU5LNU= -github.com/otiai10/openaigo v1.4.0/go.mod h1:kIaXc3V+Xy5JLplcBxehVyGYDtufHp3PFPy04jOwOAI= +github.com/otiai10/mint v1.6.1 h1:kgbTJmOpp/0ce7hk3H8jiSuR0MXmpwWRfqUdKww17qg= github.com/otiai10/openaigo v1.5.2 h1:YnNDisZmA4syArF3IxMCIrfgZOq30PLV219gPY7n2z8= github.com/otiai10/openaigo v1.5.2/go.mod h1:kIaXc3V+Xy5JLplcBxehVyGYDtufHp3PFPy04jOwOAI= +github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= +github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw= github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0= @@ -172,8 +138,6 @@ github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sashabaranov/go-openai v1.11.3 h1:bvwWF8hj4UhPlswBdL9/IfOpaHXfzGCJO8WY8ml9sGc= -github.com/sashabaranov/go-openai v1.11.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.13.0 h1:EAusFfnhaMaaUspUZ2+MbB/ZcVeD4epJmTOlZ+8AcAE= github.com/sashabaranov/go-openai v1.13.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4= @@ -181,26 +145,14 @@ github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94/go.mod h1:90zrgN3 github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4= github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/swaggo/swag v1.16.1 h1:fTNRhKstPKxcnoKsytm4sahr8FaYzUcT7i1/3nd/fBg= -github.com/swaggo/swag v1.16.1/go.mod h1:9/LMvHycG3NFHfR6LwvikHv5iFvmPADQ359cKikGxto= github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw= github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= -github.com/tmc/langchaingo v0.0.0-20230616220619-1b3da4433944 h1:EE9fvNENTdRc/yI/1zAs7VFbmDk6JZ7EbBIFl+TsCm0= -github.com/tmc/langchaingo v0.0.0-20230616220619-1b3da4433944/go.mod h1:6l1WoyqVDwkv7cFlY3gfcTv8yVowVyuutKv8PGlQCWI= -github.com/tmc/langchaingo v0.0.0-20230625081011-4d9d55dbcaba h1:NpAI9C0y9T4jwP7XFShwYJKGf/ggyCgZEtL/7lLRPwE= -github.com/tmc/langchaingo v0.0.0-20230625081011-4d9d55dbcaba/go.mod h1:tz9cjA9BW8/lWx/T5njr3ZLHK/dfPyr/0ICSMThmY2g= -github.com/tmc/langchaingo v0.0.0-20230625234550-7ea734523e39 h1:SpOEFXx5xXLypFnwNRQj7yOC3rMvSylGA5BQW/FAwYc= -github.com/tmc/langchaingo v0.0.0-20230625234550-7ea734523e39/go.mod h1:tz9cjA9BW8/lWx/T5njr3ZLHK/dfPyr/0ICSMThmY2g= -github.com/tmc/langchaingo v0.0.0-20230627220614-633853b5ac3b h1:xUxtya/3KRDn1rcCVZucp2KhjdqSZat9j0hOshSVh2Q= -github.com/tmc/langchaingo v0.0.0-20230627220614-633853b5ac3b/go.mod h1:F1k7uRBLM8jMMEPV3dVtWVNc+W91nxOBRKbJWM/LwpM= -github.com/tmc/langchaingo v0.0.0-20230628165432-e510561c17f9 h1:BooyHg3f058lrPcTLdfC7HTfjO5OGZAgwciQJ5e85l0= -github.com/tmc/langchaingo v0.0.0-20230628165432-e510561c17f9/go.mod h1:F1k7uRBLM8jMMEPV3dVtWVNc+W91nxOBRKbJWM/LwpM= github.com/tmc/langchaingo v0.0.0-20230709010448-a875e6bc0c54 h1:MZSC3/pdBzkoPG49uTRvtEepOQKdbdgaT1aLtaEwxx4= github.com/tmc/langchaingo v0.0.0-20230709010448-a875e6bc0c54/go.mod h1:RsMJqgUynOtr2jWNhUF41R3j6SDkKq9c8UfE0nJYBb4= github.com/ulikunitz/xz v0.5.8/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= @@ -228,25 +180,34 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -270,6 +231,7 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM= @@ -278,15 +240,33 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A= +google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= +google.golang.org/grpc v1.56.2 h1:fVRFRnXvU+x6C4IlHZewvJOVHoOv1TUuQyoRsYnB4bI= +google.golang.org/grpc v1.56.2/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/op/go-logging.v1 v1.0.0-20160211212156-b2cb9fa56473/go.mod h1:N1eN2tsCx0Ydtgjl4cqmbRCsY4/+z4cYDeqwZTk6zog= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go new file mode 100644 index 0000000..f63a89a --- /dev/null +++ b/pkg/grpc/client.go @@ -0,0 +1,98 @@ +package grpc + +import ( + "context" + "fmt" + "io" + "time" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type Client struct { + address string +} + +func NewClient(address string) *Client { + return &Client{ + address: address, + } +} + +func (c *Client) HealthCheck(ctx context.Context) bool { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + fmt.Println(err) + return false + } + defer conn.Close() + client := pb.NewLLMClient(conn) + + // The healthcheck call shouldn't take long time + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + res, err := client.Health(ctx, &pb.HealthMessage{}) + if err != nil { + fmt.Println(err) + + return false + } + + if res.Message == "OK" { + return true + } + return false +} + +func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewLLMClient(conn) + + return client.Predict(ctx, in, opts...) +} + +func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewLLMClient(conn) + return client.LoadModel(ctx, in, opts...) +} + +func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s string), opts ...grpc.CallOption) error { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return err + } + defer conn.Close() + client := pb.NewLLMClient(conn) + + stream, err := client.PredictStream(ctx, in, opts...) + if err != nil { + return err + } + + for { + feature, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + fmt.Println("Error", err) + + return err + } + f(feature.GetMessage()) + } + + return nil +} diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go new file mode 100644 index 0000000..8ac851a --- /dev/null +++ b/pkg/grpc/interface.go @@ -0,0 +1,11 @@ +package grpc + +import ( + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" +) + +type LLM interface { + Predict(*pb.PredictOptions) (string, error) + PredictStream(*pb.PredictOptions, chan string) + Load(*pb.ModelOptions) error +} diff --git a/pkg/grpc/llm/falcon/falcon.go b/pkg/grpc/llm/falcon/falcon.go new file mode 100644 index 0000000..a0a53be --- /dev/null +++ b/pkg/grpc/llm/falcon/falcon.go @@ -0,0 +1,136 @@ +package falcon + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + ggllm "github.com/mudler/go-ggllm.cpp" +) + +type LLM struct { + falcon *ggllm.Falcon +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + ggllmOpts := []ggllm.ModelOption{} + if opts.ContextSize != 0 { + ggllmOpts = append(ggllmOpts, ggllm.SetContext(int(opts.ContextSize))) + } + // F16 doesn't seem to produce good output at all! + //if c.F16 { + // llamaOpts = append(llamaOpts, llama.EnableF16Memory) + //} + + if opts.NGPULayers != 0 { + ggllmOpts = append(ggllmOpts, ggllm.SetGPULayers(int(opts.NGPULayers))) + } + + ggllmOpts = append(ggllmOpts, ggllm.SetMMap(opts.MMap)) + ggllmOpts = append(ggllmOpts, ggllm.SetMainGPU(opts.MainGPU)) + ggllmOpts = append(ggllmOpts, ggllm.SetTensorSplit(opts.TensorSplit)) + if opts.NBatch != 0 { + ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(int(opts.NBatch))) + } else { + ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(512)) + } + + model, err := ggllm.New(opts.Model, ggllmOpts...) + llm.falcon = model + return err +} + +func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption { + predictOptions := []ggllm.PredictOption{ + ggllm.SetTemperature(float64(opts.Temperature)), + ggllm.SetTopP(float64(opts.TopP)), + ggllm.SetTopK(int(opts.TopK)), + ggllm.SetTokens(int(opts.Tokens)), + ggllm.SetThreads(int(opts.Threads)), + } + + if opts.PromptCacheAll { + predictOptions = append(predictOptions, ggllm.EnablePromptCacheAll) + } + + if opts.PromptCacheRO { + predictOptions = append(predictOptions, ggllm.EnablePromptCacheRO) + } + + // Expected absolute path + if opts.PromptCachePath != "" { + predictOptions = append(predictOptions, ggllm.SetPathPromptCache(opts.PromptCachePath)) + } + + if opts.Mirostat != 0 { + predictOptions = append(predictOptions, ggllm.SetMirostat(int(opts.Mirostat))) + } + + if opts.MirostatETA != 0 { + predictOptions = append(predictOptions, ggllm.SetMirostatETA(float64(opts.MirostatETA))) + } + + if opts.MirostatTAU != 0 { + predictOptions = append(predictOptions, ggllm.SetMirostatTAU(float64(opts.MirostatTAU))) + } + + if opts.Debug { + predictOptions = append(predictOptions, ggllm.Debug) + } + + predictOptions = append(predictOptions, ggllm.SetStopWords(opts.StopPrompts...)) + + if opts.PresencePenalty != 0 { + predictOptions = append(predictOptions, ggllm.SetPenalty(float64(opts.PresencePenalty))) + } + + if opts.NKeep != 0 { + predictOptions = append(predictOptions, ggllm.SetNKeep(int(opts.NKeep))) + } + + if opts.Batch != 0 { + predictOptions = append(predictOptions, ggllm.SetBatch(int(opts.Batch))) + } + + if opts.IgnoreEOS { + predictOptions = append(predictOptions, ggllm.IgnoreEOS) + } + + if opts.Seed != 0 { + predictOptions = append(predictOptions, ggllm.SetSeed(int(opts.Seed))) + } + + //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) + + predictOptions = append(predictOptions, ggllm.SetFrequencyPenalty(float64(opts.FrequencyPenalty))) + predictOptions = append(predictOptions, ggllm.SetMlock(opts.MLock)) + predictOptions = append(predictOptions, ggllm.SetMemoryMap(opts.MMap)) + predictOptions = append(predictOptions, ggllm.SetPredictionMainGPU(opts.MainGPU)) + predictOptions = append(predictOptions, ggllm.SetPredictionTensorSplit(opts.TensorSplit)) + predictOptions = append(predictOptions, ggllm.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ))) + predictOptions = append(predictOptions, ggllm.SetTypicalP(float64(opts.TypicalP))) + return predictOptions +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { + predictOptions := buildPredictOptions(opts) + + predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool { + results <- token + return true + })) + + go func() { + _, err := llm.falcon.Predict(opts.Prompt, predictOptions...) + if err != nil { + fmt.Println("err: ", err) + } + close(results) + }() +} diff --git a/pkg/grpc/llm/ggml/starcoder.go b/pkg/grpc/llm/ggml/starcoder.go new file mode 100644 index 0000000..e69de29 diff --git a/pkg/grpc/proto/llmserver.pb.go b/pkg/grpc/proto/llmserver.pb.go new file mode 100644 index 0000000..067c3a1 --- /dev/null +++ b/pkg/grpc/proto/llmserver.pb.go @@ -0,0 +1,870 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.15.8 +// source: pkg/grpc/proto/llmserver.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type HealthMessage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *HealthMessage) Reset() { + *x = HealthMessage{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HealthMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HealthMessage) ProtoMessage() {} + +func (x *HealthMessage) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HealthMessage.ProtoReflect.Descriptor instead. +func (*HealthMessage) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{0} +} + +// The request message containing the user's name. +type PredictOptions struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Prompt string `protobuf:"bytes,1,opt,name=Prompt,proto3" json:"Prompt,omitempty"` + Seed int32 `protobuf:"varint,2,opt,name=Seed,proto3" json:"Seed,omitempty"` + Threads int32 `protobuf:"varint,3,opt,name=Threads,proto3" json:"Threads,omitempty"` + Tokens int32 `protobuf:"varint,4,opt,name=Tokens,proto3" json:"Tokens,omitempty"` + TopK int32 `protobuf:"varint,5,opt,name=TopK,proto3" json:"TopK,omitempty"` + Repeat int32 `protobuf:"varint,6,opt,name=Repeat,proto3" json:"Repeat,omitempty"` + Batch int32 `protobuf:"varint,7,opt,name=Batch,proto3" json:"Batch,omitempty"` + NKeep int32 `protobuf:"varint,8,opt,name=NKeep,proto3" json:"NKeep,omitempty"` + Temperature float32 `protobuf:"fixed32,9,opt,name=Temperature,proto3" json:"Temperature,omitempty"` + Penalty float32 `protobuf:"fixed32,10,opt,name=Penalty,proto3" json:"Penalty,omitempty"` + F16KV bool `protobuf:"varint,11,opt,name=F16KV,proto3" json:"F16KV,omitempty"` + DebugMode bool `protobuf:"varint,12,opt,name=DebugMode,proto3" json:"DebugMode,omitempty"` + StopPrompts []string `protobuf:"bytes,13,rep,name=StopPrompts,proto3" json:"StopPrompts,omitempty"` + IgnoreEOS bool `protobuf:"varint,14,opt,name=IgnoreEOS,proto3" json:"IgnoreEOS,omitempty"` + TailFreeSamplingZ float32 `protobuf:"fixed32,15,opt,name=TailFreeSamplingZ,proto3" json:"TailFreeSamplingZ,omitempty"` + TypicalP float32 `protobuf:"fixed32,16,opt,name=TypicalP,proto3" json:"TypicalP,omitempty"` + FrequencyPenalty float32 `protobuf:"fixed32,17,opt,name=FrequencyPenalty,proto3" json:"FrequencyPenalty,omitempty"` + PresencePenalty float32 `protobuf:"fixed32,18,opt,name=PresencePenalty,proto3" json:"PresencePenalty,omitempty"` + Mirostat int32 `protobuf:"varint,19,opt,name=Mirostat,proto3" json:"Mirostat,omitempty"` + MirostatETA float32 `protobuf:"fixed32,20,opt,name=MirostatETA,proto3" json:"MirostatETA,omitempty"` + MirostatTAU float32 `protobuf:"fixed32,21,opt,name=MirostatTAU,proto3" json:"MirostatTAU,omitempty"` + PenalizeNL bool `protobuf:"varint,22,opt,name=PenalizeNL,proto3" json:"PenalizeNL,omitempty"` + LogitBias string `protobuf:"bytes,23,opt,name=LogitBias,proto3" json:"LogitBias,omitempty"` + PathPromptCache string `protobuf:"bytes,24,opt,name=PathPromptCache,proto3" json:"PathPromptCache,omitempty"` + MLock bool `protobuf:"varint,25,opt,name=MLock,proto3" json:"MLock,omitempty"` + MMap bool `protobuf:"varint,26,opt,name=MMap,proto3" json:"MMap,omitempty"` + PromptCacheAll bool `protobuf:"varint,27,opt,name=PromptCacheAll,proto3" json:"PromptCacheAll,omitempty"` + PromptCacheRO bool `protobuf:"varint,28,opt,name=PromptCacheRO,proto3" json:"PromptCacheRO,omitempty"` + Grammar string `protobuf:"bytes,29,opt,name=Grammar,proto3" json:"Grammar,omitempty"` + MainGPU string `protobuf:"bytes,30,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"` + TensorSplit string `protobuf:"bytes,31,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"` + TopP float32 `protobuf:"fixed32,32,opt,name=TopP,proto3" json:"TopP,omitempty"` + PromptCachePath string `protobuf:"bytes,33,opt,name=PromptCachePath,proto3" json:"PromptCachePath,omitempty"` + Debug bool `protobuf:"varint,34,opt,name=Debug,proto3" json:"Debug,omitempty"` +} + +func (x *PredictOptions) Reset() { + *x = PredictOptions{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PredictOptions) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PredictOptions) ProtoMessage() {} + +func (x *PredictOptions) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PredictOptions.ProtoReflect.Descriptor instead. +func (*PredictOptions) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{1} +} + +func (x *PredictOptions) GetPrompt() string { + if x != nil { + return x.Prompt + } + return "" +} + +func (x *PredictOptions) GetSeed() int32 { + if x != nil { + return x.Seed + } + return 0 +} + +func (x *PredictOptions) GetThreads() int32 { + if x != nil { + return x.Threads + } + return 0 +} + +func (x *PredictOptions) GetTokens() int32 { + if x != nil { + return x.Tokens + } + return 0 +} + +func (x *PredictOptions) GetTopK() int32 { + if x != nil { + return x.TopK + } + return 0 +} + +func (x *PredictOptions) GetRepeat() int32 { + if x != nil { + return x.Repeat + } + return 0 +} + +func (x *PredictOptions) GetBatch() int32 { + if x != nil { + return x.Batch + } + return 0 +} + +func (x *PredictOptions) GetNKeep() int32 { + if x != nil { + return x.NKeep + } + return 0 +} + +func (x *PredictOptions) GetTemperature() float32 { + if x != nil { + return x.Temperature + } + return 0 +} + +func (x *PredictOptions) GetPenalty() float32 { + if x != nil { + return x.Penalty + } + return 0 +} + +func (x *PredictOptions) GetF16KV() bool { + if x != nil { + return x.F16KV + } + return false +} + +func (x *PredictOptions) GetDebugMode() bool { + if x != nil { + return x.DebugMode + } + return false +} + +func (x *PredictOptions) GetStopPrompts() []string { + if x != nil { + return x.StopPrompts + } + return nil +} + +func (x *PredictOptions) GetIgnoreEOS() bool { + if x != nil { + return x.IgnoreEOS + } + return false +} + +func (x *PredictOptions) GetTailFreeSamplingZ() float32 { + if x != nil { + return x.TailFreeSamplingZ + } + return 0 +} + +func (x *PredictOptions) GetTypicalP() float32 { + if x != nil { + return x.TypicalP + } + return 0 +} + +func (x *PredictOptions) GetFrequencyPenalty() float32 { + if x != nil { + return x.FrequencyPenalty + } + return 0 +} + +func (x *PredictOptions) GetPresencePenalty() float32 { + if x != nil { + return x.PresencePenalty + } + return 0 +} + +func (x *PredictOptions) GetMirostat() int32 { + if x != nil { + return x.Mirostat + } + return 0 +} + +func (x *PredictOptions) GetMirostatETA() float32 { + if x != nil { + return x.MirostatETA + } + return 0 +} + +func (x *PredictOptions) GetMirostatTAU() float32 { + if x != nil { + return x.MirostatTAU + } + return 0 +} + +func (x *PredictOptions) GetPenalizeNL() bool { + if x != nil { + return x.PenalizeNL + } + return false +} + +func (x *PredictOptions) GetLogitBias() string { + if x != nil { + return x.LogitBias + } + return "" +} + +func (x *PredictOptions) GetPathPromptCache() string { + if x != nil { + return x.PathPromptCache + } + return "" +} + +func (x *PredictOptions) GetMLock() bool { + if x != nil { + return x.MLock + } + return false +} + +func (x *PredictOptions) GetMMap() bool { + if x != nil { + return x.MMap + } + return false +} + +func (x *PredictOptions) GetPromptCacheAll() bool { + if x != nil { + return x.PromptCacheAll + } + return false +} + +func (x *PredictOptions) GetPromptCacheRO() bool { + if x != nil { + return x.PromptCacheRO + } + return false +} + +func (x *PredictOptions) GetGrammar() string { + if x != nil { + return x.Grammar + } + return "" +} + +func (x *PredictOptions) GetMainGPU() string { + if x != nil { + return x.MainGPU + } + return "" +} + +func (x *PredictOptions) GetTensorSplit() string { + if x != nil { + return x.TensorSplit + } + return "" +} + +func (x *PredictOptions) GetTopP() float32 { + if x != nil { + return x.TopP + } + return 0 +} + +func (x *PredictOptions) GetPromptCachePath() string { + if x != nil { + return x.PromptCachePath + } + return "" +} + +func (x *PredictOptions) GetDebug() bool { + if x != nil { + return x.Debug + } + return false +} + +// The response message containing the result +type Reply struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *Reply) Reset() { + *x = Reply{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Reply) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Reply) ProtoMessage() {} + +func (x *Reply) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Reply.ProtoReflect.Descriptor instead. +func (*Reply) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{2} +} + +func (x *Reply) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type ModelOptions struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Model string `protobuf:"bytes,1,opt,name=Model,proto3" json:"Model,omitempty"` + ContextSize int32 `protobuf:"varint,2,opt,name=ContextSize,proto3" json:"ContextSize,omitempty"` + Seed int32 `protobuf:"varint,3,opt,name=Seed,proto3" json:"Seed,omitempty"` + NBatch int32 `protobuf:"varint,4,opt,name=NBatch,proto3" json:"NBatch,omitempty"` + F16Memory bool `protobuf:"varint,5,opt,name=F16Memory,proto3" json:"F16Memory,omitempty"` + MLock bool `protobuf:"varint,6,opt,name=MLock,proto3" json:"MLock,omitempty"` + MMap bool `protobuf:"varint,7,opt,name=MMap,proto3" json:"MMap,omitempty"` + VocabOnly bool `protobuf:"varint,8,opt,name=VocabOnly,proto3" json:"VocabOnly,omitempty"` + LowVRAM bool `protobuf:"varint,9,opt,name=LowVRAM,proto3" json:"LowVRAM,omitempty"` + Embeddings bool `protobuf:"varint,10,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` + NUMA bool `protobuf:"varint,11,opt,name=NUMA,proto3" json:"NUMA,omitempty"` + NGPULayers int32 `protobuf:"varint,12,opt,name=NGPULayers,proto3" json:"NGPULayers,omitempty"` + MainGPU string `protobuf:"bytes,13,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"` + TensorSplit string `protobuf:"bytes,14,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"` +} + +func (x *ModelOptions) Reset() { + *x = ModelOptions{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ModelOptions) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ModelOptions) ProtoMessage() {} + +func (x *ModelOptions) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ModelOptions.ProtoReflect.Descriptor instead. +func (*ModelOptions) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{3} +} + +func (x *ModelOptions) GetModel() string { + if x != nil { + return x.Model + } + return "" +} + +func (x *ModelOptions) GetContextSize() int32 { + if x != nil { + return x.ContextSize + } + return 0 +} + +func (x *ModelOptions) GetSeed() int32 { + if x != nil { + return x.Seed + } + return 0 +} + +func (x *ModelOptions) GetNBatch() int32 { + if x != nil { + return x.NBatch + } + return 0 +} + +func (x *ModelOptions) GetF16Memory() bool { + if x != nil { + return x.F16Memory + } + return false +} + +func (x *ModelOptions) GetMLock() bool { + if x != nil { + return x.MLock + } + return false +} + +func (x *ModelOptions) GetMMap() bool { + if x != nil { + return x.MMap + } + return false +} + +func (x *ModelOptions) GetVocabOnly() bool { + if x != nil { + return x.VocabOnly + } + return false +} + +func (x *ModelOptions) GetLowVRAM() bool { + if x != nil { + return x.LowVRAM + } + return false +} + +func (x *ModelOptions) GetEmbeddings() bool { + if x != nil { + return x.Embeddings + } + return false +} + +func (x *ModelOptions) GetNUMA() bool { + if x != nil { + return x.NUMA + } + return false +} + +func (x *ModelOptions) GetNGPULayers() int32 { + if x != nil { + return x.NGPULayers + } + return 0 +} + +func (x *ModelOptions) GetMainGPU() string { + if x != nil { + return x.MainGPU + } + return "" +} + +func (x *ModelOptions) GetTensorSplit() string { + if x != nil { + return x.TensorSplit + } + return "" +} + +type Result struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` + Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"` +} + +func (x *Result) Reset() { + *x = Result{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Result) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Result) ProtoMessage() {} + +func (x *Result) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Result.ProtoReflect.Descriptor instead. +func (*Result) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{4} +} + +func (x *Result) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *Result) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +var File_pkg_grpc_proto_llmserver_proto protoreflect.FileDescriptor + +var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{ + 0x0a, 0x1e, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x2f, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x12, 0x03, 0x6c, 0x6c, 0x6d, 0x22, 0x0f, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x80, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, 0x64, 0x69, + 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x72, 0x6f, + 0x6d, 0x70, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x72, 0x6f, 0x6d, 0x70, + 0x74, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x04, 0x53, 0x65, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x12, + 0x16, 0x0a, 0x06, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x06, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x6f, 0x70, 0x4b, 0x18, + 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x4b, 0x12, 0x16, 0x0a, 0x06, 0x52, + 0x65, 0x70, 0x65, 0x61, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x52, 0x65, 0x70, + 0x65, 0x61, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x42, 0x61, 0x74, 0x63, 0x68, 0x18, 0x07, 0x20, 0x01, + 0x28, 0x05, 0x52, 0x05, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x4b, 0x65, + 0x65, 0x70, 0x18, 0x08, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x4e, 0x4b, 0x65, 0x65, 0x70, 0x12, + 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x09, + 0x20, 0x01, 0x28, 0x02, 0x52, 0x0b, 0x54, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, 0x75, 0x72, + 0x65, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x0a, 0x20, 0x01, + 0x28, 0x02, 0x52, 0x07, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x46, + 0x31, 0x36, 0x4b, 0x56, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x46, 0x31, 0x36, 0x4b, + 0x56, 0x12, 0x1c, 0x0a, 0x09, 0x44, 0x65, 0x62, 0x75, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x18, 0x0c, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x44, 0x65, 0x62, 0x75, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x12, + 0x20, 0x0a, 0x0b, 0x53, 0x74, 0x6f, 0x70, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x73, 0x18, 0x0d, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x53, 0x74, 0x6f, 0x70, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, + 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x45, 0x4f, 0x53, 0x18, 0x0e, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x45, 0x4f, 0x53, 0x12, + 0x2c, 0x0a, 0x11, 0x54, 0x61, 0x69, 0x6c, 0x46, 0x72, 0x65, 0x65, 0x53, 0x61, 0x6d, 0x70, 0x6c, + 0x69, 0x6e, 0x67, 0x5a, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x02, 0x52, 0x11, 0x54, 0x61, 0x69, 0x6c, + 0x46, 0x72, 0x65, 0x65, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x69, 0x6e, 0x67, 0x5a, 0x12, 0x1a, 0x0a, + 0x08, 0x54, 0x79, 0x70, 0x69, 0x63, 0x61, 0x6c, 0x50, 0x18, 0x10, 0x20, 0x01, 0x28, 0x02, 0x52, + 0x08, 0x54, 0x79, 0x70, 0x69, 0x63, 0x61, 0x6c, 0x50, 0x12, 0x2a, 0x0a, 0x10, 0x46, 0x72, 0x65, + 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x11, 0x20, + 0x01, 0x28, 0x02, 0x52, 0x10, 0x46, 0x72, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x50, 0x65, + 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x65, 0x73, 0x65, 0x6e, 0x63, + 0x65, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x12, 0x20, 0x01, 0x28, 0x02, 0x52, 0x0f, + 0x50, 0x72, 0x65, 0x73, 0x65, 0x6e, 0x63, 0x65, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, + 0x1a, 0x0a, 0x08, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x18, 0x13, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x08, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x12, 0x20, 0x0a, 0x0b, 0x4d, + 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x45, 0x54, 0x41, 0x18, 0x14, 0x20, 0x01, 0x28, 0x02, + 0x52, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x45, 0x54, 0x41, 0x12, 0x20, 0x0a, + 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x54, 0x41, 0x55, 0x18, 0x15, 0x20, 0x01, + 0x28, 0x02, 0x52, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x54, 0x41, 0x55, 0x12, + 0x1e, 0x0a, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x18, 0x16, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x12, + 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x18, 0x17, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, 0x28, 0x0a, + 0x0f, 0x50, 0x61, 0x74, 0x68, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, + 0x18, 0x18, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x50, 0x72, 0x6f, 0x6d, + 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, + 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, + 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, + 0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, + 0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, + 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x12, 0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f, + 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, 0x18, 0x1c, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, 0x12, + 0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x18, 0x1d, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, + 0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, + 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, + 0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, + 0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x18, 0x20, 0x20, + 0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f, + 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x18, 0x21, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, + 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x18, 0x22, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x22, 0x21, 0x0a, 0x05, 0x52, 0x65, 0x70, + 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x82, 0x03, 0x0a, + 0x0c, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x14, 0x0a, + 0x05, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4d, 0x6f, + 0x64, 0x65, 0x6c, 0x12, 0x20, 0x0a, 0x0b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x53, 0x69, + 0x7a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, + 0x74, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x04, 0x53, 0x65, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x42, 0x61, + 0x74, 0x63, 0x68, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x4e, 0x42, 0x61, 0x74, 0x63, + 0x68, 0x12, 0x1c, 0x0a, 0x09, 0x46, 0x31, 0x36, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x46, 0x31, 0x36, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x12, + 0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, + 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x07, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x1c, 0x0a, 0x09, 0x56, 0x6f, 0x63, + 0x61, 0x62, 0x4f, 0x6e, 0x6c, 0x79, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x56, 0x6f, + 0x63, 0x61, 0x62, 0x4f, 0x6e, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x4c, 0x6f, 0x77, 0x56, 0x52, + 0x41, 0x4d, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x4c, 0x6f, 0x77, 0x56, 0x52, 0x41, + 0x4d, 0x12, 0x1e, 0x0a, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, + 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, + 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x55, 0x4d, 0x41, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x04, 0x4e, 0x55, 0x4d, 0x41, 0x12, 0x1e, 0x0a, 0x0a, 0x4e, 0x47, 0x50, 0x55, 0x4c, 0x61, 0x79, + 0x65, 0x72, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x4e, 0x47, 0x50, 0x55, 0x4c, + 0x61, 0x79, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, + 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12, + 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x0e, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, + 0x74, 0x22, 0x3c, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, + 0xc4, 0x01, 0x0a, 0x03, 0x4c, 0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x12, 0x12, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, + 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, + 0x79, 0x22, 0x00, 0x12, 0x2c, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x13, + 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, + 0x00, 0x12, 0x2d, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x11, + 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x1a, 0x0b, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, + 0x12, 0x34, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, + 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, + 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x42, 0x57, 0x0a, 0x1b, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79, + 0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x6c, 0x6c, 0x6d, 0x73, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, + 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, + 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_grpc_proto_llmserver_proto_rawDescOnce sync.Once + file_pkg_grpc_proto_llmserver_proto_rawDescData = file_pkg_grpc_proto_llmserver_proto_rawDesc +) + +func file_pkg_grpc_proto_llmserver_proto_rawDescGZIP() []byte { + file_pkg_grpc_proto_llmserver_proto_rawDescOnce.Do(func() { + file_pkg_grpc_proto_llmserver_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_proto_llmserver_proto_rawDescData) + }) + return file_pkg_grpc_proto_llmserver_proto_rawDescData +} + +var file_pkg_grpc_proto_llmserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_pkg_grpc_proto_llmserver_proto_goTypes = []interface{}{ + (*HealthMessage)(nil), // 0: llm.HealthMessage + (*PredictOptions)(nil), // 1: llm.PredictOptions + (*Reply)(nil), // 2: llm.Reply + (*ModelOptions)(nil), // 3: llm.ModelOptions + (*Result)(nil), // 4: llm.Result +} +var file_pkg_grpc_proto_llmserver_proto_depIdxs = []int32{ + 0, // 0: llm.LLM.Health:input_type -> llm.HealthMessage + 1, // 1: llm.LLM.Predict:input_type -> llm.PredictOptions + 3, // 2: llm.LLM.LoadModel:input_type -> llm.ModelOptions + 1, // 3: llm.LLM.PredictStream:input_type -> llm.PredictOptions + 2, // 4: llm.LLM.Health:output_type -> llm.Reply + 2, // 5: llm.LLM.Predict:output_type -> llm.Reply + 4, // 6: llm.LLM.LoadModel:output_type -> llm.Result + 2, // 7: llm.LLM.PredictStream:output_type -> llm.Reply + 4, // [4:8] is the sub-list for method output_type + 0, // [0:4] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_pkg_grpc_proto_llmserver_proto_init() } +func file_pkg_grpc_proto_llmserver_proto_init() { + if File_pkg_grpc_proto_llmserver_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_grpc_proto_llmserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HealthMessage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_llmserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PredictOptions); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_llmserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Reply); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_llmserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ModelOptions); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_llmserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Result); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_grpc_proto_llmserver_proto_rawDesc, + NumEnums: 0, + NumMessages: 5, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_pkg_grpc_proto_llmserver_proto_goTypes, + DependencyIndexes: file_pkg_grpc_proto_llmserver_proto_depIdxs, + MessageInfos: file_pkg_grpc_proto_llmserver_proto_msgTypes, + }.Build() + File_pkg_grpc_proto_llmserver_proto = out.File + file_pkg_grpc_proto_llmserver_proto_rawDesc = nil + file_pkg_grpc_proto_llmserver_proto_goTypes = nil + file_pkg_grpc_proto_llmserver_proto_depIdxs = nil +} diff --git a/pkg/grpc/proto/llmserver.proto b/pkg/grpc/proto/llmserver.proto new file mode 100644 index 0000000..ba20806 --- /dev/null +++ b/pkg/grpc/proto/llmserver.proto @@ -0,0 +1,82 @@ +syntax = "proto3"; + +option go_package = "github.com/go-skynet/LocalAI/pkg/grpc/proto"; +option java_multiple_files = true; +option java_package = "io.skynet.localai.llmserver"; +option java_outer_classname = "LLMServer"; + +package llm; + +service LLM { + rpc Health(HealthMessage) returns (Reply) {} + rpc Predict(PredictOptions) returns (Reply) {} + rpc LoadModel(ModelOptions) returns (Result) {} + rpc PredictStream(PredictOptions) returns (stream Reply) {} +} + +message HealthMessage {} + +// The request message containing the user's name. +message PredictOptions { + string Prompt = 1; + int32 Seed = 2; + int32 Threads = 3; + int32 Tokens = 4; + int32 TopK = 5; + int32 Repeat = 6; + int32 Batch = 7; + int32 NKeep = 8; + float Temperature = 9; + float Penalty = 10; + bool F16KV = 11; + bool DebugMode = 12; + repeated string StopPrompts = 13; + bool IgnoreEOS = 14; + float TailFreeSamplingZ = 15; + float TypicalP = 16; + float FrequencyPenalty = 17; + float PresencePenalty = 18; + int32 Mirostat = 19; + float MirostatETA = 20; + float MirostatTAU = 21; + bool PenalizeNL = 22; + string LogitBias = 23; + string PathPromptCache = 24; + bool MLock = 25; + bool MMap = 26; + bool PromptCacheAll = 27; + bool PromptCacheRO = 28; + string Grammar = 29; + string MainGPU = 30; + string TensorSplit = 31; + float TopP = 32; + string PromptCachePath = 33; + bool Debug = 34; +} + +// The response message containing the result +message Reply { + string message = 1; +} + +message ModelOptions { + string Model = 1; + int32 ContextSize = 2; + int32 Seed = 3; + int32 NBatch = 4; + bool F16Memory = 5; + bool MLock = 6; + bool MMap = 7; + bool VocabOnly = 8; + bool LowVRAM = 9; + bool Embeddings = 10; + bool NUMA = 11; + int32 NGPULayers = 12; + string MainGPU = 13; + string TensorSplit = 14; +} + +message Result { + string message = 1; + bool success = 2; +} \ No newline at end of file diff --git a/pkg/grpc/proto/llmserver_grpc.pb.go b/pkg/grpc/proto/llmserver_grpc.pb.go new file mode 100644 index 0000000..6cfd981 --- /dev/null +++ b/pkg/grpc/proto/llmserver_grpc.pb.go @@ -0,0 +1,241 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.15.8 +// source: pkg/grpc/proto/llmserver.proto + +package proto + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// LLMClient is the client API for LLM service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type LLMClient interface { + Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) + Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) + LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) + PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (LLM_PredictStreamClient, error) +} + +type lLMClient struct { + cc grpc.ClientConnInterface +} + +func NewLLMClient(cc grpc.ClientConnInterface) LLMClient { + return &lLMClient{cc} +} + +func (c *lLMClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { + out := new(Reply) + err := c.cc.Invoke(ctx, "/llm.LLM/Health", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *lLMClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { + out := new(Reply) + err := c.cc.Invoke(ctx, "/llm.LLM/Predict", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *lLMClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { + out := new(Result) + err := c.cc.Invoke(ctx, "/llm.LLM/LoadModel", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *lLMClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (LLM_PredictStreamClient, error) { + stream, err := c.cc.NewStream(ctx, &LLM_ServiceDesc.Streams[0], "/llm.LLM/PredictStream", opts...) + if err != nil { + return nil, err + } + x := &lLMPredictStreamClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type LLM_PredictStreamClient interface { + Recv() (*Reply, error) + grpc.ClientStream +} + +type lLMPredictStreamClient struct { + grpc.ClientStream +} + +func (x *lLMPredictStreamClient) Recv() (*Reply, error) { + m := new(Reply) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +// LLMServer is the server API for LLM service. +// All implementations must embed UnimplementedLLMServer +// for forward compatibility +type LLMServer interface { + Health(context.Context, *HealthMessage) (*Reply, error) + Predict(context.Context, *PredictOptions) (*Reply, error) + LoadModel(context.Context, *ModelOptions) (*Result, error) + PredictStream(*PredictOptions, LLM_PredictStreamServer) error + mustEmbedUnimplementedLLMServer() +} + +// UnimplementedLLMServer must be embedded to have forward compatible implementations. +type UnimplementedLLMServer struct { +} + +func (UnimplementedLLMServer) Health(context.Context, *HealthMessage) (*Reply, error) { + return nil, status.Errorf(codes.Unimplemented, "method Health not implemented") +} +func (UnimplementedLLMServer) Predict(context.Context, *PredictOptions) (*Reply, error) { + return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented") +} +func (UnimplementedLLMServer) LoadModel(context.Context, *ModelOptions) (*Result, error) { + return nil, status.Errorf(codes.Unimplemented, "method LoadModel not implemented") +} +func (UnimplementedLLMServer) PredictStream(*PredictOptions, LLM_PredictStreamServer) error { + return status.Errorf(codes.Unimplemented, "method PredictStream not implemented") +} +func (UnimplementedLLMServer) mustEmbedUnimplementedLLMServer() {} + +// UnsafeLLMServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to LLMServer will +// result in compilation errors. +type UnsafeLLMServer interface { + mustEmbedUnimplementedLLMServer() +} + +func RegisterLLMServer(s grpc.ServiceRegistrar, srv LLMServer) { + s.RegisterService(&LLM_ServiceDesc, srv) +} + +func _LLM_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HealthMessage) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(LLMServer).Health(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/llm.LLM/Health", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(LLMServer).Health(ctx, req.(*HealthMessage)) + } + return interceptor(ctx, in, info, handler) +} + +func _LLM_Predict_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PredictOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(LLMServer).Predict(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/llm.LLM/Predict", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(LLMServer).Predict(ctx, req.(*PredictOptions)) + } + return interceptor(ctx, in, info, handler) +} + +func _LLM_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ModelOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(LLMServer).LoadModel(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/llm.LLM/LoadModel", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(LLMServer).LoadModel(ctx, req.(*ModelOptions)) + } + return interceptor(ctx, in, info, handler) +} + +func _LLM_PredictStream_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(PredictOptions) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(LLMServer).PredictStream(m, &lLMPredictStreamServer{stream}) +} + +type LLM_PredictStreamServer interface { + Send(*Reply) error + grpc.ServerStream +} + +type lLMPredictStreamServer struct { + grpc.ServerStream +} + +func (x *lLMPredictStreamServer) Send(m *Reply) error { + return x.ServerStream.SendMsg(m) +} + +// LLM_ServiceDesc is the grpc.ServiceDesc for LLM service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var LLM_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "llm.LLM", + HandlerType: (*LLMServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Health", + Handler: _LLM_Health_Handler, + }, + { + MethodName: "Predict", + Handler: _LLM_Predict_Handler, + }, + { + MethodName: "LoadModel", + Handler: _LLM_LoadModel_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "PredictStream", + Handler: _LLM_PredictStream_Handler, + ServerStreams: true, + }, + }, + Metadata: "pkg/grpc/proto/llmserver.proto", +} diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go new file mode 100644 index 0000000..d449593 --- /dev/null +++ b/pkg/grpc/server.go @@ -0,0 +1,76 @@ +package grpc + +import ( + "context" + "fmt" + "log" + "net" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "google.golang.org/grpc" +) + +// A GRPC Server that allows to run LLM inference. +// It is used by the LLMServices to expose the LLM functionalities that are called by the client. +// The GRPC Service is general, trying to encompass all the possible LLM options models. +// It depends on the real implementer then what can be done or not. +// +// The server is implemented as a GRPC service, with the following methods: +// - Predict: to run the inference with options +// - PredictStream: to run the inference with options and stream the results + +// server is used to implement helloworld.GreeterServer. +type server struct { + pb.UnimplementedLLMServer + llm LLM +} + +func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) { + return &pb.Reply{Message: "OK"}, nil +} + +func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { + err := s.llm.Load(in) + if err != nil { + return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err + } + return &pb.Result{Message: "Loading succeeded", Success: true}, nil +} + +func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) { + result, err := s.llm.Predict(in) + return &pb.Reply{Message: result}, err +} + +func (s *server) PredictStream(in *pb.PredictOptions, stream pb.LLM_PredictStreamServer) error { + + resultChan := make(chan string) + + done := make(chan bool) + go func() { + for result := range resultChan { + stream.Send(&pb.Reply{Message: result}) + } + done <- true + }() + + s.llm.PredictStream(in, resultChan) + <-done + + return nil +} + +func StartServer(address string, model LLM) error { + lis, err := net.Listen("tcp", address) + if err != nil { + return err + } + s := grpc.NewServer() + pb.RegisterLLMServer(s, &server{llm: model}) + log.Printf("gRPC Server listening at %v", lis.Addr()) + if err := s.Serve(lis); err != nil { + return err + } + + return nil +} diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 3849f85..5dba7ce 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -1,12 +1,16 @@ package model import ( + "context" "fmt" + "os" "path/filepath" "strings" + "time" rwkv "github.com/donomii/go-rwkv.cpp" whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + grpc "github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/langchain" "github.com/go-skynet/LocalAI/pkg/stablediffusion" "github.com/go-skynet/LocalAI/pkg/tts" @@ -15,8 +19,12 @@ import ( transformers "github.com/go-skynet/go-ggml-transformers.cpp" llama "github.com/go-skynet/go-llama.cpp" "github.com/hashicorp/go-multierror" + "github.com/hpcloud/tail" gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" + "github.com/phayes/freeport" "github.com/rs/zerolog/log" + + process "github.com/mudler/go-processmanager" ) const tokenizerSuffix = ".tokenizer.json" @@ -42,22 +50,24 @@ const ( StableDiffusionBackend = "stablediffusion" PiperBackend = "piper" LCHuggingFaceBackend = "langchain-huggingface" + //GGLLMFalconBackend = "falcon" ) var autoLoadBackends []string = []string{ LlamaBackend, Gpt4All, RwkvBackend, - GPTNeoXBackend, + //GGLLMFalconBackend, WhisperBackend, BertEmbeddingsBackend, + GPTNeoXBackend, GPTJBackend, Gpt2Backend, DollyBackend, - FalconBackend, MPTBackend, ReplitBackend, StarcoderBackend, + FalconBackend, BloomzBackend, } @@ -73,6 +83,12 @@ var dolly = func(modelFile string) (interface{}, error) { return transformers.NewDolly(modelFile) } +// func ggllmFalcon(opts ...ggllm.ModelOption) func(string) (interface{}, error) { +// return func(s string) (interface{}, error) { +// return ggllm.New(s, opts...) +// } +// } + var gptNeoX = func(modelFile string) (interface{}, error) { return transformers.NewGPTNeoX(modelFile) } @@ -143,55 +159,157 @@ func rwkvLM(tokenFile string, threads uint32) func(string) (interface{}, error) } } -func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32, assetDir string) (model interface{}, err error) { - log.Debug().Msgf("Loading model %s from %s", backendString, modelFile) - switch strings.ToLower(backendString) { +// starts the grpcModelProcess for the backend, and returns a grpc client +// It also loads the model +func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (interface{}, error) { + return func(s string) (interface{}, error) { + log.Debug().Msgf("Loading GRPC Model", backend, *o) + + grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) + + // Make sure the process is executable + if err := os.Chmod(grpcProcess, 0755); err != nil { + return nil, err + } + + log.Debug().Msgf("Loading GRPC Process", grpcProcess) + port, err := freeport.GetFreePort() + if err != nil { + return nil, err + } + + serverAddress := fmt.Sprintf("localhost:%d", port) + + log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress) + + grpcControlProcess := process.New( + process.WithTemporaryStateDir(), + process.WithName(grpcProcess), + process.WithArgs("--addr", serverAddress)) + + ml.grpcProcesses[o.modelFile] = grpcControlProcess + + if err := grpcControlProcess.Run(); err != nil { + return nil, err + } + + go func() { + t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) + if err != nil { + log.Debug().Msgf("Could not tail stderr") + } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) + } + }() + go func() { + t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true}) + if err != nil { + log.Debug().Msgf("Could not tail stdout") + } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) + } + }() + + log.Debug().Msgf("GRPC Service Started") + + client := grpc.NewClient(serverAddress) + + // Wait for the service to start up + ready := false + for i := 0; i < 10; i++ { + if client.HealthCheck(context.Background()) { + log.Debug().Msgf("GRPC Service Ready") + ready = true + break + } + time.Sleep(1 * time.Second) + } + + if !ready { + log.Debug().Msgf("GRPC Service NOT ready") + log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive()) + log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:")) + + log.Debug().Msgf(grpcControlProcess.ExitCode()) + + return nil, fmt.Errorf("grpc service not ready") + } + + options := *o.gRPCOptions + options.Model = s + + log.Debug().Msgf("GRPC: Loading model with options: %+v", options) + + res, err := client.LoadModel(context.TODO(), &options) + if err != nil { + return nil, err + } + if !res.Success { + return nil, fmt.Errorf("could not load model: %s", res.Message) + } + + return client, nil + } +} + +func (ml *ModelLoader) BackendLoader(opts ...Option) (model interface{}, err error) { + + //backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32, assetDir string) (model interface{}, err error) { + + o := NewOptions(opts...) + + log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) + switch strings.ToLower(o.backendString) { case LlamaBackend: - return ml.LoadModel(modelFile, llamaLM(llamaOpts...)) + return ml.LoadModel(o.modelFile, llamaLM(o.llamaOpts...)) case BloomzBackend: - return ml.LoadModel(modelFile, bloomzLM) + return ml.LoadModel(o.modelFile, bloomzLM) case GPTJBackend: - return ml.LoadModel(modelFile, gptJ) + return ml.LoadModel(o.modelFile, gptJ) case DollyBackend: - return ml.LoadModel(modelFile, dolly) + return ml.LoadModel(o.modelFile, dolly) case MPTBackend: - return ml.LoadModel(modelFile, mpt) + return ml.LoadModel(o.modelFile, mpt) case Gpt2Backend: - return ml.LoadModel(modelFile, transformersLM) + return ml.LoadModel(o.modelFile, transformersLM) case FalconBackend: - return ml.LoadModel(modelFile, falcon) + return ml.LoadModel(o.modelFile, ml.grpcModel(FalconBackend, o)) case GPTNeoXBackend: - return ml.LoadModel(modelFile, gptNeoX) + return ml.LoadModel(o.modelFile, gptNeoX) case ReplitBackend: - return ml.LoadModel(modelFile, replit) + return ml.LoadModel(o.modelFile, replit) case StableDiffusionBackend: - return ml.LoadModel(modelFile, stableDiffusion) + return ml.LoadModel(o.modelFile, stableDiffusion) case PiperBackend: - return ml.LoadModel(modelFile, piperTTS(filepath.Join(assetDir, "backend-assets", "espeak-ng-data"))) + return ml.LoadModel(o.modelFile, piperTTS(filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data"))) case StarcoderBackend: - return ml.LoadModel(modelFile, starCoder) + return ml.LoadModel(o.modelFile, starCoder) case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All: - return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetLibrarySearchPath(filepath.Join(assetDir, "backend-assets", "gpt4all")))) + return ml.LoadModel(o.modelFile, gpt4allLM(gpt4all.SetThreads(int(o.threads)), gpt4all.SetLibrarySearchPath(filepath.Join(o.assetDir, "backend-assets", "gpt4all")))) case BertEmbeddingsBackend: - return ml.LoadModel(modelFile, bertEmbeddings) + return ml.LoadModel(o.modelFile, bertEmbeddings) case RwkvBackend: - return ml.LoadModel(modelFile, rwkvLM(filepath.Join(ml.ModelPath, modelFile+tokenizerSuffix), threads)) + return ml.LoadModel(o.modelFile, rwkvLM(filepath.Join(ml.ModelPath, o.modelFile+tokenizerSuffix), o.threads)) case WhisperBackend: - return ml.LoadModel(modelFile, whisperModel) + return ml.LoadModel(o.modelFile, whisperModel) case LCHuggingFaceBackend: - return ml.LoadModel(modelFile, lcHuggingFace) + return ml.LoadModel(o.modelFile, lcHuggingFace) default: - return nil, fmt.Errorf("backend unsupported: %s", backendString) + return nil, fmt.Errorf("backend unsupported: %s", o.backendString) } } -func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32, assetDir string) (interface{}, error) { - log.Debug().Msgf("Loading model '%s' greedly", modelFile) +func (ml *ModelLoader) GreedyLoader(opts ...Option) (interface{}, error) { + o := NewOptions(opts...) + + log.Debug().Msgf("Loading model '%s' greedly", o.modelFile) ml.mu.Lock() - m, exists := ml.models[modelFile] + m, exists := ml.models[o.modelFile] if exists { - log.Debug().Msgf("Model '%s' already loaded", modelFile) + log.Debug().Msgf("Model '%s' already loaded", o.modelFile) ml.mu.Unlock() return m, nil } @@ -203,7 +321,15 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt continue } log.Debug().Msgf("[%s] Attempting to load", b) - model, modelerr := ml.BackendLoader(b, modelFile, llamaOpts, threads, assetDir) + + model, modelerr := ml.BackendLoader( + WithBackendString(b), + WithModelFile(o.modelFile), + WithLlamaOpts(o.llamaOpts...), + WithLoadGRPCOpts(o.gRPCOptions), + WithThreads(o.threads), + WithAssetDir(o.assetDir), + ) if modelerr == nil && model != nil { log.Debug().Msgf("[%s] Loads OK", b) return model, nil diff --git a/pkg/model/loader.go b/pkg/model/loader.go index ddc7b6e..35f3cef 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -10,6 +10,7 @@ import ( "sync" "text/template" + process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" ) @@ -18,6 +19,7 @@ type ModelLoader struct { mu sync.Mutex // TODO: this needs generics models map[string]interface{} + grpcProcesses map[string]*process.Process promptsTemplates map[string]*template.Template } @@ -26,6 +28,7 @@ func NewModelLoader(modelPath string) *ModelLoader { ModelPath: modelPath, models: make(map[string]interface{}), promptsTemplates: make(map[string]*template.Template), + grpcProcesses: make(map[string]*process.Process), } } diff --git a/pkg/model/options.go b/pkg/model/options.go new file mode 100644 index 0000000..3716330 --- /dev/null +++ b/pkg/model/options.go @@ -0,0 +1,62 @@ +package model + +import ( + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + llama "github.com/go-skynet/go-llama.cpp" +) + +type Options struct { + backendString string + modelFile string + llamaOpts []llama.ModelOption + threads uint32 + assetDir string + + gRPCOptions *pb.ModelOptions +} + +type Option func(*Options) + +func WithBackendString(backend string) Option { + return func(o *Options) { + o.backendString = backend + } +} + +func WithModelFile(modelFile string) Option { + return func(o *Options) { + o.modelFile = modelFile + } +} + +func WithLoadGRPCOpts(opts *pb.ModelOptions) Option { + return func(o *Options) { + o.gRPCOptions = opts + } +} + +func WithLlamaOpts(opts ...llama.ModelOption) Option { + return func(o *Options) { + o.llamaOpts = append(o.llamaOpts, opts...) + } +} + +func WithThreads(threads uint32) Option { + return func(o *Options) { + o.threads = threads + } +} + +func WithAssetDir(assetDir string) Option { + return func(o *Options) { + o.assetDir = assetDir + } +} + +func NewOptions(opts ...Option) *Options { + o := &Options{} + for _, opt := range opts { + opt(o) + } + return o +} From 58f6aab637ca67f9e49a8da9ac2ce3a9f5efdb01 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 02/12] feat: move llama to a grpc Signed-off-by: Ettore Di Giacinto --- Makefile | 9 +- api/prediction.go | 298 ++++------------------------ cmd/grpc/llama/main.go | 25 +++ pkg/grpc/client.go | 11 + pkg/grpc/interface.go | 1 + pkg/grpc/llm/falcon/falcon.go | 4 + pkg/grpc/llm/llama/llama.go | 165 +++++++++++++++ pkg/grpc/proto/llmserver.pb.go | 205 +++++++++++++------ pkg/grpc/proto/llmserver.proto | 8 +- pkg/grpc/proto/llmserver_grpc.pb.go | 36 ++++ pkg/grpc/server.go | 9 + pkg/model/initializers.go | 15 +- pkg/model/options.go | 8 - 13 files changed, 454 insertions(+), 340 deletions(-) create mode 100644 cmd/grpc/llama/main.go create mode 100644 pkg/grpc/llm/llama/llama.go diff --git a/Makefile b/Makefile index abac2b4..3514161 100644 --- a/Makefile +++ b/Makefile @@ -67,8 +67,8 @@ WHITE := $(shell tput -Txterm setaf 7) CYAN := $(shell tput -Txterm setaf 6) RESET := $(shell tput -Txterm sgr0) -C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz -LIBRARY_PATH=$(shell pwd)/go-piper:$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz +C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz +LIBRARY_PATH=$(shell pwd)/go-piper:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz ifeq ($(BUILD_TYPE),openblas) CGO_LDFLAGS+=-lopenblas @@ -369,5 +369,8 @@ falcon-grpc: backend-assets/grpc CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggllm LIBRARY_PATH=$(shell pwd)/go-ggllm \ $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/ +llama-grpc: backend-assets/grpc + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/ -grpcs: falcon-grpc \ No newline at end of file +grpcs: falcon-grpc llama-grpc \ No newline at end of file diff --git a/api/prediction.go b/api/prediction.go index b9b5710..970f06e 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -18,7 +18,6 @@ import ( "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp" - llama "github.com/go-skynet/go-llama.cpp" gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" ) @@ -36,6 +35,11 @@ func gRPCModelOpts(c Config) *pb.ModelOptions { ContextSize: int32(c.ContextSize), Seed: int32(c.Seed), NBatch: int32(b), + F16Memory: c.F16, + MLock: c.MMlock, + NUMA: c.NUMA, + Embeddings: c.Embeddings, + LowVRAM: c.LowVRAM, NGPULayers: int32(c.NGPULayers), MMap: c.MMap, MainGPU: c.MainGPU, @@ -43,32 +47,6 @@ func gRPCModelOpts(c Config) *pb.ModelOptions { } } -// func defaultGGLLMOpts(c Config) []ggllm.ModelOption { -// ggllmOpts := []ggllm.ModelOption{} -// if c.ContextSize != 0 { -// ggllmOpts = append(ggllmOpts, ggllm.SetContext(c.ContextSize)) -// } -// // F16 doesn't seem to produce good output at all! -// //if c.F16 { -// // llamaOpts = append(llamaOpts, llama.EnableF16Memory) -// //} - -// if c.NGPULayers != 0 { -// ggllmOpts = append(ggllmOpts, ggllm.SetGPULayers(c.NGPULayers)) -// } - -// ggllmOpts = append(ggllmOpts, ggllm.SetMMap(c.MMap)) -// ggllmOpts = append(ggllmOpts, ggllm.SetMainGPU(c.MainGPU)) -// ggllmOpts = append(ggllmOpts, ggllm.SetTensorSplit(c.TensorSplit)) -// if c.Batch != 0 { -// ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(c.Batch)) -// } else { -// ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(512)) -// } - -// return ggllmOpts -// } - func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { promptCachePath := "" if c.PromptCachePath != "" { @@ -77,14 +55,18 @@ func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { promptCachePath = p } return &pb.PredictOptions{ - Temperature: float32(c.Temperature), - TopP: float32(c.TopP), - TopK: int32(c.TopK), - Tokens: int32(c.Maxtokens), - Threads: int32(c.Threads), - PromptCacheAll: c.PromptCacheAll, - PromptCacheRO: c.PromptCacheRO, - PromptCachePath: promptCachePath, + Temperature: float32(c.Temperature), + TopP: float32(c.TopP), + TopK: int32(c.TopK), + Tokens: int32(c.Maxtokens), + Threads: int32(c.Threads), + PromptCacheAll: c.PromptCacheAll, + PromptCacheRO: c.PromptCacheRO, + PromptCachePath: promptCachePath, + F16KV: c.F16, + DebugMode: c.Debug, + Grammar: c.Grammar, + Mirostat: int32(c.Mirostat), MirostatETA: float32(c.MirostatETA), MirostatTAU: float32(c.MirostatTAU), @@ -105,200 +87,6 @@ func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { } } -// func buildGGLLMPredictOptions(c Config, modelPath string) []ggllm.PredictOption { -// // Generate the prediction using the language model -// predictOptions := []ggllm.PredictOption{ -// ggllm.SetTemperature(c.Temperature), -// ggllm.SetTopP(c.TopP), -// ggllm.SetTopK(c.TopK), -// ggllm.SetTokens(c.Maxtokens), -// ggllm.SetThreads(c.Threads), -// } - -// if c.PromptCacheAll { -// predictOptions = append(predictOptions, ggllm.EnablePromptCacheAll) -// } - -// if c.PromptCacheRO { -// predictOptions = append(predictOptions, ggllm.EnablePromptCacheRO) -// } - -// if c.PromptCachePath != "" { -// // Create parent directory -// p := filepath.Join(modelPath, c.PromptCachePath) -// os.MkdirAll(filepath.Dir(p), 0755) -// predictOptions = append(predictOptions, ggllm.SetPathPromptCache(p)) -// } - -// if c.Mirostat != 0 { -// predictOptions = append(predictOptions, ggllm.SetMirostat(c.Mirostat)) -// } - -// if c.MirostatETA != 0 { -// predictOptions = append(predictOptions, ggllm.SetMirostatETA(c.MirostatETA)) -// } - -// if c.MirostatTAU != 0 { -// predictOptions = append(predictOptions, ggllm.SetMirostatTAU(c.MirostatTAU)) -// } - -// if c.Debug { -// predictOptions = append(predictOptions, ggllm.Debug) -// } - -// predictOptions = append(predictOptions, ggllm.SetStopWords(c.StopWords...)) - -// if c.RepeatPenalty != 0 { -// predictOptions = append(predictOptions, ggllm.SetPenalty(c.RepeatPenalty)) -// } - -// if c.Keep != 0 { -// predictOptions = append(predictOptions, ggllm.SetNKeep(c.Keep)) -// } - -// if c.Batch != 0 { -// predictOptions = append(predictOptions, ggllm.SetBatch(c.Batch)) -// } - -// if c.IgnoreEOS { -// predictOptions = append(predictOptions, ggllm.IgnoreEOS) -// } - -// if c.Seed != 0 { -// predictOptions = append(predictOptions, ggllm.SetSeed(c.Seed)) -// } - -// //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) - -// predictOptions = append(predictOptions, ggllm.SetFrequencyPenalty(c.FrequencyPenalty)) -// predictOptions = append(predictOptions, ggllm.SetMlock(c.MMlock)) -// predictOptions = append(predictOptions, ggllm.SetMemoryMap(c.MMap)) -// predictOptions = append(predictOptions, ggllm.SetPredictionMainGPU(c.MainGPU)) -// predictOptions = append(predictOptions, ggllm.SetPredictionTensorSplit(c.TensorSplit)) -// predictOptions = append(predictOptions, ggllm.SetTailFreeSamplingZ(c.TFZ)) -// predictOptions = append(predictOptions, ggllm.SetTypicalP(c.TypicalP)) - -// return predictOptions -// } - -func defaultLLamaOpts(c Config) []llama.ModelOption { - llamaOpts := []llama.ModelOption{} - if c.ContextSize != 0 { - llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize)) - } - if c.F16 { - llamaOpts = append(llamaOpts, llama.EnableF16Memory) - } - if c.Embeddings { - llamaOpts = append(llamaOpts, llama.EnableEmbeddings) - } - - if c.NGPULayers != 0 { - llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers)) - } - - llamaOpts = append(llamaOpts, llama.SetMMap(c.MMap)) - llamaOpts = append(llamaOpts, llama.SetMainGPU(c.MainGPU)) - llamaOpts = append(llamaOpts, llama.SetTensorSplit(c.TensorSplit)) - if c.Batch != 0 { - llamaOpts = append(llamaOpts, llama.SetNBatch(c.Batch)) - } else { - llamaOpts = append(llamaOpts, llama.SetNBatch(512)) - } - - if c.NUMA { - llamaOpts = append(llamaOpts, llama.EnableNUMA) - } - - if c.LowVRAM { - llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) - } - - return llamaOpts -} - -func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption { - // Generate the prediction using the language model - predictOptions := []llama.PredictOption{ - llama.SetTemperature(c.Temperature), - llama.SetTopP(c.TopP), - llama.SetTopK(c.TopK), - llama.SetTokens(c.Maxtokens), - llama.SetThreads(c.Threads), - } - - if c.PromptCacheAll { - predictOptions = append(predictOptions, llama.EnablePromptCacheAll) - } - - if c.PromptCacheRO { - predictOptions = append(predictOptions, llama.EnablePromptCacheRO) - } - - predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar)) - - if c.PromptCachePath != "" { - // Create parent directory - p := filepath.Join(modelPath, c.PromptCachePath) - os.MkdirAll(filepath.Dir(p), 0755) - predictOptions = append(predictOptions, llama.SetPathPromptCache(p)) - } - - if c.Mirostat != 0 { - predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) - } - - if c.MirostatETA != 0 { - predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) - } - - if c.MirostatTAU != 0 { - predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) - } - - if c.Debug { - predictOptions = append(predictOptions, llama.Debug) - } - - predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) - - if c.RepeatPenalty != 0 { - predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) - } - - if c.Keep != 0 { - predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) - } - - if c.F16 { - predictOptions = append(predictOptions, llama.EnableF16KV) - } - - if c.IgnoreEOS { - predictOptions = append(predictOptions, llama.IgnoreEOS) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) - } - - //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) - - predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty)) - predictOptions = append(predictOptions, llama.SetMlock(c.MMlock)) - predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap)) - predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU)) - predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit)) - predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ)) - predictOptions = append(predictOptions, llama.SetTypicalP(c.TypicalP)) - - return predictOptions -} - func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) { if c.Backend != model.StableDiffusionBackend { return nil, fmt.Errorf("endpoint only working with stablediffusion models") @@ -351,14 +139,12 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, modelFile := c.Model - llamaOpts := defaultLLamaOpts(c) grpcOpts := gRPCModelOpts(c) var inferenceModel interface{} var err error opts := []model.Option{ - model.WithLlamaOpts(llamaOpts...), model.WithLoadGRPCOpts(grpcOpts), model.WithThreads(uint32(c.Threads)), model.WithAssetDir(o.assetsDestination), @@ -377,14 +163,34 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, var fn func() ([]float32, error) switch model := inferenceModel.(type) { - case *llama.LLama: + case *grpc.Client: fn = func() ([]float32, error) { - predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) + predictOptions := gRPCPredictOpts(c, loader.ModelPath) if len(tokens) > 0 { - return model.TokenEmbeddings(tokens, predictOptions...) + embeds := []int32{} + + for _, t := range tokens { + embeds = append(embeds, int32(t)) + } + predictOptions.EmbeddingTokens = embeds + + res, err := model.Embeddings(context.TODO(), predictOptions) + if err != nil { + return nil, err + } + + return res.Embeddings, nil + } + predictOptions.Embeddings = s + + res, err := model.Embeddings(context.TODO(), predictOptions) + if err != nil { + return nil, err } - return model.Embeddings(s, predictOptions...) + + return res.Embeddings, nil } + // bert embeddings case *bert.Bert: fn = func() ([]float32, error) { @@ -432,14 +238,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to supportStreams := false modelFile := c.Model - llamaOpts := defaultLLamaOpts(c) grpcOpts := gRPCModelOpts(c) var inferenceModel interface{} var err error opts := []model.Option{ - model.WithLlamaOpts(llamaOpts...), model.WithLoadGRPCOpts(grpcOpts), model.WithThreads(uint32(c.Threads)), model.WithAssetDir(o.assetsDestination), @@ -708,26 +512,6 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch)) } - str, er := model.Predict( - s, - predictOptions..., - ) - // Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels) - // For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}} - // after a stream event has occurred - model.SetTokenCallback(nil) - return str, er - } - case *llama.LLama: - supportStreams = true - fn = func() (string, error) { - - if tokenCallback != nil { - model.SetTokenCallback(tokenCallback) - } - - predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) - str, er := model.Predict( s, predictOptions..., diff --git a/cmd/grpc/llama/main.go b/cmd/grpc/llama/main.go new file mode 100644 index 0000000..d75ef48 --- /dev/null +++ b/cmd/grpc/llama/main.go @@ -0,0 +1,25 @@ +package main + +// GRPC Falcon server + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &llama.LLM{}); err != nil { + panic(err) + } +} diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index f63a89a..06628eb 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -47,6 +47,17 @@ func (c *Client) HealthCheck(ctx context.Context) bool { return false } +func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewLLMClient(conn) + + return client.Embedding(ctx, in, opts...) +} + func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) { conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 8ac851a..70b830f 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -8,4 +8,5 @@ type LLM interface { Predict(*pb.PredictOptions) (string, error) PredictStream(*pb.PredictOptions, chan string) Load(*pb.ModelOptions) error + Embeddings(*pb.PredictOptions) ([]float32, error) } diff --git a/pkg/grpc/llm/falcon/falcon.go b/pkg/grpc/llm/falcon/falcon.go index a0a53be..5d8cf75 100644 --- a/pkg/grpc/llm/falcon/falcon.go +++ b/pkg/grpc/llm/falcon/falcon.go @@ -42,6 +42,10 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error { return err } +func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption { predictOptions := []ggllm.PredictOption{ ggllm.SetTemperature(float64(opts.Temperature)), diff --git a/pkg/grpc/llm/llama/llama.go b/pkg/grpc/llm/llama/llama.go new file mode 100644 index 0000000..a31e274 --- /dev/null +++ b/pkg/grpc/llm/llama/llama.go @@ -0,0 +1,165 @@ +package llama + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/go-llama.cpp" +) + +type LLM struct { + llama *llama.LLama +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + llamaOpts := []llama.ModelOption{} + + if opts.ContextSize != 0 { + llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize))) + } + if opts.F16Memory { + llamaOpts = append(llamaOpts, llama.EnableF16Memory) + } + if opts.Embeddings { + llamaOpts = append(llamaOpts, llama.EnableEmbeddings) + } + if opts.NGPULayers != 0 { + llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers))) + } + + llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap)) + llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU)) + llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit)) + if opts.NBatch != 0 { + llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch))) + } else { + llamaOpts = append(llamaOpts, llama.SetNBatch(512)) + } + + if opts.NUMA { + llamaOpts = append(llamaOpts, llama.EnableNUMA) + } + + if opts.LowVRAM { + llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) + } + + model, err := llama.New(opts.Model, llamaOpts...) + llm.llama = model + return err +} + +func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { + predictOptions := []llama.PredictOption{ + llama.SetTemperature(float64(opts.Temperature)), + llama.SetTopP(float64(opts.TopP)), + llama.SetTopK(int(opts.TopK)), + llama.SetTokens(int(opts.Tokens)), + llama.SetThreads(int(opts.Threads)), + } + + if opts.PromptCacheAll { + predictOptions = append(predictOptions, llama.EnablePromptCacheAll) + } + + if opts.PromptCacheRO { + predictOptions = append(predictOptions, llama.EnablePromptCacheRO) + } + + predictOptions = append(predictOptions, llama.WithGrammar(opts.Grammar)) + + // Expected absolute path + if opts.PromptCachePath != "" { + predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath)) + } + + if opts.Mirostat != 0 { + predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat))) + } + + if opts.MirostatETA != 0 { + predictOptions = append(predictOptions, llama.SetMirostatETA(float64(opts.MirostatETA))) + } + + if opts.MirostatTAU != 0 { + predictOptions = append(predictOptions, llama.SetMirostatTAU(float64(opts.MirostatTAU))) + } + + if opts.Debug { + predictOptions = append(predictOptions, llama.Debug) + } + + predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...)) + + if opts.PresencePenalty != 0 { + predictOptions = append(predictOptions, llama.SetPenalty(float64(opts.PresencePenalty))) + } + + if opts.NKeep != 0 { + predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep))) + } + + if opts.Batch != 0 { + predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch))) + } + + if opts.F16KV { + predictOptions = append(predictOptions, llama.EnableF16KV) + } + + if opts.IgnoreEOS { + predictOptions = append(predictOptions, llama.IgnoreEOS) + } + + if opts.Seed != 0 { + predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed))) + } + + //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) + + predictOptions = append(predictOptions, llama.SetFrequencyPenalty(float64(opts.FrequencyPenalty))) + predictOptions = append(predictOptions, llama.SetMlock(opts.MLock)) + predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap)) + predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU)) + predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit)) + predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ))) + predictOptions = append(predictOptions, llama.SetTypicalP(float64(opts.TypicalP))) + return predictOptions +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { + predictOptions := buildPredictOptions(opts) + + predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool { + results <- token + return true + })) + + go func() { + _, err := llm.llama.Predict(opts.Prompt, predictOptions...) + if err != nil { + fmt.Println("err: ", err) + } + close(results) + }() +} + +func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + predictOptions := buildPredictOptions(opts) + + if len(opts.EmbeddingTokens) > 0 { + tokens := []int{} + for _, t := range opts.EmbeddingTokens { + tokens = append(tokens, int(t)) + } + return llm.llama.TokenEmbeddings(tokens, predictOptions...) + } + + return llm.llama.Embeddings(opts.Embeddings, predictOptions...) +} diff --git a/pkg/grpc/proto/llmserver.pb.go b/pkg/grpc/proto/llmserver.pb.go index 067c3a1..d54c393 100644 --- a/pkg/grpc/proto/llmserver.pb.go +++ b/pkg/grpc/proto/llmserver.pb.go @@ -87,7 +87,6 @@ type PredictOptions struct { MirostatTAU float32 `protobuf:"fixed32,21,opt,name=MirostatTAU,proto3" json:"MirostatTAU,omitempty"` PenalizeNL bool `protobuf:"varint,22,opt,name=PenalizeNL,proto3" json:"PenalizeNL,omitempty"` LogitBias string `protobuf:"bytes,23,opt,name=LogitBias,proto3" json:"LogitBias,omitempty"` - PathPromptCache string `protobuf:"bytes,24,opt,name=PathPromptCache,proto3" json:"PathPromptCache,omitempty"` MLock bool `protobuf:"varint,25,opt,name=MLock,proto3" json:"MLock,omitempty"` MMap bool `protobuf:"varint,26,opt,name=MMap,proto3" json:"MMap,omitempty"` PromptCacheAll bool `protobuf:"varint,27,opt,name=PromptCacheAll,proto3" json:"PromptCacheAll,omitempty"` @@ -98,6 +97,8 @@ type PredictOptions struct { TopP float32 `protobuf:"fixed32,32,opt,name=TopP,proto3" json:"TopP,omitempty"` PromptCachePath string `protobuf:"bytes,33,opt,name=PromptCachePath,proto3" json:"PromptCachePath,omitempty"` Debug bool `protobuf:"varint,34,opt,name=Debug,proto3" json:"Debug,omitempty"` + EmbeddingTokens []int32 `protobuf:"varint,35,rep,packed,name=EmbeddingTokens,proto3" json:"EmbeddingTokens,omitempty"` + Embeddings string `protobuf:"bytes,36,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` } func (x *PredictOptions) Reset() { @@ -293,13 +294,6 @@ func (x *PredictOptions) GetLogitBias() string { return "" } -func (x *PredictOptions) GetPathPromptCache() string { - if x != nil { - return x.PathPromptCache - } - return "" -} - func (x *PredictOptions) GetMLock() bool { if x != nil { return x.MLock @@ -370,6 +364,20 @@ func (x *PredictOptions) GetDebug() bool { return false } +func (x *PredictOptions) GetEmbeddingTokens() []int32 { + if x != nil { + return x.EmbeddingTokens + } + return nil +} + +func (x *PredictOptions) GetEmbeddings() string { + if x != nil { + return x.Embeddings + } + return "" +} + // The response message containing the result type Reply struct { state protoimpl.MessageState @@ -624,13 +632,60 @@ func (x *Result) GetSuccess() bool { return false } +type EmbeddingResult struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Embeddings []float32 `protobuf:"fixed32,1,rep,packed,name=embeddings,proto3" json:"embeddings,omitempty"` +} + +func (x *EmbeddingResult) Reset() { + *x = EmbeddingResult{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EmbeddingResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmbeddingResult) ProtoMessage() {} + +func (x *EmbeddingResult) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmbeddingResult.ProtoReflect.Descriptor instead. +func (*EmbeddingResult) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{5} +} + +func (x *EmbeddingResult) GetEmbeddings() []float32 { + if x != nil { + return x.Embeddings + } + return nil +} + var File_pkg_grpc_proto_llmserver_proto protoreflect.FileDescriptor var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{ 0x0a, 0x1e, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x03, 0x6c, 0x6c, 0x6d, 0x22, 0x0f, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x80, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, 0x64, 0x69, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xa0, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, @@ -673,28 +728,30 @@ var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{ 0x1e, 0x0a, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x18, 0x16, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x18, 0x17, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, 0x28, 0x0a, - 0x0f, 0x50, 0x61, 0x74, 0x68, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, - 0x18, 0x18, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x50, 0x72, 0x6f, 0x6d, - 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, - 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, - 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, - 0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, - 0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, - 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x12, 0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f, - 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, 0x18, 0x1c, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, 0x12, - 0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x18, 0x1d, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, - 0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, - 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, - 0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, - 0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x18, 0x20, 0x20, - 0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f, - 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x18, 0x21, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, - 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x18, 0x22, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x22, 0x21, 0x0a, 0x05, 0x52, 0x65, 0x70, + 0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, 0x14, 0x0a, + 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x4d, 0x4c, + 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, + 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x12, + 0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, + 0x18, 0x1c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, + 0x63, 0x68, 0x65, 0x52, 0x4f, 0x12, 0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, + 0x18, 0x1d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x12, + 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, + 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, + 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54, + 0x6f, 0x70, 0x50, 0x18, 0x20, 0x20, 0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x12, + 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, + 0x74, 0x68, 0x18, 0x21, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, + 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, 0x65, 0x62, + 0x75, 0x67, 0x18, 0x22, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x12, + 0x28, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x73, 0x18, 0x23, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, + 0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x45, 0x6d, 0x62, + 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x45, + 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x21, 0x0a, 0x05, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x82, 0x03, 0x0a, 0x0c, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x14, 0x0a, @@ -724,26 +781,33 @@ var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{ 0x74, 0x22, 0x3c, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, - 0xc4, 0x01, 0x0a, 0x03, 0x4c, 0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, - 0x68, 0x12, 0x12, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, - 0x79, 0x22, 0x00, 0x12, 0x2c, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x13, - 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, - 0x00, 0x12, 0x2d, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x11, - 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x73, 0x1a, 0x0b, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, - 0x12, 0x34, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, - 0x6d, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x22, + 0x31, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, + 0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x02, 0x52, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, + 0x67, 0x73, 0x32, 0xfe, 0x01, 0x0a, 0x03, 0x4c, 0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65, + 0x61, 0x6c, 0x74, 0x68, 0x12, 0x12, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, + 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2c, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, + 0x74, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, - 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x42, 0x57, 0x0a, 0x1b, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79, - 0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x6c, 0x6c, 0x6d, 0x73, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, - 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, - 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, + 0x6c, 0x12, 0x11, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0b, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, + 0x74, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, + 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, + 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x12, 0x38, 0x0a, 0x09, 0x45, 0x6d, 0x62, + 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, + 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x14, 0x2e, 0x6c, 0x6c, + 0x6d, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, + 0x74, 0x22, 0x00, 0x42, 0x57, 0x0a, 0x1b, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, + 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x50, 0x01, 0x5a, + 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, + 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, + 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -758,25 +822,28 @@ func file_pkg_grpc_proto_llmserver_proto_rawDescGZIP() []byte { return file_pkg_grpc_proto_llmserver_proto_rawDescData } -var file_pkg_grpc_proto_llmserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_pkg_grpc_proto_llmserver_proto_msgTypes = make([]protoimpl.MessageInfo, 6) var file_pkg_grpc_proto_llmserver_proto_goTypes = []interface{}{ - (*HealthMessage)(nil), // 0: llm.HealthMessage - (*PredictOptions)(nil), // 1: llm.PredictOptions - (*Reply)(nil), // 2: llm.Reply - (*ModelOptions)(nil), // 3: llm.ModelOptions - (*Result)(nil), // 4: llm.Result + (*HealthMessage)(nil), // 0: llm.HealthMessage + (*PredictOptions)(nil), // 1: llm.PredictOptions + (*Reply)(nil), // 2: llm.Reply + (*ModelOptions)(nil), // 3: llm.ModelOptions + (*Result)(nil), // 4: llm.Result + (*EmbeddingResult)(nil), // 5: llm.EmbeddingResult } var file_pkg_grpc_proto_llmserver_proto_depIdxs = []int32{ 0, // 0: llm.LLM.Health:input_type -> llm.HealthMessage 1, // 1: llm.LLM.Predict:input_type -> llm.PredictOptions 3, // 2: llm.LLM.LoadModel:input_type -> llm.ModelOptions 1, // 3: llm.LLM.PredictStream:input_type -> llm.PredictOptions - 2, // 4: llm.LLM.Health:output_type -> llm.Reply - 2, // 5: llm.LLM.Predict:output_type -> llm.Reply - 4, // 6: llm.LLM.LoadModel:output_type -> llm.Result - 2, // 7: llm.LLM.PredictStream:output_type -> llm.Reply - 4, // [4:8] is the sub-list for method output_type - 0, // [0:4] is the sub-list for method input_type + 1, // 4: llm.LLM.Embedding:input_type -> llm.PredictOptions + 2, // 5: llm.LLM.Health:output_type -> llm.Reply + 2, // 6: llm.LLM.Predict:output_type -> llm.Reply + 4, // 7: llm.LLM.LoadModel:output_type -> llm.Result + 2, // 8: llm.LLM.PredictStream:output_type -> llm.Reply + 5, // 9: llm.LLM.Embedding:output_type -> llm.EmbeddingResult + 5, // [5:10] is the sub-list for method output_type + 0, // [0:5] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -848,6 +915,18 @@ func file_pkg_grpc_proto_llmserver_proto_init() { return nil } } + file_pkg_grpc_proto_llmserver_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EmbeddingResult); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -855,7 +934,7 @@ func file_pkg_grpc_proto_llmserver_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_pkg_grpc_proto_llmserver_proto_rawDesc, NumEnums: 0, - NumMessages: 5, + NumMessages: 6, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/grpc/proto/llmserver.proto b/pkg/grpc/proto/llmserver.proto index ba20806..b6fa4cd 100644 --- a/pkg/grpc/proto/llmserver.proto +++ b/pkg/grpc/proto/llmserver.proto @@ -12,6 +12,7 @@ service LLM { rpc Predict(PredictOptions) returns (Reply) {} rpc LoadModel(ModelOptions) returns (Result) {} rpc PredictStream(PredictOptions) returns (stream Reply) {} + rpc Embedding(PredictOptions) returns (EmbeddingResult) {} } message HealthMessage {} @@ -41,7 +42,6 @@ message PredictOptions { float MirostatTAU = 21; bool PenalizeNL = 22; string LogitBias = 23; - string PathPromptCache = 24; bool MLock = 25; bool MMap = 26; bool PromptCacheAll = 27; @@ -52,6 +52,8 @@ message PredictOptions { float TopP = 32; string PromptCachePath = 33; bool Debug = 34; + repeated int32 EmbeddingTokens = 35; + string Embeddings = 36; } // The response message containing the result @@ -79,4 +81,8 @@ message ModelOptions { message Result { string message = 1; bool success = 2; +} + +message EmbeddingResult { + repeated float embeddings = 1; } \ No newline at end of file diff --git a/pkg/grpc/proto/llmserver_grpc.pb.go b/pkg/grpc/proto/llmserver_grpc.pb.go index 6cfd981..c028218 100644 --- a/pkg/grpc/proto/llmserver_grpc.pb.go +++ b/pkg/grpc/proto/llmserver_grpc.pb.go @@ -26,6 +26,7 @@ type LLMClient interface { Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (LLM_PredictStreamClient, error) + Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) } type lLMClient struct { @@ -95,6 +96,15 @@ func (x *lLMPredictStreamClient) Recv() (*Reply, error) { return m, nil } +func (c *lLMClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { + out := new(EmbeddingResult) + err := c.cc.Invoke(ctx, "/llm.LLM/Embedding", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // LLMServer is the server API for LLM service. // All implementations must embed UnimplementedLLMServer // for forward compatibility @@ -103,6 +113,7 @@ type LLMServer interface { Predict(context.Context, *PredictOptions) (*Reply, error) LoadModel(context.Context, *ModelOptions) (*Result, error) PredictStream(*PredictOptions, LLM_PredictStreamServer) error + Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) mustEmbedUnimplementedLLMServer() } @@ -122,6 +133,9 @@ func (UnimplementedLLMServer) LoadModel(context.Context, *ModelOptions) (*Result func (UnimplementedLLMServer) PredictStream(*PredictOptions, LLM_PredictStreamServer) error { return status.Errorf(codes.Unimplemented, "method PredictStream not implemented") } +func (UnimplementedLLMServer) Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) { + return nil, status.Errorf(codes.Unimplemented, "method Embedding not implemented") +} func (UnimplementedLLMServer) mustEmbedUnimplementedLLMServer() {} // UnsafeLLMServer may be embedded to opt out of forward compatibility for this service. @@ -210,6 +224,24 @@ func (x *lLMPredictStreamServer) Send(m *Reply) error { return x.ServerStream.SendMsg(m) } +func _LLM_Embedding_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PredictOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(LLMServer).Embedding(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/llm.LLM/Embedding", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(LLMServer).Embedding(ctx, req.(*PredictOptions)) + } + return interceptor(ctx, in, info, handler) +} + // LLM_ServiceDesc is the grpc.ServiceDesc for LLM service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -229,6 +261,10 @@ var LLM_ServiceDesc = grpc.ServiceDesc{ MethodName: "LoadModel", Handler: _LLM_LoadModel_Handler, }, + { + MethodName: "Embedding", + Handler: _LLM_Embedding_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index d449593..9e4c88a 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -29,6 +29,15 @@ func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, e return &pb.Reply{Message: "OK"}, nil } +func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) { + embeds, err := s.llm.Embeddings(in) + if err != nil { + return nil, err + } + + return &pb.EmbeddingResult{Embeddings: embeds}, nil +} + func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { err := s.llm.Load(in) if err != nil { diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 5dba7ce..1acde4c 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -17,7 +17,6 @@ import ( bloomz "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp" - llama "github.com/go-skynet/go-llama.cpp" "github.com/hashicorp/go-multierror" "github.com/hpcloud/tail" gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" @@ -135,11 +134,11 @@ var lcHuggingFace = func(repoId string) (interface{}, error) { return langchain.NewHuggingFace(repoId) } -func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) { - return func(s string) (interface{}, error) { - return llama.New(s, opts...) - } -} +// func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) { +// return func(s string) (interface{}, error) { +// return llama.New(s, opts...) +// } +// } func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) { return func(s string) (interface{}, error) { @@ -263,7 +262,8 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model interface{}, err err log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) switch strings.ToLower(o.backendString) { case LlamaBackend: - return ml.LoadModel(o.modelFile, llamaLM(o.llamaOpts...)) + // return ml.LoadModel(o.modelFile, llamaLM(o.llamaOpts...)) + return ml.LoadModel(o.modelFile, ml.grpcModel(LlamaBackend, o)) case BloomzBackend: return ml.LoadModel(o.modelFile, bloomzLM) case GPTJBackend: @@ -325,7 +325,6 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (interface{}, error) { model, modelerr := ml.BackendLoader( WithBackendString(b), WithModelFile(o.modelFile), - WithLlamaOpts(o.llamaOpts...), WithLoadGRPCOpts(o.gRPCOptions), WithThreads(o.threads), WithAssetDir(o.assetDir), diff --git a/pkg/model/options.go b/pkg/model/options.go index 3716330..31e54cb 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -2,13 +2,11 @@ package model import ( pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - llama "github.com/go-skynet/go-llama.cpp" ) type Options struct { backendString string modelFile string - llamaOpts []llama.ModelOption threads uint32 assetDir string @@ -35,12 +33,6 @@ func WithLoadGRPCOpts(opts *pb.ModelOptions) Option { } } -func WithLlamaOpts(opts ...llama.ModelOption) Option { - return func(o *Options) { - o.llamaOpts = append(o.llamaOpts, opts...) - } -} - func WithThreads(threads uint32) Option { return func(o *Options) { o.threads = threads From ae533cadef85c0f046060c332ebbc7cc6c4794cc Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 03/12] feat: move gpt4all to a grpc service Signed-off-by: Ettore Di Giacinto --- .gitignore | 2 +- Makefile | 30 +++------ api/prediction.go | 33 +--------- cmd/grpc/gpt4all/main.go | 23 +++++++ pkg/grpc/llm/gpt4all/gpt4all.go | 61 ++++++++++++++++++ pkg/grpc/proto/llmserver.pb.go | 110 +++++++++++++++++++------------- pkg/grpc/proto/llmserver.proto | 2 + pkg/model/initializers.go | 16 +++-- 8 files changed, 170 insertions(+), 107 deletions(-) create mode 100644 cmd/grpc/gpt4all/main.go create mode 100644 pkg/grpc/llm/gpt4all/gpt4all.go diff --git a/.gitignore b/.gitignore index 8819ad7..a40bf19 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ # go-llama build artifacts go-llama -gpt4all +/gpt4all go-stable-diffusion go-piper go-ggllm diff --git a/Makefile b/Makefile index 3514161..df7a16e 100644 --- a/Makefile +++ b/Makefile @@ -110,24 +110,6 @@ all: help gpt4all: git clone --recurse-submodules $(GPT4ALL_REPO) gpt4all cd gpt4all && git checkout -b build $(GPT4ALL_VERSION) && git submodule update --init --recursive --depth 1 - # This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml.. - @find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.m" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/llama_/llama_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/llama_/llama_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/llama_/llama_gpt4all_/g' {} + - @find ./gpt4all/gpt4all-backend -type f -name "llama_util.h" -execdir mv {} "llama_gpt4all_util.h" \; - @find ./gpt4all -type f -name "*.cmake" -exec sed -i'' -e 's/llama_util/llama_gpt4all_util/g' {} + - @find ./gpt4all -type f -name "*.txt" -exec sed -i'' -e 's/llama_util/llama_gpt4all_util/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.cpp" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.go" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/set_numa_thread_affinity/gpt4all_set_numa_thread_affinity/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.c" -exec sed -i'' -e 's/set_numa_thread_affinity/gpt4all__set_numa_thread_affinity/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.c" -exec sed -i'' -e 's/clear_numa_thread_affinity/gpt4all__clear_numa_thread_affinity/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/clear_numa_thread_affinity/gpt4all__clear_numa_thread_affinity/g' {} + ## go-ggllm go-ggllm: @@ -282,7 +264,7 @@ rebuild: ## Rebuilds the project $(MAKE) -C go-ggllm clean $(MAKE) build -prepare: prepare-sources backend-assets/gpt4all grpcs $(OPTIONAL_TARGETS) go-ggllm/libggllm.a go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building +prepare: prepare-sources grpcs go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a $(OPTIONAL_TARGETS) touch $@ clean: ## Remove build related file @@ -365,12 +347,16 @@ protogen: backend-assets/grpc: mkdir -p backend-assets/grpc -falcon-grpc: backend-assets/grpc +falcon-grpc: backend-assets/grpc go-ggllm/libggllm.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggllm LIBRARY_PATH=$(shell pwd)/go-ggllm \ $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/ -llama-grpc: backend-assets/grpc +llama-grpc: backend-assets/grpc go-llama/libbinding.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \ $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/ -grpcs: falcon-grpc llama-grpc \ No newline at end of file +gpt4all-grpc: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-bindings/golang/libgpt4all.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./cmd/grpc/gpt4all/ + +grpcs: falcon-grpc llama-grpc gpt4all-grpc \ No newline at end of file diff --git a/api/prediction.go b/api/prediction.go index 970f06e..f24376c 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -18,8 +18,6 @@ import ( "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp" - - gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" ) // mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 @@ -43,6 +41,7 @@ func gRPCModelOpts(c Config) *pb.ModelOptions { NGPULayers: int32(c.NGPULayers), MMap: c.MMap, MainGPU: c.MainGPU, + Threads: int32(c.Threads), TensorSplit: c.TensorSplit, } } @@ -492,36 +491,6 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to predictOptions..., ) } - case *gpt4all.Model: - supportStreams = true - - fn = func() (string, error) { - if tokenCallback != nil { - model.SetTokenCallback(tokenCallback) - } - - // Generate the prediction using the language model - predictOptions := []gpt4all.PredictOption{ - gpt4all.SetTemperature(c.Temperature), - gpt4all.SetTopP(c.TopP), - gpt4all.SetTopK(c.TopK), - gpt4all.SetTokens(c.Maxtokens), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch)) - } - - str, er := model.Predict( - s, - predictOptions..., - ) - // Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels) - // For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}} - // after a stream event has occurred - model.SetTokenCallback(nil) - return str, er - } case *grpc.Client: // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported supportStreams = true diff --git a/cmd/grpc/gpt4all/main.go b/cmd/grpc/gpt4all/main.go new file mode 100644 index 0000000..a784d40 --- /dev/null +++ b/cmd/grpc/gpt4all/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + gpt4all "github.com/go-skynet/LocalAI/pkg/grpc/llm/gpt4all" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &gpt4all.LLM{}); err != nil { + panic(err) + } +} diff --git a/pkg/grpc/llm/gpt4all/gpt4all.go b/pkg/grpc/llm/gpt4all/gpt4all.go new file mode 100644 index 0000000..0d7dac5 --- /dev/null +++ b/pkg/grpc/llm/gpt4all/gpt4all.go @@ -0,0 +1,61 @@ +package gpt4all + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" +) + +type LLM struct { + gpt4all *gpt4all.Model +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + model, err := gpt4all.New(opts.Model, + gpt4all.SetThreads(int(opts.Threads)), + gpt4all.SetLibrarySearchPath(opts.LibrarySearchPath)) + llm.gpt4all = model + return err +} + +func buildPredictOptions(opts *pb.PredictOptions) []gpt4all.PredictOption { + predictOptions := []gpt4all.PredictOption{ + gpt4all.SetTemperature(float64(opts.Temperature)), + gpt4all.SetTopP(float64(opts.TopP)), + gpt4all.SetTopK(int(opts.TopK)), + gpt4all.SetTokens(int(opts.Tokens)), + } + + if opts.Batch != 0 { + predictOptions = append(predictOptions, gpt4all.SetBatch(int(opts.Batch))) + } + return predictOptions +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { + predictOptions := buildPredictOptions(opts) + + go func() { + llm.gpt4all.SetTokenCallback(func(token string) bool { + results <- token + return true + }) + _, err := llm.gpt4all.Predict(opts.Prompt, predictOptions...) + if err != nil { + fmt.Println("err: ", err) + } + llm.gpt4all.SetTokenCallback(nil) + close(results) + }() +} + +func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return []float32{}, fmt.Errorf("not implemented") +} diff --git a/pkg/grpc/proto/llmserver.pb.go b/pkg/grpc/proto/llmserver.pb.go index d54c393..d8bdcd2 100644 --- a/pkg/grpc/proto/llmserver.pb.go +++ b/pkg/grpc/proto/llmserver.pb.go @@ -431,20 +431,22 @@ type ModelOptions struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Model string `protobuf:"bytes,1,opt,name=Model,proto3" json:"Model,omitempty"` - ContextSize int32 `protobuf:"varint,2,opt,name=ContextSize,proto3" json:"ContextSize,omitempty"` - Seed int32 `protobuf:"varint,3,opt,name=Seed,proto3" json:"Seed,omitempty"` - NBatch int32 `protobuf:"varint,4,opt,name=NBatch,proto3" json:"NBatch,omitempty"` - F16Memory bool `protobuf:"varint,5,opt,name=F16Memory,proto3" json:"F16Memory,omitempty"` - MLock bool `protobuf:"varint,6,opt,name=MLock,proto3" json:"MLock,omitempty"` - MMap bool `protobuf:"varint,7,opt,name=MMap,proto3" json:"MMap,omitempty"` - VocabOnly bool `protobuf:"varint,8,opt,name=VocabOnly,proto3" json:"VocabOnly,omitempty"` - LowVRAM bool `protobuf:"varint,9,opt,name=LowVRAM,proto3" json:"LowVRAM,omitempty"` - Embeddings bool `protobuf:"varint,10,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` - NUMA bool `protobuf:"varint,11,opt,name=NUMA,proto3" json:"NUMA,omitempty"` - NGPULayers int32 `protobuf:"varint,12,opt,name=NGPULayers,proto3" json:"NGPULayers,omitempty"` - MainGPU string `protobuf:"bytes,13,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"` - TensorSplit string `protobuf:"bytes,14,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"` + Model string `protobuf:"bytes,1,opt,name=Model,proto3" json:"Model,omitempty"` + ContextSize int32 `protobuf:"varint,2,opt,name=ContextSize,proto3" json:"ContextSize,omitempty"` + Seed int32 `protobuf:"varint,3,opt,name=Seed,proto3" json:"Seed,omitempty"` + NBatch int32 `protobuf:"varint,4,opt,name=NBatch,proto3" json:"NBatch,omitempty"` + F16Memory bool `protobuf:"varint,5,opt,name=F16Memory,proto3" json:"F16Memory,omitempty"` + MLock bool `protobuf:"varint,6,opt,name=MLock,proto3" json:"MLock,omitempty"` + MMap bool `protobuf:"varint,7,opt,name=MMap,proto3" json:"MMap,omitempty"` + VocabOnly bool `protobuf:"varint,8,opt,name=VocabOnly,proto3" json:"VocabOnly,omitempty"` + LowVRAM bool `protobuf:"varint,9,opt,name=LowVRAM,proto3" json:"LowVRAM,omitempty"` + Embeddings bool `protobuf:"varint,10,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` + NUMA bool `protobuf:"varint,11,opt,name=NUMA,proto3" json:"NUMA,omitempty"` + NGPULayers int32 `protobuf:"varint,12,opt,name=NGPULayers,proto3" json:"NGPULayers,omitempty"` + MainGPU string `protobuf:"bytes,13,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"` + TensorSplit string `protobuf:"bytes,14,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"` + Threads int32 `protobuf:"varint,15,opt,name=Threads,proto3" json:"Threads,omitempty"` + LibrarySearchPath string `protobuf:"bytes,16,opt,name=LibrarySearchPath,proto3" json:"LibrarySearchPath,omitempty"` } func (x *ModelOptions) Reset() { @@ -577,6 +579,20 @@ func (x *ModelOptions) GetTensorSplit() string { return "" } +func (x *ModelOptions) GetThreads() int32 { + if x != nil { + return x.Threads + } + return 0 +} + +func (x *ModelOptions) GetLibrarySearchPath() string { + if x != nil { + return x.LibrarySearchPath + } + return "" +} + type Result struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -753,7 +769,7 @@ var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{ 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x21, 0x0a, 0x05, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x82, 0x03, 0x0a, + 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xca, 0x03, 0x0a, 0x0c, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x20, 0x0a, 0x0b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x53, 0x69, @@ -778,36 +794,40 @@ var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{ 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, - 0x74, 0x22, 0x3c, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x22, - 0x31, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, - 0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x02, 0x52, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, - 0x67, 0x73, 0x32, 0xfe, 0x01, 0x0a, 0x03, 0x4c, 0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65, - 0x61, 0x6c, 0x74, 0x68, 0x12, 0x12, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, - 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, - 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2c, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, - 0x74, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, - 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, - 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, - 0x6c, 0x12, 0x11, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, - 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0b, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, - 0x74, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, - 0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, - 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, - 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x12, 0x38, 0x0a, 0x09, 0x45, 0x6d, 0x62, - 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, - 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x14, 0x2e, 0x6c, 0x6c, - 0x6d, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, - 0x74, 0x22, 0x00, 0x42, 0x57, 0x0a, 0x1b, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, - 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x50, 0x01, 0x5a, - 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, - 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, - 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, + 0x74, 0x12, 0x18, 0x0a, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x18, 0x0f, 0x20, 0x01, + 0x28, 0x05, 0x52, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x12, 0x2c, 0x0a, 0x11, 0x4c, + 0x69, 0x62, 0x72, 0x61, 0x72, 0x79, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x50, 0x61, 0x74, 0x68, + 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x11, 0x4c, 0x69, 0x62, 0x72, 0x61, 0x72, 0x79, 0x53, + 0x65, 0x61, 0x72, 0x63, 0x68, 0x50, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x06, 0x52, 0x65, 0x73, + 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, + 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, + 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x22, 0x31, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, + 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x65, 0x6d, + 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x02, 0x52, 0x0a, + 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x32, 0xfe, 0x01, 0x0a, 0x03, 0x4c, + 0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x12, 0x2e, 0x6c, + 0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2c, + 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, + 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, + 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x09, + 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x11, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, + 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0b, 0x2e, 0x6c, + 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x0d, 0x50, + 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x6c, + 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, + 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, + 0x01, 0x12, 0x38, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x13, + 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, + 0x6f, 0x6e, 0x73, 0x1a, 0x14, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, + 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x42, 0x57, 0x0a, 0x1b, 0x69, + 0x6f, 0x2e, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, + 0x2e, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, + 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, + 0x63, 0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/pkg/grpc/proto/llmserver.proto b/pkg/grpc/proto/llmserver.proto index b6fa4cd..32fe0ff 100644 --- a/pkg/grpc/proto/llmserver.proto +++ b/pkg/grpc/proto/llmserver.proto @@ -76,6 +76,8 @@ message ModelOptions { int32 NGPULayers = 12; string MainGPU = 13; string TensorSplit = 14; + int32 Threads = 15; + string LibrarySearchPath = 16; } message Result { diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 1acde4c..3a0c5ea 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -19,7 +19,6 @@ import ( transformers "github.com/go-skynet/go-ggml-transformers.cpp" "github.com/hashicorp/go-multierror" "github.com/hpcloud/tail" - gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" "github.com/phayes/freeport" "github.com/rs/zerolog/log" @@ -140,11 +139,11 @@ var lcHuggingFace = func(repoId string) (interface{}, error) { // } // } -func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) { - return func(s string) (interface{}, error) { - return gpt4all.New(s, opts...) - } -} +// func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) { +// return func(s string) (interface{}, error) { +// return gpt4all.New(s, opts...) +// } +// } func rwkvLM(tokenFile string, threads uint32) func(string) (interface{}, error) { return func(s string) (interface{}, error) { @@ -287,7 +286,10 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model interface{}, err err case StarcoderBackend: return ml.LoadModel(o.modelFile, starCoder) case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All: - return ml.LoadModel(o.modelFile, gpt4allLM(gpt4all.SetThreads(int(o.threads)), gpt4all.SetLibrarySearchPath(filepath.Join(o.assetDir, "backend-assets", "gpt4all")))) + o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all") + return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt4All, o)) + + // return ml.LoadModel(o.modelFile, gpt4allLM(gpt4all.SetThreads(int(o.threads)), gpt4all.SetLibrarySearchPath(filepath.Join(o.assetDir, "backend-assets", "gpt4all")))) case BertEmbeddingsBackend: return ml.LoadModel(o.modelFile, bertEmbeddings) case RwkvBackend: From f2f1d7fe72c8205f3740c41de53b4b868f5d72cf Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 04/12] feat: use gRPC for transformers Signed-off-by: Ettore Di Giacinto --- Makefile | 45 ++++-- api/prediction.go | 194 +------------------------ cmd/grpc/dolly/main.go | 23 +++ cmd/grpc/gpt2/main.go | 23 +++ cmd/grpc/gptj/main.go | 23 +++ cmd/grpc/gptneox/main.go | 23 +++ cmd/grpc/mpt/main.go | 23 +++ cmd/grpc/replit/main.go | 23 +++ cmd/grpc/starcoder/main.go | 23 +++ pkg/grpc/llm/ggml/starcoder.go | 0 pkg/grpc/llm/transformers/dolly.go | 42 ++++++ pkg/grpc/llm/transformers/gpt2.go | 42 ++++++ pkg/grpc/llm/transformers/gptj.go | 42 ++++++ pkg/grpc/llm/transformers/gptneox.go | 42 ++++++ pkg/grpc/llm/transformers/mpt.go | 42 ++++++ pkg/grpc/llm/transformers/predict.go | 26 ++++ pkg/grpc/llm/transformers/replit.go | 42 ++++++ pkg/grpc/llm/transformers/starcoder.go | 42 ++++++ pkg/model/initializers.go | 56 +------ 19 files changed, 518 insertions(+), 258 deletions(-) create mode 100644 cmd/grpc/dolly/main.go create mode 100644 cmd/grpc/gpt2/main.go create mode 100644 cmd/grpc/gptj/main.go create mode 100644 cmd/grpc/gptneox/main.go create mode 100644 cmd/grpc/mpt/main.go create mode 100644 cmd/grpc/replit/main.go create mode 100644 cmd/grpc/starcoder/main.go delete mode 100644 pkg/grpc/llm/ggml/starcoder.go create mode 100644 pkg/grpc/llm/transformers/dolly.go create mode 100644 pkg/grpc/llm/transformers/gpt2.go create mode 100644 pkg/grpc/llm/transformers/gptj.go create mode 100644 pkg/grpc/llm/transformers/gptneox.go create mode 100644 pkg/grpc/llm/transformers/mpt.go create mode 100644 pkg/grpc/llm/transformers/predict.go create mode 100644 pkg/grpc/llm/transformers/replit.go create mode 100644 pkg/grpc/llm/transformers/starcoder.go diff --git a/Makefile b/Makefile index df7a16e..610cc6f 100644 --- a/Makefile +++ b/Makefile @@ -189,21 +189,6 @@ gpt4all/gpt4all-bindings/golang/libgpt4all.a: gpt4all go-ggml-transformers: git clone --recurse-submodules https://github.com/go-skynet/go-ggml-transformers.cpp go-ggml-transformers cd go-ggml-transformers && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1 - # This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml.. - @find ./go-ggml-transformers -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_print_usage/gpt2_print_usage/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_print_usage/gpt2_print_usage/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_params_parse/gpt2_params_parse/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_params_parse/gpt2_params_parse/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_random_prompt/gpt2_random_prompt/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_random_prompt/gpt2_random_prompt/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/set_numa_thread_affinity/transformers_set_numa_thread_affinity/g' {} + - @find ./go-ggml-transformers -type f -name "*.c" -exec sed -i'' -e 's/set_numa_thread_affinity/transformers_set_numa_thread_affinity/g' {} + - @find ./go-ggml-transformers -type f -name "*.c" -exec sed -i'' -e 's/clear_numa_thread_affinity/transformers_clear_numa_thread_affinity/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/clear_numa_thread_affinity/transformers_clear_numa_thread_affinity/g' {} + go-ggml-transformers/libtransformers.a: go-ggml-transformers $(MAKE) -C go-ggml-transformers libtransformers.a @@ -359,4 +344,32 @@ gpt4all-grpc: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-binding CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ \ $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./cmd/grpc/gpt4all/ -grpcs: falcon-grpc llama-grpc gpt4all-grpc \ No newline at end of file +dolly-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/dolly ./cmd/grpc/dolly/ + +gpt2-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt2 ./cmd/grpc/gpt2/ + +gptj-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptj ./cmd/grpc/gptj/ + +gptneox-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptneox ./cmd/grpc/gptneox/ + +mpt-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/mpt ./cmd/grpc/mpt/ + +replit-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/replit ./cmd/grpc/replit/ + +starcoder-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/starcoder ./cmd/grpc/starcoder/ + +grpcs: falcon-grpc llama-grpc gpt4all-grpc dolly-grpc gpt2-grpc gptj-grpc gptneox-grpc mpt-grpc replit-grpc starcoder-grpc \ No newline at end of file diff --git a/api/prediction.go b/api/prediction.go index f24376c..4a9c1c8 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -17,7 +17,6 @@ import ( "github.com/go-skynet/LocalAI/pkg/stablediffusion" "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" - transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) // mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 @@ -244,7 +243,7 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to opts := []model.Option{ model.WithLoadGRPCOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), + model.WithThreads(uint32(c.Threads)), // GPT4all uses this model.WithAssetDir(o.assetsDestination), model.WithModelFile(modelFile), } @@ -279,102 +278,6 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to return response, nil } - case *transformers.GPTNeoX: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.Replit: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.Starcoder: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.MPT: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } case *bloomz.Bloomz: fn = func() (string, error) { // Generate the prediction using the language model @@ -395,102 +298,7 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to predictOptions..., ) } - case *transformers.Falcon: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.GPTJ: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.Dolly: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.GPT2: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - return model.Predict( - s, - predictOptions..., - ) - } case *grpc.Client: // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported supportStreams = true diff --git a/cmd/grpc/dolly/main.go b/cmd/grpc/dolly/main.go new file mode 100644 index 0000000..43bba92 --- /dev/null +++ b/cmd/grpc/dolly/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Dolly{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/gpt2/main.go b/cmd/grpc/gpt2/main.go new file mode 100644 index 0000000..d9fe275 --- /dev/null +++ b/cmd/grpc/gpt2/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.GPT2{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/gptj/main.go b/cmd/grpc/gptj/main.go new file mode 100644 index 0000000..27d8210 --- /dev/null +++ b/cmd/grpc/gptj/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.GPTJ{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/gptneox/main.go b/cmd/grpc/gptneox/main.go new file mode 100644 index 0000000..3d005ca --- /dev/null +++ b/cmd/grpc/gptneox/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.GPTNeoX{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/mpt/main.go b/cmd/grpc/mpt/main.go new file mode 100644 index 0000000..58456a7 --- /dev/null +++ b/cmd/grpc/mpt/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.MPT{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/replit/main.go b/cmd/grpc/replit/main.go new file mode 100644 index 0000000..aed67fb --- /dev/null +++ b/cmd/grpc/replit/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Replit{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/starcoder/main.go b/cmd/grpc/starcoder/main.go new file mode 100644 index 0000000..2847acf --- /dev/null +++ b/cmd/grpc/starcoder/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Starcoder{}); err != nil { + panic(err) + } +} diff --git a/pkg/grpc/llm/ggml/starcoder.go b/pkg/grpc/llm/ggml/starcoder.go deleted file mode 100644 index e69de29..0000000 diff --git a/pkg/grpc/llm/transformers/dolly.go b/pkg/grpc/llm/transformers/dolly.go new file mode 100644 index 0000000..28a44a7 --- /dev/null +++ b/pkg/grpc/llm/transformers/dolly.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Dolly struct { + dolly *transformers.Dolly +} + +func (llm *Dolly) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewDolly(opts.Model) + llm.dolly = model + return err +} + +func (llm *Dolly) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) { + return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/gpt2.go b/pkg/grpc/llm/transformers/gpt2.go new file mode 100644 index 0000000..0eaf787 --- /dev/null +++ b/pkg/grpc/llm/transformers/gpt2.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type GPT2 struct { + gpt2 *transformers.GPT2 +} + +func (llm *GPT2) Load(opts *pb.ModelOptions) error { + model, err := transformers.New(opts.Model) + llm.gpt2 = model + return err +} + +func (llm *GPT2) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/gptj.go b/pkg/grpc/llm/transformers/gptj.go new file mode 100644 index 0000000..a7138ef --- /dev/null +++ b/pkg/grpc/llm/transformers/gptj.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type GPTJ struct { + gptj *transformers.GPTJ +} + +func (llm *GPTJ) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewGPTJ(opts.Model) + llm.gptj = model + return err +} + +func (llm *GPTJ) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/gptneox.go b/pkg/grpc/llm/transformers/gptneox.go new file mode 100644 index 0000000..2edf4ba --- /dev/null +++ b/pkg/grpc/llm/transformers/gptneox.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type GPTNeoX struct { + gptneox *transformers.GPTNeoX +} + +func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewGPTNeoX(opts.Model) + llm.gptneox = model + return err +} + +func (llm *GPTNeoX) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/mpt.go b/pkg/grpc/llm/transformers/mpt.go new file mode 100644 index 0000000..ab88418 --- /dev/null +++ b/pkg/grpc/llm/transformers/mpt.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type MPT struct { + mpt *transformers.MPT +} + +func (llm *MPT) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewMPT(opts.Model) + llm.mpt = model + return err +} + +func (llm *MPT) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) { + return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/predict.go b/pkg/grpc/llm/transformers/predict.go new file mode 100644 index 0000000..861d119 --- /dev/null +++ b/pkg/grpc/llm/transformers/predict.go @@ -0,0 +1,26 @@ +package transformers + +import ( + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +func buildPredictOptions(opts *pb.PredictOptions) []transformers.PredictOption { + predictOptions := []transformers.PredictOption{ + transformers.SetTemperature(float64(opts.Temperature)), + transformers.SetTopP(float64(opts.TopP)), + transformers.SetTopK(int(opts.TopK)), + transformers.SetTokens(int(opts.Tokens)), + transformers.SetThreads(int(opts.Threads)), + } + + if opts.Batch != 0 { + predictOptions = append(predictOptions, transformers.SetBatch(int(opts.Batch))) + } + + if opts.Seed != 0 { + predictOptions = append(predictOptions, transformers.SetSeed(int(opts.Seed))) + } + + return predictOptions +} diff --git a/pkg/grpc/llm/transformers/replit.go b/pkg/grpc/llm/transformers/replit.go new file mode 100644 index 0000000..ca1d66f --- /dev/null +++ b/pkg/grpc/llm/transformers/replit.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Replit struct { + replit *transformers.Replit +} + +func (llm *Replit) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewReplit(opts.Model) + llm.replit = model + return err +} + +func (llm *Replit) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) { + return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/grpc/llm/transformers/starcoder.go b/pkg/grpc/llm/transformers/starcoder.go new file mode 100644 index 0000000..6e1a94b --- /dev/null +++ b/pkg/grpc/llm/transformers/starcoder.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Starcoder struct { + starcoder *transformers.Starcoder +} + +func (llm *Starcoder) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewStarcoder(opts.Model) + llm.starcoder = model + return err +} + +func (llm *Starcoder) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + +func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) { + return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) { + go func() { + res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() +} diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 3a0c5ea..44a0638 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -16,7 +16,6 @@ import ( "github.com/go-skynet/LocalAI/pkg/tts" bloomz "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" - transformers "github.com/go-skynet/go-ggml-transformers.cpp" "github.com/hashicorp/go-multierror" "github.com/hpcloud/tail" "github.com/phayes/freeport" @@ -55,7 +54,6 @@ var autoLoadBackends []string = []string{ LlamaBackend, Gpt4All, RwkvBackend, - //GGLLMFalconBackend, WhisperBackend, BertEmbeddingsBackend, GPTNeoXBackend, @@ -69,40 +67,6 @@ var autoLoadBackends []string = []string{ BloomzBackend, } -var starCoder = func(modelFile string) (interface{}, error) { - return transformers.NewStarcoder(modelFile) -} - -var mpt = func(modelFile string) (interface{}, error) { - return transformers.NewMPT(modelFile) -} - -var dolly = func(modelFile string) (interface{}, error) { - return transformers.NewDolly(modelFile) -} - -// func ggllmFalcon(opts ...ggllm.ModelOption) func(string) (interface{}, error) { -// return func(s string) (interface{}, error) { -// return ggllm.New(s, opts...) -// } -// } - -var gptNeoX = func(modelFile string) (interface{}, error) { - return transformers.NewGPTNeoX(modelFile) -} - -var replit = func(modelFile string) (interface{}, error) { - return transformers.NewReplit(modelFile) -} - -var gptJ = func(modelFile string) (interface{}, error) { - return transformers.NewGPTJ(modelFile) -} - -var falcon = func(modelFile string) (interface{}, error) { - return transformers.NewFalcon(modelFile) -} - var bertEmbeddings = func(modelFile string) (interface{}, error) { return bert.New(modelFile) } @@ -111,10 +75,6 @@ var bloomzLM = func(modelFile string) (interface{}, error) { return bloomz.New(modelFile) } -var transformersLM = func(modelFile string) (interface{}, error) { - return transformers.New(modelFile) -} - var stableDiffusion = func(assetDir string) (interface{}, error) { return stablediffusion.New(assetDir) } @@ -261,34 +221,32 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model interface{}, err err log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) switch strings.ToLower(o.backendString) { case LlamaBackend: - // return ml.LoadModel(o.modelFile, llamaLM(o.llamaOpts...)) return ml.LoadModel(o.modelFile, ml.grpcModel(LlamaBackend, o)) case BloomzBackend: return ml.LoadModel(o.modelFile, bloomzLM) case GPTJBackend: - return ml.LoadModel(o.modelFile, gptJ) + return ml.LoadModel(o.modelFile, ml.grpcModel(GPTJBackend, o)) case DollyBackend: - return ml.LoadModel(o.modelFile, dolly) + return ml.LoadModel(o.modelFile, ml.grpcModel(DollyBackend, o)) case MPTBackend: - return ml.LoadModel(o.modelFile, mpt) + return ml.LoadModel(o.modelFile, ml.grpcModel(MPTBackend, o)) case Gpt2Backend: - return ml.LoadModel(o.modelFile, transformersLM) + return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt2Backend, o)) case FalconBackend: return ml.LoadModel(o.modelFile, ml.grpcModel(FalconBackend, o)) case GPTNeoXBackend: - return ml.LoadModel(o.modelFile, gptNeoX) + return ml.LoadModel(o.modelFile, ml.grpcModel(GPTNeoXBackend, o)) case ReplitBackend: - return ml.LoadModel(o.modelFile, replit) + return ml.LoadModel(o.modelFile, ml.grpcModel(ReplitBackend, o)) case StableDiffusionBackend: return ml.LoadModel(o.modelFile, stableDiffusion) case PiperBackend: return ml.LoadModel(o.modelFile, piperTTS(filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data"))) case StarcoderBackend: - return ml.LoadModel(o.modelFile, starCoder) + return ml.LoadModel(o.modelFile, ml.grpcModel(StarcoderBackend, o)) case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All: o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all") return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt4All, o)) - // return ml.LoadModel(o.modelFile, gpt4allLM(gpt4all.SetThreads(int(o.threads)), gpt4all.SetLibrarySearchPath(filepath.Join(o.assetDir, "backend-assets", "gpt4all")))) case BertEmbeddingsBackend: return ml.LoadModel(o.modelFile, bertEmbeddings) From 5dcfdbe51da5b8c9159a358ab1694c0e4f68f437 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 05/12] feat: various refactorings Signed-off-by: Ettore Di Giacinto --- api/api.go | 108 ++-- api/api_test.go | 12 +- api/backend/embeddings.go | 107 ++++ api/backend/image.go | 56 ++ api/backend/llm.go | 160 ++++++ api/backend/lock.go | 22 + api/backend/options.go | 98 ++++ api/config.go | 401 ------------- api/config/config.go | 209 +++++++ api/{ => config}/config_test.go | 24 +- api/config/prediction.go | 37 ++ api/{ => localai}/gallery.go | 21 +- api/{ => localai}/localai.go | 21 +- api/openai.go | 973 -------------------------------- api/openai/api.go | 105 ++++ api/openai/chat.go | 320 +++++++++++ api/openai/completion.go | 159 ++++++ api/openai/edit.go | 67 +++ api/openai/embeddings.go | 70 +++ api/openai/image.go | 158 ++++++ api/openai/inference.go | 36 ++ api/openai/list.go | 37 ++ api/openai/request.go | 234 ++++++++ api/openai/transcription.go | 91 +++ api/{ => options}/options.go | 84 +-- api/prediction.go | 415 -------------- main.go | 35 +- pkg/grpc/llm/falcon/falcon.go | 3 + 28 files changed, 2130 insertions(+), 1933 deletions(-) create mode 100644 api/backend/embeddings.go create mode 100644 api/backend/image.go create mode 100644 api/backend/llm.go create mode 100644 api/backend/lock.go create mode 100644 api/backend/options.go delete mode 100644 api/config.go create mode 100644 api/config/config.go rename api/{ => config}/config_test.go (62%) create mode 100644 api/config/prediction.go rename api/{ => localai}/gallery.go (86%) rename api/{ => localai}/localai.go (68%) delete mode 100644 api/openai.go create mode 100644 api/openai/api.go create mode 100644 api/openai/chat.go create mode 100644 api/openai/completion.go create mode 100644 api/openai/edit.go create mode 100644 api/openai/embeddings.go create mode 100644 api/openai/image.go create mode 100644 api/openai/inference.go create mode 100644 api/openai/list.go create mode 100644 api/openai/request.go create mode 100644 api/openai/transcription.go rename api/{ => options}/options.go (60%) delete mode 100644 api/prediction.go diff --git a/api/api.go b/api/api.go index 1438f1f..5d4f4c9 100644 --- a/api/api.go +++ b/api/api.go @@ -3,8 +3,13 @@ package api import ( "errors" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/localai" + "github.com/go-skynet/LocalAI/api/openai" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/pkg/assets" + "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/logger" @@ -13,18 +18,18 @@ import ( "github.com/rs/zerolog/log" ) -func App(opts ...AppOption) (*fiber.App, error) { - options := newOptions(opts...) +func App(opts ...options.AppOption) (*fiber.App, error) { + options := options.NewOptions(opts...) zerolog.SetGlobalLevel(zerolog.InfoLevel) - if options.debug { + if options.Debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) } // Return errors as JSON responses app := fiber.New(fiber.Config{ - BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: options.disableMessage, + BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + DisableStartupMessage: options.DisableMessage, // Override default error handler ErrorHandler: func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 @@ -38,44 +43,44 @@ func App(opts ...AppOption) (*fiber.App, error) { // Send custom error page return ctx.Status(code).JSON( - ErrorResponse{ - Error: &APIError{Message: err.Error(), Code: code}, + openai.ErrorResponse{ + Error: &openai.APIError{Message: err.Error(), Code: code}, }, ) }, }) - if options.debug { + if options.Debug { app.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) } - log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.threads, options.loader.ModelPath) + log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) - cm := NewConfigMerger() - if err := cm.LoadConfigs(options.loader.ModelPath); err != nil { + cm := config.NewConfigLoader() + if err := cm.LoadConfigs(options.Loader.ModelPath); err != nil { log.Error().Msgf("error loading config files: %s", err.Error()) } - if options.configFile != "" { - if err := cm.LoadConfigFile(options.configFile); err != nil { + if options.ConfigFile != "" { + if err := cm.LoadConfigFile(options.ConfigFile); err != nil { log.Error().Msgf("error loading config file: %s", err.Error()) } } - if options.debug { + if options.Debug { for _, v := range cm.ListConfigs() { cfg, _ := cm.GetConfig(v) log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) } } - if options.assetsDestination != "" { + if options.AssetsDestination != "" { // Extract files from the embedded FS - err := assets.ExtractFiles(options.backendAssets, options.assetsDestination) - log.Debug().Msgf("Extracting backend assets files to %s", options.assetsDestination) + err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) + log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) if err != nil { log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) } @@ -84,31 +89,32 @@ func App(opts ...AppOption) (*fiber.App, error) { // Default middleware config app.Use(recover.New()) - if options.preloadJSONModels != "" { - if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm, options.galleries); err != nil { + if options.PreloadJSONModels != "" { + if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil { return nil, err } } - if options.preloadModelsFromPath != "" { - if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm, options.galleries); err != nil { + if options.PreloadModelsFromPath != "" { + if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cm, options.Galleries); err != nil { return nil, err } } - if options.cors { - if options.corsAllowOrigins == "" { - app.Use(cors.New()) + if options.CORS { + var c func(ctx *fiber.Ctx) error + if options.CORSAllowOrigins == "" { + c = cors.New() } else { - app.Use(cors.New(cors.Config{ - AllowOrigins: options.corsAllowOrigins, - })) + c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) } + + app.Use(c) } // LocalAI API endpoints - applier := newGalleryApplier(options.loader.ModelPath) - applier.start(options.context, cm) + galleryService := localai.NewGalleryService(options.Loader.ModelPath) + galleryService.Start(options.Context, cm) app.Get("/version", func(c *fiber.Ctx) error { return c.JSON(struct { @@ -116,43 +122,43 @@ func App(opts ...AppOption) (*fiber.App, error) { }{Version: internal.PrintableVersion()}) }) - app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries)) - app.Get("/models/available", listModelFromGallery(options.galleries, options.loader.ModelPath)) - app.Get("/models/jobs/:uuid", getOpStatus(applier)) + app.Post("/models/apply", localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries)) + app.Get("/models/available", localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath)) + app.Get("/models/jobs/:uuid", localai.GetOpStatusEndpoint(galleryService)) // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", chatEndpoint(cm, options)) - app.Post("/chat/completions", chatEndpoint(cm, options)) + app.Post("/v1/chat/completions", openai.ChatEndpoint(cm, options)) + app.Post("/chat/completions", openai.ChatEndpoint(cm, options)) // edit - app.Post("/v1/edits", editEndpoint(cm, options)) - app.Post("/edits", editEndpoint(cm, options)) + app.Post("/v1/edits", openai.EditEndpoint(cm, options)) + app.Post("/edits", openai.EditEndpoint(cm, options)) // completion - app.Post("/v1/completions", completionEndpoint(cm, options)) - app.Post("/completions", completionEndpoint(cm, options)) - app.Post("/v1/engines/:model/completions", completionEndpoint(cm, options)) + app.Post("/v1/completions", openai.CompletionEndpoint(cm, options)) + app.Post("/completions", openai.CompletionEndpoint(cm, options)) + app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cm, options)) // embeddings - app.Post("/v1/embeddings", embeddingsEndpoint(cm, options)) - app.Post("/embeddings", embeddingsEndpoint(cm, options)) - app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options)) + app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cm, options)) + app.Post("/embeddings", openai.EmbeddingsEndpoint(cm, options)) + app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cm, options)) // audio - app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options)) - app.Post("/tts", ttsEndpoint(cm, options)) + app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cm, options)) + app.Post("/tts", localai.TTSEndpoint(cm, options)) // images - app.Post("/v1/images/generations", imageEndpoint(cm, options)) + app.Post("/v1/images/generations", openai.ImageEndpoint(cm, options)) - if options.imageDir != "" { - app.Static("/generated-images", options.imageDir) + if options.ImageDir != "" { + app.Static("/generated-images", options.ImageDir) } - if options.audioDir != "" { - app.Static("/generated-audio", options.audioDir) + if options.AudioDir != "" { + app.Static("/generated-audio", options.AudioDir) } ok := func(c *fiber.Ctx) error { @@ -164,8 +170,8 @@ func App(opts ...AppOption) (*fiber.App, error) { app.Get("/readyz", ok) // models - app.Get("/v1/models", listModels(options.loader, cm)) - app.Get("/models", listModels(options.loader, cm)) + app.Get("/v1/models", openai.ListModelsEndpoint(options.Loader, cm)) + app.Get("/models", openai.ListModelsEndpoint(options.Loader, cm)) return app, nil } diff --git a/api/api_test.go b/api/api_test.go index 43aa30b..a69e60d 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -13,6 +13,7 @@ import ( "runtime" . "github.com/go-skynet/LocalAI/api" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" @@ -154,9 +155,10 @@ var _ = Describe("API test", func() { }, } - app, err = App(WithContext(c), - WithGalleries(galleries), - WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir)) + app, err = App( + options.WithContext(c), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir)) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -342,7 +344,7 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - app, err = App(WithContext(c), WithModelLoader(modelLoader)) + app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader)) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -462,7 +464,7 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - app, err = App(WithContext(c), WithModelLoader(modelLoader), WithConfigFile(os.Getenv("CONFIG_FILE"))) + app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader), options.WithConfigFile(os.Getenv("CONFIG_FILE"))) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go new file mode 100644 index 0000000..cb77b6f --- /dev/null +++ b/api/backend/embeddings.go @@ -0,0 +1,107 @@ +package backend + +import ( + "context" + "fmt" + "sync" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc" + model "github.com/go-skynet/LocalAI/pkg/model" + bert "github.com/go-skynet/go-bert.cpp" +) + +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { + if !c.Embeddings { + return nil, fmt.Errorf("endpoint disabled for this model by API configuration") + } + + modelFile := c.Model + + grpcOpts := gRPCModelOpts(c) + + var inferenceModel interface{} + var err error + + opts := []model.Option{ + model.WithLoadGRPCOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), + model.WithAssetDir(o.AssetsDestination), + model.WithModelFile(modelFile), + } + + if c.Backend == "" { + inferenceModel, err = loader.GreedyLoader(opts...) + } else { + opts = append(opts, model.WithBackendString(c.Backend)) + inferenceModel, err = loader.BackendLoader(opts...) + } + if err != nil { + return nil, err + } + + var fn func() ([]float32, error) + switch model := inferenceModel.(type) { + case *grpc.Client: + fn = func() ([]float32, error) { + predictOptions := gRPCPredictOpts(c, loader.ModelPath) + if len(tokens) > 0 { + embeds := []int32{} + + for _, t := range tokens { + embeds = append(embeds, int32(t)) + } + predictOptions.EmbeddingTokens = embeds + + res, err := model.Embeddings(context.TODO(), predictOptions) + if err != nil { + return nil, err + } + + return res.Embeddings, nil + } + predictOptions.Embeddings = s + + res, err := model.Embeddings(context.TODO(), predictOptions) + if err != nil { + return nil, err + } + + return res.Embeddings, nil + } + + // bert embeddings + case *bert.Bert: + fn = func() ([]float32, error) { + if len(tokens) > 0 { + return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) + } + return model.Embeddings(s, bert.SetThreads(c.Threads)) + } + default: + fn = func() ([]float32, error) { + return nil, fmt.Errorf("embeddings not supported by the backend") + } + } + + return func() ([]float32, error) { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + l := Lock(modelFile) + defer l.Unlock() + + embeds, err := fn() + if err != nil { + return embeds, err + } + // Remove trailing 0s + for i := len(embeds) - 1; i >= 0; i-- { + if embeds[i] == 0.0 { + embeds = embeds[:i] + } else { + break + } + } + return embeds, nil + }, nil +} diff --git a/api/backend/image.go b/api/backend/image.go new file mode 100644 index 0000000..47ae842 --- /dev/null +++ b/api/backend/image.go @@ -0,0 +1,56 @@ +package backend + +import ( + "fmt" + "sync" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/stablediffusion" +) + +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { + if c.Backend != model.StableDiffusionBackend { + return nil, fmt.Errorf("endpoint only working with stablediffusion models") + } + + inferenceModel, err := loader.BackendLoader( + model.WithBackendString(c.Backend), + model.WithAssetDir(o.AssetsDestination), + model.WithThreads(uint32(c.Threads)), + model.WithModelFile(c.ImageGenerationAssets), + ) + if err != nil { + return nil, err + } + + var fn func() error + switch model := inferenceModel.(type) { + case *stablediffusion.StableDiffusion: + fn = func() error { + return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) + } + + default: + fn = func() error { + return fmt.Errorf("creation of images not supported by the backend") + } + } + + return func() error { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mutexMap.Lock() + l, ok := mutexes[c.Backend] + if !ok { + m := &sync.Mutex{} + mutexes[c.Backend] = m + l = m + } + mutexMap.Unlock() + l.Lock() + defer l.Unlock() + + return fn() + }, nil +} diff --git a/api/backend/llm.go b/api/backend/llm.go new file mode 100644 index 0000000..d2f8ef6 --- /dev/null +++ b/api/backend/llm.go @@ -0,0 +1,160 @@ +package backend + +import ( + "context" + "regexp" + "strings" + "sync" + + "github.com/donomii/go-rwkv.cpp" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc" + "github.com/go-skynet/LocalAI/pkg/langchain" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/bloomz.cpp" +) + +func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { + supportStreams := false + modelFile := c.Model + + grpcOpts := gRPCModelOpts(c) + + var inferenceModel interface{} + var err error + + opts := []model.Option{ + model.WithLoadGRPCOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), // GPT4all uses this + model.WithAssetDir(o.AssetsDestination), + model.WithModelFile(modelFile), + } + + if c.Backend == "" { + inferenceModel, err = loader.GreedyLoader(opts...) + } else { + opts = append(opts, model.WithBackendString(c.Backend)) + inferenceModel, err = loader.BackendLoader(opts...) + } + if err != nil { + return nil, err + } + + var fn func() (string, error) + + switch model := inferenceModel.(type) { + case *rwkv.RwkvState: + supportStreams = true + + fn = func() (string, error) { + stopWord := "\n" + if len(c.StopWords) > 0 { + stopWord = c.StopWords[0] + } + + if err := model.ProcessInput(s); err != nil { + return "", err + } + + response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) + + return response, nil + } + case *bloomz.Bloomz: + fn = func() (string, error) { + // Generate the prediction using the language model + predictOptions := []bloomz.PredictOption{ + bloomz.SetTemperature(c.Temperature), + bloomz.SetTopP(c.TopP), + bloomz.SetTopK(c.TopK), + bloomz.SetTokens(c.Maxtokens), + bloomz.SetThreads(c.Threads), + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) + } + + return model.Predict( + s, + predictOptions..., + ) + } + + case *grpc.Client: + // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported + supportStreams = true + fn = func() (string, error) { + + opts := gRPCPredictOpts(c, loader.ModelPath) + opts.Prompt = s + if tokenCallback != nil { + ss := "" + err := model.PredictStream(context.TODO(), opts, func(s string) { + tokenCallback(s) + ss += s + }) + return ss, err + } else { + reply, err := model.Predict(context.TODO(), opts) + return reply.Message, err + } + } + case *langchain.HuggingFace: + fn = func() (string, error) { + + // Generate the prediction using the language model + predictOptions := []langchain.PredictOption{ + langchain.SetModel(c.Model), + langchain.SetMaxTokens(c.Maxtokens), + langchain.SetTemperature(c.Temperature), + langchain.SetStopWords(c.StopWords), + } + + pred, er := model.PredictHuggingFace(s, predictOptions...) + if er != nil { + return "", er + } + return pred.Completion, nil + } + } + + return func() (string, error) { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + l := Lock(modelFile) + defer l.Unlock() + + res, err := fn() + if tokenCallback != nil && !supportStreams { + tokenCallback(res) + } + return res, err + }, nil +} + +var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) +var mu sync.Mutex = sync.Mutex{} + +func Finetune(config config.Config, input, prediction string) string { + if config.Echo { + prediction = input + prediction + } + + for _, c := range config.Cutstrings { + mu.Lock() + reg, ok := cutstrings[c] + if !ok { + cutstrings[c] = regexp.MustCompile(c) + reg = cutstrings[c] + } + mu.Unlock() + prediction = reg.ReplaceAllString(prediction, "") + } + + for _, c := range config.TrimSpace { + prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) + } + return prediction + +} diff --git a/api/backend/lock.go b/api/backend/lock.go new file mode 100644 index 0000000..6b4f577 --- /dev/null +++ b/api/backend/lock.go @@ -0,0 +1,22 @@ +package backend + +import "sync" + +// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 +var mutexMap sync.Mutex +var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) + +func Lock(s string) *sync.Mutex { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mutexMap.Lock() + l, ok := mutexes[s] + if !ok { + m := &sync.Mutex{} + mutexes[s] = m + l = m + } + mutexMap.Unlock() + l.Lock() + + return l +} diff --git a/api/backend/options.go b/api/backend/options.go new file mode 100644 index 0000000..f19dbae --- /dev/null +++ b/api/backend/options.go @@ -0,0 +1,98 @@ +package backend + +import ( + "os" + "path/filepath" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/pkg/langchain" + "github.com/go-skynet/bloomz.cpp" +) + +func langchainOptions(c config.Config) []langchain.PredictOption { + return []langchain.PredictOption{ + langchain.SetModel(c.Model), + langchain.SetMaxTokens(c.Maxtokens), + langchain.SetTemperature(c.Temperature), + langchain.SetStopWords(c.StopWords), + } +} + +func bloomzOptions(c config.Config) []bloomz.PredictOption { + // Generate the prediction using the language model + predictOptions := []bloomz.PredictOption{ + bloomz.SetTemperature(c.Temperature), + bloomz.SetTopP(c.TopP), + bloomz.SetTopK(c.TopK), + bloomz.SetTokens(c.Maxtokens), + bloomz.SetThreads(c.Threads), + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) + } + return predictOptions +} +func gRPCModelOpts(c config.Config) *pb.ModelOptions { + b := 512 + if c.Batch != 0 { + b = c.Batch + } + return &pb.ModelOptions{ + ContextSize: int32(c.ContextSize), + Seed: int32(c.Seed), + NBatch: int32(b), + F16Memory: c.F16, + MLock: c.MMlock, + NUMA: c.NUMA, + Embeddings: c.Embeddings, + LowVRAM: c.LowVRAM, + NGPULayers: int32(c.NGPULayers), + MMap: c.MMap, + MainGPU: c.MainGPU, + Threads: int32(c.Threads), + TensorSplit: c.TensorSplit, + } +} + +func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { + promptCachePath := "" + if c.PromptCachePath != "" { + p := filepath.Join(modelPath, c.PromptCachePath) + os.MkdirAll(filepath.Dir(p), 0755) + promptCachePath = p + } + return &pb.PredictOptions{ + Temperature: float32(c.Temperature), + TopP: float32(c.TopP), + TopK: int32(c.TopK), + Tokens: int32(c.Maxtokens), + Threads: int32(c.Threads), + PromptCacheAll: c.PromptCacheAll, + PromptCacheRO: c.PromptCacheRO, + PromptCachePath: promptCachePath, + F16KV: c.F16, + DebugMode: c.Debug, + Grammar: c.Grammar, + + Mirostat: int32(c.Mirostat), + MirostatETA: float32(c.MirostatETA), + MirostatTAU: float32(c.MirostatTAU), + Debug: c.Debug, + StopPrompts: c.StopWords, + Repeat: int32(c.RepeatPenalty), + NKeep: int32(c.Keep), + Batch: int32(c.Batch), + IgnoreEOS: c.IgnoreEOS, + Seed: int32(c.Seed), + FrequencyPenalty: float32(c.FrequencyPenalty), + MLock: c.MMlock, + MMap: c.MMap, + MainGPU: c.MainGPU, + TensorSplit: c.TensorSplit, + TailFreeSamplingZ: float32(c.TFZ), + TypicalP: float32(c.TypicalP), + } +} diff --git a/api/config.go b/api/config.go deleted file mode 100644 index 57fe0d1..0000000 --- a/api/config.go +++ /dev/null @@ -1,401 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "io/fs" - "os" - "path/filepath" - "strings" - "sync" - - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" - "gopkg.in/yaml.v3" -) - -type Config struct { - OpenAIRequest `yaml:"parameters"` - Name string `yaml:"name"` - StopWords []string `yaml:"stopwords"` - Cutstrings []string `yaml:"cutstrings"` - TrimSpace []string `yaml:"trimspace"` - ContextSize int `yaml:"context_size"` - F16 bool `yaml:"f16"` - NUMA bool `yaml:"numa"` - Threads int `yaml:"threads"` - Debug bool `yaml:"debug"` - Roles map[string]string `yaml:"roles"` - Embeddings bool `yaml:"embeddings"` - Backend string `yaml:"backend"` - TemplateConfig TemplateConfig `yaml:"template"` - MirostatETA float64 `yaml:"mirostat_eta"` - MirostatTAU float64 `yaml:"mirostat_tau"` - Mirostat int `yaml:"mirostat"` - NGPULayers int `yaml:"gpu_layers"` - MMap bool `yaml:"mmap"` - MMlock bool `yaml:"mmlock"` - LowVRAM bool `yaml:"low_vram"` - - TensorSplit string `yaml:"tensor_split"` - MainGPU string `yaml:"main_gpu"` - ImageGenerationAssets string `yaml:"asset_dir"` - - PromptCachePath string `yaml:"prompt_cache_path"` - PromptCacheAll bool `yaml:"prompt_cache_all"` - PromptCacheRO bool `yaml:"prompt_cache_ro"` - - Grammar string `yaml:"grammar"` - - FunctionsConfig Functions `yaml:"function"` - - PromptStrings, InputStrings []string - InputToken [][]int - functionCallString, functionCallNameString string -} - -type Functions struct { - DisableNoAction bool `yaml:"disable_no_action"` - NoActionFunctionName string `yaml:"no_action_function_name"` - NoActionDescriptionName string `yaml:"no_action_description_name"` -} - -type TemplateConfig struct { - Completion string `yaml:"completion"` - Functions string `yaml:"function"` - Chat string `yaml:"chat"` - Edit string `yaml:"edit"` -} - -type ConfigMerger struct { - configs map[string]Config - sync.Mutex -} - -func defaultConfig(modelFile string) *Config { - return &Config{ - OpenAIRequest: defaultRequest(modelFile), - } -} - -func NewConfigMerger() *ConfigMerger { - return &ConfigMerger{ - configs: make(map[string]Config), - } -} -func ReadConfigFile(file string) ([]*Config, error) { - c := &[]*Config{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - return *c, nil -} - -func ReadConfig(file string) (*Config, error) { - c := &Config{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - return c, nil -} - -func (cm *ConfigMerger) LoadConfigFile(file string) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadConfigFile(file) - if err != nil { - return fmt.Errorf("cannot load config file: %w", err) - } - - for _, cc := range c { - cm.configs[cc.Name] = *cc - } - return nil -} - -func (cm *ConfigMerger) LoadConfig(file string) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadConfig(file) - if err != nil { - return fmt.Errorf("cannot read config file: %w", err) - } - - cm.configs[c.Name] = *c - return nil -} - -func (cm *ConfigMerger) GetConfig(m string) (Config, bool) { - cm.Lock() - defer cm.Unlock() - v, exists := cm.configs[m] - return v, exists -} - -func (cm *ConfigMerger) ListConfigs() []string { - cm.Lock() - defer cm.Unlock() - var res []string - for k := range cm.configs { - res = append(res, k) - } - return res -} - -func (cm *ConfigMerger) LoadConfigs(path string) error { - cm.Lock() - defer cm.Unlock() - entries, err := os.ReadDir(path) - if err != nil { - return err - } - files := make([]fs.FileInfo, 0, len(entries)) - for _, entry := range entries { - info, err := entry.Info() - if err != nil { - return err - } - files = append(files, info) - } - for _, file := range files { - // Skip templates, YAML and .keep files - if !strings.Contains(file.Name(), ".yaml") { - continue - } - c, err := ReadConfig(filepath.Join(path, file.Name())) - if err == nil { - cm.configs[c.Name] = *c - } - } - - return nil -} - -func updateConfig(config *Config, input *OpenAIRequest) { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != 0 { - config.TopK = input.TopK - } - if input.TopP != 0 { - config.TopP = input.TopP - } - - if input.Grammar != "" { - config.Grammar = input.Grammar - } - - if input.Temperature != 0 { - config.Temperature = input.Temperature - } - - if input.Maxtokens != 0 { - config.Maxtokens = input.Maxtokens - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.F16 { - config.F16 = input.F16 - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != 0 { - config.Seed = input.Seed - } - - if input.Mirostat != 0 { - config.Mirostat = input.Mirostat - } - - if input.MirostatETA != 0 { - config.MirostatETA = input.MirostatETA - } - - if input.MirostatTAU != 0 { - config.MirostatTAU = input.MirostatTAU - } - - if input.TypicalP != 0 { - config.TypicalP = input.TypicalP - } - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []interface{}: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []interface{}: - tokens := []int{} - for _, ii := range i { - tokens = append(tokens, int(ii.(float64))) - } - config.InputToken = append(config.InputToken, tokens) - } - } - } - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.functionCallString = fnc - } - case map[string]interface{}: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - config.functionCallNameString = name - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } -} -func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { - input := new(OpenAIRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", nil, err - } - - modelFile := input.Model - - if c.Params("model") != "" { - modelFile = c.Params("model") - } - - received, _ := json.Marshal(input) - - log.Debug().Msgf("Request received: %s", string(received)) - - // Set model from bearer token, if available - bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) - - // If no model was specified, take the first available - if modelFile == "" && !bearerExists && randomModel { - models, _ := loader.ListModels() - if len(models) > 0 { - modelFile = models[0] - log.Debug().Msgf("No model specified, using: %s", modelFile) - } else { - log.Debug().Msgf("No model specified, returning error") - return "", nil, fmt.Errorf("no model specified") - } - } - - // If a model is found in bearer token takes precedence - if bearerExists { - log.Debug().Msgf("Using model from bearer token: %s", bearer) - modelFile = bearer - } - return modelFile, input, nil -} - -func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { - // Load a config file if present after the model name - modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") - - var config *Config - - defaults := func() { - config = defaultConfig(modelFile) - config.ContextSize = ctx - config.Threads = threads - config.F16 = f16 - config.Debug = debug - } - - cfg, exists := cm.GetConfig(modelFile) - if !exists { - if _, err := os.Stat(modelConfig); err == nil { - if err := cm.LoadConfig(modelConfig); err != nil { - return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - cfg, exists = cm.GetConfig(modelFile) - if exists { - config = &cfg - } else { - defaults() - } - } else { - defaults() - } - } else { - config = &cfg - } - - // Set the parameters for the language model prediction - updateConfig(config, input) - - // Don't allow 0 as setting - if config.Threads == 0 { - if threads != 0 { - config.Threads = threads - } else { - config.Threads = 4 - } - } - - // Enforce debug flag if passed from CLI - if debug { - config.Debug = true - } - - return config, input, nil -} diff --git a/api/config/config.go b/api/config/config.go new file mode 100644 index 0000000..9df8d3e --- /dev/null +++ b/api/config/config.go @@ -0,0 +1,209 @@ +package api_config + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + + "gopkg.in/yaml.v3" +) + +type Config struct { + PredictionOptions `yaml:"parameters"` + Name string `yaml:"name"` + StopWords []string `yaml:"stopwords"` + Cutstrings []string `yaml:"cutstrings"` + TrimSpace []string `yaml:"trimspace"` + ContextSize int `yaml:"context_size"` + F16 bool `yaml:"f16"` + NUMA bool `yaml:"numa"` + Threads int `yaml:"threads"` + Debug bool `yaml:"debug"` + Roles map[string]string `yaml:"roles"` + Embeddings bool `yaml:"embeddings"` + Backend string `yaml:"backend"` + TemplateConfig TemplateConfig `yaml:"template"` + MirostatETA float64 `yaml:"mirostat_eta"` + MirostatTAU float64 `yaml:"mirostat_tau"` + Mirostat int `yaml:"mirostat"` + NGPULayers int `yaml:"gpu_layers"` + MMap bool `yaml:"mmap"` + MMlock bool `yaml:"mmlock"` + LowVRAM bool `yaml:"low_vram"` + + TensorSplit string `yaml:"tensor_split"` + MainGPU string `yaml:"main_gpu"` + ImageGenerationAssets string `yaml:"asset_dir"` + + PromptCachePath string `yaml:"prompt_cache_path"` + PromptCacheAll bool `yaml:"prompt_cache_all"` + PromptCacheRO bool `yaml:"prompt_cache_ro"` + + Grammar string `yaml:"grammar"` + + PromptStrings, InputStrings []string + InputToken [][]int + functionCallString, functionCallNameString string + + FunctionsConfig Functions `yaml:"function"` +} + +type Functions struct { + DisableNoAction bool `yaml:"disable_no_action"` + NoActionFunctionName string `yaml:"no_action_function_name"` + NoActionDescriptionName string `yaml:"no_action_description_name"` +} + +type TemplateConfig struct { + Completion string `yaml:"completion"` + Functions string `yaml:"function"` + Chat string `yaml:"chat"` + Edit string `yaml:"edit"` +} + +type ConfigLoader struct { + configs map[string]Config + sync.Mutex +} + +func (c *Config) SetFunctionCallString(s string) { + c.functionCallString = s +} + +func (c *Config) SetFunctionCallNameString(s string) { + c.functionCallNameString = s +} + +func (c *Config) ShouldUseFunctions() bool { + return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) +} + +func (c *Config) ShouldCallSpecificFunction() bool { + return len(c.functionCallNameString) > 0 +} + +func (c *Config) FunctionToCall() string { + return c.functionCallNameString +} + +func defaultPredictOptions(modelFile string) PredictionOptions { + return PredictionOptions{ + TopP: 0.7, + TopK: 80, + Maxtokens: 512, + Temperature: 0.9, + Model: modelFile, + } +} + +func DefaultConfig(modelFile string) *Config { + return &Config{ + PredictionOptions: defaultPredictOptions(modelFile), + } +} + +func NewConfigLoader() *ConfigLoader { + return &ConfigLoader{ + configs: make(map[string]Config), + } +} +func ReadConfigFile(file string) ([]*Config, error) { + c := &[]*Config{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + return *c, nil +} + +func ReadConfig(file string) (*Config, error) { + c := &Config{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + return c, nil +} + +func (cm *ConfigLoader) LoadConfigFile(file string) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadConfigFile(file) + if err != nil { + return fmt.Errorf("cannot load config file: %w", err) + } + + for _, cc := range c { + cm.configs[cc.Name] = *cc + } + return nil +} + +func (cm *ConfigLoader) LoadConfig(file string) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadConfig(file) + if err != nil { + return fmt.Errorf("cannot read config file: %w", err) + } + + cm.configs[c.Name] = *c + return nil +} + +func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { + cm.Lock() + defer cm.Unlock() + v, exists := cm.configs[m] + return v, exists +} + +func (cm *ConfigLoader) ListConfigs() []string { + cm.Lock() + defer cm.Unlock() + var res []string + for k := range cm.configs { + res = append(res, k) + } + return res +} + +func (cm *ConfigLoader) LoadConfigs(path string) error { + cm.Lock() + defer cm.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return err + } + files := make([]fs.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + return err + } + files = append(files, info) + } + for _, file := range files { + // Skip templates, YAML and .keep files + if !strings.Contains(file.Name(), ".yaml") { + continue + } + c, err := ReadConfig(filepath.Join(path, file.Name())) + if err == nil { + cm.configs[c.Name] = *c + } + } + + return nil +} diff --git a/api/config_test.go b/api/config/config_test.go similarity index 62% rename from api/config_test.go rename to api/config/config_test.go index 626b90b..4b00d58 100644 --- a/api/config_test.go +++ b/api/config/config_test.go @@ -1,8 +1,10 @@ -package api +package api_config_test import ( "os" + . "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -26,29 +28,29 @@ var _ = Describe("Test cases for config related functions", func() { }) It("Test LoadConfigs", func() { - cm := NewConfigMerger() - options := newOptions() + cm := NewConfigLoader() + opts := options.NewOptions() modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH")) - WithModelLoader(modelLoader)(options) + options.WithModelLoader(modelLoader)(opts) - err := cm.LoadConfigs(options.loader.ModelPath) + err := cm.LoadConfigs(opts.Loader.ModelPath) Expect(err).To(BeNil()) - Expect(cm.configs).ToNot(BeNil()) + Expect(cm.ListConfigs()).ToNot(BeNil()) // config should includes gpt4all models's api.config - Expect(cm.configs).To(HaveKey("gpt4all")) + Expect(cm.ListConfigs()).To(ContainElements("gpt4all")) // config should includes gpt2 models's api.config - Expect(cm.configs).To(HaveKey("gpt4all-2")) + Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2")) // config should includes text-embedding-ada-002 models's api.config - Expect(cm.configs).To(HaveKey("text-embedding-ada-002")) + Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002")) // config should includes rwkv_test models's api.config - Expect(cm.configs).To(HaveKey("rwkv_test")) + Expect(cm.ListConfigs()).To(ContainElements("rwkv_test")) // config should includes whisper-1 models's api.config - Expect(cm.configs).To(HaveKey("whisper-1")) + Expect(cm.ListConfigs()).To(ContainElements("whisper-1")) }) }) }) diff --git a/api/config/prediction.go b/api/config/prediction.go new file mode 100644 index 0000000..59f4fcb --- /dev/null +++ b/api/config/prediction.go @@ -0,0 +1,37 @@ +package api_config + +type PredictionOptions struct { + + // Also part of the OpenAI official spec + Model string `json:"model" yaml:"model"` + + // Also part of the OpenAI official spec + Language string `json:"language"` + + // Also part of the OpenAI official spec. use it for returning multiple results + N int `json:"n"` + + // Common options between all the API calls, part of the OpenAI spec + TopP float64 `json:"top_p" yaml:"top_p"` + TopK int `json:"top_k" yaml:"top_k"` + Temperature float64 `json:"temperature" yaml:"temperature"` + Maxtokens int `json:"max_tokens" yaml:"max_tokens"` + Echo bool `json:"echo"` + + // Custom parameters - not present in the OpenAI API + Batch int `json:"batch" yaml:"batch"` + F16 bool `json:"f16" yaml:"f16"` + IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` + RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` + Keep int `json:"n_keep" yaml:"n_keep"` + + MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` + MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` + Mirostat int `json:"mirostat" yaml:"mirostat"` + + FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` + TFZ float64 `json:"tfz" yaml:"tfz"` + + TypicalP float64 `json:"typical_p" yaml:"typical_p"` + Seed int `json:"seed" yaml:"seed"` +} diff --git a/api/gallery.go b/api/localai/gallery.go similarity index 86% rename from api/gallery.go rename to api/localai/gallery.go index 1c0cec9..feae294 100644 --- a/api/gallery.go +++ b/api/localai/gallery.go @@ -1,4 +1,4 @@ -package api +package localai import ( "context" @@ -9,6 +9,7 @@ import ( json "github.com/json-iterator/go" + config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/gofiber/fiber/v2" "github.com/google/uuid" @@ -38,7 +39,7 @@ type galleryApplier struct { statuses map[string]*galleryOpStatus } -func newGalleryApplier(modelPath string) *galleryApplier { +func NewGalleryService(modelPath string) *galleryApplier { return &galleryApplier{ modelPath: modelPath, C: make(chan galleryOp), @@ -47,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier { } // prepareModel applies a -func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { +func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error { config, err := gallery.GetGalleryConfigFromURL(req.URL) if err != nil { @@ -72,7 +73,7 @@ func (g *galleryApplier) getStatus(s string) *galleryOpStatus { return g.statuses[s] } -func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { +func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { go func() { for { select { @@ -148,7 +149,7 @@ type galleryModel struct { ID string `json:"id"` } -func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error { +func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { dat, err := os.ReadFile(s) if err != nil { return err @@ -156,7 +157,7 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gal return ApplyGalleryFromString(modelPath, string(dat), cm, galleries) } -func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error { +func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { var requests []galleryModel err := json.Unmarshal([]byte(s), &requests) if err != nil { @@ -174,7 +175,9 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []g return err } -func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { +/// Endpoints + +func GetOpStatusEndpoint(g *galleryApplier) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { status := g.getStatus(c.Params("uuid")) @@ -191,7 +194,7 @@ type GalleryModel struct { gallery.GalleryModel } -func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error { +func ApplyModelGalleryEndpoint(modelPath string, cm *config.ConfigLoader, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(GalleryModel) // Get input data from the request body @@ -216,7 +219,7 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, gal } } -func listModelFromGallery(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error { +func ListModelFromGalleryEndpoint(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { log.Debug().Msgf("Listing models from galleries: %+v", galleries) diff --git a/api/localai.go b/api/localai/localai.go similarity index 68% rename from api/localai.go rename to api/localai/localai.go index 66eda5a..f79e889 100644 --- a/api/localai.go +++ b/api/localai/localai.go @@ -1,10 +1,13 @@ -package api +package localai import ( "fmt" "os" "path/filepath" + config "github.com/go-skynet/LocalAI/api/config" + + "github.com/go-skynet/LocalAI/api/options" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/tts" "github.com/go-skynet/LocalAI/pkg/utils" @@ -32,7 +35,7 @@ func generateUniqueFileName(dir, baseName, ext string) string { } } -func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { +func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(TTSRequest) @@ -41,10 +44,10 @@ func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return err } - piperModel, err := o.loader.BackendLoader( + piperModel, err := o.Loader.BackendLoader( model.WithBackendString(model.PiperBackend), model.WithModelFile(input.Model), - model.WithAssetDir(o.assetsDestination)) + model.WithAssetDir(o.AssetsDestination)) if err != nil { return err } @@ -58,16 +61,16 @@ func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return fmt.Errorf("loader returned non-piper object %+v", w) } - if err := os.MkdirAll(o.audioDir, 0755); err != nil { + if err := os.MkdirAll(o.AudioDir, 0755); err != nil { return err } - fileName := generateUniqueFileName(o.audioDir, "piper", ".wav") - filePath := filepath.Join(o.audioDir, fileName) + fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") + filePath := filepath.Join(o.AudioDir, fileName) - modelPath := filepath.Join(o.loader.ModelPath, input.Model) + modelPath := filepath.Join(o.Loader.ModelPath, input.Model) - if err := utils.VerifyPath(modelPath, o.loader.ModelPath); err != nil { + if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { return err } diff --git a/api/openai.go b/api/openai.go deleted file mode 100644 index c39b1cc..0000000 --- a/api/openai.go +++ /dev/null @@ -1,973 +0,0 @@ -package api - -import ( - "bufio" - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "io/ioutil" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - - "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" - "github.com/go-skynet/LocalAI/pkg/grammar" - model "github.com/go-skynet/LocalAI/pkg/model" - whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" - "github.com/valyala/fasthttp" -) - -// APIError provides error information returned by the OpenAI API. -type APIError struct { - Code any `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` -} - -type ErrorResponse struct { - Error *APIError `json:"error,omitempty"` -} - -type OpenAIUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Item struct { - Embedding []float32 `json:"embedding"` - Index int `json:"index"` - Object string `json:"object,omitempty"` - - // Images - URL string `json:"url,omitempty"` - B64JSON string `json:"b64_json,omitempty"` -} - -type OpenAIResponse struct { - Created int `json:"created,omitempty"` - Object string `json:"object,omitempty"` - ID string `json:"id,omitempty"` - Model string `json:"model,omitempty"` - Choices []Choice `json:"choices,omitempty"` - Data []Item `json:"data,omitempty"` - - Usage OpenAIUsage `json:"usage"` -} - -type Choice struct { - Index int `json:"index,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - Message *Message `json:"message,omitempty"` - Delta *Message `json:"delta,omitempty"` - Text string `json:"text,omitempty"` -} - -type Message struct { - // The message role - Role string `json:"role,omitempty" yaml:"role"` - // The message content - Content *string `json:"content" yaml:"content"` - // A result of a function call - FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` -} - -type OpenAIModel struct { - ID string `json:"id"` - Object string `json:"object"` -} - -type OpenAIRequest struct { - Model string `json:"model" yaml:"model"` - - // whisper - File string `json:"file" validate:"required"` - Language string `json:"language"` - //whisper/image - ResponseFormat string `json:"response_format"` - // image - Size string `json:"size"` - // Prompt is read only by completion/image API calls - Prompt interface{} `json:"prompt" yaml:"prompt"` - - // Edit endpoint - Instruction string `json:"instruction" yaml:"instruction"` - Input interface{} `json:"input" yaml:"input"` - - Stop interface{} `json:"stop" yaml:"stop"` - - // Messages is read only by chat/completion API calls - Messages []Message `json:"messages" yaml:"messages"` - - // A list of available functions to call - Functions []grammar.Function `json:"functions" yaml:"functions"` - FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object - - Stream bool `json:"stream"` - Echo bool `json:"echo"` - // Common options between all the API calls - TopP float64 `json:"top_p" yaml:"top_p"` - TopK int `json:"top_k" yaml:"top_k"` - Temperature float64 `json:"temperature" yaml:"temperature"` - Maxtokens int `json:"max_tokens" yaml:"max_tokens"` - - N int `json:"n"` - - // Custom parameters - not present in the OpenAI API - Batch int `json:"batch" yaml:"batch"` - F16 bool `json:"f16" yaml:"f16"` - IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` - RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` - Keep int `json:"n_keep" yaml:"n_keep"` - - MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` - MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` - Mirostat int `json:"mirostat" yaml:"mirostat"` - - FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` - TFZ float64 `json:"tfz" yaml:"tfz"` - - Seed int `json:"seed" yaml:"seed"` - - // Image (not supported by OpenAI) - Mode int `json:"mode"` - Step int `json:"step"` - - // A grammar to constrain the LLM output - Grammar string `json:"grammar" yaml:"grammar"` - // A grammar object - JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` - - TypicalP float64 `json:"typical_p" yaml:"typical_p"` -} - -func defaultRequest(modelFile string) OpenAIRequest { - return OpenAIRequest{ - TopP: 0.7, - TopK: 80, - Maxtokens: 512, - Temperature: 0.9, - Model: modelFile, - } -} - -// https://platform.openai.com/docs/api-reference/completions -func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { - ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { - resp := OpenAIResponse{ - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ - { - Index: 0, - Text: s, - }, - }, - Object: "text_completion", - } - log.Debug().Msgf("Sending goroutine: %s", s) - - responses <- resp - return true - }) - close(responses) - } - - return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("`input`: %+v", input) - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - if input.Stream { - log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - //c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - } - - templateFile := config.Model - - if config.TemplateConfig.Completion != "" { - templateFile = config.TemplateConfig.Completion - } - - if input.Stream { - if len(config.PromptStrings) > 1 { - return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") - } - - predInput := config.PromptStrings[0] - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - }{ - Input: predInput, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } - - responses := make(chan OpenAIResponse) - - go process(predInput, input, config, o.loader, responses) - - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - for ev := range responses { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - - log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) - w.Flush() - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ - { - Index: 0, - FinishReason: "stop", - }, - }, - Object: "text_completion", - } - respData, _ := json.Marshal(resp) - - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return nil - } - - var result []Choice - for _, i := range config.PromptStrings { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - }{ - Input: i, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - - r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Text: s}) - }, nil) - if err != nil { - return err - } - - result = append(result, r...) - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "text_completion", - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -// https://platform.openai.com/docs/api-reference/embeddings -func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - items := []Item{} - - for i, s := range config.InputToken { - // get the model function to call for the result - embedFn, err := ModelEmbedding("", s, o.loader, *config, o) - if err != nil { - return err - } - - embeddings, err := embedFn() - if err != nil { - return err - } - items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - for i, s := range config.InputStrings { - // get the model function to call for the result - embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config, o) - if err != nil { - return err - } - - embeddings, err := embedFn() - if err != nil { - return err - } - items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Data: items, - Object: "list", - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -func isEOS(s string) bool { - if s == "<|endoftext|>" { - return true - } - - return false -} -func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - - process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { - initialMessage := OpenAIResponse{ - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, - Object: "chat.completion.chunk", - } - responses <- initialMessage - - ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { - resp := OpenAIResponse{ - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, - Object: "chat.completion.chunk", - } - log.Debug().Msgf("Sending goroutine: %s", s) - - if s != "" && !isEOS(s) { - responses <- resp - } - return true - }) - close(responses) - } - return func(c *fiber.Ctx) error { - processFunctions := false - funcs := grammar.Functions{} - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - log.Debug().Msgf("Configuration read: %+v", config) - - // Allow the user to set custom actions via config file - // to be "embedded" in each model - noActionName := "answer" - noActionDescription := "use this action to answer without performing any action" - - if config.FunctionsConfig.NoActionFunctionName != "" { - noActionName = config.FunctionsConfig.NoActionFunctionName - } - if config.FunctionsConfig.NoActionDescriptionName != "" { - noActionDescription = config.FunctionsConfig.NoActionDescriptionName - } - - // process functions if we have any defined or if we have a function call string - if len(input.Functions) > 0 && - ((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) { - log.Debug().Msgf("Response needs to process functions") - - processFunctions = true - - noActionGrammar := grammar.Function{ - Name: noActionName, - Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "The message to reply the user with", - }}, - }, - } - - // Append the no action function - funcs = append(funcs, input.Functions...) - if !config.FunctionsConfig.DisableNoAction { - funcs = append(funcs, noActionGrammar) - } - - // Force picking one of the functions by the request - if config.functionCallNameString != "" { - funcs = funcs.Select(config.functionCallNameString) - } - - // Update input grammar - jsStruct := funcs.ToJSONStructure() - config.Grammar = jsStruct.Grammar("") - } else if input.JSONFunctionGrammarObject != nil { - config.Grammar = input.JSONFunctionGrammarObject.Grammar("") - } - - // functions are not supported in stream mode (yet?) - toStream := input.Stream && !processFunctions - - log.Debug().Msgf("Parameters: %+v", config) - - var predInput string - - mess := []string{} - for _, i := range input.Messages { - var content string - role := i.Role - // if function call, we might want to customize the role so we can display better that the "assistant called a json action" - // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request - if i.FunctionCall != nil && i.Role == "assistant" { - roleFn := "assistant_function_call" - r := config.Roles[roleFn] - if r != "" { - role = roleFn - } - } - r := config.Roles[role] - contentExists := i.Content != nil && *i.Content != "" - if r != "" { - if contentExists { - content = fmt.Sprint(r, " ", *i.Content) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + fmt.Sprint(r, " ", string(j)) - } else { - content = fmt.Sprint(r, " ", string(j)) - } - } - } - } else { - if contentExists { - content = fmt.Sprint(*i.Content) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + string(j) - } else { - content = string(j) - } - } - } - } - - mess = append(mess, content) - } - - predInput = strings.Join(mess, "\n") - log.Debug().Msgf("Prompt (before templating): %s", predInput) - - if toStream { - log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - // c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - } - - templateFile := config.Model - - if config.TemplateConfig.Chat != "" && !processFunctions { - templateFile = config.TemplateConfig.Chat - } - - if config.TemplateConfig.Functions != "" && processFunctions { - templateFile = config.TemplateConfig.Functions - } - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - Functions []grammar.Function - }{ - Input: predInput, - Functions: funcs, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } else { - log.Debug().Msgf("Template failed loading: %s", err.Error()) - } - - log.Debug().Msgf("Prompt (after templating): %s", predInput) - if processFunctions { - log.Debug().Msgf("Grammar: %+v", config.Grammar) - } - - if toStream { - responses := make(chan OpenAIResponse) - - go process(predInput, input, config, o.loader, responses) - - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - for ev := range responses { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - - log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) - w.Flush() - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ - { - FinishReason: "stop", - Index: 0, - Delta: &Message{}, - }}, - Object: "chat.completion.chunk", - } - respData, _ := json.Marshal(resp) - - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return nil - } - - result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) { - if processFunctions { - // As we have to change the result before processing, we can't stream the answer (yet?) - ss := map[string]interface{}{} - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name := ss["function"] - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - d, _ := json.Marshal(args) - - ss["arguments"] = string(d) - ss["name"] = func_name - - // if do nothing, reply with a message - if func_name == noActionName { - log.Debug().Msgf("nothing to do, computing a reply") - - // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} - json.Unmarshal([]byte(d), &arguments) - m, exists := arguments["message"] - if exists { - switch message := m.(type) { - case string: - if message != "" { - log.Debug().Msgf("Reply received from LLM: %s", message) - message = Finetune(*config, predInput, message) - log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) - return - } - } - } - - log.Debug().Msgf("No action received from LLM, without a message, computing a reply") - // Otherwise ask the LLM to understand the JSON output and the context, and return a message - // Note: This costs (in term of CPU) another computation - config.Grammar = "" - predFunc, err := ModelInference(predInput, o.loader, *config, o, nil) - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } - - prediction, err := predFunc() - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } - - prediction = Finetune(*config, predInput, prediction) - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) - } else { - // otherwise reply with the function call - *c = append(*c, Choice{ - FinishReason: "function_call", - Message: &Message{Role: "assistant", FunctionCall: ss}, - }) - } - - return - } - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) - }, nil) - if err != nil { - return err - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "chat.completion", - } - respData, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", respData) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - templateFile := config.Model - - if config.TemplateConfig.Edit != "" { - templateFile = config.TemplateConfig.Edit - } - - var result []Choice - for _, i := range config.InputStrings { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - Instruction string - }{Input: i}) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - - r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Text: s}) - }, nil) - if err != nil { - return err - } - - result = append(result, r...) - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "edit", - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -// https://platform.openai.com/docs/api-reference/images/create - -/* -* - - curl http://localhost:8080/v1/images/generations \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "A cute baby sea otter", - "n": 1, - "size": "512x512" - }' - -* -*/ -func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - m, input, err := readInput(c, o.loader, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - if m == "" { - m = model.StableDiffusionBackend - } - log.Debug().Msgf("Loading model: %+v", m) - - config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - // XXX: Only stablediffusion is supported for now - if config.Backend == "" { - config.Backend = model.StableDiffusionBackend - } - - sizeParts := strings.Split(input.Size, "x") - if len(sizeParts) != 2 { - return fmt.Errorf("Invalid value for 'size'") - } - width, err := strconv.Atoi(sizeParts[0]) - if err != nil { - return fmt.Errorf("Invalid value for 'size'") - } - height, err := strconv.Atoi(sizeParts[1]) - if err != nil { - return fmt.Errorf("Invalid value for 'size'") - } - - b64JSON := false - if input.ResponseFormat == "b64_json" { - b64JSON = true - } - - var result []Item - for _, i := range config.PromptStrings { - n := input.N - if input.N == 0 { - n = 1 - } - for j := 0; j < n; j++ { - prompts := strings.Split(i, "|") - positive_prompt := prompts[0] - negative_prompt := "" - if len(prompts) > 1 { - negative_prompt = prompts[1] - } - - mode := 0 - step := 15 - - if input.Mode != 0 { - mode = input.Mode - } - - if input.Step != 0 { - step = input.Step - } - - tempDir := "" - if !b64JSON { - tempDir = o.imageDir - } - // Create a temporary file - outputFile, err := ioutil.TempFile(tempDir, "b64") - if err != nil { - return err - } - outputFile.Close() - output := outputFile.Name() + ".png" - // Rename the temporary file - err = os.Rename(outputFile.Name(), output) - if err != nil { - return err - } - - baseURL := c.BaseURL() - - fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config, o) - if err != nil { - return err - } - if err := fn(); err != nil { - return err - } - - item := &Item{} - - if b64JSON { - defer os.RemoveAll(output) - data, err := os.ReadFile(output) - if err != nil { - return err - } - item.B64JSON = base64.StdEncoding.EncodeToString(data) - } else { - base := filepath.Base(output) - item.URL = baseURL + "/generated-images/" + base - } - - result = append(result, *item) - } - } - - resp := &OpenAIResponse{ - Data: result, - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -// https://platform.openai.com/docs/api-reference/audio/create -func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - m, input, err := readInput(c, o.loader, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - // retrieve the file data from the request - file, err := c.FormFile("file") - if err != nil { - return err - } - f, err := file.Open() - if err != nil { - return err - } - defer f.Close() - - dir, err := os.MkdirTemp("", "whisper") - - if err != nil { - return err - } - defer os.RemoveAll(dir) - - dst := filepath.Join(dir, path.Base(file.Filename)) - dstFile, err := os.Create(dst) - if err != nil { - return err - } - - if _, err := io.Copy(dstFile, f); err != nil { - log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) - return err - } - - log.Debug().Msgf("Audio file copied to: %+v", dst) - - whisperModel, err := o.loader.BackendLoader( - model.WithBackendString(model.WhisperBackend), - model.WithModelFile(config.Model), - model.WithThreads(uint32(config.Threads)), - model.WithAssetDir(o.assetsDestination)) - if err != nil { - return err - } - - if whisperModel == nil { - return fmt.Errorf("could not load whisper model") - } - - w, ok := whisperModel.(whisper.Model) - if !ok { - return fmt.Errorf("loader returned non-whisper object") - } - - tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) - if err != nil { - return err - } - - log.Debug().Msgf("Trascribed: %+v", tr) - // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(tr) - } -} - -func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - models, err := loader.ListModels() - if err != nil { - return err - } - var mm map[string]interface{} = map[string]interface{}{} - - dataModels := []OpenAIModel{} - for _, m := range models { - mm[m] = nil - dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) - } - - for _, k := range cm.ListConfigs() { - if _, exists := mm[k]; !exists { - dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) - } - } - - return c.JSON(struct { - Object string `json:"object"` - Data []OpenAIModel `json:"data"` - }{ - Object: "list", - Data: dataModels, - }) - } -} diff --git a/api/openai/api.go b/api/openai/api.go new file mode 100644 index 0000000..6d7ce5e --- /dev/null +++ b/api/openai/api.go @@ -0,0 +1,105 @@ +package openai + +import ( + config "github.com/go-skynet/LocalAI/api/config" + + "github.com/go-skynet/LocalAI/pkg/grammar" +) + +// APIError provides error information returned by the OpenAI API. +type APIError struct { + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` +} + +type ErrorResponse struct { + Error *APIError `json:"error,omitempty"` +} + +type OpenAIUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Item struct { + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + Object string `json:"object,omitempty"` + + // Images + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` +} + +type OpenAIResponse struct { + Created int `json:"created,omitempty"` + Object string `json:"object,omitempty"` + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Choices []Choice `json:"choices,omitempty"` + Data []Item `json:"data,omitempty"` + + Usage OpenAIUsage `json:"usage"` +} + +type Choice struct { + Index int `json:"index,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Message *Message `json:"message,omitempty"` + Delta *Message `json:"delta,omitempty"` + Text string `json:"text,omitempty"` +} + +type Message struct { + // The message role + Role string `json:"role,omitempty" yaml:"role"` + // The message content + Content *string `json:"content" yaml:"content"` + // A result of a function call + FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` +} + +type OpenAIModel struct { + ID string `json:"id"` + Object string `json:"object"` +} + +type OpenAIRequest struct { + config.PredictionOptions + + // whisper + File string `json:"file" validate:"required"` + //whisper/image + ResponseFormat string `json:"response_format"` + // image + Size string `json:"size"` + // Prompt is read only by completion/image API calls + Prompt interface{} `json:"prompt" yaml:"prompt"` + + // Edit endpoint + Instruction string `json:"instruction" yaml:"instruction"` + Input interface{} `json:"input" yaml:"input"` + + Stop interface{} `json:"stop" yaml:"stop"` + + // Messages is read only by chat/completion API calls + Messages []Message `json:"messages" yaml:"messages"` + + // A list of available functions to call + Functions []grammar.Function `json:"functions" yaml:"functions"` + FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object + + Stream bool `json:"stream"` + + // Image (not supported by OpenAI) + Mode int `json:"mode"` + Step int `json:"step"` + + // A grammar to constrain the LLM output + Grammar string `json:"grammar" yaml:"grammar"` + + JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` +} diff --git a/api/openai/chat.go b/api/openai/chat.go new file mode 100644 index 0000000..30f6e01 --- /dev/null +++ b/api/openai/chat.go @@ -0,0 +1,320 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "strings" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grammar" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { + initialMessage := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { + resp := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, + Object: "chat.completion.chunk", + } + + responses <- resp + return true + }) + close(responses) + } + return func(c *fiber.Ctx) error { + processFunctions := false + funcs := grammar.Functions{} + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + log.Debug().Msgf("Configuration read: %+v", config) + + // Allow the user to set custom actions via config file + // to be "embedded" in each model + noActionName := "answer" + noActionDescription := "use this action to answer without performing any action" + + if config.FunctionsConfig.NoActionFunctionName != "" { + noActionName = config.FunctionsConfig.NoActionFunctionName + } + if config.FunctionsConfig.NoActionDescriptionName != "" { + noActionDescription = config.FunctionsConfig.NoActionDescriptionName + } + + // process functions if we have any defined or if we have a function call string + if len(input.Functions) > 0 && config.ShouldUseFunctions() { + log.Debug().Msgf("Response needs to process functions") + + processFunctions = true + + noActionGrammar := grammar.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } + + // Append the no action function + funcs = append(funcs, input.Functions...) + if !config.FunctionsConfig.DisableNoAction { + funcs = append(funcs, noActionGrammar) + } + + // Force picking one of the functions by the request + if config.FunctionToCall() != "" { + funcs = funcs.Select(config.FunctionToCall()) + } + + // Update input grammar + jsStruct := funcs.ToJSONStructure() + config.Grammar = jsStruct.Grammar("") + } else if input.JSONFunctionGrammarObject != nil { + config.Grammar = input.JSONFunctionGrammarObject.Grammar("") + } + + // functions are not supported in stream mode (yet?) + toStream := input.Stream && !processFunctions + + log.Debug().Msgf("Parameters: %+v", config) + + var predInput string + + mess := []string{} + for _, i := range input.Messages { + var content string + role := i.Role + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" + // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request + if i.FunctionCall != nil && i.Role == "assistant" { + roleFn := "assistant_function_call" + r := config.Roles[roleFn] + if r != "" { + role = roleFn + } + } + r := config.Roles[role] + contentExists := i.Content != nil && *i.Content != "" + if r != "" { + if contentExists { + content = fmt.Sprint(r, " ", *i.Content) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } + } + } + } else { + if contentExists { + content = fmt.Sprint(*i.Content) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } + } + } + } + + mess = append(mess, content) + } + + predInput = strings.Join(mess, "\n") + log.Debug().Msgf("Prompt (before templating): %s", predInput) + + if toStream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + // c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + + templateFile := config.Model + + if config.TemplateConfig.Chat != "" && !processFunctions { + templateFile = config.TemplateConfig.Chat + } + + if config.TemplateConfig.Functions != "" && processFunctions { + templateFile = config.TemplateConfig.Functions + } + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + Functions []grammar.Function + }{ + Input: predInput, + Functions: funcs, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) + } + + log.Debug().Msgf("Prompt (after templating): %s", predInput) + if processFunctions { + log.Debug().Msgf("Grammar: %+v", config.Grammar) + } + + if toStream { + responses := make(chan OpenAIResponse) + + go process(predInput, input, config, o.Loader, responses) + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{ + { + FinishReason: "stop", + Index: 0, + Delta: &Message{}, + }}, + Object: "chat.completion.chunk", + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + + result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + if processFunctions { + // As we have to change the result before processing, we can't stream the answer (yet?) + ss := map[string]interface{}{} + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name := ss["function"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + d, _ := json.Marshal(args) + + ss["arguments"] = string(d) + ss["name"] = func_name + + // if do nothing, reply with a message + if func_name == noActionName { + log.Debug().Msgf("nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(d), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = backend.Finetune(*config, predInput, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) + return + } + } + } + + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU) another computation + config.Grammar = "" + predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil) + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction, err := predFunc() + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction = backend.Finetune(*config, predInput, prediction) + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) + } else { + // otherwise reply with the function call + *c = append(*c, Choice{ + FinishReason: "function_call", + Message: &Message{Role: "assistant", FunctionCall: ss}, + }) + } + + return + } + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) + }, nil) + if err != nil { + return err + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + } + respData, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", respData) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/completion.go b/api/openai/completion.go new file mode 100644 index 0000000..d17fd60 --- /dev/null +++ b/api/openai/completion.go @@ -0,0 +1,159 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +// https://platform.openai.com/docs/api-reference/completions +func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { + ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { + resp := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{ + { + Index: 0, + Text: s, + }, + }, + Object: "text_completion", + } + log.Debug().Msgf("Sending goroutine: %s", s) + + responses <- resp + return true + }) + close(responses) + } + + return func(c *fiber.Ctx) error { + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("`input`: %+v", input) + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + if input.Stream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + //c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + + templateFile := config.Model + + if config.TemplateConfig.Completion != "" { + templateFile = config.TemplateConfig.Completion + } + + if input.Stream { + if len(config.PromptStrings) > 1 { + return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") + } + + predInput := config.PromptStrings[0] + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + }{ + Input: predInput, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + + responses := make(chan OpenAIResponse) + + go process(predInput, input, config, o.Loader, responses) + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{ + { + Index: 0, + FinishReason: "stop", + }, + }, + Object: "text_completion", + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + + var result []Choice + for _, i := range config.PromptStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + }{ + Input: i, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + + r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + result = append(result, r...) + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "text_completion", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/edit.go b/api/openai/edit.go new file mode 100644 index 0000000..d988d6d --- /dev/null +++ b/api/openai/edit.go @@ -0,0 +1,67 @@ +package openai + +import ( + "encoding/json" + "fmt" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + templateFile := config.Model + + if config.TemplateConfig.Edit != "" { + templateFile = config.TemplateConfig.Edit + } + + var result []Choice + for _, i := range config.InputStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + Instruction string + }{Input: i}) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + + r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + result = append(result, r...) + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "edit", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/embeddings.go b/api/openai/embeddings.go new file mode 100644 index 0000000..248ae5c --- /dev/null +++ b/api/openai/embeddings.go @@ -0,0 +1,70 @@ +package openai + +import ( + "encoding/json" + "fmt" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/embeddings +func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + items := []Item{} + + for i, s := range config.InputToken { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + for i, s := range config.InputStrings { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Data: items, + Object: "list", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/image.go b/api/openai/image.go new file mode 100644 index 0000000..bca54c1 --- /dev/null +++ b/api/openai/image.go @@ -0,0 +1,158 @@ +package openai + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/images/create + +/* +* + + curl http://localhost:8080/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A cute baby sea otter", + "n": 1, + "size": "512x512" + }' + +* +*/ +func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + m, input, err := readInput(c, o.Loader, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + if m == "" { + m = model.StableDiffusionBackend + } + log.Debug().Msgf("Loading model: %+v", m) + + config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + // XXX: Only stablediffusion is supported for now + if config.Backend == "" { + config.Backend = model.StableDiffusionBackend + } + + sizeParts := strings.Split(input.Size, "x") + if len(sizeParts) != 2 { + return fmt.Errorf("Invalid value for 'size'") + } + width, err := strconv.Atoi(sizeParts[0]) + if err != nil { + return fmt.Errorf("Invalid value for 'size'") + } + height, err := strconv.Atoi(sizeParts[1]) + if err != nil { + return fmt.Errorf("Invalid value for 'size'") + } + + b64JSON := false + if input.ResponseFormat == "b64_json" { + b64JSON = true + } + + var result []Item + for _, i := range config.PromptStrings { + n := input.N + if input.N == 0 { + n = 1 + } + for j := 0; j < n; j++ { + prompts := strings.Split(i, "|") + positive_prompt := prompts[0] + negative_prompt := "" + if len(prompts) > 1 { + negative_prompt = prompts[1] + } + + mode := 0 + step := 15 + + if input.Mode != 0 { + mode = input.Mode + } + + if input.Step != 0 { + step = input.Step + } + + tempDir := "" + if !b64JSON { + tempDir = o.ImageDir + } + // Create a temporary file + outputFile, err := ioutil.TempFile(tempDir, "b64") + if err != nil { + return err + } + outputFile.Close() + output := outputFile.Name() + ".png" + // Rename the temporary file + err = os.Rename(outputFile.Name(), output) + if err != nil { + return err + } + + baseURL := c.BaseURL() + + fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.Loader, *config, o) + if err != nil { + return err + } + if err := fn(); err != nil { + return err + } + + item := &Item{} + + if b64JSON { + defer os.RemoveAll(output) + data, err := os.ReadFile(output) + if err != nil { + return err + } + item.B64JSON = base64.StdEncoding.EncodeToString(data) + } else { + base := filepath.Base(output) + item.URL = baseURL + "/generated-images/" + base + } + + result = append(result, *item) + } + } + + resp := &OpenAIResponse{ + Data: result, + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/inference.go b/api/openai/inference.go new file mode 100644 index 0000000..a9991fa --- /dev/null +++ b/api/openai/inference.go @@ -0,0 +1,36 @@ +package openai + +import ( + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { + result := []Choice{} + + if n == 0 { + n = 1 + } + + // get the model function to call for the result + predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback) + if err != nil { + return result, err + } + + for i := 0; i < n; i++ { + prediction, err := predFunc() + if err != nil { + return result, err + } + + prediction = backend.Finetune(*config, predInput, prediction) + cb(prediction, &result) + + //result = append(result, Choice{Text: prediction}) + + } + return result, err +} diff --git a/api/openai/list.go b/api/openai/list.go new file mode 100644 index 0000000..0cd7f3a --- /dev/null +++ b/api/openai/list.go @@ -0,0 +1,37 @@ +package openai + +import ( + config "github.com/go-skynet/LocalAI/api/config" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" +) + +func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + models, err := loader.ListModels() + if err != nil { + return err + } + var mm map[string]interface{} = map[string]interface{}{} + + dataModels := []OpenAIModel{} + for _, m := range models { + mm[m] = nil + dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) + } + + for _, k := range cm.ListConfigs() { + if _, exists := mm[k]; !exists { + dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) + } + } + + return c.JSON(struct { + Object string `json:"object"` + Data []OpenAIModel `json:"data"` + }{ + Object: "list", + Data: dataModels, + }) + } +} diff --git a/api/openai/request.go b/api/openai/request.go new file mode 100644 index 0000000..84dbaa8 --- /dev/null +++ b/api/openai/request.go @@ -0,0 +1,234 @@ +package openai + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + config "github.com/go-skynet/LocalAI/api/config" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { + input := new(OpenAIRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return "", nil, err + } + + modelFile := input.Model + + if c.Params("model") != "" { + modelFile = c.Params("model") + } + + received, _ := json.Marshal(input) + + log.Debug().Msgf("Request received: %s", string(received)) + + // Set model from bearer token, if available + bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) + + // If no model was specified, take the first available + if modelFile == "" && !bearerExists && randomModel { + models, _ := loader.ListModels() + if len(models) > 0 { + modelFile = models[0] + log.Debug().Msgf("No model specified, using: %s", modelFile) + } else { + log.Debug().Msgf("No model specified, returning error") + return "", nil, fmt.Errorf("no model specified") + } + } + + // If a model is found in bearer token takes precedence + if bearerExists { + log.Debug().Msgf("Using model from bearer token: %s", bearer) + modelFile = bearer + } + return modelFile, input, nil +} + +func updateConfig(config *config.Config, input *OpenAIRequest) { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != 0 { + config.TopK = input.TopK + } + if input.TopP != 0 { + config.TopP = input.TopP + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != 0 { + config.Temperature = input.Temperature + } + + if input.Maxtokens != 0 { + config.Maxtokens = input.Maxtokens + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.F16 { + config.F16 = input.F16 + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != 0 { + config.Seed = input.Seed + } + + if input.Mirostat != 0 { + config.Mirostat = input.Mirostat + } + + if input.MirostatETA != 0 { + config.MirostatETA = input.MirostatETA + } + + if input.MirostatTAU != 0 { + config.MirostatTAU = input.MirostatTAU + } + + if input.TypicalP != 0 { + config.TypicalP = input.TypicalP + } + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []interface{}: + tokens := []int{} + for _, ii := range i { + tokens = append(tokens, int(ii.(float64))) + } + config.InputToken = append(config.InputToken, tokens) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if !e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } +} + +func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) { + // Load a config file if present after the model name + modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") + + var cfg *config.Config + + defaults := func() { + cfg = config.DefaultConfig(modelFile) + cfg.ContextSize = ctx + cfg.Threads = threads + cfg.F16 = f16 + cfg.Debug = debug + } + + cfgExisting, exists := cm.GetConfig(modelFile) + if !exists { + if _, err := os.Stat(modelConfig); err == nil { + if err := cm.LoadConfig(modelConfig); err != nil { + return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + cfgExisting, exists = cm.GetConfig(modelFile) + if exists { + cfg = &cfgExisting + } else { + defaults() + } + } else { + defaults() + } + } else { + cfg = &cfgExisting + } + + // Set the parameters for the language model prediction + updateConfig(cfg, input) + + // Don't allow 0 as setting + if cfg.Threads == 0 { + if threads != 0 { + cfg.Threads = threads + } else { + cfg.Threads = 4 + } + } + + // Enforce debug flag if passed from CLI + if debug { + cfg.Debug = true + } + + return cfg, input, nil +} diff --git a/api/openai/transcription.go b/api/openai/transcription.go new file mode 100644 index 0000000..279f320 --- /dev/null +++ b/api/openai/transcription.go @@ -0,0 +1,91 @@ +package openai + +import ( + "fmt" + "io" + "net/http" + "os" + "path" + "path/filepath" + + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" + + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/audio/create +func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + m, input, err := readInput(c, o.Loader, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + // retrieve the file data from the request + file, err := c.FormFile("file") + if err != nil { + return err + } + f, err := file.Open() + if err != nil { + return err + } + defer f.Close() + + dir, err := os.MkdirTemp("", "whisper") + + if err != nil { + return err + } + defer os.RemoveAll(dir) + + dst := filepath.Join(dir, path.Base(file.Filename)) + dstFile, err := os.Create(dst) + if err != nil { + return err + } + + if _, err := io.Copy(dstFile, f); err != nil { + log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) + return err + } + + log.Debug().Msgf("Audio file copied to: %+v", dst) + + whisperModel, err := o.Loader.BackendLoader( + model.WithBackendString(model.WhisperBackend), + model.WithModelFile(config.Model), + model.WithThreads(uint32(config.Threads)), + model.WithAssetDir(o.AssetsDestination)) + if err != nil { + return err + } + + if whisperModel == nil { + return fmt.Errorf("could not load whisper model") + } + + w, ok := whisperModel.(whisper.Model) + if !ok { + return fmt.Errorf("loader returned non-whisper object") + } + + tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) + if err != nil { + return err + } + + log.Debug().Msgf("Trascribed: %+v", tr) + // TODO: handle different outputs here + return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr}) + } +} diff --git a/api/options.go b/api/options/options.go similarity index 60% rename from api/options.go rename to api/options/options.go index 923288a..06029b0 100644 --- a/api/options.go +++ b/api/options/options.go @@ -1,4 +1,4 @@ -package api +package options import ( "context" @@ -11,35 +11,35 @@ import ( ) type Option struct { - context context.Context - configFile string - loader *model.ModelLoader - uploadLimitMB, threads, ctxSize int - f16 bool - debug, disableMessage bool - imageDir string - audioDir string - cors bool - preloadJSONModels string - preloadModelsFromPath string - corsAllowOrigins string + Context context.Context + ConfigFile string + Loader *model.ModelLoader + UploadLimitMB, Threads, ContextSize int + F16 bool + Debug, DisableMessage bool + ImageDir string + AudioDir string + CORS bool + PreloadJSONModels string + PreloadModelsFromPath string + CORSAllowOrigins string - galleries []gallery.Gallery + Galleries []gallery.Gallery - backendAssets embed.FS - assetsDestination string + BackendAssets embed.FS + AssetsDestination string } type AppOption func(*Option) -func newOptions(o ...AppOption) *Option { +func NewOptions(o ...AppOption) *Option { opt := &Option{ - context: context.Background(), - uploadLimitMB: 15, - threads: 1, - ctxSize: 512, - debug: true, - disableMessage: true, + Context: context.Background(), + UploadLimitMB: 15, + Threads: 1, + ContextSize: 512, + Debug: true, + DisableMessage: true, } for _, oo := range o { oo(opt) @@ -49,25 +49,25 @@ func newOptions(o ...AppOption) *Option { func WithCors(b bool) AppOption { return func(o *Option) { - o.cors = b + o.CORS = b } } func WithCorsAllowOrigins(b string) AppOption { return func(o *Option) { - o.corsAllowOrigins = b + o.CORSAllowOrigins = b } } func WithBackendAssetsOutput(out string) AppOption { return func(o *Option) { - o.assetsDestination = out + o.AssetsDestination = out } } func WithBackendAssets(f embed.FS) AppOption { return func(o *Option) { - o.backendAssets = f + o.BackendAssets = f } } @@ -81,89 +81,89 @@ func WithStringGalleries(galls string) AppOption { if err := json.Unmarshal([]byte(galls), &galleries); err != nil { log.Error().Msgf("failed loading galleries: %s", err.Error()) } - o.galleries = append(o.galleries, galleries...) + o.Galleries = append(o.Galleries, galleries...) } } func WithGalleries(galleries []gallery.Gallery) AppOption { return func(o *Option) { - o.galleries = append(o.galleries, galleries...) + o.Galleries = append(o.Galleries, galleries...) } } func WithContext(ctx context.Context) AppOption { return func(o *Option) { - o.context = ctx + o.Context = ctx } } func WithYAMLConfigPreload(configFile string) AppOption { return func(o *Option) { - o.preloadModelsFromPath = configFile + o.PreloadModelsFromPath = configFile } } func WithJSONStringPreload(configFile string) AppOption { return func(o *Option) { - o.preloadJSONModels = configFile + o.PreloadJSONModels = configFile } } func WithConfigFile(configFile string) AppOption { return func(o *Option) { - o.configFile = configFile + o.ConfigFile = configFile } } func WithModelLoader(loader *model.ModelLoader) AppOption { return func(o *Option) { - o.loader = loader + o.Loader = loader } } func WithUploadLimitMB(limit int) AppOption { return func(o *Option) { - o.uploadLimitMB = limit + o.UploadLimitMB = limit } } func WithThreads(threads int) AppOption { return func(o *Option) { - o.threads = threads + o.Threads = threads } } func WithContextSize(ctxSize int) AppOption { return func(o *Option) { - o.ctxSize = ctxSize + o.ContextSize = ctxSize } } func WithF16(f16 bool) AppOption { return func(o *Option) { - o.f16 = f16 + o.F16 = f16 } } func WithDebug(debug bool) AppOption { return func(o *Option) { - o.debug = debug + o.Debug = debug } } func WithDisableMessage(disableMessage bool) AppOption { return func(o *Option) { - o.disableMessage = disableMessage + o.DisableMessage = disableMessage } } func WithAudioDir(audioDir string) AppOption { return func(o *Option) { - o.audioDir = audioDir + o.AudioDir = audioDir } } func WithImageDir(imageDir string) AppOption { return func(o *Option) { - o.imageDir = imageDir + o.ImageDir = imageDir } } diff --git a/api/prediction.go b/api/prediction.go deleted file mode 100644 index 4a9c1c8..0000000 --- a/api/prediction.go +++ /dev/null @@ -1,415 +0,0 @@ -package api - -import ( - "context" - "fmt" - "os" - "path/filepath" - "regexp" - "strings" - "sync" - - "github.com/donomii/go-rwkv.cpp" - "github.com/go-skynet/LocalAI/pkg/grpc" - pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/go-skynet/LocalAI/pkg/langchain" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/stablediffusion" - "github.com/go-skynet/bloomz.cpp" - bert "github.com/go-skynet/go-bert.cpp" -) - -// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 -var mutexMap sync.Mutex -var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) - -func gRPCModelOpts(c Config) *pb.ModelOptions { - b := 512 - if c.Batch != 0 { - b = c.Batch - } - return &pb.ModelOptions{ - ContextSize: int32(c.ContextSize), - Seed: int32(c.Seed), - NBatch: int32(b), - F16Memory: c.F16, - MLock: c.MMlock, - NUMA: c.NUMA, - Embeddings: c.Embeddings, - LowVRAM: c.LowVRAM, - NGPULayers: int32(c.NGPULayers), - MMap: c.MMap, - MainGPU: c.MainGPU, - Threads: int32(c.Threads), - TensorSplit: c.TensorSplit, - } -} - -func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { - promptCachePath := "" - if c.PromptCachePath != "" { - p := filepath.Join(modelPath, c.PromptCachePath) - os.MkdirAll(filepath.Dir(p), 0755) - promptCachePath = p - } - return &pb.PredictOptions{ - Temperature: float32(c.Temperature), - TopP: float32(c.TopP), - TopK: int32(c.TopK), - Tokens: int32(c.Maxtokens), - Threads: int32(c.Threads), - PromptCacheAll: c.PromptCacheAll, - PromptCacheRO: c.PromptCacheRO, - PromptCachePath: promptCachePath, - F16KV: c.F16, - DebugMode: c.Debug, - Grammar: c.Grammar, - - Mirostat: int32(c.Mirostat), - MirostatETA: float32(c.MirostatETA), - MirostatTAU: float32(c.MirostatTAU), - Debug: c.Debug, - StopPrompts: c.StopWords, - Repeat: int32(c.RepeatPenalty), - NKeep: int32(c.Keep), - Batch: int32(c.Batch), - IgnoreEOS: c.IgnoreEOS, - Seed: int32(c.Seed), - FrequencyPenalty: float32(c.FrequencyPenalty), - MLock: c.MMlock, - MMap: c.MMap, - MainGPU: c.MainGPU, - TensorSplit: c.TensorSplit, - TailFreeSamplingZ: float32(c.TFZ), - TypicalP: float32(c.TypicalP), - } -} - -func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) { - if c.Backend != model.StableDiffusionBackend { - return nil, fmt.Errorf("endpoint only working with stablediffusion models") - } - - inferenceModel, err := loader.BackendLoader( - model.WithBackendString(c.Backend), - model.WithAssetDir(o.assetsDestination), - model.WithThreads(uint32(c.Threads)), - model.WithModelFile(c.ImageGenerationAssets), - ) - if err != nil { - return nil, err - } - - var fn func() error - switch model := inferenceModel.(type) { - case *stablediffusion.StableDiffusion: - fn = func() error { - return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) - } - - default: - fn = func() error { - return fmt.Errorf("creation of images not supported by the backend") - } - } - - return func() error { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[c.Backend] - if !ok { - m := &sync.Mutex{} - mutexes[c.Backend] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - return fn() - }, nil -} - -func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, o *Option) (func() ([]float32, error), error) { - if !c.Embeddings { - return nil, fmt.Errorf("endpoint disabled for this model by API configuration") - } - - modelFile := c.Model - - grpcOpts := gRPCModelOpts(c) - - var inferenceModel interface{} - var err error - - opts := []model.Option{ - model.WithLoadGRPCOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), - model.WithAssetDir(o.assetsDestination), - model.WithModelFile(modelFile), - } - - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) - } else { - opts = append(opts, model.WithBackendString(c.Backend)) - inferenceModel, err = loader.BackendLoader(opts...) - } - if err != nil { - return nil, err - } - - var fn func() ([]float32, error) - switch model := inferenceModel.(type) { - case *grpc.Client: - fn = func() ([]float32, error) { - predictOptions := gRPCPredictOpts(c, loader.ModelPath) - if len(tokens) > 0 { - embeds := []int32{} - - for _, t := range tokens { - embeds = append(embeds, int32(t)) - } - predictOptions.EmbeddingTokens = embeds - - res, err := model.Embeddings(context.TODO(), predictOptions) - if err != nil { - return nil, err - } - - return res.Embeddings, nil - } - predictOptions.Embeddings = s - - res, err := model.Embeddings(context.TODO(), predictOptions) - if err != nil { - return nil, err - } - - return res.Embeddings, nil - } - - // bert embeddings - case *bert.Bert: - fn = func() ([]float32, error) { - if len(tokens) > 0 { - return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) - } - return model.Embeddings(s, bert.SetThreads(c.Threads)) - } - default: - fn = func() ([]float32, error) { - return nil, fmt.Errorf("embeddings not supported by the backend") - } - } - - return func() ([]float32, error) { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[modelFile] - if !ok { - m := &sync.Mutex{} - mutexes[modelFile] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - embeds, err := fn() - if err != nil { - return embeds, err - } - // Remove trailing 0s - for i := len(embeds) - 1; i >= 0; i-- { - if embeds[i] == 0.0 { - embeds = embeds[:i] - } else { - break - } - } - return embeds, nil - }, nil -} - -func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, tokenCallback func(string) bool) (func() (string, error), error) { - supportStreams := false - modelFile := c.Model - - grpcOpts := gRPCModelOpts(c) - - var inferenceModel interface{} - var err error - - opts := []model.Option{ - model.WithLoadGRPCOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), // GPT4all uses this - model.WithAssetDir(o.assetsDestination), - model.WithModelFile(modelFile), - } - - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(opts...) - } else { - opts = append(opts, model.WithBackendString(c.Backend)) - inferenceModel, err = loader.BackendLoader(opts...) - } - if err != nil { - return nil, err - } - - var fn func() (string, error) - - switch model := inferenceModel.(type) { - case *rwkv.RwkvState: - supportStreams = true - - fn = func() (string, error) { - stopWord := "\n" - if len(c.StopWords) > 0 { - stopWord = c.StopWords[0] - } - - if err := model.ProcessInput(s); err != nil { - return "", err - } - - response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) - - return response, nil - } - case *bloomz.Bloomz: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []bloomz.PredictOption{ - bloomz.SetTemperature(c.Temperature), - bloomz.SetTopP(c.TopP), - bloomz.SetTopK(c.TopK), - bloomz.SetTokens(c.Maxtokens), - bloomz.SetThreads(c.Threads), - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - - case *grpc.Client: - // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported - supportStreams = true - fn = func() (string, error) { - - opts := gRPCPredictOpts(c, loader.ModelPath) - opts.Prompt = s - if tokenCallback != nil { - ss := "" - err := model.PredictStream(context.TODO(), opts, func(s string) { - tokenCallback(s) - ss += s - }) - return ss, err - } else { - reply, err := model.Predict(context.TODO(), opts) - return reply.Message, err - } - } - case *langchain.HuggingFace: - fn = func() (string, error) { - - // Generate the prediction using the language model - predictOptions := []langchain.PredictOption{ - langchain.SetModel(c.Model), - langchain.SetMaxTokens(c.Maxtokens), - langchain.SetTemperature(c.Temperature), - langchain.SetStopWords(c.StopWords), - } - - pred, er := model.PredictHuggingFace(s, predictOptions...) - if er != nil { - return "", er - } - return pred.Completion, nil - } - } - - return func() (string, error) { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[modelFile] - if !ok { - m := &sync.Mutex{} - mutexes[modelFile] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - res, err := fn() - if tokenCallback != nil && !supportStreams { - tokenCallback(res) - } - return res, err - }, nil -} - -func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, o *Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { - result := []Choice{} - - n := input.N - - if input.N == 0 { - n = 1 - } - - // get the model function to call for the result - predFunc, err := ModelInference(predInput, loader, *config, o, tokenCallback) - if err != nil { - return result, err - } - - for i := 0; i < n; i++ { - prediction, err := predFunc() - if err != nil { - return result, err - } - - prediction = Finetune(*config, predInput, prediction) - cb(prediction, &result) - - //result = append(result, Choice{Text: prediction}) - - } - return result, err -} - -var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) -var mu sync.Mutex = sync.Mutex{} - -func Finetune(config Config, input, prediction string) string { - if config.Echo { - prediction = input + prediction - } - - for _, c := range config.Cutstrings { - mu.Lock() - reg, ok := cutstrings[c] - if !ok { - cutstrings[c] = regexp.MustCompile(c) - reg = cutstrings[c] - } - mu.Unlock() - prediction = reg.ReplaceAllString(prediction, "") - } - - for _, c := range config.TrimSpace { - prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) - } - return prediction - -} diff --git a/main.go b/main.go index fc1dea0..ec38afe 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "path/filepath" api "github.com/go-skynet/LocalAI/api" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/internal" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog" @@ -129,23 +130,23 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit Copyright: "Ettore Di Giacinto", Action: func(ctx *cli.Context) error { app, err := api.App( - api.WithConfigFile(ctx.String("config-file")), - api.WithJSONStringPreload(ctx.String("preload-models")), - api.WithYAMLConfigPreload(ctx.String("preload-models-config")), - api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), - api.WithContextSize(ctx.Int("context-size")), - api.WithDebug(ctx.Bool("debug")), - api.WithImageDir(ctx.String("image-path")), - api.WithAudioDir(ctx.String("audio-path")), - api.WithF16(ctx.Bool("f16")), - api.WithStringGalleries(ctx.String("galleries")), - api.WithDisableMessage(false), - api.WithCors(ctx.Bool("cors")), - api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), - api.WithThreads(ctx.Int("threads")), - api.WithBackendAssets(backendAssets), - api.WithBackendAssetsOutput(ctx.String("backend-assets-path")), - api.WithUploadLimitMB(ctx.Int("upload-limit"))) + options.WithConfigFile(ctx.String("config-file")), + options.WithJSONStringPreload(ctx.String("preload-models")), + options.WithYAMLConfigPreload(ctx.String("preload-models-config")), + options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), + options.WithContextSize(ctx.Int("context-size")), + options.WithDebug(ctx.Bool("debug")), + options.WithImageDir(ctx.String("image-path")), + options.WithAudioDir(ctx.String("audio-path")), + options.WithF16(ctx.Bool("f16")), + options.WithStringGalleries(ctx.String("galleries")), + options.WithDisableMessage(false), + options.WithCors(ctx.Bool("cors")), + options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), + options.WithThreads(ctx.Int("threads")), + options.WithBackendAssets(backendAssets), + options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), + options.WithUploadLimitMB(ctx.Int("upload-limit"))) if err != nil { return err } diff --git a/pkg/grpc/llm/falcon/falcon.go b/pkg/grpc/llm/falcon/falcon.go index 5d8cf75..0a7a533 100644 --- a/pkg/grpc/llm/falcon/falcon.go +++ b/pkg/grpc/llm/falcon/falcon.go @@ -126,6 +126,9 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { predictOptions := buildPredictOptions(opts) predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool { + if token == "<|endoftext|>" { + return true + } results <- token return true })) From 1d0ed95a54032fb5f071be21d702b5fa8c2b9d6d Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 06/12] feat: move other backends to grpc This finally makes everything more consistent Signed-off-by: Ettore Di Giacinto --- .gitignore | 4 +- Makefile | 149 +- api/api.go | 7 + api/api_test.go | 233 ++- api/backend/embeddings.go | 28 +- api/backend/image.go | 28 +- api/backend/llm.go | 118 +- api/backend/options.go | 26 - api/localai/localai.go | 17 +- api/openai/transcription.go | 18 +- cmd/grpc/bert-embeddings/main.go | 22 + cmd/grpc/bloomz/main.go | 23 + cmd/grpc/falcon-ggml/main.go | 23 + cmd/grpc/langchain-huggingface/main.go | 23 + cmd/grpc/piper/main.go | 23 + cmd/grpc/rwkv/main.go | 23 + cmd/grpc/stablediffusion/main.go | 23 + cmd/grpc/whisper/main.go | 23 + main.go | 9 + pkg/grpc/base/base.go | 42 + pkg/grpc/client.go | 61 +- pkg/grpc/image/stablediffusion.go | 33 + pkg/grpc/interface.go | 6 +- pkg/grpc/llm/bert/bert.go | 33 + pkg/grpc/llm/bloomz/bloomz.go | 59 + pkg/grpc/llm/falcon/falcon.go | 11 +- pkg/grpc/llm/gpt4all/gpt4all.go | 9 +- pkg/grpc/llm/langchain/langchain.go | 58 + pkg/grpc/llm/llama/llama.go | 7 +- pkg/grpc/llm/rwkv/rwkv.go | 71 + pkg/grpc/llm/transformers/dolly.go | 11 +- pkg/grpc/llm/transformers/falcon.go | 43 + pkg/grpc/llm/transformers/gpt2.go | 10 +- pkg/grpc/llm/transformers/gptj.go | 10 +- pkg/grpc/llm/transformers/gptneox.go | 10 +- pkg/grpc/llm/transformers/mpt.go | 10 +- pkg/grpc/llm/transformers/replit.go | 10 +- pkg/grpc/llm/transformers/starcoder.go | 11 +- pkg/grpc/proto/backend.pb.go | 1458 +++++++++++++++++ .../proto/{llmserver.proto => backend.proto} | 49 +- pkg/grpc/proto/backend_grpc.pb.go | 385 +++++ pkg/grpc/proto/llmserver.pb.go | 969 ----------- pkg/grpc/proto/llmserver_grpc.pb.go | 277 ---- pkg/grpc/server.go | 47 +- pkg/grpc/transcribe/whisper.go | 27 + pkg/grpc/tts/piper.go | 44 + pkg/grpc/whisper/api/api.go | 16 + pkg/{ => grpc}/whisper/whisper.go | 23 +- pkg/model/initializers.go | 171 +- pkg/model/loader.go | 34 +- pkg/model/options.go | 16 +- pkg/tts/generate.go | 12 - pkg/tts/generate_unsupported.go | 10 - pkg/tts/piper.go | 20 - 54 files changed, 3171 insertions(+), 1712 deletions(-) create mode 100644 cmd/grpc/bert-embeddings/main.go create mode 100644 cmd/grpc/bloomz/main.go create mode 100644 cmd/grpc/falcon-ggml/main.go create mode 100644 cmd/grpc/langchain-huggingface/main.go create mode 100644 cmd/grpc/piper/main.go create mode 100644 cmd/grpc/rwkv/main.go create mode 100644 cmd/grpc/stablediffusion/main.go create mode 100644 cmd/grpc/whisper/main.go create mode 100644 pkg/grpc/base/base.go create mode 100644 pkg/grpc/image/stablediffusion.go create mode 100644 pkg/grpc/llm/bert/bert.go create mode 100644 pkg/grpc/llm/bloomz/bloomz.go create mode 100644 pkg/grpc/llm/langchain/langchain.go create mode 100644 pkg/grpc/llm/rwkv/rwkv.go create mode 100644 pkg/grpc/llm/transformers/falcon.go create mode 100644 pkg/grpc/proto/backend.pb.go rename pkg/grpc/proto/{llmserver.proto => backend.proto} (67%) create mode 100644 pkg/grpc/proto/backend_grpc.pb.go delete mode 100644 pkg/grpc/proto/llmserver.pb.go delete mode 100644 pkg/grpc/proto/llmserver_grpc.pb.go create mode 100644 pkg/grpc/transcribe/whisper.go create mode 100644 pkg/grpc/tts/piper.go create mode 100644 pkg/grpc/whisper/api/api.go rename pkg/{ => grpc}/whisper/whisper.go (78%) delete mode 100644 pkg/tts/generate.go delete mode 100644 pkg/tts/generate_unsupported.go delete mode 100644 pkg/tts/piper.go diff --git a/.gitignore b/.gitignore index a40bf19..7b35ba9 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ go-llama go-stable-diffusion go-piper go-ggllm -piper +/piper *.a get-sources @@ -13,7 +13,7 @@ go-ggml-transformers go-gpt2 go-rwkv whisper.cpp -bloomz +/bloomz go-bert # LocalAI build binary diff --git a/Makefile b/Makefile index 610cc6f..9596bcb 100644 --- a/Makefile +++ b/Makefile @@ -67,9 +67,6 @@ WHITE := $(shell tput -Txterm setaf 7) CYAN := $(shell tput -Txterm setaf 6) RESET := $(shell tput -Txterm sgr0) -C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz -LIBRARY_PATH=$(shell pwd)/go-piper:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz - ifeq ($(BUILD_TYPE),openblas) CGO_LDFLAGS+=-lopenblas endif @@ -95,11 +92,17 @@ endif ifeq ($(findstring stablediffusion,$(GO_TAGS)),stablediffusion) OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a + OPTIONAL_GRPC+=backend-assets/grpc/stablediffusion endif ifeq ($(findstring tts,$(GO_TAGS)),tts) OPTIONAL_TARGETS+=go-piper/libpiper_binding.a OPTIONAL_TARGETS+=backend-assets/espeak-ng-data + OPTIONAL_GRPC+=backend-assets/grpc/piper +# die if ESPEAK_DATA is not set +ifndef ESPEAK_DATA +$(error ESPEAK_DATA is not set. Espeak data is required for tts) +endif endif .PHONY: all test build vendor @@ -128,9 +131,6 @@ go-piper: go-bert: git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp go-bert cd go-bert && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1 - @find ./go-bert -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} + - @find ./go-bert -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} + - @find ./go-bert -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} + ## stable diffusion go-stable-diffusion: @@ -144,9 +144,6 @@ go-stable-diffusion/libstablediffusion.a: go-rwkv: git clone --recurse-submodules $(RWKV_REPO) go-rwkv cd go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1 - @find ./go-rwkv -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} + - @find ./go-rwkv -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} + - @find ./go-rwkv -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} + go-rwkv/librwkv.a: go-rwkv cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. @@ -154,13 +151,7 @@ go-rwkv/librwkv.a: go-rwkv ## bloomz bloomz: git clone --recurse-submodules https://github.com/go-skynet/bloomz.cpp bloomz - @find ./bloomz -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} + - @find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} + - @find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} + - @find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} + - @find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} + - @find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_bloomz_replace/g' {} + - @find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_bloomz_replace/g' {} + + cd bloomz && git checkout -b build $(BLOOMZ_VERSION) && git submodule update --init --recursive --depth 1 bloomz/libbloomz.a: bloomz cd bloomz && make libbloomz.a @@ -179,6 +170,7 @@ backend-assets/espeak-ng-data: ifdef ESPEAK_DATA @cp -rf $(ESPEAK_DATA)/. backend-assets/espeak-ng-data else + @echo "ESPEAK_DATA not set, skipping tts. Note that this will break the tts functionality." @touch backend-assets/espeak-ng-data/keep endif @@ -196,9 +188,6 @@ go-ggml-transformers/libtransformers.a: go-ggml-transformers whisper.cpp: git clone https://github.com/ggerganov/whisper.cpp.git cd whisper.cpp && git checkout -b build $(WHISPER_CPP_VERSION) && git submodule update --init --recursive --depth 1 - @find ./whisper.cpp -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} + - @find ./whisper.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} + - @find ./whisper.cpp -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} + whisper.cpp/libwhisper.a: whisper.cpp cd whisper.cpp && make libwhisper.a @@ -249,7 +238,7 @@ rebuild: ## Rebuilds the project $(MAKE) -C go-ggllm clean $(MAKE) build -prepare: prepare-sources grpcs go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a $(OPTIONAL_TARGETS) +prepare: prepare-sources grpcs go-bert/libgobert.a go-ggml-transformers/libtransformers.a whisper.cpp/libwhisper.a $(OPTIONAL_TARGETS) touch $@ clean: ## Remove build related file @@ -277,7 +266,7 @@ build: prepare ## Build the project $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) $(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET}) - CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./ + CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./ ifeq ($(BUILD_TYPE),metal) cp go-llama/build/bin/ggml-metal.metal . endif @@ -286,12 +275,9 @@ dist: build mkdir -p release cp $(BINARY_NAME) release/$(BINARY_NAME)-$(BUILD_ID)-$(OS)-$(ARCH) -generic-build: ## Build the project using generic - BUILD_TYPE="generic" $(MAKE) build - ## Run run: prepare ## run local-ai - CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) run ./ + CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./ test-models/testmodel: mkdir test-models @@ -304,12 +290,42 @@ test-models/testmodel: wget https://raw.githubusercontent.com/saharNooby/rwkv.cpp/5eb8f09c146ea8124633ab041d9ea0b1f1db4459/rwkv/20B_tokenizer.json -O test-models/rwkv.tokenizer.json cp tests/models_fixtures/* test-models -test: prepare test-models/testmodel +prepare-test: grpcs cp -r backend-assets api cp tests/models_fixtures/* test-models - C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg - C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg - C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg + +test: prepare test-models/testmodel grpcs + @echo 'Running tests' + export GO_TAGS="tts stablediffusion" + $(MAKE) prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg + $(MAKE) test-gpt4all + $(MAKE) test-llama + $(MAKE) test-tts + $(MAKE) test-stablediffusion + +test-gpt4all: prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg + +test-llama: prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg + +test-tts: prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r ./api ./pkg + +test-stablediffusion: prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r ./api ./pkg + +test-container: + docker build --target requirements -t local-ai-test-container . + docker run --name localai-tests -e GO_TAGS=$(GO_TAGS) -ti -v $(abspath ./):/build local-ai-test-container make test + docker rm localai-tests + docker rmi local-ai-test-container ## Help: help: ## Show this help. @@ -325,51 +341,82 @@ help: ## Show this help. protogen: protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \ - pkg/grpc/proto/llmserver.proto + pkg/grpc/proto/backend.proto ## GRPC backend-assets/grpc: mkdir -p backend-assets/grpc -falcon-grpc: backend-assets/grpc go-ggllm/libggllm.a +backend-assets/grpc/falcon: backend-assets/grpc go-ggllm/libggllm.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggllm LIBRARY_PATH=$(shell pwd)/go-ggllm \ - $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/ -llama-grpc: backend-assets/grpc go-llama/libbinding.a +backend-assets/grpc/llama: backend-assets/grpc go-llama/libbinding.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \ - $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/ -gpt4all-grpc: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-bindings/golang/libgpt4all.a +backend-assets/grpc/gpt4all: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-bindings/golang/libgpt4all.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ \ - $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./cmd/grpc/gpt4all/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./cmd/grpc/gpt4all/ -dolly-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a +backend-assets/grpc/dolly: backend-assets/grpc go-ggml-transformers/libtransformers.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ - $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/dolly ./cmd/grpc/dolly/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/dolly ./cmd/grpc/dolly/ -gpt2-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a +backend-assets/grpc/gpt2: backend-assets/grpc go-ggml-transformers/libtransformers.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ - $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt2 ./cmd/grpc/gpt2/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt2 ./cmd/grpc/gpt2/ -gptj-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a +backend-assets/grpc/gptj: backend-assets/grpc go-ggml-transformers/libtransformers.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ - $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptj ./cmd/grpc/gptj/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptj ./cmd/grpc/gptj/ -gptneox-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a +backend-assets/grpc/gptneox: backend-assets/grpc go-ggml-transformers/libtransformers.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ - $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptneox ./cmd/grpc/gptneox/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptneox ./cmd/grpc/gptneox/ -mpt-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a +backend-assets/grpc/mpt: backend-assets/grpc go-ggml-transformers/libtransformers.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ - $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/mpt ./cmd/grpc/mpt/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/mpt ./cmd/grpc/mpt/ -replit-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a +backend-assets/grpc/replit: backend-assets/grpc go-ggml-transformers/libtransformers.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ - $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/replit ./cmd/grpc/replit/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/replit ./cmd/grpc/replit/ -starcoder-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a +backend-assets/grpc/falcon-ggml: backend-assets/grpc go-ggml-transformers/libtransformers.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ - $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/starcoder ./cmd/grpc/starcoder/ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon-ggml ./cmd/grpc/falcon-ggml/ + +backend-assets/grpc/starcoder: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/starcoder ./cmd/grpc/starcoder/ + +backend-assets/grpc/rwkv: backend-assets/grpc go-rwkv/librwkv.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-rwkv LIBRARY_PATH=$(shell pwd)/go-rwkv \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/rwkv ./cmd/grpc/rwkv/ + +backend-assets/grpc/bloomz: backend-assets/grpc bloomz/libbloomz.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/bloomz LIBRARY_PATH=$(shell pwd)/bloomz \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bloomz ./cmd/grpc/bloomz/ + +backend-assets/grpc/bert-embeddings: backend-assets/grpc go-bert/libgobert.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-bert LIBRARY_PATH=$(shell pwd)/go-bert \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bert-embeddings ./cmd/grpc/bert-embeddings/ + +backend-assets/grpc/langchain-huggingface: backend-assets/grpc + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/langchain-huggingface ./cmd/grpc/langchain-huggingface/ + +backend-assets/grpc/stablediffusion: backend-assets/grpc go-stable-diffusion/libstablediffusion.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/go-stable-diffusion/ \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./cmd/grpc/stablediffusion/ + +backend-assets/grpc/piper: backend-assets/grpc backend-assets/espeak-ng-data go-piper/libpiper_binding.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/go-piper \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./cmd/grpc/piper/ + +backend-assets/grpc/whisper: backend-assets/grpc whisper.cpp/libwhisper.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/whisper.cpp LIBRARY_PATH=$(shell pwd)/whisper.cpp \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./cmd/grpc/whisper/ -grpcs: falcon-grpc llama-grpc gpt4all-grpc dolly-grpc gpt2-grpc gptj-grpc gptneox-grpc mpt-grpc replit-grpc starcoder-grpc \ No newline at end of file +grpcs: backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC) \ No newline at end of file diff --git a/api/api.go b/api/api.go index 5d4f4c9..8dcefa2 100644 --- a/api/api.go +++ b/api/api.go @@ -173,5 +173,12 @@ func App(opts ...options.AppOption) (*fiber.App, error) { app.Get("/v1/models", openai.ListModelsEndpoint(options.Loader, cm)) app.Get("/models", openai.ListModelsEndpoint(options.Loader, cm)) + // turn off any process that was started by GRPC if the context is canceled + go func() { + <-options.Context.Done() + log.Debug().Msgf("Context canceled, shutting down") + options.Loader.StopGRPC() + }() + return app, nil } diff --git a/api/api_test.go b/api/api_test.go index a69e60d..ca840b5 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -5,7 +5,9 @@ import ( "context" "embed" "encoding/json" + "errors" "fmt" + "io" "io/ioutil" "net/http" "os" @@ -24,6 +26,7 @@ import ( openaigo "github.com/otiai10/openaigo" "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" ) type modelApplyRequest struct { @@ -203,7 +206,7 @@ var _ = Describe("API test", func() { fmt.Println(response) resp = response return response["processed"].(bool) - }, "360s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) Expect(resp["message"]).ToNot(ContainSubstring("error")) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml")) @@ -245,9 +248,8 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) - fmt.Println(response) return response["processed"].(bool) - }, "360s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) @@ -270,9 +272,8 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) - fmt.Println(response) return response["processed"].(bool) - }, "360s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) @@ -299,14 +300,58 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) - fmt.Println(response) return response["processed"].(bool) - }, "360s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) + By("testing completion") resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "openllama_3b", Prompt: "Count up to five: one, two, three, four, "}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Text).To(ContainSubstring("five")) + + By("testing functions") + resp2, err := client.CreateChatCompletion( + context.TODO(), + openai.ChatCompletionRequest{ + Model: "openllama_3b", + Messages: []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "What is the weather like in San Francisco (celsius)?", + }, + }, + Functions: []openai.FunctionDefinition{ + openai.FunctionDefinition{ + Name: "get_current_weather", + Description: "Get the current weather", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celcius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp2.Choices)).To(Equal(1)) + Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) + Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) + + var res map[string]string + err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) + Expect(err).ToNot(HaveOccurred()) + Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) + Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) + Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason)) }) It("runs gpt4all", Label("gpt4all"), func() { @@ -326,15 +371,126 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) - fmt.Println(response) return response["processed"].(bool) - }, "360s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-j", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "How are you?"}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).To(ContainSubstring("well")) }) + + }) + }) + + Context("Model gallery", func() { + BeforeEach(func() { + var err error + tmpdir, err = os.MkdirTemp("", "") + Expect(err).ToNot(HaveOccurred()) + + modelLoader = model.NewModelLoader(tmpdir) + c, cancel = context.WithCancel(context.Background()) + + galleries := []gallery.Gallery{ + { + Name: "model-gallery", + URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/index.yaml", + }, + } + + app, err = App( + options.WithContext(c), + options.WithAudioDir(tmpdir), + options.WithImageDir(tmpdir), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), + options.WithBackendAssets(backendAssets), + options.WithBackendAssetsOutput(tmpdir), + ) + Expect(err).ToNot(HaveOccurred()) + go app.Listen("127.0.0.1:9090") + + defaultConfig := openai.DefaultConfig("") + defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + + client2 = openaigo.NewClient("") + client2.BaseURL = defaultConfig.BaseURL + + // Wait for API to be ready + client = openai.NewClientWithConfig(defaultConfig) + Eventually(func() error { + _, err := client.ListModels(context.TODO()) + return err + }, "2m").ShouldNot(HaveOccurred()) + }) + + AfterEach(func() { + cancel() + app.Shutdown() + os.RemoveAll(tmpdir) + }) + It("installs and is capable to run tts", Label("tts"), func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + ID: "model-gallery@voice-en-us-kathleen-low", + }) + + Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) + + uuid := response["uuid"].(string) + + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + fmt.Println(response) + return response["processed"].(bool) + }, "360s", "10s").Should(Equal(true)) + + // An HTTP Post to the /tts endpoint should return a wav audio file + resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "en-us-kathleen-low.onnx"}`))) + Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) + dat, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) + + Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat))) + Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav")) + }) + It("installs and is capable to generate images", Label("stablediffusion"), func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + ID: "model-gallery@stablediffusion", + }) + + Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) + + uuid := response["uuid"].(string) + + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + fmt.Println(response) + return response["processed"].(bool) + }, "360s", "10s").Should(Equal(true)) + + resp, err := http.Post( + "http://127.0.0.1:9090/v1/images/generations", + "application/json", + bytes.NewBuffer([]byte(`{ + "prompt": "floating hair, portrait, ((loli)), ((one girl)), cute face, hidden hands, asymmetrical bangs, beautiful detailed eyes, eye shadow, hair ornament, ribbons, bowties, buttons, pleated skirt, (((masterpiece))), ((best quality)), colorful|((part of the head)), ((((mutated hands and fingers)))), deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, Octane renderer, lowres, bad anatomy, bad hands, text", + "mode": 2, "seed":9000, + "size": "256x256", "n":2}`))) + // The response should contain an URL + Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) + dat, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred(), string(dat)) + Expect(string(dat)).To(ContainSubstring("http://127.0.0.1:9090/"), string(dat)) + Expect(string(dat)).To(ContainSubstring(".png"), string(dat)) + }) }) @@ -401,7 +557,7 @@ var _ = Describe("API test", func() { It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 11 errors occurred:")) + Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:")) }) It("transcribes audio", func() { if runtime.GOOS != "linux" { @@ -446,14 +602,67 @@ var _ = Describe("API test", func() { }) Context("backends", func() { - It("runs rwkv", func() { + It("runs rwkv completion", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices) > 0).To(BeTrue()) - Expect(resp.Choices[0].Text).To(Equal(" five.")) + Expect(resp.Choices[0].Text).To(ContainSubstring("five")) + + stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{ + Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true, + }) + Expect(err).ToNot(HaveOccurred()) + defer stream.Close() + + tokens := 0 + text := "" + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + + Expect(err).ToNot(HaveOccurred()) + text += response.Choices[0].Text + tokens++ + } + Expect(text).ToNot(BeEmpty()) + Expect(text).To(ContainSubstring("five")) + Expect(tokens).ToNot(Or(Equal(1), Equal(0))) + }) + It("runs rwkv chat completion", func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + resp, err := client.CreateChatCompletion(context.TODO(), + openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Choices) > 0).To(BeTrue()) + Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("Sure"), ContainSubstring("five"))) + + stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}}) + Expect(err).ToNot(HaveOccurred()) + defer stream.Close() + + tokens := 0 + text := "" + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + + Expect(err).ToNot(HaveOccurred()) + text += response.Choices[0].Delta.Content + tokens++ + } + Expect(text).ToNot(BeEmpty()) + Expect(text).To(Or(ContainSubstring("Sure"), ContainSubstring("five"))) + + Expect(tokens).ToNot(Or(Equal(1), Equal(0))) }) }) }) diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go index cb77b6f..0310347 100644 --- a/api/backend/embeddings.go +++ b/api/backend/embeddings.go @@ -1,7 +1,6 @@ package backend import ( - "context" "fmt" "sync" @@ -9,7 +8,6 @@ import ( "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/grpc" model "github.com/go-skynet/LocalAI/pkg/model" - bert "github.com/go-skynet/go-bert.cpp" ) func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { @@ -25,10 +23,11 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. var err error opts := []model.Option{ - model.WithLoadGRPCOpts(grpcOpts), + model.WithLoadGRPCLLMModelOpts(grpcOpts), model.WithThreads(uint32(c.Threads)), model.WithAssetDir(o.AssetsDestination), model.WithModelFile(modelFile), + model.WithContext(o.Context), } if c.Backend == "" { @@ -54,7 +53,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. } predictOptions.EmbeddingTokens = embeds - res, err := model.Embeddings(context.TODO(), predictOptions) + res, err := model.Embeddings(o.Context, predictOptions) if err != nil { return nil, err } @@ -63,22 +62,13 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. } predictOptions.Embeddings = s - res, err := model.Embeddings(context.TODO(), predictOptions) + res, err := model.Embeddings(o.Context, predictOptions) if err != nil { return nil, err } return res.Embeddings, nil } - - // bert embeddings - case *bert.Bert: - fn = func() ([]float32, error) { - if len(tokens) > 0 { - return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) - } - return model.Embeddings(s, bert.SetThreads(c.Threads)) - } default: fn = func() ([]float32, error) { return nil, fmt.Errorf("embeddings not supported by the backend") @@ -87,7 +77,15 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config. return func() ([]float32, error) { // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - l := Lock(modelFile) + mutexMap.Lock() + l, ok := mutexes[modelFile] + if !ok { + m := &sync.Mutex{} + mutexes[modelFile] = m + l = m + } + mutexMap.Unlock() + l.Lock() defer l.Unlock() embeds, err := fn() diff --git a/api/backend/image.go b/api/backend/image.go index 47ae842..a631b3b 100644 --- a/api/backend/image.go +++ b/api/backend/image.go @@ -6,8 +6,8 @@ import ( config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/stablediffusion" ) func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { @@ -19,23 +19,27 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat model.WithBackendString(c.Backend), model.WithAssetDir(o.AssetsDestination), model.WithThreads(uint32(c.Threads)), + model.WithContext(o.Context), model.WithModelFile(c.ImageGenerationAssets), ) if err != nil { return nil, err } - var fn func() error - switch model := inferenceModel.(type) { - case *stablediffusion.StableDiffusion: - fn = func() error { - return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) - } - - default: - fn = func() error { - return fmt.Errorf("creation of images not supported by the backend") - } + fn := func() error { + _, err := inferenceModel.GenerateImage( + o.Context, + &proto.GenerateImageRequest{ + Height: int32(height), + Width: int32(width), + Mode: int32(mode), + Step: int32(step), + Seed: int32(seed), + PositivePrompt: positive_prompt, + NegativePrompt: negative_prompt, + Dst: dst, + }) + return err } return func() error { diff --git a/api/backend/llm.go b/api/backend/llm.go index d2f8ef6..8fcd6da 100644 --- a/api/backend/llm.go +++ b/api/backend/llm.go @@ -1,34 +1,30 @@ package backend import ( - "context" "regexp" "strings" "sync" - "github.com/donomii/go-rwkv.cpp" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/grpc" - "github.com/go-skynet/LocalAI/pkg/langchain" model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/bloomz.cpp" ) func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { - supportStreams := false modelFile := c.Model grpcOpts := gRPCModelOpts(c) - var inferenceModel interface{} + var inferenceModel *grpc.Client var err error opts := []model.Option{ - model.WithLoadGRPCOpts(grpcOpts), - model.WithThreads(uint32(c.Threads)), // GPT4all uses this + model.WithLoadGRPCLLMModelOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup model.WithAssetDir(o.AssetsDestination), model.WithModelFile(modelFile), + model.WithContext(o.Context), } if c.Backend == "" { @@ -41,95 +37,37 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt return nil, err } - var fn func() (string, error) - - switch model := inferenceModel.(type) { - case *rwkv.RwkvState: - supportStreams = true - - fn = func() (string, error) { - stopWord := "\n" - if len(c.StopWords) > 0 { - stopWord = c.StopWords[0] - } - - if err := model.ProcessInput(s); err != nil { - return "", err - } - - response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) - - return response, nil - } - case *bloomz.Bloomz: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []bloomz.PredictOption{ - bloomz.SetTemperature(c.Temperature), - bloomz.SetTopP(c.TopP), - bloomz.SetTopK(c.TopK), - bloomz.SetTokens(c.Maxtokens), - bloomz.SetThreads(c.Threads), - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - - case *grpc.Client: - // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported - supportStreams = true - fn = func() (string, error) { - - opts := gRPCPredictOpts(c, loader.ModelPath) - opts.Prompt = s - if tokenCallback != nil { - ss := "" - err := model.PredictStream(context.TODO(), opts, func(s string) { - tokenCallback(s) - ss += s - }) - return ss, err - } else { - reply, err := model.Predict(context.TODO(), opts) - return reply.Message, err - } - } - case *langchain.HuggingFace: - fn = func() (string, error) { - - // Generate the prediction using the language model - predictOptions := []langchain.PredictOption{ - langchain.SetModel(c.Model), - langchain.SetMaxTokens(c.Maxtokens), - langchain.SetTemperature(c.Temperature), - langchain.SetStopWords(c.StopWords), - } - - pred, er := model.PredictHuggingFace(s, predictOptions...) - if er != nil { - return "", er - } - return pred.Completion, nil + // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported + fn := func() (string, error) { + opts := gRPCPredictOpts(c, loader.ModelPath) + opts.Prompt = s + if tokenCallback != nil { + ss := "" + err := inferenceModel.PredictStream(o.Context, opts, func(s string) { + tokenCallback(s) + ss += s + }) + return ss, err + } else { + reply, err := inferenceModel.Predict(o.Context, opts) + return reply.Message, err } } return func() (string, error) { // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - l := Lock(modelFile) + mutexMap.Lock() + l, ok := mutexes[modelFile] + if !ok { + m := &sync.Mutex{} + mutexes[modelFile] = m + l = m + } + mutexMap.Unlock() + l.Lock() defer l.Unlock() - res, err := fn() - if tokenCallback != nil && !supportStreams { - tokenCallback(res) - } - return res, err + return fn() }, nil } diff --git a/api/backend/options.go b/api/backend/options.go index f19dbae..7038ffc 100644 --- a/api/backend/options.go +++ b/api/backend/options.go @@ -7,34 +7,8 @@ import ( pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" config "github.com/go-skynet/LocalAI/api/config" - "github.com/go-skynet/LocalAI/pkg/langchain" - "github.com/go-skynet/bloomz.cpp" ) -func langchainOptions(c config.Config) []langchain.PredictOption { - return []langchain.PredictOption{ - langchain.SetModel(c.Model), - langchain.SetMaxTokens(c.Maxtokens), - langchain.SetTemperature(c.Temperature), - langchain.SetStopWords(c.StopWords), - } -} - -func bloomzOptions(c config.Config) []bloomz.PredictOption { - // Generate the prediction using the language model - predictOptions := []bloomz.PredictOption{ - bloomz.SetTemperature(c.Temperature), - bloomz.SetTopP(c.TopP), - bloomz.SetTopK(c.TopK), - bloomz.SetTokens(c.Maxtokens), - bloomz.SetThreads(c.Threads), - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) - } - return predictOptions -} func gRPCModelOpts(c config.Config) *pb.ModelOptions { b := 512 if c.Batch != 0 { diff --git a/api/localai/localai.go b/api/localai/localai.go index f79e889..7c57c92 100644 --- a/api/localai/localai.go +++ b/api/localai/localai.go @@ -1,6 +1,7 @@ package localai import ( + "context" "fmt" "os" "path/filepath" @@ -8,8 +9,8 @@ import ( config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/tts" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/gofiber/fiber/v2" ) @@ -47,6 +48,7 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) piperModel, err := o.Loader.BackendLoader( model.WithBackendString(model.PiperBackend), model.WithModelFile(input.Model), + model.WithContext(o.Context), model.WithAssetDir(o.AssetsDestination)) if err != nil { return err @@ -56,13 +58,8 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return fmt.Errorf("could not load piper model") } - w, ok := piperModel.(*tts.Piper) - if !ok { - return fmt.Errorf("loader returned non-piper object %+v", w) - } - if err := os.MkdirAll(o.AudioDir, 0755); err != nil { - return err + return fmt.Errorf("failed creating audio directory: %s", err) } fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") @@ -74,7 +71,11 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) return err } - if err := w.TTS(input.Input, modelPath, filePath); err != nil { + if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ + Text: input.Input, + Model: modelPath, + Dst: filePath, + }); err != nil { return err } diff --git a/api/openai/transcription.go b/api/openai/transcription.go index 279f320..346693c 100644 --- a/api/openai/transcription.go +++ b/api/openai/transcription.go @@ -1,6 +1,7 @@ package openai import ( + "context" "fmt" "io" "net/http" @@ -8,11 +9,10 @@ import ( "path" "path/filepath" - "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" model "github.com/go-skynet/LocalAI/pkg/model" - whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" @@ -64,6 +64,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe whisperModel, err := o.Loader.BackendLoader( model.WithBackendString(model.WhisperBackend), model.WithModelFile(config.Model), + model.WithContext(o.Context), model.WithThreads(uint32(config.Threads)), model.WithAssetDir(o.AssetsDestination)) if err != nil { @@ -74,18 +75,17 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe return fmt.Errorf("could not load whisper model") } - w, ok := whisperModel.(whisper.Model) - if !ok { - return fmt.Errorf("loader returned non-whisper object") - } - - tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) + tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ + Dst: dst, + Language: input.Language, + Threads: uint32(config.Threads), + }) if err != nil { return err } log.Debug().Msgf("Trascribed: %+v", tr) // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr}) + return c.Status(http.StatusOK).JSON(tr) } } diff --git a/cmd/grpc/bert-embeddings/main.go b/cmd/grpc/bert-embeddings/main.go new file mode 100644 index 0000000..008c30d --- /dev/null +++ b/cmd/grpc/bert-embeddings/main.go @@ -0,0 +1,22 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" + bert "github.com/go-skynet/LocalAI/pkg/grpc/llm/bert" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &bert.Embeddings{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/bloomz/main.go b/cmd/grpc/bloomz/main.go new file mode 100644 index 0000000..7348cab --- /dev/null +++ b/cmd/grpc/bloomz/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + bloomz "github.com/go-skynet/LocalAI/pkg/grpc/llm/bloomz" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &bloomz.LLM{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/falcon-ggml/main.go b/cmd/grpc/falcon-ggml/main.go new file mode 100644 index 0000000..677c660 --- /dev/null +++ b/cmd/grpc/falcon-ggml/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Falcon{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/langchain-huggingface/main.go b/cmd/grpc/langchain-huggingface/main.go new file mode 100644 index 0000000..ab96584 --- /dev/null +++ b/cmd/grpc/langchain-huggingface/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + langchain "github.com/go-skynet/LocalAI/pkg/grpc/llm/langchain" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &langchain.LLM{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/piper/main.go b/cmd/grpc/piper/main.go new file mode 100644 index 0000000..7de80e2 --- /dev/null +++ b/cmd/grpc/piper/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + tts "github.com/go-skynet/LocalAI/pkg/grpc/tts" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &tts.Piper{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/rwkv/main.go b/cmd/grpc/rwkv/main.go new file mode 100644 index 0000000..f050a7c --- /dev/null +++ b/cmd/grpc/rwkv/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + rwkv "github.com/go-skynet/LocalAI/pkg/grpc/llm/rwkv" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &rwkv.LLM{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/stablediffusion/main.go b/cmd/grpc/stablediffusion/main.go new file mode 100644 index 0000000..76b4a5a --- /dev/null +++ b/cmd/grpc/stablediffusion/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + image "github.com/go-skynet/LocalAI/pkg/grpc/image" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &image.StableDiffusion{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/whisper/main.go b/cmd/grpc/whisper/main.go new file mode 100644 index 0000000..8d4a5fe --- /dev/null +++ b/cmd/grpc/whisper/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transcribe "github.com/go-skynet/LocalAI/pkg/grpc/transcribe" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transcribe.Whisper{}); err != nil { + panic(err) + } +} diff --git a/main.go b/main.go index ec38afe..3f534b0 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,9 @@ package main import ( "os" + "os/signal" "path/filepath" + "syscall" api "github.com/go-skynet/LocalAI/api" "github.com/go-skynet/LocalAI/api/options" @@ -15,6 +17,13 @@ import ( func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + // clean up process + go func() { + c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + os.Exit(1) + }() path, err := os.Getwd() if err != nil { diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go new file mode 100644 index 0000000..a6d89f2 --- /dev/null +++ b/pkg/grpc/base/base.go @@ -0,0 +1,42 @@ +package base + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" +) + +type Base struct { +} + +func (llm *Base) Load(opts *pb.ModelOptions) error { + return fmt.Errorf("unimplemented") + +} + +func (llm *Base) Predict(opts *pb.PredictOptions) (string, error) { + return "", fmt.Errorf("unimplemented") +} + +func (llm *Base) PredictStream(opts *pb.PredictOptions, results chan string) error { + return fmt.Errorf("unimplemented") +} + +func (llm *Base) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return []float32{}, fmt.Errorf("unimplemented") +} + +func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { + return fmt.Errorf("unimplemented") +} + +func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (api.Result, error) { + return api.Result{}, fmt.Errorf("unimplemented") +} + +func (llm *Base) TTS(*pb.TTSRequest) error { + return fmt.Errorf("unimplemented") +} diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index 06628eb..bbc40bf 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -7,6 +7,7 @@ import ( "time" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) @@ -28,7 +29,7 @@ func (c *Client) HealthCheck(ctx context.Context) bool { return false } defer conn.Close() - client := pb.NewLLMClient(conn) + client := pb.NewBackendClient(conn) // The healthcheck call shouldn't take long time ctx, cancel := context.WithTimeout(ctx, 10*time.Second) @@ -53,7 +54,7 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ... return nil, err } defer conn.Close() - client := pb.NewLLMClient(conn) + client := pb.NewBackendClient(conn) return client.Embedding(ctx, in, opts...) } @@ -64,7 +65,7 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp return nil, err } defer conn.Close() - client := pb.NewLLMClient(conn) + client := pb.NewBackendClient(conn) return client.Predict(ctx, in, opts...) } @@ -75,7 +76,7 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp return nil, err } defer conn.Close() - client := pb.NewLLMClient(conn) + client := pb.NewBackendClient(conn) return client.LoadModel(ctx, in, opts...) } @@ -85,7 +86,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun return err } defer conn.Close() - client := pb.NewLLMClient(conn) + client := pb.NewBackendClient(conn) stream, err := client.PredictStream(ctx, in, opts...) if err != nil { @@ -107,3 +108,53 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun return nil } + +func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.GenerateImage(ctx, in, opts...) +} + +func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.TTS(ctx, in, opts...) +} + +func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*api.Result, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + res, err := client.AudioTranscription(ctx, in, opts...) + if err != nil { + return nil, err + } + tresult := &api.Result{} + for _, s := range res.Segments { + tks := []int{} + for _, t := range s.Tokens { + tks = append(tks, int(t)) + } + tresult.Segments = append(tresult.Segments, + api.Segment{ + Text: s.Text, + Id: int(s.Id), + Start: time.Duration(s.Start), + End: time.Duration(s.End), + Tokens: tks, + }) + } + tresult.Text = res.Text + return tresult, err +} diff --git a/pkg/grpc/image/stablediffusion.go b/pkg/grpc/image/stablediffusion.go new file mode 100644 index 0000000..ce0275e --- /dev/null +++ b/pkg/grpc/image/stablediffusion.go @@ -0,0 +1,33 @@ +package image + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/stablediffusion" +) + +type StableDiffusion struct { + base.Base + stablediffusion *stablediffusion.StableDiffusion +} + +func (sd *StableDiffusion) Load(opts *pb.ModelOptions) error { + var err error + // Note: the Model here is a path to a directory containing the model files + sd.stablediffusion, err = stablediffusion.New(opts.Model) + return err +} + +func (sd *StableDiffusion) GenerateImage(opts *pb.GenerateImageRequest) error { + return sd.stablediffusion.GenerateImage( + int(opts.Height), + int(opts.Width), + int(opts.Mode), + int(opts.Step), + int(opts.Seed), + opts.PositivePrompt, + opts.NegativePrompt, + opts.Dst) +} diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 70b830f..6832a95 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -2,11 +2,15 @@ package grpc import ( pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" ) type LLM interface { Predict(*pb.PredictOptions) (string, error) - PredictStream(*pb.PredictOptions, chan string) + PredictStream(*pb.PredictOptions, chan string) error Load(*pb.ModelOptions) error Embeddings(*pb.PredictOptions) ([]float32, error) + GenerateImage(*pb.GenerateImageRequest) error + AudioTranscription(*pb.TranscriptRequest) (api.Result, error) + TTS(*pb.TTSRequest) error } diff --git a/pkg/grpc/llm/bert/bert.go b/pkg/grpc/llm/bert/bert.go new file mode 100644 index 0000000..7692797 --- /dev/null +++ b/pkg/grpc/llm/bert/bert.go @@ -0,0 +1,33 @@ +package bert + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + bert "github.com/go-skynet/go-bert.cpp" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" +) + +type Embeddings struct { + base.Base + bert *bert.Bert +} + +func (llm *Embeddings) Load(opts *pb.ModelOptions) error { + model, err := bert.New(opts.Model) + llm.bert = model + return err +} + +func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + if len(opts.EmbeddingTokens) > 0 { + tokens := []int{} + for _, t := range opts.EmbeddingTokens { + tokens = append(tokens, int(t)) + } + return llm.bert.TokenEmbeddings(tokens, bert.SetThreads(int(opts.Threads))) + } + + return llm.bert.Embeddings(opts.Embeddings, bert.SetThreads(int(opts.Threads))) +} diff --git a/pkg/grpc/llm/bloomz/bloomz.go b/pkg/grpc/llm/bloomz/bloomz.go new file mode 100644 index 0000000..daa2264 --- /dev/null +++ b/pkg/grpc/llm/bloomz/bloomz.go @@ -0,0 +1,59 @@ +package bloomz + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + "github.com/go-skynet/bloomz.cpp" +) + +type LLM struct { + base.Base + + bloomz *bloomz.Bloomz +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + model, err := bloomz.New(opts.Model) + llm.bloomz = model + return err +} + +func buildPredictOptions(opts *pb.PredictOptions) []bloomz.PredictOption { + predictOptions := []bloomz.PredictOption{ + bloomz.SetTemperature(float64(opts.Temperature)), + bloomz.SetTopP(float64(opts.TopP)), + bloomz.SetTopK(int(opts.TopK)), + bloomz.SetTokens(int(opts.Tokens)), + bloomz.SetThreads(int(opts.Threads)), + } + + if opts.Seed != 0 { + predictOptions = append(predictOptions, bloomz.SetSeed(int(opts.Seed))) + } + + return predictOptions +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + return llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/falcon/falcon.go b/pkg/grpc/llm/falcon/falcon.go index 0a7a533..3c0f84e 100644 --- a/pkg/grpc/llm/falcon/falcon.go +++ b/pkg/grpc/llm/falcon/falcon.go @@ -5,12 +5,15 @@ package falcon import ( "fmt" + "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" ggllm "github.com/mudler/go-ggllm.cpp" ) type LLM struct { + base.Base + falcon *ggllm.Falcon } @@ -42,10 +45,6 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error { return err } -func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - return nil, fmt.Errorf("not implemented") -} - func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption { predictOptions := []ggllm.PredictOption{ ggllm.SetTemperature(float64(opts.Temperature)), @@ -122,7 +121,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) } -func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { predictOptions := buildPredictOptions(opts) predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool { @@ -140,4 +139,6 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { } close(results) }() + + return nil } diff --git a/pkg/grpc/llm/gpt4all/gpt4all.go b/pkg/grpc/llm/gpt4all/gpt4all.go index 0d7dac5..e17afc1 100644 --- a/pkg/grpc/llm/gpt4all/gpt4all.go +++ b/pkg/grpc/llm/gpt4all/gpt4all.go @@ -5,11 +5,14 @@ package gpt4all import ( "fmt" + "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" ) type LLM struct { + base.Base + gpt4all *gpt4all.Model } @@ -39,7 +42,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...) } -func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { predictOptions := buildPredictOptions(opts) go func() { @@ -54,8 +57,6 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { llm.gpt4all.SetTokenCallback(nil) close(results) }() -} -func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - return []float32{}, fmt.Errorf("not implemented") + return nil } diff --git a/pkg/grpc/llm/langchain/langchain.go b/pkg/grpc/llm/langchain/langchain.go new file mode 100644 index 0000000..5d5f94b --- /dev/null +++ b/pkg/grpc/llm/langchain/langchain.go @@ -0,0 +1,58 @@ +package langchain + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/langchain" +) + +type LLM struct { + base.Base + + langchain *langchain.HuggingFace + model string +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + llm.langchain, _ = langchain.NewHuggingFace(opts.Model) + llm.model = opts.Model + return nil +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + o := []langchain.PredictOption{ + langchain.SetModel(llm.model), + langchain.SetMaxTokens(int(opts.Tokens)), + langchain.SetTemperature(float64(opts.Temperature)), + langchain.SetStopWords(opts.StopPrompts), + } + pred, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...) + if err != nil { + return "", err + } + return pred.Completion, nil +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + o := []langchain.PredictOption{ + langchain.SetModel(llm.model), + langchain.SetMaxTokens(int(opts.Tokens)), + langchain.SetTemperature(float64(opts.Temperature)), + langchain.SetStopWords(opts.StopPrompts), + } + go func() { + res, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res.Completion + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/llama/llama.go b/pkg/grpc/llm/llama/llama.go index a31e274..82063b7 100644 --- a/pkg/grpc/llm/llama/llama.go +++ b/pkg/grpc/llm/llama/llama.go @@ -5,11 +5,14 @@ package llama import ( "fmt" + "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" "github.com/go-skynet/go-llama.cpp" ) type LLM struct { + base.Base + llama *llama.LLama } @@ -133,7 +136,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...) } -func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { predictOptions := buildPredictOptions(opts) predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool { @@ -148,6 +151,8 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { } close(results) }() + + return nil } func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { diff --git a/pkg/grpc/llm/rwkv/rwkv.go b/pkg/grpc/llm/rwkv/rwkv.go new file mode 100644 index 0000000..f54c14b --- /dev/null +++ b/pkg/grpc/llm/rwkv/rwkv.go @@ -0,0 +1,71 @@ +package rwkv + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + "path/filepath" + + "github.com/donomii/go-rwkv.cpp" + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" +) + +const tokenizerSuffix = ".tokenizer.json" + +type LLM struct { + base.Base + + rwkv *rwkv.RwkvState +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + modelPath := filepath.Dir(opts.Model) + modelFile := filepath.Base(opts.Model) + model := rwkv.LoadFiles(opts.Model, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads())) + + if model == nil { + return fmt.Errorf("could not load model") + } + llm.rwkv = model + return nil +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + + stopWord := "\n" + if len(opts.StopPrompts) > 0 { + stopWord = opts.StopPrompts[0] + } + + if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil { + return "", err + } + + response := llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), nil) + + return response, nil +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + + stopWord := "\n" + if len(opts.StopPrompts) > 0 { + stopWord = opts.StopPrompts[0] + } + + if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil { + fmt.Println("Error processing input: ", err) + return + } + + llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), func(s string) bool { + results <- s + return true + }) + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/transformers/dolly.go b/pkg/grpc/llm/transformers/dolly.go index 28a44a7..d5f3093 100644 --- a/pkg/grpc/llm/transformers/dolly.go +++ b/pkg/grpc/llm/transformers/dolly.go @@ -5,12 +5,15 @@ package transformers import ( "fmt" + "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type Dolly struct { + base.Base + dolly *transformers.Dolly } @@ -20,16 +23,12 @@ func (llm *Dolly) Load(opts *pb.ModelOptions) error { return err } -func (llm *Dolly) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - return nil, fmt.Errorf("not implemented") -} - func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) { return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict -func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) { +func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error { go func() { res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -39,4 +38,6 @@ func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) { results <- res close(results) }() + + return nil } diff --git a/pkg/grpc/llm/transformers/falcon.go b/pkg/grpc/llm/transformers/falcon.go new file mode 100644 index 0000000..982e43e --- /dev/null +++ b/pkg/grpc/llm/transformers/falcon.go @@ -0,0 +1,43 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Falcon struct { + base.Base + + falcon *transformers.Falcon +} + +func (llm *Falcon) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewFalcon(opts.Model) + llm.falcon = model + return err +} + +func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) { + return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/transformers/gpt2.go b/pkg/grpc/llm/transformers/gpt2.go index 0eaf787..85a4112 100644 --- a/pkg/grpc/llm/transformers/gpt2.go +++ b/pkg/grpc/llm/transformers/gpt2.go @@ -5,12 +5,15 @@ package transformers import ( "fmt" + "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type GPT2 struct { + base.Base + gpt2 *transformers.GPT2 } @@ -20,16 +23,12 @@ func (llm *GPT2) Load(opts *pb.ModelOptions) error { return err } -func (llm *GPT2) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - return nil, fmt.Errorf("not implemented") -} - func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) { return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict -func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) { +func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error { go func() { res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -39,4 +38,5 @@ func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) { results <- res close(results) }() + return nil } diff --git a/pkg/grpc/llm/transformers/gptj.go b/pkg/grpc/llm/transformers/gptj.go index a7138ef..e2bc3bf 100644 --- a/pkg/grpc/llm/transformers/gptj.go +++ b/pkg/grpc/llm/transformers/gptj.go @@ -5,12 +5,15 @@ package transformers import ( "fmt" + "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type GPTJ struct { + base.Base + gptj *transformers.GPTJ } @@ -20,16 +23,12 @@ func (llm *GPTJ) Load(opts *pb.ModelOptions) error { return err } -func (llm *GPTJ) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - return nil, fmt.Errorf("not implemented") -} - func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) { return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict -func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) { +func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error { go func() { res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -39,4 +38,5 @@ func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) { results <- res close(results) }() + return nil } diff --git a/pkg/grpc/llm/transformers/gptneox.go b/pkg/grpc/llm/transformers/gptneox.go index 2edf4ba..ca6db94 100644 --- a/pkg/grpc/llm/transformers/gptneox.go +++ b/pkg/grpc/llm/transformers/gptneox.go @@ -5,12 +5,15 @@ package transformers import ( "fmt" + "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type GPTNeoX struct { + base.Base + gptneox *transformers.GPTNeoX } @@ -20,16 +23,12 @@ func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error { return err } -func (llm *GPTNeoX) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - return nil, fmt.Errorf("not implemented") -} - func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) { return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict -func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) { +func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error { go func() { res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -39,4 +38,5 @@ func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) results <- res close(results) }() + return nil } diff --git a/pkg/grpc/llm/transformers/mpt.go b/pkg/grpc/llm/transformers/mpt.go index ab88418..d2b9ff1 100644 --- a/pkg/grpc/llm/transformers/mpt.go +++ b/pkg/grpc/llm/transformers/mpt.go @@ -5,12 +5,15 @@ package transformers import ( "fmt" + "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type MPT struct { + base.Base + mpt *transformers.MPT } @@ -20,16 +23,12 @@ func (llm *MPT) Load(opts *pb.ModelOptions) error { return err } -func (llm *MPT) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - return nil, fmt.Errorf("not implemented") -} - func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) { return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict -func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) { +func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error { go func() { res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -39,4 +38,5 @@ func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) { results <- res close(results) }() + return nil } diff --git a/pkg/grpc/llm/transformers/replit.go b/pkg/grpc/llm/transformers/replit.go index ca1d66f..4b26ffd 100644 --- a/pkg/grpc/llm/transformers/replit.go +++ b/pkg/grpc/llm/transformers/replit.go @@ -5,12 +5,15 @@ package transformers import ( "fmt" + "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type Replit struct { + base.Base + replit *transformers.Replit } @@ -20,16 +23,12 @@ func (llm *Replit) Load(opts *pb.ModelOptions) error { return err } -func (llm *Replit) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - return nil, fmt.Errorf("not implemented") -} - func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) { return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict -func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) { +func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error { go func() { res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -39,4 +38,5 @@ func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) { results <- res close(results) }() + return nil } diff --git a/pkg/grpc/llm/transformers/starcoder.go b/pkg/grpc/llm/transformers/starcoder.go index 6e1a94b..7631274 100644 --- a/pkg/grpc/llm/transformers/starcoder.go +++ b/pkg/grpc/llm/transformers/starcoder.go @@ -5,12 +5,15 @@ package transformers import ( "fmt" + "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" transformers "github.com/go-skynet/go-ggml-transformers.cpp" ) type Starcoder struct { + base.Base + starcoder *transformers.Starcoder } @@ -20,16 +23,12 @@ func (llm *Starcoder) Load(opts *pb.ModelOptions) error { return err } -func (llm *Starcoder) Embeddings(opts *pb.PredictOptions) ([]float32, error) { - return nil, fmt.Errorf("not implemented") -} - func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) { return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) } // fallback to Predict -func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) { +func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error { go func() { res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) @@ -39,4 +38,6 @@ func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string results <- res close(results) }() + + return nil } diff --git a/pkg/grpc/proto/backend.pb.go b/pkg/grpc/proto/backend.pb.go new file mode 100644 index 0000000..dcf14a3 --- /dev/null +++ b/pkg/grpc/proto/backend.pb.go @@ -0,0 +1,1458 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.15.8 +// source: pkg/grpc/proto/backend.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type HealthMessage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *HealthMessage) Reset() { + *x = HealthMessage{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HealthMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HealthMessage) ProtoMessage() {} + +func (x *HealthMessage) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HealthMessage.ProtoReflect.Descriptor instead. +func (*HealthMessage) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{0} +} + +// The request message containing the user's name. +type PredictOptions struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Prompt string `protobuf:"bytes,1,opt,name=Prompt,proto3" json:"Prompt,omitempty"` + Seed int32 `protobuf:"varint,2,opt,name=Seed,proto3" json:"Seed,omitempty"` + Threads int32 `protobuf:"varint,3,opt,name=Threads,proto3" json:"Threads,omitempty"` + Tokens int32 `protobuf:"varint,4,opt,name=Tokens,proto3" json:"Tokens,omitempty"` + TopK int32 `protobuf:"varint,5,opt,name=TopK,proto3" json:"TopK,omitempty"` + Repeat int32 `protobuf:"varint,6,opt,name=Repeat,proto3" json:"Repeat,omitempty"` + Batch int32 `protobuf:"varint,7,opt,name=Batch,proto3" json:"Batch,omitempty"` + NKeep int32 `protobuf:"varint,8,opt,name=NKeep,proto3" json:"NKeep,omitempty"` + Temperature float32 `protobuf:"fixed32,9,opt,name=Temperature,proto3" json:"Temperature,omitempty"` + Penalty float32 `protobuf:"fixed32,10,opt,name=Penalty,proto3" json:"Penalty,omitempty"` + F16KV bool `protobuf:"varint,11,opt,name=F16KV,proto3" json:"F16KV,omitempty"` + DebugMode bool `protobuf:"varint,12,opt,name=DebugMode,proto3" json:"DebugMode,omitempty"` + StopPrompts []string `protobuf:"bytes,13,rep,name=StopPrompts,proto3" json:"StopPrompts,omitempty"` + IgnoreEOS bool `protobuf:"varint,14,opt,name=IgnoreEOS,proto3" json:"IgnoreEOS,omitempty"` + TailFreeSamplingZ float32 `protobuf:"fixed32,15,opt,name=TailFreeSamplingZ,proto3" json:"TailFreeSamplingZ,omitempty"` + TypicalP float32 `protobuf:"fixed32,16,opt,name=TypicalP,proto3" json:"TypicalP,omitempty"` + FrequencyPenalty float32 `protobuf:"fixed32,17,opt,name=FrequencyPenalty,proto3" json:"FrequencyPenalty,omitempty"` + PresencePenalty float32 `protobuf:"fixed32,18,opt,name=PresencePenalty,proto3" json:"PresencePenalty,omitempty"` + Mirostat int32 `protobuf:"varint,19,opt,name=Mirostat,proto3" json:"Mirostat,omitempty"` + MirostatETA float32 `protobuf:"fixed32,20,opt,name=MirostatETA,proto3" json:"MirostatETA,omitempty"` + MirostatTAU float32 `protobuf:"fixed32,21,opt,name=MirostatTAU,proto3" json:"MirostatTAU,omitempty"` + PenalizeNL bool `protobuf:"varint,22,opt,name=PenalizeNL,proto3" json:"PenalizeNL,omitempty"` + LogitBias string `protobuf:"bytes,23,opt,name=LogitBias,proto3" json:"LogitBias,omitempty"` + MLock bool `protobuf:"varint,25,opt,name=MLock,proto3" json:"MLock,omitempty"` + MMap bool `protobuf:"varint,26,opt,name=MMap,proto3" json:"MMap,omitempty"` + PromptCacheAll bool `protobuf:"varint,27,opt,name=PromptCacheAll,proto3" json:"PromptCacheAll,omitempty"` + PromptCacheRO bool `protobuf:"varint,28,opt,name=PromptCacheRO,proto3" json:"PromptCacheRO,omitempty"` + Grammar string `protobuf:"bytes,29,opt,name=Grammar,proto3" json:"Grammar,omitempty"` + MainGPU string `protobuf:"bytes,30,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"` + TensorSplit string `protobuf:"bytes,31,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"` + TopP float32 `protobuf:"fixed32,32,opt,name=TopP,proto3" json:"TopP,omitempty"` + PromptCachePath string `protobuf:"bytes,33,opt,name=PromptCachePath,proto3" json:"PromptCachePath,omitempty"` + Debug bool `protobuf:"varint,34,opt,name=Debug,proto3" json:"Debug,omitempty"` + EmbeddingTokens []int32 `protobuf:"varint,35,rep,packed,name=EmbeddingTokens,proto3" json:"EmbeddingTokens,omitempty"` + Embeddings string `protobuf:"bytes,36,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` +} + +func (x *PredictOptions) Reset() { + *x = PredictOptions{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PredictOptions) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PredictOptions) ProtoMessage() {} + +func (x *PredictOptions) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PredictOptions.ProtoReflect.Descriptor instead. +func (*PredictOptions) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{1} +} + +func (x *PredictOptions) GetPrompt() string { + if x != nil { + return x.Prompt + } + return "" +} + +func (x *PredictOptions) GetSeed() int32 { + if x != nil { + return x.Seed + } + return 0 +} + +func (x *PredictOptions) GetThreads() int32 { + if x != nil { + return x.Threads + } + return 0 +} + +func (x *PredictOptions) GetTokens() int32 { + if x != nil { + return x.Tokens + } + return 0 +} + +func (x *PredictOptions) GetTopK() int32 { + if x != nil { + return x.TopK + } + return 0 +} + +func (x *PredictOptions) GetRepeat() int32 { + if x != nil { + return x.Repeat + } + return 0 +} + +func (x *PredictOptions) GetBatch() int32 { + if x != nil { + return x.Batch + } + return 0 +} + +func (x *PredictOptions) GetNKeep() int32 { + if x != nil { + return x.NKeep + } + return 0 +} + +func (x *PredictOptions) GetTemperature() float32 { + if x != nil { + return x.Temperature + } + return 0 +} + +func (x *PredictOptions) GetPenalty() float32 { + if x != nil { + return x.Penalty + } + return 0 +} + +func (x *PredictOptions) GetF16KV() bool { + if x != nil { + return x.F16KV + } + return false +} + +func (x *PredictOptions) GetDebugMode() bool { + if x != nil { + return x.DebugMode + } + return false +} + +func (x *PredictOptions) GetStopPrompts() []string { + if x != nil { + return x.StopPrompts + } + return nil +} + +func (x *PredictOptions) GetIgnoreEOS() bool { + if x != nil { + return x.IgnoreEOS + } + return false +} + +func (x *PredictOptions) GetTailFreeSamplingZ() float32 { + if x != nil { + return x.TailFreeSamplingZ + } + return 0 +} + +func (x *PredictOptions) GetTypicalP() float32 { + if x != nil { + return x.TypicalP + } + return 0 +} + +func (x *PredictOptions) GetFrequencyPenalty() float32 { + if x != nil { + return x.FrequencyPenalty + } + return 0 +} + +func (x *PredictOptions) GetPresencePenalty() float32 { + if x != nil { + return x.PresencePenalty + } + return 0 +} + +func (x *PredictOptions) GetMirostat() int32 { + if x != nil { + return x.Mirostat + } + return 0 +} + +func (x *PredictOptions) GetMirostatETA() float32 { + if x != nil { + return x.MirostatETA + } + return 0 +} + +func (x *PredictOptions) GetMirostatTAU() float32 { + if x != nil { + return x.MirostatTAU + } + return 0 +} + +func (x *PredictOptions) GetPenalizeNL() bool { + if x != nil { + return x.PenalizeNL + } + return false +} + +func (x *PredictOptions) GetLogitBias() string { + if x != nil { + return x.LogitBias + } + return "" +} + +func (x *PredictOptions) GetMLock() bool { + if x != nil { + return x.MLock + } + return false +} + +func (x *PredictOptions) GetMMap() bool { + if x != nil { + return x.MMap + } + return false +} + +func (x *PredictOptions) GetPromptCacheAll() bool { + if x != nil { + return x.PromptCacheAll + } + return false +} + +func (x *PredictOptions) GetPromptCacheRO() bool { + if x != nil { + return x.PromptCacheRO + } + return false +} + +func (x *PredictOptions) GetGrammar() string { + if x != nil { + return x.Grammar + } + return "" +} + +func (x *PredictOptions) GetMainGPU() string { + if x != nil { + return x.MainGPU + } + return "" +} + +func (x *PredictOptions) GetTensorSplit() string { + if x != nil { + return x.TensorSplit + } + return "" +} + +func (x *PredictOptions) GetTopP() float32 { + if x != nil { + return x.TopP + } + return 0 +} + +func (x *PredictOptions) GetPromptCachePath() string { + if x != nil { + return x.PromptCachePath + } + return "" +} + +func (x *PredictOptions) GetDebug() bool { + if x != nil { + return x.Debug + } + return false +} + +func (x *PredictOptions) GetEmbeddingTokens() []int32 { + if x != nil { + return x.EmbeddingTokens + } + return nil +} + +func (x *PredictOptions) GetEmbeddings() string { + if x != nil { + return x.Embeddings + } + return "" +} + +// The response message containing the result +type Reply struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *Reply) Reset() { + *x = Reply{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Reply) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Reply) ProtoMessage() {} + +func (x *Reply) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Reply.ProtoReflect.Descriptor instead. +func (*Reply) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{2} +} + +func (x *Reply) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type ModelOptions struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Model string `protobuf:"bytes,1,opt,name=Model,proto3" json:"Model,omitempty"` + ContextSize int32 `protobuf:"varint,2,opt,name=ContextSize,proto3" json:"ContextSize,omitempty"` + Seed int32 `protobuf:"varint,3,opt,name=Seed,proto3" json:"Seed,omitempty"` + NBatch int32 `protobuf:"varint,4,opt,name=NBatch,proto3" json:"NBatch,omitempty"` + F16Memory bool `protobuf:"varint,5,opt,name=F16Memory,proto3" json:"F16Memory,omitempty"` + MLock bool `protobuf:"varint,6,opt,name=MLock,proto3" json:"MLock,omitempty"` + MMap bool `protobuf:"varint,7,opt,name=MMap,proto3" json:"MMap,omitempty"` + VocabOnly bool `protobuf:"varint,8,opt,name=VocabOnly,proto3" json:"VocabOnly,omitempty"` + LowVRAM bool `protobuf:"varint,9,opt,name=LowVRAM,proto3" json:"LowVRAM,omitempty"` + Embeddings bool `protobuf:"varint,10,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` + NUMA bool `protobuf:"varint,11,opt,name=NUMA,proto3" json:"NUMA,omitempty"` + NGPULayers int32 `protobuf:"varint,12,opt,name=NGPULayers,proto3" json:"NGPULayers,omitempty"` + MainGPU string `protobuf:"bytes,13,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"` + TensorSplit string `protobuf:"bytes,14,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"` + Threads int32 `protobuf:"varint,15,opt,name=Threads,proto3" json:"Threads,omitempty"` + LibrarySearchPath string `protobuf:"bytes,16,opt,name=LibrarySearchPath,proto3" json:"LibrarySearchPath,omitempty"` +} + +func (x *ModelOptions) Reset() { + *x = ModelOptions{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ModelOptions) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ModelOptions) ProtoMessage() {} + +func (x *ModelOptions) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ModelOptions.ProtoReflect.Descriptor instead. +func (*ModelOptions) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{3} +} + +func (x *ModelOptions) GetModel() string { + if x != nil { + return x.Model + } + return "" +} + +func (x *ModelOptions) GetContextSize() int32 { + if x != nil { + return x.ContextSize + } + return 0 +} + +func (x *ModelOptions) GetSeed() int32 { + if x != nil { + return x.Seed + } + return 0 +} + +func (x *ModelOptions) GetNBatch() int32 { + if x != nil { + return x.NBatch + } + return 0 +} + +func (x *ModelOptions) GetF16Memory() bool { + if x != nil { + return x.F16Memory + } + return false +} + +func (x *ModelOptions) GetMLock() bool { + if x != nil { + return x.MLock + } + return false +} + +func (x *ModelOptions) GetMMap() bool { + if x != nil { + return x.MMap + } + return false +} + +func (x *ModelOptions) GetVocabOnly() bool { + if x != nil { + return x.VocabOnly + } + return false +} + +func (x *ModelOptions) GetLowVRAM() bool { + if x != nil { + return x.LowVRAM + } + return false +} + +func (x *ModelOptions) GetEmbeddings() bool { + if x != nil { + return x.Embeddings + } + return false +} + +func (x *ModelOptions) GetNUMA() bool { + if x != nil { + return x.NUMA + } + return false +} + +func (x *ModelOptions) GetNGPULayers() int32 { + if x != nil { + return x.NGPULayers + } + return 0 +} + +func (x *ModelOptions) GetMainGPU() string { + if x != nil { + return x.MainGPU + } + return "" +} + +func (x *ModelOptions) GetTensorSplit() string { + if x != nil { + return x.TensorSplit + } + return "" +} + +func (x *ModelOptions) GetThreads() int32 { + if x != nil { + return x.Threads + } + return 0 +} + +func (x *ModelOptions) GetLibrarySearchPath() string { + if x != nil { + return x.LibrarySearchPath + } + return "" +} + +type Result struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` + Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"` +} + +func (x *Result) Reset() { + *x = Result{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Result) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Result) ProtoMessage() {} + +func (x *Result) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Result.ProtoReflect.Descriptor instead. +func (*Result) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{4} +} + +func (x *Result) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *Result) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +type EmbeddingResult struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Embeddings []float32 `protobuf:"fixed32,1,rep,packed,name=embeddings,proto3" json:"embeddings,omitempty"` +} + +func (x *EmbeddingResult) Reset() { + *x = EmbeddingResult{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EmbeddingResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmbeddingResult) ProtoMessage() {} + +func (x *EmbeddingResult) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmbeddingResult.ProtoReflect.Descriptor instead. +func (*EmbeddingResult) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{5} +} + +func (x *EmbeddingResult) GetEmbeddings() []float32 { + if x != nil { + return x.Embeddings + } + return nil +} + +type TranscriptRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Dst string `protobuf:"bytes,2,opt,name=dst,proto3" json:"dst,omitempty"` + Language string `protobuf:"bytes,3,opt,name=language,proto3" json:"language,omitempty"` + Threads uint32 `protobuf:"varint,4,opt,name=threads,proto3" json:"threads,omitempty"` +} + +func (x *TranscriptRequest) Reset() { + *x = TranscriptRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TranscriptRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TranscriptRequest) ProtoMessage() {} + +func (x *TranscriptRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TranscriptRequest.ProtoReflect.Descriptor instead. +func (*TranscriptRequest) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{6} +} + +func (x *TranscriptRequest) GetDst() string { + if x != nil { + return x.Dst + } + return "" +} + +func (x *TranscriptRequest) GetLanguage() string { + if x != nil { + return x.Language + } + return "" +} + +func (x *TranscriptRequest) GetThreads() uint32 { + if x != nil { + return x.Threads + } + return 0 +} + +type TranscriptResult struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Segments []*TranscriptSegment `protobuf:"bytes,1,rep,name=segments,proto3" json:"segments,omitempty"` + Text string `protobuf:"bytes,2,opt,name=text,proto3" json:"text,omitempty"` +} + +func (x *TranscriptResult) Reset() { + *x = TranscriptResult{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TranscriptResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TranscriptResult) ProtoMessage() {} + +func (x *TranscriptResult) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TranscriptResult.ProtoReflect.Descriptor instead. +func (*TranscriptResult) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{7} +} + +func (x *TranscriptResult) GetSegments() []*TranscriptSegment { + if x != nil { + return x.Segments + } + return nil +} + +func (x *TranscriptResult) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +type TranscriptSegment struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Id int32 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + Start int64 `protobuf:"varint,2,opt,name=start,proto3" json:"start,omitempty"` + End int64 `protobuf:"varint,3,opt,name=end,proto3" json:"end,omitempty"` + Text string `protobuf:"bytes,4,opt,name=text,proto3" json:"text,omitempty"` + Tokens []int32 `protobuf:"varint,5,rep,packed,name=tokens,proto3" json:"tokens,omitempty"` +} + +func (x *TranscriptSegment) Reset() { + *x = TranscriptSegment{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TranscriptSegment) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TranscriptSegment) ProtoMessage() {} + +func (x *TranscriptSegment) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[8] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TranscriptSegment.ProtoReflect.Descriptor instead. +func (*TranscriptSegment) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{8} +} + +func (x *TranscriptSegment) GetId() int32 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *TranscriptSegment) GetStart() int64 { + if x != nil { + return x.Start + } + return 0 +} + +func (x *TranscriptSegment) GetEnd() int64 { + if x != nil { + return x.End + } + return 0 +} + +func (x *TranscriptSegment) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +func (x *TranscriptSegment) GetTokens() []int32 { + if x != nil { + return x.Tokens + } + return nil +} + +type GenerateImageRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Height int32 `protobuf:"varint,1,opt,name=height,proto3" json:"height,omitempty"` + Width int32 `protobuf:"varint,2,opt,name=width,proto3" json:"width,omitempty"` + Mode int32 `protobuf:"varint,3,opt,name=mode,proto3" json:"mode,omitempty"` + Step int32 `protobuf:"varint,4,opt,name=step,proto3" json:"step,omitempty"` + Seed int32 `protobuf:"varint,5,opt,name=seed,proto3" json:"seed,omitempty"` + PositivePrompt string `protobuf:"bytes,6,opt,name=positive_prompt,json=positivePrompt,proto3" json:"positive_prompt,omitempty"` + NegativePrompt string `protobuf:"bytes,7,opt,name=negative_prompt,json=negativePrompt,proto3" json:"negative_prompt,omitempty"` + Dst string `protobuf:"bytes,8,opt,name=dst,proto3" json:"dst,omitempty"` +} + +func (x *GenerateImageRequest) Reset() { + *x = GenerateImageRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GenerateImageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GenerateImageRequest) ProtoMessage() {} + +func (x *GenerateImageRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[9] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GenerateImageRequest.ProtoReflect.Descriptor instead. +func (*GenerateImageRequest) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{9} +} + +func (x *GenerateImageRequest) GetHeight() int32 { + if x != nil { + return x.Height + } + return 0 +} + +func (x *GenerateImageRequest) GetWidth() int32 { + if x != nil { + return x.Width + } + return 0 +} + +func (x *GenerateImageRequest) GetMode() int32 { + if x != nil { + return x.Mode + } + return 0 +} + +func (x *GenerateImageRequest) GetStep() int32 { + if x != nil { + return x.Step + } + return 0 +} + +func (x *GenerateImageRequest) GetSeed() int32 { + if x != nil { + return x.Seed + } + return 0 +} + +func (x *GenerateImageRequest) GetPositivePrompt() string { + if x != nil { + return x.PositivePrompt + } + return "" +} + +func (x *GenerateImageRequest) GetNegativePrompt() string { + if x != nil { + return x.NegativePrompt + } + return "" +} + +func (x *GenerateImageRequest) GetDst() string { + if x != nil { + return x.Dst + } + return "" +} + +type TTSRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"` + Model string `protobuf:"bytes,2,opt,name=model,proto3" json:"model,omitempty"` + Dst string `protobuf:"bytes,3,opt,name=dst,proto3" json:"dst,omitempty"` +} + +func (x *TTSRequest) Reset() { + *x = TTSRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TTSRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TTSRequest) ProtoMessage() {} + +func (x *TTSRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[10] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TTSRequest.ProtoReflect.Descriptor instead. +func (*TTSRequest) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{10} +} + +func (x *TTSRequest) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +func (x *TTSRequest) GetModel() string { + if x != nil { + return x.Model + } + return "" +} + +func (x *TTSRequest) GetDst() string { + if x != nil { + return x.Dst + } + return "" +} + +var File_pkg_grpc_proto_backend_proto protoreflect.FileDescriptor + +var file_pkg_grpc_proto_backend_proto_rawDesc = []byte{ + 0x0a, 0x1c, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x2f, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, + 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x22, 0x0f, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xa0, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, + 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x50, + 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x72, 0x6f, + 0x6d, 0x70, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x04, 0x53, 0x65, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, + 0x64, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, + 0x73, 0x12, 0x16, 0x0a, 0x06, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x06, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x6f, 0x70, + 0x4b, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x4b, 0x12, 0x16, 0x0a, + 0x06, 0x52, 0x65, 0x70, 0x65, 0x61, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x52, + 0x65, 0x70, 0x65, 0x61, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x42, 0x61, 0x74, 0x63, 0x68, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x4e, + 0x4b, 0x65, 0x65, 0x70, 0x18, 0x08, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x4e, 0x4b, 0x65, 0x65, + 0x70, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, 0x75, 0x72, 0x65, + 0x18, 0x09, 0x20, 0x01, 0x28, 0x02, 0x52, 0x0b, 0x54, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, + 0x75, 0x72, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x0a, + 0x20, 0x01, 0x28, 0x02, 0x52, 0x07, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, 0x14, 0x0a, + 0x05, 0x46, 0x31, 0x36, 0x4b, 0x56, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x46, 0x31, + 0x36, 0x4b, 0x56, 0x12, 0x1c, 0x0a, 0x09, 0x44, 0x65, 0x62, 0x75, 0x67, 0x4d, 0x6f, 0x64, 0x65, + 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x44, 0x65, 0x62, 0x75, 0x67, 0x4d, 0x6f, 0x64, + 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x53, 0x74, 0x6f, 0x70, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x73, + 0x18, 0x0d, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x53, 0x74, 0x6f, 0x70, 0x50, 0x72, 0x6f, 0x6d, + 0x70, 0x74, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x45, 0x4f, 0x53, + 0x18, 0x0e, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x45, 0x4f, + 0x53, 0x12, 0x2c, 0x0a, 0x11, 0x54, 0x61, 0x69, 0x6c, 0x46, 0x72, 0x65, 0x65, 0x53, 0x61, 0x6d, + 0x70, 0x6c, 0x69, 0x6e, 0x67, 0x5a, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x02, 0x52, 0x11, 0x54, 0x61, + 0x69, 0x6c, 0x46, 0x72, 0x65, 0x65, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x69, 0x6e, 0x67, 0x5a, 0x12, + 0x1a, 0x0a, 0x08, 0x54, 0x79, 0x70, 0x69, 0x63, 0x61, 0x6c, 0x50, 0x18, 0x10, 0x20, 0x01, 0x28, + 0x02, 0x52, 0x08, 0x54, 0x79, 0x70, 0x69, 0x63, 0x61, 0x6c, 0x50, 0x12, 0x2a, 0x0a, 0x10, 0x46, + 0x72, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, + 0x11, 0x20, 0x01, 0x28, 0x02, 0x52, 0x10, 0x46, 0x72, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, + 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x65, 0x73, 0x65, + 0x6e, 0x63, 0x65, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x12, 0x20, 0x01, 0x28, 0x02, + 0x52, 0x0f, 0x50, 0x72, 0x65, 0x73, 0x65, 0x6e, 0x63, 0x65, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, + 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x18, 0x13, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x08, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x12, 0x20, 0x0a, + 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x45, 0x54, 0x41, 0x18, 0x14, 0x20, 0x01, + 0x28, 0x02, 0x52, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x45, 0x54, 0x41, 0x12, + 0x20, 0x0a, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x54, 0x41, 0x55, 0x18, 0x15, + 0x20, 0x01, 0x28, 0x02, 0x52, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x54, 0x41, + 0x55, 0x12, 0x1e, 0x0a, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x18, + 0x16, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, + 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x18, 0x17, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, + 0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, + 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, + 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, + 0x6c, 0x12, 0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, + 0x52, 0x4f, 0x18, 0x1c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, + 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, 0x12, 0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, + 0x61, 0x72, 0x18, 0x1d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, + 0x72, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, + 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, + 0x04, 0x54, 0x6f, 0x70, 0x50, 0x18, 0x20, 0x20, 0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, + 0x50, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, + 0x50, 0x61, 0x74, 0x68, 0x18, 0x21, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, + 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, + 0x65, 0x62, 0x75, 0x67, 0x18, 0x22, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, + 0x67, 0x12, 0x28, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x23, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0f, 0x45, 0x6d, 0x62, 0x65, + 0x64, 0x64, 0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x45, + 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x21, 0x0a, 0x05, 0x52, + 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xca, + 0x03, 0x0a, 0x0c, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, + 0x14, 0x0a, 0x05, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x20, 0x0a, 0x0b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x53, 0x69, 0x7a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x43, 0x6f, 0x6e, 0x74, + 0x65, 0x78, 0x74, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x53, 0x65, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x4e, + 0x42, 0x61, 0x74, 0x63, 0x68, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x4e, 0x42, 0x61, + 0x74, 0x63, 0x68, 0x12, 0x1c, 0x0a, 0x09, 0x46, 0x31, 0x36, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x46, 0x31, 0x36, 0x4d, 0x65, 0x6d, 0x6f, 0x72, + 0x79, 0x12, 0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x1c, 0x0a, 0x09, 0x56, + 0x6f, 0x63, 0x61, 0x62, 0x4f, 0x6e, 0x6c, 0x79, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, + 0x56, 0x6f, 0x63, 0x61, 0x62, 0x4f, 0x6e, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x4c, 0x6f, 0x77, + 0x56, 0x52, 0x41, 0x4d, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x4c, 0x6f, 0x77, 0x56, + 0x52, 0x41, 0x4d, 0x12, 0x1e, 0x0a, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, + 0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, + 0x6e, 0x67, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x55, 0x4d, 0x41, 0x18, 0x0b, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x04, 0x4e, 0x55, 0x4d, 0x41, 0x12, 0x1e, 0x0a, 0x0a, 0x4e, 0x47, 0x50, 0x55, 0x4c, + 0x61, 0x79, 0x65, 0x72, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x4e, 0x47, 0x50, + 0x55, 0x4c, 0x61, 0x79, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, + 0x50, 0x55, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, + 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, + 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, + 0x6c, 0x69, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x18, 0x0f, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x12, 0x2c, 0x0a, + 0x11, 0x4c, 0x69, 0x62, 0x72, 0x61, 0x72, 0x79, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x50, 0x61, + 0x74, 0x68, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x11, 0x4c, 0x69, 0x62, 0x72, 0x61, 0x72, + 0x79, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x50, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x06, 0x52, + 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, + 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x22, 0x31, 0x0a, 0x0f, 0x45, 0x6d, 0x62, + 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, + 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x02, + 0x52, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x5b, 0x0a, 0x11, + 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x64, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x6c, 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6c, 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x12, + 0x18, 0x0a, 0x07, 0x74, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, + 0x52, 0x07, 0x74, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x22, 0x5e, 0x0a, 0x10, 0x54, 0x72, 0x61, + 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x36, 0x0a, + 0x08, 0x73, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x1a, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, + 0x72, 0x69, 0x70, 0x74, 0x53, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x08, 0x73, 0x65, 0x67, + 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x65, 0x78, 0x74, 0x22, 0x77, 0x0a, 0x11, 0x54, 0x72, 0x61, + 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x53, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x0e, + 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, 0x69, 0x64, 0x12, 0x14, + 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x73, + 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x65, 0x78, 0x74, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x65, 0x78, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x05, 0x52, 0x06, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x73, 0x22, 0xe4, 0x01, 0x0a, 0x14, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x49, + 0x6d, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x68, + 0x65, 0x69, 0x67, 0x68, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x68, 0x65, 0x69, + 0x67, 0x68, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x77, 0x69, 0x64, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x05, 0x52, 0x05, 0x77, 0x69, 0x64, 0x74, 0x68, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x6f, 0x64, + 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x12, 0x12, 0x0a, + 0x04, 0x73, 0x74, 0x65, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x73, 0x74, 0x65, + 0x70, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x65, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x04, 0x73, 0x65, 0x65, 0x64, 0x12, 0x27, 0x0a, 0x0f, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x76, + 0x65, 0x5f, 0x70, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, + 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x76, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x12, 0x27, + 0x0a, 0x0f, 0x6e, 0x65, 0x67, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, 0x72, 0x6f, 0x6d, 0x70, + 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x67, 0x61, 0x74, 0x69, 0x76, + 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x73, 0x74, 0x18, 0x08, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x73, 0x74, 0x22, 0x48, 0x0a, 0x0a, 0x54, 0x54, 0x53, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x65, 0x78, 0x74, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x65, 0x78, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6d, + 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, 0x6f, 0x64, 0x65, + 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x64, 0x73, 0x74, 0x32, 0xeb, 0x03, 0x0a, 0x07, 0x42, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x12, + 0x32, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, + 0x65, 0x6e, 0x64, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, + 0x79, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x17, + 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, + 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x35, 0x0a, 0x09, 0x4c, 0x6f, 0x61, + 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, + 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0f, 0x2e, + 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, + 0x12, 0x3c, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, + 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, + 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x12, 0x40, + 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x17, 0x2e, 0x62, 0x61, + 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x18, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x45, + 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, + 0x12, 0x41, 0x0a, 0x0d, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, + 0x65, 0x12, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x47, 0x65, 0x6e, 0x65, + 0x72, 0x61, 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, + 0x74, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x12, 0x41, 0x75, 0x64, 0x69, 0x6f, 0x54, 0x72, 0x61, 0x6e, + 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x2e, 0x62, 0x61, 0x63, 0x6b, + 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, + 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, + 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x03, 0x54, 0x54, 0x53, 0x12, 0x13, 0x2e, 0x62, 0x61, 0x63, 0x6b, + 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x54, 0x53, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, + 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, + 0x00, 0x42, 0x5a, 0x0a, 0x19, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2e, 0x6c, + 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x42, 0x0e, + 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x42, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x50, 0x01, + 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, + 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, + 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_grpc_proto_backend_proto_rawDescOnce sync.Once + file_pkg_grpc_proto_backend_proto_rawDescData = file_pkg_grpc_proto_backend_proto_rawDesc +) + +func file_pkg_grpc_proto_backend_proto_rawDescGZIP() []byte { + file_pkg_grpc_proto_backend_proto_rawDescOnce.Do(func() { + file_pkg_grpc_proto_backend_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_proto_backend_proto_rawDescData) + }) + return file_pkg_grpc_proto_backend_proto_rawDescData +} + +var file_pkg_grpc_proto_backend_proto_msgTypes = make([]protoimpl.MessageInfo, 11) +var file_pkg_grpc_proto_backend_proto_goTypes = []interface{}{ + (*HealthMessage)(nil), // 0: backend.HealthMessage + (*PredictOptions)(nil), // 1: backend.PredictOptions + (*Reply)(nil), // 2: backend.Reply + (*ModelOptions)(nil), // 3: backend.ModelOptions + (*Result)(nil), // 4: backend.Result + (*EmbeddingResult)(nil), // 5: backend.EmbeddingResult + (*TranscriptRequest)(nil), // 6: backend.TranscriptRequest + (*TranscriptResult)(nil), // 7: backend.TranscriptResult + (*TranscriptSegment)(nil), // 8: backend.TranscriptSegment + (*GenerateImageRequest)(nil), // 9: backend.GenerateImageRequest + (*TTSRequest)(nil), // 10: backend.TTSRequest +} +var file_pkg_grpc_proto_backend_proto_depIdxs = []int32{ + 8, // 0: backend.TranscriptResult.segments:type_name -> backend.TranscriptSegment + 0, // 1: backend.Backend.Health:input_type -> backend.HealthMessage + 1, // 2: backend.Backend.Predict:input_type -> backend.PredictOptions + 3, // 3: backend.Backend.LoadModel:input_type -> backend.ModelOptions + 1, // 4: backend.Backend.PredictStream:input_type -> backend.PredictOptions + 1, // 5: backend.Backend.Embedding:input_type -> backend.PredictOptions + 9, // 6: backend.Backend.GenerateImage:input_type -> backend.GenerateImageRequest + 6, // 7: backend.Backend.AudioTranscription:input_type -> backend.TranscriptRequest + 10, // 8: backend.Backend.TTS:input_type -> backend.TTSRequest + 2, // 9: backend.Backend.Health:output_type -> backend.Reply + 2, // 10: backend.Backend.Predict:output_type -> backend.Reply + 4, // 11: backend.Backend.LoadModel:output_type -> backend.Result + 2, // 12: backend.Backend.PredictStream:output_type -> backend.Reply + 5, // 13: backend.Backend.Embedding:output_type -> backend.EmbeddingResult + 4, // 14: backend.Backend.GenerateImage:output_type -> backend.Result + 7, // 15: backend.Backend.AudioTranscription:output_type -> backend.TranscriptResult + 4, // 16: backend.Backend.TTS:output_type -> backend.Result + 9, // [9:17] is the sub-list for method output_type + 1, // [1:9] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_pkg_grpc_proto_backend_proto_init() } +func file_pkg_grpc_proto_backend_proto_init() { + if File_pkg_grpc_proto_backend_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_grpc_proto_backend_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HealthMessage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PredictOptions); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Reply); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ModelOptions); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Result); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EmbeddingResult); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TranscriptRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TranscriptResult); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TranscriptSegment); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GenerateImageRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TTSRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_grpc_proto_backend_proto_rawDesc, + NumEnums: 0, + NumMessages: 11, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_pkg_grpc_proto_backend_proto_goTypes, + DependencyIndexes: file_pkg_grpc_proto_backend_proto_depIdxs, + MessageInfos: file_pkg_grpc_proto_backend_proto_msgTypes, + }.Build() + File_pkg_grpc_proto_backend_proto = out.File + file_pkg_grpc_proto_backend_proto_rawDesc = nil + file_pkg_grpc_proto_backend_proto_goTypes = nil + file_pkg_grpc_proto_backend_proto_depIdxs = nil +} diff --git a/pkg/grpc/proto/llmserver.proto b/pkg/grpc/proto/backend.proto similarity index 67% rename from pkg/grpc/proto/llmserver.proto rename to pkg/grpc/proto/backend.proto index 32fe0ff..7e0bdb7 100644 --- a/pkg/grpc/proto/llmserver.proto +++ b/pkg/grpc/proto/backend.proto @@ -2,17 +2,20 @@ syntax = "proto3"; option go_package = "github.com/go-skynet/LocalAI/pkg/grpc/proto"; option java_multiple_files = true; -option java_package = "io.skynet.localai.llmserver"; -option java_outer_classname = "LLMServer"; +option java_package = "io.skynet.localai.backend"; +option java_outer_classname = "LocalAIBackend"; -package llm; +package backend; -service LLM { +service Backend { rpc Health(HealthMessage) returns (Reply) {} rpc Predict(PredictOptions) returns (Reply) {} rpc LoadModel(ModelOptions) returns (Result) {} rpc PredictStream(PredictOptions) returns (stream Reply) {} rpc Embedding(PredictOptions) returns (EmbeddingResult) {} + rpc GenerateImage(GenerateImageRequest) returns (Result) {} + rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} + rpc TTS(TTSRequest) returns (Result) {} } message HealthMessage {} @@ -87,4 +90,40 @@ message Result { message EmbeddingResult { repeated float embeddings = 1; -} \ No newline at end of file +} + +message TranscriptRequest { + string dst = 2; + string language = 3; + uint32 threads = 4; +} + +message TranscriptResult { + repeated TranscriptSegment segments = 1; + string text = 2; +} + +message TranscriptSegment { + int32 id = 1; + int64 start = 2; + int64 end = 3; + string text = 4; + repeated int32 tokens = 5; +} + +message GenerateImageRequest { + int32 height = 1; + int32 width = 2; + int32 mode = 3; + int32 step = 4; + int32 seed = 5; + string positive_prompt = 6; + string negative_prompt = 7; + string dst = 8; +} + +message TTSRequest { + string text = 1; + string model = 2; + string dst = 3; +} diff --git a/pkg/grpc/proto/backend_grpc.pb.go b/pkg/grpc/proto/backend_grpc.pb.go new file mode 100644 index 0000000..b9d7dd8 --- /dev/null +++ b/pkg/grpc/proto/backend_grpc.pb.go @@ -0,0 +1,385 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.15.8 +// source: pkg/grpc/proto/backend.proto + +package proto + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// BackendClient is the client API for Backend service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type BackendClient interface { + Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) + Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) + LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) + PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) + Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) + GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) + AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) + TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) +} + +type backendClient struct { + cc grpc.ClientConnInterface +} + +func NewBackendClient(cc grpc.ClientConnInterface) BackendClient { + return &backendClient{cc} +} + +func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { + out := new(Reply) + err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { + out := new(Reply) + err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { + out := new(Result) + err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) { + stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...) + if err != nil { + return nil, err + } + x := &backendPredictStreamClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type Backend_PredictStreamClient interface { + Recv() (*Reply, error) + grpc.ClientStream +} + +type backendPredictStreamClient struct { + grpc.ClientStream +} + +func (x *backendPredictStreamClient) Recv() (*Reply, error) { + m := new(Reply) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { + out := new(EmbeddingResult) + err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) { + out := new(Result) + err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) { + out := new(TranscriptResult) + err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) { + out := new(Result) + err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// BackendServer is the server API for Backend service. +// All implementations must embed UnimplementedBackendServer +// for forward compatibility +type BackendServer interface { + Health(context.Context, *HealthMessage) (*Reply, error) + Predict(context.Context, *PredictOptions) (*Reply, error) + LoadModel(context.Context, *ModelOptions) (*Result, error) + PredictStream(*PredictOptions, Backend_PredictStreamServer) error + Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) + GenerateImage(context.Context, *GenerateImageRequest) (*Result, error) + AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error) + TTS(context.Context, *TTSRequest) (*Result, error) + mustEmbedUnimplementedBackendServer() +} + +// UnimplementedBackendServer must be embedded to have forward compatible implementations. +type UnimplementedBackendServer struct { +} + +func (UnimplementedBackendServer) Health(context.Context, *HealthMessage) (*Reply, error) { + return nil, status.Errorf(codes.Unimplemented, "method Health not implemented") +} +func (UnimplementedBackendServer) Predict(context.Context, *PredictOptions) (*Reply, error) { + return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented") +} +func (UnimplementedBackendServer) LoadModel(context.Context, *ModelOptions) (*Result, error) { + return nil, status.Errorf(codes.Unimplemented, "method LoadModel not implemented") +} +func (UnimplementedBackendServer) PredictStream(*PredictOptions, Backend_PredictStreamServer) error { + return status.Errorf(codes.Unimplemented, "method PredictStream not implemented") +} +func (UnimplementedBackendServer) Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) { + return nil, status.Errorf(codes.Unimplemented, "method Embedding not implemented") +} +func (UnimplementedBackendServer) GenerateImage(context.Context, *GenerateImageRequest) (*Result, error) { + return nil, status.Errorf(codes.Unimplemented, "method GenerateImage not implemented") +} +func (UnimplementedBackendServer) AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error) { + return nil, status.Errorf(codes.Unimplemented, "method AudioTranscription not implemented") +} +func (UnimplementedBackendServer) TTS(context.Context, *TTSRequest) (*Result, error) { + return nil, status.Errorf(codes.Unimplemented, "method TTS not implemented") +} +func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {} + +// UnsafeBackendServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to BackendServer will +// result in compilation errors. +type UnsafeBackendServer interface { + mustEmbedUnimplementedBackendServer() +} + +func RegisterBackendServer(s grpc.ServiceRegistrar, srv BackendServer) { + s.RegisterService(&Backend_ServiceDesc, srv) +} + +func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HealthMessage) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).Health(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/Health", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).Health(ctx, req.(*HealthMessage)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PredictOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).Predict(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/Predict", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).Predict(ctx, req.(*PredictOptions)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ModelOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).LoadModel(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/LoadModel", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_PredictStream_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(PredictOptions) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(BackendServer).PredictStream(m, &backendPredictStreamServer{stream}) +} + +type Backend_PredictStreamServer interface { + Send(*Reply) error + grpc.ServerStream +} + +type backendPredictStreamServer struct { + grpc.ServerStream +} + +func (x *backendPredictStreamServer) Send(m *Reply) error { + return x.ServerStream.SendMsg(m) +} + +func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PredictOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).Embedding(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/Embedding", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GenerateImageRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).GenerateImage(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/GenerateImage", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TranscriptRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).AudioTranscription(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/AudioTranscription", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TTSRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).TTS(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/TTS", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).TTS(ctx, req.(*TTSRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Backend_ServiceDesc is the grpc.ServiceDesc for Backend service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Backend_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "backend.Backend", + HandlerType: (*BackendServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Health", + Handler: _Backend_Health_Handler, + }, + { + MethodName: "Predict", + Handler: _Backend_Predict_Handler, + }, + { + MethodName: "LoadModel", + Handler: _Backend_LoadModel_Handler, + }, + { + MethodName: "Embedding", + Handler: _Backend_Embedding_Handler, + }, + { + MethodName: "GenerateImage", + Handler: _Backend_GenerateImage_Handler, + }, + { + MethodName: "AudioTranscription", + Handler: _Backend_AudioTranscription_Handler, + }, + { + MethodName: "TTS", + Handler: _Backend_TTS_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "PredictStream", + Handler: _Backend_PredictStream_Handler, + ServerStreams: true, + }, + }, + Metadata: "pkg/grpc/proto/backend.proto", +} diff --git a/pkg/grpc/proto/llmserver.pb.go b/pkg/grpc/proto/llmserver.pb.go deleted file mode 100644 index d8bdcd2..0000000 --- a/pkg/grpc/proto/llmserver.pb.go +++ /dev/null @@ -1,969 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.26.0 -// protoc v3.15.8 -// source: pkg/grpc/proto/llmserver.proto - -package proto - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type HealthMessage struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields -} - -func (x *HealthMessage) Reset() { - *x = HealthMessage{} - if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *HealthMessage) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*HealthMessage) ProtoMessage() {} - -func (x *HealthMessage) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use HealthMessage.ProtoReflect.Descriptor instead. -func (*HealthMessage) Descriptor() ([]byte, []int) { - return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{0} -} - -// The request message containing the user's name. -type PredictOptions struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Prompt string `protobuf:"bytes,1,opt,name=Prompt,proto3" json:"Prompt,omitempty"` - Seed int32 `protobuf:"varint,2,opt,name=Seed,proto3" json:"Seed,omitempty"` - Threads int32 `protobuf:"varint,3,opt,name=Threads,proto3" json:"Threads,omitempty"` - Tokens int32 `protobuf:"varint,4,opt,name=Tokens,proto3" json:"Tokens,omitempty"` - TopK int32 `protobuf:"varint,5,opt,name=TopK,proto3" json:"TopK,omitempty"` - Repeat int32 `protobuf:"varint,6,opt,name=Repeat,proto3" json:"Repeat,omitempty"` - Batch int32 `protobuf:"varint,7,opt,name=Batch,proto3" json:"Batch,omitempty"` - NKeep int32 `protobuf:"varint,8,opt,name=NKeep,proto3" json:"NKeep,omitempty"` - Temperature float32 `protobuf:"fixed32,9,opt,name=Temperature,proto3" json:"Temperature,omitempty"` - Penalty float32 `protobuf:"fixed32,10,opt,name=Penalty,proto3" json:"Penalty,omitempty"` - F16KV bool `protobuf:"varint,11,opt,name=F16KV,proto3" json:"F16KV,omitempty"` - DebugMode bool `protobuf:"varint,12,opt,name=DebugMode,proto3" json:"DebugMode,omitempty"` - StopPrompts []string `protobuf:"bytes,13,rep,name=StopPrompts,proto3" json:"StopPrompts,omitempty"` - IgnoreEOS bool `protobuf:"varint,14,opt,name=IgnoreEOS,proto3" json:"IgnoreEOS,omitempty"` - TailFreeSamplingZ float32 `protobuf:"fixed32,15,opt,name=TailFreeSamplingZ,proto3" json:"TailFreeSamplingZ,omitempty"` - TypicalP float32 `protobuf:"fixed32,16,opt,name=TypicalP,proto3" json:"TypicalP,omitempty"` - FrequencyPenalty float32 `protobuf:"fixed32,17,opt,name=FrequencyPenalty,proto3" json:"FrequencyPenalty,omitempty"` - PresencePenalty float32 `protobuf:"fixed32,18,opt,name=PresencePenalty,proto3" json:"PresencePenalty,omitempty"` - Mirostat int32 `protobuf:"varint,19,opt,name=Mirostat,proto3" json:"Mirostat,omitempty"` - MirostatETA float32 `protobuf:"fixed32,20,opt,name=MirostatETA,proto3" json:"MirostatETA,omitempty"` - MirostatTAU float32 `protobuf:"fixed32,21,opt,name=MirostatTAU,proto3" json:"MirostatTAU,omitempty"` - PenalizeNL bool `protobuf:"varint,22,opt,name=PenalizeNL,proto3" json:"PenalizeNL,omitempty"` - LogitBias string `protobuf:"bytes,23,opt,name=LogitBias,proto3" json:"LogitBias,omitempty"` - MLock bool `protobuf:"varint,25,opt,name=MLock,proto3" json:"MLock,omitempty"` - MMap bool `protobuf:"varint,26,opt,name=MMap,proto3" json:"MMap,omitempty"` - PromptCacheAll bool `protobuf:"varint,27,opt,name=PromptCacheAll,proto3" json:"PromptCacheAll,omitempty"` - PromptCacheRO bool `protobuf:"varint,28,opt,name=PromptCacheRO,proto3" json:"PromptCacheRO,omitempty"` - Grammar string `protobuf:"bytes,29,opt,name=Grammar,proto3" json:"Grammar,omitempty"` - MainGPU string `protobuf:"bytes,30,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"` - TensorSplit string `protobuf:"bytes,31,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"` - TopP float32 `protobuf:"fixed32,32,opt,name=TopP,proto3" json:"TopP,omitempty"` - PromptCachePath string `protobuf:"bytes,33,opt,name=PromptCachePath,proto3" json:"PromptCachePath,omitempty"` - Debug bool `protobuf:"varint,34,opt,name=Debug,proto3" json:"Debug,omitempty"` - EmbeddingTokens []int32 `protobuf:"varint,35,rep,packed,name=EmbeddingTokens,proto3" json:"EmbeddingTokens,omitempty"` - Embeddings string `protobuf:"bytes,36,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` -} - -func (x *PredictOptions) Reset() { - *x = PredictOptions{} - if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *PredictOptions) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*PredictOptions) ProtoMessage() {} - -func (x *PredictOptions) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use PredictOptions.ProtoReflect.Descriptor instead. -func (*PredictOptions) Descriptor() ([]byte, []int) { - return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{1} -} - -func (x *PredictOptions) GetPrompt() string { - if x != nil { - return x.Prompt - } - return "" -} - -func (x *PredictOptions) GetSeed() int32 { - if x != nil { - return x.Seed - } - return 0 -} - -func (x *PredictOptions) GetThreads() int32 { - if x != nil { - return x.Threads - } - return 0 -} - -func (x *PredictOptions) GetTokens() int32 { - if x != nil { - return x.Tokens - } - return 0 -} - -func (x *PredictOptions) GetTopK() int32 { - if x != nil { - return x.TopK - } - return 0 -} - -func (x *PredictOptions) GetRepeat() int32 { - if x != nil { - return x.Repeat - } - return 0 -} - -func (x *PredictOptions) GetBatch() int32 { - if x != nil { - return x.Batch - } - return 0 -} - -func (x *PredictOptions) GetNKeep() int32 { - if x != nil { - return x.NKeep - } - return 0 -} - -func (x *PredictOptions) GetTemperature() float32 { - if x != nil { - return x.Temperature - } - return 0 -} - -func (x *PredictOptions) GetPenalty() float32 { - if x != nil { - return x.Penalty - } - return 0 -} - -func (x *PredictOptions) GetF16KV() bool { - if x != nil { - return x.F16KV - } - return false -} - -func (x *PredictOptions) GetDebugMode() bool { - if x != nil { - return x.DebugMode - } - return false -} - -func (x *PredictOptions) GetStopPrompts() []string { - if x != nil { - return x.StopPrompts - } - return nil -} - -func (x *PredictOptions) GetIgnoreEOS() bool { - if x != nil { - return x.IgnoreEOS - } - return false -} - -func (x *PredictOptions) GetTailFreeSamplingZ() float32 { - if x != nil { - return x.TailFreeSamplingZ - } - return 0 -} - -func (x *PredictOptions) GetTypicalP() float32 { - if x != nil { - return x.TypicalP - } - return 0 -} - -func (x *PredictOptions) GetFrequencyPenalty() float32 { - if x != nil { - return x.FrequencyPenalty - } - return 0 -} - -func (x *PredictOptions) GetPresencePenalty() float32 { - if x != nil { - return x.PresencePenalty - } - return 0 -} - -func (x *PredictOptions) GetMirostat() int32 { - if x != nil { - return x.Mirostat - } - return 0 -} - -func (x *PredictOptions) GetMirostatETA() float32 { - if x != nil { - return x.MirostatETA - } - return 0 -} - -func (x *PredictOptions) GetMirostatTAU() float32 { - if x != nil { - return x.MirostatTAU - } - return 0 -} - -func (x *PredictOptions) GetPenalizeNL() bool { - if x != nil { - return x.PenalizeNL - } - return false -} - -func (x *PredictOptions) GetLogitBias() string { - if x != nil { - return x.LogitBias - } - return "" -} - -func (x *PredictOptions) GetMLock() bool { - if x != nil { - return x.MLock - } - return false -} - -func (x *PredictOptions) GetMMap() bool { - if x != nil { - return x.MMap - } - return false -} - -func (x *PredictOptions) GetPromptCacheAll() bool { - if x != nil { - return x.PromptCacheAll - } - return false -} - -func (x *PredictOptions) GetPromptCacheRO() bool { - if x != nil { - return x.PromptCacheRO - } - return false -} - -func (x *PredictOptions) GetGrammar() string { - if x != nil { - return x.Grammar - } - return "" -} - -func (x *PredictOptions) GetMainGPU() string { - if x != nil { - return x.MainGPU - } - return "" -} - -func (x *PredictOptions) GetTensorSplit() string { - if x != nil { - return x.TensorSplit - } - return "" -} - -func (x *PredictOptions) GetTopP() float32 { - if x != nil { - return x.TopP - } - return 0 -} - -func (x *PredictOptions) GetPromptCachePath() string { - if x != nil { - return x.PromptCachePath - } - return "" -} - -func (x *PredictOptions) GetDebug() bool { - if x != nil { - return x.Debug - } - return false -} - -func (x *PredictOptions) GetEmbeddingTokens() []int32 { - if x != nil { - return x.EmbeddingTokens - } - return nil -} - -func (x *PredictOptions) GetEmbeddings() string { - if x != nil { - return x.Embeddings - } - return "" -} - -// The response message containing the result -type Reply struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` -} - -func (x *Reply) Reset() { - *x = Reply{} - if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *Reply) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Reply) ProtoMessage() {} - -func (x *Reply) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Reply.ProtoReflect.Descriptor instead. -func (*Reply) Descriptor() ([]byte, []int) { - return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{2} -} - -func (x *Reply) GetMessage() string { - if x != nil { - return x.Message - } - return "" -} - -type ModelOptions struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Model string `protobuf:"bytes,1,opt,name=Model,proto3" json:"Model,omitempty"` - ContextSize int32 `protobuf:"varint,2,opt,name=ContextSize,proto3" json:"ContextSize,omitempty"` - Seed int32 `protobuf:"varint,3,opt,name=Seed,proto3" json:"Seed,omitempty"` - NBatch int32 `protobuf:"varint,4,opt,name=NBatch,proto3" json:"NBatch,omitempty"` - F16Memory bool `protobuf:"varint,5,opt,name=F16Memory,proto3" json:"F16Memory,omitempty"` - MLock bool `protobuf:"varint,6,opt,name=MLock,proto3" json:"MLock,omitempty"` - MMap bool `protobuf:"varint,7,opt,name=MMap,proto3" json:"MMap,omitempty"` - VocabOnly bool `protobuf:"varint,8,opt,name=VocabOnly,proto3" json:"VocabOnly,omitempty"` - LowVRAM bool `protobuf:"varint,9,opt,name=LowVRAM,proto3" json:"LowVRAM,omitempty"` - Embeddings bool `protobuf:"varint,10,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` - NUMA bool `protobuf:"varint,11,opt,name=NUMA,proto3" json:"NUMA,omitempty"` - NGPULayers int32 `protobuf:"varint,12,opt,name=NGPULayers,proto3" json:"NGPULayers,omitempty"` - MainGPU string `protobuf:"bytes,13,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"` - TensorSplit string `protobuf:"bytes,14,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"` - Threads int32 `protobuf:"varint,15,opt,name=Threads,proto3" json:"Threads,omitempty"` - LibrarySearchPath string `protobuf:"bytes,16,opt,name=LibrarySearchPath,proto3" json:"LibrarySearchPath,omitempty"` -} - -func (x *ModelOptions) Reset() { - *x = ModelOptions{} - if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *ModelOptions) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*ModelOptions) ProtoMessage() {} - -func (x *ModelOptions) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use ModelOptions.ProtoReflect.Descriptor instead. -func (*ModelOptions) Descriptor() ([]byte, []int) { - return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{3} -} - -func (x *ModelOptions) GetModel() string { - if x != nil { - return x.Model - } - return "" -} - -func (x *ModelOptions) GetContextSize() int32 { - if x != nil { - return x.ContextSize - } - return 0 -} - -func (x *ModelOptions) GetSeed() int32 { - if x != nil { - return x.Seed - } - return 0 -} - -func (x *ModelOptions) GetNBatch() int32 { - if x != nil { - return x.NBatch - } - return 0 -} - -func (x *ModelOptions) GetF16Memory() bool { - if x != nil { - return x.F16Memory - } - return false -} - -func (x *ModelOptions) GetMLock() bool { - if x != nil { - return x.MLock - } - return false -} - -func (x *ModelOptions) GetMMap() bool { - if x != nil { - return x.MMap - } - return false -} - -func (x *ModelOptions) GetVocabOnly() bool { - if x != nil { - return x.VocabOnly - } - return false -} - -func (x *ModelOptions) GetLowVRAM() bool { - if x != nil { - return x.LowVRAM - } - return false -} - -func (x *ModelOptions) GetEmbeddings() bool { - if x != nil { - return x.Embeddings - } - return false -} - -func (x *ModelOptions) GetNUMA() bool { - if x != nil { - return x.NUMA - } - return false -} - -func (x *ModelOptions) GetNGPULayers() int32 { - if x != nil { - return x.NGPULayers - } - return 0 -} - -func (x *ModelOptions) GetMainGPU() string { - if x != nil { - return x.MainGPU - } - return "" -} - -func (x *ModelOptions) GetTensorSplit() string { - if x != nil { - return x.TensorSplit - } - return "" -} - -func (x *ModelOptions) GetThreads() int32 { - if x != nil { - return x.Threads - } - return 0 -} - -func (x *ModelOptions) GetLibrarySearchPath() string { - if x != nil { - return x.LibrarySearchPath - } - return "" -} - -type Result struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` - Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"` -} - -func (x *Result) Reset() { - *x = Result{} - if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *Result) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Result) ProtoMessage() {} - -func (x *Result) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Result.ProtoReflect.Descriptor instead. -func (*Result) Descriptor() ([]byte, []int) { - return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{4} -} - -func (x *Result) GetMessage() string { - if x != nil { - return x.Message - } - return "" -} - -func (x *Result) GetSuccess() bool { - if x != nil { - return x.Success - } - return false -} - -type EmbeddingResult struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Embeddings []float32 `protobuf:"fixed32,1,rep,packed,name=embeddings,proto3" json:"embeddings,omitempty"` -} - -func (x *EmbeddingResult) Reset() { - *x = EmbeddingResult{} - if protoimpl.UnsafeEnabled { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *EmbeddingResult) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*EmbeddingResult) ProtoMessage() {} - -func (x *EmbeddingResult) ProtoReflect() protoreflect.Message { - mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[5] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use EmbeddingResult.ProtoReflect.Descriptor instead. -func (*EmbeddingResult) Descriptor() ([]byte, []int) { - return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{5} -} - -func (x *EmbeddingResult) GetEmbeddings() []float32 { - if x != nil { - return x.Embeddings - } - return nil -} - -var File_pkg_grpc_proto_llmserver_proto protoreflect.FileDescriptor - -var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{ - 0x0a, 0x1e, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x2f, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x12, 0x03, 0x6c, 0x6c, 0x6d, 0x22, 0x0f, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xa0, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, 0x64, 0x69, - 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x72, 0x6f, - 0x6d, 0x70, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x72, 0x6f, 0x6d, 0x70, - 0x74, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, - 0x04, 0x53, 0x65, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, - 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x12, - 0x16, 0x0a, 0x06, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, - 0x06, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x6f, 0x70, 0x4b, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x4b, 0x12, 0x16, 0x0a, 0x06, 0x52, - 0x65, 0x70, 0x65, 0x61, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x52, 0x65, 0x70, - 0x65, 0x61, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x42, 0x61, 0x74, 0x63, 0x68, 0x18, 0x07, 0x20, 0x01, - 0x28, 0x05, 0x52, 0x05, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x4b, 0x65, - 0x65, 0x70, 0x18, 0x08, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x4e, 0x4b, 0x65, 0x65, 0x70, 0x12, - 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x09, - 0x20, 0x01, 0x28, 0x02, 0x52, 0x0b, 0x54, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, 0x75, 0x72, - 0x65, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x0a, 0x20, 0x01, - 0x28, 0x02, 0x52, 0x07, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x46, - 0x31, 0x36, 0x4b, 0x56, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x46, 0x31, 0x36, 0x4b, - 0x56, 0x12, 0x1c, 0x0a, 0x09, 0x44, 0x65, 0x62, 0x75, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x18, 0x0c, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x44, 0x65, 0x62, 0x75, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x12, - 0x20, 0x0a, 0x0b, 0x53, 0x74, 0x6f, 0x70, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x73, 0x18, 0x0d, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x53, 0x74, 0x6f, 0x70, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, - 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x45, 0x4f, 0x53, 0x18, 0x0e, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x45, 0x4f, 0x53, 0x12, - 0x2c, 0x0a, 0x11, 0x54, 0x61, 0x69, 0x6c, 0x46, 0x72, 0x65, 0x65, 0x53, 0x61, 0x6d, 0x70, 0x6c, - 0x69, 0x6e, 0x67, 0x5a, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x02, 0x52, 0x11, 0x54, 0x61, 0x69, 0x6c, - 0x46, 0x72, 0x65, 0x65, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x69, 0x6e, 0x67, 0x5a, 0x12, 0x1a, 0x0a, - 0x08, 0x54, 0x79, 0x70, 0x69, 0x63, 0x61, 0x6c, 0x50, 0x18, 0x10, 0x20, 0x01, 0x28, 0x02, 0x52, - 0x08, 0x54, 0x79, 0x70, 0x69, 0x63, 0x61, 0x6c, 0x50, 0x12, 0x2a, 0x0a, 0x10, 0x46, 0x72, 0x65, - 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x11, 0x20, - 0x01, 0x28, 0x02, 0x52, 0x10, 0x46, 0x72, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x50, 0x65, - 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x65, 0x73, 0x65, 0x6e, 0x63, - 0x65, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x12, 0x20, 0x01, 0x28, 0x02, 0x52, 0x0f, - 0x50, 0x72, 0x65, 0x73, 0x65, 0x6e, 0x63, 0x65, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, - 0x1a, 0x0a, 0x08, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x18, 0x13, 0x20, 0x01, 0x28, - 0x05, 0x52, 0x08, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x12, 0x20, 0x0a, 0x0b, 0x4d, - 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x45, 0x54, 0x41, 0x18, 0x14, 0x20, 0x01, 0x28, 0x02, - 0x52, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x45, 0x54, 0x41, 0x12, 0x20, 0x0a, - 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x54, 0x41, 0x55, 0x18, 0x15, 0x20, 0x01, - 0x28, 0x02, 0x52, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x54, 0x41, 0x55, 0x12, - 0x1e, 0x0a, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x18, 0x16, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x12, - 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x18, 0x17, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, 0x14, 0x0a, - 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x4d, 0x4c, - 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, - 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x12, - 0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, - 0x18, 0x1c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, - 0x63, 0x68, 0x65, 0x52, 0x4f, 0x12, 0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, - 0x18, 0x1d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x12, - 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, - 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, - 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54, - 0x6f, 0x70, 0x50, 0x18, 0x20, 0x20, 0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x12, - 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, - 0x74, 0x68, 0x18, 0x21, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, - 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, 0x65, 0x62, - 0x75, 0x67, 0x18, 0x22, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x12, - 0x28, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65, - 0x6e, 0x73, 0x18, 0x23, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, - 0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x45, 0x6d, 0x62, - 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x45, - 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x21, 0x0a, 0x05, 0x52, 0x65, 0x70, - 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xca, 0x03, 0x0a, - 0x0c, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x14, 0x0a, - 0x05, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4d, 0x6f, - 0x64, 0x65, 0x6c, 0x12, 0x20, 0x0a, 0x0b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x53, 0x69, - 0x7a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, - 0x74, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x04, 0x53, 0x65, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x42, 0x61, - 0x74, 0x63, 0x68, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x4e, 0x42, 0x61, 0x74, 0x63, - 0x68, 0x12, 0x1c, 0x0a, 0x09, 0x46, 0x31, 0x36, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x46, 0x31, 0x36, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x12, - 0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, - 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x07, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x1c, 0x0a, 0x09, 0x56, 0x6f, 0x63, - 0x61, 0x62, 0x4f, 0x6e, 0x6c, 0x79, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x56, 0x6f, - 0x63, 0x61, 0x62, 0x4f, 0x6e, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x4c, 0x6f, 0x77, 0x56, 0x52, - 0x41, 0x4d, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x4c, 0x6f, 0x77, 0x56, 0x52, 0x41, - 0x4d, 0x12, 0x1e, 0x0a, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, - 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, - 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x55, 0x4d, 0x41, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x04, 0x4e, 0x55, 0x4d, 0x41, 0x12, 0x1e, 0x0a, 0x0a, 0x4e, 0x47, 0x50, 0x55, 0x4c, 0x61, 0x79, - 0x65, 0x72, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x4e, 0x47, 0x50, 0x55, 0x4c, - 0x61, 0x79, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, - 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12, - 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x0e, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, - 0x74, 0x12, 0x18, 0x0a, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x18, 0x0f, 0x20, 0x01, - 0x28, 0x05, 0x52, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x12, 0x2c, 0x0a, 0x11, 0x4c, - 0x69, 0x62, 0x72, 0x61, 0x72, 0x79, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x50, 0x61, 0x74, 0x68, - 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x11, 0x4c, 0x69, 0x62, 0x72, 0x61, 0x72, 0x79, 0x53, - 0x65, 0x61, 0x72, 0x63, 0x68, 0x50, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x06, 0x52, 0x65, 0x73, - 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, - 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, - 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x22, 0x31, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, - 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x65, 0x6d, - 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x02, 0x52, 0x0a, - 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x32, 0xfe, 0x01, 0x0a, 0x03, 0x4c, - 0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x12, 0x2e, 0x6c, - 0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2c, - 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, - 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, - 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x09, - 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x11, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, - 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0b, 0x2e, 0x6c, - 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x0d, 0x50, - 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x6c, - 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, - 0x01, 0x12, 0x38, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x13, - 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x1a, 0x14, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, - 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x42, 0x57, 0x0a, 0x1b, 0x69, - 0x6f, 0x2e, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, - 0x2e, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, - 0x63, 0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} - -var ( - file_pkg_grpc_proto_llmserver_proto_rawDescOnce sync.Once - file_pkg_grpc_proto_llmserver_proto_rawDescData = file_pkg_grpc_proto_llmserver_proto_rawDesc -) - -func file_pkg_grpc_proto_llmserver_proto_rawDescGZIP() []byte { - file_pkg_grpc_proto_llmserver_proto_rawDescOnce.Do(func() { - file_pkg_grpc_proto_llmserver_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_proto_llmserver_proto_rawDescData) - }) - return file_pkg_grpc_proto_llmserver_proto_rawDescData -} - -var file_pkg_grpc_proto_llmserver_proto_msgTypes = make([]protoimpl.MessageInfo, 6) -var file_pkg_grpc_proto_llmserver_proto_goTypes = []interface{}{ - (*HealthMessage)(nil), // 0: llm.HealthMessage - (*PredictOptions)(nil), // 1: llm.PredictOptions - (*Reply)(nil), // 2: llm.Reply - (*ModelOptions)(nil), // 3: llm.ModelOptions - (*Result)(nil), // 4: llm.Result - (*EmbeddingResult)(nil), // 5: llm.EmbeddingResult -} -var file_pkg_grpc_proto_llmserver_proto_depIdxs = []int32{ - 0, // 0: llm.LLM.Health:input_type -> llm.HealthMessage - 1, // 1: llm.LLM.Predict:input_type -> llm.PredictOptions - 3, // 2: llm.LLM.LoadModel:input_type -> llm.ModelOptions - 1, // 3: llm.LLM.PredictStream:input_type -> llm.PredictOptions - 1, // 4: llm.LLM.Embedding:input_type -> llm.PredictOptions - 2, // 5: llm.LLM.Health:output_type -> llm.Reply - 2, // 6: llm.LLM.Predict:output_type -> llm.Reply - 4, // 7: llm.LLM.LoadModel:output_type -> llm.Result - 2, // 8: llm.LLM.PredictStream:output_type -> llm.Reply - 5, // 9: llm.LLM.Embedding:output_type -> llm.EmbeddingResult - 5, // [5:10] is the sub-list for method output_type - 0, // [0:5] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name -} - -func init() { file_pkg_grpc_proto_llmserver_proto_init() } -func file_pkg_grpc_proto_llmserver_proto_init() { - if File_pkg_grpc_proto_llmserver_proto != nil { - return - } - if !protoimpl.UnsafeEnabled { - file_pkg_grpc_proto_llmserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*HealthMessage); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_pkg_grpc_proto_llmserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PredictOptions); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_pkg_grpc_proto_llmserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Reply); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_pkg_grpc_proto_llmserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ModelOptions); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_pkg_grpc_proto_llmserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Result); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_pkg_grpc_proto_llmserver_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*EmbeddingResult); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_pkg_grpc_proto_llmserver_proto_rawDesc, - NumEnums: 0, - NumMessages: 6, - NumExtensions: 0, - NumServices: 1, - }, - GoTypes: file_pkg_grpc_proto_llmserver_proto_goTypes, - DependencyIndexes: file_pkg_grpc_proto_llmserver_proto_depIdxs, - MessageInfos: file_pkg_grpc_proto_llmserver_proto_msgTypes, - }.Build() - File_pkg_grpc_proto_llmserver_proto = out.File - file_pkg_grpc_proto_llmserver_proto_rawDesc = nil - file_pkg_grpc_proto_llmserver_proto_goTypes = nil - file_pkg_grpc_proto_llmserver_proto_depIdxs = nil -} diff --git a/pkg/grpc/proto/llmserver_grpc.pb.go b/pkg/grpc/proto/llmserver_grpc.pb.go deleted file mode 100644 index c028218..0000000 --- a/pkg/grpc/proto/llmserver_grpc.pb.go +++ /dev/null @@ -1,277 +0,0 @@ -// Code generated by protoc-gen-go-grpc. DO NOT EDIT. -// versions: -// - protoc-gen-go-grpc v1.2.0 -// - protoc v3.15.8 -// source: pkg/grpc/proto/llmserver.proto - -package proto - -import ( - context "context" - grpc "google.golang.org/grpc" - codes "google.golang.org/grpc/codes" - status "google.golang.org/grpc/status" -) - -// This is a compile-time assertion to ensure that this generated file -// is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 - -// LLMClient is the client API for LLM service. -// -// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. -type LLMClient interface { - Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) - Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) - LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) - PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (LLM_PredictStreamClient, error) - Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) -} - -type lLMClient struct { - cc grpc.ClientConnInterface -} - -func NewLLMClient(cc grpc.ClientConnInterface) LLMClient { - return &lLMClient{cc} -} - -func (c *lLMClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { - out := new(Reply) - err := c.cc.Invoke(ctx, "/llm.LLM/Health", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *lLMClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { - out := new(Reply) - err := c.cc.Invoke(ctx, "/llm.LLM/Predict", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *lLMClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { - out := new(Result) - err := c.cc.Invoke(ctx, "/llm.LLM/LoadModel", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *lLMClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (LLM_PredictStreamClient, error) { - stream, err := c.cc.NewStream(ctx, &LLM_ServiceDesc.Streams[0], "/llm.LLM/PredictStream", opts...) - if err != nil { - return nil, err - } - x := &lLMPredictStreamClient{stream} - if err := x.ClientStream.SendMsg(in); err != nil { - return nil, err - } - if err := x.ClientStream.CloseSend(); err != nil { - return nil, err - } - return x, nil -} - -type LLM_PredictStreamClient interface { - Recv() (*Reply, error) - grpc.ClientStream -} - -type lLMPredictStreamClient struct { - grpc.ClientStream -} - -func (x *lLMPredictStreamClient) Recv() (*Reply, error) { - m := new(Reply) - if err := x.ClientStream.RecvMsg(m); err != nil { - return nil, err - } - return m, nil -} - -func (c *lLMClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { - out := new(EmbeddingResult) - err := c.cc.Invoke(ctx, "/llm.LLM/Embedding", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -// LLMServer is the server API for LLM service. -// All implementations must embed UnimplementedLLMServer -// for forward compatibility -type LLMServer interface { - Health(context.Context, *HealthMessage) (*Reply, error) - Predict(context.Context, *PredictOptions) (*Reply, error) - LoadModel(context.Context, *ModelOptions) (*Result, error) - PredictStream(*PredictOptions, LLM_PredictStreamServer) error - Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) - mustEmbedUnimplementedLLMServer() -} - -// UnimplementedLLMServer must be embedded to have forward compatible implementations. -type UnimplementedLLMServer struct { -} - -func (UnimplementedLLMServer) Health(context.Context, *HealthMessage) (*Reply, error) { - return nil, status.Errorf(codes.Unimplemented, "method Health not implemented") -} -func (UnimplementedLLMServer) Predict(context.Context, *PredictOptions) (*Reply, error) { - return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented") -} -func (UnimplementedLLMServer) LoadModel(context.Context, *ModelOptions) (*Result, error) { - return nil, status.Errorf(codes.Unimplemented, "method LoadModel not implemented") -} -func (UnimplementedLLMServer) PredictStream(*PredictOptions, LLM_PredictStreamServer) error { - return status.Errorf(codes.Unimplemented, "method PredictStream not implemented") -} -func (UnimplementedLLMServer) Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) { - return nil, status.Errorf(codes.Unimplemented, "method Embedding not implemented") -} -func (UnimplementedLLMServer) mustEmbedUnimplementedLLMServer() {} - -// UnsafeLLMServer may be embedded to opt out of forward compatibility for this service. -// Use of this interface is not recommended, as added methods to LLMServer will -// result in compilation errors. -type UnsafeLLMServer interface { - mustEmbedUnimplementedLLMServer() -} - -func RegisterLLMServer(s grpc.ServiceRegistrar, srv LLMServer) { - s.RegisterService(&LLM_ServiceDesc, srv) -} - -func _LLM_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(HealthMessage) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(LLMServer).Health(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/llm.LLM/Health", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(LLMServer).Health(ctx, req.(*HealthMessage)) - } - return interceptor(ctx, in, info, handler) -} - -func _LLM_Predict_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(PredictOptions) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(LLMServer).Predict(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/llm.LLM/Predict", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(LLMServer).Predict(ctx, req.(*PredictOptions)) - } - return interceptor(ctx, in, info, handler) -} - -func _LLM_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(ModelOptions) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(LLMServer).LoadModel(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/llm.LLM/LoadModel", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(LLMServer).LoadModel(ctx, req.(*ModelOptions)) - } - return interceptor(ctx, in, info, handler) -} - -func _LLM_PredictStream_Handler(srv interface{}, stream grpc.ServerStream) error { - m := new(PredictOptions) - if err := stream.RecvMsg(m); err != nil { - return err - } - return srv.(LLMServer).PredictStream(m, &lLMPredictStreamServer{stream}) -} - -type LLM_PredictStreamServer interface { - Send(*Reply) error - grpc.ServerStream -} - -type lLMPredictStreamServer struct { - grpc.ServerStream -} - -func (x *lLMPredictStreamServer) Send(m *Reply) error { - return x.ServerStream.SendMsg(m) -} - -func _LLM_Embedding_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(PredictOptions) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(LLMServer).Embedding(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/llm.LLM/Embedding", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(LLMServer).Embedding(ctx, req.(*PredictOptions)) - } - return interceptor(ctx, in, info, handler) -} - -// LLM_ServiceDesc is the grpc.ServiceDesc for LLM service. -// It's only intended for direct use with grpc.RegisterService, -// and not to be introspected or modified (even as a copy) -var LLM_ServiceDesc = grpc.ServiceDesc{ - ServiceName: "llm.LLM", - HandlerType: (*LLMServer)(nil), - Methods: []grpc.MethodDesc{ - { - MethodName: "Health", - Handler: _LLM_Health_Handler, - }, - { - MethodName: "Predict", - Handler: _LLM_Predict_Handler, - }, - { - MethodName: "LoadModel", - Handler: _LLM_LoadModel_Handler, - }, - { - MethodName: "Embedding", - Handler: _LLM_Embedding_Handler, - }, - }, - Streams: []grpc.StreamDesc{ - { - StreamName: "PredictStream", - Handler: _LLM_PredictStream_Handler, - ServerStreams: true, - }, - }, - Metadata: "pkg/grpc/proto/llmserver.proto", -} diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index 9e4c88a..8d7a182 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -21,7 +21,7 @@ import ( // server is used to implement helloworld.GreeterServer. type server struct { - pb.UnimplementedLLMServer + pb.UnimplementedBackendServer llm LLM } @@ -51,7 +51,48 @@ func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, return &pb.Reply{Message: result}, err } -func (s *server) PredictStream(in *pb.PredictOptions, stream pb.LLM_PredictStreamServer) error { +func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) { + err := s.llm.GenerateImage(in) + if err != nil { + return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err + } + return &pb.Result{Message: "Image generated", Success: true}, nil +} + +func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { + err := s.llm.TTS(in) + if err != nil { + return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err + } + return &pb.Result{Message: "Audio generated", Success: true}, nil +} + +func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { + result, err := s.llm.AudioTranscription(in) + if err != nil { + return nil, err + } + tresult := &pb.TranscriptResult{} + for _, s := range result.Segments { + tks := []int32{} + for _, t := range s.Tokens { + tks = append(tks, int32(t)) + } + tresult.Segments = append(tresult.Segments, + &pb.TranscriptSegment{ + Text: s.Text, + Id: int32(s.Id), + Start: int64(s.Start), + End: int64(s.End), + Tokens: tks, + }) + } + + tresult.Text = result.Text + return tresult, nil +} + +func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error { resultChan := make(chan string) @@ -75,7 +116,7 @@ func StartServer(address string, model LLM) error { return err } s := grpc.NewServer() - pb.RegisterLLMServer(s, &server{llm: model}) + pb.RegisterBackendServer(s, &server{llm: model}) log.Printf("gRPC Server listening at %v", lis.Addr()) if err := s.Serve(lis); err != nil { return err diff --git a/pkg/grpc/transcribe/whisper.go b/pkg/grpc/transcribe/whisper.go new file mode 100644 index 0000000..c0120db --- /dev/null +++ b/pkg/grpc/transcribe/whisper.go @@ -0,0 +1,27 @@ +package transcribe + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + whisperutil "github.com/go-skynet/LocalAI/pkg/grpc/whisper" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" +) + +type Whisper struct { + base.Base + whisper whisper.Model +} + +func (sd *Whisper) Load(opts *pb.ModelOptions) error { + // Note: the Model here is a path to a directory containing the model files + w, err := whisper.New(opts.Model) + sd.whisper = w + return err +} + +func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (api.Result, error) { + return whisperutil.Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) +} diff --git a/pkg/grpc/tts/piper.go b/pkg/grpc/tts/piper.go new file mode 100644 index 0000000..dbaa4b7 --- /dev/null +++ b/pkg/grpc/tts/piper.go @@ -0,0 +1,44 @@ +package tts + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "os" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + piper "github.com/mudler/go-piper" +) + +type Piper struct { + base.Base + piper *PiperB +} + +func (sd *Piper) Load(opts *pb.ModelOptions) error { + var err error + // Note: the Model here is a path to a directory containing the model files + sd.piper, err = New(opts.LibrarySearchPath) + return err +} + +func (sd *Piper) TTS(opts *pb.TTSRequest) error { + return sd.piper.TTS(opts.Text, opts.Model, opts.Dst) +} + +type PiperB struct { + assetDir string +} + +func New(assetDir string) (*PiperB, error) { + if _, err := os.Stat(assetDir); err != nil { + return nil, err + } + return &PiperB{ + assetDir: assetDir, + }, nil +} + +func (s *PiperB) TTS(text, model, dst string) error { + return piper.TextToWav(text, model, s.assetDir, "", dst) +} diff --git a/pkg/grpc/whisper/api/api.go b/pkg/grpc/whisper/api/api.go new file mode 100644 index 0000000..700d80e --- /dev/null +++ b/pkg/grpc/whisper/api/api.go @@ -0,0 +1,16 @@ +package api + +import "time" + +type Segment struct { + Id int `json:"id"` + Start time.Duration `json:"start"` + End time.Duration `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` +} + +type Result struct { + Segments []Segment `json:"segments"` + Text string `json:"text"` +} diff --git a/pkg/whisper/whisper.go b/pkg/grpc/whisper/whisper.go similarity index 78% rename from pkg/whisper/whisper.go rename to pkg/grpc/whisper/whisper.go index 63e8cc5..806e145 100644 --- a/pkg/whisper/whisper.go +++ b/pkg/grpc/whisper/whisper.go @@ -5,25 +5,12 @@ import ( "os" "os/exec" "path/filepath" - "time" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" wav "github.com/go-audio/wav" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" ) -type Segment struct { - Id int `json:"id"` - Start time.Duration `json:"start"` - End time.Duration `json:"end"` - Text string `json:"text"` - Tokens []int `json:"tokens"` -} - -type Result struct { - Segments []Segment `json:"segments"` - Text string `json:"text"` -} - func sh(c string) (string, error) { cmd := exec.Command("/bin/sh", "-c", c) cmd.Env = os.Environ() @@ -42,8 +29,8 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string, threads uint) (Result, error) { - res := Result{} +func Transcript(model whisper.Model, audiopath, language string, threads uint) (api.Result, error) { + res := api.Result{} dir, err := os.MkdirTemp("", "whisper") if err != nil { @@ -99,11 +86,11 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) ( } var tokens []int - for _, t := range(s.Tokens) { + for _, t := range s.Tokens { tokens = append(tokens, t.Id) } - segment := Segment{Id: s.Num, Text: s.Text, Start:s.Start, End: s.End, Tokens: tokens} + segment := api.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens} res.Segments = append(res.Segments, segment) res.Text += s.Text diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 44a0638..d91131d 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -4,18 +4,13 @@ import ( "context" "fmt" "os" + "os/signal" "path/filepath" "strings" + "syscall" "time" - rwkv "github.com/donomii/go-rwkv.cpp" - whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" grpc "github.com/go-skynet/LocalAI/pkg/grpc" - "github.com/go-skynet/LocalAI/pkg/langchain" - "github.com/go-skynet/LocalAI/pkg/stablediffusion" - "github.com/go-skynet/LocalAI/pkg/tts" - bloomz "github.com/go-skynet/bloomz.cpp" - bert "github.com/go-skynet/go-bert.cpp" "github.com/hashicorp/go-multierror" "github.com/hpcloud/tail" "github.com/phayes/freeport" @@ -27,20 +22,22 @@ import ( const tokenizerSuffix = ".tokenizer.json" const ( - LlamaBackend = "llama" - BloomzBackend = "bloomz" - StarcoderBackend = "starcoder" - GPTJBackend = "gptj" - DollyBackend = "dolly" - MPTBackend = "mpt" - GPTNeoXBackend = "gptneox" - ReplitBackend = "replit" - Gpt2Backend = "gpt2" - Gpt4AllLlamaBackend = "gpt4all-llama" - Gpt4AllMptBackend = "gpt4all-mpt" - Gpt4AllJBackend = "gpt4all-j" - Gpt4All = "gpt4all" - FalconBackend = "falcon" + LlamaBackend = "llama" + BloomzBackend = "bloomz" + StarcoderBackend = "starcoder" + GPTJBackend = "gptj" + DollyBackend = "dolly" + MPTBackend = "mpt" + GPTNeoXBackend = "gptneox" + ReplitBackend = "replit" + Gpt2Backend = "gpt2" + Gpt4AllLlamaBackend = "gpt4all-llama" + Gpt4AllMptBackend = "gpt4all-mpt" + Gpt4AllJBackend = "gpt4all-j" + Gpt4All = "gpt4all" + FalconBackend = "falcon" + FalconGGMLBackend = "falcon-ggml" + BertEmbeddingsBackend = "bert-embeddings" RwkvBackend = "rwkv" WhisperBackend = "whisper" @@ -54,77 +51,39 @@ var autoLoadBackends []string = []string{ LlamaBackend, Gpt4All, RwkvBackend, + FalconBackend, WhisperBackend, - BertEmbeddingsBackend, GPTNeoXBackend, + BertEmbeddingsBackend, + FalconGGMLBackend, GPTJBackend, Gpt2Backend, DollyBackend, MPTBackend, ReplitBackend, StarcoderBackend, - FalconBackend, BloomzBackend, } -var bertEmbeddings = func(modelFile string) (interface{}, error) { - return bert.New(modelFile) -} - -var bloomzLM = func(modelFile string) (interface{}, error) { - return bloomz.New(modelFile) -} - -var stableDiffusion = func(assetDir string) (interface{}, error) { - return stablediffusion.New(assetDir) -} - -func piperTTS(assetDir string) func(s string) (interface{}, error) { - return func(s string) (interface{}, error) { - return tts.New(assetDir) - } -} - -var whisperModel = func(modelFile string) (interface{}, error) { - return whisper.New(modelFile) -} - -var lcHuggingFace = func(repoId string) (interface{}, error) { - return langchain.NewHuggingFace(repoId) -} - -// func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) { -// return func(s string) (interface{}, error) { -// return llama.New(s, opts...) -// } -// } - -// func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) { -// return func(s string) (interface{}, error) { -// return gpt4all.New(s, opts...) -// } -// } - -func rwkvLM(tokenFile string, threads uint32) func(string) (interface{}, error) { - return func(s string) (interface{}, error) { - log.Debug().Msgf("Loading RWKV", s, tokenFile) - - model := rwkv.LoadFiles(s, tokenFile, threads) - if model == nil { - return nil, fmt.Errorf("could not load model") - } - return model, nil +func (ml *ModelLoader) StopGRPC() { + for _, p := range ml.grpcProcesses { + p.Stop() } } // starts the grpcModelProcess for the backend, and returns a grpc client // It also loads the model -func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (interface{}, error) { - return func(s string) (interface{}, error) { +func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) { + return func(s string) (*grpc.Client, error) { log.Debug().Msgf("Loading GRPC Model", backend, *o) grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) + // Check if the file exists + if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { + return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) + } + // Make sure the process is executable if err := os.Chmod(grpcProcess, 0755); err != nil { return nil, err @@ -151,6 +110,14 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (inter return nil, err } + // clean up process + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + grpcControlProcess.Stop() + }() + go func() { t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) if err != nil { @@ -200,7 +167,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (inter log.Debug().Msgf("GRPC: Loading model with options: %+v", options) - res, err := client.LoadModel(context.TODO(), &options) + res, err := client.LoadModel(o.context, &options) if err != nil { return nil, err } @@ -212,63 +179,37 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (inter } } -func (ml *ModelLoader) BackendLoader(opts ...Option) (model interface{}, err error) { - - //backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32, assetDir string) (model interface{}, err error) { - +func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err error) { o := NewOptions(opts...) log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) - switch strings.ToLower(o.backendString) { - case LlamaBackend: - return ml.LoadModel(o.modelFile, ml.grpcModel(LlamaBackend, o)) - case BloomzBackend: - return ml.LoadModel(o.modelFile, bloomzLM) - case GPTJBackend: - return ml.LoadModel(o.modelFile, ml.grpcModel(GPTJBackend, o)) - case DollyBackend: - return ml.LoadModel(o.modelFile, ml.grpcModel(DollyBackend, o)) - case MPTBackend: - return ml.LoadModel(o.modelFile, ml.grpcModel(MPTBackend, o)) - case Gpt2Backend: - return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt2Backend, o)) - case FalconBackend: - return ml.LoadModel(o.modelFile, ml.grpcModel(FalconBackend, o)) - case GPTNeoXBackend: - return ml.LoadModel(o.modelFile, ml.grpcModel(GPTNeoXBackend, o)) - case ReplitBackend: - return ml.LoadModel(o.modelFile, ml.grpcModel(ReplitBackend, o)) - case StableDiffusionBackend: - return ml.LoadModel(o.modelFile, stableDiffusion) - case PiperBackend: - return ml.LoadModel(o.modelFile, piperTTS(filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data"))) - case StarcoderBackend: - return ml.LoadModel(o.modelFile, ml.grpcModel(StarcoderBackend, o)) + + backend := strings.ToLower(o.backendString) + switch backend { + case LlamaBackend, GPTJBackend, DollyBackend, + MPTBackend, Gpt2Backend, FalconBackend, + GPTNeoXBackend, ReplitBackend, StarcoderBackend, BloomzBackend, + RwkvBackend, LCHuggingFaceBackend, BertEmbeddingsBackend, FalconGGMLBackend, StableDiffusionBackend, WhisperBackend: + return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o)) case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All: o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all") return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt4All, o)) - // return ml.LoadModel(o.modelFile, gpt4allLM(gpt4all.SetThreads(int(o.threads)), gpt4all.SetLibrarySearchPath(filepath.Join(o.assetDir, "backend-assets", "gpt4all")))) - case BertEmbeddingsBackend: - return ml.LoadModel(o.modelFile, bertEmbeddings) - case RwkvBackend: - return ml.LoadModel(o.modelFile, rwkvLM(filepath.Join(ml.ModelPath, o.modelFile+tokenizerSuffix), o.threads)) - case WhisperBackend: - return ml.LoadModel(o.modelFile, whisperModel) - case LCHuggingFaceBackend: - return ml.LoadModel(o.modelFile, lcHuggingFace) + case PiperBackend: + o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data") + return ml.LoadModel(o.modelFile, ml.grpcModel(PiperBackend, o)) default: return nil, fmt.Errorf("backend unsupported: %s", o.backendString) } } -func (ml *ModelLoader) GreedyLoader(opts ...Option) (interface{}, error) { +func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { o := NewOptions(opts...) log.Debug().Msgf("Loading model '%s' greedly", o.modelFile) + // Is this really needed? BackendLoader already does this ml.mu.Lock() - m, exists := ml.models[o.modelFile] - if exists { + if m := ml.checkIsLoaded(o.modelFile); m != nil { log.Debug().Msgf("Model '%s' already loaded", o.modelFile) ml.mu.Unlock() return m, nil @@ -285,7 +226,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (interface{}, error) { model, modelerr := ml.BackendLoader( WithBackendString(b), WithModelFile(o.modelFile), - WithLoadGRPCOpts(o.gRPCOptions), + WithLoadGRPCLLMModelOpts(o.gRPCOptions), WithThreads(o.threads), WithAssetDir(o.assetDir), ) diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 35f3cef..833c311 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -2,6 +2,7 @@ package model import ( "bytes" + "context" "fmt" "io/ioutil" "os" @@ -10,6 +11,7 @@ import ( "sync" "text/template" + "github.com/go-skynet/LocalAI/pkg/grpc" process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" ) @@ -18,7 +20,7 @@ type ModelLoader struct { ModelPath string mu sync.Mutex // TODO: this needs generics - models map[string]interface{} + models map[string]*grpc.Client grpcProcesses map[string]*process.Process promptsTemplates map[string]*template.Template } @@ -26,7 +28,7 @@ type ModelLoader struct { func NewModelLoader(modelPath string) *ModelLoader { return &ModelLoader{ ModelPath: modelPath, - models: make(map[string]interface{}), + models: make(map[string]*grpc.Client), promptsTemplates: make(map[string]*template.Template), grpcProcesses: make(map[string]*process.Process), } @@ -113,14 +115,14 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error { return nil } -func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (interface{}, error)) (interface{}, error) { +func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Client, error)) (*grpc.Client, error) { ml.mu.Lock() defer ml.mu.Unlock() // Check if we already have a loaded model - if m, ok := ml.models[modelName]; ok { + if model := ml.checkIsLoaded(modelName); model != nil { log.Debug().Msgf("Model already loaded in memory: %s", modelName) - return m, nil + return model, nil } // Load the model and keep it in memory for later use @@ -140,3 +142,25 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (interfac ml.models[modelName] = model return model, nil } + +func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client { + if m, ok := ml.models[s]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", s) + + if !m.HealthCheck(context.Background()) { + log.Debug().Msgf("GRPC Model not responding", s) + if !ml.grpcProcesses[s].IsAlive() { + log.Debug().Msgf("GRPC Process is not responding", s) + // stop and delete the process, this forces to re-load the model and re-create again the service + ml.grpcProcesses[s].Stop() + delete(ml.grpcProcesses, s) + delete(ml.models, s) + return nil + } + } + + return m + } + + return nil +} diff --git a/pkg/model/options.go b/pkg/model/options.go index 31e54cb..298ebd4 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -1,6 +1,8 @@ package model import ( + "context" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" ) @@ -9,6 +11,7 @@ type Options struct { modelFile string threads uint32 assetDir string + context context.Context gRPCOptions *pb.ModelOptions } @@ -27,7 +30,7 @@ func WithModelFile(modelFile string) Option { } } -func WithLoadGRPCOpts(opts *pb.ModelOptions) Option { +func WithLoadGRPCLLMModelOpts(opts *pb.ModelOptions) Option { return func(o *Options) { o.gRPCOptions = opts } @@ -45,8 +48,17 @@ func WithAssetDir(assetDir string) Option { } } +func WithContext(ctx context.Context) Option { + return func(o *Options) { + o.context = ctx + } +} + func NewOptions(opts ...Option) *Options { - o := &Options{} + o := &Options{ + gRPCOptions: &pb.ModelOptions{}, + context: context.Background(), + } for _, opt := range opts { opt(o) } diff --git a/pkg/tts/generate.go b/pkg/tts/generate.go deleted file mode 100644 index e4722d4..0000000 --- a/pkg/tts/generate.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build tts -// +build tts - -package tts - -import ( - piper "github.com/mudler/go-piper" -) - -func tts(text, model, assetDir, arLib, dst string) error { - return piper.TextToWav(text, model, assetDir, arLib, dst) -} diff --git a/pkg/tts/generate_unsupported.go b/pkg/tts/generate_unsupported.go deleted file mode 100644 index 3092695..0000000 --- a/pkg/tts/generate_unsupported.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !tts -// +build !tts - -package tts - -import "fmt" - -func tts(text, model, assetDir, arLib, dst string) error { - return fmt.Errorf("this version of LocalAI was built without the tts tag") -} diff --git a/pkg/tts/piper.go b/pkg/tts/piper.go deleted file mode 100644 index b76a637..0000000 --- a/pkg/tts/piper.go +++ /dev/null @@ -1,20 +0,0 @@ -package tts - -import "os" - -type Piper struct { - assetDir string -} - -func New(assetDir string) (*Piper, error) { - if _, err := os.Stat(assetDir); err != nil { - return nil, err - } - return &Piper{ - assetDir: assetDir, - }, nil -} - -func (s *Piper) TTS(text, model, dst string) error { - return tts(text, model, s.assetDir, "", dst) -} From 189cb3a7be8f15d5697816911f79225c44736adb Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 07/12] feat: run all tests Signed-off-by: Ettore Di Giacinto --- .github/workflows/test.yml | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a18cd20..9d64c42 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,9 +26,30 @@ jobs: run: | sudo apt-get update sudo apt-get install build-essential ffmpeg + + sudo apt-get install -y ca-certificates cmake curl patch + sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2 + + sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \ + PIPER_PHONEMIZE_VERSION='1.0.0' SPDLOG_VERSION="1.11.0" \ + curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v${SPDLOG_VERSION}.tar.gz" | \ + tar -xzvf - && \ + mkdir -p "spdlog-${SPDLOG_VERSION}/build" && \ + cd "spdlog-${SPDLOG_VERSION}/build" && \ + cmake .. && \ + make -j8 && \ + sudo cmake --install . --prefix /usr && mkdir -p "lib/Linux-$(uname -m)" && \ + cd /build && \ + mkdir -p "lib/Linux-$(uname -m)/piper_phonemize" && \ + curl -L "https://github.com/rhasspy/piper-phonemize/releases/download/v${PIPER_PHONEMIZE_VERSION}/libpiper_phonemize-${TARGETARCH:-$(go env GOARCH)}${TARGETVARIANT}.tar.gz" | \ + tar -C "lib/Linux-$(uname -m)/piper_phonemize" -xzvf - && ls -liah /build/lib/Linux-$(uname -m)/piper_phonemize/ && \ + sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /lib64/ && \ + sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \ + sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/ + - name: Test run: | - make test + ESPEAK_DATA="/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data" GO_TAGS="tts stablediffusion" make test macOS-latest: runs-on: macOS-latest From 7f3de3ca4aabac33ec6211c2d9fd634c5ae452fd Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 08/12] fix: fix makefile error Signed-off-by: Ettore Di Giacinto --- Makefile | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Makefile b/Makefile index 9596bcb..639ce7f 100644 --- a/Makefile +++ b/Makefile @@ -99,10 +99,6 @@ ifeq ($(findstring tts,$(GO_TAGS)),tts) OPTIONAL_TARGETS+=go-piper/libpiper_binding.a OPTIONAL_TARGETS+=backend-assets/espeak-ng-data OPTIONAL_GRPC+=backend-assets/grpc/piper -# die if ESPEAK_DATA is not set -ifndef ESPEAK_DATA -$(error ESPEAK_DATA is not set. Espeak data is required for tts) -endif endif .PHONY: all test build vendor From 98e73ed67a6a476a490e8bf06e8ef6aacfd1605c Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 09/12] fix: CI fixes Signed-off-by: Ettore Di Giacinto --- .github/workflows/test.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9d64c42..5b8385c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,17 +31,16 @@ jobs: sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2 sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \ - PIPER_PHONEMIZE_VERSION='1.0.0' SPDLOG_VERSION="1.11.0" \ - curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v${SPDLOG_VERSION}.tar.gz" | \ + curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v1.11.0.tar.gz" | \ tar -xzvf - && \ - mkdir -p "spdlog-${SPDLOG_VERSION}/build" && \ - cd "spdlog-${SPDLOG_VERSION}/build" && \ + mkdir -p "spdlog-1.11.0/build" && \ + cd "spdlog-1.11.0/build" && \ cmake .. && \ make -j8 && \ sudo cmake --install . --prefix /usr && mkdir -p "lib/Linux-$(uname -m)" && \ cd /build && \ mkdir -p "lib/Linux-$(uname -m)/piper_phonemize" && \ - curl -L "https://github.com/rhasspy/piper-phonemize/releases/download/v${PIPER_PHONEMIZE_VERSION}/libpiper_phonemize-${TARGETARCH:-$(go env GOARCH)}${TARGETVARIANT}.tar.gz" | \ + curl -L "https://github.com/rhasspy/piper-phonemize/releases/download/v1.0.0/libpiper_phonemize-amd64.tar.gz" | \ tar -C "lib/Linux-$(uname -m)/piper_phonemize" -xzvf - && ls -liah /build/lib/Linux-$(uname -m)/piper_phonemize/ && \ sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /lib64/ && \ sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \ From 26e510bf28e6850d441529440de7a83df1e235f2 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 10/12] fix: Makefile Signed-off-by: Ettore Di Giacinto --- Makefile | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 639ce7f..3ec2213 100644 --- a/Makefile +++ b/Makefile @@ -91,13 +91,13 @@ ifeq ($(STATIC),true) endif ifeq ($(findstring stablediffusion,$(GO_TAGS)),stablediffusion) - OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a +# OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a OPTIONAL_GRPC+=backend-assets/grpc/stablediffusion endif ifeq ($(findstring tts,$(GO_TAGS)),tts) - OPTIONAL_TARGETS+=go-piper/libpiper_binding.a - OPTIONAL_TARGETS+=backend-assets/espeak-ng-data +# OPTIONAL_TARGETS+=go-piper/libpiper_binding.a +# OPTIONAL_TARGETS+=backend-assets/espeak-ng-data OPTIONAL_GRPC+=backend-assets/grpc/piper endif @@ -234,7 +234,7 @@ rebuild: ## Rebuilds the project $(MAKE) -C go-ggllm clean $(MAKE) build -prepare: prepare-sources grpcs go-bert/libgobert.a go-ggml-transformers/libtransformers.a whisper.cpp/libwhisper.a $(OPTIONAL_TARGETS) +prepare: prepare-sources $(OPTIONAL_TARGETS) touch $@ clean: ## Remove build related file @@ -256,7 +256,7 @@ clean: ## Remove build related file ## Build: -build: prepare ## Build the project +build: grpcs prepare ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) @@ -415,4 +415,4 @@ backend-assets/grpc/whisper: backend-assets/grpc whisper.cpp/libwhisper.a CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/whisper.cpp LIBRARY_PATH=$(shell pwd)/whisper.cpp \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./cmd/grpc/whisper/ -grpcs: backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC) \ No newline at end of file +grpcs: prepare backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC) \ No newline at end of file From c0a91ab548616ec50bf0c8c4cd43621c62e7e847 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 11/12] fix: fix LDFLAGS for rwkv.cpp Previously the libs were added by other deps that made the linker add those as well (by chance). Signed-off-by: Ettore Di Giacinto --- Makefile | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/Makefile b/Makefile index 3ec2213..c381203 100644 --- a/Makefile +++ b/Makefile @@ -67,6 +67,15 @@ WHITE := $(shell tput -Txterm setaf 7) CYAN := $(shell tput -Txterm setaf 6) RESET := $(shell tput -Txterm sgr0) +ifndef UNAME_S +UNAME_S := $(shell uname -s) +endif + +# workaround for rwkv.cpp +ifeq ($(UNAME_S),Darwin) + CGO_LDFLAGS += -lcblas -framework Accelerate +endif + ifeq ($(BUILD_TYPE),openblas) CGO_LDFLAGS+=-lopenblas endif From f193f565647832c899976cb32675cd238c83da68 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH 12/12] fix: fix copy Signed-off-by: Ettore Di Giacinto --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index c381203..ba01c59 100644 --- a/Makefile +++ b/Makefile @@ -296,7 +296,7 @@ test-models/testmodel: cp tests/models_fixtures/* test-models prepare-test: grpcs - cp -r backend-assets api + cp -rf backend-assets api cp tests/models_fixtures/* test-models test: prepare test-models/testmodel grpcs