feat: move other backends to grpc

This finally makes everything more consistent

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
renovate/github.com-imdario-mergo-1.x
Ettore Di Giacinto 1 year ago
parent 5dcfdbe51d
commit 1d0ed95a54
  1. 4
      .gitignore
  2. 149
      Makefile
  3. 7
      api/api.go
  4. 233
      api/api_test.go
  5. 28
      api/backend/embeddings.go
  6. 28
      api/backend/image.go
  7. 118
      api/backend/llm.go
  8. 26
      api/backend/options.go
  9. 17
      api/localai/localai.go
  10. 18
      api/openai/transcription.go
  11. 22
      cmd/grpc/bert-embeddings/main.go
  12. 23
      cmd/grpc/bloomz/main.go
  13. 23
      cmd/grpc/falcon-ggml/main.go
  14. 23
      cmd/grpc/langchain-huggingface/main.go
  15. 23
      cmd/grpc/piper/main.go
  16. 23
      cmd/grpc/rwkv/main.go
  17. 23
      cmd/grpc/stablediffusion/main.go
  18. 23
      cmd/grpc/whisper/main.go
  19. 9
      main.go
  20. 42
      pkg/grpc/base/base.go
  21. 61
      pkg/grpc/client.go
  22. 33
      pkg/grpc/image/stablediffusion.go
  23. 6
      pkg/grpc/interface.go
  24. 33
      pkg/grpc/llm/bert/bert.go
  25. 59
      pkg/grpc/llm/bloomz/bloomz.go
  26. 11
      pkg/grpc/llm/falcon/falcon.go
  27. 9
      pkg/grpc/llm/gpt4all/gpt4all.go
  28. 58
      pkg/grpc/llm/langchain/langchain.go
  29. 7
      pkg/grpc/llm/llama/llama.go
  30. 71
      pkg/grpc/llm/rwkv/rwkv.go
  31. 11
      pkg/grpc/llm/transformers/dolly.go
  32. 43
      pkg/grpc/llm/transformers/falcon.go
  33. 10
      pkg/grpc/llm/transformers/gpt2.go
  34. 10
      pkg/grpc/llm/transformers/gptj.go
  35. 10
      pkg/grpc/llm/transformers/gptneox.go
  36. 10
      pkg/grpc/llm/transformers/mpt.go
  37. 10
      pkg/grpc/llm/transformers/replit.go
  38. 11
      pkg/grpc/llm/transformers/starcoder.go
  39. 1458
      pkg/grpc/proto/backend.pb.go
  40. 47
      pkg/grpc/proto/backend.proto
  41. 385
      pkg/grpc/proto/backend_grpc.pb.go
  42. 969
      pkg/grpc/proto/llmserver.pb.go
  43. 277
      pkg/grpc/proto/llmserver_grpc.pb.go
  44. 47
      pkg/grpc/server.go
  45. 27
      pkg/grpc/transcribe/whisper.go
  46. 44
      pkg/grpc/tts/piper.go
  47. 16
      pkg/grpc/whisper/api/api.go
  48. 23
      pkg/grpc/whisper/whisper.go
  49. 171
      pkg/model/initializers.go
  50. 34
      pkg/model/loader.go
  51. 16
      pkg/model/options.go
  52. 12
      pkg/tts/generate.go
  53. 10
      pkg/tts/generate_unsupported.go
  54. 20
      pkg/tts/piper.go

4
.gitignore vendored

@ -4,7 +4,7 @@ go-llama
go-stable-diffusion go-stable-diffusion
go-piper go-piper
go-ggllm go-ggllm
piper /piper
*.a *.a
get-sources get-sources
@ -13,7 +13,7 @@ go-ggml-transformers
go-gpt2 go-gpt2
go-rwkv go-rwkv
whisper.cpp whisper.cpp
bloomz /bloomz
go-bert go-bert
# LocalAI build binary # LocalAI build binary

@ -67,9 +67,6 @@ WHITE := $(shell tput -Txterm setaf 7)
CYAN := $(shell tput -Txterm setaf 6) CYAN := $(shell tput -Txterm setaf 6)
RESET := $(shell tput -Txterm sgr0) RESET := $(shell tput -Txterm sgr0)
C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz
LIBRARY_PATH=$(shell pwd)/go-piper:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz
ifeq ($(BUILD_TYPE),openblas) ifeq ($(BUILD_TYPE),openblas)
CGO_LDFLAGS+=-lopenblas CGO_LDFLAGS+=-lopenblas
endif endif
@ -95,11 +92,17 @@ endif
ifeq ($(findstring stablediffusion,$(GO_TAGS)),stablediffusion) ifeq ($(findstring stablediffusion,$(GO_TAGS)),stablediffusion)
OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a
OPTIONAL_GRPC+=backend-assets/grpc/stablediffusion
endif endif
ifeq ($(findstring tts,$(GO_TAGS)),tts) ifeq ($(findstring tts,$(GO_TAGS)),tts)
OPTIONAL_TARGETS+=go-piper/libpiper_binding.a OPTIONAL_TARGETS+=go-piper/libpiper_binding.a
OPTIONAL_TARGETS+=backend-assets/espeak-ng-data OPTIONAL_TARGETS+=backend-assets/espeak-ng-data
OPTIONAL_GRPC+=backend-assets/grpc/piper
# die if ESPEAK_DATA is not set
ifndef ESPEAK_DATA
$(error ESPEAK_DATA is not set. Espeak data is required for tts)
endif
endif endif
.PHONY: all test build vendor .PHONY: all test build vendor
@ -128,9 +131,6 @@ go-piper:
go-bert: go-bert:
git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp go-bert git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp go-bert
cd go-bert && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1 cd go-bert && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1
@find ./go-bert -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} +
@find ./go-bert -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} +
@find ./go-bert -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} +
## stable diffusion ## stable diffusion
go-stable-diffusion: go-stable-diffusion:
@ -144,9 +144,6 @@ go-stable-diffusion/libstablediffusion.a:
go-rwkv: go-rwkv:
git clone --recurse-submodules $(RWKV_REPO) go-rwkv git clone --recurse-submodules $(RWKV_REPO) go-rwkv
cd go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1 cd go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1
@find ./go-rwkv -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} +
@find ./go-rwkv -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} +
@find ./go-rwkv -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} +
go-rwkv/librwkv.a: go-rwkv go-rwkv/librwkv.a: go-rwkv
cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a ..
@ -154,13 +151,7 @@ go-rwkv/librwkv.a: go-rwkv
## bloomz ## bloomz
bloomz: bloomz:
git clone --recurse-submodules https://github.com/go-skynet/bloomz.cpp bloomz git clone --recurse-submodules https://github.com/go-skynet/bloomz.cpp bloomz
@find ./bloomz -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} + cd bloomz && git checkout -b build $(BLOOMZ_VERSION) && git submodule update --init --recursive --depth 1
@find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} +
@find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} +
@find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} +
@find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} +
@find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_bloomz_replace/g' {} +
@find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_bloomz_replace/g' {} +
bloomz/libbloomz.a: bloomz bloomz/libbloomz.a: bloomz
cd bloomz && make libbloomz.a cd bloomz && make libbloomz.a
@ -179,6 +170,7 @@ backend-assets/espeak-ng-data:
ifdef ESPEAK_DATA ifdef ESPEAK_DATA
@cp -rf $(ESPEAK_DATA)/. backend-assets/espeak-ng-data @cp -rf $(ESPEAK_DATA)/. backend-assets/espeak-ng-data
else else
@echo "ESPEAK_DATA not set, skipping tts. Note that this will break the tts functionality."
@touch backend-assets/espeak-ng-data/keep @touch backend-assets/espeak-ng-data/keep
endif endif
@ -196,9 +188,6 @@ go-ggml-transformers/libtransformers.a: go-ggml-transformers
whisper.cpp: whisper.cpp:
git clone https://github.com/ggerganov/whisper.cpp.git git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp && git checkout -b build $(WHISPER_CPP_VERSION) && git submodule update --init --recursive --depth 1 cd whisper.cpp && git checkout -b build $(WHISPER_CPP_VERSION) && git submodule update --init --recursive --depth 1
@find ./whisper.cpp -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} +
@find ./whisper.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} +
@find ./whisper.cpp -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} +
whisper.cpp/libwhisper.a: whisper.cpp whisper.cpp/libwhisper.a: whisper.cpp
cd whisper.cpp && make libwhisper.a cd whisper.cpp && make libwhisper.a
@ -249,7 +238,7 @@ rebuild: ## Rebuilds the project
$(MAKE) -C go-ggllm clean $(MAKE) -C go-ggllm clean
$(MAKE) build $(MAKE) build
prepare: prepare-sources grpcs go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a $(OPTIONAL_TARGETS) prepare: prepare-sources grpcs go-bert/libgobert.a go-ggml-transformers/libtransformers.a whisper.cpp/libwhisper.a $(OPTIONAL_TARGETS)
touch $@ touch $@
clean: ## Remove build related file clean: ## Remove build related file
@ -277,7 +266,7 @@ build: prepare ## Build the project
$(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET})
$(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET}) $(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET})
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./ CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./
ifeq ($(BUILD_TYPE),metal) ifeq ($(BUILD_TYPE),metal)
cp go-llama/build/bin/ggml-metal.metal . cp go-llama/build/bin/ggml-metal.metal .
endif endif
@ -286,12 +275,9 @@ dist: build
mkdir -p release mkdir -p release
cp $(BINARY_NAME) release/$(BINARY_NAME)-$(BUILD_ID)-$(OS)-$(ARCH) cp $(BINARY_NAME) release/$(BINARY_NAME)-$(BUILD_ID)-$(OS)-$(ARCH)
generic-build: ## Build the project using generic
BUILD_TYPE="generic" $(MAKE) build
## Run ## Run
run: prepare ## run local-ai run: prepare ## run local-ai
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) run ./ CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./
test-models/testmodel: test-models/testmodel:
mkdir test-models mkdir test-models
@ -304,12 +290,42 @@ test-models/testmodel:
wget https://raw.githubusercontent.com/saharNooby/rwkv.cpp/5eb8f09c146ea8124633ab041d9ea0b1f1db4459/rwkv/20B_tokenizer.json -O test-models/rwkv.tokenizer.json wget https://raw.githubusercontent.com/saharNooby/rwkv.cpp/5eb8f09c146ea8124633ab041d9ea0b1f1db4459/rwkv/20B_tokenizer.json -O test-models/rwkv.tokenizer.json
cp tests/models_fixtures/* test-models cp tests/models_fixtures/* test-models
test: prepare test-models/testmodel prepare-test: grpcs
cp -r backend-assets api cp -r backend-assets api
cp tests/models_fixtures/* test-models cp tests/models_fixtures/* test-models
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg test: prepare test-models/testmodel grpcs
C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg @echo 'Running tests'
export GO_TAGS="tts stablediffusion"
$(MAKE) prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg
$(MAKE) test-gpt4all
$(MAKE) test-llama
$(MAKE) test-tts
$(MAKE) test-stablediffusion
test-gpt4all: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg
test-llama: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg
test-tts: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r ./api ./pkg
test-stablediffusion: prepare-test
TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \
$(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r ./api ./pkg
test-container:
docker build --target requirements -t local-ai-test-container .
docker run --name localai-tests -e GO_TAGS=$(GO_TAGS) -ti -v $(abspath ./):/build local-ai-test-container make test
docker rm localai-tests
docker rmi local-ai-test-container
## Help: ## Help:
help: ## Show this help. help: ## Show this help.
@ -325,51 +341,82 @@ help: ## Show this help.
protogen: protogen:
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \ protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \
pkg/grpc/proto/llmserver.proto pkg/grpc/proto/backend.proto
## GRPC ## GRPC
backend-assets/grpc: backend-assets/grpc:
mkdir -p backend-assets/grpc mkdir -p backend-assets/grpc
falcon-grpc: backend-assets/grpc go-ggllm/libggllm.a backend-assets/grpc/falcon: backend-assets/grpc go-ggllm/libggllm.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggllm LIBRARY_PATH=$(shell pwd)/go-ggllm \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggllm LIBRARY_PATH=$(shell pwd)/go-ggllm \
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/
llama-grpc: backend-assets/grpc go-llama/libbinding.a backend-assets/grpc/llama: backend-assets/grpc go-llama/libbinding.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/
gpt4all-grpc: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-bindings/golang/libgpt4all.a backend-assets/grpc/gpt4all: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-bindings/golang/libgpt4all.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ \
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./cmd/grpc/gpt4all/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./cmd/grpc/gpt4all/
dolly-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a backend-assets/grpc/dolly: backend-assets/grpc go-ggml-transformers/libtransformers.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/dolly ./cmd/grpc/dolly/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/dolly ./cmd/grpc/dolly/
gpt2-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a backend-assets/grpc/gpt2: backend-assets/grpc go-ggml-transformers/libtransformers.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt2 ./cmd/grpc/gpt2/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt2 ./cmd/grpc/gpt2/
gptj-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a backend-assets/grpc/gptj: backend-assets/grpc go-ggml-transformers/libtransformers.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptj ./cmd/grpc/gptj/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptj ./cmd/grpc/gptj/
gptneox-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a backend-assets/grpc/gptneox: backend-assets/grpc go-ggml-transformers/libtransformers.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptneox ./cmd/grpc/gptneox/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptneox ./cmd/grpc/gptneox/
mpt-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a backend-assets/grpc/mpt: backend-assets/grpc go-ggml-transformers/libtransformers.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/mpt ./cmd/grpc/mpt/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/mpt ./cmd/grpc/mpt/
replit-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a backend-assets/grpc/replit: backend-assets/grpc go-ggml-transformers/libtransformers.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/replit ./cmd/grpc/replit/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/replit ./cmd/grpc/replit/
starcoder-grpc: backend-assets/grpc go-ggml-transformers/libtransformers.a backend-assets/grpc/falcon-ggml: backend-assets/grpc go-ggml-transformers/libtransformers.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
$(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/starcoder ./cmd/grpc/starcoder/ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon-ggml ./cmd/grpc/falcon-ggml/
backend-assets/grpc/starcoder: backend-assets/grpc go-ggml-transformers/libtransformers.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/starcoder ./cmd/grpc/starcoder/
backend-assets/grpc/rwkv: backend-assets/grpc go-rwkv/librwkv.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-rwkv LIBRARY_PATH=$(shell pwd)/go-rwkv \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/rwkv ./cmd/grpc/rwkv/
backend-assets/grpc/bloomz: backend-assets/grpc bloomz/libbloomz.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/bloomz LIBRARY_PATH=$(shell pwd)/bloomz \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bloomz ./cmd/grpc/bloomz/
backend-assets/grpc/bert-embeddings: backend-assets/grpc go-bert/libgobert.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-bert LIBRARY_PATH=$(shell pwd)/go-bert \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bert-embeddings ./cmd/grpc/bert-embeddings/
backend-assets/grpc/langchain-huggingface: backend-assets/grpc
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/langchain-huggingface ./cmd/grpc/langchain-huggingface/
backend-assets/grpc/stablediffusion: backend-assets/grpc go-stable-diffusion/libstablediffusion.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/go-stable-diffusion/ \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./cmd/grpc/stablediffusion/
backend-assets/grpc/piper: backend-assets/grpc backend-assets/espeak-ng-data go-piper/libpiper_binding.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/go-piper \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./cmd/grpc/piper/
backend-assets/grpc/whisper: backend-assets/grpc whisper.cpp/libwhisper.a
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/whisper.cpp LIBRARY_PATH=$(shell pwd)/whisper.cpp \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./cmd/grpc/whisper/
grpcs: falcon-grpc llama-grpc gpt4all-grpc dolly-grpc gpt2-grpc gptj-grpc gptneox-grpc mpt-grpc replit-grpc starcoder-grpc grpcs: backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC)

@ -173,5 +173,12 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
app.Get("/v1/models", openai.ListModelsEndpoint(options.Loader, cm)) app.Get("/v1/models", openai.ListModelsEndpoint(options.Loader, cm))
app.Get("/models", openai.ListModelsEndpoint(options.Loader, cm)) app.Get("/models", openai.ListModelsEndpoint(options.Loader, cm))
// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
options.Loader.StopGRPC()
}()
return app, nil return app, nil
} }

@ -5,7 +5,9 @@ import (
"context" "context"
"embed" "embed"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"os" "os"
@ -24,6 +26,7 @@ import (
openaigo "github.com/otiai10/openaigo" openaigo "github.com/otiai10/openaigo"
"github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai"
"github.com/sashabaranov/go-openai/jsonschema"
) )
type modelApplyRequest struct { type modelApplyRequest struct {
@ -203,7 +206,7 @@ var _ = Describe("API test", func() {
fmt.Println(response) fmt.Println(response)
resp = response resp = response
return response["processed"].(bool) return response["processed"].(bool)
}, "360s").Should(Equal(true)) }, "360s", "10s").Should(Equal(true))
Expect(resp["message"]).ToNot(ContainSubstring("error")) Expect(resp["message"]).ToNot(ContainSubstring("error"))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml")) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml"))
@ -245,9 +248,8 @@ var _ = Describe("API test", func() {
Eventually(func() bool { Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool) return response["processed"].(bool)
}, "360s").Should(Equal(true)) }, "360s", "10s").Should(Equal(true))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -270,9 +272,8 @@ var _ = Describe("API test", func() {
Eventually(func() bool { Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool) return response["processed"].(bool)
}, "360s").Should(Equal(true)) }, "360s", "10s").Should(Equal(true))
dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml"))
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@ -299,14 +300,58 @@ var _ = Describe("API test", func() {
Eventually(func() bool { Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool) return response["processed"].(bool)
}, "360s").Should(Equal(true)) }, "360s", "10s").Should(Equal(true))
By("testing completion")
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "openllama_3b", Prompt: "Count up to five: one, two, three, four, "}) resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "openllama_3b", Prompt: "Count up to five: one, two, three, four, "})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1)) Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Text).To(ContainSubstring("five")) Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
By("testing functions")
resp2, err := client.CreateChatCompletion(
context.TODO(),
openai.ChatCompletionRequest{
Model: "openllama_3b",
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
Content: "What is the weather like in San Francisco (celsius)?",
},
},
Functions: []openai.FunctionDefinition{
openai.FunctionDefinition{
Name: "get_current_weather",
Description: "Get the current weather",
Parameters: jsonschema.Definition{
Type: jsonschema.Object,
Properties: map[string]jsonschema.Definition{
"location": {
Type: jsonschema.String,
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: jsonschema.String,
Enum: []string{"celcius", "fahrenheit"},
},
},
Required: []string{"location"},
},
},
},
})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp2.Choices)).To(Equal(1))
Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil())
Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name)
var res map[string]string
err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res)
Expect(err).ToNot(HaveOccurred())
Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res))
Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res))
Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason))
}) })
It("runs gpt4all", Label("gpt4all"), func() { It("runs gpt4all", Label("gpt4all"), func() {
@ -326,15 +371,126 @@ var _ = Describe("API test", func() {
Eventually(func() bool { Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool) return response["processed"].(bool)
}, "360s").Should(Equal(true)) }, "360s", "10s").Should(Equal(true))
resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-j", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "How are you?"}}}) resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-j", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "How are you?"}}})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices)).To(Equal(1)) Expect(len(resp.Choices)).To(Equal(1))
Expect(resp.Choices[0].Message.Content).To(ContainSubstring("well")) Expect(resp.Choices[0].Message.Content).To(ContainSubstring("well"))
}) })
})
})
Context("Model gallery", func() {
BeforeEach(func() {
var err error
tmpdir, err = os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background())
galleries := []gallery.Gallery{
{
Name: "model-gallery",
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/index.yaml",
},
}
app, err = App(
options.WithContext(c),
options.WithAudioDir(tmpdir),
options.WithImageDir(tmpdir),
options.WithGalleries(galleries),
options.WithModelLoader(modelLoader),
options.WithBackendAssets(backendAssets),
options.WithBackendAssetsOutput(tmpdir),
)
Expect(err).ToNot(HaveOccurred())
go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("")
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1"
client2 = openaigo.NewClient("")
client2.BaseURL = defaultConfig.BaseURL
// Wait for API to be ready
client = openai.NewClientWithConfig(defaultConfig)
Eventually(func() error {
_, err := client.ListModels(context.TODO())
return err
}, "2m").ShouldNot(HaveOccurred())
})
AfterEach(func() {
cancel()
app.Shutdown()
os.RemoveAll(tmpdir)
})
It("installs and is capable to run tts", Label("tts"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "model-gallery@voice-en-us-kathleen-low",
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
// An HTTP Post to the /tts endpoint should return a wav audio file
resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "en-us-kathleen-low.onnx"}`)))
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
dat, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat)))
Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav"))
})
It("installs and is capable to generate images", Label("stablediffusion"), func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{
ID: "model-gallery@stablediffusion",
})
Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response))
uuid := response["uuid"].(string)
Eventually(func() bool {
response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid)
fmt.Println(response)
return response["processed"].(bool)
}, "360s", "10s").Should(Equal(true))
resp, err := http.Post(
"http://127.0.0.1:9090/v1/images/generations",
"application/json",
bytes.NewBuffer([]byte(`{
"prompt": "floating hair, portrait, ((loli)), ((one girl)), cute face, hidden hands, asymmetrical bangs, beautiful detailed eyes, eye shadow, hair ornament, ribbons, bowties, buttons, pleated skirt, (((masterpiece))), ((best quality)), colorful|((part of the head)), ((((mutated hands and fingers)))), deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, Octane renderer, lowres, bad anatomy, bad hands, text",
"mode": 2, "seed":9000,
"size": "256x256", "n":2}`)))
// The response should contain an URL
Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp))
dat, err := io.ReadAll(resp.Body)
Expect(err).ToNot(HaveOccurred(), string(dat))
Expect(string(dat)).To(ContainSubstring("http://127.0.0.1:9090/"), string(dat))
Expect(string(dat)).To(ContainSubstring(".png"), string(dat))
}) })
}) })
@ -401,7 +557,7 @@ var _ = Describe("API test", func() {
It("returns errors", func() { It("returns errors", func() {
_, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"})
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 11 errors occurred:")) Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:"))
}) })
It("transcribes audio", func() { It("transcribes audio", func() {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
@ -446,14 +602,67 @@ var _ = Describe("API test", func() {
}) })
Context("backends", func() { Context("backends", func() {
It("runs rwkv", func() { It("runs rwkv completion", func() {
if runtime.GOOS != "linux" { if runtime.GOOS != "linux" {
Skip("test supported only on linux") Skip("test supported only on linux")
} }
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"}) resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue()) Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(resp.Choices[0].Text).To(Equal(" five.")) Expect(resp.Choices[0].Text).To(ContainSubstring("five"))
stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()
tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Text
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(text).To(ContainSubstring("five"))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
It("runs rwkv chat completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateChatCompletion(context.TODO(),
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()
tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}
Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Delta.Content
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(text).To(Or(ContainSubstring("Sure"), ContainSubstring("five")))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
}) })
}) })
}) })

@ -1,7 +1,6 @@
package backend package backend
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
@ -9,7 +8,6 @@ import (
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
bert "github.com/go-skynet/go-bert.cpp"
) )
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) {
@ -25,10 +23,11 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
var err error var err error
opts := []model.Option{ opts := []model.Option{
model.WithLoadGRPCOpts(grpcOpts), model.WithLoadGRPCLLMModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), model.WithThreads(uint32(c.Threads)),
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(o.AssetsDestination),
model.WithModelFile(modelFile), model.WithModelFile(modelFile),
model.WithContext(o.Context),
} }
if c.Backend == "" { if c.Backend == "" {
@ -54,7 +53,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
} }
predictOptions.EmbeddingTokens = embeds predictOptions.EmbeddingTokens = embeds
res, err := model.Embeddings(context.TODO(), predictOptions) res, err := model.Embeddings(o.Context, predictOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,22 +62,13 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
} }
predictOptions.Embeddings = s predictOptions.Embeddings = s
res, err := model.Embeddings(context.TODO(), predictOptions) res, err := model.Embeddings(o.Context, predictOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return res.Embeddings, nil return res.Embeddings, nil
} }
// bert embeddings
case *bert.Bert:
fn = func() ([]float32, error) {
if len(tokens) > 0 {
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads))
}
return model.Embeddings(s, bert.SetThreads(c.Threads))
}
default: default:
fn = func() ([]float32, error) { fn = func() ([]float32, error) {
return nil, fmt.Errorf("embeddings not supported by the backend") return nil, fmt.Errorf("embeddings not supported by the backend")
@ -87,7 +77,15 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
return func() ([]float32, error) { return func() ([]float32, error) {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
l := Lock(modelFile) mutexMap.Lock()
l, ok := mutexes[modelFile]
if !ok {
m := &sync.Mutex{}
mutexes[modelFile] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock() defer l.Unlock()
embeds, err := fn() embeds, err := fn()

@ -6,8 +6,8 @@ import (
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
) )
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) {
@ -19,23 +19,27 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
model.WithBackendString(c.Backend), model.WithBackendString(c.Backend),
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(o.AssetsDestination),
model.WithThreads(uint32(c.Threads)), model.WithThreads(uint32(c.Threads)),
model.WithContext(o.Context),
model.WithModelFile(c.ImageGenerationAssets), model.WithModelFile(c.ImageGenerationAssets),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
var fn func() error fn := func() error {
switch model := inferenceModel.(type) { _, err := inferenceModel.GenerateImage(
case *stablediffusion.StableDiffusion: o.Context,
fn = func() error { &proto.GenerateImageRequest{
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) Height: int32(height),
} Width: int32(width),
Mode: int32(mode),
default: Step: int32(step),
fn = func() error { Seed: int32(seed),
return fmt.Errorf("creation of images not supported by the backend") PositivePrompt: positive_prompt,
} NegativePrompt: negative_prompt,
Dst: dst,
})
return err
} }
return func() error { return func() error {

@ -1,34 +1,30 @@
package backend package backend
import ( import (
"context"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
"github.com/donomii/go-rwkv.cpp"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc" "github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/langchain"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/bloomz.cpp"
) )
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
supportStreams := false
modelFile := c.Model modelFile := c.Model
grpcOpts := gRPCModelOpts(c) grpcOpts := gRPCModelOpts(c)
var inferenceModel interface{} var inferenceModel *grpc.Client
var err error var err error
opts := []model.Option{ opts := []model.Option{
model.WithLoadGRPCOpts(grpcOpts), model.WithLoadGRPCLLMModelOpts(grpcOpts),
model.WithThreads(uint32(c.Threads)), // GPT4all uses this model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
model.WithAssetDir(o.AssetsDestination), model.WithAssetDir(o.AssetsDestination),
model.WithModelFile(modelFile), model.WithModelFile(modelFile),
model.WithContext(o.Context),
} }
if c.Backend == "" { if c.Backend == "" {
@ -41,95 +37,37 @@ func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *opt
return nil, err return nil, err
} }
var fn func() (string, error) // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (string, error) {
switch model := inferenceModel.(type) { opts := gRPCPredictOpts(c, loader.ModelPath)
case *rwkv.RwkvState: opts.Prompt = s
supportStreams = true if tokenCallback != nil {
ss := ""
fn = func() (string, error) { err := inferenceModel.PredictStream(o.Context, opts, func(s string) {
stopWord := "\n" tokenCallback(s)
if len(c.StopWords) > 0 { ss += s
stopWord = c.StopWords[0] })
} return ss, err
} else {
if err := model.ProcessInput(s); err != nil { reply, err := inferenceModel.Predict(o.Context, opts)
return "", err return reply.Message, err
}
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
return response, nil
}
case *bloomz.Bloomz:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []bloomz.PredictOption{
bloomz.SetTemperature(c.Temperature),
bloomz.SetTopP(c.TopP),
bloomz.SetTopK(c.TopK),
bloomz.SetTokens(c.Maxtokens),
bloomz.SetThreads(c.Threads),
}
if c.Seed != 0 {
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
}
return model.Predict(
s,
predictOptions...,
)
}
case *grpc.Client:
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
supportStreams = true
fn = func() (string, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s
if tokenCallback != nil {
ss := ""
err := model.PredictStream(context.TODO(), opts, func(s string) {
tokenCallback(s)
ss += s
})
return ss, err
} else {
reply, err := model.Predict(context.TODO(), opts)
return reply.Message, err
}
}
case *langchain.HuggingFace:
fn = func() (string, error) {
// Generate the prediction using the language model
predictOptions := []langchain.PredictOption{
langchain.SetModel(c.Model),
langchain.SetMaxTokens(c.Maxtokens),
langchain.SetTemperature(c.Temperature),
langchain.SetStopWords(c.StopWords),
}
pred, er := model.PredictHuggingFace(s, predictOptions...)
if er != nil {
return "", er
}
return pred.Completion, nil
} }
} }
return func() (string, error) { return func() (string, error) {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
l := Lock(modelFile) mutexMap.Lock()
l, ok := mutexes[modelFile]
if !ok {
m := &sync.Mutex{}
mutexes[modelFile] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock() defer l.Unlock()
res, err := fn() return fn()
if tokenCallback != nil && !supportStreams {
tokenCallback(res)
}
return res, err
}, nil }, nil
} }

@ -7,34 +7,8 @@ import (
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/pkg/langchain"
"github.com/go-skynet/bloomz.cpp"
) )
func langchainOptions(c config.Config) []langchain.PredictOption {
return []langchain.PredictOption{
langchain.SetModel(c.Model),
langchain.SetMaxTokens(c.Maxtokens),
langchain.SetTemperature(c.Temperature),
langchain.SetStopWords(c.StopWords),
}
}
func bloomzOptions(c config.Config) []bloomz.PredictOption {
// Generate the prediction using the language model
predictOptions := []bloomz.PredictOption{
bloomz.SetTemperature(c.Temperature),
bloomz.SetTopP(c.TopP),
bloomz.SetTopK(c.TopK),
bloomz.SetTokens(c.Maxtokens),
bloomz.SetThreads(c.Threads),
}
if c.Seed != 0 {
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
}
return predictOptions
}
func gRPCModelOpts(c config.Config) *pb.ModelOptions { func gRPCModelOpts(c config.Config) *pb.ModelOptions {
b := 512 b := 512
if c.Batch != 0 { if c.Batch != 0 {

@ -1,6 +1,7 @@
package localai package localai
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -8,8 +9,8 @@ import (
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/go-skynet/LocalAI/pkg/tts"
"github.com/go-skynet/LocalAI/pkg/utils" "github.com/go-skynet/LocalAI/pkg/utils"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
@ -47,6 +48,7 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
piperModel, err := o.Loader.BackendLoader( piperModel, err := o.Loader.BackendLoader(
model.WithBackendString(model.PiperBackend), model.WithBackendString(model.PiperBackend),
model.WithModelFile(input.Model), model.WithModelFile(input.Model),
model.WithContext(o.Context),
model.WithAssetDir(o.AssetsDestination)) model.WithAssetDir(o.AssetsDestination))
if err != nil { if err != nil {
return err return err
@ -56,13 +58,8 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return fmt.Errorf("could not load piper model") return fmt.Errorf("could not load piper model")
} }
w, ok := piperModel.(*tts.Piper)
if !ok {
return fmt.Errorf("loader returned non-piper object %+v", w)
}
if err := os.MkdirAll(o.AudioDir, 0755); err != nil { if err := os.MkdirAll(o.AudioDir, 0755); err != nil {
return err return fmt.Errorf("failed creating audio directory: %s", err)
} }
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav")
@ -74,7 +71,11 @@ func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx)
return err return err
} }
if err := w.TTS(input.Input, modelPath, filePath); err != nil { if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{
Text: input.Input,
Model: modelPath,
Dst: filePath,
}); err != nil {
return err return err
} }

@ -1,6 +1,7 @@
package openai package openai
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -8,11 +9,10 @@ import (
"path" "path"
"path/filepath" "path/filepath"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
config "github.com/go-skynet/LocalAI/api/config" config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
model "github.com/go-skynet/LocalAI/pkg/model" model "github.com/go-skynet/LocalAI/pkg/model"
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -64,6 +64,7 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
whisperModel, err := o.Loader.BackendLoader( whisperModel, err := o.Loader.BackendLoader(
model.WithBackendString(model.WhisperBackend), model.WithBackendString(model.WhisperBackend),
model.WithModelFile(config.Model), model.WithModelFile(config.Model),
model.WithContext(o.Context),
model.WithThreads(uint32(config.Threads)), model.WithThreads(uint32(config.Threads)),
model.WithAssetDir(o.AssetsDestination)) model.WithAssetDir(o.AssetsDestination))
if err != nil { if err != nil {
@ -74,18 +75,17 @@ func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fibe
return fmt.Errorf("could not load whisper model") return fmt.Errorf("could not load whisper model")
} }
w, ok := whisperModel.(whisper.Model) tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
if !ok { Dst: dst,
return fmt.Errorf("loader returned non-whisper object") Language: input.Language,
} Threads: uint32(config.Threads),
})
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads))
if err != nil { if err != nil {
return err return err
} }
log.Debug().Msgf("Trascribed: %+v", tr) log.Debug().Msgf("Trascribed: %+v", tr)
// TODO: handle different outputs here // TODO: handle different outputs here
return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr}) return c.Status(http.StatusOK).JSON(tr)
} }
} }

@ -0,0 +1,22 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
bert "github.com/go-skynet/LocalAI/pkg/grpc/llm/bert"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
flag.Parse()
if err := grpc.StartServer(*addr, &bert.Embeddings{}); err != nil {
panic(err)
}
}

@ -0,0 +1,23 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
bloomz "github.com/go-skynet/LocalAI/pkg/grpc/llm/bloomz"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
flag.Parse()
if err := grpc.StartServer(*addr, &bloomz.LLM{}); err != nil {
panic(err)
}
}

@ -0,0 +1,23 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
flag.Parse()
if err := grpc.StartServer(*addr, &transformers.Falcon{}); err != nil {
panic(err)
}
}

@ -0,0 +1,23 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
langchain "github.com/go-skynet/LocalAI/pkg/grpc/llm/langchain"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
flag.Parse()
if err := grpc.StartServer(*addr, &langchain.LLM{}); err != nil {
panic(err)
}
}

@ -0,0 +1,23 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
tts "github.com/go-skynet/LocalAI/pkg/grpc/tts"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
flag.Parse()
if err := grpc.StartServer(*addr, &tts.Piper{}); err != nil {
panic(err)
}
}

@ -0,0 +1,23 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
rwkv "github.com/go-skynet/LocalAI/pkg/grpc/llm/rwkv"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
flag.Parse()
if err := grpc.StartServer(*addr, &rwkv.LLM{}); err != nil {
panic(err)
}
}

@ -0,0 +1,23 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
image "github.com/go-skynet/LocalAI/pkg/grpc/image"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
flag.Parse()
if err := grpc.StartServer(*addr, &image.StableDiffusion{}); err != nil {
panic(err)
}
}

@ -0,0 +1,23 @@
package main
// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"
transcribe "github.com/go-skynet/LocalAI/pkg/grpc/transcribe"
grpc "github.com/go-skynet/LocalAI/pkg/grpc"
)
var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)
func main() {
flag.Parse()
if err := grpc.StartServer(*addr, &transcribe.Whisper{}); err != nil {
panic(err)
}
}

@ -2,7 +2,9 @@ package main
import ( import (
"os" "os"
"os/signal"
"path/filepath" "path/filepath"
"syscall"
api "github.com/go-skynet/LocalAI/api" api "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/api/options"
@ -15,6 +17,13 @@ import (
func main() { func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
// clean up process
go func() {
c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
os.Exit(1)
}()
path, err := os.Getwd() path, err := os.Getwd()
if err != nil { if err != nil {

@ -0,0 +1,42 @@
package base
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"fmt"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
)
type Base struct {
}
func (llm *Base) Load(opts *pb.ModelOptions) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) Predict(opts *pb.PredictOptions) (string, error) {
return "", fmt.Errorf("unimplemented")
}
func (llm *Base) PredictStream(opts *pb.PredictOptions, results chan string) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return []float32{}, fmt.Errorf("unimplemented")
}
func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
return fmt.Errorf("unimplemented")
}
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (api.Result, error) {
return api.Result{}, fmt.Errorf("unimplemented")
}
func (llm *Base) TTS(*pb.TTSRequest) error {
return fmt.Errorf("unimplemented")
}

@ -7,6 +7,7 @@ import (
"time" "time"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
) )
@ -28,7 +29,7 @@ func (c *Client) HealthCheck(ctx context.Context) bool {
return false return false
} }
defer conn.Close() defer conn.Close()
client := pb.NewLLMClient(conn) client := pb.NewBackendClient(conn)
// The healthcheck call shouldn't take long time // The healthcheck call shouldn't take long time
ctx, cancel := context.WithTimeout(ctx, 10*time.Second) ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
@ -53,7 +54,7 @@ func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...
return nil, err return nil, err
} }
defer conn.Close() defer conn.Close()
client := pb.NewLLMClient(conn) client := pb.NewBackendClient(conn)
return client.Embedding(ctx, in, opts...) return client.Embedding(ctx, in, opts...)
} }
@ -64,7 +65,7 @@ func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grp
return nil, err return nil, err
} }
defer conn.Close() defer conn.Close()
client := pb.NewLLMClient(conn) client := pb.NewBackendClient(conn)
return client.Predict(ctx, in, opts...) return client.Predict(ctx, in, opts...)
} }
@ -75,7 +76,7 @@ func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grp
return nil, err return nil, err
} }
defer conn.Close() defer conn.Close()
client := pb.NewLLMClient(conn) client := pb.NewBackendClient(conn)
return client.LoadModel(ctx, in, opts...) return client.LoadModel(ctx, in, opts...)
} }
@ -85,7 +86,7 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
return err return err
} }
defer conn.Close() defer conn.Close()
client := pb.NewLLMClient(conn) client := pb.NewBackendClient(conn)
stream, err := client.PredictStream(ctx, in, opts...) stream, err := client.PredictStream(ctx, in, opts...)
if err != nil { if err != nil {
@ -107,3 +108,53 @@ func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f fun
return nil return nil
} }
func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) {
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.GenerateImage(ctx, in, opts...)
}
func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) {
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
return client.TTS(ctx, in, opts...)
}
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*api.Result, error) {
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, err
}
defer conn.Close()
client := pb.NewBackendClient(conn)
res, err := client.AudioTranscription(ctx, in, opts...)
if err != nil {
return nil, err
}
tresult := &api.Result{}
for _, s := range res.Segments {
tks := []int{}
for _, t := range s.Tokens {
tks = append(tks, int(t))
}
tresult.Segments = append(tresult.Segments,
api.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),
End: time.Duration(s.End),
Tokens: tks,
})
}
tresult.Text = res.Text
return tresult, err
}

@ -0,0 +1,33 @@
package image
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
)
type StableDiffusion struct {
base.Base
stablediffusion *stablediffusion.StableDiffusion
}
func (sd *StableDiffusion) Load(opts *pb.ModelOptions) error {
var err error
// Note: the Model here is a path to a directory containing the model files
sd.stablediffusion, err = stablediffusion.New(opts.Model)
return err
}
func (sd *StableDiffusion) GenerateImage(opts *pb.GenerateImageRequest) error {
return sd.stablediffusion.GenerateImage(
int(opts.Height),
int(opts.Width),
int(opts.Mode),
int(opts.Step),
int(opts.Seed),
opts.PositivePrompt,
opts.NegativePrompt,
opts.Dst)
}

@ -2,11 +2,15 @@ package grpc
import ( import (
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
) )
type LLM interface { type LLM interface {
Predict(*pb.PredictOptions) (string, error) Predict(*pb.PredictOptions) (string, error)
PredictStream(*pb.PredictOptions, chan string) PredictStream(*pb.PredictOptions, chan string) error
Load(*pb.ModelOptions) error Load(*pb.ModelOptions) error
Embeddings(*pb.PredictOptions) ([]float32, error) Embeddings(*pb.PredictOptions) ([]float32, error)
GenerateImage(*pb.GenerateImageRequest) error
AudioTranscription(*pb.TranscriptRequest) (api.Result, error)
TTS(*pb.TTSRequest) error
} }

@ -0,0 +1,33 @@
package bert
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
bert "github.com/go-skynet/go-bert.cpp"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
)
type Embeddings struct {
base.Base
bert *bert.Bert
}
func (llm *Embeddings) Load(opts *pb.ModelOptions) error {
model, err := bert.New(opts.Model)
llm.bert = model
return err
}
func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
if len(opts.EmbeddingTokens) > 0 {
tokens := []int{}
for _, t := range opts.EmbeddingTokens {
tokens = append(tokens, int(t))
}
return llm.bert.TokenEmbeddings(tokens, bert.SetThreads(int(opts.Threads)))
}
return llm.bert.Embeddings(opts.Embeddings, bert.SetThreads(int(opts.Threads)))
}

@ -0,0 +1,59 @@
package bloomz
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/bloomz.cpp"
)
type LLM struct {
base.Base
bloomz *bloomz.Bloomz
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
model, err := bloomz.New(opts.Model)
llm.bloomz = model
return err
}
func buildPredictOptions(opts *pb.PredictOptions) []bloomz.PredictOption {
predictOptions := []bloomz.PredictOption{
bloomz.SetTemperature(float64(opts.Temperature)),
bloomz.SetTopP(float64(opts.TopP)),
bloomz.SetTopK(int(opts.TopK)),
bloomz.SetTokens(int(opts.Tokens)),
bloomz.SetThreads(int(opts.Threads)),
}
if opts.Seed != 0 {
predictOptions = append(predictOptions, bloomz.SetSeed(int(opts.Seed)))
}
return predictOptions
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
return llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
go func() {
res, err := llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...)
if err != nil {
fmt.Println("err: ", err)
}
results <- res
close(results)
}()
return nil
}

@ -5,12 +5,15 @@ package falcon
import ( import (
"fmt" "fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
ggllm "github.com/mudler/go-ggllm.cpp" ggllm "github.com/mudler/go-ggllm.cpp"
) )
type LLM struct { type LLM struct {
base.Base
falcon *ggllm.Falcon falcon *ggllm.Falcon
} }
@ -42,10 +45,6 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error {
return err return err
} }
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return nil, fmt.Errorf("not implemented")
}
func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption { func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption {
predictOptions := []ggllm.PredictOption{ predictOptions := []ggllm.PredictOption{
ggllm.SetTemperature(float64(opts.Temperature)), ggllm.SetTemperature(float64(opts.Temperature)),
@ -122,7 +121,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
} }
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
predictOptions := buildPredictOptions(opts) predictOptions := buildPredictOptions(opts)
predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool { predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool {
@ -140,4 +139,6 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) {
} }
close(results) close(results)
}() }()
return nil
} }

@ -5,11 +5,14 @@ package gpt4all
import ( import (
"fmt" "fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang"
) )
type LLM struct { type LLM struct {
base.Base
gpt4all *gpt4all.Model gpt4all *gpt4all.Model
} }
@ -39,7 +42,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...) return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...)
} }
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
predictOptions := buildPredictOptions(opts) predictOptions := buildPredictOptions(opts)
go func() { go func() {
@ -54,8 +57,6 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) {
llm.gpt4all.SetTokenCallback(nil) llm.gpt4all.SetTokenCallback(nil)
close(results) close(results)
}() }()
}
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { return nil
return []float32{}, fmt.Errorf("not implemented")
} }

@ -0,0 +1,58 @@
package langchain
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/LocalAI/pkg/langchain"
)
type LLM struct {
base.Base
langchain *langchain.HuggingFace
model string
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
llm.langchain, _ = langchain.NewHuggingFace(opts.Model)
llm.model = opts.Model
return nil
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
o := []langchain.PredictOption{
langchain.SetModel(llm.model),
langchain.SetMaxTokens(int(opts.Tokens)),
langchain.SetTemperature(float64(opts.Temperature)),
langchain.SetStopWords(opts.StopPrompts),
}
pred, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...)
if err != nil {
return "", err
}
return pred.Completion, nil
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
o := []langchain.PredictOption{
langchain.SetModel(llm.model),
langchain.SetMaxTokens(int(opts.Tokens)),
langchain.SetTemperature(float64(opts.Temperature)),
langchain.SetStopWords(opts.StopPrompts),
}
go func() {
res, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...)
if err != nil {
fmt.Println("err: ", err)
}
results <- res.Completion
close(results)
}()
return nil
}

@ -5,11 +5,14 @@ package llama
import ( import (
"fmt" "fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
"github.com/go-skynet/go-llama.cpp" "github.com/go-skynet/go-llama.cpp"
) )
type LLM struct { type LLM struct {
base.Base
llama *llama.LLama llama *llama.LLama
} }
@ -133,7 +136,7 @@ func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...) return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...)
} }
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
predictOptions := buildPredictOptions(opts) predictOptions := buildPredictOptions(opts)
predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool { predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool {
@ -148,6 +151,8 @@ func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) {
} }
close(results) close(results)
}() }()
return nil
} }
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) {

@ -0,0 +1,71 @@
package rwkv
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"fmt"
"path/filepath"
"github.com/donomii/go-rwkv.cpp"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
)
const tokenizerSuffix = ".tokenizer.json"
type LLM struct {
base.Base
rwkv *rwkv.RwkvState
}
func (llm *LLM) Load(opts *pb.ModelOptions) error {
modelPath := filepath.Dir(opts.Model)
modelFile := filepath.Base(opts.Model)
model := rwkv.LoadFiles(opts.Model, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads()))
if model == nil {
return fmt.Errorf("could not load model")
}
llm.rwkv = model
return nil
}
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) {
stopWord := "\n"
if len(opts.StopPrompts) > 0 {
stopWord = opts.StopPrompts[0]
}
if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil {
return "", err
}
response := llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), nil)
return response, nil
}
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error {
go func() {
stopWord := "\n"
if len(opts.StopPrompts) > 0 {
stopWord = opts.StopPrompts[0]
}
if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil {
fmt.Println("Error processing input: ", err)
return
}
llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), func(s string) bool {
results <- s
return true
})
close(results)
}()
return nil
}

@ -5,12 +5,15 @@ package transformers
import ( import (
"fmt" "fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
transformers "github.com/go-skynet/go-ggml-transformers.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp"
) )
type Dolly struct { type Dolly struct {
base.Base
dolly *transformers.Dolly dolly *transformers.Dolly
} }
@ -20,16 +23,12 @@ func (llm *Dolly) Load(opts *pb.ModelOptions) error {
return err return err
} }
func (llm *Dolly) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return nil, fmt.Errorf("not implemented")
}
func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) { func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) {
return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
} }
// fallback to Predict // fallback to Predict
func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) { func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error {
go func() { go func() {
res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -39,4 +38,6 @@ func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) {
results <- res results <- res
close(results) close(results)
}() }()
return nil
} }

@ -0,0 +1,43 @@
package transformers
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
transformers "github.com/go-skynet/go-ggml-transformers.cpp"
)
type Falcon struct {
base.Base
falcon *transformers.Falcon
}
func (llm *Falcon) Load(opts *pb.ModelOptions) error {
model, err := transformers.NewFalcon(opts.Model)
llm.falcon = model
return err
}
func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) {
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
}
// fallback to Predict
func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error {
go func() {
res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...)
if err != nil {
fmt.Println("err: ", err)
}
results <- res
close(results)
}()
return nil
}

@ -5,12 +5,15 @@ package transformers
import ( import (
"fmt" "fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
transformers "github.com/go-skynet/go-ggml-transformers.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp"
) )
type GPT2 struct { type GPT2 struct {
base.Base
gpt2 *transformers.GPT2 gpt2 *transformers.GPT2
} }
@ -20,16 +23,12 @@ func (llm *GPT2) Load(opts *pb.ModelOptions) error {
return err return err
} }
func (llm *GPT2) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return nil, fmt.Errorf("not implemented")
}
func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) { func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) {
return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
} }
// fallback to Predict // fallback to Predict
func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) { func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error {
go func() { go func() {
res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -39,4 +38,5 @@ func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) {
results <- res results <- res
close(results) close(results)
}() }()
return nil
} }

@ -5,12 +5,15 @@ package transformers
import ( import (
"fmt" "fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
transformers "github.com/go-skynet/go-ggml-transformers.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp"
) )
type GPTJ struct { type GPTJ struct {
base.Base
gptj *transformers.GPTJ gptj *transformers.GPTJ
} }
@ -20,16 +23,12 @@ func (llm *GPTJ) Load(opts *pb.ModelOptions) error {
return err return err
} }
func (llm *GPTJ) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return nil, fmt.Errorf("not implemented")
}
func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) { func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) {
return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
} }
// fallback to Predict // fallback to Predict
func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) { func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error {
go func() { go func() {
res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -39,4 +38,5 @@ func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) {
results <- res results <- res
close(results) close(results)
}() }()
return nil
} }

@ -5,12 +5,15 @@ package transformers
import ( import (
"fmt" "fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
transformers "github.com/go-skynet/go-ggml-transformers.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp"
) )
type GPTNeoX struct { type GPTNeoX struct {
base.Base
gptneox *transformers.GPTNeoX gptneox *transformers.GPTNeoX
} }
@ -20,16 +23,12 @@ func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error {
return err return err
} }
func (llm *GPTNeoX) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return nil, fmt.Errorf("not implemented")
}
func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) { func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) {
return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
} }
// fallback to Predict // fallback to Predict
func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) { func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error {
go func() { go func() {
res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -39,4 +38,5 @@ func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string)
results <- res results <- res
close(results) close(results)
}() }()
return nil
} }

@ -5,12 +5,15 @@ package transformers
import ( import (
"fmt" "fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
transformers "github.com/go-skynet/go-ggml-transformers.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp"
) )
type MPT struct { type MPT struct {
base.Base
mpt *transformers.MPT mpt *transformers.MPT
} }
@ -20,16 +23,12 @@ func (llm *MPT) Load(opts *pb.ModelOptions) error {
return err return err
} }
func (llm *MPT) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return nil, fmt.Errorf("not implemented")
}
func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) { func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) {
return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
} }
// fallback to Predict // fallback to Predict
func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) { func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error {
go func() { go func() {
res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -39,4 +38,5 @@ func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) {
results <- res results <- res
close(results) close(results)
}() }()
return nil
} }

@ -5,12 +5,15 @@ package transformers
import ( import (
"fmt" "fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
transformers "github.com/go-skynet/go-ggml-transformers.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp"
) )
type Replit struct { type Replit struct {
base.Base
replit *transformers.Replit replit *transformers.Replit
} }
@ -20,16 +23,12 @@ func (llm *Replit) Load(opts *pb.ModelOptions) error {
return err return err
} }
func (llm *Replit) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return nil, fmt.Errorf("not implemented")
}
func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) { func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) {
return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
} }
// fallback to Predict // fallback to Predict
func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) { func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error {
go func() { go func() {
res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -39,4 +38,5 @@ func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) {
results <- res results <- res
close(results) close(results)
}() }()
return nil
} }

@ -5,12 +5,15 @@ package transformers
import ( import (
"fmt" "fmt"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
transformers "github.com/go-skynet/go-ggml-transformers.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp"
) )
type Starcoder struct { type Starcoder struct {
base.Base
starcoder *transformers.Starcoder starcoder *transformers.Starcoder
} }
@ -20,16 +23,12 @@ func (llm *Starcoder) Load(opts *pb.ModelOptions) error {
return err return err
} }
func (llm *Starcoder) Embeddings(opts *pb.PredictOptions) ([]float32, error) {
return nil, fmt.Errorf("not implemented")
}
func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) { func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) {
return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
} }
// fallback to Predict // fallback to Predict
func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) { func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error {
go func() { go func() {
res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...)
@ -39,4 +38,6 @@ func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string
results <- res results <- res
close(results) close(results)
}() }()
return nil
} }

File diff suppressed because it is too large Load Diff

@ -2,17 +2,20 @@ syntax = "proto3";
option go_package = "github.com/go-skynet/LocalAI/pkg/grpc/proto"; option go_package = "github.com/go-skynet/LocalAI/pkg/grpc/proto";
option java_multiple_files = true; option java_multiple_files = true;
option java_package = "io.skynet.localai.llmserver"; option java_package = "io.skynet.localai.backend";
option java_outer_classname = "LLMServer"; option java_outer_classname = "LocalAIBackend";
package llm; package backend;
service LLM { service Backend {
rpc Health(HealthMessage) returns (Reply) {} rpc Health(HealthMessage) returns (Reply) {}
rpc Predict(PredictOptions) returns (Reply) {} rpc Predict(PredictOptions) returns (Reply) {}
rpc LoadModel(ModelOptions) returns (Result) {} rpc LoadModel(ModelOptions) returns (Result) {}
rpc PredictStream(PredictOptions) returns (stream Reply) {} rpc PredictStream(PredictOptions) returns (stream Reply) {}
rpc Embedding(PredictOptions) returns (EmbeddingResult) {} rpc Embedding(PredictOptions) returns (EmbeddingResult) {}
rpc GenerateImage(GenerateImageRequest) returns (Result) {}
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {}
rpc TTS(TTSRequest) returns (Result) {}
} }
message HealthMessage {} message HealthMessage {}
@ -88,3 +91,39 @@ message Result {
message EmbeddingResult { message EmbeddingResult {
repeated float embeddings = 1; repeated float embeddings = 1;
} }
message TranscriptRequest {
string dst = 2;
string language = 3;
uint32 threads = 4;
}
message TranscriptResult {
repeated TranscriptSegment segments = 1;
string text = 2;
}
message TranscriptSegment {
int32 id = 1;
int64 start = 2;
int64 end = 3;
string text = 4;
repeated int32 tokens = 5;
}
message GenerateImageRequest {
int32 height = 1;
int32 width = 2;
int32 mode = 3;
int32 step = 4;
int32 seed = 5;
string positive_prompt = 6;
string negative_prompt = 7;
string dst = 8;
}
message TTSRequest {
string text = 1;
string model = 2;
string dst = 3;
}

@ -0,0 +1,385 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.15.8
// source: pkg/grpc/proto/backend.proto
package proto
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// BackendClient is the client API for Backend service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type BackendClient interface {
Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error)
Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error)
LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error)
PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error)
Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error)
GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error)
AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error)
TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error)
}
type backendClient struct {
cc grpc.ClientConnInterface
}
func NewBackendClient(cc grpc.ClientConnInterface) BackendClient {
return &backendClient{cc}
}
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply)
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply)
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...)
if err != nil {
return nil, err
}
x := &backendPredictStreamClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type Backend_PredictStreamClient interface {
Recv() (*Reply, error)
grpc.ClientStream
}
type backendPredictStreamClient struct {
grpc.ClientStream
}
func (x *backendPredictStreamClient) Recv() (*Reply, error) {
m := new(Reply)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
out := new(EmbeddingResult)
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) {
out := new(TranscriptResult)
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// BackendServer is the server API for Backend service.
// All implementations must embed UnimplementedBackendServer
// for forward compatibility
type BackendServer interface {
Health(context.Context, *HealthMessage) (*Reply, error)
Predict(context.Context, *PredictOptions) (*Reply, error)
LoadModel(context.Context, *ModelOptions) (*Result, error)
PredictStream(*PredictOptions, Backend_PredictStreamServer) error
Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error)
GenerateImage(context.Context, *GenerateImageRequest) (*Result, error)
AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error)
TTS(context.Context, *TTSRequest) (*Result, error)
mustEmbedUnimplementedBackendServer()
}
// UnimplementedBackendServer must be embedded to have forward compatible implementations.
type UnimplementedBackendServer struct {
}
func (UnimplementedBackendServer) Health(context.Context, *HealthMessage) (*Reply, error) {
return nil, status.Errorf(codes.Unimplemented, "method Health not implemented")
}
func (UnimplementedBackendServer) Predict(context.Context, *PredictOptions) (*Reply, error) {
return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented")
}
func (UnimplementedBackendServer) LoadModel(context.Context, *ModelOptions) (*Result, error) {
return nil, status.Errorf(codes.Unimplemented, "method LoadModel not implemented")
}
func (UnimplementedBackendServer) PredictStream(*PredictOptions, Backend_PredictStreamServer) error {
return status.Errorf(codes.Unimplemented, "method PredictStream not implemented")
}
func (UnimplementedBackendServer) Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) {
return nil, status.Errorf(codes.Unimplemented, "method Embedding not implemented")
}
func (UnimplementedBackendServer) GenerateImage(context.Context, *GenerateImageRequest) (*Result, error) {
return nil, status.Errorf(codes.Unimplemented, "method GenerateImage not implemented")
}
func (UnimplementedBackendServer) AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error) {
return nil, status.Errorf(codes.Unimplemented, "method AudioTranscription not implemented")
}
func (UnimplementedBackendServer) TTS(context.Context, *TTSRequest) (*Result, error) {
return nil, status.Errorf(codes.Unimplemented, "method TTS not implemented")
}
func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {}
// UnsafeBackendServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to BackendServer will
// result in compilation errors.
type UnsafeBackendServer interface {
mustEmbedUnimplementedBackendServer()
}
func RegisterBackendServer(s grpc.ServiceRegistrar, srv BackendServer) {
s.RegisterService(&Backend_ServiceDesc, srv)
}
func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HealthMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).Health(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/Health",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Health(ctx, req.(*HealthMessage))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(PredictOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).Predict(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/Predict",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ModelOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).LoadModel(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/LoadModel",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_PredictStream_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(PredictOptions)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(BackendServer).PredictStream(m, &backendPredictStreamServer{stream})
}
type Backend_PredictStreamServer interface {
Send(*Reply) error
grpc.ServerStream
}
type backendPredictStreamServer struct {
grpc.ServerStream
}
func (x *backendPredictStreamServer) Send(m *Reply) error {
return x.ServerStream.SendMsg(m)
}
func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(PredictOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).Embedding(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/Embedding",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GenerateImageRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).GenerateImage(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/GenerateImage",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(TranscriptRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).AudioTranscription(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/AudioTranscription",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest))
}
return interceptor(ctx, in, info, handler)
}
func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(TTSRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(BackendServer).TTS(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/backend.Backend/TTS",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest))
}
return interceptor(ctx, in, info, handler)
}
// Backend_ServiceDesc is the grpc.ServiceDesc for Backend service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var Backend_ServiceDesc = grpc.ServiceDesc{
ServiceName: "backend.Backend",
HandlerType: (*BackendServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Health",
Handler: _Backend_Health_Handler,
},
{
MethodName: "Predict",
Handler: _Backend_Predict_Handler,
},
{
MethodName: "LoadModel",
Handler: _Backend_LoadModel_Handler,
},
{
MethodName: "Embedding",
Handler: _Backend_Embedding_Handler,
},
{
MethodName: "GenerateImage",
Handler: _Backend_GenerateImage_Handler,
},
{
MethodName: "AudioTranscription",
Handler: _Backend_AudioTranscription_Handler,
},
{
MethodName: "TTS",
Handler: _Backend_TTS_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "PredictStream",
Handler: _Backend_PredictStream_Handler,
ServerStreams: true,
},
},
Metadata: "pkg/grpc/proto/backend.proto",
}

@ -1,969 +0,0 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.26.0
// protoc v3.15.8
// source: pkg/grpc/proto/llmserver.proto
package proto
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type HealthMessage struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
}
func (x *HealthMessage) Reset() {
*x = HealthMessage{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *HealthMessage) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*HealthMessage) ProtoMessage() {}
func (x *HealthMessage) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use HealthMessage.ProtoReflect.Descriptor instead.
func (*HealthMessage) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{0}
}
// The request message containing the user's name.
type PredictOptions struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Prompt string `protobuf:"bytes,1,opt,name=Prompt,proto3" json:"Prompt,omitempty"`
Seed int32 `protobuf:"varint,2,opt,name=Seed,proto3" json:"Seed,omitempty"`
Threads int32 `protobuf:"varint,3,opt,name=Threads,proto3" json:"Threads,omitempty"`
Tokens int32 `protobuf:"varint,4,opt,name=Tokens,proto3" json:"Tokens,omitempty"`
TopK int32 `protobuf:"varint,5,opt,name=TopK,proto3" json:"TopK,omitempty"`
Repeat int32 `protobuf:"varint,6,opt,name=Repeat,proto3" json:"Repeat,omitempty"`
Batch int32 `protobuf:"varint,7,opt,name=Batch,proto3" json:"Batch,omitempty"`
NKeep int32 `protobuf:"varint,8,opt,name=NKeep,proto3" json:"NKeep,omitempty"`
Temperature float32 `protobuf:"fixed32,9,opt,name=Temperature,proto3" json:"Temperature,omitempty"`
Penalty float32 `protobuf:"fixed32,10,opt,name=Penalty,proto3" json:"Penalty,omitempty"`
F16KV bool `protobuf:"varint,11,opt,name=F16KV,proto3" json:"F16KV,omitempty"`
DebugMode bool `protobuf:"varint,12,opt,name=DebugMode,proto3" json:"DebugMode,omitempty"`
StopPrompts []string `protobuf:"bytes,13,rep,name=StopPrompts,proto3" json:"StopPrompts,omitempty"`
IgnoreEOS bool `protobuf:"varint,14,opt,name=IgnoreEOS,proto3" json:"IgnoreEOS,omitempty"`
TailFreeSamplingZ float32 `protobuf:"fixed32,15,opt,name=TailFreeSamplingZ,proto3" json:"TailFreeSamplingZ,omitempty"`
TypicalP float32 `protobuf:"fixed32,16,opt,name=TypicalP,proto3" json:"TypicalP,omitempty"`
FrequencyPenalty float32 `protobuf:"fixed32,17,opt,name=FrequencyPenalty,proto3" json:"FrequencyPenalty,omitempty"`
PresencePenalty float32 `protobuf:"fixed32,18,opt,name=PresencePenalty,proto3" json:"PresencePenalty,omitempty"`
Mirostat int32 `protobuf:"varint,19,opt,name=Mirostat,proto3" json:"Mirostat,omitempty"`
MirostatETA float32 `protobuf:"fixed32,20,opt,name=MirostatETA,proto3" json:"MirostatETA,omitempty"`
MirostatTAU float32 `protobuf:"fixed32,21,opt,name=MirostatTAU,proto3" json:"MirostatTAU,omitempty"`
PenalizeNL bool `protobuf:"varint,22,opt,name=PenalizeNL,proto3" json:"PenalizeNL,omitempty"`
LogitBias string `protobuf:"bytes,23,opt,name=LogitBias,proto3" json:"LogitBias,omitempty"`
MLock bool `protobuf:"varint,25,opt,name=MLock,proto3" json:"MLock,omitempty"`
MMap bool `protobuf:"varint,26,opt,name=MMap,proto3" json:"MMap,omitempty"`
PromptCacheAll bool `protobuf:"varint,27,opt,name=PromptCacheAll,proto3" json:"PromptCacheAll,omitempty"`
PromptCacheRO bool `protobuf:"varint,28,opt,name=PromptCacheRO,proto3" json:"PromptCacheRO,omitempty"`
Grammar string `protobuf:"bytes,29,opt,name=Grammar,proto3" json:"Grammar,omitempty"`
MainGPU string `protobuf:"bytes,30,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"`
TensorSplit string `protobuf:"bytes,31,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"`
TopP float32 `protobuf:"fixed32,32,opt,name=TopP,proto3" json:"TopP,omitempty"`
PromptCachePath string `protobuf:"bytes,33,opt,name=PromptCachePath,proto3" json:"PromptCachePath,omitempty"`
Debug bool `protobuf:"varint,34,opt,name=Debug,proto3" json:"Debug,omitempty"`
EmbeddingTokens []int32 `protobuf:"varint,35,rep,packed,name=EmbeddingTokens,proto3" json:"EmbeddingTokens,omitempty"`
Embeddings string `protobuf:"bytes,36,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"`
}
func (x *PredictOptions) Reset() {
*x = PredictOptions{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *PredictOptions) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*PredictOptions) ProtoMessage() {}
func (x *PredictOptions) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use PredictOptions.ProtoReflect.Descriptor instead.
func (*PredictOptions) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{1}
}
func (x *PredictOptions) GetPrompt() string {
if x != nil {
return x.Prompt
}
return ""
}
func (x *PredictOptions) GetSeed() int32 {
if x != nil {
return x.Seed
}
return 0
}
func (x *PredictOptions) GetThreads() int32 {
if x != nil {
return x.Threads
}
return 0
}
func (x *PredictOptions) GetTokens() int32 {
if x != nil {
return x.Tokens
}
return 0
}
func (x *PredictOptions) GetTopK() int32 {
if x != nil {
return x.TopK
}
return 0
}
func (x *PredictOptions) GetRepeat() int32 {
if x != nil {
return x.Repeat
}
return 0
}
func (x *PredictOptions) GetBatch() int32 {
if x != nil {
return x.Batch
}
return 0
}
func (x *PredictOptions) GetNKeep() int32 {
if x != nil {
return x.NKeep
}
return 0
}
func (x *PredictOptions) GetTemperature() float32 {
if x != nil {
return x.Temperature
}
return 0
}
func (x *PredictOptions) GetPenalty() float32 {
if x != nil {
return x.Penalty
}
return 0
}
func (x *PredictOptions) GetF16KV() bool {
if x != nil {
return x.F16KV
}
return false
}
func (x *PredictOptions) GetDebugMode() bool {
if x != nil {
return x.DebugMode
}
return false
}
func (x *PredictOptions) GetStopPrompts() []string {
if x != nil {
return x.StopPrompts
}
return nil
}
func (x *PredictOptions) GetIgnoreEOS() bool {
if x != nil {
return x.IgnoreEOS
}
return false
}
func (x *PredictOptions) GetTailFreeSamplingZ() float32 {
if x != nil {
return x.TailFreeSamplingZ
}
return 0
}
func (x *PredictOptions) GetTypicalP() float32 {
if x != nil {
return x.TypicalP
}
return 0
}
func (x *PredictOptions) GetFrequencyPenalty() float32 {
if x != nil {
return x.FrequencyPenalty
}
return 0
}
func (x *PredictOptions) GetPresencePenalty() float32 {
if x != nil {
return x.PresencePenalty
}
return 0
}
func (x *PredictOptions) GetMirostat() int32 {
if x != nil {
return x.Mirostat
}
return 0
}
func (x *PredictOptions) GetMirostatETA() float32 {
if x != nil {
return x.MirostatETA
}
return 0
}
func (x *PredictOptions) GetMirostatTAU() float32 {
if x != nil {
return x.MirostatTAU
}
return 0
}
func (x *PredictOptions) GetPenalizeNL() bool {
if x != nil {
return x.PenalizeNL
}
return false
}
func (x *PredictOptions) GetLogitBias() string {
if x != nil {
return x.LogitBias
}
return ""
}
func (x *PredictOptions) GetMLock() bool {
if x != nil {
return x.MLock
}
return false
}
func (x *PredictOptions) GetMMap() bool {
if x != nil {
return x.MMap
}
return false
}
func (x *PredictOptions) GetPromptCacheAll() bool {
if x != nil {
return x.PromptCacheAll
}
return false
}
func (x *PredictOptions) GetPromptCacheRO() bool {
if x != nil {
return x.PromptCacheRO
}
return false
}
func (x *PredictOptions) GetGrammar() string {
if x != nil {
return x.Grammar
}
return ""
}
func (x *PredictOptions) GetMainGPU() string {
if x != nil {
return x.MainGPU
}
return ""
}
func (x *PredictOptions) GetTensorSplit() string {
if x != nil {
return x.TensorSplit
}
return ""
}
func (x *PredictOptions) GetTopP() float32 {
if x != nil {
return x.TopP
}
return 0
}
func (x *PredictOptions) GetPromptCachePath() string {
if x != nil {
return x.PromptCachePath
}
return ""
}
func (x *PredictOptions) GetDebug() bool {
if x != nil {
return x.Debug
}
return false
}
func (x *PredictOptions) GetEmbeddingTokens() []int32 {
if x != nil {
return x.EmbeddingTokens
}
return nil
}
func (x *PredictOptions) GetEmbeddings() string {
if x != nil {
return x.Embeddings
}
return ""
}
// The response message containing the result
type Reply struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
}
func (x *Reply) Reset() {
*x = Reply{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[2]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Reply) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Reply) ProtoMessage() {}
func (x *Reply) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[2]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Reply.ProtoReflect.Descriptor instead.
func (*Reply) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{2}
}
func (x *Reply) GetMessage() string {
if x != nil {
return x.Message
}
return ""
}
type ModelOptions struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Model string `protobuf:"bytes,1,opt,name=Model,proto3" json:"Model,omitempty"`
ContextSize int32 `protobuf:"varint,2,opt,name=ContextSize,proto3" json:"ContextSize,omitempty"`
Seed int32 `protobuf:"varint,3,opt,name=Seed,proto3" json:"Seed,omitempty"`
NBatch int32 `protobuf:"varint,4,opt,name=NBatch,proto3" json:"NBatch,omitempty"`
F16Memory bool `protobuf:"varint,5,opt,name=F16Memory,proto3" json:"F16Memory,omitempty"`
MLock bool `protobuf:"varint,6,opt,name=MLock,proto3" json:"MLock,omitempty"`
MMap bool `protobuf:"varint,7,opt,name=MMap,proto3" json:"MMap,omitempty"`
VocabOnly bool `protobuf:"varint,8,opt,name=VocabOnly,proto3" json:"VocabOnly,omitempty"`
LowVRAM bool `protobuf:"varint,9,opt,name=LowVRAM,proto3" json:"LowVRAM,omitempty"`
Embeddings bool `protobuf:"varint,10,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"`
NUMA bool `protobuf:"varint,11,opt,name=NUMA,proto3" json:"NUMA,omitempty"`
NGPULayers int32 `protobuf:"varint,12,opt,name=NGPULayers,proto3" json:"NGPULayers,omitempty"`
MainGPU string `protobuf:"bytes,13,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"`
TensorSplit string `protobuf:"bytes,14,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"`
Threads int32 `protobuf:"varint,15,opt,name=Threads,proto3" json:"Threads,omitempty"`
LibrarySearchPath string `protobuf:"bytes,16,opt,name=LibrarySearchPath,proto3" json:"LibrarySearchPath,omitempty"`
}
func (x *ModelOptions) Reset() {
*x = ModelOptions{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[3]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *ModelOptions) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*ModelOptions) ProtoMessage() {}
func (x *ModelOptions) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[3]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use ModelOptions.ProtoReflect.Descriptor instead.
func (*ModelOptions) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{3}
}
func (x *ModelOptions) GetModel() string {
if x != nil {
return x.Model
}
return ""
}
func (x *ModelOptions) GetContextSize() int32 {
if x != nil {
return x.ContextSize
}
return 0
}
func (x *ModelOptions) GetSeed() int32 {
if x != nil {
return x.Seed
}
return 0
}
func (x *ModelOptions) GetNBatch() int32 {
if x != nil {
return x.NBatch
}
return 0
}
func (x *ModelOptions) GetF16Memory() bool {
if x != nil {
return x.F16Memory
}
return false
}
func (x *ModelOptions) GetMLock() bool {
if x != nil {
return x.MLock
}
return false
}
func (x *ModelOptions) GetMMap() bool {
if x != nil {
return x.MMap
}
return false
}
func (x *ModelOptions) GetVocabOnly() bool {
if x != nil {
return x.VocabOnly
}
return false
}
func (x *ModelOptions) GetLowVRAM() bool {
if x != nil {
return x.LowVRAM
}
return false
}
func (x *ModelOptions) GetEmbeddings() bool {
if x != nil {
return x.Embeddings
}
return false
}
func (x *ModelOptions) GetNUMA() bool {
if x != nil {
return x.NUMA
}
return false
}
func (x *ModelOptions) GetNGPULayers() int32 {
if x != nil {
return x.NGPULayers
}
return 0
}
func (x *ModelOptions) GetMainGPU() string {
if x != nil {
return x.MainGPU
}
return ""
}
func (x *ModelOptions) GetTensorSplit() string {
if x != nil {
return x.TensorSplit
}
return ""
}
func (x *ModelOptions) GetThreads() int32 {
if x != nil {
return x.Threads
}
return 0
}
func (x *ModelOptions) GetLibrarySearchPath() string {
if x != nil {
return x.LibrarySearchPath
}
return ""
}
type Result struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"`
Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"`
}
func (x *Result) Reset() {
*x = Result{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Result) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Result) ProtoMessage() {}
func (x *Result) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Result.ProtoReflect.Descriptor instead.
func (*Result) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{4}
}
func (x *Result) GetMessage() string {
if x != nil {
return x.Message
}
return ""
}
func (x *Result) GetSuccess() bool {
if x != nil {
return x.Success
}
return false
}
type EmbeddingResult struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Embeddings []float32 `protobuf:"fixed32,1,rep,packed,name=embeddings,proto3" json:"embeddings,omitempty"`
}
func (x *EmbeddingResult) Reset() {
*x = EmbeddingResult{}
if protoimpl.UnsafeEnabled {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *EmbeddingResult) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*EmbeddingResult) ProtoMessage() {}
func (x *EmbeddingResult) ProtoReflect() protoreflect.Message {
mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[5]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use EmbeddingResult.ProtoReflect.Descriptor instead.
func (*EmbeddingResult) Descriptor() ([]byte, []int) {
return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{5}
}
func (x *EmbeddingResult) GetEmbeddings() []float32 {
if x != nil {
return x.Embeddings
}
return nil
}
var File_pkg_grpc_proto_llmserver_proto protoreflect.FileDescriptor
var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{
0x0a, 0x1e, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x2f, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x12, 0x03, 0x6c, 0x6c, 0x6d, 0x22, 0x0f, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xa0, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, 0x64, 0x69,
0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x72, 0x6f,
0x6d, 0x70, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x72, 0x6f, 0x6d, 0x70,
0x74, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52,
0x04, 0x53, 0x65, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73,
0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x12,
0x16, 0x0a, 0x06, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52,
0x06, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x6f, 0x70, 0x4b, 0x18,
0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x4b, 0x12, 0x16, 0x0a, 0x06, 0x52,
0x65, 0x70, 0x65, 0x61, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x52, 0x65, 0x70,
0x65, 0x61, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x42, 0x61, 0x74, 0x63, 0x68, 0x18, 0x07, 0x20, 0x01,
0x28, 0x05, 0x52, 0x05, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x4b, 0x65,
0x65, 0x70, 0x18, 0x08, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x4e, 0x4b, 0x65, 0x65, 0x70, 0x12,
0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, 0x75, 0x72, 0x65, 0x18, 0x09,
0x20, 0x01, 0x28, 0x02, 0x52, 0x0b, 0x54, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, 0x75, 0x72,
0x65, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x0a, 0x20, 0x01,
0x28, 0x02, 0x52, 0x07, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x46,
0x31, 0x36, 0x4b, 0x56, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x46, 0x31, 0x36, 0x4b,
0x56, 0x12, 0x1c, 0x0a, 0x09, 0x44, 0x65, 0x62, 0x75, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x18, 0x0c,
0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x44, 0x65, 0x62, 0x75, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x12,
0x20, 0x0a, 0x0b, 0x53, 0x74, 0x6f, 0x70, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x73, 0x18, 0x0d,
0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x53, 0x74, 0x6f, 0x70, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74,
0x73, 0x12, 0x1c, 0x0a, 0x09, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x45, 0x4f, 0x53, 0x18, 0x0e,
0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x45, 0x4f, 0x53, 0x12,
0x2c, 0x0a, 0x11, 0x54, 0x61, 0x69, 0x6c, 0x46, 0x72, 0x65, 0x65, 0x53, 0x61, 0x6d, 0x70, 0x6c,
0x69, 0x6e, 0x67, 0x5a, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x02, 0x52, 0x11, 0x54, 0x61, 0x69, 0x6c,
0x46, 0x72, 0x65, 0x65, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x69, 0x6e, 0x67, 0x5a, 0x12, 0x1a, 0x0a,
0x08, 0x54, 0x79, 0x70, 0x69, 0x63, 0x61, 0x6c, 0x50, 0x18, 0x10, 0x20, 0x01, 0x28, 0x02, 0x52,
0x08, 0x54, 0x79, 0x70, 0x69, 0x63, 0x61, 0x6c, 0x50, 0x12, 0x2a, 0x0a, 0x10, 0x46, 0x72, 0x65,
0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x11, 0x20,
0x01, 0x28, 0x02, 0x52, 0x10, 0x46, 0x72, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x50, 0x65,
0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x65, 0x73, 0x65, 0x6e, 0x63,
0x65, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x12, 0x20, 0x01, 0x28, 0x02, 0x52, 0x0f,
0x50, 0x72, 0x65, 0x73, 0x65, 0x6e, 0x63, 0x65, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12,
0x1a, 0x0a, 0x08, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x18, 0x13, 0x20, 0x01, 0x28,
0x05, 0x52, 0x08, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x12, 0x20, 0x0a, 0x0b, 0x4d,
0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x45, 0x54, 0x41, 0x18, 0x14, 0x20, 0x01, 0x28, 0x02,
0x52, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x45, 0x54, 0x41, 0x12, 0x20, 0x0a,
0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x54, 0x41, 0x55, 0x18, 0x15, 0x20, 0x01,
0x28, 0x02, 0x52, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x54, 0x41, 0x55, 0x12,
0x1e, 0x0a, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x18, 0x16, 0x20,
0x01, 0x28, 0x08, 0x52, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x12,
0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x18, 0x17, 0x20, 0x01,
0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, 0x14, 0x0a,
0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x4d, 0x4c,
0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, 0x01, 0x28,
0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70,
0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x52,
0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x12,
0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f,
0x18, 0x1c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61,
0x63, 0x68, 0x65, 0x52, 0x4f, 0x12, 0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72,
0x18, 0x1d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x12,
0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09,
0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e,
0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b,
0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54,
0x6f, 0x70, 0x50, 0x18, 0x20, 0x20, 0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x12,
0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61,
0x74, 0x68, 0x18, 0x21, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74,
0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, 0x65, 0x62,
0x75, 0x67, 0x18, 0x22, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x12,
0x28, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65,
0x6e, 0x73, 0x18, 0x23, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64,
0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x45, 0x6d, 0x62,
0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x45,
0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x21, 0x0a, 0x05, 0x52, 0x65, 0x70,
0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20,
0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xca, 0x03, 0x0a,
0x0c, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x14, 0x0a,
0x05, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4d, 0x6f,
0x64, 0x65, 0x6c, 0x12, 0x20, 0x0a, 0x0b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, 0x53, 0x69,
0x7a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78,
0x74, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x03, 0x20,
0x01, 0x28, 0x05, 0x52, 0x04, 0x53, 0x65, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x42, 0x61,
0x74, 0x63, 0x68, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x4e, 0x42, 0x61, 0x74, 0x63,
0x68, 0x12, 0x1c, 0x0a, 0x09, 0x46, 0x31, 0x36, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x18, 0x05,
0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x46, 0x31, 0x36, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, 0x12,
0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05,
0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x07, 0x20,
0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x1c, 0x0a, 0x09, 0x56, 0x6f, 0x63,
0x61, 0x62, 0x4f, 0x6e, 0x6c, 0x79, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x56, 0x6f,
0x63, 0x61, 0x62, 0x4f, 0x6e, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x4c, 0x6f, 0x77, 0x56, 0x52,
0x41, 0x4d, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x4c, 0x6f, 0x77, 0x56, 0x52, 0x41,
0x4d, 0x12, 0x1e, 0x0a, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18,
0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67,
0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x55, 0x4d, 0x41, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52,
0x04, 0x4e, 0x55, 0x4d, 0x41, 0x12, 0x1e, 0x0a, 0x0a, 0x4e, 0x47, 0x50, 0x55, 0x4c, 0x61, 0x79,
0x65, 0x72, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x4e, 0x47, 0x50, 0x55, 0x4c,
0x61, 0x79, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55,
0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12,
0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x0e,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69,
0x74, 0x12, 0x18, 0x0a, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x18, 0x0f, 0x20, 0x01,
0x28, 0x05, 0x52, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x12, 0x2c, 0x0a, 0x11, 0x4c,
0x69, 0x62, 0x72, 0x61, 0x72, 0x79, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x50, 0x61, 0x74, 0x68,
0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x11, 0x4c, 0x69, 0x62, 0x72, 0x61, 0x72, 0x79, 0x53,
0x65, 0x61, 0x72, 0x63, 0x68, 0x50, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x06, 0x52, 0x65, 0x73,
0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a,
0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07,
0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x22, 0x31, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64,
0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x65, 0x6d,
0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x02, 0x52, 0x0a,
0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x32, 0xfe, 0x01, 0x0a, 0x03, 0x4c,
0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x12, 0x2e, 0x6c,
0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65,
0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2c,
0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e,
0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a,
0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x09,
0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x11, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e,
0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0b, 0x2e, 0x6c,
0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x0d, 0x50,
0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x6c,
0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e,
0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30,
0x01, 0x12, 0x38, 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x13,
0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69,
0x6f, 0x6e, 0x73, 0x1a, 0x14, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64,
0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, 0x42, 0x57, 0x0a, 0x1b, 0x69,
0x6f, 0x2e, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69,
0x2e, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53,
0x65, 0x72, 0x76, 0x65, 0x72, 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f,
0x63, 0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70,
0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_pkg_grpc_proto_llmserver_proto_rawDescOnce sync.Once
file_pkg_grpc_proto_llmserver_proto_rawDescData = file_pkg_grpc_proto_llmserver_proto_rawDesc
)
func file_pkg_grpc_proto_llmserver_proto_rawDescGZIP() []byte {
file_pkg_grpc_proto_llmserver_proto_rawDescOnce.Do(func() {
file_pkg_grpc_proto_llmserver_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_proto_llmserver_proto_rawDescData)
})
return file_pkg_grpc_proto_llmserver_proto_rawDescData
}
var file_pkg_grpc_proto_llmserver_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_pkg_grpc_proto_llmserver_proto_goTypes = []interface{}{
(*HealthMessage)(nil), // 0: llm.HealthMessage
(*PredictOptions)(nil), // 1: llm.PredictOptions
(*Reply)(nil), // 2: llm.Reply
(*ModelOptions)(nil), // 3: llm.ModelOptions
(*Result)(nil), // 4: llm.Result
(*EmbeddingResult)(nil), // 5: llm.EmbeddingResult
}
var file_pkg_grpc_proto_llmserver_proto_depIdxs = []int32{
0, // 0: llm.LLM.Health:input_type -> llm.HealthMessage
1, // 1: llm.LLM.Predict:input_type -> llm.PredictOptions
3, // 2: llm.LLM.LoadModel:input_type -> llm.ModelOptions
1, // 3: llm.LLM.PredictStream:input_type -> llm.PredictOptions
1, // 4: llm.LLM.Embedding:input_type -> llm.PredictOptions
2, // 5: llm.LLM.Health:output_type -> llm.Reply
2, // 6: llm.LLM.Predict:output_type -> llm.Reply
4, // 7: llm.LLM.LoadModel:output_type -> llm.Result
2, // 8: llm.LLM.PredictStream:output_type -> llm.Reply
5, // 9: llm.LLM.Embedding:output_type -> llm.EmbeddingResult
5, // [5:10] is the sub-list for method output_type
0, // [0:5] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_pkg_grpc_proto_llmserver_proto_init() }
func file_pkg_grpc_proto_llmserver_proto_init() {
if File_pkg_grpc_proto_llmserver_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_pkg_grpc_proto_llmserver_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*HealthMessage); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_proto_llmserver_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*PredictOptions); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_proto_llmserver_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Reply); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_proto_llmserver_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*ModelOptions); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_proto_llmserver_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Result); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_pkg_grpc_proto_llmserver_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*EmbeddingResult); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_pkg_grpc_proto_llmserver_proto_rawDesc,
NumEnums: 0,
NumMessages: 6,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_pkg_grpc_proto_llmserver_proto_goTypes,
DependencyIndexes: file_pkg_grpc_proto_llmserver_proto_depIdxs,
MessageInfos: file_pkg_grpc_proto_llmserver_proto_msgTypes,
}.Build()
File_pkg_grpc_proto_llmserver_proto = out.File
file_pkg_grpc_proto_llmserver_proto_rawDesc = nil
file_pkg_grpc_proto_llmserver_proto_goTypes = nil
file_pkg_grpc_proto_llmserver_proto_depIdxs = nil
}

@ -1,277 +0,0 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.15.8
// source: pkg/grpc/proto/llmserver.proto
package proto
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// LLMClient is the client API for LLM service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type LLMClient interface {
Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error)
Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error)
LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error)
PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (LLM_PredictStreamClient, error)
Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error)
}
type lLMClient struct {
cc grpc.ClientConnInterface
}
func NewLLMClient(cc grpc.ClientConnInterface) LLMClient {
return &lLMClient{cc}
}
func (c *lLMClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply)
err := c.cc.Invoke(ctx, "/llm.LLM/Health", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *lLMClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) {
out := new(Reply)
err := c.cc.Invoke(ctx, "/llm.LLM/Predict", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *lLMClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) {
out := new(Result)
err := c.cc.Invoke(ctx, "/llm.LLM/LoadModel", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *lLMClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (LLM_PredictStreamClient, error) {
stream, err := c.cc.NewStream(ctx, &LLM_ServiceDesc.Streams[0], "/llm.LLM/PredictStream", opts...)
if err != nil {
return nil, err
}
x := &lLMPredictStreamClient{stream}
if err := x.ClientStream.SendMsg(in); err != nil {
return nil, err
}
if err := x.ClientStream.CloseSend(); err != nil {
return nil, err
}
return x, nil
}
type LLM_PredictStreamClient interface {
Recv() (*Reply, error)
grpc.ClientStream
}
type lLMPredictStreamClient struct {
grpc.ClientStream
}
func (x *lLMPredictStreamClient) Recv() (*Reply, error) {
m := new(Reply)
if err := x.ClientStream.RecvMsg(m); err != nil {
return nil, err
}
return m, nil
}
func (c *lLMClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) {
out := new(EmbeddingResult)
err := c.cc.Invoke(ctx, "/llm.LLM/Embedding", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// LLMServer is the server API for LLM service.
// All implementations must embed UnimplementedLLMServer
// for forward compatibility
type LLMServer interface {
Health(context.Context, *HealthMessage) (*Reply, error)
Predict(context.Context, *PredictOptions) (*Reply, error)
LoadModel(context.Context, *ModelOptions) (*Result, error)
PredictStream(*PredictOptions, LLM_PredictStreamServer) error
Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error)
mustEmbedUnimplementedLLMServer()
}
// UnimplementedLLMServer must be embedded to have forward compatible implementations.
type UnimplementedLLMServer struct {
}
func (UnimplementedLLMServer) Health(context.Context, *HealthMessage) (*Reply, error) {
return nil, status.Errorf(codes.Unimplemented, "method Health not implemented")
}
func (UnimplementedLLMServer) Predict(context.Context, *PredictOptions) (*Reply, error) {
return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented")
}
func (UnimplementedLLMServer) LoadModel(context.Context, *ModelOptions) (*Result, error) {
return nil, status.Errorf(codes.Unimplemented, "method LoadModel not implemented")
}
func (UnimplementedLLMServer) PredictStream(*PredictOptions, LLM_PredictStreamServer) error {
return status.Errorf(codes.Unimplemented, "method PredictStream not implemented")
}
func (UnimplementedLLMServer) Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) {
return nil, status.Errorf(codes.Unimplemented, "method Embedding not implemented")
}
func (UnimplementedLLMServer) mustEmbedUnimplementedLLMServer() {}
// UnsafeLLMServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to LLMServer will
// result in compilation errors.
type UnsafeLLMServer interface {
mustEmbedUnimplementedLLMServer()
}
func RegisterLLMServer(s grpc.ServiceRegistrar, srv LLMServer) {
s.RegisterService(&LLM_ServiceDesc, srv)
}
func _LLM_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(HealthMessage)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(LLMServer).Health(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/llm.LLM/Health",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(LLMServer).Health(ctx, req.(*HealthMessage))
}
return interceptor(ctx, in, info, handler)
}
func _LLM_Predict_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(PredictOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(LLMServer).Predict(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/llm.LLM/Predict",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(LLMServer).Predict(ctx, req.(*PredictOptions))
}
return interceptor(ctx, in, info, handler)
}
func _LLM_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ModelOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(LLMServer).LoadModel(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/llm.LLM/LoadModel",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(LLMServer).LoadModel(ctx, req.(*ModelOptions))
}
return interceptor(ctx, in, info, handler)
}
func _LLM_PredictStream_Handler(srv interface{}, stream grpc.ServerStream) error {
m := new(PredictOptions)
if err := stream.RecvMsg(m); err != nil {
return err
}
return srv.(LLMServer).PredictStream(m, &lLMPredictStreamServer{stream})
}
type LLM_PredictStreamServer interface {
Send(*Reply) error
grpc.ServerStream
}
type lLMPredictStreamServer struct {
grpc.ServerStream
}
func (x *lLMPredictStreamServer) Send(m *Reply) error {
return x.ServerStream.SendMsg(m)
}
func _LLM_Embedding_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(PredictOptions)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(LLMServer).Embedding(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/llm.LLM/Embedding",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(LLMServer).Embedding(ctx, req.(*PredictOptions))
}
return interceptor(ctx, in, info, handler)
}
// LLM_ServiceDesc is the grpc.ServiceDesc for LLM service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var LLM_ServiceDesc = grpc.ServiceDesc{
ServiceName: "llm.LLM",
HandlerType: (*LLMServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Health",
Handler: _LLM_Health_Handler,
},
{
MethodName: "Predict",
Handler: _LLM_Predict_Handler,
},
{
MethodName: "LoadModel",
Handler: _LLM_LoadModel_Handler,
},
{
MethodName: "Embedding",
Handler: _LLM_Embedding_Handler,
},
},
Streams: []grpc.StreamDesc{
{
StreamName: "PredictStream",
Handler: _LLM_PredictStream_Handler,
ServerStreams: true,
},
},
Metadata: "pkg/grpc/proto/llmserver.proto",
}

@ -21,7 +21,7 @@ import (
// server is used to implement helloworld.GreeterServer. // server is used to implement helloworld.GreeterServer.
type server struct { type server struct {
pb.UnimplementedLLMServer pb.UnimplementedBackendServer
llm LLM llm LLM
} }
@ -51,7 +51,48 @@ func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply,
return &pb.Reply{Message: result}, err return &pb.Reply{Message: result}, err
} }
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.LLM_PredictStreamServer) error { func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) {
err := s.llm.GenerateImage(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Image generated", Success: true}, nil
}
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) {
err := s.llm.TTS(in)
if err != nil {
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err
}
return &pb.Result{Message: "Audio generated", Success: true}, nil
}
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) {
result, err := s.llm.AudioTranscription(in)
if err != nil {
return nil, err
}
tresult := &pb.TranscriptResult{}
for _, s := range result.Segments {
tks := []int32{}
for _, t := range s.Tokens {
tks = append(tks, int32(t))
}
tresult.Segments = append(tresult.Segments,
&pb.TranscriptSegment{
Text: s.Text,
Id: int32(s.Id),
Start: int64(s.Start),
End: int64(s.End),
Tokens: tks,
})
}
tresult.Text = result.Text
return tresult, nil
}
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error {
resultChan := make(chan string) resultChan := make(chan string)
@ -75,7 +116,7 @@ func StartServer(address string, model LLM) error {
return err return err
} }
s := grpc.NewServer() s := grpc.NewServer()
pb.RegisterLLMServer(s, &server{llm: model}) pb.RegisterBackendServer(s, &server{llm: model})
log.Printf("gRPC Server listening at %v", lis.Addr()) log.Printf("gRPC Server listening at %v", lis.Addr())
if err := s.Serve(lis); err != nil { if err := s.Serve(lis); err != nil {
return err return err

@ -0,0 +1,27 @@
package transcribe
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
whisperutil "github.com/go-skynet/LocalAI/pkg/grpc/whisper"
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
)
type Whisper struct {
base.Base
whisper whisper.Model
}
func (sd *Whisper) Load(opts *pb.ModelOptions) error {
// Note: the Model here is a path to a directory containing the model files
w, err := whisper.New(opts.Model)
sd.whisper = w
return err
}
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (api.Result, error) {
return whisperutil.Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads))
}

@ -0,0 +1,44 @@
package tts
// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"os"
"github.com/go-skynet/LocalAI/pkg/grpc/base"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
piper "github.com/mudler/go-piper"
)
type Piper struct {
base.Base
piper *PiperB
}
func (sd *Piper) Load(opts *pb.ModelOptions) error {
var err error
// Note: the Model here is a path to a directory containing the model files
sd.piper, err = New(opts.LibrarySearchPath)
return err
}
func (sd *Piper) TTS(opts *pb.TTSRequest) error {
return sd.piper.TTS(opts.Text, opts.Model, opts.Dst)
}
type PiperB struct {
assetDir string
}
func New(assetDir string) (*PiperB, error) {
if _, err := os.Stat(assetDir); err != nil {
return nil, err
}
return &PiperB{
assetDir: assetDir,
}, nil
}
func (s *PiperB) TTS(text, model, dst string) error {
return piper.TextToWav(text, model, s.assetDir, "", dst)
}

@ -0,0 +1,16 @@
package api
import "time"
type Segment struct {
Id int `json:"id"`
Start time.Duration `json:"start"`
End time.Duration `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
}
type Result struct {
Segments []Segment `json:"segments"`
Text string `json:"text"`
}

@ -5,25 +5,12 @@ import (
"os" "os"
"os/exec" "os/exec"
"path/filepath" "path/filepath"
"time"
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
wav "github.com/go-audio/wav" wav "github.com/go-audio/wav"
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api"
) )
type Segment struct {
Id int `json:"id"`
Start time.Duration `json:"start"`
End time.Duration `json:"end"`
Text string `json:"text"`
Tokens []int `json:"tokens"`
}
type Result struct {
Segments []Segment `json:"segments"`
Text string `json:"text"`
}
func sh(c string) (string, error) { func sh(c string) (string, error) {
cmd := exec.Command("/bin/sh", "-c", c) cmd := exec.Command("/bin/sh", "-c", c)
cmd.Env = os.Environ() cmd.Env = os.Environ()
@ -42,8 +29,8 @@ func audioToWav(src, dst string) error {
return nil return nil
} }
func Transcript(model whisper.Model, audiopath, language string, threads uint) (Result, error) { func Transcript(model whisper.Model, audiopath, language string, threads uint) (api.Result, error) {
res := Result{} res := api.Result{}
dir, err := os.MkdirTemp("", "whisper") dir, err := os.MkdirTemp("", "whisper")
if err != nil { if err != nil {
@ -99,11 +86,11 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
} }
var tokens []int var tokens []int
for _, t := range(s.Tokens) { for _, t := range s.Tokens {
tokens = append(tokens, t.Id) tokens = append(tokens, t.Id)
} }
segment := Segment{Id: s.Num, Text: s.Text, Start:s.Start, End: s.End, Tokens: tokens} segment := api.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens}
res.Segments = append(res.Segments, segment) res.Segments = append(res.Segments, segment)
res.Text += s.Text res.Text += s.Text

@ -4,18 +4,13 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"os/signal"
"path/filepath" "path/filepath"
"strings" "strings"
"syscall"
"time" "time"
rwkv "github.com/donomii/go-rwkv.cpp"
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
grpc "github.com/go-skynet/LocalAI/pkg/grpc" grpc "github.com/go-skynet/LocalAI/pkg/grpc"
"github.com/go-skynet/LocalAI/pkg/langchain"
"github.com/go-skynet/LocalAI/pkg/stablediffusion"
"github.com/go-skynet/LocalAI/pkg/tts"
bloomz "github.com/go-skynet/bloomz.cpp"
bert "github.com/go-skynet/go-bert.cpp"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"github.com/hpcloud/tail" "github.com/hpcloud/tail"
"github.com/phayes/freeport" "github.com/phayes/freeport"
@ -27,20 +22,22 @@ import (
const tokenizerSuffix = ".tokenizer.json" const tokenizerSuffix = ".tokenizer.json"
const ( const (
LlamaBackend = "llama" LlamaBackend = "llama"
BloomzBackend = "bloomz" BloomzBackend = "bloomz"
StarcoderBackend = "starcoder" StarcoderBackend = "starcoder"
GPTJBackend = "gptj" GPTJBackend = "gptj"
DollyBackend = "dolly" DollyBackend = "dolly"
MPTBackend = "mpt" MPTBackend = "mpt"
GPTNeoXBackend = "gptneox" GPTNeoXBackend = "gptneox"
ReplitBackend = "replit" ReplitBackend = "replit"
Gpt2Backend = "gpt2" Gpt2Backend = "gpt2"
Gpt4AllLlamaBackend = "gpt4all-llama" Gpt4AllLlamaBackend = "gpt4all-llama"
Gpt4AllMptBackend = "gpt4all-mpt" Gpt4AllMptBackend = "gpt4all-mpt"
Gpt4AllJBackend = "gpt4all-j" Gpt4AllJBackend = "gpt4all-j"
Gpt4All = "gpt4all" Gpt4All = "gpt4all"
FalconBackend = "falcon" FalconBackend = "falcon"
FalconGGMLBackend = "falcon-ggml"
BertEmbeddingsBackend = "bert-embeddings" BertEmbeddingsBackend = "bert-embeddings"
RwkvBackend = "rwkv" RwkvBackend = "rwkv"
WhisperBackend = "whisper" WhisperBackend = "whisper"
@ -54,77 +51,39 @@ var autoLoadBackends []string = []string{
LlamaBackend, LlamaBackend,
Gpt4All, Gpt4All,
RwkvBackend, RwkvBackend,
FalconBackend,
WhisperBackend, WhisperBackend,
BertEmbeddingsBackend,
GPTNeoXBackend, GPTNeoXBackend,
BertEmbeddingsBackend,
FalconGGMLBackend,
GPTJBackend, GPTJBackend,
Gpt2Backend, Gpt2Backend,
DollyBackend, DollyBackend,
MPTBackend, MPTBackend,
ReplitBackend, ReplitBackend,
StarcoderBackend, StarcoderBackend,
FalconBackend,
BloomzBackend, BloomzBackend,
} }
var bertEmbeddings = func(modelFile string) (interface{}, error) { func (ml *ModelLoader) StopGRPC() {
return bert.New(modelFile) for _, p := range ml.grpcProcesses {
} p.Stop()
var bloomzLM = func(modelFile string) (interface{}, error) {
return bloomz.New(modelFile)
}
var stableDiffusion = func(assetDir string) (interface{}, error) {
return stablediffusion.New(assetDir)
}
func piperTTS(assetDir string) func(s string) (interface{}, error) {
return func(s string) (interface{}, error) {
return tts.New(assetDir)
}
}
var whisperModel = func(modelFile string) (interface{}, error) {
return whisper.New(modelFile)
}
var lcHuggingFace = func(repoId string) (interface{}, error) {
return langchain.NewHuggingFace(repoId)
}
// func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) {
// return func(s string) (interface{}, error) {
// return llama.New(s, opts...)
// }
// }
// func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) {
// return func(s string) (interface{}, error) {
// return gpt4all.New(s, opts...)
// }
// }
func rwkvLM(tokenFile string, threads uint32) func(string) (interface{}, error) {
return func(s string) (interface{}, error) {
log.Debug().Msgf("Loading RWKV", s, tokenFile)
model := rwkv.LoadFiles(s, tokenFile, threads)
if model == nil {
return nil, fmt.Errorf("could not load model")
}
return model, nil
} }
} }
// starts the grpcModelProcess for the backend, and returns a grpc client // starts the grpcModelProcess for the backend, and returns a grpc client
// It also loads the model // It also loads the model
func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (interface{}, error) { func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) {
return func(s string) (interface{}, error) { return func(s string) (*grpc.Client, error) {
log.Debug().Msgf("Loading GRPC Model", backend, *o) log.Debug().Msgf("Loading GRPC Model", backend, *o)
grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend)
// Check if the file exists
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
}
// Make sure the process is executable // Make sure the process is executable
if err := os.Chmod(grpcProcess, 0755); err != nil { if err := os.Chmod(grpcProcess, 0755); err != nil {
return nil, err return nil, err
@ -151,6 +110,14 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (inter
return nil, err return nil, err
} }
// clean up process
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
grpcControlProcess.Stop()
}()
go func() { go func() {
t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true})
if err != nil { if err != nil {
@ -200,7 +167,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (inter
log.Debug().Msgf("GRPC: Loading model with options: %+v", options) log.Debug().Msgf("GRPC: Loading model with options: %+v", options)
res, err := client.LoadModel(context.TODO(), &options) res, err := client.LoadModel(o.context, &options)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -212,63 +179,37 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (inter
} }
} }
func (ml *ModelLoader) BackendLoader(opts ...Option) (model interface{}, err error) { func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err error) {
//backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32, assetDir string) (model interface{}, err error) {
o := NewOptions(opts...) o := NewOptions(opts...)
log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile)
switch strings.ToLower(o.backendString) {
case LlamaBackend: backend := strings.ToLower(o.backendString)
return ml.LoadModel(o.modelFile, ml.grpcModel(LlamaBackend, o)) switch backend {
case BloomzBackend: case LlamaBackend, GPTJBackend, DollyBackend,
return ml.LoadModel(o.modelFile, bloomzLM) MPTBackend, Gpt2Backend, FalconBackend,
case GPTJBackend: GPTNeoXBackend, ReplitBackend, StarcoderBackend, BloomzBackend,
return ml.LoadModel(o.modelFile, ml.grpcModel(GPTJBackend, o)) RwkvBackend, LCHuggingFaceBackend, BertEmbeddingsBackend, FalconGGMLBackend, StableDiffusionBackend, WhisperBackend:
case DollyBackend: return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o))
return ml.LoadModel(o.modelFile, ml.grpcModel(DollyBackend, o))
case MPTBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(MPTBackend, o))
case Gpt2Backend:
return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt2Backend, o))
case FalconBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(FalconBackend, o))
case GPTNeoXBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(GPTNeoXBackend, o))
case ReplitBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(ReplitBackend, o))
case StableDiffusionBackend:
return ml.LoadModel(o.modelFile, stableDiffusion)
case PiperBackend:
return ml.LoadModel(o.modelFile, piperTTS(filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data")))
case StarcoderBackend:
return ml.LoadModel(o.modelFile, ml.grpcModel(StarcoderBackend, o))
case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All: case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All:
o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all") o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all")
return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt4All, o)) return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt4All, o))
// return ml.LoadModel(o.modelFile, gpt4allLM(gpt4all.SetThreads(int(o.threads)), gpt4all.SetLibrarySearchPath(filepath.Join(o.assetDir, "backend-assets", "gpt4all")))) case PiperBackend:
case BertEmbeddingsBackend: o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data")
return ml.LoadModel(o.modelFile, bertEmbeddings) return ml.LoadModel(o.modelFile, ml.grpcModel(PiperBackend, o))
case RwkvBackend:
return ml.LoadModel(o.modelFile, rwkvLM(filepath.Join(ml.ModelPath, o.modelFile+tokenizerSuffix), o.threads))
case WhisperBackend:
return ml.LoadModel(o.modelFile, whisperModel)
case LCHuggingFaceBackend:
return ml.LoadModel(o.modelFile, lcHuggingFace)
default: default:
return nil, fmt.Errorf("backend unsupported: %s", o.backendString) return nil, fmt.Errorf("backend unsupported: %s", o.backendString)
} }
} }
func (ml *ModelLoader) GreedyLoader(opts ...Option) (interface{}, error) { func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) {
o := NewOptions(opts...) o := NewOptions(opts...)
log.Debug().Msgf("Loading model '%s' greedly", o.modelFile) log.Debug().Msgf("Loading model '%s' greedly", o.modelFile)
// Is this really needed? BackendLoader already does this
ml.mu.Lock() ml.mu.Lock()
m, exists := ml.models[o.modelFile] if m := ml.checkIsLoaded(o.modelFile); m != nil {
if exists {
log.Debug().Msgf("Model '%s' already loaded", o.modelFile) log.Debug().Msgf("Model '%s' already loaded", o.modelFile)
ml.mu.Unlock() ml.mu.Unlock()
return m, nil return m, nil
@ -285,7 +226,7 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (interface{}, error) {
model, modelerr := ml.BackendLoader( model, modelerr := ml.BackendLoader(
WithBackendString(b), WithBackendString(b),
WithModelFile(o.modelFile), WithModelFile(o.modelFile),
WithLoadGRPCOpts(o.gRPCOptions), WithLoadGRPCLLMModelOpts(o.gRPCOptions),
WithThreads(o.threads), WithThreads(o.threads),
WithAssetDir(o.assetDir), WithAssetDir(o.assetDir),
) )

@ -2,6 +2,7 @@ package model
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
@ -10,6 +11,7 @@ import (
"sync" "sync"
"text/template" "text/template"
"github.com/go-skynet/LocalAI/pkg/grpc"
process "github.com/mudler/go-processmanager" process "github.com/mudler/go-processmanager"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -18,7 +20,7 @@ type ModelLoader struct {
ModelPath string ModelPath string
mu sync.Mutex mu sync.Mutex
// TODO: this needs generics // TODO: this needs generics
models map[string]interface{} models map[string]*grpc.Client
grpcProcesses map[string]*process.Process grpcProcesses map[string]*process.Process
promptsTemplates map[string]*template.Template promptsTemplates map[string]*template.Template
} }
@ -26,7 +28,7 @@ type ModelLoader struct {
func NewModelLoader(modelPath string) *ModelLoader { func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{ return &ModelLoader{
ModelPath: modelPath, ModelPath: modelPath,
models: make(map[string]interface{}), models: make(map[string]*grpc.Client),
promptsTemplates: make(map[string]*template.Template), promptsTemplates: make(map[string]*template.Template),
grpcProcesses: make(map[string]*process.Process), grpcProcesses: make(map[string]*process.Process),
} }
@ -113,14 +115,14 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error {
return nil return nil
} }
func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (interface{}, error)) (interface{}, error) { func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Client, error)) (*grpc.Client, error) {
ml.mu.Lock() ml.mu.Lock()
defer ml.mu.Unlock() defer ml.mu.Unlock()
// Check if we already have a loaded model // Check if we already have a loaded model
if m, ok := ml.models[modelName]; ok { if model := ml.checkIsLoaded(modelName); model != nil {
log.Debug().Msgf("Model already loaded in memory: %s", modelName) log.Debug().Msgf("Model already loaded in memory: %s", modelName)
return m, nil return model, nil
} }
// Load the model and keep it in memory for later use // Load the model and keep it in memory for later use
@ -140,3 +142,25 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (interfac
ml.models[modelName] = model ml.models[modelName] = model
return model, nil return model, nil
} }
func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client {
if m, ok := ml.models[s]; ok {
log.Debug().Msgf("Model already loaded in memory: %s", s)
if !m.HealthCheck(context.Background()) {
log.Debug().Msgf("GRPC Model not responding", s)
if !ml.grpcProcesses[s].IsAlive() {
log.Debug().Msgf("GRPC Process is not responding", s)
// stop and delete the process, this forces to re-load the model and re-create again the service
ml.grpcProcesses[s].Stop()
delete(ml.grpcProcesses, s)
delete(ml.models, s)
return nil
}
}
return m
}
return nil
}

@ -1,6 +1,8 @@
package model package model
import ( import (
"context"
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto"
) )
@ -9,6 +11,7 @@ type Options struct {
modelFile string modelFile string
threads uint32 threads uint32
assetDir string assetDir string
context context.Context
gRPCOptions *pb.ModelOptions gRPCOptions *pb.ModelOptions
} }
@ -27,7 +30,7 @@ func WithModelFile(modelFile string) Option {
} }
} }
func WithLoadGRPCOpts(opts *pb.ModelOptions) Option { func WithLoadGRPCLLMModelOpts(opts *pb.ModelOptions) Option {
return func(o *Options) { return func(o *Options) {
o.gRPCOptions = opts o.gRPCOptions = opts
} }
@ -45,8 +48,17 @@ func WithAssetDir(assetDir string) Option {
} }
} }
func WithContext(ctx context.Context) Option {
return func(o *Options) {
o.context = ctx
}
}
func NewOptions(opts ...Option) *Options { func NewOptions(opts ...Option) *Options {
o := &Options{} o := &Options{
gRPCOptions: &pb.ModelOptions{},
context: context.Background(),
}
for _, opt := range opts { for _, opt := range opts {
opt(o) opt(o)
} }

@ -1,12 +0,0 @@
//go:build tts
// +build tts
package tts
import (
piper "github.com/mudler/go-piper"
)
func tts(text, model, assetDir, arLib, dst string) error {
return piper.TextToWav(text, model, assetDir, arLib, dst)
}

@ -1,10 +0,0 @@
//go:build !tts
// +build !tts
package tts
import "fmt"
func tts(text, model, assetDir, arLib, dst string) error {
return fmt.Errorf("this version of LocalAI was built without the tts tag")
}

@ -1,20 +0,0 @@
package tts
import "os"
type Piper struct {
assetDir string
}
func New(assetDir string) (*Piper, error) {
if _, err := os.Stat(assetDir); err != nil {
return nil, err
}
return &Piper{
assetDir: assetDir,
}, nil
}
func (s *Piper) TTS(text, model, dst string) error {
return tts(text, model, s.assetDir, "", dst)
}
Loading…
Cancel
Save