Support token embeddings in bert

token_berts
mudler 2 years ago
parent 98d5c2a830
commit aa9df08809
  1. 2
      Makefile
  2. 2
      api/prediction.go

@ -10,7 +10,7 @@ GOGPT2_VERSION?=92421a8cf61ed6e03babd9067af292b094cb1307
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47 RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47
WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993 WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993
BERT_VERSION?=ec771ec715576ac050263bb7bb74bfd616a5ba13 BERT_VERSION?=ac22f8f74aec5e31bc46242c17e7d511f127856b
BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1 BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1

@ -68,7 +68,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config)
case *bert.Bert: case *bert.Bert:
fn = func() ([]float32, error) { fn = func() ([]float32, error) {
if len(tokens) > 0 { if len(tokens) > 0 {
return nil, fmt.Errorf("embeddings endpoint for this model supports only string") return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads))
} }
return model.Embeddings(s, bert.SetThreads(c.Threads)) return model.Embeddings(s, bert.SetThreads(c.Threads))
} }

Loading…
Cancel
Save