feat: add experimental support for embeddings as arrays (#207)

token_berts
Ettore Di Giacinto 2 years ago committed by GitHub
parent bc03c492a0
commit 89dfa0f5fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      Makefile
  2. 12
      api/config.go
  3. 17
      api/openai.go
  4. 5
      api/prediction.go

@ -3,7 +3,7 @@ GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet GOVET=$(GOCMD) vet
BINARY_NAME=local-ai BINARY_NAME=local-ai
GOLLAMA_VERSION?=cf9b522db63898dcc5eb86e37c979ab85cbd583e GOLLAMA_VERSION?=b4e97a42d0c10ada6b529b0ec17b05c72435aeab
GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109 GOGPT4ALLJ_VERSION?=1f7bff57f66cb7062e40d0ac3abd2217815e5109
GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa GOGPT2_VERSION?=245a5bfe6708ab80dc5c733dcdbfbe3cfd2acdaa
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp

@ -33,6 +33,7 @@ type Config struct {
Mirostat int `yaml:"mirostat"` Mirostat int `yaml:"mirostat"`
PromptStrings, InputStrings []string PromptStrings, InputStrings []string
InputToken [][]int
} }
type TemplateConfig struct { type TemplateConfig struct {
@ -186,8 +187,15 @@ func updateConfig(config *Config, input *OpenAIRequest) {
} }
case []interface{}: case []interface{}:
for _, pp := range inputs { for _, pp := range inputs {
if s, ok := pp.(string); ok { switch i := pp.(type) {
config.InputStrings = append(config.InputStrings, s) case string:
config.InputStrings = append(config.InputStrings, i)
case []interface{}:
tokens := []int{}
for _, ii := range i {
tokens = append(tokens, int(ii.(float64)))
}
config.InputToken = append(config.InputToken, tokens)
} }
} }
} }

@ -177,10 +177,23 @@ func embeddingsEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Parameter Config: %+v", config) log.Debug().Msgf("Parameter Config: %+v", config)
items := []Item{} items := []Item{}
for i, s := range config.InputStrings { for i, s := range config.InputToken {
// get the model function to call for the result
embedFn, err := ModelEmbedding("", s, loader, *config)
if err != nil {
return err
}
embeddings, err := embedFn()
if err != nil {
return err
}
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"})
}
for i, s := range config.InputStrings {
// get the model function to call for the result // get the model function to call for the result
embedFn, err := ModelEmbedding(s, loader, *config) embedFn, err := ModelEmbedding(s, []int{}, loader, *config)
if err != nil { if err != nil {
return err return err
} }

@ -32,7 +32,7 @@ func defaultLLamaOpts(c Config) []llama.ModelOption {
return llamaOpts return llamaOpts
} }
func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) { func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) {
if !c.Embeddings { if !c.Embeddings {
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") return nil, fmt.Errorf("endpoint disabled for this model by API configuration")
} }
@ -57,6 +57,9 @@ func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]fl
case *llama.LLama: case *llama.LLama:
fn = func() ([]float32, error) { fn = func() ([]float32, error) {
predictOptions := buildLLamaPredictOptions(c) predictOptions := buildLLamaPredictOptions(c)
if len(tokens) > 0 {
return model.TokenEmbeddings(tokens, predictOptions...)
}
return model.Embeddings(s, predictOptions...) return model.Embeddings(s, predictOptions...)
} }
default: default:

Loading…
Cancel
Save