From aa9df08809a3b4d569d8f8f9ee11d780562d13da Mon Sep 17 00:00:00 2001 From: mudler Date: Fri, 12 May 2023 14:11:00 +0200 Subject: [PATCH] Support token embeddings in bert --- Makefile | 2 +- api/prediction.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index b6471f3..795aacb 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ GOGPT2_VERSION?=92421a8cf61ed6e03babd9067af292b094cb1307 RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47 WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993 -BERT_VERSION?=ec771ec715576ac050263bb7bb74bfd616a5ba13 +BERT_VERSION?=ac22f8f74aec5e31bc46242c17e7d511f127856b BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1 diff --git a/api/prediction.go b/api/prediction.go index f31ffd5..3dfb45f 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -68,7 +68,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) case *bert.Bert: fn = func() ([]float32, error) { if len(tokens) > 0 { - return nil, fmt.Errorf("embeddings endpoint for this model supports only string") + return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) } return model.Embeddings(s, bert.SetThreads(c.Threads)) }