|
|
@ -124,8 +124,11 @@ var _ = Describe("API test", func() { |
|
|
|
var c context.Context |
|
|
|
var c context.Context |
|
|
|
var cancel context.CancelFunc |
|
|
|
var cancel context.CancelFunc |
|
|
|
var tmpdir string |
|
|
|
var tmpdir string |
|
|
|
commonOpts := []options.AppOption{options.WithDebug(false), |
|
|
|
|
|
|
|
options.WithDisableMessage(true)} |
|
|
|
commonOpts := []options.AppOption{ |
|
|
|
|
|
|
|
options.WithDebug(true), |
|
|
|
|
|
|
|
options.WithDisableMessage(true), |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Context("API with ephemeral models", func() { |
|
|
|
Context("API with ephemeral models", func() { |
|
|
|
BeforeEach(func() { |
|
|
|
BeforeEach(func() { |
|
|
@ -145,7 +148,7 @@ var _ = Describe("API test", func() { |
|
|
|
Name: "bert2", |
|
|
|
Name: "bert2", |
|
|
|
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", |
|
|
|
URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml", |
|
|
|
Overrides: map[string]interface{}{"foo": "bar"}, |
|
|
|
Overrides: map[string]interface{}{"foo": "bar"}, |
|
|
|
AdditionalFiles: []gallery.File{gallery.File{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}}, |
|
|
|
AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/bert-embeddings.yaml"}}, |
|
|
|
}, |
|
|
|
}, |
|
|
|
} |
|
|
|
} |
|
|
|
out, err := yaml.Marshal(g) |
|
|
|
out, err := yaml.Marshal(g) |
|
|
@ -421,64 +424,32 @@ var _ = Describe("API test", func() { |
|
|
|
os.RemoveAll(tmpdir) |
|
|
|
os.RemoveAll(tmpdir) |
|
|
|
}) |
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
Context("API query", func() { |
|
|
|
It("calculate embeddings with huggingface", func() { |
|
|
|
BeforeEach(func() { |
|
|
|
if runtime.GOOS != "linux" { |
|
|
|
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) |
|
|
|
Skip("test supported only on linux") |
|
|
|
c, cancel = context.WithCancel(context.Background()) |
|
|
|
} |
|
|
|
|
|
|
|
resp, err := client.CreateEmbeddings( |
|
|
|
var err error |
|
|
|
context.Background(), |
|
|
|
app, err = App( |
|
|
|
openai.EmbeddingRequest{ |
|
|
|
append(commonOpts, |
|
|
|
Model: openai.AdaCodeSearchCode, |
|
|
|
options.WithDebug(true), |
|
|
|
Input: []string{"sun", "cat"}, |
|
|
|
options.WithContext(c), options.WithModelLoader(modelLoader))...) |
|
|
|
}, |
|
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
) |
|
|
|
go app.Listen("127.0.0.1:9090") |
|
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
|
|
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) |
|
|
|
defaultConfig := openai.DefaultConfig("") |
|
|
|
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) |
|
|
|
defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
client2 = openaigo.NewClient("") |
|
|
|
|
|
|
|
client2.BaseURL = defaultConfig.BaseURL |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Wait for API to be ready
|
|
|
|
|
|
|
|
client = openai.NewClientWithConfig(defaultConfig) |
|
|
|
|
|
|
|
Eventually(func() error { |
|
|
|
|
|
|
|
_, err := client.ListModels(context.TODO()) |
|
|
|
|
|
|
|
return err |
|
|
|
|
|
|
|
}, "2m").ShouldNot(HaveOccurred()) |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
AfterEach(func() { |
|
|
|
|
|
|
|
cancel() |
|
|
|
|
|
|
|
app.Shutdown() |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
It("calculate embeddings with huggingface", func() { |
|
|
|
sunEmbedding := resp.Data[0].Embedding |
|
|
|
if runtime.GOOS != "linux" { |
|
|
|
resp2, err := client.CreateEmbeddings( |
|
|
|
Skip("test supported only on linux") |
|
|
|
context.Background(), |
|
|
|
} |
|
|
|
openai.EmbeddingRequest{ |
|
|
|
resp, err := client.CreateEmbeddings( |
|
|
|
Model: openai.AdaCodeSearchCode, |
|
|
|
context.Background(), |
|
|
|
Input: []string{"sun"}, |
|
|
|
openai.EmbeddingRequest{ |
|
|
|
}, |
|
|
|
Model: openai.AdaCodeSearchCode, |
|
|
|
) |
|
|
|
Input: []string{"sun", "cat"}, |
|
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
}, |
|
|
|
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) |
|
|
|
) |
|
|
|
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) |
|
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
|
|
Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) |
|
|
|
|
|
|
|
Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sunEmbedding := resp.Data[0].Embedding |
|
|
|
|
|
|
|
resp2, err := client.CreateEmbeddings( |
|
|
|
|
|
|
|
context.Background(), |
|
|
|
|
|
|
|
openai.EmbeddingRequest{ |
|
|
|
|
|
|
|
Model: openai.AdaCodeSearchCode, |
|
|
|
|
|
|
|
Input: []string{"sun"}, |
|
|
|
|
|
|
|
}, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
Expect(err).ToNot(HaveOccurred()) |
|
|
|
|
|
|
|
Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) |
|
|
|
|
|
|
|
Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) |
|
|
|
|
|
|
|
}) |
|
|
|
|
|
|
|
}) |
|
|
|
}) |
|
|
|
}) |
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|