Compare commits

...

5 Commits

  1. 2
      Makefile
  2. 10
      api/api_test.go
  3. 2
      api/prediction.go
  4. 4
      tests/fixtures/whisper.yaml

@ -10,7 +10,7 @@ GOGPT2_VERSION?=92421a8cf61ed6e03babd9067af292b094cb1307
RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp
RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47 RWKV_VERSION?=07166da10cb2a9e8854395a4f210464dcea76e47
WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993 WHISPER_CPP_VERSION?=bf2449dfae35a46b2cd92ab22661ce81a48d4993
BERT_VERSION?=ec771ec715576ac050263bb7bb74bfd616a5ba13 BERT_VERSION?=ac22f8f74aec5e31bc46242c17e7d511f127856b
BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1 BLOOMZ_VERSION?=e9366e82abdfe70565644fbfae9651976714efd1

@ -4,6 +4,7 @@ import (
"context" "context"
"os" "os"
"path/filepath" "path/filepath"
"runtime"
. "github.com/go-skynet/LocalAI/api" . "github.com/go-skynet/LocalAI/api"
"github.com/go-skynet/LocalAI/pkg/model" "github.com/go-skynet/LocalAI/pkg/model"
@ -46,7 +47,7 @@ var _ = Describe("API test", func() {
It("returns the models list", func() { It("returns the models list", func() {
models, err := client.ListModels(context.TODO()) models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(4)) Expect(len(models.Models)).To(Equal(5))
Expect(models.Models[0].ID).To(Equal("testmodel")) Expect(models.Models[0].ID).To(Equal("testmodel"))
}) })
It("can generate completions", func() { It("can generate completions", func() {
@ -82,7 +83,10 @@ var _ = Describe("API test", func() {
Expect(err).To(HaveOccurred()) Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:")) Expect(err.Error()).To(ContainSubstring("error, status code: 500, message: could not load model - all backends returned error: 12 errors occurred:"))
}) })
PIt("transcribes audio", func() { It("transcribes audio", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateTranscription( resp, err := client.CreateTranscription(
context.Background(), context.Background(),
openai.AudioRequest{ openai.AudioRequest{
@ -119,7 +123,7 @@ var _ = Describe("API test", func() {
models, err := client.ListModels(context.TODO()) models, err := client.ListModels(context.TODO())
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(len(models.Models)).To(Equal(6)) Expect(len(models.Models)).To(Equal(7))
Expect(models.Models[0].ID).To(Equal("testmodel")) Expect(models.Models[0].ID).To(Equal("testmodel"))
}) })
It("can generate chat completions from config file", func() { It("can generate chat completions from config file", func() {

@ -68,7 +68,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config)
case *bert.Bert: case *bert.Bert:
fn = func() ([]float32, error) { fn = func() ([]float32, error) {
if len(tokens) > 0 { if len(tokens) > 0 {
return nil, fmt.Errorf("embeddings endpoint for this model supports only string") return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads))
} }
return model.Embeddings(s, bert.SetThreads(c.Threads)) return model.Embeddings(s, bert.SetThreads(c.Threads))
} }

@ -0,0 +1,4 @@
name: whisper-1
backend: whisper
parameters:
model: whisper-en
Loading…
Cancel
Save