diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a18cd20..5b8385c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,9 +26,29 @@ jobs: run: | sudo apt-get update sudo apt-get install build-essential ffmpeg + + sudo apt-get install -y ca-certificates cmake curl patch + sudo apt-get install -y libopencv-dev && sudo ln -s /usr/include/opencv4/opencv2 /usr/include/opencv2 + + sudo mkdir /build && sudo chmod -R 777 /build && cd /build && \ + curl -L "https://github.com/gabime/spdlog/archive/refs/tags/v1.11.0.tar.gz" | \ + tar -xzvf - && \ + mkdir -p "spdlog-1.11.0/build" && \ + cd "spdlog-1.11.0/build" && \ + cmake .. && \ + make -j8 && \ + sudo cmake --install . --prefix /usr && mkdir -p "lib/Linux-$(uname -m)" && \ + cd /build && \ + mkdir -p "lib/Linux-$(uname -m)/piper_phonemize" && \ + curl -L "https://github.com/rhasspy/piper-phonemize/releases/download/v1.0.0/libpiper_phonemize-amd64.tar.gz" | \ + tar -C "lib/Linux-$(uname -m)/piper_phonemize" -xzvf - && ls -liah /build/lib/Linux-$(uname -m)/piper_phonemize/ && \ + sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /lib64/ && \ + sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/lib/. /usr/lib/ && \ + sudo cp -rfv /build/lib/Linux-$(uname -m)/piper_phonemize/include/. /usr/include/ + - name: Test run: | - make test + ESPEAK_DATA="/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data" GO_TAGS="tts stablediffusion" make test macOS-latest: runs-on: macOS-latest diff --git a/.gitignore b/.gitignore index 8ad9f22..7b35ba9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,19 @@ # go-llama build artifacts go-llama -gpt4all +/gpt4all go-stable-diffusion +go-piper +go-ggllm +/piper + +*.a +get-sources + go-ggml-transformers go-gpt2 go-rwkv whisper.cpp -bloomz +/bloomz go-bert # LocalAI build binary @@ -29,4 +36,4 @@ release/ # Generated during build backend-assets/ -/ggml-metal.metal \ No newline at end of file +/ggml-metal.metal diff --git a/Makefile b/Makefile index d885b94..ba01c59 100644 --- a/Makefile +++ b/Makefile @@ -41,6 +41,9 @@ BLOOMZ_VERSION?=1834e77b83faafe912ad4092ccf7f77937349e2f # stablediffusion version STABLEDIFFUSION_VERSION?=d89260f598afb809279bc72aa0107b4292587632 +# Go-ggllm +GOGGLLM_VERSION?=862477d16eefb0805261c19c9b0d053e3b2b684b + export BUILD_TYPE?= CGO_LDFLAGS?= CUDA_LIBPATH?=/usr/local/cuda/lib64/ @@ -64,8 +67,14 @@ WHITE := $(shell tput -Txterm setaf 7) CYAN := $(shell tput -Txterm setaf 6) RESET := $(shell tput -Txterm sgr0) -C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz -LIBRARY_PATH=$(shell pwd)/go-piper:$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz +ifndef UNAME_S +UNAME_S := $(shell uname -s) +endif + +# workaround for rwkv.cpp +ifeq ($(UNAME_S),Darwin) + CGO_LDFLAGS += -lcblas -framework Accelerate +endif ifeq ($(BUILD_TYPE),openblas) CGO_LDFLAGS+=-lopenblas @@ -91,12 +100,14 @@ ifeq ($(STATIC),true) endif ifeq ($(findstring stablediffusion,$(GO_TAGS)),stablediffusion) - OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a +# OPTIONAL_TARGETS+=go-stable-diffusion/libstablediffusion.a + OPTIONAL_GRPC+=backend-assets/grpc/stablediffusion endif ifeq ($(findstring tts,$(GO_TAGS)),tts) - OPTIONAL_TARGETS+=go-piper/libpiper_binding.a - OPTIONAL_TARGETS+=backend-assets/espeak-ng-data +# OPTIONAL_TARGETS+=go-piper/libpiper_binding.a +# OPTIONAL_TARGETS+=backend-assets/espeak-ng-data + OPTIONAL_GRPC+=backend-assets/grpc/piper endif .PHONY: all test build vendor @@ -107,24 +118,14 @@ all: help gpt4all: git clone --recurse-submodules $(GPT4ALL_REPO) gpt4all cd gpt4all && git checkout -b build $(GPT4ALL_VERSION) && git submodule update --init --recursive --depth 1 - # This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml.. - @find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.m" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.c" -exec sed -i'' -e 's/llama_/llama_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.cpp" -exec sed -i'' -e 's/llama_/llama_gpt4all_/g' {} + - @find ./gpt4all -type f -name "*.h" -exec sed -i'' -e 's/llama_/llama_gpt4all_/g' {} + - @find ./gpt4all/gpt4all-backend -type f -name "llama_util.h" -execdir mv {} "llama_gpt4all_util.h" \; - @find ./gpt4all -type f -name "*.cmake" -exec sed -i'' -e 's/llama_util/llama_gpt4all_util/g' {} + - @find ./gpt4all -type f -name "*.txt" -exec sed -i'' -e 's/llama_util/llama_gpt4all_util/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.cpp" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.go" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/load_model/load_gpt4all_model/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/set_numa_thread_affinity/gpt4all_set_numa_thread_affinity/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.c" -exec sed -i'' -e 's/set_numa_thread_affinity/gpt4all__set_numa_thread_affinity/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.c" -exec sed -i'' -e 's/clear_numa_thread_affinity/gpt4all__clear_numa_thread_affinity/g' {} + - @find ./gpt4all/gpt4all-bindings/golang -type f -name "*.h" -exec sed -i'' -e 's/clear_numa_thread_affinity/gpt4all__clear_numa_thread_affinity/g' {} + + +## go-ggllm +go-ggllm: + git clone --recurse-submodules https://github.com/mudler/go-ggllm.cpp go-ggllm + cd go-ggllm && git checkout -b build $(GOGGLLM_VERSION) && git submodule update --init --recursive --depth 1 + +go-ggllm/libggllm.a: go-ggllm + $(MAKE) -C go-ggllm BUILD_TYPE=$(BUILD_TYPE) libggllm.a ## go-piper go-piper: @@ -135,9 +136,6 @@ go-piper: go-bert: git clone --recurse-submodules https://github.com/go-skynet/go-bert.cpp go-bert cd go-bert && git checkout -b build $(BERT_VERSION) && git submodule update --init --recursive --depth 1 - @find ./go-bert -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} + - @find ./go-bert -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} + - @find ./go-bert -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bert_/g' {} + ## stable diffusion go-stable-diffusion: @@ -151,9 +149,6 @@ go-stable-diffusion/libstablediffusion.a: go-rwkv: git clone --recurse-submodules $(RWKV_REPO) go-rwkv cd go-rwkv && git checkout -b build $(RWKV_VERSION) && git submodule update --init --recursive --depth 1 - @find ./go-rwkv -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} + - @find ./go-rwkv -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} + - @find ./go-rwkv -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_rwkv_/g' {} + go-rwkv/librwkv.a: go-rwkv cd go-rwkv && cd rwkv.cpp && cmake . -DRWKV_BUILD_SHARED_LIBRARY=OFF && cmake --build . && cp librwkv.a .. @@ -161,13 +156,7 @@ go-rwkv/librwkv.a: go-rwkv ## bloomz bloomz: git clone --recurse-submodules https://github.com/go-skynet/bloomz.cpp bloomz - @find ./bloomz -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} + - @find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} + - @find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_bloomz_/g' {} + - @find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} + - @find ./bloomz -type f -name "*.h" -exec sed -i'' -e 's/gpt_/gpt_bloomz_/g' {} + - @find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/void replace/void json_bloomz_replace/g' {} + - @find ./bloomz -type f -name "*.cpp" -exec sed -i'' -e 's/::replace/::json_bloomz_replace/g' {} + + cd bloomz && git checkout -b build $(BLOOMZ_VERSION) && git submodule update --init --recursive --depth 1 bloomz/libbloomz.a: bloomz cd bloomz && make libbloomz.a @@ -186,6 +175,7 @@ backend-assets/espeak-ng-data: ifdef ESPEAK_DATA @cp -rf $(ESPEAK_DATA)/. backend-assets/espeak-ng-data else + @echo "ESPEAK_DATA not set, skipping tts. Note that this will break the tts functionality." @touch backend-assets/espeak-ng-data/keep endif @@ -196,21 +186,6 @@ gpt4all/gpt4all-bindings/golang/libgpt4all.a: gpt4all go-ggml-transformers: git clone --recurse-submodules https://github.com/go-skynet/go-ggml-transformers.cpp go-ggml-transformers cd go-ggml-transformers && git checkout -b build $(GOGPT2_VERSION) && git submodule update --init --recursive --depth 1 - # This is hackish, but needed as both go-llama and go-gpt4allj have their own version of ggml.. - @find ./go-ggml-transformers -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_print_usage/gpt2_print_usage/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_print_usage/gpt2_print_usage/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_params_parse/gpt2_params_parse/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_params_parse/gpt2_params_parse/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/gpt_random_prompt/gpt2_random_prompt/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/gpt_random_prompt/gpt2_random_prompt/g' {} + - @find ./go-ggml-transformers -type f -name "*.cpp" -exec sed -i'' -e 's/json_/json_gpt2_/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/set_numa_thread_affinity/transformers_set_numa_thread_affinity/g' {} + - @find ./go-ggml-transformers -type f -name "*.c" -exec sed -i'' -e 's/set_numa_thread_affinity/transformers_set_numa_thread_affinity/g' {} + - @find ./go-ggml-transformers -type f -name "*.c" -exec sed -i'' -e 's/clear_numa_thread_affinity/transformers_clear_numa_thread_affinity/g' {} + - @find ./go-ggml-transformers -type f -name "*.h" -exec sed -i'' -e 's/clear_numa_thread_affinity/transformers_clear_numa_thread_affinity/g' {} + go-ggml-transformers/libtransformers.a: go-ggml-transformers $(MAKE) -C go-ggml-transformers libtransformers.a @@ -218,9 +193,6 @@ go-ggml-transformers/libtransformers.a: go-ggml-transformers whisper.cpp: git clone https://github.com/ggerganov/whisper.cpp.git cd whisper.cpp && git checkout -b build $(WHISPER_CPP_VERSION) && git submodule update --init --recursive --depth 1 - @find ./whisper.cpp -type f -name "*.c" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} + - @find ./whisper.cpp -type f -name "*.cpp" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} + - @find ./whisper.cpp -type f -name "*.h" -exec sed -i'' -e 's/ggml_/ggml_whisper_/g' {} + whisper.cpp/libwhisper.a: whisper.cpp cd whisper.cpp && make libwhisper.a @@ -238,7 +210,7 @@ go-llama/libbinding.a: go-llama go-piper/libpiper_binding.a: $(MAKE) -C go-piper libpiper_binding.a example/main -get-sources: go-llama go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion +get-sources: go-llama go-ggllm go-ggml-transformers gpt4all go-piper go-rwkv whisper.cpp go-bert bloomz go-stable-diffusion touch $@ replace: @@ -251,6 +223,7 @@ replace: $(GOCMD) mod edit -replace github.com/go-skynet/bloomz.cpp=$(shell pwd)/bloomz $(GOCMD) mod edit -replace github.com/mudler/go-stable-diffusion=$(shell pwd)/go-stable-diffusion $(GOCMD) mod edit -replace github.com/mudler/go-piper=$(shell pwd)/go-piper + $(GOCMD) mod edit -replace github.com/mudler/go-ggllm.cpp=$(shell pwd)/go-ggllm prepare-sources: get-sources replace $(GOCMD) mod download @@ -267,9 +240,10 @@ rebuild: ## Rebuilds the project $(MAKE) -C go-bert clean $(MAKE) -C bloomz clean $(MAKE) -C go-piper clean + $(MAKE) -C go-ggllm clean $(MAKE) build -prepare: prepare-sources backend-assets/gpt4all $(OPTIONAL_TARGETS) go-llama/libbinding.a go-bert/libgobert.a go-ggml-transformers/libtransformers.a go-rwkv/librwkv.a whisper.cpp/libwhisper.a bloomz/libbloomz.a ## Prepares for building +prepare: prepare-sources $(OPTIONAL_TARGETS) touch $@ clean: ## Remove build related file @@ -285,18 +259,19 @@ clean: ## Remove build related file rm -rf ./bloomz rm -rf ./whisper.cpp rm -rf ./go-piper + rm -rf ./go-ggllm rm -rf $(BINARY_NAME) rm -rf release/ ## Build: -build: prepare ## Build the project +build: grpcs prepare ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) $(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET}) - CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./ + CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./ ifeq ($(BUILD_TYPE),metal) cp go-llama/build/bin/ggml-metal.metal . endif @@ -305,12 +280,9 @@ dist: build mkdir -p release cp $(BINARY_NAME) release/$(BINARY_NAME)-$(BUILD_ID)-$(OS)-$(ARCH) -generic-build: ## Build the project using generic - BUILD_TYPE="generic" $(MAKE) build - ## Run run: prepare ## run local-ai - CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} $(GOCMD) run ./ + CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./ test-models/testmodel: mkdir test-models @@ -323,12 +295,42 @@ test-models/testmodel: wget https://raw.githubusercontent.com/saharNooby/rwkv.cpp/5eb8f09c146ea8124633ab041d9ea0b1f1db4459/rwkv/20B_tokenizer.json -O test-models/rwkv.tokenizer.json cp tests/models_fixtures/* test-models -test: prepare test-models/testmodel - cp -r backend-assets api +prepare-test: grpcs + cp -rf backend-assets api cp tests/models_fixtures/* test-models - C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg - C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg - C_INCLUDE_PATH=${C_INCLUDE_PATH} LIBRARY_PATH=${LIBRARY_PATH} TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg + +test: prepare test-models/testmodel grpcs + @echo 'Running tests' + export GO_TAGS="tts stablediffusion" + $(MAKE) prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!gpt4all && !llama" --flake-attempts 5 -v -r ./api ./pkg + $(MAKE) test-gpt4all + $(MAKE) test-llama + $(MAKE) test-tts + $(MAKE) test-stablediffusion + +test-gpt4all: prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="gpt4all" --flake-attempts 5 -v -r ./api ./pkg + +test-llama: prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama" --flake-attempts 5 -v -r ./api ./pkg + +test-tts: prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts 1 -v -r ./api ./pkg + +test-stablediffusion: prepare-test + TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models \ + $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts 1 -v -r ./api ./pkg + +test-container: + docker build --target requirements -t local-ai-test-container . + docker run --name localai-tests -e GO_TAGS=$(GO_TAGS) -ti -v $(abspath ./):/build local-ai-test-container make test + docker rm localai-tests + docker rmi local-ai-test-container ## Help: help: ## Show this help. @@ -341,3 +343,85 @@ help: ## Show this help. if (/^[a-zA-Z_-]+:.*?##.*$$/) {printf " ${YELLOW}%-20s${GREEN}%s${RESET}\n", $$1, $$2} \ else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \ }' $(MAKEFILE_LIST) + +protogen: + protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \ + pkg/grpc/proto/backend.proto + +## GRPC + +backend-assets/grpc: + mkdir -p backend-assets/grpc + +backend-assets/grpc/falcon: backend-assets/grpc go-ggllm/libggllm.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggllm LIBRARY_PATH=$(shell pwd)/go-ggllm \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/ + +backend-assets/grpc/llama: backend-assets/grpc go-llama/libbinding.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/ + +backend-assets/grpc/gpt4all: backend-assets/grpc backend-assets/gpt4all gpt4all/gpt4all-bindings/golang/libgpt4all.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ LIBRARY_PATH=$(shell pwd)/gpt4all/gpt4all-bindings/golang/ \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt4all ./cmd/grpc/gpt4all/ + +backend-assets/grpc/dolly: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/dolly ./cmd/grpc/dolly/ + +backend-assets/grpc/gpt2: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gpt2 ./cmd/grpc/gpt2/ + +backend-assets/grpc/gptj: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptj ./cmd/grpc/gptj/ + +backend-assets/grpc/gptneox: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/gptneox ./cmd/grpc/gptneox/ + +backend-assets/grpc/mpt: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/mpt ./cmd/grpc/mpt/ + +backend-assets/grpc/replit: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/replit ./cmd/grpc/replit/ + +backend-assets/grpc/falcon-ggml: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon-ggml ./cmd/grpc/falcon-ggml/ + +backend-assets/grpc/starcoder: backend-assets/grpc go-ggml-transformers/libtransformers.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggml-transformers LIBRARY_PATH=$(shell pwd)/go-ggml-transformers \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/starcoder ./cmd/grpc/starcoder/ + +backend-assets/grpc/rwkv: backend-assets/grpc go-rwkv/librwkv.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-rwkv LIBRARY_PATH=$(shell pwd)/go-rwkv \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/rwkv ./cmd/grpc/rwkv/ + +backend-assets/grpc/bloomz: backend-assets/grpc bloomz/libbloomz.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/bloomz LIBRARY_PATH=$(shell pwd)/bloomz \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bloomz ./cmd/grpc/bloomz/ + +backend-assets/grpc/bert-embeddings: backend-assets/grpc go-bert/libgobert.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-bert LIBRARY_PATH=$(shell pwd)/go-bert \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bert-embeddings ./cmd/grpc/bert-embeddings/ + +backend-assets/grpc/langchain-huggingface: backend-assets/grpc + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/langchain-huggingface ./cmd/grpc/langchain-huggingface/ + +backend-assets/grpc/stablediffusion: backend-assets/grpc go-stable-diffusion/libstablediffusion.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/ LIBRARY_PATH=$(shell pwd)/go-stable-diffusion/ \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/stablediffusion ./cmd/grpc/stablediffusion/ + +backend-assets/grpc/piper: backend-assets/grpc backend-assets/espeak-ng-data go-piper/libpiper_binding.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" LIBRARY_PATH=$(shell pwd)/go-piper \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./cmd/grpc/piper/ + +backend-assets/grpc/whisper: backend-assets/grpc whisper.cpp/libwhisper.a + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/whisper.cpp LIBRARY_PATH=$(shell pwd)/whisper.cpp \ + $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./cmd/grpc/whisper/ + +grpcs: prepare backend-assets/grpc/langchain-huggingface backend-assets/grpc/falcon-ggml backend-assets/grpc/bert-embeddings backend-assets/grpc/falcon backend-assets/grpc/bloomz backend-assets/grpc/llama backend-assets/grpc/gpt4all backend-assets/grpc/dolly backend-assets/grpc/gpt2 backend-assets/grpc/gptj backend-assets/grpc/gptneox backend-assets/grpc/mpt backend-assets/grpc/replit backend-assets/grpc/starcoder backend-assets/grpc/rwkv backend-assets/grpc/whisper $(OPTIONAL_GRPC) \ No newline at end of file diff --git a/api/api.go b/api/api.go index 543e756..8dcefa2 100644 --- a/api/api.go +++ b/api/api.go @@ -3,8 +3,13 @@ package api import ( "errors" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/localai" + "github.com/go-skynet/LocalAI/api/openai" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/internal" "github.com/go-skynet/LocalAI/pkg/assets" + "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/logger" @@ -13,18 +18,18 @@ import ( "github.com/rs/zerolog/log" ) -func App(opts ...AppOption) (*fiber.App, error) { - options := newOptions(opts...) +func App(opts ...options.AppOption) (*fiber.App, error) { + options := options.NewOptions(opts...) zerolog.SetGlobalLevel(zerolog.InfoLevel) - if options.debug { + if options.Debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) } // Return errors as JSON responses app := fiber.New(fiber.Config{ - BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: options.disableMessage, + BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + DisableStartupMessage: options.DisableMessage, // Override default error handler ErrorHandler: func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 @@ -38,43 +43,44 @@ func App(opts ...AppOption) (*fiber.App, error) { // Send custom error page return ctx.Status(code).JSON( - ErrorResponse{ - Error: &APIError{Message: err.Error(), Code: code}, + openai.ErrorResponse{ + Error: &openai.APIError{Message: err.Error(), Code: code}, }, ) }, }) - if options.debug { + if options.Debug { app.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) } - log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.threads, options.loader.ModelPath) + log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath) log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) - cm := NewConfigMerger() - if err := cm.LoadConfigs(options.loader.ModelPath); err != nil { + cm := config.NewConfigLoader() + if err := cm.LoadConfigs(options.Loader.ModelPath); err != nil { log.Error().Msgf("error loading config files: %s", err.Error()) } - if options.configFile != "" { - if err := cm.LoadConfigFile(options.configFile); err != nil { + if options.ConfigFile != "" { + if err := cm.LoadConfigFile(options.ConfigFile); err != nil { log.Error().Msgf("error loading config file: %s", err.Error()) } } - if options.debug { + if options.Debug { for _, v := range cm.ListConfigs() { cfg, _ := cm.GetConfig(v) log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) } } - if options.assetsDestination != "" { + if options.AssetsDestination != "" { // Extract files from the embedded FS - err := assets.ExtractFiles(options.backendAssets, options.assetsDestination) + err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination) + log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination) if err != nil { log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err) } @@ -83,31 +89,32 @@ func App(opts ...AppOption) (*fiber.App, error) { // Default middleware config app.Use(recover.New()) - if options.preloadJSONModels != "" { - if err := ApplyGalleryFromString(options.loader.ModelPath, options.preloadJSONModels, cm, options.galleries); err != nil { + if options.PreloadJSONModels != "" { + if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil { return nil, err } } - if options.preloadModelsFromPath != "" { - if err := ApplyGalleryFromFile(options.loader.ModelPath, options.preloadModelsFromPath, cm, options.galleries); err != nil { + if options.PreloadModelsFromPath != "" { + if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cm, options.Galleries); err != nil { return nil, err } } - if options.cors { - if options.corsAllowOrigins == "" { - app.Use(cors.New()) + if options.CORS { + var c func(ctx *fiber.Ctx) error + if options.CORSAllowOrigins == "" { + c = cors.New() } else { - app.Use(cors.New(cors.Config{ - AllowOrigins: options.corsAllowOrigins, - })) + c = cors.New(cors.Config{AllowOrigins: options.CORSAllowOrigins}) } + + app.Use(c) } // LocalAI API endpoints - applier := newGalleryApplier(options.loader.ModelPath) - applier.start(options.context, cm) + galleryService := localai.NewGalleryService(options.Loader.ModelPath) + galleryService.Start(options.Context, cm) app.Get("/version", func(c *fiber.Ctx) error { return c.JSON(struct { @@ -115,43 +122,43 @@ func App(opts ...AppOption) (*fiber.App, error) { }{Version: internal.PrintableVersion()}) }) - app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C, options.galleries)) - app.Get("/models/available", listModelFromGallery(options.galleries, options.loader.ModelPath)) - app.Get("/models/jobs/:uuid", getOpStatus(applier)) + app.Post("/models/apply", localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries)) + app.Get("/models/available", localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath)) + app.Get("/models/jobs/:uuid", localai.GetOpStatusEndpoint(galleryService)) // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", chatEndpoint(cm, options)) - app.Post("/chat/completions", chatEndpoint(cm, options)) + app.Post("/v1/chat/completions", openai.ChatEndpoint(cm, options)) + app.Post("/chat/completions", openai.ChatEndpoint(cm, options)) // edit - app.Post("/v1/edits", editEndpoint(cm, options)) - app.Post("/edits", editEndpoint(cm, options)) + app.Post("/v1/edits", openai.EditEndpoint(cm, options)) + app.Post("/edits", openai.EditEndpoint(cm, options)) // completion - app.Post("/v1/completions", completionEndpoint(cm, options)) - app.Post("/completions", completionEndpoint(cm, options)) - app.Post("/v1/engines/:model/completions", completionEndpoint(cm, options)) + app.Post("/v1/completions", openai.CompletionEndpoint(cm, options)) + app.Post("/completions", openai.CompletionEndpoint(cm, options)) + app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cm, options)) // embeddings - app.Post("/v1/embeddings", embeddingsEndpoint(cm, options)) - app.Post("/embeddings", embeddingsEndpoint(cm, options)) - app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options)) + app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cm, options)) + app.Post("/embeddings", openai.EmbeddingsEndpoint(cm, options)) + app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cm, options)) // audio - app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options)) - app.Post("/tts", ttsEndpoint(cm, options)) + app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cm, options)) + app.Post("/tts", localai.TTSEndpoint(cm, options)) // images - app.Post("/v1/images/generations", imageEndpoint(cm, options)) + app.Post("/v1/images/generations", openai.ImageEndpoint(cm, options)) - if options.imageDir != "" { - app.Static("/generated-images", options.imageDir) + if options.ImageDir != "" { + app.Static("/generated-images", options.ImageDir) } - if options.audioDir != "" { - app.Static("/generated-audio", options.audioDir) + if options.AudioDir != "" { + app.Static("/generated-audio", options.AudioDir) } ok := func(c *fiber.Ctx) error { @@ -163,8 +170,15 @@ func App(opts ...AppOption) (*fiber.App, error) { app.Get("/readyz", ok) // models - app.Get("/v1/models", listModels(options.loader, cm)) - app.Get("/models", listModels(options.loader, cm)) + app.Get("/v1/models", openai.ListModelsEndpoint(options.Loader, cm)) + app.Get("/models", openai.ListModelsEndpoint(options.Loader, cm)) + + // turn off any process that was started by GRPC if the context is canceled + go func() { + <-options.Context.Done() + log.Debug().Msgf("Context canceled, shutting down") + options.Loader.StopGRPC() + }() return app, nil } diff --git a/api/api_test.go b/api/api_test.go index 43aa30b..ca840b5 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -5,7 +5,9 @@ import ( "context" "embed" "encoding/json" + "errors" "fmt" + "io" "io/ioutil" "net/http" "os" @@ -13,6 +15,7 @@ import ( "runtime" . "github.com/go-skynet/LocalAI/api" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/utils" @@ -23,6 +26,7 @@ import ( openaigo "github.com/otiai10/openaigo" "github.com/sashabaranov/go-openai" + "github.com/sashabaranov/go-openai/jsonschema" ) type modelApplyRequest struct { @@ -154,9 +158,10 @@ var _ = Describe("API test", func() { }, } - app, err = App(WithContext(c), - WithGalleries(galleries), - WithModelLoader(modelLoader), WithBackendAssets(backendAssets), WithBackendAssetsOutput(tmpdir)) + app, err = App( + options.WithContext(c), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), options.WithBackendAssets(backendAssets), options.WithBackendAssetsOutput(tmpdir)) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -201,7 +206,7 @@ var _ = Describe("API test", func() { fmt.Println(response) resp = response return response["processed"].(bool) - }, "360s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) Expect(resp["message"]).ToNot(ContainSubstring("error")) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert2.yaml")) @@ -243,9 +248,8 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) - fmt.Println(response) return response["processed"].(bool) - }, "360s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) @@ -268,9 +272,8 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) - fmt.Println(response) return response["processed"].(bool) - }, "360s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) dat, err := os.ReadFile(filepath.Join(tmpdir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) @@ -297,14 +300,58 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) - fmt.Println(response) return response["processed"].(bool) - }, "360s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) + By("testing completion") resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "openllama_3b", Prompt: "Count up to five: one, two, three, four, "}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Text).To(ContainSubstring("five")) + + By("testing functions") + resp2, err := client.CreateChatCompletion( + context.TODO(), + openai.ChatCompletionRequest{ + Model: "openllama_3b", + Messages: []openai.ChatCompletionMessage{ + { + Role: "user", + Content: "What is the weather like in San Francisco (celsius)?", + }, + }, + Functions: []openai.FunctionDefinition{ + openai.FunctionDefinition{ + Name: "get_current_weather", + Description: "Get the current weather", + Parameters: jsonschema.Definition{ + Type: jsonschema.Object, + Properties: map[string]jsonschema.Definition{ + "location": { + Type: jsonschema.String, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: jsonschema.String, + Enum: []string{"celcius", "fahrenheit"}, + }, + }, + Required: []string{"location"}, + }, + }, + }, + }) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp2.Choices)).To(Equal(1)) + Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) + Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) + + var res map[string]string + err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) + Expect(err).ToNot(HaveOccurred()) + Expect(res["location"]).To(Equal("San Francisco"), fmt.Sprint(res)) + Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) + Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason)) }) It("runs gpt4all", Label("gpt4all"), func() { @@ -324,15 +371,126 @@ var _ = Describe("API test", func() { Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) - fmt.Println(response) return response["processed"].(bool) - }, "360s").Should(Equal(true)) + }, "360s", "10s").Should(Equal(true)) resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "gpt4all-j", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: "How are you?"}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).To(ContainSubstring("well")) }) + + }) + }) + + Context("Model gallery", func() { + BeforeEach(func() { + var err error + tmpdir, err = os.MkdirTemp("", "") + Expect(err).ToNot(HaveOccurred()) + + modelLoader = model.NewModelLoader(tmpdir) + c, cancel = context.WithCancel(context.Background()) + + galleries := []gallery.Gallery{ + { + Name: "model-gallery", + URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/index.yaml", + }, + } + + app, err = App( + options.WithContext(c), + options.WithAudioDir(tmpdir), + options.WithImageDir(tmpdir), + options.WithGalleries(galleries), + options.WithModelLoader(modelLoader), + options.WithBackendAssets(backendAssets), + options.WithBackendAssetsOutput(tmpdir), + ) + Expect(err).ToNot(HaveOccurred()) + go app.Listen("127.0.0.1:9090") + + defaultConfig := openai.DefaultConfig("") + defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" + + client2 = openaigo.NewClient("") + client2.BaseURL = defaultConfig.BaseURL + + // Wait for API to be ready + client = openai.NewClientWithConfig(defaultConfig) + Eventually(func() error { + _, err := client.ListModels(context.TODO()) + return err + }, "2m").ShouldNot(HaveOccurred()) + }) + + AfterEach(func() { + cancel() + app.Shutdown() + os.RemoveAll(tmpdir) + }) + It("installs and is capable to run tts", Label("tts"), func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + ID: "model-gallery@voice-en-us-kathleen-low", + }) + + Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) + + uuid := response["uuid"].(string) + + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + fmt.Println(response) + return response["processed"].(bool) + }, "360s", "10s").Should(Equal(true)) + + // An HTTP Post to the /tts endpoint should return a wav audio file + resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "en-us-kathleen-low.onnx"}`))) + Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) + dat, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) + + Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat))) + Expect(resp.Header.Get("Content-Type")).To(Equal("audio/x-wav")) + }) + It("installs and is capable to generate images", Label("stablediffusion"), func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + + response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ + ID: "model-gallery@stablediffusion", + }) + + Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) + + uuid := response["uuid"].(string) + + Eventually(func() bool { + response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) + fmt.Println(response) + return response["processed"].(bool) + }, "360s", "10s").Should(Equal(true)) + + resp, err := http.Post( + "http://127.0.0.1:9090/v1/images/generations", + "application/json", + bytes.NewBuffer([]byte(`{ + "prompt": "floating hair, portrait, ((loli)), ((one girl)), cute face, hidden hands, asymmetrical bangs, beautiful detailed eyes, eye shadow, hair ornament, ribbons, bowties, buttons, pleated skirt, (((masterpiece))), ((best quality)), colorful|((part of the head)), ((((mutated hands and fingers)))), deformed, blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, blurry, floating limbs, disconnected limbs, malformed hands, blur, out of focus, long neck, long body, Octane renderer, lowres, bad anatomy, bad hands, text", + "mode": 2, "seed":9000, + "size": "256x256", "n":2}`))) + // The response should contain an URL + Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) + dat, err := io.ReadAll(resp.Body) + Expect(err).ToNot(HaveOccurred(), string(dat)) + Expect(string(dat)).To(ContainSubstring("http://127.0.0.1:9090/"), string(dat)) + Expect(string(dat)).To(ContainSubstring(".png"), string(dat)) + }) }) @@ -342,7 +500,7 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - app, err = App(WithContext(c), WithModelLoader(modelLoader)) + app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader)) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") @@ -399,7 +557,7 @@ var _ = Describe("API test", func() { It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: "abcdedfghikl"}) Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 11 errors occurred:")) + Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:")) }) It("transcribes audio", func() { if runtime.GOOS != "linux" { @@ -444,14 +602,67 @@ var _ = Describe("API test", func() { }) Context("backends", func() { - It("runs rwkv", func() { + It("runs rwkv completion", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices) > 0).To(BeTrue()) - Expect(resp.Choices[0].Text).To(Equal(" five.")) + Expect(resp.Choices[0].Text).To(ContainSubstring("five")) + + stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{ + Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true, + }) + Expect(err).ToNot(HaveOccurred()) + defer stream.Close() + + tokens := 0 + text := "" + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + + Expect(err).ToNot(HaveOccurred()) + text += response.Choices[0].Text + tokens++ + } + Expect(text).ToNot(BeEmpty()) + Expect(text).To(ContainSubstring("five")) + Expect(tokens).ToNot(Or(Equal(1), Equal(0))) + }) + It("runs rwkv chat completion", func() { + if runtime.GOOS != "linux" { + Skip("test supported only on linux") + } + resp, err := client.CreateChatCompletion(context.TODO(), + openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}}) + Expect(err).ToNot(HaveOccurred()) + Expect(len(resp.Choices) > 0).To(BeTrue()) + Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("Sure"), ContainSubstring("five"))) + + stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}}) + Expect(err).ToNot(HaveOccurred()) + defer stream.Close() + + tokens := 0 + text := "" + for { + response, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + + Expect(err).ToNot(HaveOccurred()) + text += response.Choices[0].Delta.Content + tokens++ + } + Expect(text).ToNot(BeEmpty()) + Expect(text).To(Or(ContainSubstring("Sure"), ContainSubstring("five"))) + + Expect(tokens).ToNot(Or(Equal(1), Equal(0))) }) }) }) @@ -462,7 +673,7 @@ var _ = Describe("API test", func() { c, cancel = context.WithCancel(context.Background()) var err error - app, err = App(WithContext(c), WithModelLoader(modelLoader), WithConfigFile(os.Getenv("CONFIG_FILE"))) + app, err = App(options.WithContext(c), options.WithModelLoader(modelLoader), options.WithConfigFile(os.Getenv("CONFIG_FILE"))) Expect(err).ToNot(HaveOccurred()) go app.Listen("127.0.0.1:9090") diff --git a/api/backend/embeddings.go b/api/backend/embeddings.go new file mode 100644 index 0000000..0310347 --- /dev/null +++ b/api/backend/embeddings.go @@ -0,0 +1,105 @@ +package backend + +import ( + "fmt" + "sync" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { + if !c.Embeddings { + return nil, fmt.Errorf("endpoint disabled for this model by API configuration") + } + + modelFile := c.Model + + grpcOpts := gRPCModelOpts(c) + + var inferenceModel interface{} + var err error + + opts := []model.Option{ + model.WithLoadGRPCLLMModelOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), + model.WithAssetDir(o.AssetsDestination), + model.WithModelFile(modelFile), + model.WithContext(o.Context), + } + + if c.Backend == "" { + inferenceModel, err = loader.GreedyLoader(opts...) + } else { + opts = append(opts, model.WithBackendString(c.Backend)) + inferenceModel, err = loader.BackendLoader(opts...) + } + if err != nil { + return nil, err + } + + var fn func() ([]float32, error) + switch model := inferenceModel.(type) { + case *grpc.Client: + fn = func() ([]float32, error) { + predictOptions := gRPCPredictOpts(c, loader.ModelPath) + if len(tokens) > 0 { + embeds := []int32{} + + for _, t := range tokens { + embeds = append(embeds, int32(t)) + } + predictOptions.EmbeddingTokens = embeds + + res, err := model.Embeddings(o.Context, predictOptions) + if err != nil { + return nil, err + } + + return res.Embeddings, nil + } + predictOptions.Embeddings = s + + res, err := model.Embeddings(o.Context, predictOptions) + if err != nil { + return nil, err + } + + return res.Embeddings, nil + } + default: + fn = func() ([]float32, error) { + return nil, fmt.Errorf("embeddings not supported by the backend") + } + } + + return func() ([]float32, error) { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mutexMap.Lock() + l, ok := mutexes[modelFile] + if !ok { + m := &sync.Mutex{} + mutexes[modelFile] = m + l = m + } + mutexMap.Unlock() + l.Lock() + defer l.Unlock() + + embeds, err := fn() + if err != nil { + return embeds, err + } + // Remove trailing 0s + for i := len(embeds) - 1; i >= 0; i-- { + if embeds[i] == 0.0 { + embeds = embeds[:i] + } else { + break + } + } + return embeds, nil + }, nil +} diff --git a/api/backend/image.go b/api/backend/image.go new file mode 100644 index 0000000..a631b3b --- /dev/null +++ b/api/backend/image.go @@ -0,0 +1,60 @@ +package backend + +import ( + "fmt" + "sync" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { + if c.Backend != model.StableDiffusionBackend { + return nil, fmt.Errorf("endpoint only working with stablediffusion models") + } + + inferenceModel, err := loader.BackendLoader( + model.WithBackendString(c.Backend), + model.WithAssetDir(o.AssetsDestination), + model.WithThreads(uint32(c.Threads)), + model.WithContext(o.Context), + model.WithModelFile(c.ImageGenerationAssets), + ) + if err != nil { + return nil, err + } + + fn := func() error { + _, err := inferenceModel.GenerateImage( + o.Context, + &proto.GenerateImageRequest{ + Height: int32(height), + Width: int32(width), + Mode: int32(mode), + Step: int32(step), + Seed: int32(seed), + PositivePrompt: positive_prompt, + NegativePrompt: negative_prompt, + Dst: dst, + }) + return err + } + + return func() error { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mutexMap.Lock() + l, ok := mutexes[c.Backend] + if !ok { + m := &sync.Mutex{} + mutexes[c.Backend] = m + l = m + } + mutexMap.Unlock() + l.Lock() + defer l.Unlock() + + return fn() + }, nil +} diff --git a/api/backend/llm.go b/api/backend/llm.go new file mode 100644 index 0000000..8fcd6da --- /dev/null +++ b/api/backend/llm.go @@ -0,0 +1,98 @@ +package backend + +import ( + "regexp" + "strings" + "sync" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { + modelFile := c.Model + + grpcOpts := gRPCModelOpts(c) + + var inferenceModel *grpc.Client + var err error + + opts := []model.Option{ + model.WithLoadGRPCLLMModelOpts(grpcOpts), + model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup + model.WithAssetDir(o.AssetsDestination), + model.WithModelFile(modelFile), + model.WithContext(o.Context), + } + + if c.Backend == "" { + inferenceModel, err = loader.GreedyLoader(opts...) + } else { + opts = append(opts, model.WithBackendString(c.Backend)) + inferenceModel, err = loader.BackendLoader(opts...) + } + if err != nil { + return nil, err + } + + // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported + fn := func() (string, error) { + opts := gRPCPredictOpts(c, loader.ModelPath) + opts.Prompt = s + if tokenCallback != nil { + ss := "" + err := inferenceModel.PredictStream(o.Context, opts, func(s string) { + tokenCallback(s) + ss += s + }) + return ss, err + } else { + reply, err := inferenceModel.Predict(o.Context, opts) + return reply.Message, err + } + } + + return func() (string, error) { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mutexMap.Lock() + l, ok := mutexes[modelFile] + if !ok { + m := &sync.Mutex{} + mutexes[modelFile] = m + l = m + } + mutexMap.Unlock() + l.Lock() + defer l.Unlock() + + return fn() + }, nil +} + +var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) +var mu sync.Mutex = sync.Mutex{} + +func Finetune(config config.Config, input, prediction string) string { + if config.Echo { + prediction = input + prediction + } + + for _, c := range config.Cutstrings { + mu.Lock() + reg, ok := cutstrings[c] + if !ok { + cutstrings[c] = regexp.MustCompile(c) + reg = cutstrings[c] + } + mu.Unlock() + prediction = reg.ReplaceAllString(prediction, "") + } + + for _, c := range config.TrimSpace { + prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) + } + return prediction + +} diff --git a/api/backend/lock.go b/api/backend/lock.go new file mode 100644 index 0000000..6b4f577 --- /dev/null +++ b/api/backend/lock.go @@ -0,0 +1,22 @@ +package backend + +import "sync" + +// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 +var mutexMap sync.Mutex +var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) + +func Lock(s string) *sync.Mutex { + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mutexMap.Lock() + l, ok := mutexes[s] + if !ok { + m := &sync.Mutex{} + mutexes[s] = m + l = m + } + mutexMap.Unlock() + l.Lock() + + return l +} diff --git a/api/backend/options.go b/api/backend/options.go new file mode 100644 index 0000000..7038ffc --- /dev/null +++ b/api/backend/options.go @@ -0,0 +1,72 @@ +package backend + +import ( + "os" + "path/filepath" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + config "github.com/go-skynet/LocalAI/api/config" +) + +func gRPCModelOpts(c config.Config) *pb.ModelOptions { + b := 512 + if c.Batch != 0 { + b = c.Batch + } + return &pb.ModelOptions{ + ContextSize: int32(c.ContextSize), + Seed: int32(c.Seed), + NBatch: int32(b), + F16Memory: c.F16, + MLock: c.MMlock, + NUMA: c.NUMA, + Embeddings: c.Embeddings, + LowVRAM: c.LowVRAM, + NGPULayers: int32(c.NGPULayers), + MMap: c.MMap, + MainGPU: c.MainGPU, + Threads: int32(c.Threads), + TensorSplit: c.TensorSplit, + } +} + +func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { + promptCachePath := "" + if c.PromptCachePath != "" { + p := filepath.Join(modelPath, c.PromptCachePath) + os.MkdirAll(filepath.Dir(p), 0755) + promptCachePath = p + } + return &pb.PredictOptions{ + Temperature: float32(c.Temperature), + TopP: float32(c.TopP), + TopK: int32(c.TopK), + Tokens: int32(c.Maxtokens), + Threads: int32(c.Threads), + PromptCacheAll: c.PromptCacheAll, + PromptCacheRO: c.PromptCacheRO, + PromptCachePath: promptCachePath, + F16KV: c.F16, + DebugMode: c.Debug, + Grammar: c.Grammar, + + Mirostat: int32(c.Mirostat), + MirostatETA: float32(c.MirostatETA), + MirostatTAU: float32(c.MirostatTAU), + Debug: c.Debug, + StopPrompts: c.StopWords, + Repeat: int32(c.RepeatPenalty), + NKeep: int32(c.Keep), + Batch: int32(c.Batch), + IgnoreEOS: c.IgnoreEOS, + Seed: int32(c.Seed), + FrequencyPenalty: float32(c.FrequencyPenalty), + MLock: c.MMlock, + MMap: c.MMap, + MainGPU: c.MainGPU, + TensorSplit: c.TensorSplit, + TailFreeSamplingZ: float32(c.TFZ), + TypicalP: float32(c.TypicalP), + } +} diff --git a/api/config.go b/api/config.go deleted file mode 100644 index 57fe0d1..0000000 --- a/api/config.go +++ /dev/null @@ -1,401 +0,0 @@ -package api - -import ( - "encoding/json" - "fmt" - "io/fs" - "os" - "path/filepath" - "strings" - "sync" - - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" - "gopkg.in/yaml.v3" -) - -type Config struct { - OpenAIRequest `yaml:"parameters"` - Name string `yaml:"name"` - StopWords []string `yaml:"stopwords"` - Cutstrings []string `yaml:"cutstrings"` - TrimSpace []string `yaml:"trimspace"` - ContextSize int `yaml:"context_size"` - F16 bool `yaml:"f16"` - NUMA bool `yaml:"numa"` - Threads int `yaml:"threads"` - Debug bool `yaml:"debug"` - Roles map[string]string `yaml:"roles"` - Embeddings bool `yaml:"embeddings"` - Backend string `yaml:"backend"` - TemplateConfig TemplateConfig `yaml:"template"` - MirostatETA float64 `yaml:"mirostat_eta"` - MirostatTAU float64 `yaml:"mirostat_tau"` - Mirostat int `yaml:"mirostat"` - NGPULayers int `yaml:"gpu_layers"` - MMap bool `yaml:"mmap"` - MMlock bool `yaml:"mmlock"` - LowVRAM bool `yaml:"low_vram"` - - TensorSplit string `yaml:"tensor_split"` - MainGPU string `yaml:"main_gpu"` - ImageGenerationAssets string `yaml:"asset_dir"` - - PromptCachePath string `yaml:"prompt_cache_path"` - PromptCacheAll bool `yaml:"prompt_cache_all"` - PromptCacheRO bool `yaml:"prompt_cache_ro"` - - Grammar string `yaml:"grammar"` - - FunctionsConfig Functions `yaml:"function"` - - PromptStrings, InputStrings []string - InputToken [][]int - functionCallString, functionCallNameString string -} - -type Functions struct { - DisableNoAction bool `yaml:"disable_no_action"` - NoActionFunctionName string `yaml:"no_action_function_name"` - NoActionDescriptionName string `yaml:"no_action_description_name"` -} - -type TemplateConfig struct { - Completion string `yaml:"completion"` - Functions string `yaml:"function"` - Chat string `yaml:"chat"` - Edit string `yaml:"edit"` -} - -type ConfigMerger struct { - configs map[string]Config - sync.Mutex -} - -func defaultConfig(modelFile string) *Config { - return &Config{ - OpenAIRequest: defaultRequest(modelFile), - } -} - -func NewConfigMerger() *ConfigMerger { - return &ConfigMerger{ - configs: make(map[string]Config), - } -} -func ReadConfigFile(file string) ([]*Config, error) { - c := &[]*Config{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - return *c, nil -} - -func ReadConfig(file string) (*Config, error) { - c := &Config{} - f, err := os.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) - } - if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) - } - - return c, nil -} - -func (cm *ConfigMerger) LoadConfigFile(file string) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadConfigFile(file) - if err != nil { - return fmt.Errorf("cannot load config file: %w", err) - } - - for _, cc := range c { - cm.configs[cc.Name] = *cc - } - return nil -} - -func (cm *ConfigMerger) LoadConfig(file string) error { - cm.Lock() - defer cm.Unlock() - c, err := ReadConfig(file) - if err != nil { - return fmt.Errorf("cannot read config file: %w", err) - } - - cm.configs[c.Name] = *c - return nil -} - -func (cm *ConfigMerger) GetConfig(m string) (Config, bool) { - cm.Lock() - defer cm.Unlock() - v, exists := cm.configs[m] - return v, exists -} - -func (cm *ConfigMerger) ListConfigs() []string { - cm.Lock() - defer cm.Unlock() - var res []string - for k := range cm.configs { - res = append(res, k) - } - return res -} - -func (cm *ConfigMerger) LoadConfigs(path string) error { - cm.Lock() - defer cm.Unlock() - entries, err := os.ReadDir(path) - if err != nil { - return err - } - files := make([]fs.FileInfo, 0, len(entries)) - for _, entry := range entries { - info, err := entry.Info() - if err != nil { - return err - } - files = append(files, info) - } - for _, file := range files { - // Skip templates, YAML and .keep files - if !strings.Contains(file.Name(), ".yaml") { - continue - } - c, err := ReadConfig(filepath.Join(path, file.Name())) - if err == nil { - cm.configs[c.Name] = *c - } - } - - return nil -} - -func updateConfig(config *Config, input *OpenAIRequest) { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != 0 { - config.TopK = input.TopK - } - if input.TopP != 0 { - config.TopP = input.TopP - } - - if input.Grammar != "" { - config.Grammar = input.Grammar - } - - if input.Temperature != 0 { - config.Temperature = input.Temperature - } - - if input.Maxtokens != 0 { - config.Maxtokens = input.Maxtokens - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.F16 { - config.F16 = input.F16 - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != 0 { - config.Seed = input.Seed - } - - if input.Mirostat != 0 { - config.Mirostat = input.Mirostat - } - - if input.MirostatETA != 0 { - config.MirostatETA = input.MirostatETA - } - - if input.MirostatTAU != 0 { - config.MirostatTAU = input.MirostatTAU - } - - if input.TypicalP != 0 { - config.TypicalP = input.TypicalP - } - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []interface{}: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []interface{}: - tokens := []int{} - for _, ii := range i { - tokens = append(tokens, int(ii.(float64))) - } - config.InputToken = append(config.InputToken, tokens) - } - } - } - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.functionCallString = fnc - } - case map[string]interface{}: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - config.functionCallNameString = name - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } -} -func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { - input := new(OpenAIRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", nil, err - } - - modelFile := input.Model - - if c.Params("model") != "" { - modelFile = c.Params("model") - } - - received, _ := json.Marshal(input) - - log.Debug().Msgf("Request received: %s", string(received)) - - // Set model from bearer token, if available - bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) - - // If no model was specified, take the first available - if modelFile == "" && !bearerExists && randomModel { - models, _ := loader.ListModels() - if len(models) > 0 { - modelFile = models[0] - log.Debug().Msgf("No model specified, using: %s", modelFile) - } else { - log.Debug().Msgf("No model specified, returning error") - return "", nil, fmt.Errorf("no model specified") - } - } - - // If a model is found in bearer token takes precedence - if bearerExists { - log.Debug().Msgf("Using model from bearer token: %s", bearer) - modelFile = bearer - } - return modelFile, input, nil -} - -func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { - // Load a config file if present after the model name - modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") - - var config *Config - - defaults := func() { - config = defaultConfig(modelFile) - config.ContextSize = ctx - config.Threads = threads - config.F16 = f16 - config.Debug = debug - } - - cfg, exists := cm.GetConfig(modelFile) - if !exists { - if _, err := os.Stat(modelConfig); err == nil { - if err := cm.LoadConfig(modelConfig); err != nil { - return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - cfg, exists = cm.GetConfig(modelFile) - if exists { - config = &cfg - } else { - defaults() - } - } else { - defaults() - } - } else { - config = &cfg - } - - // Set the parameters for the language model prediction - updateConfig(config, input) - - // Don't allow 0 as setting - if config.Threads == 0 { - if threads != 0 { - config.Threads = threads - } else { - config.Threads = 4 - } - } - - // Enforce debug flag if passed from CLI - if debug { - config.Debug = true - } - - return config, input, nil -} diff --git a/api/config/config.go b/api/config/config.go new file mode 100644 index 0000000..9df8d3e --- /dev/null +++ b/api/config/config.go @@ -0,0 +1,209 @@ +package api_config + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + + "gopkg.in/yaml.v3" +) + +type Config struct { + PredictionOptions `yaml:"parameters"` + Name string `yaml:"name"` + StopWords []string `yaml:"stopwords"` + Cutstrings []string `yaml:"cutstrings"` + TrimSpace []string `yaml:"trimspace"` + ContextSize int `yaml:"context_size"` + F16 bool `yaml:"f16"` + NUMA bool `yaml:"numa"` + Threads int `yaml:"threads"` + Debug bool `yaml:"debug"` + Roles map[string]string `yaml:"roles"` + Embeddings bool `yaml:"embeddings"` + Backend string `yaml:"backend"` + TemplateConfig TemplateConfig `yaml:"template"` + MirostatETA float64 `yaml:"mirostat_eta"` + MirostatTAU float64 `yaml:"mirostat_tau"` + Mirostat int `yaml:"mirostat"` + NGPULayers int `yaml:"gpu_layers"` + MMap bool `yaml:"mmap"` + MMlock bool `yaml:"mmlock"` + LowVRAM bool `yaml:"low_vram"` + + TensorSplit string `yaml:"tensor_split"` + MainGPU string `yaml:"main_gpu"` + ImageGenerationAssets string `yaml:"asset_dir"` + + PromptCachePath string `yaml:"prompt_cache_path"` + PromptCacheAll bool `yaml:"prompt_cache_all"` + PromptCacheRO bool `yaml:"prompt_cache_ro"` + + Grammar string `yaml:"grammar"` + + PromptStrings, InputStrings []string + InputToken [][]int + functionCallString, functionCallNameString string + + FunctionsConfig Functions `yaml:"function"` +} + +type Functions struct { + DisableNoAction bool `yaml:"disable_no_action"` + NoActionFunctionName string `yaml:"no_action_function_name"` + NoActionDescriptionName string `yaml:"no_action_description_name"` +} + +type TemplateConfig struct { + Completion string `yaml:"completion"` + Functions string `yaml:"function"` + Chat string `yaml:"chat"` + Edit string `yaml:"edit"` +} + +type ConfigLoader struct { + configs map[string]Config + sync.Mutex +} + +func (c *Config) SetFunctionCallString(s string) { + c.functionCallString = s +} + +func (c *Config) SetFunctionCallNameString(s string) { + c.functionCallNameString = s +} + +func (c *Config) ShouldUseFunctions() bool { + return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) +} + +func (c *Config) ShouldCallSpecificFunction() bool { + return len(c.functionCallNameString) > 0 +} + +func (c *Config) FunctionToCall() string { + return c.functionCallNameString +} + +func defaultPredictOptions(modelFile string) PredictionOptions { + return PredictionOptions{ + TopP: 0.7, + TopK: 80, + Maxtokens: 512, + Temperature: 0.9, + Model: modelFile, + } +} + +func DefaultConfig(modelFile string) *Config { + return &Config{ + PredictionOptions: defaultPredictOptions(modelFile), + } +} + +func NewConfigLoader() *ConfigLoader { + return &ConfigLoader{ + configs: make(map[string]Config), + } +} +func ReadConfigFile(file string) ([]*Config, error) { + c := &[]*Config{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + return *c, nil +} + +func ReadConfig(file string) (*Config, error) { + c := &Config{} + f, err := os.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("cannot read config file: %w", err) + } + if err := yaml.Unmarshal(f, c); err != nil { + return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + } + + return c, nil +} + +func (cm *ConfigLoader) LoadConfigFile(file string) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadConfigFile(file) + if err != nil { + return fmt.Errorf("cannot load config file: %w", err) + } + + for _, cc := range c { + cm.configs[cc.Name] = *cc + } + return nil +} + +func (cm *ConfigLoader) LoadConfig(file string) error { + cm.Lock() + defer cm.Unlock() + c, err := ReadConfig(file) + if err != nil { + return fmt.Errorf("cannot read config file: %w", err) + } + + cm.configs[c.Name] = *c + return nil +} + +func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { + cm.Lock() + defer cm.Unlock() + v, exists := cm.configs[m] + return v, exists +} + +func (cm *ConfigLoader) ListConfigs() []string { + cm.Lock() + defer cm.Unlock() + var res []string + for k := range cm.configs { + res = append(res, k) + } + return res +} + +func (cm *ConfigLoader) LoadConfigs(path string) error { + cm.Lock() + defer cm.Unlock() + entries, err := os.ReadDir(path) + if err != nil { + return err + } + files := make([]fs.FileInfo, 0, len(entries)) + for _, entry := range entries { + info, err := entry.Info() + if err != nil { + return err + } + files = append(files, info) + } + for _, file := range files { + // Skip templates, YAML and .keep files + if !strings.Contains(file.Name(), ".yaml") { + continue + } + c, err := ReadConfig(filepath.Join(path, file.Name())) + if err == nil { + cm.configs[c.Name] = *c + } + } + + return nil +} diff --git a/api/config_test.go b/api/config/config_test.go similarity index 62% rename from api/config_test.go rename to api/config/config_test.go index 626b90b..4b00d58 100644 --- a/api/config_test.go +++ b/api/config/config_test.go @@ -1,8 +1,10 @@ -package api +package api_config_test import ( "os" + . "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -26,29 +28,29 @@ var _ = Describe("Test cases for config related functions", func() { }) It("Test LoadConfigs", func() { - cm := NewConfigMerger() - options := newOptions() + cm := NewConfigLoader() + opts := options.NewOptions() modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH")) - WithModelLoader(modelLoader)(options) + options.WithModelLoader(modelLoader)(opts) - err := cm.LoadConfigs(options.loader.ModelPath) + err := cm.LoadConfigs(opts.Loader.ModelPath) Expect(err).To(BeNil()) - Expect(cm.configs).ToNot(BeNil()) + Expect(cm.ListConfigs()).ToNot(BeNil()) // config should includes gpt4all models's api.config - Expect(cm.configs).To(HaveKey("gpt4all")) + Expect(cm.ListConfigs()).To(ContainElements("gpt4all")) // config should includes gpt2 models's api.config - Expect(cm.configs).To(HaveKey("gpt4all-2")) + Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2")) // config should includes text-embedding-ada-002 models's api.config - Expect(cm.configs).To(HaveKey("text-embedding-ada-002")) + Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002")) // config should includes rwkv_test models's api.config - Expect(cm.configs).To(HaveKey("rwkv_test")) + Expect(cm.ListConfigs()).To(ContainElements("rwkv_test")) // config should includes whisper-1 models's api.config - Expect(cm.configs).To(HaveKey("whisper-1")) + Expect(cm.ListConfigs()).To(ContainElements("whisper-1")) }) }) }) diff --git a/api/config/prediction.go b/api/config/prediction.go new file mode 100644 index 0000000..59f4fcb --- /dev/null +++ b/api/config/prediction.go @@ -0,0 +1,37 @@ +package api_config + +type PredictionOptions struct { + + // Also part of the OpenAI official spec + Model string `json:"model" yaml:"model"` + + // Also part of the OpenAI official spec + Language string `json:"language"` + + // Also part of the OpenAI official spec. use it for returning multiple results + N int `json:"n"` + + // Common options between all the API calls, part of the OpenAI spec + TopP float64 `json:"top_p" yaml:"top_p"` + TopK int `json:"top_k" yaml:"top_k"` + Temperature float64 `json:"temperature" yaml:"temperature"` + Maxtokens int `json:"max_tokens" yaml:"max_tokens"` + Echo bool `json:"echo"` + + // Custom parameters - not present in the OpenAI API + Batch int `json:"batch" yaml:"batch"` + F16 bool `json:"f16" yaml:"f16"` + IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` + RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` + Keep int `json:"n_keep" yaml:"n_keep"` + + MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` + MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` + Mirostat int `json:"mirostat" yaml:"mirostat"` + + FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` + TFZ float64 `json:"tfz" yaml:"tfz"` + + TypicalP float64 `json:"typical_p" yaml:"typical_p"` + Seed int `json:"seed" yaml:"seed"` +} diff --git a/api/localai.go b/api/localai.go deleted file mode 100644 index b719689..0000000 --- a/api/localai.go +++ /dev/null @@ -1,78 +0,0 @@ -package api - -import ( - "fmt" - "os" - "path/filepath" - - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/tts" - "github.com/go-skynet/LocalAI/pkg/utils" - llama "github.com/go-skynet/go-llama.cpp" - "github.com/gofiber/fiber/v2" -) - -type TTSRequest struct { - Model string `json:"model" yaml:"model"` - Input string `json:"input" yaml:"input"` -} - -func generateUniqueFileName(dir, baseName, ext string) string { - counter := 1 - fileName := baseName + ext - - for { - filePath := filepath.Join(dir, fileName) - _, err := os.Stat(filePath) - if os.IsNotExist(err) { - return fileName - } - - counter++ - fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) - } -} - -func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - - input := new(TTSRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - - piperModel, err := o.loader.BackendLoader(model.PiperBackend, input.Model, []llama.ModelOption{}, uint32(0), o.assetsDestination) - if err != nil { - return err - } - - if piperModel == nil { - return fmt.Errorf("could not load piper model") - } - - w, ok := piperModel.(*tts.Piper) - if !ok { - return fmt.Errorf("loader returned non-piper object %+v", w) - } - - if err := os.MkdirAll(o.audioDir, 0755); err != nil { - return err - } - - fileName := generateUniqueFileName(o.audioDir, "piper", ".wav") - filePath := filepath.Join(o.audioDir, fileName) - - modelPath := filepath.Join(o.loader.ModelPath, input.Model) - - if err := utils.VerifyPath(modelPath, o.loader.ModelPath); err != nil { - return err - } - - if err := w.TTS(input.Input, modelPath, filePath); err != nil { - return err - } - - return c.Download(filePath) - } -} diff --git a/api/gallery.go b/api/localai/gallery.go similarity index 86% rename from api/gallery.go rename to api/localai/gallery.go index 1c0cec9..feae294 100644 --- a/api/gallery.go +++ b/api/localai/gallery.go @@ -1,4 +1,4 @@ -package api +package localai import ( "context" @@ -9,6 +9,7 @@ import ( json "github.com/json-iterator/go" + config "github.com/go-skynet/LocalAI/api/config" "github.com/go-skynet/LocalAI/pkg/gallery" "github.com/gofiber/fiber/v2" "github.com/google/uuid" @@ -38,7 +39,7 @@ type galleryApplier struct { statuses map[string]*galleryOpStatus } -func newGalleryApplier(modelPath string) *galleryApplier { +func NewGalleryService(modelPath string) *galleryApplier { return &galleryApplier{ modelPath: modelPath, C: make(chan galleryOp), @@ -47,7 +48,7 @@ func newGalleryApplier(modelPath string) *galleryApplier { } // prepareModel applies a -func prepareModel(modelPath string, req gallery.GalleryModel, cm *ConfigMerger, downloadStatus func(string, string, string, float64)) error { +func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error { config, err := gallery.GetGalleryConfigFromURL(req.URL) if err != nil { @@ -72,7 +73,7 @@ func (g *galleryApplier) getStatus(s string) *galleryOpStatus { return g.statuses[s] } -func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { +func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { go func() { for { select { @@ -148,7 +149,7 @@ type galleryModel struct { ID string `json:"id"` } -func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error { +func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { dat, err := os.ReadFile(s) if err != nil { return err @@ -156,7 +157,7 @@ func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger, galleries []gal return ApplyGalleryFromString(modelPath, string(dat), cm, galleries) } -func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []gallery.Gallery) error { +func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { var requests []galleryModel err := json.Unmarshal([]byte(s), &requests) if err != nil { @@ -174,7 +175,9 @@ func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger, galleries []g return err } -func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { +/// Endpoints + +func GetOpStatusEndpoint(g *galleryApplier) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { status := g.getStatus(c.Params("uuid")) @@ -191,7 +194,7 @@ type GalleryModel struct { gallery.GalleryModel } -func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error { +func ApplyModelGalleryEndpoint(modelPath string, cm *config.ConfigLoader, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { input := new(GalleryModel) // Get input data from the request body @@ -216,7 +219,7 @@ func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp, gal } } -func listModelFromGallery(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error { +func ListModelFromGalleryEndpoint(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { log.Debug().Msgf("Listing models from galleries: %+v", galleries) diff --git a/api/localai/localai.go b/api/localai/localai.go new file mode 100644 index 0000000..7c57c92 --- /dev/null +++ b/api/localai/localai.go @@ -0,0 +1,84 @@ +package localai + +import ( + "context" + "fmt" + "os" + "path/filepath" + + config "github.com/go-skynet/LocalAI/api/config" + + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/go-skynet/LocalAI/pkg/utils" + "github.com/gofiber/fiber/v2" +) + +type TTSRequest struct { + Model string `json:"model" yaml:"model"` + Input string `json:"input" yaml:"input"` +} + +func generateUniqueFileName(dir, baseName, ext string) string { + counter := 1 + fileName := baseName + ext + + for { + filePath := filepath.Join(dir, fileName) + _, err := os.Stat(filePath) + if os.IsNotExist(err) { + return fileName + } + + counter++ + fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) + } +} + +func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + + input := new(TTSRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return err + } + + piperModel, err := o.Loader.BackendLoader( + model.WithBackendString(model.PiperBackend), + model.WithModelFile(input.Model), + model.WithContext(o.Context), + model.WithAssetDir(o.AssetsDestination)) + if err != nil { + return err + } + + if piperModel == nil { + return fmt.Errorf("could not load piper model") + } + + if err := os.MkdirAll(o.AudioDir, 0755); err != nil { + return fmt.Errorf("failed creating audio directory: %s", err) + } + + fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") + filePath := filepath.Join(o.AudioDir, fileName) + + modelPath := filepath.Join(o.Loader.ModelPath, input.Model) + + if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { + return err + } + + if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ + Text: input.Input, + Model: modelPath, + Dst: filePath, + }); err != nil { + return err + } + + return c.Download(filePath) + } +} diff --git a/api/openai.go b/api/openai.go deleted file mode 100644 index 77d2c8e..0000000 --- a/api/openai.go +++ /dev/null @@ -1,961 +0,0 @@ -package api - -import ( - "bufio" - "bytes" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "io/ioutil" - "net/http" - "os" - "path" - "path/filepath" - "strconv" - "strings" - - "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" - "github.com/go-skynet/LocalAI/pkg/grammar" - model "github.com/go-skynet/LocalAI/pkg/model" - whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" - llama "github.com/go-skynet/go-llama.cpp" - "github.com/gofiber/fiber/v2" - "github.com/rs/zerolog/log" - "github.com/valyala/fasthttp" -) - -// APIError provides error information returned by the OpenAI API. -type APIError struct { - Code any `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` -} - -type ErrorResponse struct { - Error *APIError `json:"error,omitempty"` -} - -type OpenAIUsage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type Item struct { - Embedding []float32 `json:"embedding"` - Index int `json:"index"` - Object string `json:"object,omitempty"` - - // Images - URL string `json:"url,omitempty"` - B64JSON string `json:"b64_json,omitempty"` -} - -type OpenAIResponse struct { - Created int `json:"created,omitempty"` - Object string `json:"object,omitempty"` - ID string `json:"id,omitempty"` - Model string `json:"model,omitempty"` - Choices []Choice `json:"choices,omitempty"` - Data []Item `json:"data,omitempty"` - - Usage OpenAIUsage `json:"usage"` -} - -type Choice struct { - Index int `json:"index,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - Message *Message `json:"message,omitempty"` - Delta *Message `json:"delta,omitempty"` - Text string `json:"text,omitempty"` -} - -type Message struct { - // The message role - Role string `json:"role,omitempty" yaml:"role"` - // The message content - Content *string `json:"content" yaml:"content"` - // A result of a function call - FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` -} - -type OpenAIModel struct { - ID string `json:"id"` - Object string `json:"object"` -} - -type OpenAIRequest struct { - Model string `json:"model" yaml:"model"` - - // whisper - File string `json:"file" validate:"required"` - Language string `json:"language"` - //whisper/image - ResponseFormat string `json:"response_format"` - // image - Size string `json:"size"` - // Prompt is read only by completion/image API calls - Prompt interface{} `json:"prompt" yaml:"prompt"` - - // Edit endpoint - Instruction string `json:"instruction" yaml:"instruction"` - Input interface{} `json:"input" yaml:"input"` - - Stop interface{} `json:"stop" yaml:"stop"` - - // Messages is read only by chat/completion API calls - Messages []Message `json:"messages" yaml:"messages"` - - // A list of available functions to call - Functions []grammar.Function `json:"functions" yaml:"functions"` - FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object - - Stream bool `json:"stream"` - Echo bool `json:"echo"` - // Common options between all the API calls - TopP float64 `json:"top_p" yaml:"top_p"` - TopK int `json:"top_k" yaml:"top_k"` - Temperature float64 `json:"temperature" yaml:"temperature"` - Maxtokens int `json:"max_tokens" yaml:"max_tokens"` - - N int `json:"n"` - - // Custom parameters - not present in the OpenAI API - Batch int `json:"batch" yaml:"batch"` - F16 bool `json:"f16" yaml:"f16"` - IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` - RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` - Keep int `json:"n_keep" yaml:"n_keep"` - - MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` - MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` - Mirostat int `json:"mirostat" yaml:"mirostat"` - - FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` - TFZ float64 `json:"tfz" yaml:"tfz"` - - Seed int `json:"seed" yaml:"seed"` - - // Image (not supported by OpenAI) - Mode int `json:"mode"` - Step int `json:"step"` - - // A grammar to constrain the LLM output - Grammar string `json:"grammar" yaml:"grammar"` - // A grammar object - JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` - - TypicalP float64 `json:"typical_p" yaml:"typical_p"` -} - -func defaultRequest(modelFile string) OpenAIRequest { - return OpenAIRequest{ - TopP: 0.7, - TopK: 80, - Maxtokens: 512, - Temperature: 0.9, - Model: modelFile, - } -} - -// https://platform.openai.com/docs/api-reference/completions -func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { - ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { - resp := OpenAIResponse{ - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ - { - Index: 0, - Text: s, - }, - }, - Object: "text_completion", - } - log.Debug().Msgf("Sending goroutine: %s", s) - - responses <- resp - return true - }) - close(responses) - } - - return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("`input`: %+v", input) - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - if input.Stream { - log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - //c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - } - - templateFile := config.Model - - if config.TemplateConfig.Completion != "" { - templateFile = config.TemplateConfig.Completion - } - - if input.Stream { - if len(config.PromptStrings) > 1 { - return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") - } - - predInput := config.PromptStrings[0] - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - }{ - Input: predInput, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } - - responses := make(chan OpenAIResponse) - - go process(predInput, input, config, o.loader, responses) - - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - for ev := range responses { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - - log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) - w.Flush() - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ - { - Index: 0, - FinishReason: "stop", - }, - }, - Object: "text_completion", - } - respData, _ := json.Marshal(resp) - - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return nil - } - - var result []Choice - for _, i := range config.PromptStrings { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - }{ - Input: i, - }) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - - r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Text: s}) - }, nil) - if err != nil { - return err - } - - result = append(result, r...) - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "text_completion", - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -// https://platform.openai.com/docs/api-reference/embeddings -func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - items := []Item{} - - for i, s := range config.InputToken { - // get the model function to call for the result - embedFn, err := ModelEmbedding("", s, o.loader, *config, o) - if err != nil { - return err - } - - embeddings, err := embedFn() - if err != nil { - return err - } - items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - for i, s := range config.InputStrings { - // get the model function to call for the result - embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config, o) - if err != nil { - return err - } - - embeddings, err := embedFn() - if err != nil { - return err - } - items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Data: items, - Object: "list", - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - - process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { - initialMessage := OpenAIResponse{ - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, - Object: "chat.completion.chunk", - } - responses <- initialMessage - - ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { - resp := OpenAIResponse{ - Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, - Object: "chat.completion.chunk", - } - log.Debug().Msgf("Sending goroutine: %s", s) - - responses <- resp - return true - }) - close(responses) - } - return func(c *fiber.Ctx) error { - processFunctions := false - funcs := grammar.Functions{} - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - log.Debug().Msgf("Configuration read: %+v", config) - - // Allow the user to set custom actions via config file - // to be "embedded" in each model - noActionName := "answer" - noActionDescription := "use this action to answer without performing any action" - - if config.FunctionsConfig.NoActionFunctionName != "" { - noActionName = config.FunctionsConfig.NoActionFunctionName - } - if config.FunctionsConfig.NoActionDescriptionName != "" { - noActionDescription = config.FunctionsConfig.NoActionDescriptionName - } - - // process functions if we have any defined or if we have a function call string - if len(input.Functions) > 0 && - ((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) { - log.Debug().Msgf("Response needs to process functions") - - processFunctions = true - - noActionGrammar := grammar.Function{ - Name: noActionName, - Description: noActionDescription, - Parameters: map[string]interface{}{ - "properties": map[string]interface{}{ - "message": map[string]interface{}{ - "type": "string", - "description": "The message to reply the user with", - }}, - }, - } - - // Append the no action function - funcs = append(funcs, input.Functions...) - if !config.FunctionsConfig.DisableNoAction { - funcs = append(funcs, noActionGrammar) - } - - // Force picking one of the functions by the request - if config.functionCallNameString != "" { - funcs = funcs.Select(config.functionCallNameString) - } - - // Update input grammar - jsStruct := funcs.ToJSONStructure() - config.Grammar = jsStruct.Grammar("") - } else if input.JSONFunctionGrammarObject != nil { - config.Grammar = input.JSONFunctionGrammarObject.Grammar("") - } - - // functions are not supported in stream mode (yet?) - toStream := input.Stream && !processFunctions - - log.Debug().Msgf("Parameters: %+v", config) - - var predInput string - - mess := []string{} - for _, i := range input.Messages { - var content string - role := i.Role - // if function call, we might want to customize the role so we can display better that the "assistant called a json action" - // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request - if i.FunctionCall != nil && i.Role == "assistant" { - roleFn := "assistant_function_call" - r := config.Roles[roleFn] - if r != "" { - role = roleFn - } - } - r := config.Roles[role] - contentExists := i.Content != nil && *i.Content != "" - if r != "" { - if contentExists { - content = fmt.Sprint(r, " ", *i.Content) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + fmt.Sprint(r, " ", string(j)) - } else { - content = fmt.Sprint(r, " ", string(j)) - } - } - } - } else { - if contentExists { - content = fmt.Sprint(*i.Content) - } - if i.FunctionCall != nil { - j, err := json.Marshal(i.FunctionCall) - if err == nil { - if contentExists { - content += "\n" + string(j) - } else { - content = string(j) - } - } - } - } - - mess = append(mess, content) - } - - predInput = strings.Join(mess, "\n") - log.Debug().Msgf("Prompt (before templating): %s", predInput) - - if toStream { - log.Debug().Msgf("Stream request received") - c.Context().SetContentType("text/event-stream") - //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - // c.Set("Content-Type", "text/event-stream") - c.Set("Cache-Control", "no-cache") - c.Set("Connection", "keep-alive") - c.Set("Transfer-Encoding", "chunked") - } - - templateFile := config.Model - - if config.TemplateConfig.Chat != "" && !processFunctions { - templateFile = config.TemplateConfig.Chat - } - - if config.TemplateConfig.Functions != "" && processFunctions { - templateFile = config.TemplateConfig.Functions - } - - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - Functions []grammar.Function - }{ - Input: predInput, - Functions: funcs, - }) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } else { - log.Debug().Msgf("Template failed loading: %s", err.Error()) - } - - log.Debug().Msgf("Prompt (after templating): %s", predInput) - if processFunctions { - log.Debug().Msgf("Grammar: %+v", config.Grammar) - } - - if toStream { - responses := make(chan OpenAIResponse) - - go process(predInput, input, config, o.loader, responses) - - c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - - for ev := range responses { - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.Encode(ev) - - log.Debug().Msgf("Sending chunk: %s", buf.String()) - fmt.Fprintf(w, "data: %v\n", buf.String()) - w.Flush() - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{ - { - FinishReason: "stop", - Index: 0, - Delta: &Message{}, - }}, - Object: "chat.completion.chunk", - } - respData, _ := json.Marshal(resp) - - w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) - w.WriteString("data: [DONE]\n\n") - w.Flush() - })) - return nil - } - - result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) { - if processFunctions { - // As we have to change the result before processing, we can't stream the answer (yet?) - ss := map[string]interface{}{} - json.Unmarshal([]byte(s), &ss) - log.Debug().Msgf("Function return: %s %+v", s, ss) - - // The grammar defines the function name as "function", while OpenAI returns "name" - func_name := ss["function"] - // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object - args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) - d, _ := json.Marshal(args) - - ss["arguments"] = string(d) - ss["name"] = func_name - - // if do nothing, reply with a message - if func_name == noActionName { - log.Debug().Msgf("nothing to do, computing a reply") - - // If there is a message that the LLM already sends as part of the JSON reply, use it - arguments := map[string]interface{}{} - json.Unmarshal([]byte(d), &arguments) - m, exists := arguments["message"] - if exists { - switch message := m.(type) { - case string: - if message != "" { - log.Debug().Msgf("Reply received from LLM: %s", message) - message = Finetune(*config, predInput, message) - log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) - return - } - } - } - - log.Debug().Msgf("No action received from LLM, without a message, computing a reply") - // Otherwise ask the LLM to understand the JSON output and the context, and return a message - // Note: This costs (in term of CPU) another computation - config.Grammar = "" - predFunc, err := ModelInference(predInput, o.loader, *config, o, nil) - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } - - prediction, err := predFunc() - if err != nil { - log.Error().Msgf("inference error: %s", err.Error()) - return - } - - prediction = Finetune(*config, predInput, prediction) - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) - } else { - // otherwise reply with the function call - *c = append(*c, Choice{ - FinishReason: "function_call", - Message: &Message{Role: "assistant", FunctionCall: ss}, - }) - } - - return - } - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) - }, nil) - if err != nil { - return err - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "chat.completion", - } - respData, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", respData) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - model, input, err := readInput(c, o.loader, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - templateFile := config.Model - - if config.TemplateConfig.Edit != "" { - templateFile = config.TemplateConfig.Edit - } - - var result []Choice - for _, i := range config.InputStrings { - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - Instruction string - }{Input: i}) - if err == nil { - i = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", i) - } - - r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Text: s}) - }, nil) - if err != nil { - return err - } - - result = append(result, r...) - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "edit", - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -// https://platform.openai.com/docs/api-reference/images/create - -/* -* - - curl http://localhost:8080/v1/images/generations \ - -H "Content-Type: application/json" \ - -d '{ - "prompt": "A cute baby sea otter", - "n": 1, - "size": "512x512" - }' - -* -*/ -func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - m, input, err := readInput(c, o.loader, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - if m == "" { - m = model.StableDiffusionBackend - } - log.Debug().Msgf("Loading model: %+v", m) - - config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - log.Debug().Msgf("Parameter Config: %+v", config) - - // XXX: Only stablediffusion is supported for now - if config.Backend == "" { - config.Backend = model.StableDiffusionBackend - } - - sizeParts := strings.Split(input.Size, "x") - if len(sizeParts) != 2 { - return fmt.Errorf("Invalid value for 'size'") - } - width, err := strconv.Atoi(sizeParts[0]) - if err != nil { - return fmt.Errorf("Invalid value for 'size'") - } - height, err := strconv.Atoi(sizeParts[1]) - if err != nil { - return fmt.Errorf("Invalid value for 'size'") - } - - b64JSON := false - if input.ResponseFormat == "b64_json" { - b64JSON = true - } - - var result []Item - for _, i := range config.PromptStrings { - n := input.N - if input.N == 0 { - n = 1 - } - for j := 0; j < n; j++ { - prompts := strings.Split(i, "|") - positive_prompt := prompts[0] - negative_prompt := "" - if len(prompts) > 1 { - negative_prompt = prompts[1] - } - - mode := 0 - step := 15 - - if input.Mode != 0 { - mode = input.Mode - } - - if input.Step != 0 { - step = input.Step - } - - tempDir := "" - if !b64JSON { - tempDir = o.imageDir - } - // Create a temporary file - outputFile, err := ioutil.TempFile(tempDir, "b64") - if err != nil { - return err - } - outputFile.Close() - output := outputFile.Name() + ".png" - // Rename the temporary file - err = os.Rename(outputFile.Name(), output) - if err != nil { - return err - } - - baseURL := c.BaseURL() - - fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config, o) - if err != nil { - return err - } - if err := fn(); err != nil { - return err - } - - item := &Item{} - - if b64JSON { - defer os.RemoveAll(output) - data, err := os.ReadFile(output) - if err != nil { - return err - } - item.B64JSON = base64.StdEncoding.EncodeToString(data) - } else { - base := filepath.Base(output) - item.URL = baseURL + "/generated-images/" + base - } - - result = append(result, *item) - } - } - - resp := &OpenAIResponse{ - Data: result, - } - - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - - // Return the prediction in the response body - return c.JSON(resp) - } -} - -// https://platform.openai.com/docs/api-reference/audio/create -func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - m, input, err := readInput(c, o.loader, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - // retrieve the file data from the request - file, err := c.FormFile("file") - if err != nil { - return err - } - f, err := file.Open() - if err != nil { - return err - } - defer f.Close() - - dir, err := os.MkdirTemp("", "whisper") - - if err != nil { - return err - } - defer os.RemoveAll(dir) - - dst := filepath.Join(dir, path.Base(file.Filename)) - dstFile, err := os.Create(dst) - if err != nil { - return err - } - - if _, err := io.Copy(dstFile, f); err != nil { - log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) - return err - } - - log.Debug().Msgf("Audio file copied to: %+v", dst) - - whisperModel, err := o.loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads), o.assetsDestination) - if err != nil { - return err - } - - if whisperModel == nil { - return fmt.Errorf("could not load whisper model") - } - - w, ok := whisperModel.(whisper.Model) - if !ok { - return fmt.Errorf("loader returned non-whisper object") - } - - tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) - if err != nil { - return err - } - - log.Debug().Msgf("Trascribed: %+v", tr) - // TODO: handle different outputs here - return c.Status(http.StatusOK).JSON(tr) - } -} - -func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - models, err := loader.ListModels() - if err != nil { - return err - } - var mm map[string]interface{} = map[string]interface{}{} - - dataModels := []OpenAIModel{} - for _, m := range models { - mm[m] = nil - dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) - } - - for _, k := range cm.ListConfigs() { - if _, exists := mm[k]; !exists { - dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) - } - } - - return c.JSON(struct { - Object string `json:"object"` - Data []OpenAIModel `json:"data"` - }{ - Object: "list", - Data: dataModels, - }) - } -} diff --git a/api/openai/api.go b/api/openai/api.go new file mode 100644 index 0000000..6d7ce5e --- /dev/null +++ b/api/openai/api.go @@ -0,0 +1,105 @@ +package openai + +import ( + config "github.com/go-skynet/LocalAI/api/config" + + "github.com/go-skynet/LocalAI/pkg/grammar" +) + +// APIError provides error information returned by the OpenAI API. +type APIError struct { + Code any `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` +} + +type ErrorResponse struct { + Error *APIError `json:"error,omitempty"` +} + +type OpenAIUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Item struct { + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + Object string `json:"object,omitempty"` + + // Images + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` +} + +type OpenAIResponse struct { + Created int `json:"created,omitempty"` + Object string `json:"object,omitempty"` + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Choices []Choice `json:"choices,omitempty"` + Data []Item `json:"data,omitempty"` + + Usage OpenAIUsage `json:"usage"` +} + +type Choice struct { + Index int `json:"index,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Message *Message `json:"message,omitempty"` + Delta *Message `json:"delta,omitempty"` + Text string `json:"text,omitempty"` +} + +type Message struct { + // The message role + Role string `json:"role,omitempty" yaml:"role"` + // The message content + Content *string `json:"content" yaml:"content"` + // A result of a function call + FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` +} + +type OpenAIModel struct { + ID string `json:"id"` + Object string `json:"object"` +} + +type OpenAIRequest struct { + config.PredictionOptions + + // whisper + File string `json:"file" validate:"required"` + //whisper/image + ResponseFormat string `json:"response_format"` + // image + Size string `json:"size"` + // Prompt is read only by completion/image API calls + Prompt interface{} `json:"prompt" yaml:"prompt"` + + // Edit endpoint + Instruction string `json:"instruction" yaml:"instruction"` + Input interface{} `json:"input" yaml:"input"` + + Stop interface{} `json:"stop" yaml:"stop"` + + // Messages is read only by chat/completion API calls + Messages []Message `json:"messages" yaml:"messages"` + + // A list of available functions to call + Functions []grammar.Function `json:"functions" yaml:"functions"` + FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object + + Stream bool `json:"stream"` + + // Image (not supported by OpenAI) + Mode int `json:"mode"` + Step int `json:"step"` + + // A grammar to constrain the LLM output + Grammar string `json:"grammar" yaml:"grammar"` + + JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` +} diff --git a/api/openai/chat.go b/api/openai/chat.go new file mode 100644 index 0000000..30f6e01 --- /dev/null +++ b/api/openai/chat.go @@ -0,0 +1,320 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "strings" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grammar" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { + initialMessage := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + + ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { + resp := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, + Object: "chat.completion.chunk", + } + + responses <- resp + return true + }) + close(responses) + } + return func(c *fiber.Ctx) error { + processFunctions := false + funcs := grammar.Functions{} + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + log.Debug().Msgf("Configuration read: %+v", config) + + // Allow the user to set custom actions via config file + // to be "embedded" in each model + noActionName := "answer" + noActionDescription := "use this action to answer without performing any action" + + if config.FunctionsConfig.NoActionFunctionName != "" { + noActionName = config.FunctionsConfig.NoActionFunctionName + } + if config.FunctionsConfig.NoActionDescriptionName != "" { + noActionDescription = config.FunctionsConfig.NoActionDescriptionName + } + + // process functions if we have any defined or if we have a function call string + if len(input.Functions) > 0 && config.ShouldUseFunctions() { + log.Debug().Msgf("Response needs to process functions") + + processFunctions = true + + noActionGrammar := grammar.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } + + // Append the no action function + funcs = append(funcs, input.Functions...) + if !config.FunctionsConfig.DisableNoAction { + funcs = append(funcs, noActionGrammar) + } + + // Force picking one of the functions by the request + if config.FunctionToCall() != "" { + funcs = funcs.Select(config.FunctionToCall()) + } + + // Update input grammar + jsStruct := funcs.ToJSONStructure() + config.Grammar = jsStruct.Grammar("") + } else if input.JSONFunctionGrammarObject != nil { + config.Grammar = input.JSONFunctionGrammarObject.Grammar("") + } + + // functions are not supported in stream mode (yet?) + toStream := input.Stream && !processFunctions + + log.Debug().Msgf("Parameters: %+v", config) + + var predInput string + + mess := []string{} + for _, i := range input.Messages { + var content string + role := i.Role + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" + // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request + if i.FunctionCall != nil && i.Role == "assistant" { + roleFn := "assistant_function_call" + r := config.Roles[roleFn] + if r != "" { + role = roleFn + } + } + r := config.Roles[role] + contentExists := i.Content != nil && *i.Content != "" + if r != "" { + if contentExists { + content = fmt.Sprint(r, " ", *i.Content) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } + } + } + } else { + if contentExists { + content = fmt.Sprint(*i.Content) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } + } + } + } + + mess = append(mess, content) + } + + predInput = strings.Join(mess, "\n") + log.Debug().Msgf("Prompt (before templating): %s", predInput) + + if toStream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + // c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + + templateFile := config.Model + + if config.TemplateConfig.Chat != "" && !processFunctions { + templateFile = config.TemplateConfig.Chat + } + + if config.TemplateConfig.Functions != "" && processFunctions { + templateFile = config.TemplateConfig.Functions + } + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + Functions []grammar.Function + }{ + Input: predInput, + Functions: funcs, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) + } + + log.Debug().Msgf("Prompt (after templating): %s", predInput) + if processFunctions { + log.Debug().Msgf("Grammar: %+v", config.Grammar) + } + + if toStream { + responses := make(chan OpenAIResponse) + + go process(predInput, input, config, o.Loader, responses) + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{ + { + FinishReason: "stop", + Index: 0, + Delta: &Message{}, + }}, + Object: "chat.completion.chunk", + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + + result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + if processFunctions { + // As we have to change the result before processing, we can't stream the answer (yet?) + ss := map[string]interface{}{} + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name := ss["function"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + d, _ := json.Marshal(args) + + ss["arguments"] = string(d) + ss["name"] = func_name + + // if do nothing, reply with a message + if func_name == noActionName { + log.Debug().Msgf("nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(d), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = backend.Finetune(*config, predInput, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) + return + } + } + } + + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU) another computation + config.Grammar = "" + predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil) + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction, err := predFunc() + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction = backend.Finetune(*config, predInput, prediction) + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) + } else { + // otherwise reply with the function call + *c = append(*c, Choice{ + FinishReason: "function_call", + Message: &Message{Role: "assistant", FunctionCall: ss}, + }) + } + + return + } + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) + }, nil) + if err != nil { + return err + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + } + respData, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", respData) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/completion.go b/api/openai/completion.go new file mode 100644 index 0000000..d17fd60 --- /dev/null +++ b/api/openai/completion.go @@ -0,0 +1,159 @@ +package openai + +import ( + "bufio" + "bytes" + "encoding/json" + "errors" + "fmt" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" + "github.com/valyala/fasthttp" +) + +// https://platform.openai.com/docs/api-reference/completions +func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { + ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { + resp := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{ + { + Index: 0, + Text: s, + }, + }, + Object: "text_completion", + } + log.Debug().Msgf("Sending goroutine: %s", s) + + responses <- resp + return true + }) + close(responses) + } + + return func(c *fiber.Ctx) error { + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("`input`: %+v", input) + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + if input.Stream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + //c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + + templateFile := config.Model + + if config.TemplateConfig.Completion != "" { + templateFile = config.TemplateConfig.Completion + } + + if input.Stream { + if len(config.PromptStrings) > 1 { + return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") + } + + predInput := config.PromptStrings[0] + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + }{ + Input: predInput, + }) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + + responses := make(chan OpenAIResponse) + + go process(predInput, input, config, o.Loader, responses) + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{ + { + Index: 0, + FinishReason: "stop", + }, + }, + Object: "text_completion", + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + + var result []Choice + for _, i := range config.PromptStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + }{ + Input: i, + }) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + + r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + result = append(result, r...) + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "text_completion", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/edit.go b/api/openai/edit.go new file mode 100644 index 0000000..d988d6d --- /dev/null +++ b/api/openai/edit.go @@ -0,0 +1,67 @@ +package openai + +import ( + "encoding/json" + "fmt" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + templateFile := config.Model + + if config.TemplateConfig.Edit != "" { + templateFile = config.TemplateConfig.Edit + } + + var result []Choice + for _, i := range config.InputStrings { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { + Input string + Instruction string + }{Input: i}) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } + + r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + result = append(result, r...) + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "edit", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/embeddings.go b/api/openai/embeddings.go new file mode 100644 index 0000000..248ae5c --- /dev/null +++ b/api/openai/embeddings.go @@ -0,0 +1,70 @@ +package openai + +import ( + "encoding/json" + "fmt" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/embeddings +func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + model, input, err := readInput(c, o.Loader, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + items := []Item{} + + for i, s := range config.InputToken { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + for i, s := range config.InputStrings { + // get the model function to call for the result + embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) + if err != nil { + return err + } + + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Data: items, + Object: "list", + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/image.go b/api/openai/image.go new file mode 100644 index 0000000..bca54c1 --- /dev/null +++ b/api/openai/image.go @@ -0,0 +1,158 @@ +package openai + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/images/create + +/* +* + + curl http://localhost:8080/v1/images/generations \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A cute baby sea otter", + "n": 1, + "size": "512x512" + }' + +* +*/ +func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + m, input, err := readInput(c, o.Loader, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + if m == "" { + m = model.StableDiffusionBackend + } + log.Debug().Msgf("Loading model: %+v", m) + + config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + log.Debug().Msgf("Parameter Config: %+v", config) + + // XXX: Only stablediffusion is supported for now + if config.Backend == "" { + config.Backend = model.StableDiffusionBackend + } + + sizeParts := strings.Split(input.Size, "x") + if len(sizeParts) != 2 { + return fmt.Errorf("Invalid value for 'size'") + } + width, err := strconv.Atoi(sizeParts[0]) + if err != nil { + return fmt.Errorf("Invalid value for 'size'") + } + height, err := strconv.Atoi(sizeParts[1]) + if err != nil { + return fmt.Errorf("Invalid value for 'size'") + } + + b64JSON := false + if input.ResponseFormat == "b64_json" { + b64JSON = true + } + + var result []Item + for _, i := range config.PromptStrings { + n := input.N + if input.N == 0 { + n = 1 + } + for j := 0; j < n; j++ { + prompts := strings.Split(i, "|") + positive_prompt := prompts[0] + negative_prompt := "" + if len(prompts) > 1 { + negative_prompt = prompts[1] + } + + mode := 0 + step := 15 + + if input.Mode != 0 { + mode = input.Mode + } + + if input.Step != 0 { + step = input.Step + } + + tempDir := "" + if !b64JSON { + tempDir = o.ImageDir + } + // Create a temporary file + outputFile, err := ioutil.TempFile(tempDir, "b64") + if err != nil { + return err + } + outputFile.Close() + output := outputFile.Name() + ".png" + // Rename the temporary file + err = os.Rename(outputFile.Name(), output) + if err != nil { + return err + } + + baseURL := c.BaseURL() + + fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.Loader, *config, o) + if err != nil { + return err + } + if err := fn(); err != nil { + return err + } + + item := &Item{} + + if b64JSON { + defer os.RemoveAll(output) + data, err := os.ReadFile(output) + if err != nil { + return err + } + item.B64JSON = base64.StdEncoding.EncodeToString(data) + } else { + base := filepath.Base(output) + item.URL = baseURL + "/generated-images/" + base + } + + result = append(result, *item) + } + } + + resp := &OpenAIResponse{ + Data: result, + } + + jsonResult, _ := json.Marshal(resp) + log.Debug().Msgf("Response: %s", jsonResult) + + // Return the prediction in the response body + return c.JSON(resp) + } +} diff --git a/api/openai/inference.go b/api/openai/inference.go new file mode 100644 index 0000000..a9991fa --- /dev/null +++ b/api/openai/inference.go @@ -0,0 +1,36 @@ +package openai + +import ( + "github.com/go-skynet/LocalAI/api/backend" + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + model "github.com/go-skynet/LocalAI/pkg/model" +) + +func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { + result := []Choice{} + + if n == 0 { + n = 1 + } + + // get the model function to call for the result + predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback) + if err != nil { + return result, err + } + + for i := 0; i < n; i++ { + prediction, err := predFunc() + if err != nil { + return result, err + } + + prediction = backend.Finetune(*config, predInput, prediction) + cb(prediction, &result) + + //result = append(result, Choice{Text: prediction}) + + } + return result, err +} diff --git a/api/openai/list.go b/api/openai/list.go new file mode 100644 index 0000000..0cd7f3a --- /dev/null +++ b/api/openai/list.go @@ -0,0 +1,37 @@ +package openai + +import ( + config "github.com/go-skynet/LocalAI/api/config" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" +) + +func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + models, err := loader.ListModels() + if err != nil { + return err + } + var mm map[string]interface{} = map[string]interface{}{} + + dataModels := []OpenAIModel{} + for _, m := range models { + mm[m] = nil + dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) + } + + for _, k := range cm.ListConfigs() { + if _, exists := mm[k]; !exists { + dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) + } + } + + return c.JSON(struct { + Object string `json:"object"` + Data []OpenAIModel `json:"data"` + }{ + Object: "list", + Data: dataModels, + }) + } +} diff --git a/api/openai/request.go b/api/openai/request.go new file mode 100644 index 0000000..84dbaa8 --- /dev/null +++ b/api/openai/request.go @@ -0,0 +1,234 @@ +package openai + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + config "github.com/go-skynet/LocalAI/api/config" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { + input := new(OpenAIRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return "", nil, err + } + + modelFile := input.Model + + if c.Params("model") != "" { + modelFile = c.Params("model") + } + + received, _ := json.Marshal(input) + + log.Debug().Msgf("Request received: %s", string(received)) + + // Set model from bearer token, if available + bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) + + // If no model was specified, take the first available + if modelFile == "" && !bearerExists && randomModel { + models, _ := loader.ListModels() + if len(models) > 0 { + modelFile = models[0] + log.Debug().Msgf("No model specified, using: %s", modelFile) + } else { + log.Debug().Msgf("No model specified, returning error") + return "", nil, fmt.Errorf("no model specified") + } + } + + // If a model is found in bearer token takes precedence + if bearerExists { + log.Debug().Msgf("Using model from bearer token: %s", bearer) + modelFile = bearer + } + return modelFile, input, nil +} + +func updateConfig(config *config.Config, input *OpenAIRequest) { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != 0 { + config.TopK = input.TopK + } + if input.TopP != 0 { + config.TopP = input.TopP + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != 0 { + config.Temperature = input.Temperature + } + + if input.Maxtokens != 0 { + config.Maxtokens = input.Maxtokens + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.F16 { + config.F16 = input.F16 + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != 0 { + config.Seed = input.Seed + } + + if input.Mirostat != 0 { + config.Mirostat = input.Mirostat + } + + if input.MirostatETA != 0 { + config.MirostatETA = input.MirostatETA + } + + if input.MirostatTAU != 0 { + config.MirostatTAU = input.MirostatTAU + } + + if input.TypicalP != 0 { + config.TypicalP = input.TypicalP + } + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []interface{}: + tokens := []int{} + for _, ii := range i { + tokens = append(tokens, int(ii.(float64))) + } + config.InputToken = append(config.InputToken, tokens) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if !e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } +} + +func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) { + // Load a config file if present after the model name + modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") + + var cfg *config.Config + + defaults := func() { + cfg = config.DefaultConfig(modelFile) + cfg.ContextSize = ctx + cfg.Threads = threads + cfg.F16 = f16 + cfg.Debug = debug + } + + cfgExisting, exists := cm.GetConfig(modelFile) + if !exists { + if _, err := os.Stat(modelConfig); err == nil { + if err := cm.LoadConfig(modelConfig); err != nil { + return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + cfgExisting, exists = cm.GetConfig(modelFile) + if exists { + cfg = &cfgExisting + } else { + defaults() + } + } else { + defaults() + } + } else { + cfg = &cfgExisting + } + + // Set the parameters for the language model prediction + updateConfig(cfg, input) + + // Don't allow 0 as setting + if cfg.Threads == 0 { + if threads != 0 { + cfg.Threads = threads + } else { + cfg.Threads = 4 + } + } + + // Enforce debug flag if passed from CLI + if debug { + cfg.Debug = true + } + + return cfg, input, nil +} diff --git a/api/openai/transcription.go b/api/openai/transcription.go new file mode 100644 index 0000000..346693c --- /dev/null +++ b/api/openai/transcription.go @@ -0,0 +1,91 @@ +package openai + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path" + "path/filepath" + + config "github.com/go-skynet/LocalAI/api/config" + "github.com/go-skynet/LocalAI/api/options" + "github.com/go-skynet/LocalAI/pkg/grpc/proto" + model "github.com/go-skynet/LocalAI/pkg/model" + + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +// https://platform.openai.com/docs/api-reference/audio/create +func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { + return func(c *fiber.Ctx) error { + m, input, err := readInput(c, o.Loader, false) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + + config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + // retrieve the file data from the request + file, err := c.FormFile("file") + if err != nil { + return err + } + f, err := file.Open() + if err != nil { + return err + } + defer f.Close() + + dir, err := os.MkdirTemp("", "whisper") + + if err != nil { + return err + } + defer os.RemoveAll(dir) + + dst := filepath.Join(dir, path.Base(file.Filename)) + dstFile, err := os.Create(dst) + if err != nil { + return err + } + + if _, err := io.Copy(dstFile, f); err != nil { + log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) + return err + } + + log.Debug().Msgf("Audio file copied to: %+v", dst) + + whisperModel, err := o.Loader.BackendLoader( + model.WithBackendString(model.WhisperBackend), + model.WithModelFile(config.Model), + model.WithContext(o.Context), + model.WithThreads(uint32(config.Threads)), + model.WithAssetDir(o.AssetsDestination)) + if err != nil { + return err + } + + if whisperModel == nil { + return fmt.Errorf("could not load whisper model") + } + + tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ + Dst: dst, + Language: input.Language, + Threads: uint32(config.Threads), + }) + if err != nil { + return err + } + + log.Debug().Msgf("Trascribed: %+v", tr) + // TODO: handle different outputs here + return c.Status(http.StatusOK).JSON(tr) + } +} diff --git a/api/options.go b/api/options/options.go similarity index 60% rename from api/options.go rename to api/options/options.go index 923288a..06029b0 100644 --- a/api/options.go +++ b/api/options/options.go @@ -1,4 +1,4 @@ -package api +package options import ( "context" @@ -11,35 +11,35 @@ import ( ) type Option struct { - context context.Context - configFile string - loader *model.ModelLoader - uploadLimitMB, threads, ctxSize int - f16 bool - debug, disableMessage bool - imageDir string - audioDir string - cors bool - preloadJSONModels string - preloadModelsFromPath string - corsAllowOrigins string + Context context.Context + ConfigFile string + Loader *model.ModelLoader + UploadLimitMB, Threads, ContextSize int + F16 bool + Debug, DisableMessage bool + ImageDir string + AudioDir string + CORS bool + PreloadJSONModels string + PreloadModelsFromPath string + CORSAllowOrigins string - galleries []gallery.Gallery + Galleries []gallery.Gallery - backendAssets embed.FS - assetsDestination string + BackendAssets embed.FS + AssetsDestination string } type AppOption func(*Option) -func newOptions(o ...AppOption) *Option { +func NewOptions(o ...AppOption) *Option { opt := &Option{ - context: context.Background(), - uploadLimitMB: 15, - threads: 1, - ctxSize: 512, - debug: true, - disableMessage: true, + Context: context.Background(), + UploadLimitMB: 15, + Threads: 1, + ContextSize: 512, + Debug: true, + DisableMessage: true, } for _, oo := range o { oo(opt) @@ -49,25 +49,25 @@ func newOptions(o ...AppOption) *Option { func WithCors(b bool) AppOption { return func(o *Option) { - o.cors = b + o.CORS = b } } func WithCorsAllowOrigins(b string) AppOption { return func(o *Option) { - o.corsAllowOrigins = b + o.CORSAllowOrigins = b } } func WithBackendAssetsOutput(out string) AppOption { return func(o *Option) { - o.assetsDestination = out + o.AssetsDestination = out } } func WithBackendAssets(f embed.FS) AppOption { return func(o *Option) { - o.backendAssets = f + o.BackendAssets = f } } @@ -81,89 +81,89 @@ func WithStringGalleries(galls string) AppOption { if err := json.Unmarshal([]byte(galls), &galleries); err != nil { log.Error().Msgf("failed loading galleries: %s", err.Error()) } - o.galleries = append(o.galleries, galleries...) + o.Galleries = append(o.Galleries, galleries...) } } func WithGalleries(galleries []gallery.Gallery) AppOption { return func(o *Option) { - o.galleries = append(o.galleries, galleries...) + o.Galleries = append(o.Galleries, galleries...) } } func WithContext(ctx context.Context) AppOption { return func(o *Option) { - o.context = ctx + o.Context = ctx } } func WithYAMLConfigPreload(configFile string) AppOption { return func(o *Option) { - o.preloadModelsFromPath = configFile + o.PreloadModelsFromPath = configFile } } func WithJSONStringPreload(configFile string) AppOption { return func(o *Option) { - o.preloadJSONModels = configFile + o.PreloadJSONModels = configFile } } func WithConfigFile(configFile string) AppOption { return func(o *Option) { - o.configFile = configFile + o.ConfigFile = configFile } } func WithModelLoader(loader *model.ModelLoader) AppOption { return func(o *Option) { - o.loader = loader + o.Loader = loader } } func WithUploadLimitMB(limit int) AppOption { return func(o *Option) { - o.uploadLimitMB = limit + o.UploadLimitMB = limit } } func WithThreads(threads int) AppOption { return func(o *Option) { - o.threads = threads + o.Threads = threads } } func WithContextSize(ctxSize int) AppOption { return func(o *Option) { - o.ctxSize = ctxSize + o.ContextSize = ctxSize } } func WithF16(f16 bool) AppOption { return func(o *Option) { - o.f16 = f16 + o.F16 = f16 } } func WithDebug(debug bool) AppOption { return func(o *Option) { - o.debug = debug + o.Debug = debug } } func WithDisableMessage(disableMessage bool) AppOption { return func(o *Option) { - o.disableMessage = disableMessage + o.DisableMessage = disableMessage } } func WithAudioDir(audioDir string) AppOption { return func(o *Option) { - o.audioDir = audioDir + o.AudioDir = audioDir } } func WithImageDir(imageDir string) AppOption { return func(o *Option) { - o.imageDir = imageDir + o.ImageDir = imageDir } } diff --git a/api/prediction.go b/api/prediction.go deleted file mode 100644 index 7daa730..0000000 --- a/api/prediction.go +++ /dev/null @@ -1,649 +0,0 @@ -package api - -import ( - "fmt" - "os" - "path/filepath" - "regexp" - "strings" - "sync" - - "github.com/donomii/go-rwkv.cpp" - "github.com/go-skynet/LocalAI/pkg/langchain" - model "github.com/go-skynet/LocalAI/pkg/model" - "github.com/go-skynet/LocalAI/pkg/stablediffusion" - "github.com/go-skynet/bloomz.cpp" - bert "github.com/go-skynet/go-bert.cpp" - transformers "github.com/go-skynet/go-ggml-transformers.cpp" - llama "github.com/go-skynet/go-llama.cpp" - gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" -) - -// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 -var mutexMap sync.Mutex -var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) - -func defaultLLamaOpts(c Config) []llama.ModelOption { - llamaOpts := []llama.ModelOption{} - if c.ContextSize != 0 { - llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize)) - } - if c.F16 { - llamaOpts = append(llamaOpts, llama.EnableF16Memory) - } - if c.Embeddings { - llamaOpts = append(llamaOpts, llama.EnableEmbeddings) - } - - if c.NGPULayers != 0 { - llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers)) - } - - llamaOpts = append(llamaOpts, llama.SetMMap(c.MMap)) - llamaOpts = append(llamaOpts, llama.SetMainGPU(c.MainGPU)) - llamaOpts = append(llamaOpts, llama.SetTensorSplit(c.TensorSplit)) - if c.Batch != 0 { - llamaOpts = append(llamaOpts, llama.SetNBatch(c.Batch)) - } else { - llamaOpts = append(llamaOpts, llama.SetNBatch(512)) - } - - if c.NUMA { - llamaOpts = append(llamaOpts, llama.EnableNUMA) - } - - if c.LowVRAM { - llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) - } - - return llamaOpts -} - -func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) { - if c.Backend != model.StableDiffusionBackend { - return nil, fmt.Errorf("endpoint only working with stablediffusion models") - } - inferenceModel, err := loader.BackendLoader(c.Backend, c.ImageGenerationAssets, []llama.ModelOption{}, uint32(c.Threads), o.assetsDestination) - if err != nil { - return nil, err - } - - var fn func() error - switch model := inferenceModel.(type) { - case *stablediffusion.StableDiffusion: - fn = func() error { - return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) - } - - default: - fn = func() error { - return fmt.Errorf("creation of images not supported by the backend") - } - } - - return func() error { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[c.Backend] - if !ok { - m := &sync.Mutex{} - mutexes[c.Backend] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - return fn() - }, nil -} - -func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, o *Option) (func() ([]float32, error), error) { - if !c.Embeddings { - return nil, fmt.Errorf("endpoint disabled for this model by API configuration") - } - - modelFile := c.Model - - llamaOpts := defaultLLamaOpts(c) - - var inferenceModel interface{} - var err error - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) - } else { - inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) - } - if err != nil { - return nil, err - } - - var fn func() ([]float32, error) - switch model := inferenceModel.(type) { - case *llama.LLama: - fn = func() ([]float32, error) { - predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) - if len(tokens) > 0 { - return model.TokenEmbeddings(tokens, predictOptions...) - } - return model.Embeddings(s, predictOptions...) - } - // bert embeddings - case *bert.Bert: - fn = func() ([]float32, error) { - if len(tokens) > 0 { - return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) - } - return model.Embeddings(s, bert.SetThreads(c.Threads)) - } - default: - fn = func() ([]float32, error) { - return nil, fmt.Errorf("embeddings not supported by the backend") - } - } - - return func() ([]float32, error) { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[modelFile] - if !ok { - m := &sync.Mutex{} - mutexes[modelFile] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - embeds, err := fn() - if err != nil { - return embeds, err - } - // Remove trailing 0s - for i := len(embeds) - 1; i >= 0; i-- { - if embeds[i] == 0.0 { - embeds = embeds[:i] - } else { - break - } - } - return embeds, nil - }, nil -} - -func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption { - // Generate the prediction using the language model - predictOptions := []llama.PredictOption{ - llama.SetTemperature(c.Temperature), - llama.SetTopP(c.TopP), - llama.SetTopK(c.TopK), - llama.SetTokens(c.Maxtokens), - llama.SetThreads(c.Threads), - } - - if c.PromptCacheAll { - predictOptions = append(predictOptions, llama.EnablePromptCacheAll) - } - - if c.PromptCacheRO { - predictOptions = append(predictOptions, llama.EnablePromptCacheRO) - } - - predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar)) - - if c.PromptCachePath != "" { - // Create parent directory - p := filepath.Join(modelPath, c.PromptCachePath) - os.MkdirAll(filepath.Dir(p), 0755) - predictOptions = append(predictOptions, llama.SetPathPromptCache(p)) - } - - if c.Mirostat != 0 { - predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) - } - - if c.MirostatETA != 0 { - predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) - } - - if c.MirostatTAU != 0 { - predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) - } - - if c.Debug { - predictOptions = append(predictOptions, llama.Debug) - } - - predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) - - if c.RepeatPenalty != 0 { - predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) - } - - if c.Keep != 0 { - predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) - } - - if c.F16 { - predictOptions = append(predictOptions, llama.EnableF16KV) - } - - if c.IgnoreEOS { - predictOptions = append(predictOptions, llama.IgnoreEOS) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) - } - - //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) - - predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty)) - predictOptions = append(predictOptions, llama.SetMlock(c.MMlock)) - predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap)) - predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU)) - predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit)) - predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ)) - predictOptions = append(predictOptions, llama.SetTypicalP(c.TypicalP)) - - return predictOptions -} - -func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, tokenCallback func(string) bool) (func() (string, error), error) { - supportStreams := false - modelFile := c.Model - - llamaOpts := defaultLLamaOpts(c) - - var inferenceModel interface{} - var err error - if c.Backend == "" { - inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) - } else { - inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) - } - if err != nil { - return nil, err - } - - var fn func() (string, error) - - switch model := inferenceModel.(type) { - case *rwkv.RwkvState: - supportStreams = true - - fn = func() (string, error) { - stopWord := "\n" - if len(c.StopWords) > 0 { - stopWord = c.StopWords[0] - } - - if err := model.ProcessInput(s); err != nil { - return "", err - } - - response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) - - return response, nil - } - case *transformers.GPTNeoX: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.Replit: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.Starcoder: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.MPT: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *bloomz.Bloomz: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []bloomz.PredictOption{ - bloomz.SetTemperature(c.Temperature), - bloomz.SetTopP(c.TopP), - bloomz.SetTopK(c.TopK), - bloomz.SetTokens(c.Maxtokens), - bloomz.SetThreads(c.Threads), - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.Falcon: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.GPTJ: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.Dolly: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *transformers.GPT2: - fn = func() (string, error) { - // Generate the prediction using the language model - predictOptions := []transformers.PredictOption{ - transformers.SetTemperature(c.Temperature), - transformers.SetTopP(c.TopP), - transformers.SetTopK(c.TopK), - transformers.SetTokens(c.Maxtokens), - transformers.SetThreads(c.Threads), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) - } - - return model.Predict( - s, - predictOptions..., - ) - } - case *gpt4all.Model: - supportStreams = true - - fn = func() (string, error) { - if tokenCallback != nil { - model.SetTokenCallback(tokenCallback) - } - - // Generate the prediction using the language model - predictOptions := []gpt4all.PredictOption{ - gpt4all.SetTemperature(c.Temperature), - gpt4all.SetTopP(c.TopP), - gpt4all.SetTopK(c.TopK), - gpt4all.SetTokens(c.Maxtokens), - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch)) - } - - str, er := model.Predict( - s, - predictOptions..., - ) - // Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels) - // For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}} - // after a stream event has occurred - model.SetTokenCallback(nil) - return str, er - } - case *llama.LLama: - supportStreams = true - fn = func() (string, error) { - - if tokenCallback != nil { - model.SetTokenCallback(tokenCallback) - } - - predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) - - str, er := model.Predict( - s, - predictOptions..., - ) - // Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels) - // For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}} - // after a stream event has occurred - model.SetTokenCallback(nil) - return str, er - } - case *langchain.HuggingFace: - fn = func() (string, error) { - - // Generate the prediction using the language model - predictOptions := []langchain.PredictOption{ - langchain.SetModel(c.Model), - langchain.SetMaxTokens(c.Maxtokens), - langchain.SetTemperature(c.Temperature), - langchain.SetStopWords(c.StopWords), - } - - pred, er := model.PredictHuggingFace(s, predictOptions...) - if er != nil { - return "", er - } - return pred.Completion, nil - } - } - - return func() (string, error) { - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - mutexMap.Lock() - l, ok := mutexes[modelFile] - if !ok { - m := &sync.Mutex{} - mutexes[modelFile] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - - res, err := fn() - if tokenCallback != nil && !supportStreams { - tokenCallback(res) - } - return res, err - }, nil -} - -func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, o *Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { - result := []Choice{} - - n := input.N - - if input.N == 0 { - n = 1 - } - - // get the model function to call for the result - predFunc, err := ModelInference(predInput, loader, *config, o, tokenCallback) - if err != nil { - return result, err - } - - for i := 0; i < n; i++ { - prediction, err := predFunc() - if err != nil { - return result, err - } - - prediction = Finetune(*config, predInput, prediction) - cb(prediction, &result) - - //result = append(result, Choice{Text: prediction}) - - } - return result, err -} - -var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) -var mu sync.Mutex = sync.Mutex{} - -func Finetune(config Config, input, prediction string) string { - if config.Echo { - prediction = input + prediction - } - - for _, c := range config.Cutstrings { - mu.Lock() - reg, ok := cutstrings[c] - if !ok { - cutstrings[c] = regexp.MustCompile(c) - reg = cutstrings[c] - } - mu.Unlock() - prediction = reg.ReplaceAllString(prediction, "") - } - - for _, c := range config.TrimSpace { - prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) - } - return prediction - -} diff --git a/cmd/grpc/bert-embeddings/main.go b/cmd/grpc/bert-embeddings/main.go new file mode 100644 index 0000000..008c30d --- /dev/null +++ b/cmd/grpc/bert-embeddings/main.go @@ -0,0 +1,22 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" + bert "github.com/go-skynet/LocalAI/pkg/grpc/llm/bert" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &bert.Embeddings{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/bloomz/main.go b/cmd/grpc/bloomz/main.go new file mode 100644 index 0000000..7348cab --- /dev/null +++ b/cmd/grpc/bloomz/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + bloomz "github.com/go-skynet/LocalAI/pkg/grpc/llm/bloomz" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &bloomz.LLM{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/dolly/main.go b/cmd/grpc/dolly/main.go new file mode 100644 index 0000000..43bba92 --- /dev/null +++ b/cmd/grpc/dolly/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Dolly{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/falcon-ggml/main.go b/cmd/grpc/falcon-ggml/main.go new file mode 100644 index 0000000..677c660 --- /dev/null +++ b/cmd/grpc/falcon-ggml/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Falcon{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/falcon/main.go b/cmd/grpc/falcon/main.go new file mode 100644 index 0000000..9ccead4 --- /dev/null +++ b/cmd/grpc/falcon/main.go @@ -0,0 +1,25 @@ +package main + +// GRPC Falcon server + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + falcon "github.com/go-skynet/LocalAI/pkg/grpc/llm/falcon" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &falcon.LLM{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/gpt2/main.go b/cmd/grpc/gpt2/main.go new file mode 100644 index 0000000..d9fe275 --- /dev/null +++ b/cmd/grpc/gpt2/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.GPT2{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/gpt4all/main.go b/cmd/grpc/gpt4all/main.go new file mode 100644 index 0000000..a784d40 --- /dev/null +++ b/cmd/grpc/gpt4all/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + gpt4all "github.com/go-skynet/LocalAI/pkg/grpc/llm/gpt4all" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &gpt4all.LLM{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/gptj/main.go b/cmd/grpc/gptj/main.go new file mode 100644 index 0000000..27d8210 --- /dev/null +++ b/cmd/grpc/gptj/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.GPTJ{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/gptneox/main.go b/cmd/grpc/gptneox/main.go new file mode 100644 index 0000000..3d005ca --- /dev/null +++ b/cmd/grpc/gptneox/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.GPTNeoX{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/langchain-huggingface/main.go b/cmd/grpc/langchain-huggingface/main.go new file mode 100644 index 0000000..ab96584 --- /dev/null +++ b/cmd/grpc/langchain-huggingface/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + langchain "github.com/go-skynet/LocalAI/pkg/grpc/llm/langchain" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &langchain.LLM{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/llama/main.go b/cmd/grpc/llama/main.go new file mode 100644 index 0000000..d75ef48 --- /dev/null +++ b/cmd/grpc/llama/main.go @@ -0,0 +1,25 @@ +package main + +// GRPC Falcon server + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &llama.LLM{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/mpt/main.go b/cmd/grpc/mpt/main.go new file mode 100644 index 0000000..58456a7 --- /dev/null +++ b/cmd/grpc/mpt/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.MPT{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/piper/main.go b/cmd/grpc/piper/main.go new file mode 100644 index 0000000..7de80e2 --- /dev/null +++ b/cmd/grpc/piper/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + tts "github.com/go-skynet/LocalAI/pkg/grpc/tts" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &tts.Piper{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/replit/main.go b/cmd/grpc/replit/main.go new file mode 100644 index 0000000..aed67fb --- /dev/null +++ b/cmd/grpc/replit/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Replit{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/rwkv/main.go b/cmd/grpc/rwkv/main.go new file mode 100644 index 0000000..f050a7c --- /dev/null +++ b/cmd/grpc/rwkv/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + rwkv "github.com/go-skynet/LocalAI/pkg/grpc/llm/rwkv" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &rwkv.LLM{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/stablediffusion/main.go b/cmd/grpc/stablediffusion/main.go new file mode 100644 index 0000000..76b4a5a --- /dev/null +++ b/cmd/grpc/stablediffusion/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + image "github.com/go-skynet/LocalAI/pkg/grpc/image" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &image.StableDiffusion{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/starcoder/main.go b/cmd/grpc/starcoder/main.go new file mode 100644 index 0000000..2847acf --- /dev/null +++ b/cmd/grpc/starcoder/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transformers.Starcoder{}); err != nil { + panic(err) + } +} diff --git a/cmd/grpc/whisper/main.go b/cmd/grpc/whisper/main.go new file mode 100644 index 0000000..8d4a5fe --- /dev/null +++ b/cmd/grpc/whisper/main.go @@ -0,0 +1,23 @@ +package main + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + transcribe "github.com/go-skynet/LocalAI/pkg/grpc/transcribe" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &transcribe.Whisper{}); err != nil { + panic(err) + } +} diff --git a/go.mod b/go.mod index 0f65978..1d6268c 100644 --- a/go.mod +++ b/go.mod @@ -13,20 +13,25 @@ require ( github.com/gofiber/fiber/v2 v2.47.0 github.com/google/uuid v1.3.0 github.com/hashicorp/go-multierror v1.1.1 + github.com/hpcloud/tail v1.0.0 github.com/imdario/mergo v0.3.16 github.com/json-iterator/go v1.1.12 github.com/mholt/archiver/v3 v3.5.1 + github.com/mudler/go-ggllm.cpp v0.0.0-20230708215552-a6504d5bc137 + github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230708212935-d611d107479f github.com/onsi/ginkgo/v2 v2.11.0 github.com/onsi/gomega v1.27.8 github.com/otiai10/openaigo v1.5.2 + github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 github.com/rs/zerolog v1.29.1 github.com/sashabaranov/go-openai v1.13.0 - github.com/swaggo/swag v1.16.1 github.com/tmc/langchaingo v0.0.0-20230709010448-a875e6bc0c54 github.com/urfave/cli/v2 v2.25.7 github.com/valyala/fasthttp v1.48.0 + google.golang.org/grpc v1.56.2 + google.golang.org/protobuf v1.30.0 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -34,8 +39,10 @@ require ( require ( github.com/dlclark/regexp2 v1.8.1 // indirect github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect + github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.2 // indirect github.com/klauspost/pgzip v1.2.5 // indirect + github.com/kr/text v0.2.0 // indirect github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/nwaples/rardecode v1.1.0 // indirect @@ -43,33 +50,27 @@ require ( github.com/pkoukk/tiktoken-go v0.1.2 // indirect github.com/ulikunitz/xz v0.5.9 // indirect github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect + google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/fsnotify.v1 v1.4.7 // indirect + gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect ) require ( - github.com/KyleBanks/depth v1.2.1 // indirect - github.com/PuerkitoBio/purell v1.1.1 // indirect - github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 // indirect github.com/andybalholm/brotli v1.0.5 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/go-audio/audio v1.0.0 // indirect github.com/go-audio/riff v1.0.0 // indirect github.com/go-logr/logr v1.2.4 // indirect - github.com/go-openapi/jsonpointer v0.19.5 // indirect - github.com/go-openapi/jsonreference v0.19.6 // indirect - github.com/go-openapi/spec v0.20.4 // indirect - github.com/go-openapi/swag v0.22.3 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/hashicorp/errwrap v1.0.0 // indirect - github.com/josharian/intern v1.0.0 // indirect github.com/klauspost/compress v1.16.3 // indirect - github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-runewidth v0.0.14 // indirect github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 - github.com/otiai10/mint v1.6.1 // indirect github.com/philhofer/fwd v1.1.2 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 81f81e7..2906f50 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,3 @@ -github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= -github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= -github.com/PuerkitoBio/purell v1.1.1 h1:WEQqlqaGbrPkxLJWfBwQmfEAE1Z7ONdDLqrN38tNFfI= -github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= -github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578 h1:d+Bc7a5rLufV/sSk/8dngufqelfh6jnri85riMAaF/M= -github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= @@ -19,13 +13,12 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674 h1:G70Yf/QOCEL1v24idWnGd6rJsbqiGkJAJnMaWaolzEg= -github.com/donomii/go-rwkv.cpp v0.0.0-20230619005719-f5a8c4539674/go.mod h1:gWy7FIWioqYmYxkaoFyBnaKApeZVrUkHhv9EV9pz4dM= github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L7HYpRu/0lE3e0BaElwnNO1qkNQxBY= github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s= github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27 h1:boeMTUUBtnLU8JElZJHXrsUzROJar9/t6vGOFjkrhhI= -github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230606002726-57543c169e27/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= +github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e h1:KtbU2JR3lJuXFASHG2+sVLucfMPBjWKUUKByX6C81mQ= github.com/ggerganov/whisper.cpp/bindings/go v0.0.0-20230628193450-85ed71aaec8e/go.mod h1:QIjZ9OktHFG7p+/m3sMvrAJKKdWrr1fZIK0rM6HZlyo= github.com/go-audio/audio v1.0.0 h1:zS9vebldgbQqktK4H0lUqWrG8P0NxCJVqcj7ZpNnwd4= @@ -36,47 +29,28 @@ github.com/go-audio/wav v1.1.0 h1:jQgLtbqBzY7G+BM8fXF7AHUk1uHUviWS4X39d5rsL2g= github.com/go-audio/wav v1.1.0/go.mod h1:mpe9qfwbScEbkd8uybLuIpTgHyrISw/OTuvjUW2iGtE= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= -github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= -github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= -github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= -github.com/go-openapi/jsonreference v0.19.6 h1:UBIxjkht+AWIgYzCDSv2GN+E/togfwXUJFRTWhl2Jjs= -github.com/go-openapi/jsonreference v0.19.6/go.mod h1:diGHMEHg2IqXZGKxqyvWdfWU/aim5Dprw5bqpKkTvns= -github.com/go-openapi/spec v0.20.4 h1:O8hJrt0UMnhHcluhIdUgCLRWyM2x7QkBXRvOs7m+O1M= -github.com/go-openapi/spec v0.20.4/go.mod h1:faYFR1CvsJZ0mNsmsphTMSoRrNV3TEDoAM7FOEWeq8I= -github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk= -github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyrCM= -github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= -github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= -github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa h1:gxr68r/6EWroay4iI81jxqGCDbKotY4+CiwdUkBz2NQ= -github.com/go-skynet/bloomz.cpp v0.0.0-20230529155654-1834e77b83fa/go.mod h1:wc0fJ9V04yiYTfgKvE5RUUSRQ5Kzi0Bo4I+U3nNOUuA= -github.com/go-skynet/go-bert.cpp v0.0.0-20230607105116-6069103f54b9 h1:wRGbDwNwPmSzoXVw/HLzXY4blpRvPWg7QW2OA0WKezA= -github.com/go-skynet/go-bert.cpp v0.0.0-20230607105116-6069103f54b9/go.mod h1:pXKCpYYXujMeAvgJHU6WoMfvYbr84563+J8+Ebkyr5U= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230617123349-32b9223ccdb1 h1:jVGgzDSfpjD/0jl/ChpGI+O4EHSAeeU6DK7IyhH8PK8= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230617123349-32b9223ccdb1/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230620192816-a459d2726792 h1:rozZ9gWGzq0ZhBsNCWqfLTRCebaxwTsxLMnflwe6rDU= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230620192816-a459d2726792/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230626202628-8e31841dcddc h1:SrNxH4U8W6cqurbxpXxm9rzifeDsCgecRT73kT0BRq0= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230626202628-8e31841dcddc/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230630204211-3fec197a1dc4 h1:LScGc8yWTS9wbS2RTOq6s+waeHElLIQDJg2SUCwrO3E= -github.com/go-skynet/go-ggml-transformers.cpp v0.0.0-20230630204211-3fec197a1dc4/go.mod h1:31j1odgFXP8hDSUVfH0zErKI5aYVP18ddYnPkwCso2A= -github.com/go-skynet/go-llama.cpp v0.0.0-20230616223721-7ad833b67070 h1:T771FjB1yQw8j4P5x4ayFrUPNTglzxRIqDjaNkMVIME= -github.com/go-skynet/go-llama.cpp v0.0.0-20230616223721-7ad833b67070/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= -github.com/go-skynet/go-llama.cpp v0.0.0-20230626215901-f104111358e8 h1:Knh5QUvI/68erb/yWtrVa/3hvoQdENF2dH0hL2HNPrI= -github.com/go-skynet/go-llama.cpp v0.0.0-20230626215901-f104111358e8/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= -github.com/go-skynet/go-llama.cpp v0.0.0-20230627195533-582753605210 h1:9bm+vsiR3UI7xlU0G0cMU2Swq78RysoFVkSONvrujF8= -github.com/go-skynet/go-llama.cpp v0.0.0-20230627195533-582753605210/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= -github.com/go-skynet/go-llama.cpp v0.0.0-20230628194133-42ba44838369 h1:lSX1NWzRvRS2MlACvyvVVUnqXhKiuMAoN3DO5TbCe8M= -github.com/go-skynet/go-llama.cpp v0.0.0-20230628194133-42ba44838369/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= -github.com/go-skynet/go-llama.cpp v0.0.0-20230703203849-ffa57fbc3a12 h1:cfGZiZana0gPD0i8nmyOGTUQGb4N8PYqaBqhhukREPc= -github.com/go-skynet/go-llama.cpp v0.0.0-20230703203849-ffa57fbc3a12/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40= +github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofiber/fiber/v2 v2.47.0 h1:EN5lHVCc+Pyqh5OEsk8fzRiifgwpbrP0rulQ4iNf3fs= github.com/gofiber/fiber/v2 v2.47.0/go.mod h1:mbFMVN1lQuzziTkkakgtKKdjfsXSw9BKR5lmcNksUoU= +github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= +github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= +github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= +github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= +github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= +github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.2 h1:aeE13tS0IiQgFjYdoL8qN3K1N2bXXtI6Vi51/y7BpMw= github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -89,11 +63,11 @@ github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/U github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= +github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= -github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= @@ -103,16 +77,12 @@ github.com/klauspost/compress v1.16.3/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQs github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/pgzip v1.2.5 h1:qnWYvvKqedOF2ulHpMG72XQol4ILEJ8k2wwRl/Km8oE= github.com/klauspost/pgzip v1.2.5/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= -github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA= -github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= @@ -128,33 +98,29 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760 h1:OFVkSxR7CRSRSNm5dvpMRZwmSwWa8EMMnHbc84fW5tU= -github.com/mudler/go-piper v0.0.0-20230621222733-56b8a81b4760/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= -github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af h1:XFq6OUqsWQam0OrEr05okXsJK/TQur3zoZTHbiZD3Ks= -github.com/mudler/go-stable-diffusion v0.0.0-20230605122230-d89260f598af/go.mod h1:8ufRkpz/S/9ahkaxzZ5i4WMgO9w4InEhuRoT7vK5Rnw= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= -github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230620230702-09ae04cee90c h1:axNtjd5k6Xs4Ck7B7VRRQu6q5lQzTsjdWmaJkDADopU= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230620230702-09ae04cee90c/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230628182915-a67f8132e165 h1:zcnIdoSeLueTDxUD2A1qnyaSp8uh0Ay7OgHeBwpxSeg= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230628182915-a67f8132e165/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230708212935-d611d107479f h1:FtXRIjsBvoBQ5xmA26QbzyG4RjV2U5lOpUgP4npITOM= -github.com/nomic-ai/gpt4all/gpt4all-bindings/golang v0.0.0-20230708212935-d611d107479f/go.mod h1:4T3CHXyrt+7FQHXaxULZfPjHbD8/99WuDDJa0YVZARI= +github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d h1:/lAg9vPAAU+s35cDMCx1IyeMn+4OYfCBPqi08Q8vXDg= +github.com/mudler/go-processmanager v0.0.0-20220724164624-c45b5c61312d/go.mod h1:HGGAOJhipApckwNV8ZTliRJqxctUv3xRY+zbQEwuytc= github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ= github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0= +github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= +github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU= github.com/onsi/ginkgo/v2 v2.11.0/go.mod h1:ZhrRA5XmEE3x3rhlzamx/JJvujdZoJ2uvgI7kR0iZvM= +github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= +github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= +github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc= github.com/onsi/gomega v1.27.8/go.mod h1:2J8vzI/s+2shY9XHRApDkdgPo1TKT7P2u6fXeJKFnNQ= -github.com/otiai10/mint v1.5.1 h1:XaPLeE+9vGbuyEHem1JNk3bYc7KKqyI/na0/mLd/Kks= -github.com/otiai10/mint v1.5.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= -github.com/otiai10/mint v1.6.1/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= -github.com/otiai10/openaigo v1.2.0 h1:Whq+uvgqw8NdIsVdixtBKCAI6OdfCJiGPlhUnYJQ6Ag= -github.com/otiai10/openaigo v1.2.0/go.mod h1:792bx6AWTS61weDi2EzKpHHnTF4eDMAlJ5GvAk/mgPg= -github.com/otiai10/openaigo v1.4.0 h1:BeacKb2Q5bVejjOKHFJxL2WFYal3QxwkrKtKuoU5LNU= -github.com/otiai10/openaigo v1.4.0/go.mod h1:kIaXc3V+Xy5JLplcBxehVyGYDtufHp3PFPy04jOwOAI= +github.com/otiai10/mint v1.6.1 h1:kgbTJmOpp/0ce7hk3H8jiSuR0MXmpwWRfqUdKww17qg= github.com/otiai10/openaigo v1.5.2 h1:YnNDisZmA4syArF3IxMCIrfgZOq30PLV219gPY7n2z8= github.com/otiai10/openaigo v1.5.2/go.mod h1:kIaXc3V+Xy5JLplcBxehVyGYDtufHp3PFPy04jOwOAI= +github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= +github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw= github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0= @@ -172,8 +138,6 @@ github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sashabaranov/go-openai v1.11.3 h1:bvwWF8hj4UhPlswBdL9/IfOpaHXfzGCJO8WY8ml9sGc= -github.com/sashabaranov/go-openai v1.11.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sashabaranov/go-openai v1.13.0 h1:EAusFfnhaMaaUspUZ2+MbB/ZcVeD4epJmTOlZ+8AcAE= github.com/sashabaranov/go-openai v1.13.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4= @@ -181,26 +145,14 @@ github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94/go.mod h1:90zrgN3 github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4= github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee h1:8Iv5m6xEo1NR1AvpV+7XmhI4r39LGNzwUL4YpMuL5vk= github.com/savsgio/gotils v0.0.0-20230208104028-c358bd845dee/go.mod h1:qwtSXrKuJh/zsFQ12yEE89xfCrGKK63Rr7ctU/uCo4g= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/swaggo/swag v1.16.1 h1:fTNRhKstPKxcnoKsytm4sahr8FaYzUcT7i1/3nd/fBg= -github.com/swaggo/swag v1.16.1/go.mod h1:9/LMvHycG3NFHfR6LwvikHv5iFvmPADQ359cKikGxto= github.com/tinylib/msgp v1.1.6/go.mod h1:75BAfg2hauQhs3qedfdDZmWAPcFMAvJE5b9rGOMufyw= github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= -github.com/tmc/langchaingo v0.0.0-20230616220619-1b3da4433944 h1:EE9fvNENTdRc/yI/1zAs7VFbmDk6JZ7EbBIFl+TsCm0= -github.com/tmc/langchaingo v0.0.0-20230616220619-1b3da4433944/go.mod h1:6l1WoyqVDwkv7cFlY3gfcTv8yVowVyuutKv8PGlQCWI= -github.com/tmc/langchaingo v0.0.0-20230625081011-4d9d55dbcaba h1:NpAI9C0y9T4jwP7XFShwYJKGf/ggyCgZEtL/7lLRPwE= -github.com/tmc/langchaingo v0.0.0-20230625081011-4d9d55dbcaba/go.mod h1:tz9cjA9BW8/lWx/T5njr3ZLHK/dfPyr/0ICSMThmY2g= -github.com/tmc/langchaingo v0.0.0-20230625234550-7ea734523e39 h1:SpOEFXx5xXLypFnwNRQj7yOC3rMvSylGA5BQW/FAwYc= -github.com/tmc/langchaingo v0.0.0-20230625234550-7ea734523e39/go.mod h1:tz9cjA9BW8/lWx/T5njr3ZLHK/dfPyr/0ICSMThmY2g= -github.com/tmc/langchaingo v0.0.0-20230627220614-633853b5ac3b h1:xUxtya/3KRDn1rcCVZucp2KhjdqSZat9j0hOshSVh2Q= -github.com/tmc/langchaingo v0.0.0-20230627220614-633853b5ac3b/go.mod h1:F1k7uRBLM8jMMEPV3dVtWVNc+W91nxOBRKbJWM/LwpM= -github.com/tmc/langchaingo v0.0.0-20230628165432-e510561c17f9 h1:BooyHg3f058lrPcTLdfC7HTfjO5OGZAgwciQJ5e85l0= -github.com/tmc/langchaingo v0.0.0-20230628165432-e510561c17f9/go.mod h1:F1k7uRBLM8jMMEPV3dVtWVNc+W91nxOBRKbJWM/LwpM= github.com/tmc/langchaingo v0.0.0-20230709010448-a875e6bc0c54 h1:MZSC3/pdBzkoPG49uTRvtEepOQKdbdgaT1aLtaEwxx4= github.com/tmc/langchaingo v0.0.0-20230709010448-a875e6bc0c54/go.mod h1:RsMJqgUynOtr2jWNhUF41R3j6SDkKq9c8UfE0nJYBb4= github.com/ulikunitz/xz v0.5.8/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= @@ -228,25 +180,34 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= +golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210420072515-93ed5bcd2bfe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -270,6 +231,7 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201022035929-9cf592e881e9/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.4.0/go.mod h1:UE5sM2OK9E/d67R0ANs2xJizIymRP5gJU295PvKXxjQ= golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM= @@ -278,15 +240,33 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1 h1:KpwkzHKEF7B9Zxg18WzOa7djJ+Ha5DzthMyZYQfEn2A= +google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU= +google.golang.org/grpc v1.56.2 h1:fVRFRnXvU+x6C4IlHZewvJOVHoOv1TUuQyoRsYnB4bI= +google.golang.org/grpc v1.56.2/go.mod h1:I9bI3vqKfayGqPUAwGdOSu7kt6oIJLixfffKrpXqQ9s= +google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= +google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= +google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= +google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= +google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= +google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU= -gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= +gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= +gopkg.in/op/go-logging.v1 v1.0.0-20160211212156-b2cb9fa56473/go.mod h1:N1eN2tsCx0Ydtgjl4cqmbRCsY4/+z4cYDeqwZTk6zog= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index fc1dea0..3f534b0 100644 --- a/main.go +++ b/main.go @@ -2,9 +2,12 @@ package main import ( "os" + "os/signal" "path/filepath" + "syscall" api "github.com/go-skynet/LocalAI/api" + "github.com/go-skynet/LocalAI/api/options" "github.com/go-skynet/LocalAI/internal" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog" @@ -14,6 +17,13 @@ import ( func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + // clean up process + go func() { + c := make(chan os.Signal, 1) // we need to reserve to buffer size 1, so the notifier are not blocked + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + os.Exit(1) + }() path, err := os.Getwd() if err != nil { @@ -129,23 +139,23 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit Copyright: "Ettore Di Giacinto", Action: func(ctx *cli.Context) error { app, err := api.App( - api.WithConfigFile(ctx.String("config-file")), - api.WithJSONStringPreload(ctx.String("preload-models")), - api.WithYAMLConfigPreload(ctx.String("preload-models-config")), - api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), - api.WithContextSize(ctx.Int("context-size")), - api.WithDebug(ctx.Bool("debug")), - api.WithImageDir(ctx.String("image-path")), - api.WithAudioDir(ctx.String("audio-path")), - api.WithF16(ctx.Bool("f16")), - api.WithStringGalleries(ctx.String("galleries")), - api.WithDisableMessage(false), - api.WithCors(ctx.Bool("cors")), - api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), - api.WithThreads(ctx.Int("threads")), - api.WithBackendAssets(backendAssets), - api.WithBackendAssetsOutput(ctx.String("backend-assets-path")), - api.WithUploadLimitMB(ctx.Int("upload-limit"))) + options.WithConfigFile(ctx.String("config-file")), + options.WithJSONStringPreload(ctx.String("preload-models")), + options.WithYAMLConfigPreload(ctx.String("preload-models-config")), + options.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), + options.WithContextSize(ctx.Int("context-size")), + options.WithDebug(ctx.Bool("debug")), + options.WithImageDir(ctx.String("image-path")), + options.WithAudioDir(ctx.String("audio-path")), + options.WithF16(ctx.Bool("f16")), + options.WithStringGalleries(ctx.String("galleries")), + options.WithDisableMessage(false), + options.WithCors(ctx.Bool("cors")), + options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), + options.WithThreads(ctx.Int("threads")), + options.WithBackendAssets(backendAssets), + options.WithBackendAssetsOutput(ctx.String("backend-assets-path")), + options.WithUploadLimitMB(ctx.Int("upload-limit"))) if err != nil { return err } diff --git a/pkg/grpc/base/base.go b/pkg/grpc/base/base.go new file mode 100644 index 0000000..a6d89f2 --- /dev/null +++ b/pkg/grpc/base/base.go @@ -0,0 +1,42 @@ +package base + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" +) + +type Base struct { +} + +func (llm *Base) Load(opts *pb.ModelOptions) error { + return fmt.Errorf("unimplemented") + +} + +func (llm *Base) Predict(opts *pb.PredictOptions) (string, error) { + return "", fmt.Errorf("unimplemented") +} + +func (llm *Base) PredictStream(opts *pb.PredictOptions, results chan string) error { + return fmt.Errorf("unimplemented") +} + +func (llm *Base) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return []float32{}, fmt.Errorf("unimplemented") +} + +func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { + return fmt.Errorf("unimplemented") +} + +func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (api.Result, error) { + return api.Result{}, fmt.Errorf("unimplemented") +} + +func (llm *Base) TTS(*pb.TTSRequest) error { + return fmt.Errorf("unimplemented") +} diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go new file mode 100644 index 0000000..bbc40bf --- /dev/null +++ b/pkg/grpc/client.go @@ -0,0 +1,160 @@ +package grpc + +import ( + "context" + "fmt" + "io" + "time" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +type Client struct { + address string +} + +func NewClient(address string) *Client { + return &Client{ + address: address, + } +} + +func (c *Client) HealthCheck(ctx context.Context) bool { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + fmt.Println(err) + return false + } + defer conn.Close() + client := pb.NewBackendClient(conn) + + // The healthcheck call shouldn't take long time + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + res, err := client.Health(ctx, &pb.HealthMessage{}) + if err != nil { + fmt.Println(err) + + return false + } + + if res.Message == "OK" { + return true + } + return false +} + +func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + + return client.Embedding(ctx, in, opts...) +} + +func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + + return client.Predict(ctx, in, opts...) +} + +func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.LoadModel(ctx, in, opts...) +} + +func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s string), opts ...grpc.CallOption) error { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + + stream, err := client.PredictStream(ctx, in, opts...) + if err != nil { + return err + } + + for { + feature, err := stream.Recv() + if err == io.EOF { + break + } + if err != nil { + fmt.Println("Error", err) + + return err + } + f(feature.GetMessage()) + } + + return nil +} + +func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.GenerateImage(ctx, in, opts...) +} + +func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + return client.TTS(ctx, in, opts...) +} + +func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*api.Result, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewBackendClient(conn) + res, err := client.AudioTranscription(ctx, in, opts...) + if err != nil { + return nil, err + } + tresult := &api.Result{} + for _, s := range res.Segments { + tks := []int{} + for _, t := range s.Tokens { + tks = append(tks, int(t)) + } + tresult.Segments = append(tresult.Segments, + api.Segment{ + Text: s.Text, + Id: int(s.Id), + Start: time.Duration(s.Start), + End: time.Duration(s.End), + Tokens: tks, + }) + } + tresult.Text = res.Text + return tresult, err +} diff --git a/pkg/grpc/image/stablediffusion.go b/pkg/grpc/image/stablediffusion.go new file mode 100644 index 0000000..ce0275e --- /dev/null +++ b/pkg/grpc/image/stablediffusion.go @@ -0,0 +1,33 @@ +package image + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/stablediffusion" +) + +type StableDiffusion struct { + base.Base + stablediffusion *stablediffusion.StableDiffusion +} + +func (sd *StableDiffusion) Load(opts *pb.ModelOptions) error { + var err error + // Note: the Model here is a path to a directory containing the model files + sd.stablediffusion, err = stablediffusion.New(opts.Model) + return err +} + +func (sd *StableDiffusion) GenerateImage(opts *pb.GenerateImageRequest) error { + return sd.stablediffusion.GenerateImage( + int(opts.Height), + int(opts.Width), + int(opts.Mode), + int(opts.Step), + int(opts.Seed), + opts.PositivePrompt, + opts.NegativePrompt, + opts.Dst) +} diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go new file mode 100644 index 0000000..6832a95 --- /dev/null +++ b/pkg/grpc/interface.go @@ -0,0 +1,16 @@ +package grpc + +import ( + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" +) + +type LLM interface { + Predict(*pb.PredictOptions) (string, error) + PredictStream(*pb.PredictOptions, chan string) error + Load(*pb.ModelOptions) error + Embeddings(*pb.PredictOptions) ([]float32, error) + GenerateImage(*pb.GenerateImageRequest) error + AudioTranscription(*pb.TranscriptRequest) (api.Result, error) + TTS(*pb.TTSRequest) error +} diff --git a/pkg/grpc/llm/bert/bert.go b/pkg/grpc/llm/bert/bert.go new file mode 100644 index 0000000..7692797 --- /dev/null +++ b/pkg/grpc/llm/bert/bert.go @@ -0,0 +1,33 @@ +package bert + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + bert "github.com/go-skynet/go-bert.cpp" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" +) + +type Embeddings struct { + base.Base + bert *bert.Bert +} + +func (llm *Embeddings) Load(opts *pb.ModelOptions) error { + model, err := bert.New(opts.Model) + llm.bert = model + return err +} + +func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + if len(opts.EmbeddingTokens) > 0 { + tokens := []int{} + for _, t := range opts.EmbeddingTokens { + tokens = append(tokens, int(t)) + } + return llm.bert.TokenEmbeddings(tokens, bert.SetThreads(int(opts.Threads))) + } + + return llm.bert.Embeddings(opts.Embeddings, bert.SetThreads(int(opts.Threads))) +} diff --git a/pkg/grpc/llm/bloomz/bloomz.go b/pkg/grpc/llm/bloomz/bloomz.go new file mode 100644 index 0000000..daa2264 --- /dev/null +++ b/pkg/grpc/llm/bloomz/bloomz.go @@ -0,0 +1,59 @@ +package bloomz + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + "github.com/go-skynet/bloomz.cpp" +) + +type LLM struct { + base.Base + + bloomz *bloomz.Bloomz +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + model, err := bloomz.New(opts.Model) + llm.bloomz = model + return err +} + +func buildPredictOptions(opts *pb.PredictOptions) []bloomz.PredictOption { + predictOptions := []bloomz.PredictOption{ + bloomz.SetTemperature(float64(opts.Temperature)), + bloomz.SetTopP(float64(opts.TopP)), + bloomz.SetTopK(int(opts.TopK)), + bloomz.SetTokens(int(opts.Tokens)), + bloomz.SetThreads(int(opts.Threads)), + } + + if opts.Seed != 0 { + predictOptions = append(predictOptions, bloomz.SetSeed(int(opts.Seed))) + } + + return predictOptions +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + return llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/falcon/falcon.go b/pkg/grpc/llm/falcon/falcon.go new file mode 100644 index 0000000..3c0f84e --- /dev/null +++ b/pkg/grpc/llm/falcon/falcon.go @@ -0,0 +1,144 @@ +package falcon + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + ggllm "github.com/mudler/go-ggllm.cpp" +) + +type LLM struct { + base.Base + + falcon *ggllm.Falcon +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + ggllmOpts := []ggllm.ModelOption{} + if opts.ContextSize != 0 { + ggllmOpts = append(ggllmOpts, ggllm.SetContext(int(opts.ContextSize))) + } + // F16 doesn't seem to produce good output at all! + //if c.F16 { + // llamaOpts = append(llamaOpts, llama.EnableF16Memory) + //} + + if opts.NGPULayers != 0 { + ggllmOpts = append(ggllmOpts, ggllm.SetGPULayers(int(opts.NGPULayers))) + } + + ggllmOpts = append(ggllmOpts, ggllm.SetMMap(opts.MMap)) + ggllmOpts = append(ggllmOpts, ggllm.SetMainGPU(opts.MainGPU)) + ggllmOpts = append(ggllmOpts, ggllm.SetTensorSplit(opts.TensorSplit)) + if opts.NBatch != 0 { + ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(int(opts.NBatch))) + } else { + ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(512)) + } + + model, err := ggllm.New(opts.Model, ggllmOpts...) + llm.falcon = model + return err +} + +func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption { + predictOptions := []ggllm.PredictOption{ + ggllm.SetTemperature(float64(opts.Temperature)), + ggllm.SetTopP(float64(opts.TopP)), + ggllm.SetTopK(int(opts.TopK)), + ggllm.SetTokens(int(opts.Tokens)), + ggllm.SetThreads(int(opts.Threads)), + } + + if opts.PromptCacheAll { + predictOptions = append(predictOptions, ggllm.EnablePromptCacheAll) + } + + if opts.PromptCacheRO { + predictOptions = append(predictOptions, ggllm.EnablePromptCacheRO) + } + + // Expected absolute path + if opts.PromptCachePath != "" { + predictOptions = append(predictOptions, ggllm.SetPathPromptCache(opts.PromptCachePath)) + } + + if opts.Mirostat != 0 { + predictOptions = append(predictOptions, ggllm.SetMirostat(int(opts.Mirostat))) + } + + if opts.MirostatETA != 0 { + predictOptions = append(predictOptions, ggllm.SetMirostatETA(float64(opts.MirostatETA))) + } + + if opts.MirostatTAU != 0 { + predictOptions = append(predictOptions, ggllm.SetMirostatTAU(float64(opts.MirostatTAU))) + } + + if opts.Debug { + predictOptions = append(predictOptions, ggllm.Debug) + } + + predictOptions = append(predictOptions, ggllm.SetStopWords(opts.StopPrompts...)) + + if opts.PresencePenalty != 0 { + predictOptions = append(predictOptions, ggllm.SetPenalty(float64(opts.PresencePenalty))) + } + + if opts.NKeep != 0 { + predictOptions = append(predictOptions, ggllm.SetNKeep(int(opts.NKeep))) + } + + if opts.Batch != 0 { + predictOptions = append(predictOptions, ggllm.SetBatch(int(opts.Batch))) + } + + if opts.IgnoreEOS { + predictOptions = append(predictOptions, ggllm.IgnoreEOS) + } + + if opts.Seed != 0 { + predictOptions = append(predictOptions, ggllm.SetSeed(int(opts.Seed))) + } + + //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) + + predictOptions = append(predictOptions, ggllm.SetFrequencyPenalty(float64(opts.FrequencyPenalty))) + predictOptions = append(predictOptions, ggllm.SetMlock(opts.MLock)) + predictOptions = append(predictOptions, ggllm.SetMemoryMap(opts.MMap)) + predictOptions = append(predictOptions, ggllm.SetPredictionMainGPU(opts.MainGPU)) + predictOptions = append(predictOptions, ggllm.SetPredictionTensorSplit(opts.TensorSplit)) + predictOptions = append(predictOptions, ggllm.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ))) + predictOptions = append(predictOptions, ggllm.SetTypicalP(float64(opts.TypicalP))) + return predictOptions +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + predictOptions := buildPredictOptions(opts) + + predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool { + if token == "<|endoftext|>" { + return true + } + results <- token + return true + })) + + go func() { + _, err := llm.falcon.Predict(opts.Prompt, predictOptions...) + if err != nil { + fmt.Println("err: ", err) + } + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/gpt4all/gpt4all.go b/pkg/grpc/llm/gpt4all/gpt4all.go new file mode 100644 index 0000000..e17afc1 --- /dev/null +++ b/pkg/grpc/llm/gpt4all/gpt4all.go @@ -0,0 +1,62 @@ +package gpt4all + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" +) + +type LLM struct { + base.Base + + gpt4all *gpt4all.Model +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + model, err := gpt4all.New(opts.Model, + gpt4all.SetThreads(int(opts.Threads)), + gpt4all.SetLibrarySearchPath(opts.LibrarySearchPath)) + llm.gpt4all = model + return err +} + +func buildPredictOptions(opts *pb.PredictOptions) []gpt4all.PredictOption { + predictOptions := []gpt4all.PredictOption{ + gpt4all.SetTemperature(float64(opts.Temperature)), + gpt4all.SetTopP(float64(opts.TopP)), + gpt4all.SetTopK(int(opts.TopK)), + gpt4all.SetTokens(int(opts.Tokens)), + } + + if opts.Batch != 0 { + predictOptions = append(predictOptions, gpt4all.SetBatch(int(opts.Batch))) + } + return predictOptions +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + predictOptions := buildPredictOptions(opts) + + go func() { + llm.gpt4all.SetTokenCallback(func(token string) bool { + results <- token + return true + }) + _, err := llm.gpt4all.Predict(opts.Prompt, predictOptions...) + if err != nil { + fmt.Println("err: ", err) + } + llm.gpt4all.SetTokenCallback(nil) + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/langchain/langchain.go b/pkg/grpc/llm/langchain/langchain.go new file mode 100644 index 0000000..5d5f94b --- /dev/null +++ b/pkg/grpc/llm/langchain/langchain.go @@ -0,0 +1,58 @@ +package langchain + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/LocalAI/pkg/langchain" +) + +type LLM struct { + base.Base + + langchain *langchain.HuggingFace + model string +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + llm.langchain, _ = langchain.NewHuggingFace(opts.Model) + llm.model = opts.Model + return nil +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + o := []langchain.PredictOption{ + langchain.SetModel(llm.model), + langchain.SetMaxTokens(int(opts.Tokens)), + langchain.SetTemperature(float64(opts.Temperature)), + langchain.SetStopWords(opts.StopPrompts), + } + pred, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...) + if err != nil { + return "", err + } + return pred.Completion, nil +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + o := []langchain.PredictOption{ + langchain.SetModel(llm.model), + langchain.SetMaxTokens(int(opts.Tokens)), + langchain.SetTemperature(float64(opts.Temperature)), + langchain.SetStopWords(opts.StopPrompts), + } + go func() { + res, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res.Completion + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/llama/llama.go b/pkg/grpc/llm/llama/llama.go new file mode 100644 index 0000000..82063b7 --- /dev/null +++ b/pkg/grpc/llm/llama/llama.go @@ -0,0 +1,170 @@ +package llama + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/go-llama.cpp" +) + +type LLM struct { + base.Base + + llama *llama.LLama +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + llamaOpts := []llama.ModelOption{} + + if opts.ContextSize != 0 { + llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize))) + } + if opts.F16Memory { + llamaOpts = append(llamaOpts, llama.EnableF16Memory) + } + if opts.Embeddings { + llamaOpts = append(llamaOpts, llama.EnableEmbeddings) + } + if opts.NGPULayers != 0 { + llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers))) + } + + llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap)) + llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU)) + llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit)) + if opts.NBatch != 0 { + llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch))) + } else { + llamaOpts = append(llamaOpts, llama.SetNBatch(512)) + } + + if opts.NUMA { + llamaOpts = append(llamaOpts, llama.EnableNUMA) + } + + if opts.LowVRAM { + llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) + } + + model, err := llama.New(opts.Model, llamaOpts...) + llm.llama = model + return err +} + +func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { + predictOptions := []llama.PredictOption{ + llama.SetTemperature(float64(opts.Temperature)), + llama.SetTopP(float64(opts.TopP)), + llama.SetTopK(int(opts.TopK)), + llama.SetTokens(int(opts.Tokens)), + llama.SetThreads(int(opts.Threads)), + } + + if opts.PromptCacheAll { + predictOptions = append(predictOptions, llama.EnablePromptCacheAll) + } + + if opts.PromptCacheRO { + predictOptions = append(predictOptions, llama.EnablePromptCacheRO) + } + + predictOptions = append(predictOptions, llama.WithGrammar(opts.Grammar)) + + // Expected absolute path + if opts.PromptCachePath != "" { + predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath)) + } + + if opts.Mirostat != 0 { + predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat))) + } + + if opts.MirostatETA != 0 { + predictOptions = append(predictOptions, llama.SetMirostatETA(float64(opts.MirostatETA))) + } + + if opts.MirostatTAU != 0 { + predictOptions = append(predictOptions, llama.SetMirostatTAU(float64(opts.MirostatTAU))) + } + + if opts.Debug { + predictOptions = append(predictOptions, llama.Debug) + } + + predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...)) + + if opts.PresencePenalty != 0 { + predictOptions = append(predictOptions, llama.SetPenalty(float64(opts.PresencePenalty))) + } + + if opts.NKeep != 0 { + predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep))) + } + + if opts.Batch != 0 { + predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch))) + } + + if opts.F16KV { + predictOptions = append(predictOptions, llama.EnableF16KV) + } + + if opts.IgnoreEOS { + predictOptions = append(predictOptions, llama.IgnoreEOS) + } + + if opts.Seed != 0 { + predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed))) + } + + //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) + + predictOptions = append(predictOptions, llama.SetFrequencyPenalty(float64(opts.FrequencyPenalty))) + predictOptions = append(predictOptions, llama.SetMlock(opts.MLock)) + predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap)) + predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU)) + predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit)) + predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ))) + predictOptions = append(predictOptions, llama.SetTypicalP(float64(opts.TypicalP))) + return predictOptions +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + predictOptions := buildPredictOptions(opts) + + predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool { + results <- token + return true + })) + + go func() { + _, err := llm.llama.Predict(opts.Prompt, predictOptions...) + if err != nil { + fmt.Println("err: ", err) + } + close(results) + }() + + return nil +} + +func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + predictOptions := buildPredictOptions(opts) + + if len(opts.EmbeddingTokens) > 0 { + tokens := []int{} + for _, t := range opts.EmbeddingTokens { + tokens = append(tokens, int(t)) + } + return llm.llama.TokenEmbeddings(tokens, predictOptions...) + } + + return llm.llama.Embeddings(opts.Embeddings, predictOptions...) +} diff --git a/pkg/grpc/llm/rwkv/rwkv.go b/pkg/grpc/llm/rwkv/rwkv.go new file mode 100644 index 0000000..f54c14b --- /dev/null +++ b/pkg/grpc/llm/rwkv/rwkv.go @@ -0,0 +1,71 @@ +package rwkv + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + "path/filepath" + + "github.com/donomii/go-rwkv.cpp" + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" +) + +const tokenizerSuffix = ".tokenizer.json" + +type LLM struct { + base.Base + + rwkv *rwkv.RwkvState +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + modelPath := filepath.Dir(opts.Model) + modelFile := filepath.Base(opts.Model) + model := rwkv.LoadFiles(opts.Model, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads())) + + if model == nil { + return fmt.Errorf("could not load model") + } + llm.rwkv = model + return nil +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + + stopWord := "\n" + if len(opts.StopPrompts) > 0 { + stopWord = opts.StopPrompts[0] + } + + if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil { + return "", err + } + + response := llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), nil) + + return response, nil +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + + stopWord := "\n" + if len(opts.StopPrompts) > 0 { + stopWord = opts.StopPrompts[0] + } + + if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil { + fmt.Println("Error processing input: ", err) + return + } + + llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), func(s string) bool { + results <- s + return true + }) + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/transformers/dolly.go b/pkg/grpc/llm/transformers/dolly.go new file mode 100644 index 0000000..d5f3093 --- /dev/null +++ b/pkg/grpc/llm/transformers/dolly.go @@ -0,0 +1,43 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Dolly struct { + base.Base + + dolly *transformers.Dolly +} + +func (llm *Dolly) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewDolly(opts.Model) + llm.dolly = model + return err +} + +func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) { + return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/transformers/falcon.go b/pkg/grpc/llm/transformers/falcon.go new file mode 100644 index 0000000..982e43e --- /dev/null +++ b/pkg/grpc/llm/transformers/falcon.go @@ -0,0 +1,43 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Falcon struct { + base.Base + + falcon *transformers.Falcon +} + +func (llm *Falcon) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewFalcon(opts.Model) + llm.falcon = model + return err +} + +func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) { + return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + + return nil +} diff --git a/pkg/grpc/llm/transformers/gpt2.go b/pkg/grpc/llm/transformers/gpt2.go new file mode 100644 index 0000000..85a4112 --- /dev/null +++ b/pkg/grpc/llm/transformers/gpt2.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type GPT2 struct { + base.Base + + gpt2 *transformers.GPT2 +} + +func (llm *GPT2) Load(opts *pb.ModelOptions) error { + model, err := transformers.New(opts.Model) + llm.gpt2 = model + return err +} + +func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + return nil +} diff --git a/pkg/grpc/llm/transformers/gptj.go b/pkg/grpc/llm/transformers/gptj.go new file mode 100644 index 0000000..e2bc3bf --- /dev/null +++ b/pkg/grpc/llm/transformers/gptj.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type GPTJ struct { + base.Base + + gptj *transformers.GPTJ +} + +func (llm *GPTJ) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewGPTJ(opts.Model) + llm.gptj = model + return err +} + +func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + return nil +} diff --git a/pkg/grpc/llm/transformers/gptneox.go b/pkg/grpc/llm/transformers/gptneox.go new file mode 100644 index 0000000..ca6db94 --- /dev/null +++ b/pkg/grpc/llm/transformers/gptneox.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type GPTNeoX struct { + base.Base + + gptneox *transformers.GPTNeoX +} + +func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewGPTNeoX(opts.Model) + llm.gptneox = model + return err +} + +func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) { + return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + return nil +} diff --git a/pkg/grpc/llm/transformers/mpt.go b/pkg/grpc/llm/transformers/mpt.go new file mode 100644 index 0000000..d2b9ff1 --- /dev/null +++ b/pkg/grpc/llm/transformers/mpt.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type MPT struct { + base.Base + + mpt *transformers.MPT +} + +func (llm *MPT) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewMPT(opts.Model) + llm.mpt = model + return err +} + +func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) { + return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + return nil +} diff --git a/pkg/grpc/llm/transformers/predict.go b/pkg/grpc/llm/transformers/predict.go new file mode 100644 index 0000000..861d119 --- /dev/null +++ b/pkg/grpc/llm/transformers/predict.go @@ -0,0 +1,26 @@ +package transformers + +import ( + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +func buildPredictOptions(opts *pb.PredictOptions) []transformers.PredictOption { + predictOptions := []transformers.PredictOption{ + transformers.SetTemperature(float64(opts.Temperature)), + transformers.SetTopP(float64(opts.TopP)), + transformers.SetTopK(int(opts.TopK)), + transformers.SetTokens(int(opts.Tokens)), + transformers.SetThreads(int(opts.Threads)), + } + + if opts.Batch != 0 { + predictOptions = append(predictOptions, transformers.SetBatch(int(opts.Batch))) + } + + if opts.Seed != 0 { + predictOptions = append(predictOptions, transformers.SetSeed(int(opts.Seed))) + } + + return predictOptions +} diff --git a/pkg/grpc/llm/transformers/replit.go b/pkg/grpc/llm/transformers/replit.go new file mode 100644 index 0000000..4b26ffd --- /dev/null +++ b/pkg/grpc/llm/transformers/replit.go @@ -0,0 +1,42 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Replit struct { + base.Base + + replit *transformers.Replit +} + +func (llm *Replit) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewReplit(opts.Model) + llm.replit = model + return err +} + +func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) { + return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + return nil +} diff --git a/pkg/grpc/llm/transformers/starcoder.go b/pkg/grpc/llm/transformers/starcoder.go new file mode 100644 index 0000000..7631274 --- /dev/null +++ b/pkg/grpc/llm/transformers/starcoder.go @@ -0,0 +1,43 @@ +package transformers + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + + transformers "github.com/go-skynet/go-ggml-transformers.cpp" +) + +type Starcoder struct { + base.Base + + starcoder *transformers.Starcoder +} + +func (llm *Starcoder) Load(opts *pb.ModelOptions) error { + model, err := transformers.NewStarcoder(opts.Model) + llm.starcoder = model + return err +} + +func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) { + return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +// fallback to Predict +func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error { + go func() { + res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) + + if err != nil { + fmt.Println("err: ", err) + } + results <- res + close(results) + }() + + return nil +} diff --git a/pkg/grpc/proto/backend.pb.go b/pkg/grpc/proto/backend.pb.go new file mode 100644 index 0000000..dcf14a3 --- /dev/null +++ b/pkg/grpc/proto/backend.pb.go @@ -0,0 +1,1458 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.26.0 +// protoc v3.15.8 +// source: pkg/grpc/proto/backend.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type HealthMessage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *HealthMessage) Reset() { + *x = HealthMessage{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HealthMessage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HealthMessage) ProtoMessage() {} + +func (x *HealthMessage) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HealthMessage.ProtoReflect.Descriptor instead. +func (*HealthMessage) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{0} +} + +// The request message containing the user's name. +type PredictOptions struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Prompt string `protobuf:"bytes,1,opt,name=Prompt,proto3" json:"Prompt,omitempty"` + Seed int32 `protobuf:"varint,2,opt,name=Seed,proto3" json:"Seed,omitempty"` + Threads int32 `protobuf:"varint,3,opt,name=Threads,proto3" json:"Threads,omitempty"` + Tokens int32 `protobuf:"varint,4,opt,name=Tokens,proto3" json:"Tokens,omitempty"` + TopK int32 `protobuf:"varint,5,opt,name=TopK,proto3" json:"TopK,omitempty"` + Repeat int32 `protobuf:"varint,6,opt,name=Repeat,proto3" json:"Repeat,omitempty"` + Batch int32 `protobuf:"varint,7,opt,name=Batch,proto3" json:"Batch,omitempty"` + NKeep int32 `protobuf:"varint,8,opt,name=NKeep,proto3" json:"NKeep,omitempty"` + Temperature float32 `protobuf:"fixed32,9,opt,name=Temperature,proto3" json:"Temperature,omitempty"` + Penalty float32 `protobuf:"fixed32,10,opt,name=Penalty,proto3" json:"Penalty,omitempty"` + F16KV bool `protobuf:"varint,11,opt,name=F16KV,proto3" json:"F16KV,omitempty"` + DebugMode bool `protobuf:"varint,12,opt,name=DebugMode,proto3" json:"DebugMode,omitempty"` + StopPrompts []string `protobuf:"bytes,13,rep,name=StopPrompts,proto3" json:"StopPrompts,omitempty"` + IgnoreEOS bool `protobuf:"varint,14,opt,name=IgnoreEOS,proto3" json:"IgnoreEOS,omitempty"` + TailFreeSamplingZ float32 `protobuf:"fixed32,15,opt,name=TailFreeSamplingZ,proto3" json:"TailFreeSamplingZ,omitempty"` + TypicalP float32 `protobuf:"fixed32,16,opt,name=TypicalP,proto3" json:"TypicalP,omitempty"` + FrequencyPenalty float32 `protobuf:"fixed32,17,opt,name=FrequencyPenalty,proto3" json:"FrequencyPenalty,omitempty"` + PresencePenalty float32 `protobuf:"fixed32,18,opt,name=PresencePenalty,proto3" json:"PresencePenalty,omitempty"` + Mirostat int32 `protobuf:"varint,19,opt,name=Mirostat,proto3" json:"Mirostat,omitempty"` + MirostatETA float32 `protobuf:"fixed32,20,opt,name=MirostatETA,proto3" json:"MirostatETA,omitempty"` + MirostatTAU float32 `protobuf:"fixed32,21,opt,name=MirostatTAU,proto3" json:"MirostatTAU,omitempty"` + PenalizeNL bool `protobuf:"varint,22,opt,name=PenalizeNL,proto3" json:"PenalizeNL,omitempty"` + LogitBias string `protobuf:"bytes,23,opt,name=LogitBias,proto3" json:"LogitBias,omitempty"` + MLock bool `protobuf:"varint,25,opt,name=MLock,proto3" json:"MLock,omitempty"` + MMap bool `protobuf:"varint,26,opt,name=MMap,proto3" json:"MMap,omitempty"` + PromptCacheAll bool `protobuf:"varint,27,opt,name=PromptCacheAll,proto3" json:"PromptCacheAll,omitempty"` + PromptCacheRO bool `protobuf:"varint,28,opt,name=PromptCacheRO,proto3" json:"PromptCacheRO,omitempty"` + Grammar string `protobuf:"bytes,29,opt,name=Grammar,proto3" json:"Grammar,omitempty"` + MainGPU string `protobuf:"bytes,30,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"` + TensorSplit string `protobuf:"bytes,31,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"` + TopP float32 `protobuf:"fixed32,32,opt,name=TopP,proto3" json:"TopP,omitempty"` + PromptCachePath string `protobuf:"bytes,33,opt,name=PromptCachePath,proto3" json:"PromptCachePath,omitempty"` + Debug bool `protobuf:"varint,34,opt,name=Debug,proto3" json:"Debug,omitempty"` + EmbeddingTokens []int32 `protobuf:"varint,35,rep,packed,name=EmbeddingTokens,proto3" json:"EmbeddingTokens,omitempty"` + Embeddings string `protobuf:"bytes,36,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` +} + +func (x *PredictOptions) Reset() { + *x = PredictOptions{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *PredictOptions) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PredictOptions) ProtoMessage() {} + +func (x *PredictOptions) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PredictOptions.ProtoReflect.Descriptor instead. +func (*PredictOptions) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{1} +} + +func (x *PredictOptions) GetPrompt() string { + if x != nil { + return x.Prompt + } + return "" +} + +func (x *PredictOptions) GetSeed() int32 { + if x != nil { + return x.Seed + } + return 0 +} + +func (x *PredictOptions) GetThreads() int32 { + if x != nil { + return x.Threads + } + return 0 +} + +func (x *PredictOptions) GetTokens() int32 { + if x != nil { + return x.Tokens + } + return 0 +} + +func (x *PredictOptions) GetTopK() int32 { + if x != nil { + return x.TopK + } + return 0 +} + +func (x *PredictOptions) GetRepeat() int32 { + if x != nil { + return x.Repeat + } + return 0 +} + +func (x *PredictOptions) GetBatch() int32 { + if x != nil { + return x.Batch + } + return 0 +} + +func (x *PredictOptions) GetNKeep() int32 { + if x != nil { + return x.NKeep + } + return 0 +} + +func (x *PredictOptions) GetTemperature() float32 { + if x != nil { + return x.Temperature + } + return 0 +} + +func (x *PredictOptions) GetPenalty() float32 { + if x != nil { + return x.Penalty + } + return 0 +} + +func (x *PredictOptions) GetF16KV() bool { + if x != nil { + return x.F16KV + } + return false +} + +func (x *PredictOptions) GetDebugMode() bool { + if x != nil { + return x.DebugMode + } + return false +} + +func (x *PredictOptions) GetStopPrompts() []string { + if x != nil { + return x.StopPrompts + } + return nil +} + +func (x *PredictOptions) GetIgnoreEOS() bool { + if x != nil { + return x.IgnoreEOS + } + return false +} + +func (x *PredictOptions) GetTailFreeSamplingZ() float32 { + if x != nil { + return x.TailFreeSamplingZ + } + return 0 +} + +func (x *PredictOptions) GetTypicalP() float32 { + if x != nil { + return x.TypicalP + } + return 0 +} + +func (x *PredictOptions) GetFrequencyPenalty() float32 { + if x != nil { + return x.FrequencyPenalty + } + return 0 +} + +func (x *PredictOptions) GetPresencePenalty() float32 { + if x != nil { + return x.PresencePenalty + } + return 0 +} + +func (x *PredictOptions) GetMirostat() int32 { + if x != nil { + return x.Mirostat + } + return 0 +} + +func (x *PredictOptions) GetMirostatETA() float32 { + if x != nil { + return x.MirostatETA + } + return 0 +} + +func (x *PredictOptions) GetMirostatTAU() float32 { + if x != nil { + return x.MirostatTAU + } + return 0 +} + +func (x *PredictOptions) GetPenalizeNL() bool { + if x != nil { + return x.PenalizeNL + } + return false +} + +func (x *PredictOptions) GetLogitBias() string { + if x != nil { + return x.LogitBias + } + return "" +} + +func (x *PredictOptions) GetMLock() bool { + if x != nil { + return x.MLock + } + return false +} + +func (x *PredictOptions) GetMMap() bool { + if x != nil { + return x.MMap + } + return false +} + +func (x *PredictOptions) GetPromptCacheAll() bool { + if x != nil { + return x.PromptCacheAll + } + return false +} + +func (x *PredictOptions) GetPromptCacheRO() bool { + if x != nil { + return x.PromptCacheRO + } + return false +} + +func (x *PredictOptions) GetGrammar() string { + if x != nil { + return x.Grammar + } + return "" +} + +func (x *PredictOptions) GetMainGPU() string { + if x != nil { + return x.MainGPU + } + return "" +} + +func (x *PredictOptions) GetTensorSplit() string { + if x != nil { + return x.TensorSplit + } + return "" +} + +func (x *PredictOptions) GetTopP() float32 { + if x != nil { + return x.TopP + } + return 0 +} + +func (x *PredictOptions) GetPromptCachePath() string { + if x != nil { + return x.PromptCachePath + } + return "" +} + +func (x *PredictOptions) GetDebug() bool { + if x != nil { + return x.Debug + } + return false +} + +func (x *PredictOptions) GetEmbeddingTokens() []int32 { + if x != nil { + return x.EmbeddingTokens + } + return nil +} + +func (x *PredictOptions) GetEmbeddings() string { + if x != nil { + return x.Embeddings + } + return "" +} + +// The response message containing the result +type Reply struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` +} + +func (x *Reply) Reset() { + *x = Reply{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Reply) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Reply) ProtoMessage() {} + +func (x *Reply) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Reply.ProtoReflect.Descriptor instead. +func (*Reply) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{2} +} + +func (x *Reply) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +type ModelOptions struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Model string `protobuf:"bytes,1,opt,name=Model,proto3" json:"Model,omitempty"` + ContextSize int32 `protobuf:"varint,2,opt,name=ContextSize,proto3" json:"ContextSize,omitempty"` + Seed int32 `protobuf:"varint,3,opt,name=Seed,proto3" json:"Seed,omitempty"` + NBatch int32 `protobuf:"varint,4,opt,name=NBatch,proto3" json:"NBatch,omitempty"` + F16Memory bool `protobuf:"varint,5,opt,name=F16Memory,proto3" json:"F16Memory,omitempty"` + MLock bool `protobuf:"varint,6,opt,name=MLock,proto3" json:"MLock,omitempty"` + MMap bool `protobuf:"varint,7,opt,name=MMap,proto3" json:"MMap,omitempty"` + VocabOnly bool `protobuf:"varint,8,opt,name=VocabOnly,proto3" json:"VocabOnly,omitempty"` + LowVRAM bool `protobuf:"varint,9,opt,name=LowVRAM,proto3" json:"LowVRAM,omitempty"` + Embeddings bool `protobuf:"varint,10,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` + NUMA bool `protobuf:"varint,11,opt,name=NUMA,proto3" json:"NUMA,omitempty"` + NGPULayers int32 `protobuf:"varint,12,opt,name=NGPULayers,proto3" json:"NGPULayers,omitempty"` + MainGPU string `protobuf:"bytes,13,opt,name=MainGPU,proto3" json:"MainGPU,omitempty"` + TensorSplit string `protobuf:"bytes,14,opt,name=TensorSplit,proto3" json:"TensorSplit,omitempty"` + Threads int32 `protobuf:"varint,15,opt,name=Threads,proto3" json:"Threads,omitempty"` + LibrarySearchPath string `protobuf:"bytes,16,opt,name=LibrarySearchPath,proto3" json:"LibrarySearchPath,omitempty"` +} + +func (x *ModelOptions) Reset() { + *x = ModelOptions{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ModelOptions) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ModelOptions) ProtoMessage() {} + +func (x *ModelOptions) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ModelOptions.ProtoReflect.Descriptor instead. +func (*ModelOptions) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{3} +} + +func (x *ModelOptions) GetModel() string { + if x != nil { + return x.Model + } + return "" +} + +func (x *ModelOptions) GetContextSize() int32 { + if x != nil { + return x.ContextSize + } + return 0 +} + +func (x *ModelOptions) GetSeed() int32 { + if x != nil { + return x.Seed + } + return 0 +} + +func (x *ModelOptions) GetNBatch() int32 { + if x != nil { + return x.NBatch + } + return 0 +} + +func (x *ModelOptions) GetF16Memory() bool { + if x != nil { + return x.F16Memory + } + return false +} + +func (x *ModelOptions) GetMLock() bool { + if x != nil { + return x.MLock + } + return false +} + +func (x *ModelOptions) GetMMap() bool { + if x != nil { + return x.MMap + } + return false +} + +func (x *ModelOptions) GetVocabOnly() bool { + if x != nil { + return x.VocabOnly + } + return false +} + +func (x *ModelOptions) GetLowVRAM() bool { + if x != nil { + return x.LowVRAM + } + return false +} + +func (x *ModelOptions) GetEmbeddings() bool { + if x != nil { + return x.Embeddings + } + return false +} + +func (x *ModelOptions) GetNUMA() bool { + if x != nil { + return x.NUMA + } + return false +} + +func (x *ModelOptions) GetNGPULayers() int32 { + if x != nil { + return x.NGPULayers + } + return 0 +} + +func (x *ModelOptions) GetMainGPU() string { + if x != nil { + return x.MainGPU + } + return "" +} + +func (x *ModelOptions) GetTensorSplit() string { + if x != nil { + return x.TensorSplit + } + return "" +} + +func (x *ModelOptions) GetThreads() int32 { + if x != nil { + return x.Threads + } + return 0 +} + +func (x *ModelOptions) GetLibrarySearchPath() string { + if x != nil { + return x.LibrarySearchPath + } + return "" +} + +type Result struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Message string `protobuf:"bytes,1,opt,name=message,proto3" json:"message,omitempty"` + Success bool `protobuf:"varint,2,opt,name=success,proto3" json:"success,omitempty"` +} + +func (x *Result) Reset() { + *x = Result{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Result) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Result) ProtoMessage() {} + +func (x *Result) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Result.ProtoReflect.Descriptor instead. +func (*Result) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{4} +} + +func (x *Result) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *Result) GetSuccess() bool { + if x != nil { + return x.Success + } + return false +} + +type EmbeddingResult struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Embeddings []float32 `protobuf:"fixed32,1,rep,packed,name=embeddings,proto3" json:"embeddings,omitempty"` +} + +func (x *EmbeddingResult) Reset() { + *x = EmbeddingResult{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EmbeddingResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmbeddingResult) ProtoMessage() {} + +func (x *EmbeddingResult) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmbeddingResult.ProtoReflect.Descriptor instead. +func (*EmbeddingResult) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{5} +} + +func (x *EmbeddingResult) GetEmbeddings() []float32 { + if x != nil { + return x.Embeddings + } + return nil +} + +type TranscriptRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Dst string `protobuf:"bytes,2,opt,name=dst,proto3" json:"dst,omitempty"` + Language string `protobuf:"bytes,3,opt,name=language,proto3" json:"language,omitempty"` + Threads uint32 `protobuf:"varint,4,opt,name=threads,proto3" json:"threads,omitempty"` +} + +func (x *TranscriptRequest) Reset() { + *x = TranscriptRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TranscriptRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TranscriptRequest) ProtoMessage() {} + +func (x *TranscriptRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[6] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TranscriptRequest.ProtoReflect.Descriptor instead. +func (*TranscriptRequest) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{6} +} + +func (x *TranscriptRequest) GetDst() string { + if x != nil { + return x.Dst + } + return "" +} + +func (x *TranscriptRequest) GetLanguage() string { + if x != nil { + return x.Language + } + return "" +} + +func (x *TranscriptRequest) GetThreads() uint32 { + if x != nil { + return x.Threads + } + return 0 +} + +type TranscriptResult struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Segments []*TranscriptSegment `protobuf:"bytes,1,rep,name=segments,proto3" json:"segments,omitempty"` + Text string `protobuf:"bytes,2,opt,name=text,proto3" json:"text,omitempty"` +} + +func (x *TranscriptResult) Reset() { + *x = TranscriptResult{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TranscriptResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TranscriptResult) ProtoMessage() {} + +func (x *TranscriptResult) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[7] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TranscriptResult.ProtoReflect.Descriptor instead. +func (*TranscriptResult) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{7} +} + +func (x *TranscriptResult) GetSegments() []*TranscriptSegment { + if x != nil { + return x.Segments + } + return nil +} + +func (x *TranscriptResult) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +type TranscriptSegment struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Id int32 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + Start int64 `protobuf:"varint,2,opt,name=start,proto3" json:"start,omitempty"` + End int64 `protobuf:"varint,3,opt,name=end,proto3" json:"end,omitempty"` + Text string `protobuf:"bytes,4,opt,name=text,proto3" json:"text,omitempty"` + Tokens []int32 `protobuf:"varint,5,rep,packed,name=tokens,proto3" json:"tokens,omitempty"` +} + +func (x *TranscriptSegment) Reset() { + *x = TranscriptSegment{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TranscriptSegment) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TranscriptSegment) ProtoMessage() {} + +func (x *TranscriptSegment) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[8] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TranscriptSegment.ProtoReflect.Descriptor instead. +func (*TranscriptSegment) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{8} +} + +func (x *TranscriptSegment) GetId() int32 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *TranscriptSegment) GetStart() int64 { + if x != nil { + return x.Start + } + return 0 +} + +func (x *TranscriptSegment) GetEnd() int64 { + if x != nil { + return x.End + } + return 0 +} + +func (x *TranscriptSegment) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +func (x *TranscriptSegment) GetTokens() []int32 { + if x != nil { + return x.Tokens + } + return nil +} + +type GenerateImageRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Height int32 `protobuf:"varint,1,opt,name=height,proto3" json:"height,omitempty"` + Width int32 `protobuf:"varint,2,opt,name=width,proto3" json:"width,omitempty"` + Mode int32 `protobuf:"varint,3,opt,name=mode,proto3" json:"mode,omitempty"` + Step int32 `protobuf:"varint,4,opt,name=step,proto3" json:"step,omitempty"` + Seed int32 `protobuf:"varint,5,opt,name=seed,proto3" json:"seed,omitempty"` + PositivePrompt string `protobuf:"bytes,6,opt,name=positive_prompt,json=positivePrompt,proto3" json:"positive_prompt,omitempty"` + NegativePrompt string `protobuf:"bytes,7,opt,name=negative_prompt,json=negativePrompt,proto3" json:"negative_prompt,omitempty"` + Dst string `protobuf:"bytes,8,opt,name=dst,proto3" json:"dst,omitempty"` +} + +func (x *GenerateImageRequest) Reset() { + *x = GenerateImageRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *GenerateImageRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GenerateImageRequest) ProtoMessage() {} + +func (x *GenerateImageRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[9] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GenerateImageRequest.ProtoReflect.Descriptor instead. +func (*GenerateImageRequest) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{9} +} + +func (x *GenerateImageRequest) GetHeight() int32 { + if x != nil { + return x.Height + } + return 0 +} + +func (x *GenerateImageRequest) GetWidth() int32 { + if x != nil { + return x.Width + } + return 0 +} + +func (x *GenerateImageRequest) GetMode() int32 { + if x != nil { + return x.Mode + } + return 0 +} + +func (x *GenerateImageRequest) GetStep() int32 { + if x != nil { + return x.Step + } + return 0 +} + +func (x *GenerateImageRequest) GetSeed() int32 { + if x != nil { + return x.Seed + } + return 0 +} + +func (x *GenerateImageRequest) GetPositivePrompt() string { + if x != nil { + return x.PositivePrompt + } + return "" +} + +func (x *GenerateImageRequest) GetNegativePrompt() string { + if x != nil { + return x.NegativePrompt + } + return "" +} + +func (x *GenerateImageRequest) GetDst() string { + if x != nil { + return x.Dst + } + return "" +} + +type TTSRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Text string `protobuf:"bytes,1,opt,name=text,proto3" json:"text,omitempty"` + Model string `protobuf:"bytes,2,opt,name=model,proto3" json:"model,omitempty"` + Dst string `protobuf:"bytes,3,opt,name=dst,proto3" json:"dst,omitempty"` +} + +func (x *TTSRequest) Reset() { + *x = TTSRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TTSRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TTSRequest) ProtoMessage() {} + +func (x *TTSRequest) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_backend_proto_msgTypes[10] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TTSRequest.ProtoReflect.Descriptor instead. +func (*TTSRequest) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_backend_proto_rawDescGZIP(), []int{10} +} + +func (x *TTSRequest) GetText() string { + if x != nil { + return x.Text + } + return "" +} + +func (x *TTSRequest) GetModel() string { + if x != nil { + return x.Model + } + return "" +} + +func (x *TTSRequest) GetDst() string { + if x != nil { + return x.Dst + } + return "" +} + +var File_pkg_grpc_proto_backend_proto protoreflect.FileDescriptor + +var file_pkg_grpc_proto_backend_proto_rawDesc = []byte{ + 0x0a, 0x1c, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x2f, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x07, + 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x22, 0x0f, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xa0, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, + 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x50, + 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x72, 0x6f, + 0x6d, 0x70, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x04, 0x53, 0x65, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, + 0x64, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, + 0x73, 0x12, 0x16, 0x0a, 0x06, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x06, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x6f, 0x70, + 0x4b, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x4b, 0x12, 0x16, 0x0a, + 0x06, 0x52, 0x65, 0x70, 0x65, 0x61, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x52, + 0x65, 0x70, 0x65, 0x61, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x42, 0x61, 0x74, 0x63, 0x68, 0x18, 0x07, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x4e, + 0x4b, 0x65, 0x65, 0x70, 0x18, 0x08, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x4e, 0x4b, 0x65, 0x65, + 0x70, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, 0x75, 0x72, 0x65, + 0x18, 0x09, 0x20, 0x01, 0x28, 0x02, 0x52, 0x0b, 0x54, 0x65, 0x6d, 0x70, 0x65, 0x72, 0x61, 0x74, + 0x75, 0x72, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x0a, + 0x20, 0x01, 0x28, 0x02, 0x52, 0x07, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, 0x14, 0x0a, + 0x05, 0x46, 0x31, 0x36, 0x4b, 0x56, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x46, 0x31, + 0x36, 0x4b, 0x56, 0x12, 0x1c, 0x0a, 0x09, 0x44, 0x65, 0x62, 0x75, 0x67, 0x4d, 0x6f, 0x64, 0x65, + 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x44, 0x65, 0x62, 0x75, 0x67, 0x4d, 0x6f, 0x64, + 0x65, 0x12, 0x20, 0x0a, 0x0b, 0x53, 0x74, 0x6f, 0x70, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x73, + 0x18, 0x0d, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x53, 0x74, 0x6f, 0x70, 0x50, 0x72, 0x6f, 0x6d, + 0x70, 0x74, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x45, 0x4f, 0x53, + 0x18, 0x0e, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x49, 0x67, 0x6e, 0x6f, 0x72, 0x65, 0x45, 0x4f, + 0x53, 0x12, 0x2c, 0x0a, 0x11, 0x54, 0x61, 0x69, 0x6c, 0x46, 0x72, 0x65, 0x65, 0x53, 0x61, 0x6d, + 0x70, 0x6c, 0x69, 0x6e, 0x67, 0x5a, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x02, 0x52, 0x11, 0x54, 0x61, + 0x69, 0x6c, 0x46, 0x72, 0x65, 0x65, 0x53, 0x61, 0x6d, 0x70, 0x6c, 0x69, 0x6e, 0x67, 0x5a, 0x12, + 0x1a, 0x0a, 0x08, 0x54, 0x79, 0x70, 0x69, 0x63, 0x61, 0x6c, 0x50, 0x18, 0x10, 0x20, 0x01, 0x28, + 0x02, 0x52, 0x08, 0x54, 0x79, 0x70, 0x69, 0x63, 0x61, 0x6c, 0x50, 0x12, 0x2a, 0x0a, 0x10, 0x46, + 0x72, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, + 0x11, 0x20, 0x01, 0x28, 0x02, 0x52, 0x10, 0x46, 0x72, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x63, 0x79, + 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x65, 0x73, 0x65, + 0x6e, 0x63, 0x65, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, 0x79, 0x18, 0x12, 0x20, 0x01, 0x28, 0x02, + 0x52, 0x0f, 0x50, 0x72, 0x65, 0x73, 0x65, 0x6e, 0x63, 0x65, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x74, + 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x18, 0x13, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x08, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x12, 0x20, 0x0a, + 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x45, 0x54, 0x41, 0x18, 0x14, 0x20, 0x01, + 0x28, 0x02, 0x52, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x45, 0x54, 0x41, 0x12, + 0x20, 0x0a, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x54, 0x41, 0x55, 0x18, 0x15, + 0x20, 0x01, 0x28, 0x02, 0x52, 0x0b, 0x4d, 0x69, 0x72, 0x6f, 0x73, 0x74, 0x61, 0x74, 0x54, 0x41, + 0x55, 0x12, 0x1e, 0x0a, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x18, + 0x16, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, + 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x18, 0x17, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, + 0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, + 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, + 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, + 0x6c, 0x12, 0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, + 0x52, 0x4f, 0x18, 0x1c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, + 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, 0x12, 0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, + 0x61, 0x72, 0x18, 0x1d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, + 0x72, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, + 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, + 0x04, 0x54, 0x6f, 0x70, 0x50, 0x18, 0x20, 0x20, 0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, + 0x50, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, + 0x50, 0x61, 0x74, 0x68, 0x18, 0x21, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, + 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, + 0x65, 0x62, 0x75, 0x67, 0x18, 0x22, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, + 0x67, 0x12, 0x28, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x54, 0x6f, + 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x23, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0f, 0x45, 0x6d, 0x62, 0x65, + 0x64, 0x64, 0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x45, + 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x21, 0x0a, 0x05, 0x52, + 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xca, + 0x03, 0x0a, 0x0c, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, + 0x14, 0x0a, 0x05, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x20, 0x0a, 0x0b, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x78, 0x74, + 0x53, 0x69, 0x7a, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0b, 0x43, 0x6f, 0x6e, 0x74, + 0x65, 0x78, 0x74, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x53, 0x65, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x4e, + 0x42, 0x61, 0x74, 0x63, 0x68, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x4e, 0x42, 0x61, + 0x74, 0x63, 0x68, 0x12, 0x1c, 0x0a, 0x09, 0x46, 0x31, 0x36, 0x4d, 0x65, 0x6d, 0x6f, 0x72, 0x79, + 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x46, 0x31, 0x36, 0x4d, 0x65, 0x6d, 0x6f, 0x72, + 0x79, 0x12, 0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, + 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x1c, 0x0a, 0x09, 0x56, + 0x6f, 0x63, 0x61, 0x62, 0x4f, 0x6e, 0x6c, 0x79, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, + 0x56, 0x6f, 0x63, 0x61, 0x62, 0x4f, 0x6e, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x4c, 0x6f, 0x77, + 0x56, 0x52, 0x41, 0x4d, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x4c, 0x6f, 0x77, 0x56, + 0x52, 0x41, 0x4d, 0x12, 0x1e, 0x0a, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, + 0x73, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, + 0x6e, 0x67, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x55, 0x4d, 0x41, 0x18, 0x0b, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x04, 0x4e, 0x55, 0x4d, 0x41, 0x12, 0x1e, 0x0a, 0x0a, 0x4e, 0x47, 0x50, 0x55, 0x4c, + 0x61, 0x79, 0x65, 0x72, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0a, 0x4e, 0x47, 0x50, + 0x55, 0x4c, 0x61, 0x79, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, + 0x50, 0x55, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, + 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, + 0x18, 0x0e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, + 0x6c, 0x69, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x18, 0x0f, + 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x54, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x12, 0x2c, 0x0a, + 0x11, 0x4c, 0x69, 0x62, 0x72, 0x61, 0x72, 0x79, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x50, 0x61, + 0x74, 0x68, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x52, 0x11, 0x4c, 0x69, 0x62, 0x72, 0x61, 0x72, + 0x79, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x50, 0x61, 0x74, 0x68, 0x22, 0x3c, 0x0a, 0x06, 0x52, + 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, + 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x22, 0x31, 0x0a, 0x0f, 0x45, 0x6d, 0x62, + 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, + 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x02, + 0x52, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x5b, 0x0a, 0x11, + 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x64, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x6c, 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6c, 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x12, + 0x18, 0x0a, 0x07, 0x74, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, + 0x52, 0x07, 0x74, 0x68, 0x72, 0x65, 0x61, 0x64, 0x73, 0x22, 0x5e, 0x0a, 0x10, 0x54, 0x72, 0x61, + 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x36, 0x0a, + 0x08, 0x73, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x1a, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, + 0x72, 0x69, 0x70, 0x74, 0x53, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x52, 0x08, 0x73, 0x65, 0x67, + 0x6d, 0x65, 0x6e, 0x74, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x65, 0x78, 0x74, 0x18, 0x02, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x65, 0x78, 0x74, 0x22, 0x77, 0x0a, 0x11, 0x54, 0x72, 0x61, + 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x53, 0x65, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x0e, + 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x02, 0x69, 0x64, 0x12, 0x14, + 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x73, + 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x65, 0x78, 0x74, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x65, 0x78, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x6f, + 0x6b, 0x65, 0x6e, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x05, 0x52, 0x06, 0x74, 0x6f, 0x6b, 0x65, + 0x6e, 0x73, 0x22, 0xe4, 0x01, 0x0a, 0x14, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x49, + 0x6d, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x68, + 0x65, 0x69, 0x67, 0x68, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x06, 0x68, 0x65, 0x69, + 0x67, 0x68, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x77, 0x69, 0x64, 0x74, 0x68, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x05, 0x52, 0x05, 0x77, 0x69, 0x64, 0x74, 0x68, 0x12, 0x12, 0x0a, 0x04, 0x6d, 0x6f, 0x64, + 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x6d, 0x6f, 0x64, 0x65, 0x12, 0x12, 0x0a, + 0x04, 0x73, 0x74, 0x65, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x05, 0x52, 0x04, 0x73, 0x74, 0x65, + 0x70, 0x12, 0x12, 0x0a, 0x04, 0x73, 0x65, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x04, 0x73, 0x65, 0x65, 0x64, 0x12, 0x27, 0x0a, 0x0f, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x76, + 0x65, 0x5f, 0x70, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, + 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x76, 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x12, 0x27, + 0x0a, 0x0f, 0x6e, 0x65, 0x67, 0x61, 0x74, 0x69, 0x76, 0x65, 0x5f, 0x70, 0x72, 0x6f, 0x6d, 0x70, + 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0e, 0x6e, 0x65, 0x67, 0x61, 0x74, 0x69, 0x76, + 0x65, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x73, 0x74, 0x18, 0x08, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, 0x73, 0x74, 0x22, 0x48, 0x0a, 0x0a, 0x54, 0x54, 0x53, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x65, 0x78, 0x74, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x65, 0x78, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6d, + 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, 0x6f, 0x64, 0x65, + 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, + 0x64, 0x73, 0x74, 0x32, 0xeb, 0x03, 0x0a, 0x07, 0x42, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x12, + 0x32, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x12, 0x16, 0x2e, 0x62, 0x61, 0x63, 0x6b, + 0x65, 0x6e, 0x64, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, + 0x79, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x17, + 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, + 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, + 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x35, 0x0a, 0x09, 0x4c, 0x6f, 0x61, + 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, + 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0f, 0x2e, + 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, + 0x12, 0x3c, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x12, 0x17, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, + 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0e, 0x2e, 0x62, 0x61, 0x63, + 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x12, 0x40, + 0x0a, 0x09, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x17, 0x2e, 0x62, 0x61, + 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x18, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x45, + 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, + 0x12, 0x41, 0x0a, 0x0d, 0x47, 0x65, 0x6e, 0x65, 0x72, 0x61, 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, + 0x65, 0x12, 0x1d, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x47, 0x65, 0x6e, 0x65, + 0x72, 0x61, 0x74, 0x65, 0x49, 0x6d, 0x61, 0x67, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x0f, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, + 0x74, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x12, 0x41, 0x75, 0x64, 0x69, 0x6f, 0x54, 0x72, 0x61, 0x6e, + 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1a, 0x2e, 0x62, 0x61, 0x63, 0x6b, + 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, + 0x54, 0x72, 0x61, 0x6e, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, + 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x03, 0x54, 0x54, 0x53, 0x12, 0x13, 0x2e, 0x62, 0x61, 0x63, 0x6b, + 0x65, 0x6e, 0x64, 0x2e, 0x54, 0x54, 0x53, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, + 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, + 0x00, 0x42, 0x5a, 0x0a, 0x19, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2e, 0x6c, + 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x62, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x42, 0x0e, + 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x42, 0x61, 0x63, 0x6b, 0x65, 0x6e, 0x64, 0x50, 0x01, + 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, + 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, + 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_pkg_grpc_proto_backend_proto_rawDescOnce sync.Once + file_pkg_grpc_proto_backend_proto_rawDescData = file_pkg_grpc_proto_backend_proto_rawDesc +) + +func file_pkg_grpc_proto_backend_proto_rawDescGZIP() []byte { + file_pkg_grpc_proto_backend_proto_rawDescOnce.Do(func() { + file_pkg_grpc_proto_backend_proto_rawDescData = protoimpl.X.CompressGZIP(file_pkg_grpc_proto_backend_proto_rawDescData) + }) + return file_pkg_grpc_proto_backend_proto_rawDescData +} + +var file_pkg_grpc_proto_backend_proto_msgTypes = make([]protoimpl.MessageInfo, 11) +var file_pkg_grpc_proto_backend_proto_goTypes = []interface{}{ + (*HealthMessage)(nil), // 0: backend.HealthMessage + (*PredictOptions)(nil), // 1: backend.PredictOptions + (*Reply)(nil), // 2: backend.Reply + (*ModelOptions)(nil), // 3: backend.ModelOptions + (*Result)(nil), // 4: backend.Result + (*EmbeddingResult)(nil), // 5: backend.EmbeddingResult + (*TranscriptRequest)(nil), // 6: backend.TranscriptRequest + (*TranscriptResult)(nil), // 7: backend.TranscriptResult + (*TranscriptSegment)(nil), // 8: backend.TranscriptSegment + (*GenerateImageRequest)(nil), // 9: backend.GenerateImageRequest + (*TTSRequest)(nil), // 10: backend.TTSRequest +} +var file_pkg_grpc_proto_backend_proto_depIdxs = []int32{ + 8, // 0: backend.TranscriptResult.segments:type_name -> backend.TranscriptSegment + 0, // 1: backend.Backend.Health:input_type -> backend.HealthMessage + 1, // 2: backend.Backend.Predict:input_type -> backend.PredictOptions + 3, // 3: backend.Backend.LoadModel:input_type -> backend.ModelOptions + 1, // 4: backend.Backend.PredictStream:input_type -> backend.PredictOptions + 1, // 5: backend.Backend.Embedding:input_type -> backend.PredictOptions + 9, // 6: backend.Backend.GenerateImage:input_type -> backend.GenerateImageRequest + 6, // 7: backend.Backend.AudioTranscription:input_type -> backend.TranscriptRequest + 10, // 8: backend.Backend.TTS:input_type -> backend.TTSRequest + 2, // 9: backend.Backend.Health:output_type -> backend.Reply + 2, // 10: backend.Backend.Predict:output_type -> backend.Reply + 4, // 11: backend.Backend.LoadModel:output_type -> backend.Result + 2, // 12: backend.Backend.PredictStream:output_type -> backend.Reply + 5, // 13: backend.Backend.Embedding:output_type -> backend.EmbeddingResult + 4, // 14: backend.Backend.GenerateImage:output_type -> backend.Result + 7, // 15: backend.Backend.AudioTranscription:output_type -> backend.TranscriptResult + 4, // 16: backend.Backend.TTS:output_type -> backend.Result + 9, // [9:17] is the sub-list for method output_type + 1, // [1:9] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_pkg_grpc_proto_backend_proto_init() } +func file_pkg_grpc_proto_backend_proto_init() { + if File_pkg_grpc_proto_backend_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_pkg_grpc_proto_backend_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HealthMessage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*PredictOptions); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Reply); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ModelOptions); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Result); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EmbeddingResult); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TranscriptRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TranscriptResult); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TranscriptSegment); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*GenerateImageRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_pkg_grpc_proto_backend_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TTSRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_pkg_grpc_proto_backend_proto_rawDesc, + NumEnums: 0, + NumMessages: 11, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_pkg_grpc_proto_backend_proto_goTypes, + DependencyIndexes: file_pkg_grpc_proto_backend_proto_depIdxs, + MessageInfos: file_pkg_grpc_proto_backend_proto_msgTypes, + }.Build() + File_pkg_grpc_proto_backend_proto = out.File + file_pkg_grpc_proto_backend_proto_rawDesc = nil + file_pkg_grpc_proto_backend_proto_goTypes = nil + file_pkg_grpc_proto_backend_proto_depIdxs = nil +} diff --git a/pkg/grpc/proto/backend.proto b/pkg/grpc/proto/backend.proto new file mode 100644 index 0000000..7e0bdb7 --- /dev/null +++ b/pkg/grpc/proto/backend.proto @@ -0,0 +1,129 @@ +syntax = "proto3"; + +option go_package = "github.com/go-skynet/LocalAI/pkg/grpc/proto"; +option java_multiple_files = true; +option java_package = "io.skynet.localai.backend"; +option java_outer_classname = "LocalAIBackend"; + +package backend; + +service Backend { + rpc Health(HealthMessage) returns (Reply) {} + rpc Predict(PredictOptions) returns (Reply) {} + rpc LoadModel(ModelOptions) returns (Result) {} + rpc PredictStream(PredictOptions) returns (stream Reply) {} + rpc Embedding(PredictOptions) returns (EmbeddingResult) {} + rpc GenerateImage(GenerateImageRequest) returns (Result) {} + rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} + rpc TTS(TTSRequest) returns (Result) {} +} + +message HealthMessage {} + +// The request message containing the user's name. +message PredictOptions { + string Prompt = 1; + int32 Seed = 2; + int32 Threads = 3; + int32 Tokens = 4; + int32 TopK = 5; + int32 Repeat = 6; + int32 Batch = 7; + int32 NKeep = 8; + float Temperature = 9; + float Penalty = 10; + bool F16KV = 11; + bool DebugMode = 12; + repeated string StopPrompts = 13; + bool IgnoreEOS = 14; + float TailFreeSamplingZ = 15; + float TypicalP = 16; + float FrequencyPenalty = 17; + float PresencePenalty = 18; + int32 Mirostat = 19; + float MirostatETA = 20; + float MirostatTAU = 21; + bool PenalizeNL = 22; + string LogitBias = 23; + bool MLock = 25; + bool MMap = 26; + bool PromptCacheAll = 27; + bool PromptCacheRO = 28; + string Grammar = 29; + string MainGPU = 30; + string TensorSplit = 31; + float TopP = 32; + string PromptCachePath = 33; + bool Debug = 34; + repeated int32 EmbeddingTokens = 35; + string Embeddings = 36; +} + +// The response message containing the result +message Reply { + string message = 1; +} + +message ModelOptions { + string Model = 1; + int32 ContextSize = 2; + int32 Seed = 3; + int32 NBatch = 4; + bool F16Memory = 5; + bool MLock = 6; + bool MMap = 7; + bool VocabOnly = 8; + bool LowVRAM = 9; + bool Embeddings = 10; + bool NUMA = 11; + int32 NGPULayers = 12; + string MainGPU = 13; + string TensorSplit = 14; + int32 Threads = 15; + string LibrarySearchPath = 16; +} + +message Result { + string message = 1; + bool success = 2; +} + +message EmbeddingResult { + repeated float embeddings = 1; +} + +message TranscriptRequest { + string dst = 2; + string language = 3; + uint32 threads = 4; +} + +message TranscriptResult { + repeated TranscriptSegment segments = 1; + string text = 2; +} + +message TranscriptSegment { + int32 id = 1; + int64 start = 2; + int64 end = 3; + string text = 4; + repeated int32 tokens = 5; +} + +message GenerateImageRequest { + int32 height = 1; + int32 width = 2; + int32 mode = 3; + int32 step = 4; + int32 seed = 5; + string positive_prompt = 6; + string negative_prompt = 7; + string dst = 8; +} + +message TTSRequest { + string text = 1; + string model = 2; + string dst = 3; +} diff --git a/pkg/grpc/proto/backend_grpc.pb.go b/pkg/grpc/proto/backend_grpc.pb.go new file mode 100644 index 0000000..b9d7dd8 --- /dev/null +++ b/pkg/grpc/proto/backend_grpc.pb.go @@ -0,0 +1,385 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.2.0 +// - protoc v3.15.8 +// source: pkg/grpc/proto/backend.proto + +package proto + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.32.0 or later. +const _ = grpc.SupportPackageIsVersion7 + +// BackendClient is the client API for Backend service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +type BackendClient interface { + Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) + Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) + LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) + PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) + Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) + GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) + AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) + TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) +} + +type backendClient struct { + cc grpc.ClientConnInterface +} + +func NewBackendClient(cc grpc.ClientConnInterface) BackendClient { + return &backendClient{cc} +} + +func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { + out := new(Reply) + err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { + out := new(Reply) + err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { + out := new(Result) + err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) { + stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...) + if err != nil { + return nil, err + } + x := &backendPredictStreamClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type Backend_PredictStreamClient interface { + Recv() (*Reply, error) + grpc.ClientStream +} + +type backendPredictStreamClient struct { + grpc.ClientStream +} + +func (x *backendPredictStreamClient) Recv() (*Reply, error) { + m := new(Reply) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + +func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { + out := new(EmbeddingResult) + err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) { + out := new(Result) + err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) { + out := new(TranscriptResult) + err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) { + out := new(Result) + err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +// BackendServer is the server API for Backend service. +// All implementations must embed UnimplementedBackendServer +// for forward compatibility +type BackendServer interface { + Health(context.Context, *HealthMessage) (*Reply, error) + Predict(context.Context, *PredictOptions) (*Reply, error) + LoadModel(context.Context, *ModelOptions) (*Result, error) + PredictStream(*PredictOptions, Backend_PredictStreamServer) error + Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) + GenerateImage(context.Context, *GenerateImageRequest) (*Result, error) + AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error) + TTS(context.Context, *TTSRequest) (*Result, error) + mustEmbedUnimplementedBackendServer() +} + +// UnimplementedBackendServer must be embedded to have forward compatible implementations. +type UnimplementedBackendServer struct { +} + +func (UnimplementedBackendServer) Health(context.Context, *HealthMessage) (*Reply, error) { + return nil, status.Errorf(codes.Unimplemented, "method Health not implemented") +} +func (UnimplementedBackendServer) Predict(context.Context, *PredictOptions) (*Reply, error) { + return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented") +} +func (UnimplementedBackendServer) LoadModel(context.Context, *ModelOptions) (*Result, error) { + return nil, status.Errorf(codes.Unimplemented, "method LoadModel not implemented") +} +func (UnimplementedBackendServer) PredictStream(*PredictOptions, Backend_PredictStreamServer) error { + return status.Errorf(codes.Unimplemented, "method PredictStream not implemented") +} +func (UnimplementedBackendServer) Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) { + return nil, status.Errorf(codes.Unimplemented, "method Embedding not implemented") +} +func (UnimplementedBackendServer) GenerateImage(context.Context, *GenerateImageRequest) (*Result, error) { + return nil, status.Errorf(codes.Unimplemented, "method GenerateImage not implemented") +} +func (UnimplementedBackendServer) AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error) { + return nil, status.Errorf(codes.Unimplemented, "method AudioTranscription not implemented") +} +func (UnimplementedBackendServer) TTS(context.Context, *TTSRequest) (*Result, error) { + return nil, status.Errorf(codes.Unimplemented, "method TTS not implemented") +} +func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {} + +// UnsafeBackendServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to BackendServer will +// result in compilation errors. +type UnsafeBackendServer interface { + mustEmbedUnimplementedBackendServer() +} + +func RegisterBackendServer(s grpc.ServiceRegistrar, srv BackendServer) { + s.RegisterService(&Backend_ServiceDesc, srv) +} + +func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HealthMessage) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).Health(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/Health", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).Health(ctx, req.(*HealthMessage)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PredictOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).Predict(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/Predict", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).Predict(ctx, req.(*PredictOptions)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ModelOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).LoadModel(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/LoadModel", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_PredictStream_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(PredictOptions) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(BackendServer).PredictStream(m, &backendPredictStreamServer{stream}) +} + +type Backend_PredictStreamServer interface { + Send(*Reply) error + grpc.ServerStream +} + +type backendPredictStreamServer struct { + grpc.ServerStream +} + +func (x *backendPredictStreamServer) Send(m *Reply) error { + return x.ServerStream.SendMsg(m) +} + +func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PredictOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).Embedding(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/Embedding", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GenerateImageRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).GenerateImage(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/GenerateImage", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TranscriptRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).AudioTranscription(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/AudioTranscription", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TTSRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(BackendServer).TTS(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/backend.Backend/TTS", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(BackendServer).TTS(ctx, req.(*TTSRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// Backend_ServiceDesc is the grpc.ServiceDesc for Backend service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var Backend_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "backend.Backend", + HandlerType: (*BackendServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Health", + Handler: _Backend_Health_Handler, + }, + { + MethodName: "Predict", + Handler: _Backend_Predict_Handler, + }, + { + MethodName: "LoadModel", + Handler: _Backend_LoadModel_Handler, + }, + { + MethodName: "Embedding", + Handler: _Backend_Embedding_Handler, + }, + { + MethodName: "GenerateImage", + Handler: _Backend_GenerateImage_Handler, + }, + { + MethodName: "AudioTranscription", + Handler: _Backend_AudioTranscription_Handler, + }, + { + MethodName: "TTS", + Handler: _Backend_TTS_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "PredictStream", + Handler: _Backend_PredictStream_Handler, + ServerStreams: true, + }, + }, + Metadata: "pkg/grpc/proto/backend.proto", +} diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go new file mode 100644 index 0000000..8d7a182 --- /dev/null +++ b/pkg/grpc/server.go @@ -0,0 +1,126 @@ +package grpc + +import ( + "context" + "fmt" + "log" + "net" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "google.golang.org/grpc" +) + +// A GRPC Server that allows to run LLM inference. +// It is used by the LLMServices to expose the LLM functionalities that are called by the client. +// The GRPC Service is general, trying to encompass all the possible LLM options models. +// It depends on the real implementer then what can be done or not. +// +// The server is implemented as a GRPC service, with the following methods: +// - Predict: to run the inference with options +// - PredictStream: to run the inference with options and stream the results + +// server is used to implement helloworld.GreeterServer. +type server struct { + pb.UnimplementedBackendServer + llm LLM +} + +func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) { + return &pb.Reply{Message: "OK"}, nil +} + +func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) { + embeds, err := s.llm.Embeddings(in) + if err != nil { + return nil, err + } + + return &pb.EmbeddingResult{Embeddings: embeds}, nil +} + +func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { + err := s.llm.Load(in) + if err != nil { + return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err + } + return &pb.Result{Message: "Loading succeeded", Success: true}, nil +} + +func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) { + result, err := s.llm.Predict(in) + return &pb.Reply{Message: result}, err +} + +func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) { + err := s.llm.GenerateImage(in) + if err != nil { + return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err + } + return &pb.Result{Message: "Image generated", Success: true}, nil +} + +func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { + err := s.llm.TTS(in) + if err != nil { + return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err + } + return &pb.Result{Message: "Audio generated", Success: true}, nil +} + +func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { + result, err := s.llm.AudioTranscription(in) + if err != nil { + return nil, err + } + tresult := &pb.TranscriptResult{} + for _, s := range result.Segments { + tks := []int32{} + for _, t := range s.Tokens { + tks = append(tks, int32(t)) + } + tresult.Segments = append(tresult.Segments, + &pb.TranscriptSegment{ + Text: s.Text, + Id: int32(s.Id), + Start: int64(s.Start), + End: int64(s.End), + Tokens: tks, + }) + } + + tresult.Text = result.Text + return tresult, nil +} + +func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error { + + resultChan := make(chan string) + + done := make(chan bool) + go func() { + for result := range resultChan { + stream.Send(&pb.Reply{Message: result}) + } + done <- true + }() + + s.llm.PredictStream(in, resultChan) + <-done + + return nil +} + +func StartServer(address string, model LLM) error { + lis, err := net.Listen("tcp", address) + if err != nil { + return err + } + s := grpc.NewServer() + pb.RegisterBackendServer(s, &server{llm: model}) + log.Printf("gRPC Server listening at %v", lis.Addr()) + if err := s.Serve(lis); err != nil { + return err + } + + return nil +} diff --git a/pkg/grpc/transcribe/whisper.go b/pkg/grpc/transcribe/whisper.go new file mode 100644 index 0000000..c0120db --- /dev/null +++ b/pkg/grpc/transcribe/whisper.go @@ -0,0 +1,27 @@ +package transcribe + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + whisperutil "github.com/go-skynet/LocalAI/pkg/grpc/whisper" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" +) + +type Whisper struct { + base.Base + whisper whisper.Model +} + +func (sd *Whisper) Load(opts *pb.ModelOptions) error { + // Note: the Model here is a path to a directory containing the model files + w, err := whisper.New(opts.Model) + sd.whisper = w + return err +} + +func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (api.Result, error) { + return whisperutil.Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) +} diff --git a/pkg/grpc/tts/piper.go b/pkg/grpc/tts/piper.go new file mode 100644 index 0000000..dbaa4b7 --- /dev/null +++ b/pkg/grpc/tts/piper.go @@ -0,0 +1,44 @@ +package tts + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "os" + + "github.com/go-skynet/LocalAI/pkg/grpc/base" + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + piper "github.com/mudler/go-piper" +) + +type Piper struct { + base.Base + piper *PiperB +} + +func (sd *Piper) Load(opts *pb.ModelOptions) error { + var err error + // Note: the Model here is a path to a directory containing the model files + sd.piper, err = New(opts.LibrarySearchPath) + return err +} + +func (sd *Piper) TTS(opts *pb.TTSRequest) error { + return sd.piper.TTS(opts.Text, opts.Model, opts.Dst) +} + +type PiperB struct { + assetDir string +} + +func New(assetDir string) (*PiperB, error) { + if _, err := os.Stat(assetDir); err != nil { + return nil, err + } + return &PiperB{ + assetDir: assetDir, + }, nil +} + +func (s *PiperB) TTS(text, model, dst string) error { + return piper.TextToWav(text, model, s.assetDir, "", dst) +} diff --git a/pkg/grpc/whisper/api/api.go b/pkg/grpc/whisper/api/api.go new file mode 100644 index 0000000..700d80e --- /dev/null +++ b/pkg/grpc/whisper/api/api.go @@ -0,0 +1,16 @@ +package api + +import "time" + +type Segment struct { + Id int `json:"id"` + Start time.Duration `json:"start"` + End time.Duration `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` +} + +type Result struct { + Segments []Segment `json:"segments"` + Text string `json:"text"` +} diff --git a/pkg/whisper/whisper.go b/pkg/grpc/whisper/whisper.go similarity index 78% rename from pkg/whisper/whisper.go rename to pkg/grpc/whisper/whisper.go index 63e8cc5..806e145 100644 --- a/pkg/whisper/whisper.go +++ b/pkg/grpc/whisper/whisper.go @@ -5,25 +5,12 @@ import ( "os" "os/exec" "path/filepath" - "time" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" wav "github.com/go-audio/wav" + "github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" ) -type Segment struct { - Id int `json:"id"` - Start time.Duration `json:"start"` - End time.Duration `json:"end"` - Text string `json:"text"` - Tokens []int `json:"tokens"` -} - -type Result struct { - Segments []Segment `json:"segments"` - Text string `json:"text"` -} - func sh(c string) (string, error) { cmd := exec.Command("/bin/sh", "-c", c) cmd.Env = os.Environ() @@ -42,8 +29,8 @@ func audioToWav(src, dst string) error { return nil } -func Transcript(model whisper.Model, audiopath, language string, threads uint) (Result, error) { - res := Result{} +func Transcript(model whisper.Model, audiopath, language string, threads uint) (api.Result, error) { + res := api.Result{} dir, err := os.MkdirTemp("", "whisper") if err != nil { @@ -99,11 +86,11 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) ( } var tokens []int - for _, t := range(s.Tokens) { + for _, t := range s.Tokens { tokens = append(tokens, t.Id) } - segment := Segment{Id: s.Num, Text: s.Text, Start:s.Start, End: s.End, Tokens: tokens} + segment := api.Segment{Id: s.Num, Text: s.Text, Start: s.Start, End: s.End, Tokens: tokens} res.Segments = append(res.Segments, segment) res.Text += s.Text diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 3849f85..d91131d 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -1,197 +1,216 @@ package model import ( + "context" "fmt" + "os" + "os/signal" "path/filepath" "strings" + "syscall" + "time" - rwkv "github.com/donomii/go-rwkv.cpp" - whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" - "github.com/go-skynet/LocalAI/pkg/langchain" - "github.com/go-skynet/LocalAI/pkg/stablediffusion" - "github.com/go-skynet/LocalAI/pkg/tts" - bloomz "github.com/go-skynet/bloomz.cpp" - bert "github.com/go-skynet/go-bert.cpp" - transformers "github.com/go-skynet/go-ggml-transformers.cpp" - llama "github.com/go-skynet/go-llama.cpp" + grpc "github.com/go-skynet/LocalAI/pkg/grpc" "github.com/hashicorp/go-multierror" - gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" + "github.com/hpcloud/tail" + "github.com/phayes/freeport" "github.com/rs/zerolog/log" + + process "github.com/mudler/go-processmanager" ) const tokenizerSuffix = ".tokenizer.json" const ( - LlamaBackend = "llama" - BloomzBackend = "bloomz" - StarcoderBackend = "starcoder" - GPTJBackend = "gptj" - DollyBackend = "dolly" - MPTBackend = "mpt" - GPTNeoXBackend = "gptneox" - ReplitBackend = "replit" - Gpt2Backend = "gpt2" - Gpt4AllLlamaBackend = "gpt4all-llama" - Gpt4AllMptBackend = "gpt4all-mpt" - Gpt4AllJBackend = "gpt4all-j" - Gpt4All = "gpt4all" - FalconBackend = "falcon" + LlamaBackend = "llama" + BloomzBackend = "bloomz" + StarcoderBackend = "starcoder" + GPTJBackend = "gptj" + DollyBackend = "dolly" + MPTBackend = "mpt" + GPTNeoXBackend = "gptneox" + ReplitBackend = "replit" + Gpt2Backend = "gpt2" + Gpt4AllLlamaBackend = "gpt4all-llama" + Gpt4AllMptBackend = "gpt4all-mpt" + Gpt4AllJBackend = "gpt4all-j" + Gpt4All = "gpt4all" + FalconBackend = "falcon" + FalconGGMLBackend = "falcon-ggml" + BertEmbeddingsBackend = "bert-embeddings" RwkvBackend = "rwkv" WhisperBackend = "whisper" StableDiffusionBackend = "stablediffusion" PiperBackend = "piper" LCHuggingFaceBackend = "langchain-huggingface" + //GGLLMFalconBackend = "falcon" ) var autoLoadBackends []string = []string{ LlamaBackend, Gpt4All, RwkvBackend, - GPTNeoXBackend, + FalconBackend, WhisperBackend, + GPTNeoXBackend, BertEmbeddingsBackend, + FalconGGMLBackend, GPTJBackend, Gpt2Backend, DollyBackend, - FalconBackend, MPTBackend, ReplitBackend, StarcoderBackend, BloomzBackend, } -var starCoder = func(modelFile string) (interface{}, error) { - return transformers.NewStarcoder(modelFile) +func (ml *ModelLoader) StopGRPC() { + for _, p := range ml.grpcProcesses { + p.Stop() + } } -var mpt = func(modelFile string) (interface{}, error) { - return transformers.NewMPT(modelFile) -} +// starts the grpcModelProcess for the backend, and returns a grpc client +// It also loads the model +func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string) (*grpc.Client, error) { + return func(s string) (*grpc.Client, error) { + log.Debug().Msgf("Loading GRPC Model", backend, *o) -var dolly = func(modelFile string) (interface{}, error) { - return transformers.NewDolly(modelFile) -} + grpcProcess := filepath.Join(o.assetDir, "backend-assets", "grpc", backend) -var gptNeoX = func(modelFile string) (interface{}, error) { - return transformers.NewGPTNeoX(modelFile) -} + // Check if the file exists + if _, err := os.Stat(grpcProcess); os.IsNotExist(err) { + return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess) + } -var replit = func(modelFile string) (interface{}, error) { - return transformers.NewReplit(modelFile) -} + // Make sure the process is executable + if err := os.Chmod(grpcProcess, 0755); err != nil { + return nil, err + } -var gptJ = func(modelFile string) (interface{}, error) { - return transformers.NewGPTJ(modelFile) -} + log.Debug().Msgf("Loading GRPC Process", grpcProcess) + port, err := freeport.GetFreePort() + if err != nil { + return nil, err + } -var falcon = func(modelFile string) (interface{}, error) { - return transformers.NewFalcon(modelFile) -} + serverAddress := fmt.Sprintf("localhost:%d", port) -var bertEmbeddings = func(modelFile string) (interface{}, error) { - return bert.New(modelFile) -} + log.Debug().Msgf("GRPC Service for '%s' (%s) will be running at: '%s'", backend, o.modelFile, serverAddress) -var bloomzLM = func(modelFile string) (interface{}, error) { - return bloomz.New(modelFile) -} + grpcControlProcess := process.New( + process.WithTemporaryStateDir(), + process.WithName(grpcProcess), + process.WithArgs("--addr", serverAddress)) -var transformersLM = func(modelFile string) (interface{}, error) { - return transformers.New(modelFile) -} + ml.grpcProcesses[o.modelFile] = grpcControlProcess -var stableDiffusion = func(assetDir string) (interface{}, error) { - return stablediffusion.New(assetDir) -} + if err := grpcControlProcess.Run(); err != nil { + return nil, err + } -func piperTTS(assetDir string) func(s string) (interface{}, error) { - return func(s string) (interface{}, error) { - return tts.New(assetDir) - } -} + // clean up process + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c + grpcControlProcess.Stop() + }() + + go func() { + t, err := tail.TailFile(grpcControlProcess.StderrPath(), tail.Config{Follow: true}) + if err != nil { + log.Debug().Msgf("Could not tail stderr") + } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) + } + }() + go func() { + t, err := tail.TailFile(grpcControlProcess.StdoutPath(), tail.Config{Follow: true}) + if err != nil { + log.Debug().Msgf("Could not tail stdout") + } + for line := range t.Lines { + log.Debug().Msgf("GRPC(%s): stderr %s", strings.Join([]string{backend, o.modelFile, serverAddress}, "-"), line.Text) + } + }() + + log.Debug().Msgf("GRPC Service Started") + + client := grpc.NewClient(serverAddress) + + // Wait for the service to start up + ready := false + for i := 0; i < 10; i++ { + if client.HealthCheck(context.Background()) { + log.Debug().Msgf("GRPC Service Ready") + ready = true + break + } + time.Sleep(1 * time.Second) + } -var whisperModel = func(modelFile string) (interface{}, error) { - return whisper.New(modelFile) -} + if !ready { + log.Debug().Msgf("GRPC Service NOT ready") + log.Debug().Msgf("Alive: ", grpcControlProcess.IsAlive()) + log.Debug().Msgf(fmt.Sprintf("GRPC Service Exitcode:")) -var lcHuggingFace = func(repoId string) (interface{}, error) { - return langchain.NewHuggingFace(repoId) -} + log.Debug().Msgf(grpcControlProcess.ExitCode()) -func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) { - return func(s string) (interface{}, error) { - return llama.New(s, opts...) - } -} + return nil, fmt.Errorf("grpc service not ready") + } -func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) { - return func(s string) (interface{}, error) { - return gpt4all.New(s, opts...) - } -} + options := *o.gRPCOptions + options.Model = s -func rwkvLM(tokenFile string, threads uint32) func(string) (interface{}, error) { - return func(s string) (interface{}, error) { - log.Debug().Msgf("Loading RWKV", s, tokenFile) + log.Debug().Msgf("GRPC: Loading model with options: %+v", options) - model := rwkv.LoadFiles(s, tokenFile, threads) - if model == nil { - return nil, fmt.Errorf("could not load model") + res, err := client.LoadModel(o.context, &options) + if err != nil { + return nil, err + } + if !res.Success { + return nil, fmt.Errorf("could not load model: %s", res.Message) } - return model, nil + + return client, nil } } -func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, llamaOpts []llama.ModelOption, threads uint32, assetDir string) (model interface{}, err error) { - log.Debug().Msgf("Loading model %s from %s", backendString, modelFile) - switch strings.ToLower(backendString) { - case LlamaBackend: - return ml.LoadModel(modelFile, llamaLM(llamaOpts...)) - case BloomzBackend: - return ml.LoadModel(modelFile, bloomzLM) - case GPTJBackend: - return ml.LoadModel(modelFile, gptJ) - case DollyBackend: - return ml.LoadModel(modelFile, dolly) - case MPTBackend: - return ml.LoadModel(modelFile, mpt) - case Gpt2Backend: - return ml.LoadModel(modelFile, transformersLM) - case FalconBackend: - return ml.LoadModel(modelFile, falcon) - case GPTNeoXBackend: - return ml.LoadModel(modelFile, gptNeoX) - case ReplitBackend: - return ml.LoadModel(modelFile, replit) - case StableDiffusionBackend: - return ml.LoadModel(modelFile, stableDiffusion) - case PiperBackend: - return ml.LoadModel(modelFile, piperTTS(filepath.Join(assetDir, "backend-assets", "espeak-ng-data"))) - case StarcoderBackend: - return ml.LoadModel(modelFile, starCoder) +func (ml *ModelLoader) BackendLoader(opts ...Option) (model *grpc.Client, err error) { + o := NewOptions(opts...) + + log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) + + backend := strings.ToLower(o.backendString) + switch backend { + case LlamaBackend, GPTJBackend, DollyBackend, + MPTBackend, Gpt2Backend, FalconBackend, + GPTNeoXBackend, ReplitBackend, StarcoderBackend, BloomzBackend, + RwkvBackend, LCHuggingFaceBackend, BertEmbeddingsBackend, FalconGGMLBackend, StableDiffusionBackend, WhisperBackend: + return ml.LoadModel(o.modelFile, ml.grpcModel(backend, o)) case Gpt4AllLlamaBackend, Gpt4AllMptBackend, Gpt4AllJBackend, Gpt4All: - return ml.LoadModel(modelFile, gpt4allLM(gpt4all.SetThreads(int(threads)), gpt4all.SetLibrarySearchPath(filepath.Join(assetDir, "backend-assets", "gpt4all")))) - case BertEmbeddingsBackend: - return ml.LoadModel(modelFile, bertEmbeddings) - case RwkvBackend: - return ml.LoadModel(modelFile, rwkvLM(filepath.Join(ml.ModelPath, modelFile+tokenizerSuffix), threads)) - case WhisperBackend: - return ml.LoadModel(modelFile, whisperModel) - case LCHuggingFaceBackend: - return ml.LoadModel(modelFile, lcHuggingFace) + o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "gpt4all") + return ml.LoadModel(o.modelFile, ml.grpcModel(Gpt4All, o)) + case PiperBackend: + o.gRPCOptions.LibrarySearchPath = filepath.Join(o.assetDir, "backend-assets", "espeak-ng-data") + return ml.LoadModel(o.modelFile, ml.grpcModel(PiperBackend, o)) default: - return nil, fmt.Errorf("backend unsupported: %s", backendString) + return nil, fmt.Errorf("backend unsupported: %s", o.backendString) } } -func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32, assetDir string) (interface{}, error) { - log.Debug().Msgf("Loading model '%s' greedly", modelFile) +func (ml *ModelLoader) GreedyLoader(opts ...Option) (*grpc.Client, error) { + o := NewOptions(opts...) + log.Debug().Msgf("Loading model '%s' greedly", o.modelFile) + + // Is this really needed? BackendLoader already does this ml.mu.Lock() - m, exists := ml.models[modelFile] - if exists { - log.Debug().Msgf("Model '%s' already loaded", modelFile) + if m := ml.checkIsLoaded(o.modelFile); m != nil { + log.Debug().Msgf("Model '%s' already loaded", o.modelFile) ml.mu.Unlock() return m, nil } @@ -203,7 +222,14 @@ func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOpt continue } log.Debug().Msgf("[%s] Attempting to load", b) - model, modelerr := ml.BackendLoader(b, modelFile, llamaOpts, threads, assetDir) + + model, modelerr := ml.BackendLoader( + WithBackendString(b), + WithModelFile(o.modelFile), + WithLoadGRPCLLMModelOpts(o.gRPCOptions), + WithThreads(o.threads), + WithAssetDir(o.assetDir), + ) if modelerr == nil && model != nil { log.Debug().Msgf("[%s] Loads OK", b) return model, nil diff --git a/pkg/model/loader.go b/pkg/model/loader.go index ddc7b6e..833c311 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -2,6 +2,7 @@ package model import ( "bytes" + "context" "fmt" "io/ioutil" "os" @@ -10,6 +11,8 @@ import ( "sync" "text/template" + "github.com/go-skynet/LocalAI/pkg/grpc" + process "github.com/mudler/go-processmanager" "github.com/rs/zerolog/log" ) @@ -17,15 +20,17 @@ type ModelLoader struct { ModelPath string mu sync.Mutex // TODO: this needs generics - models map[string]interface{} + models map[string]*grpc.Client + grpcProcesses map[string]*process.Process promptsTemplates map[string]*template.Template } func NewModelLoader(modelPath string) *ModelLoader { return &ModelLoader{ ModelPath: modelPath, - models: make(map[string]interface{}), + models: make(map[string]*grpc.Client), promptsTemplates: make(map[string]*template.Template), + grpcProcesses: make(map[string]*process.Process), } } @@ -110,14 +115,14 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error { return nil } -func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (interface{}, error)) (interface{}, error) { +func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (*grpc.Client, error)) (*grpc.Client, error) { ml.mu.Lock() defer ml.mu.Unlock() // Check if we already have a loaded model - if m, ok := ml.models[modelName]; ok { + if model := ml.checkIsLoaded(modelName); model != nil { log.Debug().Msgf("Model already loaded in memory: %s", modelName) - return m, nil + return model, nil } // Load the model and keep it in memory for later use @@ -137,3 +142,25 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string) (interfac ml.models[modelName] = model return model, nil } + +func (ml *ModelLoader) checkIsLoaded(s string) *grpc.Client { + if m, ok := ml.models[s]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", s) + + if !m.HealthCheck(context.Background()) { + log.Debug().Msgf("GRPC Model not responding", s) + if !ml.grpcProcesses[s].IsAlive() { + log.Debug().Msgf("GRPC Process is not responding", s) + // stop and delete the process, this forces to re-load the model and re-create again the service + ml.grpcProcesses[s].Stop() + delete(ml.grpcProcesses, s) + delete(ml.models, s) + return nil + } + } + + return m + } + + return nil +} diff --git a/pkg/model/options.go b/pkg/model/options.go new file mode 100644 index 0000000..298ebd4 --- /dev/null +++ b/pkg/model/options.go @@ -0,0 +1,66 @@ +package model + +import ( + "context" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" +) + +type Options struct { + backendString string + modelFile string + threads uint32 + assetDir string + context context.Context + + gRPCOptions *pb.ModelOptions +} + +type Option func(*Options) + +func WithBackendString(backend string) Option { + return func(o *Options) { + o.backendString = backend + } +} + +func WithModelFile(modelFile string) Option { + return func(o *Options) { + o.modelFile = modelFile + } +} + +func WithLoadGRPCLLMModelOpts(opts *pb.ModelOptions) Option { + return func(o *Options) { + o.gRPCOptions = opts + } +} + +func WithThreads(threads uint32) Option { + return func(o *Options) { + o.threads = threads + } +} + +func WithAssetDir(assetDir string) Option { + return func(o *Options) { + o.assetDir = assetDir + } +} + +func WithContext(ctx context.Context) Option { + return func(o *Options) { + o.context = ctx + } +} + +func NewOptions(opts ...Option) *Options { + o := &Options{ + gRPCOptions: &pb.ModelOptions{}, + context: context.Background(), + } + for _, opt := range opts { + opt(o) + } + return o +} diff --git a/pkg/tts/generate.go b/pkg/tts/generate.go deleted file mode 100644 index e4722d4..0000000 --- a/pkg/tts/generate.go +++ /dev/null @@ -1,12 +0,0 @@ -//go:build tts -// +build tts - -package tts - -import ( - piper "github.com/mudler/go-piper" -) - -func tts(text, model, assetDir, arLib, dst string) error { - return piper.TextToWav(text, model, assetDir, arLib, dst) -} diff --git a/pkg/tts/generate_unsupported.go b/pkg/tts/generate_unsupported.go deleted file mode 100644 index 3092695..0000000 --- a/pkg/tts/generate_unsupported.go +++ /dev/null @@ -1,10 +0,0 @@ -//go:build !tts -// +build !tts - -package tts - -import "fmt" - -func tts(text, model, assetDir, arLib, dst string) error { - return fmt.Errorf("this version of LocalAI was built without the tts tag") -} diff --git a/pkg/tts/piper.go b/pkg/tts/piper.go deleted file mode 100644 index b76a637..0000000 --- a/pkg/tts/piper.go +++ /dev/null @@ -1,20 +0,0 @@ -package tts - -import "os" - -type Piper struct { - assetDir string -} - -func New(assetDir string) (*Piper, error) { - if _, err := os.Stat(assetDir); err != nil { - return nil, err - } - return &Piper{ - assetDir: assetDir, - }, nil -} - -func (s *Piper) TTS(text, model, dst string) error { - return tts(text, model, s.assetDir, "", dst) -}