From 58f6aab637ca67f9e49a8da9ac2ce3a9f5efdb01 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sat, 15 Jul 2023 01:19:43 +0200 Subject: [PATCH] feat: move llama to a grpc Signed-off-by: Ettore Di Giacinto --- Makefile | 9 +- api/prediction.go | 298 ++++------------------------ cmd/grpc/llama/main.go | 25 +++ pkg/grpc/client.go | 11 + pkg/grpc/interface.go | 1 + pkg/grpc/llm/falcon/falcon.go | 4 + pkg/grpc/llm/llama/llama.go | 165 +++++++++++++++ pkg/grpc/proto/llmserver.pb.go | 205 +++++++++++++------ pkg/grpc/proto/llmserver.proto | 8 +- pkg/grpc/proto/llmserver_grpc.pb.go | 36 ++++ pkg/grpc/server.go | 9 + pkg/model/initializers.go | 15 +- pkg/model/options.go | 8 - 13 files changed, 454 insertions(+), 340 deletions(-) create mode 100644 cmd/grpc/llama/main.go create mode 100644 pkg/grpc/llm/llama/llama.go diff --git a/Makefile b/Makefile index abac2b4..3514161 100644 --- a/Makefile +++ b/Makefile @@ -67,8 +67,8 @@ WHITE := $(shell tput -Txterm setaf 7) CYAN := $(shell tput -Txterm setaf 6) RESET := $(shell tput -Txterm sgr0) -C_INCLUDE_PATH=$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz -LIBRARY_PATH=$(shell pwd)/go-piper:$(shell pwd)/go-llama:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz +C_INCLUDE_PATH=$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz +LIBRARY_PATH=$(shell pwd)/go-piper:$(shell pwd)/go-stable-diffusion/:$(shell pwd)/gpt4all/gpt4all-bindings/golang/:$(shell pwd)/go-ggml-transformers:$(shell pwd)/go-rwkv:$(shell pwd)/whisper.cpp:$(shell pwd)/go-bert:$(shell pwd)/bloomz ifeq ($(BUILD_TYPE),openblas) CGO_LDFLAGS+=-lopenblas @@ -369,5 +369,8 @@ falcon-grpc: backend-assets/grpc CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-ggllm LIBRARY_PATH=$(shell pwd)/go-ggllm \ $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/falcon ./cmd/grpc/falcon/ +llama-grpc: backend-assets/grpc + CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(shell pwd)/go-llama LIBRARY_PATH=$(shell pwd)/go-llama \ + $(GOCMD) build -x -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/llama ./cmd/grpc/llama/ -grpcs: falcon-grpc \ No newline at end of file +grpcs: falcon-grpc llama-grpc \ No newline at end of file diff --git a/api/prediction.go b/api/prediction.go index b9b5710..970f06e 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -18,7 +18,6 @@ import ( "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp" - llama "github.com/go-skynet/go-llama.cpp" gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" ) @@ -36,6 +35,11 @@ func gRPCModelOpts(c Config) *pb.ModelOptions { ContextSize: int32(c.ContextSize), Seed: int32(c.Seed), NBatch: int32(b), + F16Memory: c.F16, + MLock: c.MMlock, + NUMA: c.NUMA, + Embeddings: c.Embeddings, + LowVRAM: c.LowVRAM, NGPULayers: int32(c.NGPULayers), MMap: c.MMap, MainGPU: c.MainGPU, @@ -43,32 +47,6 @@ func gRPCModelOpts(c Config) *pb.ModelOptions { } } -// func defaultGGLLMOpts(c Config) []ggllm.ModelOption { -// ggllmOpts := []ggllm.ModelOption{} -// if c.ContextSize != 0 { -// ggllmOpts = append(ggllmOpts, ggllm.SetContext(c.ContextSize)) -// } -// // F16 doesn't seem to produce good output at all! -// //if c.F16 { -// // llamaOpts = append(llamaOpts, llama.EnableF16Memory) -// //} - -// if c.NGPULayers != 0 { -// ggllmOpts = append(ggllmOpts, ggllm.SetGPULayers(c.NGPULayers)) -// } - -// ggllmOpts = append(ggllmOpts, ggllm.SetMMap(c.MMap)) -// ggllmOpts = append(ggllmOpts, ggllm.SetMainGPU(c.MainGPU)) -// ggllmOpts = append(ggllmOpts, ggllm.SetTensorSplit(c.TensorSplit)) -// if c.Batch != 0 { -// ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(c.Batch)) -// } else { -// ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(512)) -// } - -// return ggllmOpts -// } - func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { promptCachePath := "" if c.PromptCachePath != "" { @@ -77,14 +55,18 @@ func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { promptCachePath = p } return &pb.PredictOptions{ - Temperature: float32(c.Temperature), - TopP: float32(c.TopP), - TopK: int32(c.TopK), - Tokens: int32(c.Maxtokens), - Threads: int32(c.Threads), - PromptCacheAll: c.PromptCacheAll, - PromptCacheRO: c.PromptCacheRO, - PromptCachePath: promptCachePath, + Temperature: float32(c.Temperature), + TopP: float32(c.TopP), + TopK: int32(c.TopK), + Tokens: int32(c.Maxtokens), + Threads: int32(c.Threads), + PromptCacheAll: c.PromptCacheAll, + PromptCacheRO: c.PromptCacheRO, + PromptCachePath: promptCachePath, + F16KV: c.F16, + DebugMode: c.Debug, + Grammar: c.Grammar, + Mirostat: int32(c.Mirostat), MirostatETA: float32(c.MirostatETA), MirostatTAU: float32(c.MirostatTAU), @@ -105,200 +87,6 @@ func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { } } -// func buildGGLLMPredictOptions(c Config, modelPath string) []ggllm.PredictOption { -// // Generate the prediction using the language model -// predictOptions := []ggllm.PredictOption{ -// ggllm.SetTemperature(c.Temperature), -// ggllm.SetTopP(c.TopP), -// ggllm.SetTopK(c.TopK), -// ggllm.SetTokens(c.Maxtokens), -// ggllm.SetThreads(c.Threads), -// } - -// if c.PromptCacheAll { -// predictOptions = append(predictOptions, ggllm.EnablePromptCacheAll) -// } - -// if c.PromptCacheRO { -// predictOptions = append(predictOptions, ggllm.EnablePromptCacheRO) -// } - -// if c.PromptCachePath != "" { -// // Create parent directory -// p := filepath.Join(modelPath, c.PromptCachePath) -// os.MkdirAll(filepath.Dir(p), 0755) -// predictOptions = append(predictOptions, ggllm.SetPathPromptCache(p)) -// } - -// if c.Mirostat != 0 { -// predictOptions = append(predictOptions, ggllm.SetMirostat(c.Mirostat)) -// } - -// if c.MirostatETA != 0 { -// predictOptions = append(predictOptions, ggllm.SetMirostatETA(c.MirostatETA)) -// } - -// if c.MirostatTAU != 0 { -// predictOptions = append(predictOptions, ggllm.SetMirostatTAU(c.MirostatTAU)) -// } - -// if c.Debug { -// predictOptions = append(predictOptions, ggllm.Debug) -// } - -// predictOptions = append(predictOptions, ggllm.SetStopWords(c.StopWords...)) - -// if c.RepeatPenalty != 0 { -// predictOptions = append(predictOptions, ggllm.SetPenalty(c.RepeatPenalty)) -// } - -// if c.Keep != 0 { -// predictOptions = append(predictOptions, ggllm.SetNKeep(c.Keep)) -// } - -// if c.Batch != 0 { -// predictOptions = append(predictOptions, ggllm.SetBatch(c.Batch)) -// } - -// if c.IgnoreEOS { -// predictOptions = append(predictOptions, ggllm.IgnoreEOS) -// } - -// if c.Seed != 0 { -// predictOptions = append(predictOptions, ggllm.SetSeed(c.Seed)) -// } - -// //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) - -// predictOptions = append(predictOptions, ggllm.SetFrequencyPenalty(c.FrequencyPenalty)) -// predictOptions = append(predictOptions, ggllm.SetMlock(c.MMlock)) -// predictOptions = append(predictOptions, ggllm.SetMemoryMap(c.MMap)) -// predictOptions = append(predictOptions, ggllm.SetPredictionMainGPU(c.MainGPU)) -// predictOptions = append(predictOptions, ggllm.SetPredictionTensorSplit(c.TensorSplit)) -// predictOptions = append(predictOptions, ggllm.SetTailFreeSamplingZ(c.TFZ)) -// predictOptions = append(predictOptions, ggllm.SetTypicalP(c.TypicalP)) - -// return predictOptions -// } - -func defaultLLamaOpts(c Config) []llama.ModelOption { - llamaOpts := []llama.ModelOption{} - if c.ContextSize != 0 { - llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize)) - } - if c.F16 { - llamaOpts = append(llamaOpts, llama.EnableF16Memory) - } - if c.Embeddings { - llamaOpts = append(llamaOpts, llama.EnableEmbeddings) - } - - if c.NGPULayers != 0 { - llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers)) - } - - llamaOpts = append(llamaOpts, llama.SetMMap(c.MMap)) - llamaOpts = append(llamaOpts, llama.SetMainGPU(c.MainGPU)) - llamaOpts = append(llamaOpts, llama.SetTensorSplit(c.TensorSplit)) - if c.Batch != 0 { - llamaOpts = append(llamaOpts, llama.SetNBatch(c.Batch)) - } else { - llamaOpts = append(llamaOpts, llama.SetNBatch(512)) - } - - if c.NUMA { - llamaOpts = append(llamaOpts, llama.EnableNUMA) - } - - if c.LowVRAM { - llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) - } - - return llamaOpts -} - -func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption { - // Generate the prediction using the language model - predictOptions := []llama.PredictOption{ - llama.SetTemperature(c.Temperature), - llama.SetTopP(c.TopP), - llama.SetTopK(c.TopK), - llama.SetTokens(c.Maxtokens), - llama.SetThreads(c.Threads), - } - - if c.PromptCacheAll { - predictOptions = append(predictOptions, llama.EnablePromptCacheAll) - } - - if c.PromptCacheRO { - predictOptions = append(predictOptions, llama.EnablePromptCacheRO) - } - - predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar)) - - if c.PromptCachePath != "" { - // Create parent directory - p := filepath.Join(modelPath, c.PromptCachePath) - os.MkdirAll(filepath.Dir(p), 0755) - predictOptions = append(predictOptions, llama.SetPathPromptCache(p)) - } - - if c.Mirostat != 0 { - predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) - } - - if c.MirostatETA != 0 { - predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) - } - - if c.MirostatTAU != 0 { - predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) - } - - if c.Debug { - predictOptions = append(predictOptions, llama.Debug) - } - - predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) - - if c.RepeatPenalty != 0 { - predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) - } - - if c.Keep != 0 { - predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) - } - - if c.F16 { - predictOptions = append(predictOptions, llama.EnableF16KV) - } - - if c.IgnoreEOS { - predictOptions = append(predictOptions, llama.IgnoreEOS) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) - } - - //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) - - predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty)) - predictOptions = append(predictOptions, llama.SetMlock(c.MMlock)) - predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap)) - predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU)) - predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit)) - predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ)) - predictOptions = append(predictOptions, llama.SetTypicalP(c.TypicalP)) - - return predictOptions -} - func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) { if c.Backend != model.StableDiffusionBackend { return nil, fmt.Errorf("endpoint only working with stablediffusion models") @@ -351,14 +139,12 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, modelFile := c.Model - llamaOpts := defaultLLamaOpts(c) grpcOpts := gRPCModelOpts(c) var inferenceModel interface{} var err error opts := []model.Option{ - model.WithLlamaOpts(llamaOpts...), model.WithLoadGRPCOpts(grpcOpts), model.WithThreads(uint32(c.Threads)), model.WithAssetDir(o.assetsDestination), @@ -377,14 +163,34 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, var fn func() ([]float32, error) switch model := inferenceModel.(type) { - case *llama.LLama: + case *grpc.Client: fn = func() ([]float32, error) { - predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) + predictOptions := gRPCPredictOpts(c, loader.ModelPath) if len(tokens) > 0 { - return model.TokenEmbeddings(tokens, predictOptions...) + embeds := []int32{} + + for _, t := range tokens { + embeds = append(embeds, int32(t)) + } + predictOptions.EmbeddingTokens = embeds + + res, err := model.Embeddings(context.TODO(), predictOptions) + if err != nil { + return nil, err + } + + return res.Embeddings, nil + } + predictOptions.Embeddings = s + + res, err := model.Embeddings(context.TODO(), predictOptions) + if err != nil { + return nil, err } - return model.Embeddings(s, predictOptions...) + + return res.Embeddings, nil } + // bert embeddings case *bert.Bert: fn = func() ([]float32, error) { @@ -432,14 +238,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to supportStreams := false modelFile := c.Model - llamaOpts := defaultLLamaOpts(c) grpcOpts := gRPCModelOpts(c) var inferenceModel interface{} var err error opts := []model.Option{ - model.WithLlamaOpts(llamaOpts...), model.WithLoadGRPCOpts(grpcOpts), model.WithThreads(uint32(c.Threads)), model.WithAssetDir(o.assetsDestination), @@ -708,26 +512,6 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, to predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch)) } - str, er := model.Predict( - s, - predictOptions..., - ) - // Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels) - // For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}} - // after a stream event has occurred - model.SetTokenCallback(nil) - return str, er - } - case *llama.LLama: - supportStreams = true - fn = func() (string, error) { - - if tokenCallback != nil { - model.SetTokenCallback(tokenCallback) - } - - predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) - str, er := model.Predict( s, predictOptions..., diff --git a/cmd/grpc/llama/main.go b/cmd/grpc/llama/main.go new file mode 100644 index 0000000..d75ef48 --- /dev/null +++ b/cmd/grpc/llama/main.go @@ -0,0 +1,25 @@ +package main + +// GRPC Falcon server + +// Note: this is started internally by LocalAI and a server is allocated for each model + +import ( + "flag" + + llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama" + + grpc "github.com/go-skynet/LocalAI/pkg/grpc" +) + +var ( + addr = flag.String("addr", "localhost:50051", "the address to connect to") +) + +func main() { + flag.Parse() + + if err := grpc.StartServer(*addr, &llama.LLM{}); err != nil { + panic(err) + } +} diff --git a/pkg/grpc/client.go b/pkg/grpc/client.go index f63a89a..06628eb 100644 --- a/pkg/grpc/client.go +++ b/pkg/grpc/client.go @@ -47,6 +47,17 @@ func (c *Client) HealthCheck(ctx context.Context) bool { return false } +func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) { + conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, err + } + defer conn.Close() + client := pb.NewLLMClient(conn) + + return client.Embedding(ctx, in, opts...) +} + func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) { conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { diff --git a/pkg/grpc/interface.go b/pkg/grpc/interface.go index 8ac851a..70b830f 100644 --- a/pkg/grpc/interface.go +++ b/pkg/grpc/interface.go @@ -8,4 +8,5 @@ type LLM interface { Predict(*pb.PredictOptions) (string, error) PredictStream(*pb.PredictOptions, chan string) Load(*pb.ModelOptions) error + Embeddings(*pb.PredictOptions) ([]float32, error) } diff --git a/pkg/grpc/llm/falcon/falcon.go b/pkg/grpc/llm/falcon/falcon.go index a0a53be..5d8cf75 100644 --- a/pkg/grpc/llm/falcon/falcon.go +++ b/pkg/grpc/llm/falcon/falcon.go @@ -42,6 +42,10 @@ func (llm *LLM) Load(opts *pb.ModelOptions) error { return err } +func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + return nil, fmt.Errorf("not implemented") +} + func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption { predictOptions := []ggllm.PredictOption{ ggllm.SetTemperature(float64(opts.Temperature)), diff --git a/pkg/grpc/llm/llama/llama.go b/pkg/grpc/llm/llama/llama.go new file mode 100644 index 0000000..a31e274 --- /dev/null +++ b/pkg/grpc/llm/llama/llama.go @@ -0,0 +1,165 @@ +package llama + +// This is a wrapper to statisfy the GRPC service interface +// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) +import ( + "fmt" + + pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" + "github.com/go-skynet/go-llama.cpp" +) + +type LLM struct { + llama *llama.LLama +} + +func (llm *LLM) Load(opts *pb.ModelOptions) error { + llamaOpts := []llama.ModelOption{} + + if opts.ContextSize != 0 { + llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize))) + } + if opts.F16Memory { + llamaOpts = append(llamaOpts, llama.EnableF16Memory) + } + if opts.Embeddings { + llamaOpts = append(llamaOpts, llama.EnableEmbeddings) + } + if opts.NGPULayers != 0 { + llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers))) + } + + llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap)) + llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU)) + llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit)) + if opts.NBatch != 0 { + llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch))) + } else { + llamaOpts = append(llamaOpts, llama.SetNBatch(512)) + } + + if opts.NUMA { + llamaOpts = append(llamaOpts, llama.EnableNUMA) + } + + if opts.LowVRAM { + llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) + } + + model, err := llama.New(opts.Model, llamaOpts...) + llm.llama = model + return err +} + +func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { + predictOptions := []llama.PredictOption{ + llama.SetTemperature(float64(opts.Temperature)), + llama.SetTopP(float64(opts.TopP)), + llama.SetTopK(int(opts.TopK)), + llama.SetTokens(int(opts.Tokens)), + llama.SetThreads(int(opts.Threads)), + } + + if opts.PromptCacheAll { + predictOptions = append(predictOptions, llama.EnablePromptCacheAll) + } + + if opts.PromptCacheRO { + predictOptions = append(predictOptions, llama.EnablePromptCacheRO) + } + + predictOptions = append(predictOptions, llama.WithGrammar(opts.Grammar)) + + // Expected absolute path + if opts.PromptCachePath != "" { + predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath)) + } + + if opts.Mirostat != 0 { + predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat))) + } + + if opts.MirostatETA != 0 { + predictOptions = append(predictOptions, llama.SetMirostatETA(float64(opts.MirostatETA))) + } + + if opts.MirostatTAU != 0 { + predictOptions = append(predictOptions, llama.SetMirostatTAU(float64(opts.MirostatTAU))) + } + + if opts.Debug { + predictOptions = append(predictOptions, llama.Debug) + } + + predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...)) + + if opts.PresencePenalty != 0 { + predictOptions = append(predictOptions, llama.SetPenalty(float64(opts.PresencePenalty))) + } + + if opts.NKeep != 0 { + predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep))) + } + + if opts.Batch != 0 { + predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch))) + } + + if opts.F16KV { + predictOptions = append(predictOptions, llama.EnableF16KV) + } + + if opts.IgnoreEOS { + predictOptions = append(predictOptions, llama.IgnoreEOS) + } + + if opts.Seed != 0 { + predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed))) + } + + //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) + + predictOptions = append(predictOptions, llama.SetFrequencyPenalty(float64(opts.FrequencyPenalty))) + predictOptions = append(predictOptions, llama.SetMlock(opts.MLock)) + predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap)) + predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU)) + predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit)) + predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ))) + predictOptions = append(predictOptions, llama.SetTypicalP(float64(opts.TypicalP))) + return predictOptions +} + +func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { + return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...) +} + +func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) { + predictOptions := buildPredictOptions(opts) + + predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool { + results <- token + return true + })) + + go func() { + _, err := llm.llama.Predict(opts.Prompt, predictOptions...) + if err != nil { + fmt.Println("err: ", err) + } + close(results) + }() +} + +func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { + predictOptions := buildPredictOptions(opts) + + if len(opts.EmbeddingTokens) > 0 { + tokens := []int{} + for _, t := range opts.EmbeddingTokens { + tokens = append(tokens, int(t)) + } + return llm.llama.TokenEmbeddings(tokens, predictOptions...) + } + + return llm.llama.Embeddings(opts.Embeddings, predictOptions...) +} diff --git a/pkg/grpc/proto/llmserver.pb.go b/pkg/grpc/proto/llmserver.pb.go index 067c3a1..d54c393 100644 --- a/pkg/grpc/proto/llmserver.pb.go +++ b/pkg/grpc/proto/llmserver.pb.go @@ -87,7 +87,6 @@ type PredictOptions struct { MirostatTAU float32 `protobuf:"fixed32,21,opt,name=MirostatTAU,proto3" json:"MirostatTAU,omitempty"` PenalizeNL bool `protobuf:"varint,22,opt,name=PenalizeNL,proto3" json:"PenalizeNL,omitempty"` LogitBias string `protobuf:"bytes,23,opt,name=LogitBias,proto3" json:"LogitBias,omitempty"` - PathPromptCache string `protobuf:"bytes,24,opt,name=PathPromptCache,proto3" json:"PathPromptCache,omitempty"` MLock bool `protobuf:"varint,25,opt,name=MLock,proto3" json:"MLock,omitempty"` MMap bool `protobuf:"varint,26,opt,name=MMap,proto3" json:"MMap,omitempty"` PromptCacheAll bool `protobuf:"varint,27,opt,name=PromptCacheAll,proto3" json:"PromptCacheAll,omitempty"` @@ -98,6 +97,8 @@ type PredictOptions struct { TopP float32 `protobuf:"fixed32,32,opt,name=TopP,proto3" json:"TopP,omitempty"` PromptCachePath string `protobuf:"bytes,33,opt,name=PromptCachePath,proto3" json:"PromptCachePath,omitempty"` Debug bool `protobuf:"varint,34,opt,name=Debug,proto3" json:"Debug,omitempty"` + EmbeddingTokens []int32 `protobuf:"varint,35,rep,packed,name=EmbeddingTokens,proto3" json:"EmbeddingTokens,omitempty"` + Embeddings string `protobuf:"bytes,36,opt,name=Embeddings,proto3" json:"Embeddings,omitempty"` } func (x *PredictOptions) Reset() { @@ -293,13 +294,6 @@ func (x *PredictOptions) GetLogitBias() string { return "" } -func (x *PredictOptions) GetPathPromptCache() string { - if x != nil { - return x.PathPromptCache - } - return "" -} - func (x *PredictOptions) GetMLock() bool { if x != nil { return x.MLock @@ -370,6 +364,20 @@ func (x *PredictOptions) GetDebug() bool { return false } +func (x *PredictOptions) GetEmbeddingTokens() []int32 { + if x != nil { + return x.EmbeddingTokens + } + return nil +} + +func (x *PredictOptions) GetEmbeddings() string { + if x != nil { + return x.Embeddings + } + return "" +} + // The response message containing the result type Reply struct { state protoimpl.MessageState @@ -624,13 +632,60 @@ func (x *Result) GetSuccess() bool { return false } +type EmbeddingResult struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Embeddings []float32 `protobuf:"fixed32,1,rep,packed,name=embeddings,proto3" json:"embeddings,omitempty"` +} + +func (x *EmbeddingResult) Reset() { + *x = EmbeddingResult{} + if protoimpl.UnsafeEnabled { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *EmbeddingResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EmbeddingResult) ProtoMessage() {} + +func (x *EmbeddingResult) ProtoReflect() protoreflect.Message { + mi := &file_pkg_grpc_proto_llmserver_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EmbeddingResult.ProtoReflect.Descriptor instead. +func (*EmbeddingResult) Descriptor() ([]byte, []int) { + return file_pkg_grpc_proto_llmserver_proto_rawDescGZIP(), []int{5} +} + +func (x *EmbeddingResult) GetEmbeddings() []float32 { + if x != nil { + return x.Embeddings + } + return nil +} + var File_pkg_grpc_proto_llmserver_proto protoreflect.FileDescriptor var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{ 0x0a, 0x1e, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x03, 0x6c, 0x6c, 0x6d, 0x22, 0x0f, 0x0a, 0x0d, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x80, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, 0x64, 0x69, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0xa0, 0x08, 0x0a, 0x0e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x53, 0x65, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, @@ -673,28 +728,30 @@ var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{ 0x1e, 0x0a, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x18, 0x16, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x50, 0x65, 0x6e, 0x61, 0x6c, 0x69, 0x7a, 0x65, 0x4e, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x18, 0x17, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, 0x28, 0x0a, - 0x0f, 0x50, 0x61, 0x74, 0x68, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, - 0x18, 0x18, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x61, 0x74, 0x68, 0x50, 0x72, 0x6f, 0x6d, - 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, - 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, - 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, - 0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, - 0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, - 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x12, 0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f, - 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, 0x18, 0x1c, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, 0x12, - 0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x18, 0x1d, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, - 0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, - 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, - 0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, - 0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x18, 0x20, 0x20, - 0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x12, 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f, - 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x18, 0x21, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, - 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x18, 0x22, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x22, 0x21, 0x0a, 0x05, 0x52, 0x65, 0x70, + 0x28, 0x09, 0x52, 0x09, 0x4c, 0x6f, 0x67, 0x69, 0x74, 0x42, 0x69, 0x61, 0x73, 0x12, 0x14, 0x0a, + 0x05, 0x4d, 0x4c, 0x6f, 0x63, 0x6b, 0x18, 0x19, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x4d, 0x4c, + 0x6f, 0x63, 0x6b, 0x12, 0x12, 0x0a, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x18, 0x1a, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x04, 0x4d, 0x4d, 0x61, 0x70, 0x12, 0x26, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, + 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x18, 0x1b, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x0e, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x41, 0x6c, 0x6c, 0x12, + 0x24, 0x0a, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x52, 0x4f, + 0x18, 0x1c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, + 0x63, 0x68, 0x65, 0x52, 0x4f, 0x12, 0x18, 0x0a, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, + 0x18, 0x1d, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x47, 0x72, 0x61, 0x6d, 0x6d, 0x61, 0x72, 0x12, + 0x18, 0x0a, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x07, 0x4d, 0x61, 0x69, 0x6e, 0x47, 0x50, 0x55, 0x12, 0x20, 0x0a, 0x0b, 0x54, 0x65, 0x6e, + 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x18, 0x1f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, + 0x54, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x53, 0x70, 0x6c, 0x69, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x54, + 0x6f, 0x70, 0x50, 0x18, 0x20, 0x20, 0x01, 0x28, 0x02, 0x52, 0x04, 0x54, 0x6f, 0x70, 0x50, 0x12, + 0x28, 0x0a, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, + 0x74, 0x68, 0x18, 0x21, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x50, 0x72, 0x6f, 0x6d, 0x70, 0x74, + 0x43, 0x61, 0x63, 0x68, 0x65, 0x50, 0x61, 0x74, 0x68, 0x12, 0x14, 0x0a, 0x05, 0x44, 0x65, 0x62, + 0x75, 0x67, 0x18, 0x22, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x44, 0x65, 0x62, 0x75, 0x67, 0x12, + 0x28, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x73, 0x18, 0x23, 0x20, 0x03, 0x28, 0x05, 0x52, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, + 0x69, 0x6e, 0x67, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x45, 0x6d, 0x62, + 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x18, 0x24, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x45, + 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, 0x22, 0x21, 0x0a, 0x05, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x82, 0x03, 0x0a, 0x0c, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x14, 0x0a, @@ -724,26 +781,33 @@ var file_pkg_grpc_proto_llmserver_proto_rawDesc = []byte{ 0x74, 0x22, 0x3c, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x32, - 0xc4, 0x01, 0x0a, 0x03, 0x4c, 0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65, 0x61, 0x6c, 0x74, - 0x68, 0x12, 0x12, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, - 0x79, 0x22, 0x00, 0x12, 0x2c, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x12, 0x13, - 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, - 0x00, 0x12, 0x2d, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x11, - 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, - 0x73, 0x1a, 0x0b, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x22, 0x00, - 0x12, 0x34, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, - 0x6d, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x22, + 0x31, 0x0a, 0x0f, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, + 0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x02, 0x52, 0x0a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, + 0x67, 0x73, 0x32, 0xfe, 0x01, 0x0a, 0x03, 0x4c, 0x4c, 0x4d, 0x12, 0x2a, 0x0a, 0x06, 0x48, 0x65, + 0x61, 0x6c, 0x74, 0x68, 0x12, 0x12, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x48, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, + 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2c, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, + 0x74, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x70, - 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x42, 0x57, 0x0a, 0x1b, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79, - 0x6e, 0x65, 0x74, 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x6c, 0x6c, 0x6d, 0x73, - 0x65, 0x72, 0x76, 0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x50, 0x01, 0x5a, 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, - 0x6f, 0x2d, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, - 0x2f, 0x70, 0x6b, 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6c, 0x79, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x09, 0x4c, 0x6f, 0x61, 0x64, 0x4d, 0x6f, 0x64, 0x65, + 0x6c, 0x12, 0x11, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x4f, 0x70, 0x74, + 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0b, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x52, 0x65, 0x73, 0x75, 0x6c, + 0x74, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x0d, 0x50, 0x72, 0x65, 0x64, 0x69, 0x63, 0x74, 0x53, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, 0x64, 0x69, + 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x0a, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, + 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x30, 0x01, 0x12, 0x38, 0x0a, 0x09, 0x45, 0x6d, 0x62, + 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x12, 0x13, 0x2e, 0x6c, 0x6c, 0x6d, 0x2e, 0x50, 0x72, 0x65, + 0x64, 0x69, 0x63, 0x74, 0x4f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x1a, 0x14, 0x2e, 0x6c, 0x6c, + 0x6d, 0x2e, 0x45, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x69, 0x6e, 0x67, 0x52, 0x65, 0x73, 0x75, 0x6c, + 0x74, 0x22, 0x00, 0x42, 0x57, 0x0a, 0x1b, 0x69, 0x6f, 0x2e, 0x73, 0x6b, 0x79, 0x6e, 0x65, 0x74, + 0x2e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x61, 0x69, 0x2e, 0x6c, 0x6c, 0x6d, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x42, 0x09, 0x4c, 0x4c, 0x4d, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x50, 0x01, 0x5a, + 0x2b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x2d, 0x73, + 0x6b, 0x79, 0x6e, 0x65, 0x74, 0x2f, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x41, 0x49, 0x2f, 0x70, 0x6b, + 0x67, 0x2f, 0x67, 0x72, 0x70, 0x63, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -758,25 +822,28 @@ func file_pkg_grpc_proto_llmserver_proto_rawDescGZIP() []byte { return file_pkg_grpc_proto_llmserver_proto_rawDescData } -var file_pkg_grpc_proto_llmserver_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_pkg_grpc_proto_llmserver_proto_msgTypes = make([]protoimpl.MessageInfo, 6) var file_pkg_grpc_proto_llmserver_proto_goTypes = []interface{}{ - (*HealthMessage)(nil), // 0: llm.HealthMessage - (*PredictOptions)(nil), // 1: llm.PredictOptions - (*Reply)(nil), // 2: llm.Reply - (*ModelOptions)(nil), // 3: llm.ModelOptions - (*Result)(nil), // 4: llm.Result + (*HealthMessage)(nil), // 0: llm.HealthMessage + (*PredictOptions)(nil), // 1: llm.PredictOptions + (*Reply)(nil), // 2: llm.Reply + (*ModelOptions)(nil), // 3: llm.ModelOptions + (*Result)(nil), // 4: llm.Result + (*EmbeddingResult)(nil), // 5: llm.EmbeddingResult } var file_pkg_grpc_proto_llmserver_proto_depIdxs = []int32{ 0, // 0: llm.LLM.Health:input_type -> llm.HealthMessage 1, // 1: llm.LLM.Predict:input_type -> llm.PredictOptions 3, // 2: llm.LLM.LoadModel:input_type -> llm.ModelOptions 1, // 3: llm.LLM.PredictStream:input_type -> llm.PredictOptions - 2, // 4: llm.LLM.Health:output_type -> llm.Reply - 2, // 5: llm.LLM.Predict:output_type -> llm.Reply - 4, // 6: llm.LLM.LoadModel:output_type -> llm.Result - 2, // 7: llm.LLM.PredictStream:output_type -> llm.Reply - 4, // [4:8] is the sub-list for method output_type - 0, // [0:4] is the sub-list for method input_type + 1, // 4: llm.LLM.Embedding:input_type -> llm.PredictOptions + 2, // 5: llm.LLM.Health:output_type -> llm.Reply + 2, // 6: llm.LLM.Predict:output_type -> llm.Reply + 4, // 7: llm.LLM.LoadModel:output_type -> llm.Result + 2, // 8: llm.LLM.PredictStream:output_type -> llm.Reply + 5, // 9: llm.LLM.Embedding:output_type -> llm.EmbeddingResult + 5, // [5:10] is the sub-list for method output_type + 0, // [0:5] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -848,6 +915,18 @@ func file_pkg_grpc_proto_llmserver_proto_init() { return nil } } + file_pkg_grpc_proto_llmserver_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*EmbeddingResult); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -855,7 +934,7 @@ func file_pkg_grpc_proto_llmserver_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_pkg_grpc_proto_llmserver_proto_rawDesc, NumEnums: 0, - NumMessages: 5, + NumMessages: 6, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/grpc/proto/llmserver.proto b/pkg/grpc/proto/llmserver.proto index ba20806..b6fa4cd 100644 --- a/pkg/grpc/proto/llmserver.proto +++ b/pkg/grpc/proto/llmserver.proto @@ -12,6 +12,7 @@ service LLM { rpc Predict(PredictOptions) returns (Reply) {} rpc LoadModel(ModelOptions) returns (Result) {} rpc PredictStream(PredictOptions) returns (stream Reply) {} + rpc Embedding(PredictOptions) returns (EmbeddingResult) {} } message HealthMessage {} @@ -41,7 +42,6 @@ message PredictOptions { float MirostatTAU = 21; bool PenalizeNL = 22; string LogitBias = 23; - string PathPromptCache = 24; bool MLock = 25; bool MMap = 26; bool PromptCacheAll = 27; @@ -52,6 +52,8 @@ message PredictOptions { float TopP = 32; string PromptCachePath = 33; bool Debug = 34; + repeated int32 EmbeddingTokens = 35; + string Embeddings = 36; } // The response message containing the result @@ -79,4 +81,8 @@ message ModelOptions { message Result { string message = 1; bool success = 2; +} + +message EmbeddingResult { + repeated float embeddings = 1; } \ No newline at end of file diff --git a/pkg/grpc/proto/llmserver_grpc.pb.go b/pkg/grpc/proto/llmserver_grpc.pb.go index 6cfd981..c028218 100644 --- a/pkg/grpc/proto/llmserver_grpc.pb.go +++ b/pkg/grpc/proto/llmserver_grpc.pb.go @@ -26,6 +26,7 @@ type LLMClient interface { Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (LLM_PredictStreamClient, error) + Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) } type lLMClient struct { @@ -95,6 +96,15 @@ func (x *lLMPredictStreamClient) Recv() (*Reply, error) { return m, nil } +func (c *lLMClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { + out := new(EmbeddingResult) + err := c.cc.Invoke(ctx, "/llm.LLM/Embedding", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // LLMServer is the server API for LLM service. // All implementations must embed UnimplementedLLMServer // for forward compatibility @@ -103,6 +113,7 @@ type LLMServer interface { Predict(context.Context, *PredictOptions) (*Reply, error) LoadModel(context.Context, *ModelOptions) (*Result, error) PredictStream(*PredictOptions, LLM_PredictStreamServer) error + Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) mustEmbedUnimplementedLLMServer() } @@ -122,6 +133,9 @@ func (UnimplementedLLMServer) LoadModel(context.Context, *ModelOptions) (*Result func (UnimplementedLLMServer) PredictStream(*PredictOptions, LLM_PredictStreamServer) error { return status.Errorf(codes.Unimplemented, "method PredictStream not implemented") } +func (UnimplementedLLMServer) Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) { + return nil, status.Errorf(codes.Unimplemented, "method Embedding not implemented") +} func (UnimplementedLLMServer) mustEmbedUnimplementedLLMServer() {} // UnsafeLLMServer may be embedded to opt out of forward compatibility for this service. @@ -210,6 +224,24 @@ func (x *lLMPredictStreamServer) Send(m *Reply) error { return x.ServerStream.SendMsg(m) } +func _LLM_Embedding_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PredictOptions) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(LLMServer).Embedding(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/llm.LLM/Embedding", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(LLMServer).Embedding(ctx, req.(*PredictOptions)) + } + return interceptor(ctx, in, info, handler) +} + // LLM_ServiceDesc is the grpc.ServiceDesc for LLM service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -229,6 +261,10 @@ var LLM_ServiceDesc = grpc.ServiceDesc{ MethodName: "LoadModel", Handler: _LLM_LoadModel_Handler, }, + { + MethodName: "Embedding", + Handler: _LLM_Embedding_Handler, + }, }, Streams: []grpc.StreamDesc{ { diff --git a/pkg/grpc/server.go b/pkg/grpc/server.go index d449593..9e4c88a 100644 --- a/pkg/grpc/server.go +++ b/pkg/grpc/server.go @@ -29,6 +29,15 @@ func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, e return &pb.Reply{Message: "OK"}, nil } +func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) { + embeds, err := s.llm.Embeddings(in) + if err != nil { + return nil, err + } + + return &pb.EmbeddingResult{Embeddings: embeds}, nil +} + func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { err := s.llm.Load(in) if err != nil { diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index 5dba7ce..1acde4c 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -17,7 +17,6 @@ import ( bloomz "github.com/go-skynet/bloomz.cpp" bert "github.com/go-skynet/go-bert.cpp" transformers "github.com/go-skynet/go-ggml-transformers.cpp" - llama "github.com/go-skynet/go-llama.cpp" "github.com/hashicorp/go-multierror" "github.com/hpcloud/tail" gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" @@ -135,11 +134,11 @@ var lcHuggingFace = func(repoId string) (interface{}, error) { return langchain.NewHuggingFace(repoId) } -func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) { - return func(s string) (interface{}, error) { - return llama.New(s, opts...) - } -} +// func llamaLM(opts ...llama.ModelOption) func(string) (interface{}, error) { +// return func(s string) (interface{}, error) { +// return llama.New(s, opts...) +// } +// } func gpt4allLM(opts ...gpt4all.ModelOption) func(string) (interface{}, error) { return func(s string) (interface{}, error) { @@ -263,7 +262,8 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (model interface{}, err err log.Debug().Msgf("Loading model %s from %s", o.backendString, o.modelFile) switch strings.ToLower(o.backendString) { case LlamaBackend: - return ml.LoadModel(o.modelFile, llamaLM(o.llamaOpts...)) + // return ml.LoadModel(o.modelFile, llamaLM(o.llamaOpts...)) + return ml.LoadModel(o.modelFile, ml.grpcModel(LlamaBackend, o)) case BloomzBackend: return ml.LoadModel(o.modelFile, bloomzLM) case GPTJBackend: @@ -325,7 +325,6 @@ func (ml *ModelLoader) GreedyLoader(opts ...Option) (interface{}, error) { model, modelerr := ml.BackendLoader( WithBackendString(b), WithModelFile(o.modelFile), - WithLlamaOpts(o.llamaOpts...), WithLoadGRPCOpts(o.gRPCOptions), WithThreads(o.threads), WithAssetDir(o.assetDir), diff --git a/pkg/model/options.go b/pkg/model/options.go index 3716330..31e54cb 100644 --- a/pkg/model/options.go +++ b/pkg/model/options.go @@ -2,13 +2,11 @@ package model import ( pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - llama "github.com/go-skynet/go-llama.cpp" ) type Options struct { backendString string modelFile string - llamaOpts []llama.ModelOption threads uint32 assetDir string @@ -35,12 +33,6 @@ func WithLoadGRPCOpts(opts *pb.ModelOptions) Option { } } -func WithLlamaOpts(opts ...llama.ModelOption) Option { - return func(o *Options) { - o.llamaOpts = append(o.llamaOpts, opts...) - } -} - func WithThreads(threads uint32) Option { return func(o *Options) { o.threads = threads