diff --git a/Makefile b/Makefile index 8b3535f..419837e 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ GOTEST=$(GOCMD) test GOVET=$(GOCMD) vet BINARY_NAME=local-ai -GOLLAMA_VERSION?=cf9b522db63898dcc5eb86e37c979ab85cbd583e +GOLLAMA_VERSION?=b4e97a42d0c10ada6b529b0ec17b05c72435aeab GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109 GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp diff --git a/api/config.go b/api/config.go index b032d15..fea30ba 100644 --- a/api/config.go +++ b/api/config.go @@ -33,6 +33,7 @@ type Config struct { Mirostat int `yaml:"mirostat"` PromptStrings, InputStrings []string + InputToken [][]int } type TemplateConfig struct { @@ -186,8 +187,15 @@ func updateConfig(config *Config, input *OpenAIRequest) { } case []interface{}: for _, pp := range inputs { - if s, ok := pp.(string); ok { - config.InputStrings = append(config.InputStrings, s) + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []interface{}: + tokens := []int{} + for _, ii := range i { + tokens = append(tokens, int(ii.(float64))) + } + config.InputToken = append(config.InputToken, tokens) } } } diff --git a/api/openai.go b/api/openai.go index 213607b..c472b68 100644 --- a/api/openai.go +++ b/api/openai.go @@ -177,10 +177,23 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, log.Debug().Msgf("Parameter Config: %+v", config) items := []Item{} - for i, s := range config.InputStrings { + for i, s := range config.InputToken { + // get the model function to call for the result + embedFn, err := ModelEmbedding("", s, loader, *config) + if err != nil { + return err + } + embeddings, err := embedFn() + if err != nil { + return err + } + items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) + } + + for i, s := range config.InputStrings { // get the model function to call for the result - embedFn, err := ModelEmbedding(s, loader, *config) + embedFn, err := ModelEmbedding(s, []int{}, loader, *config) if err != nil { return err } diff --git a/api/prediction.go b/api/prediction.go index 4bfb687..95d111f 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -32,7 +32,7 @@ func defaultLLamaOpts(c Config) []llama.ModelOption { return llamaOpts } -func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) { +func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) { if !c.Embeddings { return nil, fmt.Errorf("endpoint disabled for this model by API configuration") } @@ -57,6 +57,9 @@ func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]fl case *llama.LLama: fn = func() ([]float32, error) { predictOptions := buildLLamaPredictOptions(c) + if len(tokens) > 0 { + return model.TokenEmbeddings(tokens, predictOptions...) + } return model.Embeddings(s, predictOptions...) } default: