Compare commits
265 Commits
examples_u
...
master
Author | SHA1 | Date |
---|---|---|
gregandev | 77800c1636 | 1 year ago |
ci-robbot [bot] | 5ee186b8e5 | 1 year ago |
Ettore Di Giacinto | 94817b557c | 1 year ago |
Ettore Di Giacinto | 26e1496075 | 1 year ago |
Ettore Di Giacinto | 92fca8ae74 | 1 year ago |
Stepan | 7fa5b8401d | 1 year ago |
Ettore Di Giacinto | 0eac0402e1 | 1 year ago |
Ettore Di Giacinto | c71c729bc2 | 1 year ago |
Ettore Di Giacinto | e459f114cd | 1 year ago |
Ettore Di Giacinto | 982a7e86a8 | 1 year ago |
Ettore Di Giacinto | 94916749c5 | 1 year ago |
Ettore Di Giacinto | 5ce5f87a26 | 1 year ago |
Ettore Di Giacinto | 1d2ae46ddc | 1 year ago |
ci-robbot [bot] | 71ac331f90 | 1 year ago |
Ettore Di Giacinto | 47cc95fc9f | 1 year ago |
Ettore Di Giacinto | 3feb632eb4 | 1 year ago |
Ettore Di Giacinto | 236497e331 | 1 year ago |
ci-robbot [bot] | a38dc497b2 | 1 year ago |
ci-robbot [bot] | 28ed52fa94 | 1 year ago |
Enzo Einhorn | e995b95c94 | 1 year ago |
Ettore Di Giacinto | 8379cce209 | 1 year ago |
ci-robbot [bot] | 3c6b798522 | 1 year ago |
ci-robbot [bot] | c18770a61a | 1 year ago |
Ettore Di Giacinto | 6352448b72 | 1 year ago |
renovate[bot] | fb6cce487f | 1 year ago |
renovate[bot] | 3079cc4167 | 1 year ago |
ci-robbot [bot] | 27ef8b1eb7 | 1 year ago |
ci-robbot [bot] | c00435d72b | 1 year ago |
Ettore Di Giacinto | d0e67cce75 | 1 year ago |
renovate[bot] | 6ec315e540 | 1 year ago |
renovate[bot] | cf4e6f909c | 1 year ago |
renovate[bot] | b3a99166fd | 1 year ago |
renovate[bot] | 107008331e | 1 year ago |
ci-robbot [bot] | accd9f9044 | 1 year ago |
Ettore Di Giacinto | 17294ae5e5 | 1 year ago |
renovate[bot] | 3c3a9b765a | 1 year ago |
renovate[bot] | 526c5bcdad | 1 year ago |
renovate[bot] | a1bbe75d43 | 1 year ago |
renovate[bot] | 572a311639 | 1 year ago |
Ettore Di Giacinto | cb5d6f6e3a | 1 year ago |
Ettore Di Giacinto | e3cabb555d | 1 year ago |
Ettore Di Giacinto | f193f56564 | 1 year ago |
Ettore Di Giacinto | c0a91ab548 | 1 year ago |
Ettore Di Giacinto | 26e510bf28 | 1 year ago |
Ettore Di Giacinto | 98e73ed67a | 1 year ago |
Ettore Di Giacinto | 7f3de3ca4a | 1 year ago |
Ettore Di Giacinto | 189cb3a7be | 1 year ago |
Ettore Di Giacinto | 1d0ed95a54 | 1 year ago |
Ettore Di Giacinto | 5dcfdbe51d | 1 year ago |
Ettore Di Giacinto | f2f1d7fe72 | 1 year ago |
Ettore Di Giacinto | ae533cadef | 1 year ago |
Ettore Di Giacinto | 58f6aab637 | 1 year ago |
Ettore Di Giacinto | b816009db0 | 1 year ago |
ci-robbot [bot] | a84dee1be1 | 1 year ago |
renovate[bot] | 30e4ddbf10 | 1 year ago |
Ettore Di Giacinto | 296a5b6707 | 1 year ago |
renovate[bot] | b0520dcb59 | 1 year ago |
renovate[bot] | f42967ed86 | 1 year ago |
renovate[bot] | 966675c8e3 | 1 year ago |
renovate[bot] | f68df1624b | 1 year ago |
renovate[bot] | 42cade808b | 1 year ago |
Ettore Di Giacinto | d59211982b | 1 year ago |
Ettore Di Giacinto | 7aaa10680d | 1 year ago |
mudler | dcf35dd25f | 1 year ago |
mudler | e70322676c | 1 year ago |
mudler | b3f43ab938 | 1 year ago |
mudler | bbc4468908 | 1 year ago |
mudler | 4de7f55f2f | 1 year ago |
mudler | def23e4ee2 | 1 year ago |
mudler | 55befe396a | 1 year ago |
mudler | 483fddccf9 | 1 year ago |
mudler | c4495ad8f2 | 1 year ago |
mudler | 05aed255db | 1 year ago |
mudler | 0f1326b2bd | 1 year ago |
mudler | 1668489b00 | 1 year ago |
mudler | 7dd292cbb3 | 1 year ago |
mudler | c0578031b5 | 1 year ago |
mudler | a5b64b6a41 | 1 year ago |
mudler | b722e7eb7e | 1 year ago |
mudler | 6d19a8bdb5 | 1 year ago |
mudler | f09ddd2983 | 1 year ago |
Luis López | a6839fd238 | 1 year ago |
Ettore Di Giacinto | f3063f98d3 | 1 year ago |
Ettore Di Giacinto | 70674d3c58 | 1 year ago |
ci-robbot [bot] | 3829aba869 | 1 year ago |
Ettore Di Giacinto | 92614b91d7 | 1 year ago |
Ettore Di Giacinto | bf5acf646e | 1 year ago |
renovate[bot] | 0780be022c | 1 year ago |
renovate[bot] | c756b5d054 | 1 year ago |
ci-robbot [bot] | e3db6496d7 | 1 year ago |
renovate[bot] | 1f1c95c618 | 1 year ago |
renovate[bot] | 5ea032cf81 | 1 year ago |
ci-robbot [bot] | 1e6542a5ca | 1 year ago |
ci-robbot [bot] | 218e7bc8df | 1 year ago |
Ettore Di Giacinto | a06e467a1a | 1 year ago |
mudler | 730645b3c6 | 1 year ago |
mudler | 3dd632fd5a | 1 year ago |
renovate[bot] | 365d4d3756 | 1 year ago |
renovate[bot] | d22053a5e6 | 1 year ago |
renovate[bot] | e3ac561d30 | 1 year ago |
ci-robbot [bot] | 69367a7948 | 1 year ago |
ci-robbot [bot] | 85a38a8122 | 1 year ago |
Ettore Di Giacinto | d2cf1954fc | 1 year ago |
renovate[bot] | 70712e3445 | 1 year ago |
ci-robbot [bot] | 85eea1189e | 1 year ago |
ci-robbot [bot] | ed2344ab9b | 1 year ago |
Samuel Maynard | 935bd51510 | 1 year ago |
Ettore Di Giacinto | 3593cb0c87 | 1 year ago |
Samuel Maynard | e130b208ab | 1 year ago |
Ettore Di Giacinto | 02136531a3 | 1 year ago |
Ettore Di Giacinto | d3a486a4f8 | 1 year ago |
Ettore Di Giacinto | 2b957df56c | 1 year ago |
Matthew Koski | c2dec387aa | 1 year ago |
ci-robbot [bot] | a1ed6fbd96 | 1 year ago |
renovate[bot] | ad81e37672 | 1 year ago |
Ettore Di Giacinto | 78f3c3da48 | 1 year ago |
mudler | d18f85df46 | 1 year ago |
Ettore Di Giacinto | 6213da330a | 1 year ago |
renovate[bot] | 53f8d73101 | 1 year ago |
renovate[bot] | 2cfc9a2706 | 1 year ago |
ci-robbot [bot] | 0ba94bf33f | 1 year ago |
renovate[bot] | 06570d1e41 | 1 year ago |
ci-robbot [bot] | be1667c387 | 1 year ago |
ci-robbot [bot] | eb39d908d0 | 1 year ago |
Ettore Di Giacinto | 60db5957d3 | 1 year ago |
Ettore Di Giacinto | 2a45a99737 | 1 year ago |
renovate[bot] | 91a67d5ee0 | 1 year ago |
ci-robbot [bot] | 55cf9d5792 | 1 year ago |
Ettore Di Giacinto | a7bb029d23 | 1 year ago |
ci-robbot [bot] | cc31c58235 | 1 year ago |
renovate[bot] | 4e831307a8 | 1 year ago |
ci-robbot [bot] | 445067f6ad | 1 year ago |
ci-robbot [bot] | 11bfd0de76 | 1 year ago |
mudler | dc7b8ad23b | 1 year ago |
Ettore Di Giacinto | 2f5feb4841 | 1 year ago |
renovate[bot] | 4e3c319e83 | 1 year ago |
ci-robbot [bot] | d0025a7483 | 1 year ago |
ci-robbot [bot] | db0b29be51 | 1 year ago |
Ettore Di Giacinto | 7da07e8af9 | 1 year ago |
renovate[bot] | 6da892758b | 1 year ago |
renovate[bot] | 5e88930475 | 1 year ago |
renovate[bot] | 97b02f9765 | 1 year ago |
renovate[bot] | 7ee1b10dfb | 1 year ago |
renovate[bot] | 3932c15823 | 1 year ago |
renovate[bot] | 618fd1d417 | 1 year ago |
renovate[bot] | 151a6cf4c2 | 1 year ago |
ci-robbot [bot] | 1766de814c | 1 year ago |
ci-robbot [bot] | 0b351d6da2 | 1 year ago |
renovate[bot] | 6623ce9942 | 1 year ago |
renovate[bot] | 1dbc190fa6 | 1 year ago |
renovate[bot] | 46b9445fa6 | 1 year ago |
Ettore Di Giacinto | d3d3187e51 | 1 year ago |
Ettore Di Giacinto | 6c94f3cd67 | 1 year ago |
Ettore Di Giacinto | 295f3030a9 | 1 year ago |
renovate[bot] | 1ba88258a9 | 1 year ago |
Ettore Di Giacinto | 10ddd72b58 | 1 year ago |
Ettore Di Giacinto | 1b7990d5d9 | 1 year ago |
renovate[bot] | 9f50b8024d | 1 year ago |
Samuel Maynard | 7b9dcb05d4 | 1 year ago |
Ettore Di Giacinto | e37361985c | 1 year ago |
ci-robbot [bot] | 467e88d305 | 1 year ago |
renovate[bot] | fe4a8fbc74 | 1 year ago |
renovate[bot] | 2328bbaea1 | 1 year ago |
renovate[bot] | 4cc834adcd | 1 year ago |
renovate[bot] | 5e49ff5072 | 1 year ago |
ci-robbot [bot] | f98680a18a | 1 year ago |
Ettore Di Giacinto | 2880221bb3 | 1 year ago |
Samuel Maynard | 27887c74d8 | 1 year ago |
ci-robbot [bot] | 6306885fe7 | 1 year ago |
Ettore Di Giacinto | 2a11f16c0f | 1 year ago |
Ettore Di Giacinto | 2297504fb3 | 1 year ago |
ci-robbot [bot] | 897ac6e4e5 | 1 year ago |
renovate[bot] | f20c12a1c0 | 1 year ago |
renovate[bot] | 5dea31385c | 1 year ago |
renovate[bot] | 58f0f63926 | 1 year ago |
renovate[bot] | ed2bf48a6d | 1 year ago |
ci-robbot [bot] | e6c8ebb65c | 1 year ago |
renovate[bot] | 119733892e | 1 year ago |
ci-robbot [bot] | 437f563128 | 1 year ago |
renovate[bot] | ecad2261c8 | 1 year ago |
renovate[bot] | 182323a7fb | 1 year ago |
renovate[bot] | 30d06f9b12 | 1 year ago |
ci-robbot [bot] | 6bb562272d | 1 year ago |
Ettore Di Giacinto | 3b3164b039 | 1 year ago |
renovate[bot] | 6f0bdbd01c | 1 year ago |
renovate[bot] | ce2a1799ab | 1 year ago |
renovate[bot] | d088bd3034 | 1 year ago |
ci-robbot [bot] | 806e4c3a63 | 1 year ago |
renovate[bot] | 8532ce2002 | 1 year ago |
Ettore Di Giacinto | 84946e9275 | 1 year ago |
Ettore Di Giacinto | c9bbba4872 | 1 year ago |
Ettore Di Giacinto | ea9a651573 | 1 year ago |
Ettore Di Giacinto | 5abbb134d9 | 1 year ago |
renovate[bot] | 694dd4ad9e | 1 year ago |
renovate[bot] | 4af48e548a | 1 year ago |
Ettore Di Giacinto | 079dc197c7 | 1 year ago |
renovate[bot] | 77613169da | 1 year ago |
ci-robbot [bot] | 2630e251ce | 1 year ago |
ci-robbot [bot] | 0909a0637e | 1 year ago |
Ettore Di Giacinto | d62aef2016 | 1 year ago |
ci-robbot [bot] | 25e9483add | 1 year ago |
renovate[bot] | c1be2bdeeb | 1 year ago |
renovate[bot] | 49a2b30350 | 1 year ago |
renovate[bot] | 472cd0fc2f | 1 year ago |
renovate[bot] | dc9c43b6dd | 1 year ago |
renovate[bot] | e1e23a6302 | 1 year ago |
ci-robbot [bot] | 2e916abe15 | 1 year ago |
renovate[bot] | 3ebdb9b67e | 1 year ago |
renovate[bot] | 01f5046caf | 1 year ago |
renovate[bot] | ac17d544e0 | 1 year ago |
Ettore Di Giacinto | b447a2a719 | 1 year ago |
Ettore Di Giacinto | ec4fd1d219 | 1 year ago |
Ettore Di Giacinto | b503725dc7 | 1 year ago |
ci-robbot [bot] | e873fc7b71 | 1 year ago |
renovate[bot] | 3070e9503a | 1 year ago |
Ettore Di Giacinto | d9130def39 | 1 year ago |
renovate[bot] | cdf0a6e766 | 1 year ago |
renovate[bot] | a0e0ac887f | 1 year ago |
Ettore Di Giacinto | 4ddc956462 | 1 year ago |
renovate[bot] | 203fd7b2e8 | 1 year ago |
Ettore Di Giacinto | 1bb85377e4 | 1 year ago |
renovate[bot] | 3892fafc2d | 1 year ago |
renovate[bot] | 8a34679a13 | 1 year ago |
ci-robbot [bot] | b64c1d8ac1 | 1 year ago |
Ettore Di Giacinto | 8fb86c13bc | 1 year ago |
ci-robbot [bot] | 05edf59c91 | 1 year ago |
ci-robbot [bot] | b9f1f85433 | 1 year ago |
renovate[bot] | f8e2e76698 | 1 year ago |
ci-robbot [bot] | 29856f7527 | 1 year ago |
Sébastien Prud'homme | aa6cdf16c8 | 1 year ago |
Samuel Maynard | 96794851b3 | 1 year ago |
renovate[bot] | 51a1a721b3 | 1 year ago |
renovate[bot] | 695f3e5758 | 1 year ago |
Ettore Di Giacinto | e875c1f64a | 1 year ago |
Ettore Di Giacinto | 19f92d7d55 | 1 year ago |
Ettore Di Giacinto | 5a8dd40918 | 1 year ago |
renovate[bot] | 1b766ab89c | 1 year ago |
ci-robbot [bot] | a63d6f6364 | 1 year ago |
ci-robbot [bot] | 4422ca2235 | 1 year ago |
Ettore Di Giacinto | 78ad4813df | 1 year ago |
renovate[bot] | 42d753846e | 1 year ago |
ci-robbot [bot] | 5c018c0437 | 1 year ago |
renovate[bot] | 07cee3f6ef | 1 year ago |
ci-robbot [bot] | c5cb2ff268 | 1 year ago |
Aisuko | c8a4a4f4e9 | 1 year ago |
Pavel Zloi | 3ba07a5928 | 1 year ago |
renovate[bot] | 7282668da1 | 1 year ago |
renovate[bot] | 451e803444 | 1 year ago |
Ettore Di Giacinto | d70c55231b | 1 year ago |
ci-robbot [bot] | 275c124701 | 1 year ago |
ci-robbot [bot] | 87a6bbd251 | 1 year ago |
renovate[bot] | 8fd4c7afcc | 1 year ago |
Sébastien Prud'homme | eee3f83d98 | 1 year ago |
renovate[bot] | 28ee180283 | 1 year ago |
renovate[bot] | 432b0223f1 | 1 year ago |
renovate[bot] | 16050a32c7 | 1 year ago |
renovate[bot] | 898ca62b55 | 1 year ago |
ci-robbot [bot] | 5623a7c331 | 1 year ago |
ci-robbot [bot] | 9e3ca6d1a3 | 1 year ago |
ci-robbot [bot] | fa58965bbc | 1 year ago |
renovate[bot] | b8ef9028f1 | 1 year ago |
ci-robbot [bot] | f711d35377 | 1 year ago |
ci-robbot [bot] | abd3c62194 | 1 year ago |
Ettore Di Giacinto | 2f3c3b1867 | 1 year ago |
Ettore Di Giacinto | 11af09faf3 | 1 year ago |
@ -0,0 +1,5 @@ |
||||
# These are supported funding model platforms |
||||
|
||||
github: [mudler] |
||||
custom: |
||||
- https://www.buymeacoffee.com/mudler |
@ -0,0 +1,109 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"fmt" |
||||
"sync" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { |
||||
if !c.Embeddings { |
||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") |
||||
} |
||||
|
||||
modelFile := c.Model |
||||
|
||||
grpcOpts := gRPCModelOpts(c) |
||||
|
||||
var inferenceModel interface{} |
||||
var err error |
||||
|
||||
opts := []model.Option{ |
||||
model.WithLoadGRPCLLMModelOpts(grpcOpts), |
||||
model.WithThreads(uint32(c.Threads)), |
||||
model.WithAssetDir(o.AssetsDestination), |
||||
model.WithModelFile(modelFile), |
||||
model.WithContext(o.Context), |
||||
} |
||||
|
||||
for k, v := range o.ExternalGRPCBackends { |
||||
opts = append(opts, model.WithExternalBackend(k, v)) |
||||
} |
||||
|
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(opts...) |
||||
} else { |
||||
opts = append(opts, model.WithBackendString(c.Backend)) |
||||
inferenceModel, err = loader.BackendLoader(opts...) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() ([]float32, error) |
||||
switch model := inferenceModel.(type) { |
||||
case *grpc.Client: |
||||
fn = func() ([]float32, error) { |
||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath) |
||||
if len(tokens) > 0 { |
||||
embeds := []int32{} |
||||
|
||||
for _, t := range tokens { |
||||
embeds = append(embeds, int32(t)) |
||||
} |
||||
predictOptions.EmbeddingTokens = embeds |
||||
|
||||
res, err := model.Embeddings(o.Context, predictOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return res.Embeddings, nil |
||||
} |
||||
predictOptions.Embeddings = s |
||||
|
||||
res, err := model.Embeddings(o.Context, predictOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return res.Embeddings, nil |
||||
} |
||||
default: |
||||
fn = func() ([]float32, error) { |
||||
return nil, fmt.Errorf("embeddings not supported by the backend") |
||||
} |
||||
} |
||||
|
||||
return func() ([]float32, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[modelFile] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[modelFile] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
embeds, err := fn() |
||||
if err != nil { |
||||
return embeds, err |
||||
} |
||||
// Remove trailing 0s
|
||||
for i := len(embeds) - 1; i >= 0; i-- { |
||||
if embeds[i] == 0.0 { |
||||
embeds = embeds[:i] |
||||
} else { |
||||
break |
||||
} |
||||
} |
||||
return embeds, nil |
||||
}, nil |
||||
} |
@ -0,0 +1,68 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"fmt" |
||||
"sync" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { |
||||
if c.Backend != model.StableDiffusionBackend { |
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models") |
||||
} |
||||
|
||||
opts := []model.Option{ |
||||
model.WithBackendString(c.Backend), |
||||
model.WithAssetDir(o.AssetsDestination), |
||||
model.WithThreads(uint32(c.Threads)), |
||||
model.WithContext(o.Context), |
||||
model.WithModelFile(c.ImageGenerationAssets), |
||||
} |
||||
|
||||
for k, v := range o.ExternalGRPCBackends { |
||||
opts = append(opts, model.WithExternalBackend(k, v)) |
||||
} |
||||
|
||||
inferenceModel, err := loader.BackendLoader( |
||||
opts..., |
||||
) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
fn := func() error { |
||||
_, err := inferenceModel.GenerateImage( |
||||
o.Context, |
||||
&proto.GenerateImageRequest{ |
||||
Height: int32(height), |
||||
Width: int32(width), |
||||
Mode: int32(mode), |
||||
Step: int32(step), |
||||
Seed: int32(seed), |
||||
PositivePrompt: positive_prompt, |
||||
NegativePrompt: negative_prompt, |
||||
Dst: dst, |
||||
}) |
||||
return err |
||||
} |
||||
|
||||
return func() error { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[c.Backend] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[c.Backend] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
return fn() |
||||
}, nil |
||||
} |
@ -0,0 +1,124 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"os" |
||||
"regexp" |
||||
"strings" |
||||
"sync" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/gallery" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/go-skynet/LocalAI/pkg/utils" |
||||
) |
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { |
||||
modelFile := c.Model |
||||
|
||||
grpcOpts := gRPCModelOpts(c) |
||||
|
||||
var inferenceModel *grpc.Client |
||||
var err error |
||||
|
||||
opts := []model.Option{ |
||||
model.WithLoadGRPCLLMModelOpts(grpcOpts), |
||||
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
||||
model.WithAssetDir(o.AssetsDestination), |
||||
model.WithModelFile(modelFile), |
||||
model.WithContext(o.Context), |
||||
} |
||||
|
||||
for k, v := range o.ExternalGRPCBackends { |
||||
opts = append(opts, model.WithExternalBackend(k, v)) |
||||
} |
||||
|
||||
if c.Backend != "" { |
||||
opts = append(opts, model.WithBackendString(c.Backend)) |
||||
} |
||||
|
||||
// Check if the modelFile exists, if it doesn't try to load it from the gallery
|
||||
if o.AutoloadGalleries { // experimental
|
||||
if _, err := os.Stat(modelFile); os.IsNotExist(err) { |
||||
utils.ResetDownloadTimers() |
||||
// if we failed to load the model, we try to download it
|
||||
err := gallery.InstallModelFromGalleryByName(o.Galleries, modelFile, loader.ModelPath, gallery.GalleryModel{}, utils.DisplayDownloadFunction) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
} |
||||
|
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(opts...) |
||||
} else { |
||||
inferenceModel, err = loader.BackendLoader(opts...) |
||||
} |
||||
|
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
fn := func() (string, error) { |
||||
opts := gRPCPredictOpts(c, loader.ModelPath) |
||||
opts.Prompt = s |
||||
if tokenCallback != nil { |
||||
ss := "" |
||||
err := inferenceModel.PredictStream(o.Context, opts, func(s string) { |
||||
tokenCallback(s) |
||||
ss += s |
||||
}) |
||||
return ss, err |
||||
} else { |
||||
reply, err := inferenceModel.Predict(o.Context, opts) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return reply.Message, err |
||||
} |
||||
} |
||||
|
||||
return func() (string, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[modelFile] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[modelFile] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
return fn() |
||||
}, nil |
||||
} |
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
||||
var mu sync.Mutex = sync.Mutex{} |
||||
|
||||
func Finetune(config config.Config, input, prediction string) string { |
||||
if config.Echo { |
||||
prediction = input + prediction |
||||
} |
||||
|
||||
for _, c := range config.Cutstrings { |
||||
mu.Lock() |
||||
reg, ok := cutstrings[c] |
||||
if !ok { |
||||
cutstrings[c] = regexp.MustCompile(c) |
||||
reg = cutstrings[c] |
||||
} |
||||
mu.Unlock() |
||||
prediction = reg.ReplaceAllString(prediction, "") |
||||
} |
||||
|
||||
for _, c := range config.TrimSpace { |
||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
||||
} |
||||
return prediction |
||||
|
||||
} |
@ -0,0 +1,22 @@ |
||||
package backend |
||||
|
||||
import "sync" |
||||
|
||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
var mutexMap sync.Mutex |
||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) |
||||
|
||||
func Lock(s string) *sync.Mutex { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[s] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[s] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
|
||||
return l |
||||
} |
@ -0,0 +1,72 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"os" |
||||
"path/filepath" |
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
) |
||||
|
||||
func gRPCModelOpts(c config.Config) *pb.ModelOptions { |
||||
b := 512 |
||||
if c.Batch != 0 { |
||||
b = c.Batch |
||||
} |
||||
return &pb.ModelOptions{ |
||||
ContextSize: int32(c.ContextSize), |
||||
Seed: int32(c.Seed), |
||||
NBatch: int32(b), |
||||
F16Memory: c.F16, |
||||
MLock: c.MMlock, |
||||
NUMA: c.NUMA, |
||||
Embeddings: c.Embeddings, |
||||
LowVRAM: c.LowVRAM, |
||||
NGPULayers: int32(c.NGPULayers), |
||||
MMap: c.MMap, |
||||
MainGPU: c.MainGPU, |
||||
Threads: int32(c.Threads), |
||||
TensorSplit: c.TensorSplit, |
||||
} |
||||
} |
||||
|
||||
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { |
||||
promptCachePath := "" |
||||
if c.PromptCachePath != "" { |
||||
p := filepath.Join(modelPath, c.PromptCachePath) |
||||
os.MkdirAll(filepath.Dir(p), 0755) |
||||
promptCachePath = p |
||||
} |
||||
return &pb.PredictOptions{ |
||||
Temperature: float32(c.Temperature), |
||||
TopP: float32(c.TopP), |
||||
TopK: int32(c.TopK), |
||||
Tokens: int32(c.Maxtokens), |
||||
Threads: int32(c.Threads), |
||||
PromptCacheAll: c.PromptCacheAll, |
||||
PromptCacheRO: c.PromptCacheRO, |
||||
PromptCachePath: promptCachePath, |
||||
F16KV: c.F16, |
||||
DebugMode: c.Debug, |
||||
Grammar: c.Grammar, |
||||
|
||||
Mirostat: int32(c.Mirostat), |
||||
MirostatETA: float32(c.MirostatETA), |
||||
MirostatTAU: float32(c.MirostatTAU), |
||||
Debug: c.Debug, |
||||
StopPrompts: c.StopWords, |
||||
Repeat: int32(c.RepeatPenalty), |
||||
NKeep: int32(c.Keep), |
||||
Batch: int32(c.Batch), |
||||
IgnoreEOS: c.IgnoreEOS, |
||||
Seed: int32(c.Seed), |
||||
FrequencyPenalty: float32(c.FrequencyPenalty), |
||||
MLock: c.MMlock, |
||||
MMap: c.MMap, |
||||
MainGPU: c.MainGPU, |
||||
TensorSplit: c.TensorSplit, |
||||
TailFreeSamplingZ: float32(c.TFZ), |
||||
TypicalP: float32(c.TypicalP), |
||||
} |
||||
} |
@ -0,0 +1,42 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
func ModelTranscription(audio, language string, loader *model.ModelLoader, c config.Config, o *options.Option) (*api.Result, error) { |
||||
opts := []model.Option{ |
||||
model.WithBackendString(model.WhisperBackend), |
||||
model.WithModelFile(c.Model), |
||||
model.WithContext(o.Context), |
||||
model.WithThreads(uint32(c.Threads)), |
||||
model.WithAssetDir(o.AssetsDestination), |
||||
} |
||||
|
||||
for k, v := range o.ExternalGRPCBackends { |
||||
opts = append(opts, model.WithExternalBackend(k, v)) |
||||
} |
||||
|
||||
whisperModel, err := o.Loader.BackendLoader(opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if whisperModel == nil { |
||||
return nil, fmt.Errorf("could not load whisper model") |
||||
} |
||||
|
||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ |
||||
Dst: audio, |
||||
Language: language, |
||||
Threads: uint32(c.Threads), |
||||
}) |
||||
} |
@ -0,0 +1,72 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/go-skynet/LocalAI/pkg/utils" |
||||
) |
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string { |
||||
counter := 1 |
||||
fileName := baseName + ext |
||||
|
||||
for { |
||||
filePath := filepath.Join(dir, fileName) |
||||
_, err := os.Stat(filePath) |
||||
if os.IsNotExist(err) { |
||||
return fileName |
||||
} |
||||
|
||||
counter++ |
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) |
||||
} |
||||
} |
||||
|
||||
func ModelTTS(text, modelFile string, loader *model.ModelLoader, o *options.Option) (string, *proto.Result, error) { |
||||
opts := []model.Option{ |
||||
model.WithBackendString(model.PiperBackend), |
||||
model.WithModelFile(modelFile), |
||||
model.WithContext(o.Context), |
||||
model.WithAssetDir(o.AssetsDestination), |
||||
} |
||||
|
||||
for k, v := range o.ExternalGRPCBackends { |
||||
opts = append(opts, model.WithExternalBackend(k, v)) |
||||
} |
||||
|
||||
piperModel, err := o.Loader.BackendLoader(opts...) |
||||
if err != nil { |
||||
return "", nil, err |
||||
} |
||||
|
||||
if piperModel == nil { |
||||
return "", nil, fmt.Errorf("could not load piper model") |
||||
} |
||||
|
||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil { |
||||
return "", nil, fmt.Errorf("failed creating audio directory: %s", err) |
||||
} |
||||
|
||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") |
||||
filePath := filepath.Join(o.AudioDir, fileName) |
||||
|
||||
modelPath := filepath.Join(o.Loader.ModelPath, modelFile) |
||||
|
||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { |
||||
return "", nil, err |
||||
} |
||||
|
||||
res, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ |
||||
Text: text, |
||||
Model: modelPath, |
||||
Dst: filePath, |
||||
}) |
||||
|
||||
return filePath, res, err |
||||
} |
@ -1,333 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
"sync" |
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"gopkg.in/yaml.v3" |
||||
) |
||||
|
||||
type Config struct { |
||||
OpenAIRequest `yaml:"parameters"` |
||||
Name string `yaml:"name"` |
||||
StopWords []string `yaml:"stopwords"` |
||||
Cutstrings []string `yaml:"cutstrings"` |
||||
TrimSpace []string `yaml:"trimspace"` |
||||
ContextSize int `yaml:"context_size"` |
||||
F16 bool `yaml:"f16"` |
||||
Threads int `yaml:"threads"` |
||||
Debug bool `yaml:"debug"` |
||||
Roles map[string]string `yaml:"roles"` |
||||
Embeddings bool `yaml:"embeddings"` |
||||
Backend string `yaml:"backend"` |
||||
TemplateConfig TemplateConfig `yaml:"template"` |
||||
MirostatETA float64 `yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `yaml:"mirostat_tau"` |
||||
Mirostat int `yaml:"mirostat"` |
||||
NGPULayers int `yaml:"gpu_layers"` |
||||
ImageGenerationAssets string `yaml:"asset_dir"` |
||||
|
||||
PromptCachePath string `yaml:"prompt_cache_path"` |
||||
PromptCacheAll bool `yaml:"prompt_cache_all"` |
||||
|
||||
PromptStrings, InputStrings []string |
||||
InputToken [][]int |
||||
} |
||||
|
||||
type TemplateConfig struct { |
||||
Completion string `yaml:"completion"` |
||||
Chat string `yaml:"chat"` |
||||
Edit string `yaml:"edit"` |
||||
} |
||||
|
||||
type ConfigMerger struct { |
||||
configs map[string]Config |
||||
sync.Mutex |
||||
} |
||||
|
||||
func NewConfigMerger() *ConfigMerger { |
||||
return &ConfigMerger{ |
||||
configs: make(map[string]Config), |
||||
} |
||||
} |
||||
func ReadConfigFile(file string) ([]*Config, error) { |
||||
c := &[]*Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return *c, nil |
||||
} |
||||
|
||||
func ReadConfig(file string) (*Config, error) { |
||||
c := &Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
func (cm ConfigMerger) LoadConfigFile(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfigFile(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot load config file: %w", err) |
||||
} |
||||
|
||||
for _, cc := range c { |
||||
cm.configs[cc.Name] = *cc |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (cm ConfigMerger) LoadConfig(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfig(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
|
||||
cm.configs[c.Name] = *c |
||||
return nil |
||||
} |
||||
|
||||
func (cm ConfigMerger) GetConfig(m string) (Config, bool) { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
v, exists := cm.configs[m] |
||||
return v, exists |
||||
} |
||||
|
||||
func (cm ConfigMerger) ListConfigs() []string { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
var res []string |
||||
for k := range cm.configs { |
||||
res = append(res, k) |
||||
} |
||||
return res |
||||
} |
||||
|
||||
func (cm ConfigMerger) LoadConfigs(path string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
files, err := ioutil.ReadDir(path) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
for _, file := range files { |
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") { |
||||
continue |
||||
} |
||||
c, err := ReadConfig(filepath.Join(path, file.Name())) |
||||
if err == nil { |
||||
cm.configs[c.Name] = *c |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func updateConfig(config *Config, input *OpenAIRequest) { |
||||
if input.Echo { |
||||
config.Echo = input.Echo |
||||
} |
||||
if input.TopK != 0 { |
||||
config.TopK = input.TopK |
||||
} |
||||
if input.TopP != 0 { |
||||
config.TopP = input.TopP |
||||
} |
||||
|
||||
if input.Temperature != 0 { |
||||
config.Temperature = input.Temperature |
||||
} |
||||
|
||||
if input.Maxtokens != 0 { |
||||
config.Maxtokens = input.Maxtokens |
||||
} |
||||
|
||||
switch stop := input.Stop.(type) { |
||||
case string: |
||||
if stop != "" { |
||||
config.StopWords = append(config.StopWords, stop) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range stop { |
||||
if s, ok := pp.(string); ok { |
||||
config.StopWords = append(config.StopWords, s) |
||||
} |
||||
} |
||||
} |
||||
|
||||
if input.RepeatPenalty != 0 { |
||||
config.RepeatPenalty = input.RepeatPenalty |
||||
} |
||||
|
||||
if input.Keep != 0 { |
||||
config.Keep = input.Keep |
||||
} |
||||
|
||||
if input.Batch != 0 { |
||||
config.Batch = input.Batch |
||||
} |
||||
|
||||
if input.F16 { |
||||
config.F16 = input.F16 |
||||
} |
||||
|
||||
if input.IgnoreEOS { |
||||
config.IgnoreEOS = input.IgnoreEOS |
||||
} |
||||
|
||||
if input.Seed != 0 { |
||||
config.Seed = input.Seed |
||||
} |
||||
|
||||
if input.Mirostat != 0 { |
||||
config.Mirostat = input.Mirostat |
||||
} |
||||
|
||||
if input.MirostatETA != 0 { |
||||
config.MirostatETA = input.MirostatETA |
||||
} |
||||
|
||||
if input.MirostatTAU != 0 { |
||||
config.MirostatTAU = input.MirostatTAU |
||||
} |
||||
|
||||
switch inputs := input.Input.(type) { |
||||
case string: |
||||
if inputs != "" { |
||||
config.InputStrings = append(config.InputStrings, inputs) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range inputs { |
||||
switch i := pp.(type) { |
||||
case string: |
||||
config.InputStrings = append(config.InputStrings, i) |
||||
case []interface{}: |
||||
tokens := []int{} |
||||
for _, ii := range i { |
||||
tokens = append(tokens, int(ii.(float64))) |
||||
} |
||||
config.InputToken = append(config.InputToken, tokens) |
||||
} |
||||
} |
||||
} |
||||
|
||||
switch p := input.Prompt.(type) { |
||||
case string: |
||||
config.PromptStrings = append(config.PromptStrings, p) |
||||
case []interface{}: |
||||
for _, pp := range p { |
||||
if s, ok := pp.(string); ok { |
||||
config.PromptStrings = append(config.PromptStrings, s) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { |
||||
input := new(OpenAIRequest) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return "", nil, err |
||||
} |
||||
|
||||
modelFile := input.Model |
||||
|
||||
if c.Params("model") != "" { |
||||
modelFile = c.Params("model") |
||||
} |
||||
|
||||
received, _ := json.Marshal(input) |
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received)) |
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel { |
||||
models, _ := loader.ListModels() |
||||
if len(models) > 0 { |
||||
modelFile = models[0] |
||||
log.Debug().Msgf("No model specified, using: %s", modelFile) |
||||
} else { |
||||
log.Debug().Msgf("No model specified, returning error") |
||||
return "", nil, fmt.Errorf("no model specified") |
||||
} |
||||
} |
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists { |
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
||||
modelFile = bearer |
||||
} |
||||
return modelFile, input, nil |
||||
} |
||||
|
||||
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { |
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") |
||||
if _, err := os.Stat(modelConfig); err == nil { |
||||
if err := cm.LoadConfig(modelConfig); err != nil { |
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) |
||||
} |
||||
} |
||||
|
||||
var config *Config |
||||
cfg, exists := cm.GetConfig(modelFile) |
||||
if !exists { |
||||
config = &Config{ |
||||
OpenAIRequest: defaultRequest(modelFile), |
||||
ContextSize: ctx, |
||||
Threads: threads, |
||||
F16: f16, |
||||
Debug: debug, |
||||
} |
||||
} else { |
||||
config = &cfg |
||||
} |
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateConfig(config, input) |
||||
|
||||
// Don't allow 0 as setting
|
||||
if config.Threads == 0 { |
||||
if threads != 0 { |
||||
config.Threads = threads |
||||
} else { |
||||
config.Threads = 4 |
||||
} |
||||
} |
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if debug { |
||||
config.Debug = true |
||||
} |
||||
|
||||
return config, input, nil |
||||
} |
@ -0,0 +1,209 @@ |
||||
package api_config |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io/fs" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"gopkg.in/yaml.v3" |
||||
) |
||||
|
||||
type Config struct { |
||||
PredictionOptions `yaml:"parameters"` |
||||
Name string `yaml:"name"` |
||||
StopWords []string `yaml:"stopwords"` |
||||
Cutstrings []string `yaml:"cutstrings"` |
||||
TrimSpace []string `yaml:"trimspace"` |
||||
ContextSize int `yaml:"context_size"` |
||||
F16 bool `yaml:"f16"` |
||||
NUMA bool `yaml:"numa"` |
||||
Threads int `yaml:"threads"` |
||||
Debug bool `yaml:"debug"` |
||||
Roles map[string]string `yaml:"roles"` |
||||
Embeddings bool `yaml:"embeddings"` |
||||
Backend string `yaml:"backend"` |
||||
TemplateConfig TemplateConfig `yaml:"template"` |
||||
MirostatETA float64 `yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `yaml:"mirostat_tau"` |
||||
Mirostat int `yaml:"mirostat"` |
||||
NGPULayers int `yaml:"gpu_layers"` |
||||
MMap bool `yaml:"mmap"` |
||||
MMlock bool `yaml:"mmlock"` |
||||
LowVRAM bool `yaml:"low_vram"` |
||||
|
||||
TensorSplit string `yaml:"tensor_split"` |
||||
MainGPU string `yaml:"main_gpu"` |
||||
ImageGenerationAssets string `yaml:"asset_dir"` |
||||
|
||||
PromptCachePath string `yaml:"prompt_cache_path"` |
||||
PromptCacheAll bool `yaml:"prompt_cache_all"` |
||||
PromptCacheRO bool `yaml:"prompt_cache_ro"` |
||||
|
||||
Grammar string `yaml:"grammar"` |
||||
|
||||
PromptStrings, InputStrings []string |
||||
InputToken [][]int |
||||
functionCallString, functionCallNameString string |
||||
|
||||
FunctionsConfig Functions `yaml:"function"` |
||||
} |
||||
|
||||
type Functions struct { |
||||
DisableNoAction bool `yaml:"disable_no_action"` |
||||
NoActionFunctionName string `yaml:"no_action_function_name"` |
||||
NoActionDescriptionName string `yaml:"no_action_description_name"` |
||||
} |
||||
|
||||
type TemplateConfig struct { |
||||
Completion string `yaml:"completion"` |
||||
Functions string `yaml:"function"` |
||||
Chat string `yaml:"chat"` |
||||
Edit string `yaml:"edit"` |
||||
} |
||||
|
||||
type ConfigLoader struct { |
||||
configs map[string]Config |
||||
sync.Mutex |
||||
} |
||||
|
||||
func (c *Config) SetFunctionCallString(s string) { |
||||
c.functionCallString = s |
||||
} |
||||
|
||||
func (c *Config) SetFunctionCallNameString(s string) { |
||||
c.functionCallNameString = s |
||||
} |
||||
|
||||
func (c *Config) ShouldUseFunctions() bool { |
||||
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) |
||||
} |
||||
|
||||
func (c *Config) ShouldCallSpecificFunction() bool { |
||||
return len(c.functionCallNameString) > 0 |
||||
} |
||||
|
||||
func (c *Config) FunctionToCall() string { |
||||
return c.functionCallNameString |
||||
} |
||||
|
||||
func defaultPredictOptions(modelFile string) PredictionOptions { |
||||
return PredictionOptions{ |
||||
TopP: 0.7, |
||||
TopK: 80, |
||||
Maxtokens: 512, |
||||
Temperature: 0.9, |
||||
Model: modelFile, |
||||
} |
||||
} |
||||
|
||||
func DefaultConfig(modelFile string) *Config { |
||||
return &Config{ |
||||
PredictionOptions: defaultPredictOptions(modelFile), |
||||
} |
||||
} |
||||
|
||||
func NewConfigLoader() *ConfigLoader { |
||||
return &ConfigLoader{ |
||||
configs: make(map[string]Config), |
||||
} |
||||
} |
||||
func ReadConfigFile(file string) ([]*Config, error) { |
||||
c := &[]*Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return *c, nil |
||||
} |
||||
|
||||
func ReadConfig(file string) (*Config, error) { |
||||
c := &Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
func (cm *ConfigLoader) LoadConfigFile(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfigFile(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot load config file: %w", err) |
||||
} |
||||
|
||||
for _, cc := range c { |
||||
cm.configs[cc.Name] = *cc |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (cm *ConfigLoader) LoadConfig(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfig(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
|
||||
cm.configs[c.Name] = *c |
||||
return nil |
||||
} |
||||
|
||||
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
v, exists := cm.configs[m] |
||||
return v, exists |
||||
} |
||||
|
||||
func (cm *ConfigLoader) ListConfigs() []string { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
var res []string |
||||
for k := range cm.configs { |
||||
res = append(res, k) |
||||
} |
||||
return res |
||||
} |
||||
|
||||
func (cm *ConfigLoader) LoadConfigs(path string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
entries, err := os.ReadDir(path) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
files := make([]fs.FileInfo, 0, len(entries)) |
||||
for _, entry := range entries { |
||||
info, err := entry.Info() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
files = append(files, info) |
||||
} |
||||
for _, file := range files { |
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") { |
||||
continue |
||||
} |
||||
c, err := ReadConfig(filepath.Join(path, file.Name())) |
||||
if err == nil { |
||||
cm.configs[c.Name] = *c |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,56 @@ |
||||
package api_config_test |
||||
|
||||
import ( |
||||
"os" |
||||
|
||||
. "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/model" |
||||
. "github.com/onsi/ginkgo/v2" |
||||
. "github.com/onsi/gomega" |
||||
) |
||||
|
||||
var _ = Describe("Test cases for config related functions", func() { |
||||
|
||||
var ( |
||||
configFile string |
||||
) |
||||
|
||||
Context("Test Read configuration functions", func() { |
||||
configFile = os.Getenv("CONFIG_FILE") |
||||
It("Test ReadConfigFile", func() { |
||||
config, err := ReadConfigFile(configFile) |
||||
Expect(err).To(BeNil()) |
||||
Expect(config).ToNot(BeNil()) |
||||
// two configs in config.yaml
|
||||
Expect(config[0].Name).To(Equal("list1")) |
||||
Expect(config[1].Name).To(Equal("list2")) |
||||
}) |
||||
|
||||
It("Test LoadConfigs", func() { |
||||
cm := NewConfigLoader() |
||||
opts := options.NewOptions() |
||||
modelLoader := model.NewModelLoader(os.Getenv("MODELS_PATH")) |
||||
options.WithModelLoader(modelLoader)(opts) |
||||
|
||||
err := cm.LoadConfigs(opts.Loader.ModelPath) |
||||
Expect(err).To(BeNil()) |
||||
Expect(cm.ListConfigs()).ToNot(BeNil()) |
||||
|
||||
// config should includes gpt4all models's api.config
|
||||
Expect(cm.ListConfigs()).To(ContainElements("gpt4all")) |
||||
|
||||
// config should includes gpt2 models's api.config
|
||||
Expect(cm.ListConfigs()).To(ContainElements("gpt4all-2")) |
||||
|
||||
// config should includes text-embedding-ada-002 models's api.config
|
||||
Expect(cm.ListConfigs()).To(ContainElements("text-embedding-ada-002")) |
||||
|
||||
// config should includes rwkv_test models's api.config
|
||||
Expect(cm.ListConfigs()).To(ContainElements("rwkv_test")) |
||||
|
||||
// config should includes whisper-1 models's api.config
|
||||
Expect(cm.ListConfigs()).To(ContainElements("whisper-1")) |
||||
}) |
||||
}) |
||||
}) |
@ -0,0 +1,37 @@ |
||||
package api_config |
||||
|
||||
type PredictionOptions struct { |
||||
|
||||
// Also part of the OpenAI official spec
|
||||
Model string `json:"model" yaml:"model"` |
||||
|
||||
// Also part of the OpenAI official spec
|
||||
Language string `json:"language"` |
||||
|
||||
// Also part of the OpenAI official spec. use it for returning multiple results
|
||||
N int `json:"n"` |
||||
|
||||
// Common options between all the API calls, part of the OpenAI spec
|
||||
TopP float64 `json:"top_p" yaml:"top_p"` |
||||
TopK int `json:"top_k" yaml:"top_k"` |
||||
Temperature float64 `json:"temperature" yaml:"temperature"` |
||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"` |
||||
Echo bool `json:"echo"` |
||||
|
||||
// Custom parameters - not present in the OpenAI API
|
||||
Batch int `json:"batch" yaml:"batch"` |
||||
F16 bool `json:"f16" yaml:"f16"` |
||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` |
||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` |
||||
Keep int `json:"n_keep" yaml:"n_keep"` |
||||
|
||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` |
||||
Mirostat int `json:"mirostat" yaml:"mirostat"` |
||||
|
||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` |
||||
TFZ float64 `json:"tfz" yaml:"tfz"` |
||||
|
||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"` |
||||
Seed int `json:"seed" yaml:"seed"` |
||||
} |
@ -1,27 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"os" |
||||
|
||||
. "github.com/onsi/ginkgo/v2" |
||||
. "github.com/onsi/gomega" |
||||
) |
||||
|
||||
var _ = Describe("Test cases for config related functions", func() { |
||||
|
||||
var ( |
||||
configFile string |
||||
) |
||||
|
||||
Context("Test Read configuration functions", func() { |
||||
configFile = os.Getenv("CONFIG_FILE") |
||||
It("Test ReadConfigFile", func() { |
||||
config, err := ReadConfigFile(configFile) |
||||
Expect(err).To(BeNil()) |
||||
Expect(config).ToNot(BeNil()) |
||||
// two configs in config.yaml
|
||||
Expect(len(config)).To(Equal(2)) |
||||
}) |
||||
|
||||
}) |
||||
}) |
@ -1,233 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"net/http" |
||||
"net/url" |
||||
"os" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/google/uuid" |
||||
"gopkg.in/yaml.v3" |
||||
) |
||||
|
||||
type galleryOp struct { |
||||
req ApplyGalleryModelRequest |
||||
id string |
||||
} |
||||
|
||||
type galleryOpStatus struct { |
||||
Error error `json:"error"` |
||||
Processed bool `json:"processed"` |
||||
Message string `json:"message"` |
||||
} |
||||
|
||||
type galleryApplier struct { |
||||
modelPath string |
||||
sync.Mutex |
||||
C chan galleryOp |
||||
statuses map[string]*galleryOpStatus |
||||
} |
||||
|
||||
func newGalleryApplier(modelPath string) *galleryApplier { |
||||
return &galleryApplier{ |
||||
modelPath: modelPath, |
||||
C: make(chan galleryOp), |
||||
statuses: make(map[string]*galleryOpStatus), |
||||
} |
||||
} |
||||
|
||||
func applyGallery(modelPath string, req ApplyGalleryModelRequest, cm *ConfigMerger) error { |
||||
url, err := req.DecodeURL() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// Send a GET request to the URL
|
||||
response, err := http.Get(url) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer response.Body.Close() |
||||
|
||||
// Read the response body
|
||||
body, err := ioutil.ReadAll(response.Body) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// Unmarshal YAML data into a Config struct
|
||||
var config gallery.Config |
||||
err = yaml.Unmarshal(body, &config) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...) |
||||
|
||||
if err := gallery.Apply(modelPath, req.Name, &config, req.Overrides); err != nil { |
||||
return err |
||||
} |
||||
|
||||
// Reload models
|
||||
return cm.LoadConfigs(modelPath) |
||||
} |
||||
|
||||
func (g *galleryApplier) updatestatus(s string, op *galleryOpStatus) { |
||||
g.Lock() |
||||
defer g.Unlock() |
||||
g.statuses[s] = op |
||||
} |
||||
|
||||
func (g *galleryApplier) getstatus(s string) *galleryOpStatus { |
||||
g.Lock() |
||||
defer g.Unlock() |
||||
|
||||
return g.statuses[s] |
||||
} |
||||
|
||||
func (g *galleryApplier) start(c context.Context, cm *ConfigMerger) { |
||||
go func() { |
||||
for { |
||||
select { |
||||
case <-c.Done(): |
||||
return |
||||
case op := <-g.C: |
||||
g.updatestatus(op.id, &galleryOpStatus{Message: "processing"}) |
||||
|
||||
updateError := func(e error) { |
||||
g.updatestatus(op.id, &galleryOpStatus{Error: e, Processed: true}) |
||||
} |
||||
|
||||
if err := applyGallery(g.modelPath, op.req, cm); err != nil { |
||||
updateError(err) |
||||
continue |
||||
} |
||||
|
||||
g.updatestatus(op.id, &galleryOpStatus{Processed: true, Message: "completed"}) |
||||
} |
||||
} |
||||
}() |
||||
} |
||||
|
||||
func ApplyGalleryFromFile(modelPath, s string, cm *ConfigMerger) error { |
||||
dat, err := os.ReadFile(s) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var requests []ApplyGalleryModelRequest |
||||
err = json.Unmarshal(dat, &requests) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
for _, r := range requests { |
||||
if err := applyGallery(modelPath, r, cm); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
func ApplyGalleryFromString(modelPath, s string, cm *ConfigMerger) error { |
||||
var requests []ApplyGalleryModelRequest |
||||
err := json.Unmarshal([]byte(s), &requests) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
for _, r := range requests { |
||||
if err := applyGallery(modelPath, r, cm); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// endpoints
|
||||
|
||||
type ApplyGalleryModelRequest struct { |
||||
URL string `json:"url"` |
||||
Name string `json:"name"` |
||||
Overrides map[string]interface{} `json:"overrides"` |
||||
AdditionalFiles []gallery.File `json:"files"` |
||||
} |
||||
|
||||
const ( |
||||
githubURI = "github:" |
||||
) |
||||
|
||||
func (request ApplyGalleryModelRequest) DecodeURL() (string, error) { |
||||
input := request.URL |
||||
var rawURL string |
||||
|
||||
if strings.HasPrefix(input, githubURI) { |
||||
parts := strings.Split(input, ":") |
||||
repoParts := strings.Split(parts[1], "@") |
||||
branch := "main" |
||||
|
||||
if len(repoParts) > 1 { |
||||
branch = repoParts[1] |
||||
} |
||||
|
||||
repoPath := strings.Split(repoParts[0], "/") |
||||
org := repoPath[0] |
||||
project := repoPath[1] |
||||
projectPath := strings.Join(repoPath[2:], "/") |
||||
|
||||
rawURL = fmt.Sprintf("https://raw.githubusercontent.com/%s/%s/%s/%s", org, project, branch, projectPath) |
||||
} else if strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://") { |
||||
// Handle regular URLs
|
||||
u, err := url.Parse(input) |
||||
if err != nil { |
||||
return "", fmt.Errorf("invalid URL: %w", err) |
||||
} |
||||
rawURL = u.String() |
||||
} else { |
||||
return "", fmt.Errorf("invalid URL format") |
||||
} |
||||
|
||||
return rawURL, nil |
||||
} |
||||
|
||||
func getOpStatus(g *galleryApplier) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
|
||||
status := g.getstatus(c.Params("uuid")) |
||||
if status == nil { |
||||
return fmt.Errorf("could not find any status for ID") |
||||
} |
||||
|
||||
return c.JSON(status) |
||||
} |
||||
} |
||||
|
||||
func applyModelGallery(modelPath string, cm *ConfigMerger, g chan galleryOp) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
input := new(ApplyGalleryModelRequest) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return err |
||||
} |
||||
|
||||
uuid, err := uuid.NewUUID() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
g <- galleryOp{ |
||||
req: *input, |
||||
id: uuid.String(), |
||||
} |
||||
return c.JSON(struct { |
||||
ID string `json:"uuid"` |
||||
StatusURL string `json:"status"` |
||||
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) |
||||
} |
||||
} |
@ -1,30 +0,0 @@ |
||||
package api_test |
||||
|
||||
import ( |
||||
. "github.com/go-skynet/LocalAI/api" |
||||
. "github.com/onsi/ginkgo/v2" |
||||
. "github.com/onsi/gomega" |
||||
) |
||||
|
||||
var _ = Describe("Gallery API tests", func() { |
||||
Context("requests", func() { |
||||
It("parses github with a branch", func() { |
||||
req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main"} |
||||
str, err := req.DecodeURL() |
||||
Expect(err).ToNot(HaveOccurred()) |
||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) |
||||
}) |
||||
It("parses github without a branch", func() { |
||||
req := ApplyGalleryModelRequest{URL: "github:go-skynet/model-gallery/gpt4all-j.yaml"} |
||||
str, err := req.DecodeURL() |
||||
Expect(err).ToNot(HaveOccurred()) |
||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) |
||||
}) |
||||
It("parses URLS", func() { |
||||
req := ApplyGalleryModelRequest{URL: "https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml"} |
||||
str, err := req.DecodeURL() |
||||
Expect(err).ToNot(HaveOccurred()) |
||||
Expect(str).To(Equal("https://raw.githubusercontent.com/go-skynet/model-gallery/main/gpt4all-j.yaml")) |
||||
}) |
||||
}) |
||||
}) |
@ -0,0 +1,224 @@ |
||||
package localai |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"os" |
||||
"strings" |
||||
"sync" |
||||
|
||||
json "github.com/json-iterator/go" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/pkg/gallery" |
||||
"github.com/go-skynet/LocalAI/pkg/utils" |
||||
|
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/google/uuid" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
type galleryOp struct { |
||||
req gallery.GalleryModel |
||||
id string |
||||
galleries []gallery.Gallery |
||||
galleryName string |
||||
} |
||||
|
||||
type galleryOpStatus struct { |
||||
Error error `json:"error"` |
||||
Processed bool `json:"processed"` |
||||
Message string `json:"message"` |
||||
Progress float64 `json:"progress"` |
||||
TotalFileSize string `json:"file_size"` |
||||
DownloadedFileSize string `json:"downloaded_size"` |
||||
} |
||||
|
||||
type galleryApplier struct { |
||||
modelPath string |
||||
sync.Mutex |
||||
C chan galleryOp |
||||
statuses map[string]*galleryOpStatus |
||||
} |
||||
|
||||
func NewGalleryService(modelPath string) *galleryApplier { |
||||
return &galleryApplier{ |
||||
modelPath: modelPath, |
||||
C: make(chan galleryOp), |
||||
statuses: make(map[string]*galleryOpStatus), |
||||
} |
||||
} |
||||
|
||||
// prepareModel applies a
|
||||
func prepareModel(modelPath string, req gallery.GalleryModel, cm *config.ConfigLoader, downloadStatus func(string, string, string, float64)) error { |
||||
|
||||
config, err := gallery.GetGalleryConfigFromURL(req.URL) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
config.Files = append(config.Files, req.AdditionalFiles...) |
||||
|
||||
return gallery.InstallModel(modelPath, req.Name, &config, req.Overrides, downloadStatus) |
||||
} |
||||
|
||||
func (g *galleryApplier) updateStatus(s string, op *galleryOpStatus) { |
||||
g.Lock() |
||||
defer g.Unlock() |
||||
g.statuses[s] = op |
||||
} |
||||
|
||||
func (g *galleryApplier) getStatus(s string) *galleryOpStatus { |
||||
g.Lock() |
||||
defer g.Unlock() |
||||
|
||||
return g.statuses[s] |
||||
} |
||||
|
||||
func (g *galleryApplier) Start(c context.Context, cm *config.ConfigLoader) { |
||||
go func() { |
||||
for { |
||||
select { |
||||
case <-c.Done(): |
||||
return |
||||
case op := <-g.C: |
||||
utils.ResetDownloadTimers() |
||||
|
||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: 0}) |
||||
|
||||
// updates the status with an error
|
||||
updateError := func(e error) { |
||||
g.updateStatus(op.id, &galleryOpStatus{Error: e, Processed: true, Message: "error: " + e.Error()}) |
||||
} |
||||
|
||||
// displayDownload displays the download progress
|
||||
progressCallback := func(fileName string, current string, total string, percentage float64) { |
||||
g.updateStatus(op.id, &galleryOpStatus{Message: "processing", Progress: percentage, TotalFileSize: total, DownloadedFileSize: current}) |
||||
utils.DisplayDownloadFunction(fileName, current, total, percentage) |
||||
} |
||||
|
||||
var err error |
||||
// if the request contains a gallery name, we apply the gallery from the gallery list
|
||||
if op.galleryName != "" { |
||||
if strings.Contains(op.galleryName, "@") { |
||||
err = gallery.InstallModelFromGallery(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) |
||||
} else { |
||||
err = gallery.InstallModelFromGalleryByName(op.galleries, op.galleryName, g.modelPath, op.req, progressCallback) |
||||
} |
||||
} else { |
||||
err = prepareModel(g.modelPath, op.req, cm, progressCallback) |
||||
} |
||||
|
||||
if err != nil { |
||||
updateError(err) |
||||
continue |
||||
} |
||||
|
||||
// Reload models
|
||||
err = cm.LoadConfigs(g.modelPath) |
||||
if err != nil { |
||||
updateError(err) |
||||
continue |
||||
} |
||||
|
||||
g.updateStatus(op.id, &galleryOpStatus{Processed: true, Message: "completed", Progress: 100}) |
||||
} |
||||
} |
||||
}() |
||||
} |
||||
|
||||
type galleryModel struct { |
||||
gallery.GalleryModel |
||||
ID string `json:"id"` |
||||
} |
||||
|
||||
func ApplyGalleryFromFile(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { |
||||
dat, err := os.ReadFile(s) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return ApplyGalleryFromString(modelPath, string(dat), cm, galleries) |
||||
} |
||||
|
||||
func ApplyGalleryFromString(modelPath, s string, cm *config.ConfigLoader, galleries []gallery.Gallery) error { |
||||
var requests []galleryModel |
||||
err := json.Unmarshal([]byte(s), &requests) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
for _, r := range requests { |
||||
utils.ResetDownloadTimers() |
||||
if r.ID == "" { |
||||
err = prepareModel(modelPath, r.GalleryModel, cm, utils.DisplayDownloadFunction) |
||||
} else { |
||||
err = gallery.InstallModelFromGallery(galleries, r.ID, modelPath, r.GalleryModel, utils.DisplayDownloadFunction) |
||||
} |
||||
} |
||||
|
||||
return err |
||||
} |
||||
|
||||
/// Endpoints
|
||||
|
||||
func GetOpStatusEndpoint(g *galleryApplier) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
|
||||
status := g.getStatus(c.Params("uuid")) |
||||
if status == nil { |
||||
return fmt.Errorf("could not find any status for ID") |
||||
} |
||||
|
||||
return c.JSON(status) |
||||
} |
||||
} |
||||
|
||||
type GalleryModel struct { |
||||
ID string `json:"id"` |
||||
gallery.GalleryModel |
||||
} |
||||
|
||||
func ApplyModelGalleryEndpoint(modelPath string, cm *config.ConfigLoader, g chan galleryOp, galleries []gallery.Gallery) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
input := new(GalleryModel) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return err |
||||
} |
||||
|
||||
uuid, err := uuid.NewUUID() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
g <- galleryOp{ |
||||
req: input.GalleryModel, |
||||
id: uuid.String(), |
||||
galleryName: input.ID, |
||||
galleries: galleries, |
||||
} |
||||
return c.JSON(struct { |
||||
ID string `json:"uuid"` |
||||
StatusURL string `json:"status"` |
||||
}{ID: uuid.String(), StatusURL: c.BaseURL() + "/models/jobs/" + uuid.String()}) |
||||
} |
||||
} |
||||
|
||||
func ListModelFromGalleryEndpoint(galleries []gallery.Gallery, basePath string) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
log.Debug().Msgf("Listing models from galleries: %+v", galleries) |
||||
|
||||
models, err := gallery.AvailableGalleryModels(galleries, basePath) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
log.Debug().Msgf("Models found from galleries: %+v", models) |
||||
for _, m := range models { |
||||
log.Debug().Msgf("Model found from galleries: %+v", m) |
||||
} |
||||
dat, err := json.Marshal(models) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return c.Send(dat) |
||||
} |
||||
} |
@ -0,0 +1,31 @@ |
||||
package localai |
||||
|
||||
import ( |
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/gofiber/fiber/v2" |
||||
) |
||||
|
||||
type TTSRequest struct { |
||||
Model string `json:"model" yaml:"model"` |
||||
Input string `json:"input" yaml:"input"` |
||||
} |
||||
|
||||
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
|
||||
input := new(TTSRequest) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return err |
||||
} |
||||
|
||||
filePath, _, err := backend.ModelTTS(input.Input, input.Model, o.Loader, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
return c.Download(filePath) |
||||
} |
||||
} |
@ -1,678 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"encoding/base64" |
||||
"encoding/json" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"net/http" |
||||
"os" |
||||
"path" |
||||
"path/filepath" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" |
||||
llama "github.com/go-skynet/go-llama.cpp" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"github.com/valyala/fasthttp" |
||||
) |
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
type APIError struct { |
||||
Code any `json:"code,omitempty"` |
||||
Message string `json:"message"` |
||||
Param *string `json:"param,omitempty"` |
||||
Type string `json:"type"` |
||||
} |
||||
|
||||
type ErrorResponse struct { |
||||
Error *APIError `json:"error,omitempty"` |
||||
} |
||||
|
||||
type OpenAIUsage struct { |
||||
PromptTokens int `json:"prompt_tokens"` |
||||
CompletionTokens int `json:"completion_tokens"` |
||||
TotalTokens int `json:"total_tokens"` |
||||
} |
||||
|
||||
type Item struct { |
||||
Embedding []float32 `json:"embedding"` |
||||
Index int `json:"index"` |
||||
Object string `json:"object,omitempty"` |
||||
|
||||
// Images
|
||||
URL string `json:"url,omitempty"` |
||||
B64JSON string `json:"b64_json,omitempty"` |
||||
} |
||||
|
||||
type OpenAIResponse struct { |
||||
Created int `json:"created,omitempty"` |
||||
Object string `json:"object,omitempty"` |
||||
ID string `json:"id,omitempty"` |
||||
Model string `json:"model,omitempty"` |
||||
Choices []Choice `json:"choices,omitempty"` |
||||
Data []Item `json:"data,omitempty"` |
||||
|
||||
Usage OpenAIUsage `json:"usage"` |
||||
} |
||||
|
||||
type Choice struct { |
||||
Index int `json:"index,omitempty"` |
||||
FinishReason string `json:"finish_reason,omitempty"` |
||||
Message *Message `json:"message,omitempty"` |
||||
Delta *Message `json:"delta,omitempty"` |
||||
Text string `json:"text,omitempty"` |
||||
} |
||||
|
||||
type Message struct { |
||||
Role string `json:"role,omitempty" yaml:"role"` |
||||
Content string `json:"content,omitempty" yaml:"content"` |
||||
} |
||||
|
||||
type OpenAIModel struct { |
||||
ID string `json:"id"` |
||||
Object string `json:"object"` |
||||
} |
||||
|
||||
type OpenAIRequest struct { |
||||
Model string `json:"model" yaml:"model"` |
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"` |
||||
Language string `json:"language"` |
||||
//whisper/image
|
||||
ResponseFormat string `json:"response_format"` |
||||
// image
|
||||
Size string `json:"size"` |
||||
// Prompt is read only by completion/image API calls
|
||||
Prompt interface{} `json:"prompt" yaml:"prompt"` |
||||
|
||||
// Edit endpoint
|
||||
Instruction string `json:"instruction" yaml:"instruction"` |
||||
Input interface{} `json:"input" yaml:"input"` |
||||
|
||||
Stop interface{} `json:"stop" yaml:"stop"` |
||||
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages" yaml:"messages"` |
||||
|
||||
Stream bool `json:"stream"` |
||||
Echo bool `json:"echo"` |
||||
// Common options between all the API calls
|
||||
TopP float64 `json:"top_p" yaml:"top_p"` |
||||
TopK int `json:"top_k" yaml:"top_k"` |
||||
Temperature float64 `json:"temperature" yaml:"temperature"` |
||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"` |
||||
|
||||
N int `json:"n"` |
||||
|
||||
// Custom parameters - not present in the OpenAI API
|
||||
Batch int `json:"batch" yaml:"batch"` |
||||
F16 bool `json:"f16" yaml:"f16"` |
||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` |
||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` |
||||
Keep int `json:"n_keep" yaml:"n_keep"` |
||||
|
||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` |
||||
Mirostat int `json:"mirostat" yaml:"mirostat"` |
||||
|
||||
Seed int `json:"seed" yaml:"seed"` |
||||
|
||||
// Image (not supported by OpenAI)
|
||||
Mode int `json:"mode"` |
||||
Step int `json:"step"` |
||||
} |
||||
|
||||
func defaultRequest(modelFile string) OpenAIRequest { |
||||
return OpenAIRequest{ |
||||
TopP: 0.7, |
||||
TopK: 80, |
||||
Maxtokens: 512, |
||||
Temperature: 0.9, |
||||
Model: modelFile, |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
|
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Completion != "" { |
||||
templateFile = config.TemplateConfig.Completion |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.PromptStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{Input: i}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "text_completion", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
items := []Item{} |
||||
|
||||
for i, s := range config.InputToken { |
||||
// get the model function to call for the result
|
||||
embedFn, err := ModelEmbedding("", s, o.loader, *config) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
for i, s := range config.InputStrings { |
||||
// get the model function to call for the result
|
||||
embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items, |
||||
Object: "list", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
|
||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
initialMessage := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
responses <- initialMessage |
||||
|
||||
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Content: s}}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
log.Debug().Msgf("Sending goroutine: %s", s) |
||||
|
||||
responses <- resp |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
var predInput string |
||||
|
||||
mess := []string{} |
||||
for _, i := range input.Messages { |
||||
var content string |
||||
r := config.Roles[i.Role] |
||||
if r != "" { |
||||
content = fmt.Sprint(r, " ", i.Content) |
||||
} else { |
||||
content = i.Content |
||||
} |
||||
|
||||
mess = append(mess, content) |
||||
} |
||||
|
||||
predInput = strings.Join(mess, "\n") |
||||
|
||||
if input.Stream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Chat != "" { |
||||
templateFile = config.TemplateConfig.Chat |
||||
} |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{Input: predInput}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} |
||||
|
||||
if input.Stream { |
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{FinishReason: "stop"}}, |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
result, err := ComputeChoices(predInput, input, config, o.loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "chat.completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", respData) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Edit != "" { |
||||
templateFile = config.TemplateConfig.Edit |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.InputStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Instruction string |
||||
}{Input: i}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "edit", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/* |
||||
* |
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{ |
||||
"prompt": "A cute baby sea otter", |
||||
"n": 1, |
||||
"size": "512x512" |
||||
}' |
||||
|
||||
* |
||||
*/ |
||||
func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
if m == "" { |
||||
m = model.StableDiffusionBackend |
||||
} |
||||
log.Debug().Msgf("Loading model: %+v", m) |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
// XXX: Only stablediffusion is supported for now
|
||||
if config.Backend == "" { |
||||
config.Backend = model.StableDiffusionBackend |
||||
} |
||||
|
||||
sizeParts := strings.Split(input.Size, "x") |
||||
if len(sizeParts) != 2 { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
width, err := strconv.Atoi(sizeParts[0]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
height, err := strconv.Atoi(sizeParts[1]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
|
||||
b64JSON := false |
||||
if input.ResponseFormat == "b64_json" { |
||||
b64JSON = true |
||||
} |
||||
|
||||
var result []Item |
||||
for _, i := range config.PromptStrings { |
||||
n := input.N |
||||
if input.N == 0 { |
||||
n = 1 |
||||
} |
||||
for j := 0; j < n; j++ { |
||||
prompts := strings.Split(i, "|") |
||||
positive_prompt := prompts[0] |
||||
negative_prompt := "" |
||||
if len(prompts) > 1 { |
||||
negative_prompt = prompts[1] |
||||
} |
||||
|
||||
mode := 0 |
||||
step := 15 |
||||
|
||||
if input.Mode != 0 { |
||||
mode = input.Mode |
||||
} |
||||
|
||||
if input.Step != 0 { |
||||
step = input.Step |
||||
} |
||||
|
||||
tempDir := "" |
||||
if !b64JSON { |
||||
tempDir = o.imageDir |
||||
} |
||||
// Create a temporary file
|
||||
outputFile, err := ioutil.TempFile(tempDir, "b64") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
outputFile.Close() |
||||
output := outputFile.Name() + ".png" |
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
baseURL := c.BaseURL() |
||||
|
||||
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err := fn(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
item := &Item{} |
||||
|
||||
if b64JSON { |
||||
defer os.RemoveAll(output) |
||||
data, err := os.ReadFile(output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
||||
} else { |
||||
base := filepath.Base(output) |
||||
item.URL = baseURL + "/generated-images/" + base |
||||
} |
||||
|
||||
result = append(result, *item) |
||||
} |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Data: result, |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
f, err := file.Open() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer f.Close() |
||||
|
||||
dir, err := os.MkdirTemp("", "whisper") |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer os.RemoveAll(dir) |
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename)) |
||||
dstFile, err := os.Create(dst) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil { |
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst) |
||||
|
||||
whisperModel, err := o.loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if whisperModel == nil { |
||||
return fmt.Errorf("could not load whisper model") |
||||
} |
||||
|
||||
w, ok := whisperModel.(whisper.Model) |
||||
if !ok { |
||||
return fmt.Errorf("loader returned non-whisper object") |
||||
} |
||||
|
||||
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr) |
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr}) |
||||
} |
||||
} |
||||
|
||||
func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
models, err := loader.ListModels() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var mm map[string]interface{} = map[string]interface{}{} |
||||
|
||||
dataModels := []OpenAIModel{} |
||||
for _, m := range models { |
||||
mm[m] = nil |
||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) |
||||
} |
||||
|
||||
for _, k := range cm.ListConfigs() { |
||||
if _, exists := mm[k]; !exists { |
||||
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) |
||||
} |
||||
} |
||||
|
||||
return c.JSON(struct { |
||||
Object string `json:"object"` |
||||
Data []OpenAIModel `json:"data"` |
||||
}{ |
||||
Object: "list", |
||||
Data: dataModels, |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,105 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||
) |
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
type APIError struct { |
||||
Code any `json:"code,omitempty"` |
||||
Message string `json:"message"` |
||||
Param *string `json:"param,omitempty"` |
||||
Type string `json:"type"` |
||||
} |
||||
|
||||
type ErrorResponse struct { |
||||
Error *APIError `json:"error,omitempty"` |
||||
} |
||||
|
||||
type OpenAIUsage struct { |
||||
PromptTokens int `json:"prompt_tokens"` |
||||
CompletionTokens int `json:"completion_tokens"` |
||||
TotalTokens int `json:"total_tokens"` |
||||
} |
||||
|
||||
type Item struct { |
||||
Embedding []float32 `json:"embedding"` |
||||
Index int `json:"index"` |
||||
Object string `json:"object,omitempty"` |
||||
|
||||
// Images
|
||||
URL string `json:"url,omitempty"` |
||||
B64JSON string `json:"b64_json,omitempty"` |
||||
} |
||||
|
||||
type OpenAIResponse struct { |
||||
Created int `json:"created,omitempty"` |
||||
Object string `json:"object,omitempty"` |
||||
ID string `json:"id,omitempty"` |
||||
Model string `json:"model,omitempty"` |
||||
Choices []Choice `json:"choices,omitempty"` |
||||
Data []Item `json:"data,omitempty"` |
||||
|
||||
Usage OpenAIUsage `json:"usage"` |
||||
} |
||||
|
||||
type Choice struct { |
||||
Index int `json:"index"` |
||||
FinishReason string `json:"finish_reason,omitempty"` |
||||
Message *Message `json:"message,omitempty"` |
||||
Delta *Message `json:"delta,omitempty"` |
||||
Text string `json:"text,omitempty"` |
||||
} |
||||
|
||||
type Message struct { |
||||
// The message role
|
||||
Role string `json:"role,omitempty" yaml:"role"` |
||||
// The message content
|
||||
Content *string `json:"content" yaml:"content"` |
||||
// A result of a function call
|
||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` |
||||
} |
||||
|
||||
type OpenAIModel struct { |
||||
ID string `json:"id"` |
||||
Object string `json:"object"` |
||||
} |
||||
|
||||
type OpenAIRequest struct { |
||||
config.PredictionOptions |
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"` |
||||
//whisper/image
|
||||
ResponseFormat string `json:"response_format"` |
||||
// image
|
||||
Size string `json:"size"` |
||||
// Prompt is read only by completion/image API calls
|
||||
Prompt interface{} `json:"prompt" yaml:"prompt"` |
||||
|
||||
// Edit endpoint
|
||||
Instruction string `json:"instruction" yaml:"instruction"` |
||||
Input interface{} `json:"input" yaml:"input"` |
||||
|
||||
Stop interface{} `json:"stop" yaml:"stop"` |
||||
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages" yaml:"messages"` |
||||
|
||||
// A list of available functions to call
|
||||
Functions []grammar.Function `json:"functions" yaml:"functions"` |
||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||
|
||||
Stream bool `json:"stream"` |
||||
|
||||
// Image (not supported by OpenAI)
|
||||
Mode int `json:"mode"` |
||||
Step int `json:"step"` |
||||
|
||||
// A grammar to constrain the LLM output
|
||||
Grammar string `json:"grammar" yaml:"grammar"` |
||||
|
||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` |
||||
} |
@ -0,0 +1,322 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"encoding/json" |
||||
"fmt" |
||||
"strings" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"github.com/valyala/fasthttp" |
||||
) |
||||
|
||||
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
emptyMessage := "" |
||||
|
||||
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
initialMessage := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Role: "assistant", Content: &emptyMessage}}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
responses <- initialMessage |
||||
|
||||
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
|
||||
responses <- resp |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
return func(c *fiber.Ctx) error { |
||||
processFunctions := false |
||||
funcs := grammar.Functions{} |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
log.Debug().Msgf("Configuration read: %+v", config) |
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer" |
||||
noActionDescription := "use this action to answer without performing any action" |
||||
|
||||
if config.FunctionsConfig.NoActionFunctionName != "" { |
||||
noActionName = config.FunctionsConfig.NoActionFunctionName |
||||
} |
||||
if config.FunctionsConfig.NoActionDescriptionName != "" { |
||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName |
||||
} |
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() { |
||||
log.Debug().Msgf("Response needs to process functions") |
||||
|
||||
processFunctions = true |
||||
|
||||
noActionGrammar := grammar.Function{ |
||||
Name: noActionName, |
||||
Description: noActionDescription, |
||||
Parameters: map[string]interface{}{ |
||||
"properties": map[string]interface{}{ |
||||
"message": map[string]interface{}{ |
||||
"type": "string", |
||||
"description": "The message to reply the user with", |
||||
}}, |
||||
}, |
||||
} |
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...) |
||||
if !config.FunctionsConfig.DisableNoAction { |
||||
funcs = append(funcs, noActionGrammar) |
||||
} |
||||
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" { |
||||
funcs = funcs.Select(config.FunctionToCall()) |
||||
} |
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure() |
||||
config.Grammar = jsStruct.Grammar("") |
||||
} else if input.JSONFunctionGrammarObject != nil { |
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("") |
||||
} |
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream && !processFunctions |
||||
|
||||
log.Debug().Msgf("Parameters: %+v", config) |
||||
|
||||
var predInput string |
||||
|
||||
mess := []string{} |
||||
for _, i := range input.Messages { |
||||
var content string |
||||
role := i.Role |
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if i.FunctionCall != nil && i.Role == "assistant" { |
||||
roleFn := "assistant_function_call" |
||||
r := config.Roles[roleFn] |
||||
if r != "" { |
||||
role = roleFn |
||||
} |
||||
} |
||||
r := config.Roles[role] |
||||
contentExists := i.Content != nil && *i.Content != "" |
||||
if r != "" { |
||||
if contentExists { |
||||
content = fmt.Sprint(r, " ", *i.Content) |
||||
} |
||||
if i.FunctionCall != nil { |
||||
j, err := json.Marshal(i.FunctionCall) |
||||
if err == nil { |
||||
if contentExists { |
||||
content += "\n" + fmt.Sprint(r, " ", string(j)) |
||||
} else { |
||||
content = fmt.Sprint(r, " ", string(j)) |
||||
} |
||||
} |
||||
} |
||||
} else { |
||||
if contentExists { |
||||
content = fmt.Sprint(*i.Content) |
||||
} |
||||
if i.FunctionCall != nil { |
||||
j, err := json.Marshal(i.FunctionCall) |
||||
if err == nil { |
||||
if contentExists { |
||||
content += "\n" + string(j) |
||||
} else { |
||||
content = string(j) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
mess = append(mess, content) |
||||
} |
||||
|
||||
predInput = strings.Join(mess, "\n") |
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput) |
||||
|
||||
if toStream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions { |
||||
templateFile = config.TemplateConfig.Chat |
||||
} |
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions { |
||||
templateFile = config.TemplateConfig.Functions |
||||
} |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Functions []grammar.Function |
||||
}{ |
||||
Input: predInput, |
||||
Functions: funcs, |
||||
}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} else { |
||||
log.Debug().Msgf("Template failed loading: %s", err.Error()) |
||||
} |
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput) |
||||
if processFunctions { |
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar) |
||||
} |
||||
|
||||
if toStream { |
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.Loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
FinishReason: "stop", |
||||
Index: 0, |
||||
Delta: &Message{Content: &emptyMessage}, |
||||
}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||
if processFunctions { |
||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||
ss := map[string]interface{}{} |
||||
json.Unmarshal([]byte(s), &ss) |
||||
log.Debug().Msgf("Function return: %s %+v", s, ss) |
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name := ss["function"] |
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
d, _ := json.Marshal(args) |
||||
|
||||
ss["arguments"] = string(d) |
||||
ss["name"] = func_name |
||||
|
||||
// if do nothing, reply with a message
|
||||
if func_name == noActionName { |
||||
log.Debug().Msgf("nothing to do, computing a reply") |
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{} |
||||
json.Unmarshal([]byte(d), &arguments) |
||||
m, exists := arguments["message"] |
||||
if exists { |
||||
switch message := m.(type) { |
||||
case string: |
||||
if message != "" { |
||||
log.Debug().Msgf("Reply received from LLM: %s", message) |
||||
message = backend.Finetune(*config, predInput, message) |
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) |
||||
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply") |
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU) another computation
|
||||
config.Grammar = "" |
||||
predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil) |
||||
if err != nil { |
||||
log.Error().Msgf("inference error: %s", err.Error()) |
||||
return |
||||
} |
||||
|
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
log.Error().Msgf("inference error: %s", err.Error()) |
||||
return |
||||
} |
||||
|
||||
prediction = backend.Finetune(*config, predInput, prediction) |
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) |
||||
} else { |
||||
// otherwise reply with the function call
|
||||
*c = append(*c, Choice{ |
||||
FinishReason: "function_call", |
||||
Message: &Message{Role: "assistant", FunctionCall: ss}, |
||||
}) |
||||
} |
||||
|
||||
return |
||||
} |
||||
*c = append(*c, Choice{FinishReason: "stop", Index: 0, Message: &Message{Role: "assistant", Content: &s}}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "chat.completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", respData) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,159 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"github.com/valyala/fasthttp" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
Index: 0, |
||||
Text: s, |
||||
}, |
||||
}, |
||||
Object: "text_completion", |
||||
} |
||||
log.Debug().Msgf("Sending goroutine: %s", s) |
||||
|
||||
responses <- resp |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
|
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("`input`: %+v", input) |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
if input.Stream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Completion != "" { |
||||
templateFile = config.TemplateConfig.Completion |
||||
} |
||||
|
||||
if input.Stream { |
||||
if len(config.PromptStrings) > 1 { |
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") |
||||
} |
||||
|
||||
predInput := config.PromptStrings[0] |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{ |
||||
Input: predInput, |
||||
}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} |
||||
|
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.Loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
Index: 0, |
||||
FinishReason: "stop", |
||||
}, |
||||
}, |
||||
Object: "text_completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
var result []Choice |
||||
for k, i := range config.PromptStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{ |
||||
Input: i, |
||||
}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s, FinishReason: "stop", Index: k}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "text_completion", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,67 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Edit != "" { |
||||
templateFile = config.TemplateConfig.Edit |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.InputStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Instruction string |
||||
}{Input: i}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "edit", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,70 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
items := []Item{} |
||||
|
||||
for i, s := range config.InputToken { |
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
for i, s := range config.InputStrings { |
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items, |
||||
Object: "list", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,158 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/base64" |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"path/filepath" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/* |
||||
* |
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{ |
||||
"prompt": "A cute baby sea otter", |
||||
"n": 1, |
||||
"size": "512x512" |
||||
}' |
||||
|
||||
* |
||||
*/ |
||||
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.Loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
if m == "" { |
||||
m = model.StableDiffusionBackend |
||||
} |
||||
log.Debug().Msgf("Loading model: %+v", m) |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
// XXX: Only stablediffusion is supported for now
|
||||
if config.Backend == "" { |
||||
config.Backend = model.StableDiffusionBackend |
||||
} |
||||
|
||||
sizeParts := strings.Split(input.Size, "x") |
||||
if len(sizeParts) != 2 { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
width, err := strconv.Atoi(sizeParts[0]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
height, err := strconv.Atoi(sizeParts[1]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
|
||||
b64JSON := false |
||||
if input.ResponseFormat == "b64_json" { |
||||
b64JSON = true |
||||
} |
||||
|
||||
var result []Item |
||||
for _, i := range config.PromptStrings { |
||||
n := input.N |
||||
if input.N == 0 { |
||||
n = 1 |
||||
} |
||||
for j := 0; j < n; j++ { |
||||
prompts := strings.Split(i, "|") |
||||
positive_prompt := prompts[0] |
||||
negative_prompt := "" |
||||
if len(prompts) > 1 { |
||||
negative_prompt = prompts[1] |
||||
} |
||||
|
||||
mode := 0 |
||||
step := 15 |
||||
|
||||
if input.Mode != 0 { |
||||
mode = input.Mode |
||||
} |
||||
|
||||
if input.Step != 0 { |
||||
step = input.Step |
||||
} |
||||
|
||||
tempDir := "" |
||||
if !b64JSON { |
||||
tempDir = o.ImageDir |
||||
} |
||||
// Create a temporary file
|
||||
outputFile, err := ioutil.TempFile(tempDir, "b64") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
outputFile.Close() |
||||
output := outputFile.Name() + ".png" |
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
baseURL := c.BaseURL() |
||||
|
||||
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.Loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err := fn(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
item := &Item{} |
||||
|
||||
if b64JSON { |
||||
defer os.RemoveAll(output) |
||||
data, err := os.ReadFile(output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
||||
} else { |
||||
base := filepath.Base(output) |
||||
item.URL = baseURL + "/generated-images/" + base |
||||
} |
||||
|
||||
result = append(result, *item) |
||||
} |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Data: result, |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,36 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { |
||||
result := []Choice{} |
||||
|
||||
if n == 0 { |
||||
n = 1 |
||||
} |
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback) |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
for i := 0; i < n; i++ { |
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
prediction = backend.Finetune(*config, predInput, prediction) |
||||
cb(prediction, &result) |
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
} |
||||
return result, err |
||||
} |
@ -0,0 +1,37 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
) |
||||
|
||||
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
models, err := loader.ListModels() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var mm map[string]interface{} = map[string]interface{}{} |
||||
|
||||
dataModels := []OpenAIModel{} |
||||
for _, m := range models { |
||||
mm[m] = nil |
||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) |
||||
} |
||||
|
||||
for _, k := range cm.ListConfigs() { |
||||
if _, exists := mm[k]; !exists { |
||||
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) |
||||
} |
||||
} |
||||
|
||||
return c.JSON(struct { |
||||
Object string `json:"object"` |
||||
Data []OpenAIModel `json:"data"` |
||||
}{ |
||||
Object: "list", |
||||
Data: dataModels, |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,234 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { |
||||
input := new(OpenAIRequest) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return "", nil, err |
||||
} |
||||
|
||||
modelFile := input.Model |
||||
|
||||
if c.Params("model") != "" { |
||||
modelFile = c.Params("model") |
||||
} |
||||
|
||||
received, _ := json.Marshal(input) |
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received)) |
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel { |
||||
models, _ := loader.ListModels() |
||||
if len(models) > 0 { |
||||
modelFile = models[0] |
||||
log.Debug().Msgf("No model specified, using: %s", modelFile) |
||||
} else { |
||||
log.Debug().Msgf("No model specified, returning error") |
||||
return "", nil, fmt.Errorf("no model specified") |
||||
} |
||||
} |
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists { |
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
||||
modelFile = bearer |
||||
} |
||||
return modelFile, input, nil |
||||
} |
||||
|
||||
func updateConfig(config *config.Config, input *OpenAIRequest) { |
||||
if input.Echo { |
||||
config.Echo = input.Echo |
||||
} |
||||
if input.TopK != 0 { |
||||
config.TopK = input.TopK |
||||
} |
||||
if input.TopP != 0 { |
||||
config.TopP = input.TopP |
||||
} |
||||
|
||||
if input.Grammar != "" { |
||||
config.Grammar = input.Grammar |
||||
} |
||||
|
||||
if input.Temperature != 0 { |
||||
config.Temperature = input.Temperature |
||||
} |
||||
|
||||
if input.Maxtokens != 0 { |
||||
config.Maxtokens = input.Maxtokens |
||||
} |
||||
|
||||
switch stop := input.Stop.(type) { |
||||
case string: |
||||
if stop != "" { |
||||
config.StopWords = append(config.StopWords, stop) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range stop { |
||||
if s, ok := pp.(string); ok { |
||||
config.StopWords = append(config.StopWords, s) |
||||
} |
||||
} |
||||
} |
||||
|
||||
if input.RepeatPenalty != 0 { |
||||
config.RepeatPenalty = input.RepeatPenalty |
||||
} |
||||
|
||||
if input.Keep != 0 { |
||||
config.Keep = input.Keep |
||||
} |
||||
|
||||
if input.Batch != 0 { |
||||
config.Batch = input.Batch |
||||
} |
||||
|
||||
if input.F16 { |
||||
config.F16 = input.F16 |
||||
} |
||||
|
||||
if input.IgnoreEOS { |
||||
config.IgnoreEOS = input.IgnoreEOS |
||||
} |
||||
|
||||
if input.Seed != 0 { |
||||
config.Seed = input.Seed |
||||
} |
||||
|
||||
if input.Mirostat != 0 { |
||||
config.Mirostat = input.Mirostat |
||||
} |
||||
|
||||
if input.MirostatETA != 0 { |
||||
config.MirostatETA = input.MirostatETA |
||||
} |
||||
|
||||
if input.MirostatTAU != 0 { |
||||
config.MirostatTAU = input.MirostatTAU |
||||
} |
||||
|
||||
if input.TypicalP != 0 { |
||||
config.TypicalP = input.TypicalP |
||||
} |
||||
|
||||
switch inputs := input.Input.(type) { |
||||
case string: |
||||
if inputs != "" { |
||||
config.InputStrings = append(config.InputStrings, inputs) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range inputs { |
||||
switch i := pp.(type) { |
||||
case string: |
||||
config.InputStrings = append(config.InputStrings, i) |
||||
case []interface{}: |
||||
tokens := []int{} |
||||
for _, ii := range i { |
||||
tokens = append(tokens, int(ii.(float64))) |
||||
} |
||||
config.InputToken = append(config.InputToken, tokens) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) { |
||||
case string: |
||||
if fnc != "" { |
||||
config.SetFunctionCallString(fnc) |
||||
} |
||||
case map[string]interface{}: |
||||
var name string |
||||
n, exists := fnc["name"] |
||||
if exists { |
||||
nn, e := n.(string) |
||||
if !e { |
||||
name = nn |
||||
} |
||||
} |
||||
config.SetFunctionCallNameString(name) |
||||
} |
||||
|
||||
switch p := input.Prompt.(type) { |
||||
case string: |
||||
config.PromptStrings = append(config.PromptStrings, p) |
||||
case []interface{}: |
||||
for _, pp := range p { |
||||
if s, ok := pp.(string); ok { |
||||
config.PromptStrings = append(config.PromptStrings, s) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) { |
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") |
||||
|
||||
var cfg *config.Config |
||||
|
||||
defaults := func() { |
||||
cfg = config.DefaultConfig(modelFile) |
||||
cfg.ContextSize = ctx |
||||
cfg.Threads = threads |
||||
cfg.F16 = f16 |
||||
cfg.Debug = debug |
||||
} |
||||
|
||||
cfgExisting, exists := cm.GetConfig(modelFile) |
||||
if !exists { |
||||
if _, err := os.Stat(modelConfig); err == nil { |
||||
if err := cm.LoadConfig(modelConfig); err != nil { |
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) |
||||
} |
||||
cfgExisting, exists = cm.GetConfig(modelFile) |
||||
if exists { |
||||
cfg = &cfgExisting |
||||
} else { |
||||
defaults() |
||||
} |
||||
} else { |
||||
defaults() |
||||
} |
||||
} else { |
||||
cfg = &cfgExisting |
||||
} |
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateConfig(cfg, input) |
||||
|
||||
// Don't allow 0 as setting
|
||||
if cfg.Threads == 0 { |
||||
if threads != 0 { |
||||
cfg.Threads = threads |
||||
} else { |
||||
cfg.Threads = 4 |
||||
} |
||||
} |
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if debug { |
||||
cfg.Debug = true |
||||
} |
||||
|
||||
return cfg, input, nil |
||||
} |
@ -0,0 +1,71 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"net/http" |
||||
"os" |
||||
"path" |
||||
"path/filepath" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
|
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.Loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
f, err := file.Open() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer f.Close() |
||||
|
||||
dir, err := os.MkdirTemp("", "whisper") |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer os.RemoveAll(dir) |
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename)) |
||||
dstFile, err := os.Create(dst) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil { |
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst) |
||||
|
||||
tr, err := backend.ModelTranscription(dst, input.Language, o.Loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr) |
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr) |
||||
} |
||||
} |
@ -1,121 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
type Option struct { |
||||
context context.Context |
||||
configFile string |
||||
loader *model.ModelLoader |
||||
uploadLimitMB, threads, ctxSize int |
||||
f16 bool |
||||
debug, disableMessage bool |
||||
imageDir string |
||||
cors bool |
||||
preloadJSONModels string |
||||
preloadModelsFromPath string |
||||
corsAllowOrigins string |
||||
} |
||||
|
||||
type AppOption func(*Option) |
||||
|
||||
func newOptions(o ...AppOption) *Option { |
||||
opt := &Option{ |
||||
context: context.Background(), |
||||
uploadLimitMB: 15, |
||||
threads: 1, |
||||
ctxSize: 512, |
||||
debug: true, |
||||
disableMessage: true, |
||||
} |
||||
for _, oo := range o { |
||||
oo(opt) |
||||
} |
||||
return opt |
||||
} |
||||
|
||||
func WithCors(b bool) AppOption { |
||||
return func(o *Option) { |
||||
o.cors = b |
||||
} |
||||
} |
||||
|
||||
func WithCorsAllowOrigins(b string) AppOption { |
||||
return func(o *Option) { |
||||
o.corsAllowOrigins = b |
||||
} |
||||
} |
||||
|
||||
func WithContext(ctx context.Context) AppOption { |
||||
return func(o *Option) { |
||||
o.context = ctx |
||||
} |
||||
} |
||||
|
||||
func WithYAMLConfigPreload(configFile string) AppOption { |
||||
return func(o *Option) { |
||||
o.preloadModelsFromPath = configFile |
||||
} |
||||
} |
||||
|
||||
func WithJSONStringPreload(configFile string) AppOption { |
||||
return func(o *Option) { |
||||
o.preloadJSONModels = configFile |
||||
} |
||||
} |
||||
func WithConfigFile(configFile string) AppOption { |
||||
return func(o *Option) { |
||||
o.configFile = configFile |
||||
} |
||||
} |
||||
|
||||
func WithModelLoader(loader *model.ModelLoader) AppOption { |
||||
return func(o *Option) { |
||||
o.loader = loader |
||||
} |
||||
} |
||||
|
||||
func WithUploadLimitMB(limit int) AppOption { |
||||
return func(o *Option) { |
||||
o.uploadLimitMB = limit |
||||
} |
||||
} |
||||
|
||||
func WithThreads(threads int) AppOption { |
||||
return func(o *Option) { |
||||
o.threads = threads |
||||
} |
||||
} |
||||
|
||||
func WithContextSize(ctxSize int) AppOption { |
||||
return func(o *Option) { |
||||
o.ctxSize = ctxSize |
||||
} |
||||
} |
||||
|
||||
func WithF16(f16 bool) AppOption { |
||||
return func(o *Option) { |
||||
o.f16 = f16 |
||||
} |
||||
} |
||||
|
||||
func WithDebug(debug bool) AppOption { |
||||
return func(o *Option) { |
||||
o.debug = debug |
||||
} |
||||
} |
||||
|
||||
func WithDisableMessage(disableMessage bool) AppOption { |
||||
return func(o *Option) { |
||||
o.disableMessage = disableMessage |
||||
} |
||||
} |
||||
|
||||
func WithImageDir(imageDir string) AppOption { |
||||
return func(o *Option) { |
||||
o.imageDir = imageDir |
||||
} |
||||
} |
@ -0,0 +1,186 @@ |
||||
package options |
||||
|
||||
import ( |
||||
"context" |
||||
"embed" |
||||
"encoding/json" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/gallery" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
type Option struct { |
||||
Context context.Context |
||||
ConfigFile string |
||||
Loader *model.ModelLoader |
||||
UploadLimitMB, Threads, ContextSize int |
||||
F16 bool |
||||
Debug, DisableMessage bool |
||||
ImageDir string |
||||
AudioDir string |
||||
CORS bool |
||||
PreloadJSONModels string |
||||
PreloadModelsFromPath string |
||||
CORSAllowOrigins string |
||||
|
||||
Galleries []gallery.Gallery |
||||
|
||||
BackendAssets embed.FS |
||||
AssetsDestination string |
||||
|
||||
ExternalGRPCBackends map[string]string |
||||
|
||||
AutoloadGalleries bool |
||||
} |
||||
|
||||
type AppOption func(*Option) |
||||
|
||||
func NewOptions(o ...AppOption) *Option { |
||||
opt := &Option{ |
||||
Context: context.Background(), |
||||
UploadLimitMB: 15, |
||||
Threads: 1, |
||||
ContextSize: 512, |
||||
Debug: true, |
||||
DisableMessage: true, |
||||
} |
||||
for _, oo := range o { |
||||
oo(opt) |
||||
} |
||||
return opt |
||||
} |
||||
|
||||
func WithCors(b bool) AppOption { |
||||
return func(o *Option) { |
||||
o.CORS = b |
||||
} |
||||
} |
||||
|
||||
var EnableGalleriesAutoload = func(o *Option) { |
||||
o.AutoloadGalleries = true |
||||
} |
||||
|
||||
func WithExternalBackend(name string, uri string) AppOption { |
||||
return func(o *Option) { |
||||
if o.ExternalGRPCBackends == nil { |
||||
o.ExternalGRPCBackends = make(map[string]string) |
||||
} |
||||
o.ExternalGRPCBackends[name] = uri |
||||
} |
||||
} |
||||
|
||||
func WithCorsAllowOrigins(b string) AppOption { |
||||
return func(o *Option) { |
||||
o.CORSAllowOrigins = b |
||||
} |
||||
} |
||||
|
||||
func WithBackendAssetsOutput(out string) AppOption { |
||||
return func(o *Option) { |
||||
o.AssetsDestination = out |
||||
} |
||||
} |
||||
|
||||
func WithBackendAssets(f embed.FS) AppOption { |
||||
return func(o *Option) { |
||||
o.BackendAssets = f |
||||
} |
||||
} |
||||
|
||||
func WithStringGalleries(galls string) AppOption { |
||||
return func(o *Option) { |
||||
if galls == "" { |
||||
log.Debug().Msgf("no galleries to load") |
||||
return |
||||
} |
||||
var galleries []gallery.Gallery |
||||
if err := json.Unmarshal([]byte(galls), &galleries); err != nil { |
||||
log.Error().Msgf("failed loading galleries: %s", err.Error()) |
||||
} |
||||
o.Galleries = append(o.Galleries, galleries...) |
||||
} |
||||
} |
||||
|
||||
func WithGalleries(galleries []gallery.Gallery) AppOption { |
||||
return func(o *Option) { |
||||
o.Galleries = append(o.Galleries, galleries...) |
||||
} |
||||
} |
||||
|
||||
func WithContext(ctx context.Context) AppOption { |
||||
return func(o *Option) { |
||||
o.Context = ctx |
||||
} |
||||
} |
||||
|
||||
func WithYAMLConfigPreload(configFile string) AppOption { |
||||
return func(o *Option) { |
||||
o.PreloadModelsFromPath = configFile |
||||
} |
||||
} |
||||
|
||||
func WithJSONStringPreload(configFile string) AppOption { |
||||
return func(o *Option) { |
||||
o.PreloadJSONModels = configFile |
||||
} |
||||
} |
||||
func WithConfigFile(configFile string) AppOption { |
||||
return func(o *Option) { |
||||
o.ConfigFile = configFile |
||||
} |
||||
} |
||||
|
||||
func WithModelLoader(loader *model.ModelLoader) AppOption { |
||||
return func(o *Option) { |
||||
o.Loader = loader |
||||
} |
||||
} |
||||
|
||||
func WithUploadLimitMB(limit int) AppOption { |
||||
return func(o *Option) { |
||||
o.UploadLimitMB = limit |
||||
} |
||||
} |
||||
|
||||
func WithThreads(threads int) AppOption { |
||||
return func(o *Option) { |
||||
o.Threads = threads |
||||
} |
||||
} |
||||
|
||||
func WithContextSize(ctxSize int) AppOption { |
||||
return func(o *Option) { |
||||
o.ContextSize = ctxSize |
||||
} |
||||
} |
||||
|
||||
func WithF16(f16 bool) AppOption { |
||||
return func(o *Option) { |
||||
o.F16 = f16 |
||||
} |
||||
} |
||||
|
||||
func WithDebug(debug bool) AppOption { |
||||
return func(o *Option) { |
||||
o.Debug = debug |
||||
} |
||||
} |
||||
|
||||
func WithDisableMessage(disableMessage bool) AppOption { |
||||
return func(o *Option) { |
||||
o.DisableMessage = disableMessage |
||||
} |
||||
} |
||||
|
||||
func WithAudioDir(audioDir string) AppOption { |
||||
return func(o *Option) { |
||||
o.AudioDir = audioDir |
||||
} |
||||
} |
||||
|
||||
func WithImageDir(imageDir string) AppOption { |
||||
return func(o *Option) { |
||||
o.ImageDir = imageDir |
||||
} |
||||
} |
@ -1,574 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
"regexp" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"github.com/donomii/go-rwkv.cpp" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion" |
||||
"github.com/go-skynet/bloomz.cpp" |
||||
bert "github.com/go-skynet/go-bert.cpp" |
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
llama "github.com/go-skynet/go-llama.cpp" |
||||
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" |
||||
) |
||||
|
||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
var mutexMap sync.Mutex |
||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) |
||||
|
||||
func defaultLLamaOpts(c Config) []llama.ModelOption { |
||||
llamaOpts := []llama.ModelOption{} |
||||
if c.ContextSize != 0 { |
||||
llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize)) |
||||
} |
||||
if c.F16 { |
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory) |
||||
} |
||||
if c.Embeddings { |
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings) |
||||
} |
||||
|
||||
if c.NGPULayers != 0 { |
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers)) |
||||
} |
||||
|
||||
return llamaOpts |
||||
} |
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config) (func() error, error) { |
||||
if c.Backend != model.StableDiffusionBackend { |
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models") |
||||
} |
||||
inferenceModel, err := loader.BackendLoader(c.Backend, c.ImageGenerationAssets, []llama.ModelOption{}, uint32(c.Threads)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() error |
||||
switch model := inferenceModel.(type) { |
||||
case *stablediffusion.StableDiffusion: |
||||
fn = func() error { |
||||
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) |
||||
} |
||||
|
||||
default: |
||||
fn = func() error { |
||||
return fmt.Errorf("creation of images not supported by the backend") |
||||
} |
||||
} |
||||
|
||||
return func() error { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[c.Backend] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[c.Backend] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
return fn() |
||||
}, nil |
||||
} |
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config) (func() ([]float32, error), error) { |
||||
if !c.Embeddings { |
||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") |
||||
} |
||||
|
||||
modelFile := c.Model |
||||
|
||||
llamaOpts := defaultLLamaOpts(c) |
||||
|
||||
var inferenceModel interface{} |
||||
var err error |
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads)) |
||||
} else { |
||||
inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads)) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() ([]float32, error) |
||||
switch model := inferenceModel.(type) { |
||||
case *llama.LLama: |
||||
fn = func() ([]float32, error) { |
||||
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) |
||||
if len(tokens) > 0 { |
||||
return model.TokenEmbeddings(tokens, predictOptions...) |
||||
} |
||||
return model.Embeddings(s, predictOptions...) |
||||
} |
||||
// bert embeddings
|
||||
case *bert.Bert: |
||||
fn = func() ([]float32, error) { |
||||
if len(tokens) > 0 { |
||||
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) |
||||
} |
||||
return model.Embeddings(s, bert.SetThreads(c.Threads)) |
||||
} |
||||
default: |
||||
fn = func() ([]float32, error) { |
||||
return nil, fmt.Errorf("embeddings not supported by the backend") |
||||
} |
||||
} |
||||
|
||||
return func() ([]float32, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[modelFile] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[modelFile] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
embeds, err := fn() |
||||
if err != nil { |
||||
return embeds, err |
||||
} |
||||
// Remove trailing 0s
|
||||
for i := len(embeds) - 1; i >= 0; i-- { |
||||
if embeds[i] == 0.0 { |
||||
embeds = embeds[:i] |
||||
} else { |
||||
break |
||||
} |
||||
} |
||||
return embeds, nil |
||||
}, nil |
||||
} |
||||
|
||||
func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []llama.PredictOption{ |
||||
llama.SetTemperature(c.Temperature), |
||||
llama.SetTopP(c.TopP), |
||||
llama.SetTopK(c.TopK), |
||||
llama.SetTokens(c.Maxtokens), |
||||
llama.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.PromptCacheAll { |
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll) |
||||
} |
||||
|
||||
if c.PromptCachePath != "" { |
||||
// Create parent directory
|
||||
p := filepath.Join(modelPath, c.PromptCachePath) |
||||
os.MkdirAll(filepath.Dir(p), 0755) |
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(p)) |
||||
} |
||||
|
||||
if c.Mirostat != 0 { |
||||
predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) |
||||
} |
||||
|
||||
if c.MirostatETA != 0 { |
||||
predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) |
||||
} |
||||
|
||||
if c.MirostatTAU != 0 { |
||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) |
||||
} |
||||
|
||||
if c.Debug { |
||||
predictOptions = append(predictOptions, llama.Debug) |
||||
} |
||||
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) |
||||
|
||||
if c.RepeatPenalty != 0 { |
||||
predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) |
||||
} |
||||
|
||||
if c.Keep != 0 { |
||||
predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.F16 { |
||||
predictOptions = append(predictOptions, llama.EnableF16KV) |
||||
} |
||||
|
||||
if c.IgnoreEOS { |
||||
predictOptions = append(predictOptions, llama.IgnoreEOS) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return predictOptions |
||||
} |
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback func(string) bool) (func() (string, error), error) { |
||||
supportStreams := false |
||||
modelFile := c.Model |
||||
|
||||
llamaOpts := defaultLLamaOpts(c) |
||||
|
||||
var inferenceModel interface{} |
||||
var err error |
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads)) |
||||
} else { |
||||
inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads)) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() (string, error) |
||||
|
||||
switch model := inferenceModel.(type) { |
||||
case *rwkv.RwkvState: |
||||
supportStreams = true |
||||
|
||||
fn = func() (string, error) { |
||||
stopWord := "\n" |
||||
if len(c.StopWords) > 0 { |
||||
stopWord = c.StopWords[0] |
||||
} |
||||
|
||||
if err := model.ProcessInput(s); err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) |
||||
|
||||
return response, nil |
||||
} |
||||
case *transformers.GPTNeoX: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.Replit: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.Starcoder: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.MPT: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *bloomz.Bloomz: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []bloomz.PredictOption{ |
||||
bloomz.SetTemperature(c.Temperature), |
||||
bloomz.SetTopP(c.TopP), |
||||
bloomz.SetTopK(c.TopK), |
||||
bloomz.SetTokens(c.Maxtokens), |
||||
bloomz.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.GPTJ: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.Dolly: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.GPT2: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *gpt4all.Model: |
||||
supportStreams = true |
||||
|
||||
fn = func() (string, error) { |
||||
if tokenCallback != nil { |
||||
model.SetTokenCallback(tokenCallback) |
||||
} |
||||
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []gpt4all.PredictOption{ |
||||
gpt4all.SetTemperature(c.Temperature), |
||||
gpt4all.SetTopP(c.TopP), |
||||
gpt4all.SetTopK(c.TopK), |
||||
gpt4all.SetTokens(c.Maxtokens), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
str, er := model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
||||
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
||||
// after a stream event has occurred
|
||||
model.SetTokenCallback(nil) |
||||
return str, er |
||||
} |
||||
case *llama.LLama: |
||||
supportStreams = true |
||||
fn = func() (string, error) { |
||||
|
||||
if tokenCallback != nil { |
||||
model.SetTokenCallback(tokenCallback) |
||||
} |
||||
|
||||
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) |
||||
|
||||
str, er := model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
||||
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
||||
// after a stream event has occurred
|
||||
model.SetTokenCallback(nil) |
||||
return str, er |
||||
} |
||||
} |
||||
|
||||
return func() (string, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[modelFile] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[modelFile] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
res, err := fn() |
||||
if tokenCallback != nil && !supportStreams { |
||||
tokenCallback(res) |
||||
} |
||||
return res, err |
||||
}, nil |
||||
} |
||||
|
||||
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { |
||||
result := []Choice{} |
||||
|
||||
n := input.N |
||||
|
||||
if input.N == 0 { |
||||
n = 1 |
||||
} |
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := ModelInference(predInput, loader, *config, tokenCallback) |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
for i := 0; i < n; i++ { |
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
prediction = Finetune(*config, predInput, prediction) |
||||
cb(prediction, &result) |
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
} |
||||
return result, err |
||||
} |
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
||||
var mu sync.Mutex = sync.Mutex{} |
||||
|
||||
func Finetune(config Config, input, prediction string) string { |
||||
if config.Echo { |
||||
prediction = input + prediction |
||||
} |
||||
|
||||
for _, c := range config.Cutstrings { |
||||
mu.Lock() |
||||
reg, ok := cutstrings[c] |
||||
if !ok { |
||||
cutstrings[c] = regexp.MustCompile(c) |
||||
reg = cutstrings[c] |
||||
} |
||||
mu.Unlock() |
||||
prediction = reg.ReplaceAllString(prediction, "") |
||||
} |
||||
|
||||
for _, c := range config.TrimSpace { |
||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
||||
} |
||||
return prediction |
||||
|
||||
} |
@ -0,0 +1,6 @@ |
||||
package main |
||||
|
||||
import "embed" |
||||
|
||||
//go:embed backend-assets/*
|
||||
var backendAssets embed.FS |
@ -0,0 +1,22 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
bert "github.com/go-skynet/LocalAI/pkg/grpc/llm/bert" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &bert.Embeddings{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
bloomz "github.com/go-skynet/LocalAI/pkg/grpc/llm/bloomz" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &bloomz.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Dolly{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Falcon{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,25 @@ |
||||
package main |
||||
|
||||
// GRPC Falcon server
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
falcon "github.com/go-skynet/LocalAI/pkg/grpc/llm/falcon" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &falcon.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.GPT2{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
gpt4all "github.com/go-skynet/LocalAI/pkg/grpc/llm/gpt4all" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &gpt4all.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.GPTJ{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.GPTNeoX{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
langchain "github.com/go-skynet/LocalAI/pkg/grpc/llm/langchain" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &langchain.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,25 @@ |
||||
package main |
||||
|
||||
// GRPC Falcon server
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama-grammar" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &llama.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,25 @@ |
||||
package main |
||||
|
||||
// GRPC Falcon server
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &llama.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.MPT{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
tts "github.com/go-skynet/LocalAI/pkg/grpc/tts" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &tts.Piper{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Replit{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
rwkv "github.com/go-skynet/LocalAI/pkg/grpc/llm/rwkv" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &rwkv.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
image "github.com/go-skynet/LocalAI/pkg/grpc/image" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &image.StableDiffusion{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Starcoder{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transcribe "github.com/go-skynet/LocalAI/pkg/grpc/transcribe" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transcribe.Whisper{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,48 @@ |
||||
# chatbot-ui |
||||
|
||||
Example of integration with [mckaywrigley/chatbot-ui](https://github.com/mckaywrigley/chatbot-ui). |
||||
|
||||
![Screenshot from 2023-04-26 23-59-55](https://user-images.githubusercontent.com/2420543/234715439-98d12e03-d3ce-4f94-ab54-2b256808e05e.png) |
||||
|
||||
## Setup |
||||
|
||||
```bash |
||||
# Clone LocalAI |
||||
git clone https://github.com/go-skynet/LocalAI |
||||
|
||||
cd LocalAI/examples/chatbot-ui |
||||
|
||||
# (optional) Checkout a specific LocalAI tag |
||||
# git checkout -b build <TAG> |
||||
|
||||
# Download gpt4all-j to models/ |
||||
wget https://gpt4all.io/models/ggml-gpt4all-j.bin -O models/ggml-gpt4all-j |
||||
|
||||
# start with docker-compose |
||||
docker-compose up -d --pull always |
||||
# or you can build the images with: |
||||
# docker-compose up -d --build |
||||
``` |
||||
|
||||
## Pointing chatbot-ui to a separately managed LocalAI service |
||||
|
||||
If you want to use the [chatbot-ui example](https://github.com/go-skynet/LocalAI/tree/master/examples/chatbot-ui) with an externally managed LocalAI service, you can alter the `docker-compose` file so that it looks like the below. You will notice the file is smaller, because we have removed the section that would normally start the LocalAI service. Take care to update the IP address (or FQDN) that the chatbot-ui service tries to access (marked `<<LOCALAI_IP>>` below): |
||||
``` |
||||
version: '3.6' |
||||
|
||||
services: |
||||
chatgpt: |
||||
image: ghcr.io/mckaywrigley/chatbot-ui:main |
||||
ports: |
||||
- 3000:3000 |
||||
environment: |
||||
- 'OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXXX' |
||||
- 'OPENAI_API_HOST=http://<<LOCALAI_IP>>:8080' |
||||
``` |
||||
|
||||
Once you've edited the Dockerfile, you can start it with `docker compose up`, then browse to `http://localhost:3000`. |
||||
|
||||
## Accessing chatbot-ui |
||||
|
||||
Open http://localhost:3000 for the Web UI. |
||||
|
@ -0,0 +1,24 @@ |
||||
version: '3.6' |
||||
|
||||
services: |
||||
api: |
||||
image: quay.io/go-skynet/local-ai:latest |
||||
build: |
||||
context: ../../ |
||||
dockerfile: Dockerfile |
||||
ports: |
||||
- 8080:8080 |
||||
environment: |
||||
- DEBUG=true |
||||
- MODELS_PATH=/models |
||||
volumes: |
||||
- ./models:/models:cached |
||||
command: ["/usr/bin/local-ai" ] |
||||
|
||||
chatgpt: |
||||
image: ghcr.io/mckaywrigley/chatbot-ui:main |
||||
ports: |
||||
- 3000:3000 |
||||
environment: |
||||
- 'OPENAI_API_KEY=sk-XXXXXXXXXXXXXXXXXXXX' |
||||
- 'OPENAI_API_HOST=http://api:8080' |
@ -0,0 +1,30 @@ |
||||
# flowise |
||||
|
||||
Example of integration with [FlowiseAI/Flowise](https://github.com/FlowiseAI/Flowise). |
||||
|
||||
![Screenshot from 2023-05-30 18-01-03](https://github.com/go-skynet/LocalAI/assets/2420543/02458782-0549-4131-971c-95ee56ec1af8) |
||||
|
||||
You can check a demo video in the Flowise PR: https://github.com/FlowiseAI/Flowise/pull/123 |
||||
|
||||
## Run |
||||
|
||||
In this example LocalAI will download the gpt4all model and set it up as "gpt-3.5-turbo". See the `docker-compose.yaml` |
||||
```bash |
||||
# Clone LocalAI |
||||
git clone https://github.com/go-skynet/LocalAI |
||||
|
||||
cd LocalAI/examples/flowise |
||||
|
||||
# start with docker-compose |
||||
docker-compose up --pull always |
||||
|
||||
``` |
||||
|
||||
## Accessing flowise |
||||
|
||||
Open http://localhost:3000. |
||||
|
||||
## Using LocalAI |
||||
|
||||
Search for LocalAI in the integration, and use the `http://api:8080/` as URL. |
||||
|
@ -0,0 +1,37 @@ |
||||
version: '3.6' |
||||
|
||||
services: |
||||
api: |
||||
image: quay.io/go-skynet/local-ai:latest |
||||
# As initially LocalAI will download the models defined in PRELOAD_MODELS |
||||
# you might need to tweak the healthcheck values here according to your network connection. |
||||
# Here we give a timespan of 20m to download all the required files. |
||||
healthcheck: |
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/readyz"] |
||||
interval: 1m |
||||
timeout: 20m |
||||
retries: 20 |
||||
build: |
||||
context: ../../ |
||||
dockerfile: Dockerfile |
||||
ports: |
||||
- 8080:8080 |
||||
environment: |
||||
- DEBUG=true |
||||
- MODELS_PATH=/models |
||||
# You can preload different models here as well. |
||||
# See: https://github.com/go-skynet/model-gallery |
||||
- 'PRELOAD_MODELS=[{"url": "github:go-skynet/model-gallery/gpt4all-j.yaml", "name": "gpt-3.5-turbo"}]' |
||||
volumes: |
||||
- ./models:/models:cached |
||||
command: ["/usr/bin/local-ai" ] |
||||
flowise: |
||||
depends_on: |
||||
api: |
||||
condition: service_healthy |
||||
image: flowiseai/flowise |
||||
ports: |
||||
- 3000:3000 |
||||
volumes: |
||||
- ~/.flowise:/root/.flowise |
||||
command: /bin/sh -c "sleep 3; flowise start" |
@ -0,0 +1,9 @@ |
||||
OPENAI_API_KEY=sk---anystringhere |
||||
OPENAI_API_BASE=http://api:8080/v1 |
||||
# Models to preload at start |
||||
# Here we configure gpt4all as gpt-3.5-turbo and bert as embeddings |
||||
PRELOAD_MODELS=[{"url": "github:go-skynet/model-gallery/openllama-7b-open-instruct.yaml", "name": "gpt-3.5-turbo"}] |
||||
|
||||
## Change the default number of threads |
||||
#THREADS=14 |
||||
|
@ -0,0 +1,5 @@ |
||||
FROM python:3.10-bullseye |
||||
COPY . /app |
||||
WORKDIR /app |
||||
RUN pip install --no-cache-dir -r requirements.txt |
||||
ENTRYPOINT [ "python", "./functions-openai.py" ]; |
@ -0,0 +1,18 @@ |
||||
# LocalAI functions |
||||
|
||||
Example of using LocalAI functions, see the [OpenAI](https://openai.com/blog/function-calling-and-other-api-updates) blog post. |
||||
|
||||
## Run |
||||
|
||||
```bash |
||||
# Clone LocalAI |
||||
git clone https://github.com/go-skynet/LocalAI |
||||
|
||||
cd LocalAI/examples/functions |
||||
|
||||
docker-compose run --rm functions |
||||
``` |
||||
|
||||
Note: The example automatically downloads the `openllama` model as it is under a permissive license. |
||||
|
||||
See the `.env` configuration file to set a different model with the [model-gallery](https://github.com/go-skynet/model-gallery) by editing `PRELOAD_MODELS`. |
@ -0,0 +1,23 @@ |
||||
version: "3.9" |
||||
services: |
||||
api: |
||||
image: quay.io/go-skynet/local-ai:master |
||||
ports: |
||||
- 8080:8080 |
||||
env_file: |
||||
- .env |
||||
environment: |
||||
- DEBUG=true |
||||
- MODELS_PATH=/models |
||||
volumes: |
||||
- ./models:/models:cached |
||||
command: ["/usr/bin/local-ai" ] |
||||
functions: |
||||
build: |
||||
context: . |
||||
dockerfile: Dockerfile |
||||
depends_on: |
||||
api: |
||||
condition: service_healthy |
||||
env_file: |
||||
- .env |
@ -0,0 +1,76 @@ |
||||
import openai |
||||
import json |
||||
|
||||
# Example dummy function hard coded to return the same weather |
||||
# In production, this could be your backend API or an external API |
||||
def get_current_weather(location, unit="fahrenheit"): |
||||
"""Get the current weather in a given location""" |
||||
weather_info = { |
||||
"location": location, |
||||
"temperature": "72", |
||||
"unit": unit, |
||||
"forecast": ["sunny", "windy"], |
||||
} |
||||
return json.dumps(weather_info) |
||||
|
||||
|
||||
def run_conversation(): |
||||
# Step 1: send the conversation and available functions to GPT |
||||
messages = [{"role": "user", "content": "What's the weather like in Boston?"}] |
||||
functions = [ |
||||
{ |
||||
"name": "get_current_weather", |
||||
"description": "Get the current weather in a given location", |
||||
"parameters": { |
||||
"type": "object", |
||||
"properties": { |
||||
"location": { |
||||
"type": "string", |
||||
"description": "The city and state, e.g. San Francisco, CA", |
||||
}, |
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, |
||||
}, |
||||
"required": ["location"], |
||||
}, |
||||
} |
||||
] |
||||
response = openai.ChatCompletion.create( |
||||
model="gpt-3.5-turbo", |
||||
messages=messages, |
||||
functions=functions, |
||||
function_call="auto", # auto is default, but we'll be explicit |
||||
) |
||||
response_message = response["choices"][0]["message"] |
||||
|
||||
# Step 2: check if GPT wanted to call a function |
||||
if response_message.get("function_call"): |
||||
# Step 3: call the function |
||||
# Note: the JSON response may not always be valid; be sure to handle errors |
||||
available_functions = { |
||||
"get_current_weather": get_current_weather, |
||||
} # only one function in this example, but you can have multiple |
||||
function_name = response_message["function_call"]["name"] |
||||
fuction_to_call = available_functions[function_name] |
||||
function_args = json.loads(response_message["function_call"]["arguments"]) |
||||
function_response = fuction_to_call( |
||||
location=function_args.get("location"), |
||||
unit=function_args.get("unit"), |
||||
) |
||||
|
||||
# Step 4: send the info on the function call and function response to GPT |
||||
messages.append(response_message) # extend conversation with assistant's reply |
||||
messages.append( |
||||
{ |
||||
"role": "function", |
||||
"name": function_name, |
||||
"content": function_response, |
||||
} |
||||
) # extend conversation with function response |
||||
second_response = openai.ChatCompletion.create( |
||||
model="gpt-3.5-turbo", |
||||
messages=messages, |
||||
) # get a new response from GPT where it can see the function response |
||||
return second_response |
||||
|
||||
|
||||
print(run_conversation()) |
@ -0,0 +1,2 @@ |
||||
langchain==0.0.234 |
||||
openai==0.27.8 |
@ -0,0 +1,70 @@ |
||||
# k8sgpt example |
||||
|
||||
This example show how to use LocalAI with k8sgpt |
||||
|
||||
![Screenshot from 2023-06-19 23-58-47](https://github.com/go-skynet/go-ggml-transformers.cpp/assets/2420543/cab87409-ee68-44ae-8d53-41627fb49509) |
||||
|
||||
## Create the cluster locally with Kind (optional) |
||||
|
||||
If you want to test this locally without a remote Kubernetes cluster, you can use kind. |
||||
|
||||
Install [kind](https://kind.sigs.k8s.io/) and create a cluster: |
||||
|
||||
``` |
||||
kind create cluster |
||||
``` |
||||
|
||||
## Setup LocalAI |
||||
|
||||
We will use [helm](https://helm.sh/docs/intro/install/): |
||||
|
||||
``` |
||||
helm repo add go-skynet https://go-skynet.github.io/helm-charts/ |
||||
helm repo update |
||||
|
||||
# Clone LocalAI |
||||
git clone https://github.com/go-skynet/LocalAI |
||||
|
||||
cd LocalAI/examples/k8sgpt |
||||
|
||||
# modify values.yaml preload_models with the models you want to install. |
||||
# CHANGE the URL to a model in huggingface. |
||||
helm install local-ai go-skynet/local-ai --create-namespace --namespace local-ai --values values.yaml |
||||
``` |
||||
|
||||
## Setup K8sGPT |
||||
|
||||
``` |
||||
# Install k8sgpt |
||||
helm repo add k8sgpt https://charts.k8sgpt.ai/ |
||||
helm repo update |
||||
helm install release k8sgpt/k8sgpt-operator -n k8sgpt-operator-system --create-namespace |
||||
``` |
||||
|
||||
Apply the k8sgpt-operator configuration: |
||||
|
||||
``` |
||||
kubectl apply -f - << EOF |
||||
apiVersion: core.k8sgpt.ai/v1alpha1 |
||||
kind: K8sGPT |
||||
metadata: |
||||
name: k8sgpt-local-ai |
||||
namespace: default |
||||
spec: |
||||
backend: localai |
||||
baseUrl: http://local-ai.local-ai.svc.cluster.local:8080/v1 |
||||
noCache: false |
||||
model: gpt-3.5-turbo |
||||
noCache: false |
||||
version: v0.3.0 |
||||
enableAI: true |
||||
EOF |
||||
``` |
||||
|
||||
## Test |
||||
|
||||
Apply a broken pod: |
||||
|
||||
``` |
||||
kubectl apply -f broken-pod.yaml |
||||
``` |
@ -0,0 +1,14 @@ |
||||
apiVersion: v1 |
||||
kind: Pod |
||||
metadata: |
||||
name: broken-pod |
||||
spec: |
||||
containers: |
||||
- name: broken-pod |
||||
image: nginx:1.a.b.c |
||||
livenessProbe: |
||||
httpGet: |
||||
path: / |
||||
port: 90 |
||||
initialDelaySeconds: 3 |
||||
periodSeconds: 3 |
@ -0,0 +1,95 @@ |
||||
replicaCount: 1 |
||||
|
||||
deployment: |
||||
# https://quay.io/repository/go-skynet/local-ai?tab=tags |
||||
image: quay.io/go-skynet/local-ai:latest |
||||
env: |
||||
threads: 4 |
||||
debug: "true" |
||||
context_size: 512 |
||||
preload_models: '[{ "url": "github:go-skynet/model-gallery/wizard.yaml", "name": "gpt-3.5-turbo", "overrides": { "parameters": { "model": "WizardLM-7B-uncensored.ggmlv3.q5_1" }},"files": [ { "uri": "https://huggingface.co//WizardLM-7B-uncensored-GGML/resolve/main/WizardLM-7B-uncensored.ggmlv3.q5_1.bin", "sha256": "d92a509d83a8ea5e08ba4c2dbaf08f29015932dc2accd627ce0665ac72c2bb2b", "filename": "WizardLM-7B-uncensored.ggmlv3.q5_1" }]}]' |
||||
modelsPath: "/models" |
||||
|
||||
resources: |
||||
{} |
||||
# We usually recommend not to specify default resources and to leave this as a conscious |
||||
# choice for the user. This also increases chances charts run on environments with little |
||||
# resources, such as Minikube. If you do want to specify resources, uncomment the following |
||||
# lines, adjust them as necessary, and remove the curly braces after 'resources:'. |
||||
# limits: |
||||
# cpu: 100m |
||||
# memory: 128Mi |
||||
# requests: |
||||
# cpu: 100m |
||||
# memory: 128Mi |
||||
|
||||
# Prompt templates to include |
||||
# Note: the keys of this map will be the names of the prompt template files |
||||
promptTemplates: |
||||
{} |
||||
# ggml-gpt4all-j.tmpl: | |
||||
# The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. |
||||
# ### Prompt: |
||||
# {{.Input}} |
||||
# ### Response: |
||||
|
||||
# Models to download at runtime |
||||
models: |
||||
# Whether to force download models even if they already exist |
||||
forceDownload: false |
||||
|
||||
# The list of URLs to download models from |
||||
# Note: the name of the file will be the name of the loaded model |
||||
list: |
||||
#- url: "https://gpt4all.io/models/ggml-gpt4all-j.bin" |
||||
# basicAuth: base64EncodedCredentials |
||||
|
||||
# Persistent storage for models and prompt templates. |
||||
# PVC and HostPath are mutually exclusive. If both are enabled, |
||||
# PVC configuration takes precedence. If neither are enabled, ephemeral |
||||
# storage is used. |
||||
persistence: |
||||
pvc: |
||||
enabled: false |
||||
size: 6Gi |
||||
accessModes: |
||||
- ReadWriteOnce |
||||
|
||||
annotations: {} |
||||
|
||||
# Optional |
||||
storageClass: ~ |
||||
|
||||
hostPath: |
||||
enabled: false |
||||
path: "/models" |
||||
|
||||
service: |
||||
type: ClusterIP |
||||
port: 8080 |
||||
annotations: {} |
||||
# If using an AWS load balancer, you'll need to override the default 60s load balancer idle timeout |
||||
# service.beta.kubernetes.io/aws-load-balancer-connection-idle-timeout: "1200" |
||||
|
||||
ingress: |
||||
enabled: false |
||||
className: "" |
||||
annotations: |
||||
{} |
||||
# kubernetes.io/ingress.class: nginx |
||||
# kubernetes.io/tls-acme: "true" |
||||
hosts: |
||||
- host: chart-example.local |
||||
paths: |
||||
- path: / |
||||
pathType: ImplementationSpecific |
||||
tls: [] |
||||
# - secretName: chart-example-tls |
||||
# hosts: |
||||
# - chart-example.local |
||||
|
||||
nodeSelector: {} |
||||
|
||||
tolerations: [] |
||||
|
||||
affinity: {} |
@ -0,0 +1,68 @@ |
||||
# Data query example |
||||
|
||||
Example of integration with HuggingFace Inference API with help of [langchaingo](https://github.com/tmc/langchaingo). |
||||
|
||||
## Setup |
||||
|
||||
Download the LocalAI and start the API: |
||||
|
||||
```bash |
||||
# Clone LocalAI |
||||
git clone https://github.com/go-skynet/LocalAI |
||||
|
||||
cd LocalAI/examples/langchain-huggingface |
||||
|
||||
docker-compose up -d |
||||
``` |
||||
|
||||
Node: Ensure you've set `HUGGINGFACEHUB_API_TOKEN` environment variable, you can generate it |
||||
on [Settings / Access Tokens](https://huggingface.co/settings/tokens) page of HuggingFace site. |
||||
|
||||
This is an example `.env` file for LocalAI: |
||||
|
||||
```ini |
||||
MODELS_PATH=/models |
||||
CONTEXT_SIZE=512 |
||||
HUGGINGFACEHUB_API_TOKEN=hg_123456 |
||||
``` |
||||
|
||||
## Using remote models |
||||
|
||||
Now you can use any remote models available via HuggingFace API, for example let's enable using of |
||||
[gpt2](https://huggingface.co/gpt2) model in `gpt-3.5-turbo.yaml` config: |
||||
|
||||
```yml |
||||
name: gpt-3.5-turbo |
||||
parameters: |
||||
model: gpt2 |
||||
top_k: 80 |
||||
temperature: 0.2 |
||||
top_p: 0.7 |
||||
context_size: 1024 |
||||
backend: "langchain-huggingface" |
||||
stopwords: |
||||
- "HUMAN:" |
||||
- "GPT:" |
||||
roles: |
||||
user: " " |
||||
system: " " |
||||
template: |
||||
completion: completion |
||||
chat: gpt4all |
||||
``` |
||||
|
||||
Here is you can see in field `parameters.model` equal `gpt2` and `backend` equal `langchain-huggingface`. |
||||
|
||||
## How to use |
||||
|
||||
```shell |
||||
# Now API is accessible at localhost:8080 |
||||
curl http://localhost:8080/v1/models |
||||
# {"object":"list","data":[{"id":"gpt-3.5-turbo","object":"model"}]} |
||||
|
||||
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{ |
||||
"model": "gpt-3.5-turbo", |
||||
"prompt": "A long time ago in a galaxy far, far away", |
||||
"temperature": 0.7 |
||||
}' |
||||
``` |
@ -0,0 +1,15 @@ |
||||
version: '3.6' |
||||
|
||||
services: |
||||
api: |
||||
image: quay.io/go-skynet/local-ai:latest |
||||
build: |
||||
context: ../../ |
||||
dockerfile: Dockerfile |
||||
ports: |
||||
- 8080:8080 |
||||
env_file: |
||||
- ../../.env |
||||
volumes: |
||||
- ./models:/models:cached |
||||
command: ["/usr/bin/local-ai"] |
@ -0,0 +1 @@ |
||||
{{.Input}} |
@ -0,0 +1,17 @@ |
||||
name: gpt-3.5-turbo |
||||
parameters: |
||||
model: gpt2 |
||||
top_k: 80 |
||||
temperature: 0.2 |
||||
top_p: 0.7 |
||||
context_size: 1024 |
||||
backend: "langchain-huggingface" |
||||
stopwords: |
||||
- "HUMAN:" |
||||
- "GPT:" |
||||
roles: |
||||
user: " " |
||||
system: " " |
||||
template: |
||||
completion: completion |
||||
chat: gpt4all |
@ -0,0 +1,4 @@ |
||||
The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response. |
||||
### Prompt: |
||||
{{.Input}} |
||||
### Response: |
@ -1 +0,0 @@ |
||||
../chatbot-ui/models |
@ -0,0 +1,48 @@ |
||||
# Create an app-level token with connections:write scope |
||||
SLACK_APP_TOKEN=xapp-1-... |
||||
# Install the app into your workspace to grab this token |
||||
SLACK_BOT_TOKEN=xoxb-... |
||||
|
||||
# Set this to a random string, it doesn't matter, however if present the python library complains |
||||
OPENAI_API_KEY=sk-foo-bar-baz |
||||
|
||||
# Optional: gpt-3.5-turbo and gpt-4 are currently supported (default: gpt-3.5-turbo) |
||||
OPENAI_MODEL=gpt-3.5-turbo |
||||
# Optional: You can adjust the timeout seconds for OpenAI calls (default: 30) |
||||
OPENAI_TIMEOUT_SECONDS=560 |
||||
|
||||
MEMORY_DIR=/tmp/memory_dir |
||||
|
||||
OPENAI_API_BASE=http://api:8080/v1 |
||||
|
||||
EMBEDDINGS_MODEL_NAME=all-MiniLM-L6-v2 |
||||
|
||||
## Repository and sitemap to index in the vector database on start |
||||
SITEMAP="https://kairos.io/sitemap.xml" |
||||
|
||||
# Optional repository names. |
||||
# REPOSITORIES="foo,bar" |
||||
# # Define clone URL for "foo" |
||||
# foo_CLONE_URL="http://github.com.." |
||||
# bar_CLONE_URL="..." |
||||
# # Define branch for foo |
||||
# foo_BRANCH="master" |
||||
# Optional token if scraping issues |
||||
# GITHUB_PERSONAL_ACCESS_TOKEN="" |
||||
# ISSUE_REPOSITORIES="go-skynet/LocalAI,foo/bar,..." |
||||
|
||||
# Optional: When the string is "true", this app translates ChatGPT prompts into a user's preferred language (default: true) |
||||
USE_SLACK_LANGUAGE=true |
||||
# Optional: Adjust the app's logging level (default: DEBUG) |
||||
SLACK_APP_LOG_LEVEL=INFO |
||||
# Optional: When the string is "true", translate between OpenAI markdown and Slack mrkdwn format (default: false) |
||||
TRANSLATE_MARKDOWN=true |
||||
|
||||
|
||||
### LocalAI |
||||
|
||||
DEBUG=true |
||||
MODELS_PATH=/models |
||||
IMAGE_PATH=/tmp |
||||
# See: https://github.com/go-skynet/model-gallery |
||||
PRELOAD_MODELS=[{"url": "github:go-skynet/model-gallery/gpt4all-j.yaml", "name": "gpt-3.5-turbo"}] |
@ -0,0 +1,23 @@ |
||||
## Slack QA Bot |
||||
|
||||
This example uses https://github.com/spectrocloud-labs/Slack-QA-bot to deploy a slack bot that can answer to your documentation! |
||||
|
||||
- Create a new Slack app using the manifest-dev.yml file |
||||
- Install the app into your Slack workspace |
||||
- Retrieve your slack keys and edit `.env` |
||||
- Start the app |
||||
|
||||
```bash |
||||
# Clone LocalAI |
||||
git clone https://github.com/go-skynet/LocalAI |
||||
|
||||
cd LocalAI/examples/slack-qa-bot |
||||
|
||||
cp -rfv .env.example .env |
||||
|
||||
# Edit .env and add slackbot api keys, or repository settings to scan |
||||
vim .env |
||||
|
||||
# run the bot |
||||
docker-compose up |
||||
``` |
@ -0,0 +1,97 @@ |
||||
apiVersion: v1 |
||||
kind: Namespace |
||||
metadata: |
||||
name: slack-bot |
||||
--- |
||||
apiVersion: v1 |
||||
kind: PersistentVolumeClaim |
||||
metadata: |
||||
name: knowledgebase |
||||
namespace: slack-bot |
||||
labels: |
||||
app: localai-qabot |
||||
spec: |
||||
accessModes: |
||||
- ReadWriteOnce |
||||
resources: |
||||
requests: |
||||
storage: 5Gi |
||||
--- |
||||
apiVersion: apps/v1 |
||||
kind: Deployment |
||||
metadata: |
||||
name: localai-qabot |
||||
namespace: slack-bot |
||||
labels: |
||||
app: localai-qabot |
||||
spec: |
||||
selector: |
||||
matchLabels: |
||||
app: localai-qabot |
||||
replicas: 1 |
||||
template: |
||||
metadata: |
||||
labels: |
||||
app: localai-qabot |
||||
name: localai-qabot |
||||
spec: |
||||
containers: |
||||
- name: localai-qabot-slack |
||||
env: |
||||
- name: OPENAI_API_KEY |
||||
value: "x" |
||||
- name: SLACK_APP_TOKEN |
||||
value: "xapp-1-" |
||||
- name: SLACK_BOT_TOKEN |
||||
value: "xoxb-" |
||||
- name: OPENAI_MODEL |
||||
value: "gpt-3.5-turbo" |
||||
- name: OPENAI_TIMEOUT_SECONDS |
||||
value: "400" |
||||
- name: OPENAI_SYSTEM_TEXT |
||||
value: "" |
||||
- name: MEMORY_DIR |
||||
value: "/memory" |
||||
- name: TRANSLATE_MARKDOWN |
||||
value: "true" |
||||
- name: OPENAI_API_BASE |
||||
value: "http://local-ai.default.svc.cluster.local:8080" |
||||
- name: REPOSITORIES |
||||
value: "KAIROS,AGENT,SDK,OSBUILDER,PACKAGES,IMMUCORE" |
||||
- name: KAIROS_CLONE_URL |
||||
value: "https://github.com/kairos-io/kairos" |
||||
- name: KAIROS_BRANCH |
||||
value: "master" |
||||
- name: AGENT_CLONE_URL |
||||
value: "https://github.com/kairos-io/kairos-agent" |
||||
- name: AGENT_BRANCH |
||||
value: "main" |
||||
- name: SDK_CLONE_URL |
||||
value: "https://github.com/kairos-io/kairos-sdk" |
||||
- name: SDK_BRANCH |
||||
value: "main" |
||||
- name: OSBUILDER_CLONE_URL |
||||
value: "https://github.com/kairos-io/osbuilder" |
||||
- name: OSBUILDER_BRANCH |
||||
value: "master" |
||||
- name: PACKAGES_CLONE_URL |
||||
value: "https://github.com/kairos-io/packages" |
||||
- name: PACKAGES_BRANCH |
||||
value: "main" |
||||
- name: IMMUCORE_CLONE_URL |
||||
value: "https://github.com/kairos-io/immucore" |
||||
- name: IMMUCORE_BRANCH |
||||
value: "master" |
||||
- name: GITHUB_PERSONAL_ACCESS_TOKEN |
||||
value: "" |
||||
- name: ISSUE_REPOSITORIES |
||||
value: "kairos-io/kairos" |
||||
image: quay.io/spectrocloud-labs/slack-qa-local-bot:qa |
||||
imagePullPolicy: Always |
||||
volumeMounts: |
||||
- mountPath: "/memory" |
||||
name: knowledgebase |
||||
volumes: |
||||
- name: knowledgebase |
||||
persistentVolumeClaim: |
||||
claimName: knowledgebase |
@ -0,0 +1,30 @@ |
||||
version: "3" |
||||
|
||||
services: |
||||
api: |
||||
image: quay.io/go-skynet/local-ai:latest |
||||
# As initially LocalAI will download the models defined in PRELOAD_MODELS |
||||
# you might need to tweak the healthcheck values here according to your network connection. |
||||
# Here we give a timespan of 20m to download all the required files. |
||||
healthcheck: |
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/readyz"] |
||||
interval: 1m |
||||
timeout: 20m |
||||
retries: 20 |
||||
ports: |
||||
- 8080:8080 |
||||
env_file: |
||||
- .env |
||||
volumes: |
||||
- ./models:/models:cached |
||||
command: ["/usr/bin/local-ai" ] |
||||
|
||||
slackbot: |
||||
image: quay.io/spectrocloud-labs/slack-qa-local-bot:qa |
||||
container_name: slackbot |
||||
restart: always |
||||
env_file: |
||||
- .env |
||||
depends_on: |
||||
api: |
||||
condition: service_healthy |
@ -0,0 +1,30 @@ |
||||
## Telegram bot |
||||
|
||||
![Screenshot from 2023-06-09 00-36-26](https://github.com/go-skynet/LocalAI/assets/2420543/e98b4305-fa2d-41cf-9d2f-1bb2d75ca902) |
||||
|
||||
This example uses a fork of [chatgpt-telegram-bot](https://github.com/karfly/chatgpt_telegram_bot) to deploy a telegram bot with LocalAI instead of OpenAI. |
||||
|
||||
```bash |
||||
# Clone LocalAI |
||||
git clone https://github.com/go-skynet/LocalAI |
||||
|
||||
cd LocalAI/examples/telegram-bot |
||||
|
||||
git clone https://github.com/mudler/chatgpt_telegram_bot |
||||
|
||||
cp -rf docker-compose.yml chatgpt_telegram_bot |
||||
|
||||
cd chatgpt_telegram_bot |
||||
|
||||
mv config/config.example.yml config/config.yml |
||||
mv config/config.example.env config/config.env |
||||
|
||||
# Edit config/config.yml to set the telegram bot token |
||||
vim config/config.yml |
||||
|
||||
# run the bot |
||||
docker-compose --env-file config/config.env up --build |
||||
``` |
||||
|
||||
Note: LocalAI is configured to download `gpt4all-j` in place of `gpt-3.5-turbo` and `stablediffusion` for image generation at the first start. Download size is >6GB, if your network connection is slow, adapt the `docker-compose.yml` file healthcheck section accordingly (replace `20m`, for instance with `1h`, etc.). |
||||
To configure models manually, comment the `PRELOAD_MODELS` environment variable in the `docker-compose.yml` file and see for instance the [chatbot-ui-manual example](https://github.com/go-skynet/LocalAI/tree/master/examples/chatbot-ui-manual) `model` directory. |
@ -0,0 +1,38 @@ |
||||
version: "3" |
||||
|
||||
services: |
||||
api: |
||||
image: quay.io/go-skynet/local-ai:v1.18.0-ffmpeg |
||||
# As initially LocalAI will download the models defined in PRELOAD_MODELS |
||||
# you might need to tweak the healthcheck values here according to your network connection. |
||||
# Here we give a timespan of 20m to download all the required files. |
||||
healthcheck: |
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/readyz"] |
||||
interval: 1m |
||||
timeout: 20m |
||||
retries: 20 |
||||
ports: |
||||
- 8080:8080 |
||||
environment: |
||||
- DEBUG=true |
||||
- MODELS_PATH=/models |
||||
- IMAGE_PATH=/tmp |
||||
# You can preload different models here as well. |
||||
# See: https://github.com/go-skynet/model-gallery |
||||
- 'PRELOAD_MODELS=[{"url": "github:go-skynet/model-gallery/gpt4all-j.yaml", "name": "gpt-3.5-turbo"}, {"url": "github:go-skynet/model-gallery/stablediffusion.yaml"}, {"url": "github:go-skynet/model-gallery/whisper-base.yaml", "name": "whisper-1"}]' |
||||
volumes: |
||||
- ./models:/models:cached |
||||
command: ["/usr/bin/local-ai"] |
||||
chatgpt_telegram_bot: |
||||
container_name: chatgpt_telegram_bot |
||||
command: python3 bot/bot.py |
||||
restart: always |
||||
environment: |
||||
- OPENAI_API_KEY=sk---anystringhere |
||||
- OPENAI_API_BASE=http://api:8080/v1 |
||||
build: |
||||
context: "." |
||||
dockerfile: Dockerfile |
||||
depends_on: |
||||
api: |
||||
condition: service_healthy |
@ -0,0 +1,49 @@ |
||||
# -*- coding: utf-8 -*- |
||||
# Generated by the protocol buffer compiler. DO NOT EDIT! |
||||
# source: backend.proto |
||||
"""Generated protocol buffer code.""" |
||||
from google.protobuf import descriptor as _descriptor |
||||
from google.protobuf import descriptor_pool as _descriptor_pool |
||||
from google.protobuf import symbol_database as _symbol_database |
||||
from google.protobuf.internal import builder as _builder |
||||
# @@protoc_insertion_point(imports) |
||||
|
||||
_sym_db = _symbol_database.Default() |
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rbackend.proto\x12\x07\x62\x61\x63kend\"\x0f\n\rHealthMessage\"\xa4\x05\n\x0ePredictOptions\x12\x0e\n\x06Prompt\x18\x01 \x01(\t\x12\x0c\n\x04Seed\x18\x02 \x01(\x05\x12\x0f\n\x07Threads\x18\x03 \x01(\x05\x12\x0e\n\x06Tokens\x18\x04 \x01(\x05\x12\x0c\n\x04TopK\x18\x05 \x01(\x05\x12\x0e\n\x06Repeat\x18\x06 \x01(\x05\x12\r\n\x05\x42\x61tch\x18\x07 \x01(\x05\x12\r\n\x05NKeep\x18\x08 \x01(\x05\x12\x13\n\x0bTemperature\x18\t \x01(\x02\x12\x0f\n\x07Penalty\x18\n \x01(\x02\x12\r\n\x05\x46\x31\x36KV\x18\x0b \x01(\x08\x12\x11\n\tDebugMode\x18\x0c \x01(\x08\x12\x13\n\x0bStopPrompts\x18\r \x03(\t\x12\x11\n\tIgnoreEOS\x18\x0e \x01(\x08\x12\x19\n\x11TailFreeSamplingZ\x18\x0f \x01(\x02\x12\x10\n\x08TypicalP\x18\x10 \x01(\x02\x12\x18\n\x10\x46requencyPenalty\x18\x11 \x01(\x02\x12\x17\n\x0fPresencePenalty\x18\x12 \x01(\x02\x12\x10\n\x08Mirostat\x18\x13 \x01(\x05\x12\x13\n\x0bMirostatETA\x18\x14 \x01(\x02\x12\x13\n\x0bMirostatTAU\x18\x15 \x01(\x02\x12\x12\n\nPenalizeNL\x18\x16 \x01(\x08\x12\x11\n\tLogitBias\x18\x17 \x01(\t\x12\r\n\x05MLock\x18\x19 \x01(\x08\x12\x0c\n\x04MMap\x18\x1a \x01(\x08\x12\x16\n\x0ePromptCacheAll\x18\x1b \x01(\x08\x12\x15\n\rPromptCacheRO\x18\x1c \x01(\x08\x12\x0f\n\x07Grammar\x18\x1d \x01(\t\x12\x0f\n\x07MainGPU\x18\x1e \x01(\t\x12\x13\n\x0bTensorSplit\x18\x1f \x01(\t\x12\x0c\n\x04TopP\x18 \x01(\x02\x12\x17\n\x0fPromptCachePath\x18! \x01(\t\x12\r\n\x05\x44\x65\x62ug\x18\" \x01(\x08\x12\x17\n\x0f\x45mbeddingTokens\x18# \x03(\x05\x12\x12\n\nEmbeddings\x18$ \x01(\t\"\x18\n\x05Reply\x12\x0f\n\x07message\x18\x01 \x01(\t\"\xac\x02\n\x0cModelOptions\x12\r\n\x05Model\x18\x01 \x01(\t\x12\x13\n\x0b\x43ontextSize\x18\x02 \x01(\x05\x12\x0c\n\x04Seed\x18\x03 \x01(\x05\x12\x0e\n\x06NBatch\x18\x04 \x01(\x05\x12\x11\n\tF16Memory\x18\x05 \x01(\x08\x12\r\n\x05MLock\x18\x06 \x01(\x08\x12\x0c\n\x04MMap\x18\x07 \x01(\x08\x12\x11\n\tVocabOnly\x18\x08 \x01(\x08\x12\x0f\n\x07LowVRAM\x18\t \x01(\x08\x12\x12\n\nEmbeddings\x18\n \x01(\x08\x12\x0c\n\x04NUMA\x18\x0b \x01(\x08\x12\x12\n\nNGPULayers\x18\x0c \x01(\x05\x12\x0f\n\x07MainGPU\x18\r \x01(\t\x12\x13\n\x0bTensorSplit\x18\x0e \x01(\t\x12\x0f\n\x07Threads\x18\x0f \x01(\x05\x12\x19\n\x11LibrarySearchPath\x18\x10 \x01(\t\"*\n\x06Result\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\"%\n\x0f\x45mbeddingResult\x12\x12\n\nembeddings\x18\x01 \x03(\x02\"C\n\x11TranscriptRequest\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0f\n\x07threads\x18\x04 \x01(\r\"N\n\x10TranscriptResult\x12,\n\x08segments\x18\x01 \x03(\x0b\x32\x1a.backend.TranscriptSegment\x12\x0c\n\x04text\x18\x02 \x01(\t\"Y\n\x11TranscriptSegment\x12\n\n\x02id\x18\x01 \x01(\x05\x12\r\n\x05start\x18\x02 \x01(\x03\x12\x0b\n\x03\x65nd\x18\x03 \x01(\x03\x12\x0c\n\x04text\x18\x04 \x01(\t\x12\x0e\n\x06tokens\x18\x05 \x03(\x05\"\x9e\x01\n\x14GenerateImageRequest\x12\x0e\n\x06height\x18\x01 \x01(\x05\x12\r\n\x05width\x18\x02 \x01(\x05\x12\x0c\n\x04mode\x18\x03 \x01(\x05\x12\x0c\n\x04step\x18\x04 \x01(\x05\x12\x0c\n\x04seed\x18\x05 \x01(\x05\x12\x17\n\x0fpositive_prompt\x18\x06 \x01(\t\x12\x17\n\x0fnegative_prompt\x18\x07 \x01(\t\x12\x0b\n\x03\x64st\x18\x08 \x01(\t\"6\n\nTTSRequest\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\r\n\x05model\x18\x02 \x01(\t\x12\x0b\n\x03\x64st\x18\x03 \x01(\t2\xeb\x03\n\x07\x42\x61\x63kend\x12\x32\n\x06Health\x12\x16.backend.HealthMessage\x1a\x0e.backend.Reply\"\x00\x12\x34\n\x07Predict\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x12\x35\n\tLoadModel\x12\x15.backend.ModelOptions\x1a\x0f.backend.Result\"\x00\x12<\n\rPredictStream\x12\x17.backend.PredictOptions\x1a\x0e.backend.Reply\"\x00\x30\x01\x12@\n\tEmbedding\x12\x17.backend.PredictOptions\x1a\x18.backend.EmbeddingResult\"\x00\x12\x41\n\rGenerateImage\x12\x1d.backend.GenerateImageRequest\x1a\x0f.backend.Result\"\x00\x12M\n\x12\x41udioTranscription\x12\x1a.backend.TranscriptRequest\x1a\x19.backend.TranscriptResult\"\x00\x12-\n\x03TTS\x12\x13.backend.TTSRequest\x1a\x0f.backend.Result\"\x00\x42Z\n\x19io.skynet.localai.backendB\x0eLocalAIBackendP\x01Z+github.com/go-skynet/LocalAI/pkg/grpc/protob\x06proto3') |
||||
|
||||
_globals = globals() |
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) |
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'backend_pb2', _globals) |
||||
if _descriptor._USE_C_DESCRIPTORS == False: |
||||
|
||||
DESCRIPTOR._options = None |
||||
DESCRIPTOR._serialized_options = b'\n\031io.skynet.localai.backendB\016LocalAIBackendP\001Z+github.com/go-skynet/LocalAI/pkg/grpc/proto' |
||||
_globals['_HEALTHMESSAGE']._serialized_start=26 |
||||
_globals['_HEALTHMESSAGE']._serialized_end=41 |
||||
_globals['_PREDICTOPTIONS']._serialized_start=44 |
||||
_globals['_PREDICTOPTIONS']._serialized_end=720 |
||||
_globals['_REPLY']._serialized_start=722 |
||||
_globals['_REPLY']._serialized_end=746 |
||||
_globals['_MODELOPTIONS']._serialized_start=749 |
||||
_globals['_MODELOPTIONS']._serialized_end=1049 |
||||
_globals['_RESULT']._serialized_start=1051 |
||||
_globals['_RESULT']._serialized_end=1093 |
||||
_globals['_EMBEDDINGRESULT']._serialized_start=1095 |
||||
_globals['_EMBEDDINGRESULT']._serialized_end=1132 |
||||
_globals['_TRANSCRIPTREQUEST']._serialized_start=1134 |
||||
_globals['_TRANSCRIPTREQUEST']._serialized_end=1201 |
||||
_globals['_TRANSCRIPTRESULT']._serialized_start=1203 |
||||
_globals['_TRANSCRIPTRESULT']._serialized_end=1281 |
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_start=1283 |
||||
_globals['_TRANSCRIPTSEGMENT']._serialized_end=1372 |
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_start=1375 |
||||
_globals['_GENERATEIMAGEREQUEST']._serialized_end=1533 |
||||
_globals['_TTSREQUEST']._serialized_start=1535 |
||||
_globals['_TTSREQUEST']._serialized_end=1589 |
||||
_globals['_BACKEND']._serialized_start=1592 |
||||
_globals['_BACKEND']._serialized_end=2083 |
||||
# @@protoc_insertion_point(module_scope) |
@ -0,0 +1,297 @@ |
||||
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! |
||||
"""Client and server classes corresponding to protobuf-defined services.""" |
||||
import grpc |
||||
|
||||
import backend_pb2 as backend__pb2 |
||||
|
||||
|
||||
class BackendStub(object): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
|
||||
def __init__(self, channel): |
||||
"""Constructor. |
||||
|
||||
Args: |
||||
channel: A grpc.Channel. |
||||
""" |
||||
self.Health = channel.unary_unary( |
||||
'/backend.Backend/Health', |
||||
request_serializer=backend__pb2.HealthMessage.SerializeToString, |
||||
response_deserializer=backend__pb2.Reply.FromString, |
||||
) |
||||
self.Predict = channel.unary_unary( |
||||
'/backend.Backend/Predict', |
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString, |
||||
response_deserializer=backend__pb2.Reply.FromString, |
||||
) |
||||
self.LoadModel = channel.unary_unary( |
||||
'/backend.Backend/LoadModel', |
||||
request_serializer=backend__pb2.ModelOptions.SerializeToString, |
||||
response_deserializer=backend__pb2.Result.FromString, |
||||
) |
||||
self.PredictStream = channel.unary_stream( |
||||
'/backend.Backend/PredictStream', |
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString, |
||||
response_deserializer=backend__pb2.Reply.FromString, |
||||
) |
||||
self.Embedding = channel.unary_unary( |
||||
'/backend.Backend/Embedding', |
||||
request_serializer=backend__pb2.PredictOptions.SerializeToString, |
||||
response_deserializer=backend__pb2.EmbeddingResult.FromString, |
||||
) |
||||
self.GenerateImage = channel.unary_unary( |
||||
'/backend.Backend/GenerateImage', |
||||
request_serializer=backend__pb2.GenerateImageRequest.SerializeToString, |
||||
response_deserializer=backend__pb2.Result.FromString, |
||||
) |
||||
self.AudioTranscription = channel.unary_unary( |
||||
'/backend.Backend/AudioTranscription', |
||||
request_serializer=backend__pb2.TranscriptRequest.SerializeToString, |
||||
response_deserializer=backend__pb2.TranscriptResult.FromString, |
||||
) |
||||
self.TTS = channel.unary_unary( |
||||
'/backend.Backend/TTS', |
||||
request_serializer=backend__pb2.TTSRequest.SerializeToString, |
||||
response_deserializer=backend__pb2.Result.FromString, |
||||
) |
||||
|
||||
|
||||
class BackendServicer(object): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
|
||||
def Health(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def Predict(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def LoadModel(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def PredictStream(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def Embedding(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def GenerateImage(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def AudioTranscription(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
def TTS(self, request, context): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED) |
||||
context.set_details('Method not implemented!') |
||||
raise NotImplementedError('Method not implemented!') |
||||
|
||||
|
||||
def add_BackendServicer_to_server(servicer, server): |
||||
rpc_method_handlers = { |
||||
'Health': grpc.unary_unary_rpc_method_handler( |
||||
servicer.Health, |
||||
request_deserializer=backend__pb2.HealthMessage.FromString, |
||||
response_serializer=backend__pb2.Reply.SerializeToString, |
||||
), |
||||
'Predict': grpc.unary_unary_rpc_method_handler( |
||||
servicer.Predict, |
||||
request_deserializer=backend__pb2.PredictOptions.FromString, |
||||
response_serializer=backend__pb2.Reply.SerializeToString, |
||||
), |
||||
'LoadModel': grpc.unary_unary_rpc_method_handler( |
||||
servicer.LoadModel, |
||||
request_deserializer=backend__pb2.ModelOptions.FromString, |
||||
response_serializer=backend__pb2.Result.SerializeToString, |
||||
), |
||||
'PredictStream': grpc.unary_stream_rpc_method_handler( |
||||
servicer.PredictStream, |
||||
request_deserializer=backend__pb2.PredictOptions.FromString, |
||||
response_serializer=backend__pb2.Reply.SerializeToString, |
||||
), |
||||
'Embedding': grpc.unary_unary_rpc_method_handler( |
||||
servicer.Embedding, |
||||
request_deserializer=backend__pb2.PredictOptions.FromString, |
||||
response_serializer=backend__pb2.EmbeddingResult.SerializeToString, |
||||
), |
||||
'GenerateImage': grpc.unary_unary_rpc_method_handler( |
||||
servicer.GenerateImage, |
||||
request_deserializer=backend__pb2.GenerateImageRequest.FromString, |
||||
response_serializer=backend__pb2.Result.SerializeToString, |
||||
), |
||||
'AudioTranscription': grpc.unary_unary_rpc_method_handler( |
||||
servicer.AudioTranscription, |
||||
request_deserializer=backend__pb2.TranscriptRequest.FromString, |
||||
response_serializer=backend__pb2.TranscriptResult.SerializeToString, |
||||
), |
||||
'TTS': grpc.unary_unary_rpc_method_handler( |
||||
servicer.TTS, |
||||
request_deserializer=backend__pb2.TTSRequest.FromString, |
||||
response_serializer=backend__pb2.Result.SerializeToString, |
||||
), |
||||
} |
||||
generic_handler = grpc.method_handlers_generic_handler( |
||||
'backend.Backend', rpc_method_handlers) |
||||
server.add_generic_rpc_handlers((generic_handler,)) |
||||
|
||||
|
||||
# This class is part of an EXPERIMENTAL API. |
||||
class Backend(object): |
||||
"""Missing associated documentation comment in .proto file.""" |
||||
|
||||
@staticmethod |
||||
def Health(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Health', |
||||
backend__pb2.HealthMessage.SerializeToString, |
||||
backend__pb2.Reply.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def Predict(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Predict', |
||||
backend__pb2.PredictOptions.SerializeToString, |
||||
backend__pb2.Reply.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def LoadModel(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/LoadModel', |
||||
backend__pb2.ModelOptions.SerializeToString, |
||||
backend__pb2.Result.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def PredictStream(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_stream(request, target, '/backend.Backend/PredictStream', |
||||
backend__pb2.PredictOptions.SerializeToString, |
||||
backend__pb2.Reply.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def Embedding(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Embedding', |
||||
backend__pb2.PredictOptions.SerializeToString, |
||||
backend__pb2.EmbeddingResult.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def GenerateImage(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/GenerateImage', |
||||
backend__pb2.GenerateImageRequest.SerializeToString, |
||||
backend__pb2.Result.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def AudioTranscription(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/AudioTranscription', |
||||
backend__pb2.TranscriptRequest.SerializeToString, |
||||
backend__pb2.TranscriptResult.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
||||
|
||||
@staticmethod |
||||
def TTS(request, |
||||
target, |
||||
options=(), |
||||
channel_credentials=None, |
||||
call_credentials=None, |
||||
insecure=False, |
||||
compression=None, |
||||
wait_for_ready=None, |
||||
timeout=None, |
||||
metadata=None): |
||||
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TTS', |
||||
backend__pb2.TTSRequest.SerializeToString, |
||||
backend__pb2.Result.FromString, |
||||
options, channel_credentials, |
||||
insecure, call_credentials, compression, wait_for_ready, timeout, metadata) |
@ -0,0 +1,67 @@ |
||||
#!/usr/bin/env python3 |
||||
import grpc |
||||
from concurrent import futures |
||||
import time |
||||
import backend_pb2 |
||||
import backend_pb2_grpc |
||||
import argparse |
||||
import signal |
||||
import sys |
||||
import os |
||||
from sentence_transformers import SentenceTransformer |
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 |
||||
|
||||
# Implement the BackendServicer class with the service methods |
||||
class BackendServicer(backend_pb2_grpc.BackendServicer): |
||||
def Health(self, request, context): |
||||
return backend_pb2.Reply(message="OK") |
||||
def LoadModel(self, request, context): |
||||
model_name = request.Model |
||||
model_name = os.path.basename(model_name) |
||||
try: |
||||
self.model = SentenceTransformer(model_name) |
||||
except Exception as err: |
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") |
||||
# Implement your logic here for the LoadModel service |
||||
# Replace this with your desired response |
||||
return backend_pb2.Result(message="Model loaded successfully", success=True) |
||||
def Embedding(self, request, context): |
||||
# Implement your logic here for the Embedding service |
||||
# Replace this with your desired response |
||||
print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) |
||||
sentence_embeddings = self.model.encode(request.Embeddings) |
||||
return backend_pb2.EmbeddingResult(embeddings=sentence_embeddings) |
||||
|
||||
|
||||
def serve(address): |
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) |
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) |
||||
server.add_insecure_port(address) |
||||
server.start() |
||||
print("Server started. Listening on: " + address, file=sys.stderr) |
||||
|
||||
# Define the signal handler function |
||||
def signal_handler(sig, frame): |
||||
print("Received termination signal. Shutting down...") |
||||
server.stop(0) |
||||
sys.exit(0) |
||||
|
||||
# Set the signal handlers for SIGINT and SIGTERM |
||||
signal.signal(signal.SIGINT, signal_handler) |
||||
signal.signal(signal.SIGTERM, signal_handler) |
||||
|
||||
try: |
||||
while True: |
||||
time.sleep(_ONE_DAY_IN_SECONDS) |
||||
except KeyboardInterrupt: |
||||
server.stop(0) |
||||
|
||||
if __name__ == "__main__": |
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.") |
||||
parser.add_argument( |
||||
"--addr", default="localhost:50051", help="The address to bind the server to." |
||||
) |
||||
args = parser.parse_args() |
||||
|
||||
serve(args.addr) |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue