feat: `gRPC`-based backends (#743)
commit
e3cabb555d
@ -0,0 +1,105 @@ |
|||||||
|
package backend |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"sync" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
) |
||||||
|
|
||||||
|
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { |
||||||
|
if !c.Embeddings { |
||||||
|
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") |
||||||
|
} |
||||||
|
|
||||||
|
modelFile := c.Model |
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(c) |
||||||
|
|
||||||
|
var inferenceModel interface{} |
||||||
|
var err error |
||||||
|
|
||||||
|
opts := []model.Option{ |
||||||
|
model.WithLoadGRPCLLMModelOpts(grpcOpts), |
||||||
|
model.WithThreads(uint32(c.Threads)), |
||||||
|
model.WithAssetDir(o.AssetsDestination), |
||||||
|
model.WithModelFile(modelFile), |
||||||
|
model.WithContext(o.Context), |
||||||
|
} |
||||||
|
|
||||||
|
if c.Backend == "" { |
||||||
|
inferenceModel, err = loader.GreedyLoader(opts...) |
||||||
|
} else { |
||||||
|
opts = append(opts, model.WithBackendString(c.Backend)) |
||||||
|
inferenceModel, err = loader.BackendLoader(opts...) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var fn func() ([]float32, error) |
||||||
|
switch model := inferenceModel.(type) { |
||||||
|
case *grpc.Client: |
||||||
|
fn = func() ([]float32, error) { |
||||||
|
predictOptions := gRPCPredictOpts(c, loader.ModelPath) |
||||||
|
if len(tokens) > 0 { |
||||||
|
embeds := []int32{} |
||||||
|
|
||||||
|
for _, t := range tokens { |
||||||
|
embeds = append(embeds, int32(t)) |
||||||
|
} |
||||||
|
predictOptions.EmbeddingTokens = embeds |
||||||
|
|
||||||
|
res, err := model.Embeddings(o.Context, predictOptions) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return res.Embeddings, nil |
||||||
|
} |
||||||
|
predictOptions.Embeddings = s |
||||||
|
|
||||||
|
res, err := model.Embeddings(o.Context, predictOptions) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return res.Embeddings, nil |
||||||
|
} |
||||||
|
default: |
||||||
|
fn = func() ([]float32, error) { |
||||||
|
return nil, fmt.Errorf("embeddings not supported by the backend") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return func() ([]float32, error) { |
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
mutexMap.Lock() |
||||||
|
l, ok := mutexes[modelFile] |
||||||
|
if !ok { |
||||||
|
m := &sync.Mutex{} |
||||||
|
mutexes[modelFile] = m |
||||||
|
l = m |
||||||
|
} |
||||||
|
mutexMap.Unlock() |
||||||
|
l.Lock() |
||||||
|
defer l.Unlock() |
||||||
|
|
||||||
|
embeds, err := fn() |
||||||
|
if err != nil { |
||||||
|
return embeds, err |
||||||
|
} |
||||||
|
// Remove trailing 0s
|
||||||
|
for i := len(embeds) - 1; i >= 0; i-- { |
||||||
|
if embeds[i] == 0.0 { |
||||||
|
embeds = embeds[:i] |
||||||
|
} else { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
return embeds, nil |
||||||
|
}, nil |
||||||
|
} |
@ -0,0 +1,60 @@ |
|||||||
|
package backend |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"sync" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
) |
||||||
|
|
||||||
|
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { |
||||||
|
if c.Backend != model.StableDiffusionBackend { |
||||||
|
return nil, fmt.Errorf("endpoint only working with stablediffusion models") |
||||||
|
} |
||||||
|
|
||||||
|
inferenceModel, err := loader.BackendLoader( |
||||||
|
model.WithBackendString(c.Backend), |
||||||
|
model.WithAssetDir(o.AssetsDestination), |
||||||
|
model.WithThreads(uint32(c.Threads)), |
||||||
|
model.WithContext(o.Context), |
||||||
|
model.WithModelFile(c.ImageGenerationAssets), |
||||||
|
) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
fn := func() error { |
||||||
|
_, err := inferenceModel.GenerateImage( |
||||||
|
o.Context, |
||||||
|
&proto.GenerateImageRequest{ |
||||||
|
Height: int32(height), |
||||||
|
Width: int32(width), |
||||||
|
Mode: int32(mode), |
||||||
|
Step: int32(step), |
||||||
|
Seed: int32(seed), |
||||||
|
PositivePrompt: positive_prompt, |
||||||
|
NegativePrompt: negative_prompt, |
||||||
|
Dst: dst, |
||||||
|
}) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
return func() error { |
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
mutexMap.Lock() |
||||||
|
l, ok := mutexes[c.Backend] |
||||||
|
if !ok { |
||||||
|
m := &sync.Mutex{} |
||||||
|
mutexes[c.Backend] = m |
||||||
|
l = m |
||||||
|
} |
||||||
|
mutexMap.Unlock() |
||||||
|
l.Lock() |
||||||
|
defer l.Unlock() |
||||||
|
|
||||||
|
return fn() |
||||||
|
}, nil |
||||||
|
} |
@ -0,0 +1,98 @@ |
|||||||
|
package backend |
||||||
|
|
||||||
|
import ( |
||||||
|
"regexp" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
) |
||||||
|
|
||||||
|
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { |
||||||
|
modelFile := c.Model |
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(c) |
||||||
|
|
||||||
|
var inferenceModel *grpc.Client |
||||||
|
var err error |
||||||
|
|
||||||
|
opts := []model.Option{ |
||||||
|
model.WithLoadGRPCLLMModelOpts(grpcOpts), |
||||||
|
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
||||||
|
model.WithAssetDir(o.AssetsDestination), |
||||||
|
model.WithModelFile(modelFile), |
||||||
|
model.WithContext(o.Context), |
||||||
|
} |
||||||
|
|
||||||
|
if c.Backend == "" { |
||||||
|
inferenceModel, err = loader.GreedyLoader(opts...) |
||||||
|
} else { |
||||||
|
opts = append(opts, model.WithBackendString(c.Backend)) |
||||||
|
inferenceModel, err = loader.BackendLoader(opts...) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||||
|
fn := func() (string, error) { |
||||||
|
opts := gRPCPredictOpts(c, loader.ModelPath) |
||||||
|
opts.Prompt = s |
||||||
|
if tokenCallback != nil { |
||||||
|
ss := "" |
||||||
|
err := inferenceModel.PredictStream(o.Context, opts, func(s string) { |
||||||
|
tokenCallback(s) |
||||||
|
ss += s |
||||||
|
}) |
||||||
|
return ss, err |
||||||
|
} else { |
||||||
|
reply, err := inferenceModel.Predict(o.Context, opts) |
||||||
|
return reply.Message, err |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return func() (string, error) { |
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
mutexMap.Lock() |
||||||
|
l, ok := mutexes[modelFile] |
||||||
|
if !ok { |
||||||
|
m := &sync.Mutex{} |
||||||
|
mutexes[modelFile] = m |
||||||
|
l = m |
||||||
|
} |
||||||
|
mutexMap.Unlock() |
||||||
|
l.Lock() |
||||||
|
defer l.Unlock() |
||||||
|
|
||||||
|
return fn() |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
||||||
|
var mu sync.Mutex = sync.Mutex{} |
||||||
|
|
||||||
|
func Finetune(config config.Config, input, prediction string) string { |
||||||
|
if config.Echo { |
||||||
|
prediction = input + prediction |
||||||
|
} |
||||||
|
|
||||||
|
for _, c := range config.Cutstrings { |
||||||
|
mu.Lock() |
||||||
|
reg, ok := cutstrings[c] |
||||||
|
if !ok { |
||||||
|
cutstrings[c] = regexp.MustCompile(c) |
||||||
|
reg = cutstrings[c] |
||||||
|
} |
||||||
|
mu.Unlock() |
||||||
|
prediction = reg.ReplaceAllString(prediction, "") |
||||||
|
} |
||||||
|
|
||||||
|
for _, c := range config.TrimSpace { |
||||||
|
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
||||||
|
} |
||||||
|
return prediction |
||||||
|
|
||||||
|
} |
@ -0,0 +1,22 @@ |
|||||||
|
package backend |
||||||
|
|
||||||
|
import "sync" |
||||||
|
|
||||||
|
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
var mutexMap sync.Mutex |
||||||
|
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) |
||||||
|
|
||||||
|
func Lock(s string) *sync.Mutex { |
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
mutexMap.Lock() |
||||||
|
l, ok := mutexes[s] |
||||||
|
if !ok { |
||||||
|
m := &sync.Mutex{} |
||||||
|
mutexes[s] = m |
||||||
|
l = m |
||||||
|
} |
||||||
|
mutexMap.Unlock() |
||||||
|
l.Lock() |
||||||
|
|
||||||
|
return l |
||||||
|
} |
@ -0,0 +1,72 @@ |
|||||||
|
package backend |
||||||
|
|
||||||
|
import ( |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
|
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
) |
||||||
|
|
||||||
|
func gRPCModelOpts(c config.Config) *pb.ModelOptions { |
||||||
|
b := 512 |
||||||
|
if c.Batch != 0 { |
||||||
|
b = c.Batch |
||||||
|
} |
||||||
|
return &pb.ModelOptions{ |
||||||
|
ContextSize: int32(c.ContextSize), |
||||||
|
Seed: int32(c.Seed), |
||||||
|
NBatch: int32(b), |
||||||
|
F16Memory: c.F16, |
||||||
|
MLock: c.MMlock, |
||||||
|
NUMA: c.NUMA, |
||||||
|
Embeddings: c.Embeddings, |
||||||
|
LowVRAM: c.LowVRAM, |
||||||
|
NGPULayers: int32(c.NGPULayers), |
||||||
|
MMap: c.MMap, |
||||||
|
MainGPU: c.MainGPU, |
||||||
|
Threads: int32(c.Threads), |
||||||
|
TensorSplit: c.TensorSplit, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { |
||||||
|
promptCachePath := "" |
||||||
|
if c.PromptCachePath != "" { |
||||||
|
p := filepath.Join(modelPath, c.PromptCachePath) |
||||||
|
os.MkdirAll(filepath.Dir(p), 0755) |
||||||
|
promptCachePath = p |
||||||
|
} |
||||||
|
return &pb.PredictOptions{ |
||||||
|
Temperature: float32(c.Temperature), |
||||||
|
TopP: float32(c.TopP), |
||||||
|
TopK: int32(c.TopK), |
||||||
|
Tokens: int32(c.Maxtokens), |
||||||
|
Threads: int32(c.Threads), |
||||||
|
PromptCacheAll: c.PromptCacheAll, |
||||||
|
PromptCacheRO: c.PromptCacheRO, |
||||||
|
PromptCachePath: promptCachePath, |
||||||
|
F16KV: c.F16, |
||||||
|
DebugMode: c.Debug, |
||||||
|
Grammar: c.Grammar, |
||||||
|
|
||||||
|
Mirostat: int32(c.Mirostat), |
||||||
|
MirostatETA: float32(c.MirostatETA), |
||||||
|
MirostatTAU: float32(c.MirostatTAU), |
||||||
|
Debug: c.Debug, |
||||||
|
StopPrompts: c.StopWords, |
||||||
|
Repeat: int32(c.RepeatPenalty), |
||||||
|
NKeep: int32(c.Keep), |
||||||
|
Batch: int32(c.Batch), |
||||||
|
IgnoreEOS: c.IgnoreEOS, |
||||||
|
Seed: int32(c.Seed), |
||||||
|
FrequencyPenalty: float32(c.FrequencyPenalty), |
||||||
|
MLock: c.MMlock, |
||||||
|
MMap: c.MMap, |
||||||
|
MainGPU: c.MainGPU, |
||||||
|
TensorSplit: c.TensorSplit, |
||||||
|
TailFreeSamplingZ: float32(c.TFZ), |
||||||
|
TypicalP: float32(c.TypicalP), |
||||||
|
} |
||||||
|
} |
@ -1,401 +0,0 @@ |
|||||||
package api |
|
||||||
|
|
||||||
import ( |
|
||||||
"encoding/json" |
|
||||||
"fmt" |
|
||||||
"io/fs" |
|
||||||
"os" |
|
||||||
"path/filepath" |
|
||||||
"strings" |
|
||||||
"sync" |
|
||||||
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model" |
|
||||||
"github.com/gofiber/fiber/v2" |
|
||||||
"github.com/rs/zerolog/log" |
|
||||||
"gopkg.in/yaml.v3" |
|
||||||
) |
|
||||||
|
|
||||||
type Config struct { |
|
||||||
OpenAIRequest `yaml:"parameters"` |
|
||||||
Name string `yaml:"name"` |
|
||||||
StopWords []string `yaml:"stopwords"` |
|
||||||
Cutstrings []string `yaml:"cutstrings"` |
|
||||||
TrimSpace []string `yaml:"trimspace"` |
|
||||||
ContextSize int `yaml:"context_size"` |
|
||||||
F16 bool `yaml:"f16"` |
|
||||||
NUMA bool `yaml:"numa"` |
|
||||||
Threads int `yaml:"threads"` |
|
||||||
Debug bool `yaml:"debug"` |
|
||||||
Roles map[string]string `yaml:"roles"` |
|
||||||
Embeddings bool `yaml:"embeddings"` |
|
||||||
Backend string `yaml:"backend"` |
|
||||||
TemplateConfig TemplateConfig `yaml:"template"` |
|
||||||
MirostatETA float64 `yaml:"mirostat_eta"` |
|
||||||
MirostatTAU float64 `yaml:"mirostat_tau"` |
|
||||||
Mirostat int `yaml:"mirostat"` |
|
||||||
NGPULayers int `yaml:"gpu_layers"` |
|
||||||
MMap bool `yaml:"mmap"` |
|
||||||
MMlock bool `yaml:"mmlock"` |
|
||||||
LowVRAM bool `yaml:"low_vram"` |
|
||||||
|
|
||||||
TensorSplit string `yaml:"tensor_split"` |
|
||||||
MainGPU string `yaml:"main_gpu"` |
|
||||||
ImageGenerationAssets string `yaml:"asset_dir"` |
|
||||||
|
|
||||||
PromptCachePath string `yaml:"prompt_cache_path"` |
|
||||||
PromptCacheAll bool `yaml:"prompt_cache_all"` |
|
||||||
PromptCacheRO bool `yaml:"prompt_cache_ro"` |
|
||||||
|
|
||||||
Grammar string `yaml:"grammar"` |
|
||||||
|
|
||||||
FunctionsConfig Functions `yaml:"function"` |
|
||||||
|
|
||||||
PromptStrings, InputStrings []string |
|
||||||
InputToken [][]int |
|
||||||
functionCallString, functionCallNameString string |
|
||||||
} |
|
||||||
|
|
||||||
type Functions struct { |
|
||||||
DisableNoAction bool `yaml:"disable_no_action"` |
|
||||||
NoActionFunctionName string `yaml:"no_action_function_name"` |
|
||||||
NoActionDescriptionName string `yaml:"no_action_description_name"` |
|
||||||
} |
|
||||||
|
|
||||||
type TemplateConfig struct { |
|
||||||
Completion string `yaml:"completion"` |
|
||||||
Functions string `yaml:"function"` |
|
||||||
Chat string `yaml:"chat"` |
|
||||||
Edit string `yaml:"edit"` |
|
||||||
} |
|
||||||
|
|
||||||
type ConfigMerger struct { |
|
||||||
configs map[string]Config |
|
||||||
sync.Mutex |
|
||||||
} |
|
||||||
|
|
||||||
func defaultConfig(modelFile string) *Config { |
|
||||||
return &Config{ |
|
||||||
OpenAIRequest: defaultRequest(modelFile), |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func NewConfigMerger() *ConfigMerger { |
|
||||||
return &ConfigMerger{ |
|
||||||
configs: make(map[string]Config), |
|
||||||
} |
|
||||||
} |
|
||||||
func ReadConfigFile(file string) ([]*Config, error) { |
|
||||||
c := &[]*Config{} |
|
||||||
f, err := os.ReadFile(file) |
|
||||||
if err != nil { |
|
||||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
|
||||||
} |
|
||||||
if err := yaml.Unmarshal(f, c); err != nil { |
|
||||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
return *c, nil |
|
||||||
} |
|
||||||
|
|
||||||
func ReadConfig(file string) (*Config, error) { |
|
||||||
c := &Config{} |
|
||||||
f, err := os.ReadFile(file) |
|
||||||
if err != nil { |
|
||||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
|
||||||
} |
|
||||||
if err := yaml.Unmarshal(f, c); err != nil { |
|
||||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
return c, nil |
|
||||||
} |
|
||||||
|
|
||||||
func (cm *ConfigMerger) LoadConfigFile(file string) error { |
|
||||||
cm.Lock() |
|
||||||
defer cm.Unlock() |
|
||||||
c, err := ReadConfigFile(file) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("cannot load config file: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
for _, cc := range c { |
|
||||||
cm.configs[cc.Name] = *cc |
|
||||||
} |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (cm *ConfigMerger) LoadConfig(file string) error { |
|
||||||
cm.Lock() |
|
||||||
defer cm.Unlock() |
|
||||||
c, err := ReadConfig(file) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("cannot read config file: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
cm.configs[c.Name] = *c |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (cm *ConfigMerger) GetConfig(m string) (Config, bool) { |
|
||||||
cm.Lock() |
|
||||||
defer cm.Unlock() |
|
||||||
v, exists := cm.configs[m] |
|
||||||
return v, exists |
|
||||||
} |
|
||||||
|
|
||||||
func (cm *ConfigMerger) ListConfigs() []string { |
|
||||||
cm.Lock() |
|
||||||
defer cm.Unlock() |
|
||||||
var res []string |
|
||||||
for k := range cm.configs { |
|
||||||
res = append(res, k) |
|
||||||
} |
|
||||||
return res |
|
||||||
} |
|
||||||
|
|
||||||
func (cm *ConfigMerger) LoadConfigs(path string) error { |
|
||||||
cm.Lock() |
|
||||||
defer cm.Unlock() |
|
||||||
entries, err := os.ReadDir(path) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
files := make([]fs.FileInfo, 0, len(entries)) |
|
||||||
for _, entry := range entries { |
|
||||||
info, err := entry.Info() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
files = append(files, info) |
|
||||||
} |
|
||||||
for _, file := range files { |
|
||||||
// Skip templates, YAML and .keep files
|
|
||||||
if !strings.Contains(file.Name(), ".yaml") { |
|
||||||
continue |
|
||||||
} |
|
||||||
c, err := ReadConfig(filepath.Join(path, file.Name())) |
|
||||||
if err == nil { |
|
||||||
cm.configs[c.Name] = *c |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func updateConfig(config *Config, input *OpenAIRequest) { |
|
||||||
if input.Echo { |
|
||||||
config.Echo = input.Echo |
|
||||||
} |
|
||||||
if input.TopK != 0 { |
|
||||||
config.TopK = input.TopK |
|
||||||
} |
|
||||||
if input.TopP != 0 { |
|
||||||
config.TopP = input.TopP |
|
||||||
} |
|
||||||
|
|
||||||
if input.Grammar != "" { |
|
||||||
config.Grammar = input.Grammar |
|
||||||
} |
|
||||||
|
|
||||||
if input.Temperature != 0 { |
|
||||||
config.Temperature = input.Temperature |
|
||||||
} |
|
||||||
|
|
||||||
if input.Maxtokens != 0 { |
|
||||||
config.Maxtokens = input.Maxtokens |
|
||||||
} |
|
||||||
|
|
||||||
switch stop := input.Stop.(type) { |
|
||||||
case string: |
|
||||||
if stop != "" { |
|
||||||
config.StopWords = append(config.StopWords, stop) |
|
||||||
} |
|
||||||
case []interface{}: |
|
||||||
for _, pp := range stop { |
|
||||||
if s, ok := pp.(string); ok { |
|
||||||
config.StopWords = append(config.StopWords, s) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
if input.RepeatPenalty != 0 { |
|
||||||
config.RepeatPenalty = input.RepeatPenalty |
|
||||||
} |
|
||||||
|
|
||||||
if input.Keep != 0 { |
|
||||||
config.Keep = input.Keep |
|
||||||
} |
|
||||||
|
|
||||||
if input.Batch != 0 { |
|
||||||
config.Batch = input.Batch |
|
||||||
} |
|
||||||
|
|
||||||
if input.F16 { |
|
||||||
config.F16 = input.F16 |
|
||||||
} |
|
||||||
|
|
||||||
if input.IgnoreEOS { |
|
||||||
config.IgnoreEOS = input.IgnoreEOS |
|
||||||
} |
|
||||||
|
|
||||||
if input.Seed != 0 { |
|
||||||
config.Seed = input.Seed |
|
||||||
} |
|
||||||
|
|
||||||
if input.Mirostat != 0 { |
|
||||||
config.Mirostat = input.Mirostat |
|
||||||
} |
|
||||||
|
|
||||||
if input.MirostatETA != 0 { |
|
||||||
config.MirostatETA = input.MirostatETA |
|
||||||
} |
|
||||||
|
|
||||||
if input.MirostatTAU != 0 { |
|
||||||
config.MirostatTAU = input.MirostatTAU |
|
||||||
} |
|
||||||
|
|
||||||
if input.TypicalP != 0 { |
|
||||||
config.TypicalP = input.TypicalP |
|
||||||
} |
|
||||||
|
|
||||||
switch inputs := input.Input.(type) { |
|
||||||
case string: |
|
||||||
if inputs != "" { |
|
||||||
config.InputStrings = append(config.InputStrings, inputs) |
|
||||||
} |
|
||||||
case []interface{}: |
|
||||||
for _, pp := range inputs { |
|
||||||
switch i := pp.(type) { |
|
||||||
case string: |
|
||||||
config.InputStrings = append(config.InputStrings, i) |
|
||||||
case []interface{}: |
|
||||||
tokens := []int{} |
|
||||||
for _, ii := range i { |
|
||||||
tokens = append(tokens, int(ii.(float64))) |
|
||||||
} |
|
||||||
config.InputToken = append(config.InputToken, tokens) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
// Can be either a string or an object
|
|
||||||
switch fnc := input.FunctionCall.(type) { |
|
||||||
case string: |
|
||||||
if fnc != "" { |
|
||||||
config.functionCallString = fnc |
|
||||||
} |
|
||||||
case map[string]interface{}: |
|
||||||
var name string |
|
||||||
n, exists := fnc["name"] |
|
||||||
if exists { |
|
||||||
nn, e := n.(string) |
|
||||||
if e { |
|
||||||
name = nn |
|
||||||
} |
|
||||||
} |
|
||||||
config.functionCallNameString = name |
|
||||||
} |
|
||||||
|
|
||||||
switch p := input.Prompt.(type) { |
|
||||||
case string: |
|
||||||
config.PromptStrings = append(config.PromptStrings, p) |
|
||||||
case []interface{}: |
|
||||||
for _, pp := range p { |
|
||||||
if s, ok := pp.(string); ok { |
|
||||||
config.PromptStrings = append(config.PromptStrings, s) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { |
|
||||||
input := new(OpenAIRequest) |
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil { |
|
||||||
return "", nil, err |
|
||||||
} |
|
||||||
|
|
||||||
modelFile := input.Model |
|
||||||
|
|
||||||
if c.Params("model") != "" { |
|
||||||
modelFile = c.Params("model") |
|
||||||
} |
|
||||||
|
|
||||||
received, _ := json.Marshal(input) |
|
||||||
|
|
||||||
log.Debug().Msgf("Request received: %s", string(received)) |
|
||||||
|
|
||||||
// Set model from bearer token, if available
|
|
||||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
|
||||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
|
||||||
|
|
||||||
// If no model was specified, take the first available
|
|
||||||
if modelFile == "" && !bearerExists && randomModel { |
|
||||||
models, _ := loader.ListModels() |
|
||||||
if len(models) > 0 { |
|
||||||
modelFile = models[0] |
|
||||||
log.Debug().Msgf("No model specified, using: %s", modelFile) |
|
||||||
} else { |
|
||||||
log.Debug().Msgf("No model specified, returning error") |
|
||||||
return "", nil, fmt.Errorf("no model specified") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// If a model is found in bearer token takes precedence
|
|
||||||
if bearerExists { |
|
||||||
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
|
||||||
modelFile = bearer |
|
||||||
} |
|
||||||
return modelFile, input, nil |
|
||||||
} |
|
||||||
|
|
||||||
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { |
|
||||||
// Load a config file if present after the model name
|
|
||||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") |
|
||||||
|
|
||||||
var config *Config |
|
||||||
|
|
||||||
defaults := func() { |
|
||||||
config = defaultConfig(modelFile) |
|
||||||
config.ContextSize = ctx |
|
||||||
config.Threads = threads |
|
||||||
config.F16 = f16 |
|
||||||
config.Debug = debug |
|
||||||
} |
|
||||||
|
|
||||||
cfg, exists := cm.GetConfig(modelFile) |
|
||||||
if !exists { |
|
||||||
if _, err := os.Stat(modelConfig); err == nil { |
|
||||||
if err := cm.LoadConfig(modelConfig); err != nil { |
|
||||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) |
|
||||||
} |
|
||||||
cfg, exists = cm.GetConfig(modelFile) |
|
||||||
if exists { |
|
||||||
config = &cfg |
|
||||||
} else { |
|
||||||
defaults() |
|
||||||
} |
|
||||||
} else { |
|
||||||
defaults() |
|
||||||
} |
|
||||||
} else { |
|
||||||
config = &cfg |
|
||||||
} |
|
||||||
|
|
||||||
// Set the parameters for the language model prediction
|
|
||||||
updateConfig(config, input) |
|
||||||
|
|
||||||
// Don't allow 0 as setting
|
|
||||||
if config.Threads == 0 { |
|
||||||
if threads != 0 { |
|
||||||
config.Threads = threads |
|
||||||
} else { |
|
||||||
config.Threads = 4 |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// Enforce debug flag if passed from CLI
|
|
||||||
if debug { |
|
||||||
config.Debug = true |
|
||||||
} |
|
||||||
|
|
||||||
return config, input, nil |
|
||||||
} |
|
@ -0,0 +1,209 @@ |
|||||||
|
package api_config |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"io/fs" |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"gopkg.in/yaml.v3" |
||||||
|
) |
||||||
|
|
||||||
|
type Config struct { |
||||||
|
PredictionOptions `yaml:"parameters"` |
||||||
|
Name string `yaml:"name"` |
||||||
|
StopWords []string `yaml:"stopwords"` |
||||||
|
Cutstrings []string `yaml:"cutstrings"` |
||||||
|
TrimSpace []string `yaml:"trimspace"` |
||||||
|
ContextSize int `yaml:"context_size"` |
||||||
|
F16 bool `yaml:"f16"` |
||||||
|
NUMA bool `yaml:"numa"` |
||||||
|
Threads int `yaml:"threads"` |
||||||
|
Debug bool `yaml:"debug"` |
||||||
|
Roles map[string]string `yaml:"roles"` |
||||||
|
Embeddings bool `yaml:"embeddings"` |
||||||
|
Backend string `yaml:"backend"` |
||||||
|
TemplateConfig TemplateConfig `yaml:"template"` |
||||||
|
MirostatETA float64 `yaml:"mirostat_eta"` |
||||||
|
MirostatTAU float64 `yaml:"mirostat_tau"` |
||||||
|
Mirostat int `yaml:"mirostat"` |
||||||
|
NGPULayers int `yaml:"gpu_layers"` |
||||||
|
MMap bool `yaml:"mmap"` |
||||||
|
MMlock bool `yaml:"mmlock"` |
||||||
|
LowVRAM bool `yaml:"low_vram"` |
||||||
|
|
||||||
|
TensorSplit string `yaml:"tensor_split"` |
||||||
|
MainGPU string `yaml:"main_gpu"` |
||||||
|
ImageGenerationAssets string `yaml:"asset_dir"` |
||||||
|
|
||||||
|
PromptCachePath string `yaml:"prompt_cache_path"` |
||||||
|
PromptCacheAll bool `yaml:"prompt_cache_all"` |
||||||
|
PromptCacheRO bool `yaml:"prompt_cache_ro"` |
||||||
|
|
||||||
|
Grammar string `yaml:"grammar"` |
||||||
|
|
||||||
|
PromptStrings, InputStrings []string |
||||||
|
InputToken [][]int |
||||||
|
functionCallString, functionCallNameString string |
||||||
|
|
||||||
|
FunctionsConfig Functions `yaml:"function"` |
||||||
|
} |
||||||
|
|
||||||
|
type Functions struct { |
||||||
|
DisableNoAction bool `yaml:"disable_no_action"` |
||||||
|
NoActionFunctionName string `yaml:"no_action_function_name"` |
||||||
|
NoActionDescriptionName string `yaml:"no_action_description_name"` |
||||||
|
} |
||||||
|
|
||||||
|
type TemplateConfig struct { |
||||||
|
Completion string `yaml:"completion"` |
||||||
|
Functions string `yaml:"function"` |
||||||
|
Chat string `yaml:"chat"` |
||||||
|
Edit string `yaml:"edit"` |
||||||
|
} |
||||||
|
|
||||||
|
type ConfigLoader struct { |
||||||
|
configs map[string]Config |
||||||
|
sync.Mutex |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Config) SetFunctionCallString(s string) { |
||||||
|
c.functionCallString = s |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Config) SetFunctionCallNameString(s string) { |
||||||
|
c.functionCallNameString = s |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Config) ShouldUseFunctions() bool { |
||||||
|
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Config) ShouldCallSpecificFunction() bool { |
||||||
|
return len(c.functionCallNameString) > 0 |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Config) FunctionToCall() string { |
||||||
|
return c.functionCallNameString |
||||||
|
} |
||||||
|
|
||||||
|
func defaultPredictOptions(modelFile string) PredictionOptions { |
||||||
|
return PredictionOptions{ |
||||||
|
TopP: 0.7, |
||||||
|
TopK: 80, |
||||||
|
Maxtokens: 512, |
||||||
|
Temperature: 0.9, |
||||||
|
Model: modelFile, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func DefaultConfig(modelFile string) *Config { |
||||||
|
return &Config{ |
||||||
|
PredictionOptions: defaultPredictOptions(modelFile), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func NewConfigLoader() *ConfigLoader { |
||||||
|
return &ConfigLoader{ |
||||||
|
configs: make(map[string]Config), |
||||||
|
} |
||||||
|
} |
||||||
|
func ReadConfigFile(file string) ([]*Config, error) { |
||||||
|
c := &[]*Config{} |
||||||
|
f, err := os.ReadFile(file) |
||||||
|
if err != nil { |
||||||
|
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||||
|
} |
||||||
|
if err := yaml.Unmarshal(f, c); err != nil { |
||||||
|
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
return *c, nil |
||||||
|
} |
||||||
|
|
||||||
|
func ReadConfig(file string) (*Config, error) { |
||||||
|
c := &Config{} |
||||||
|
f, err := os.ReadFile(file) |
||||||
|
if err != nil { |
||||||
|
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||||
|
} |
||||||
|
if err := yaml.Unmarshal(f, c); err != nil { |
||||||
|
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
return c, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (cm *ConfigLoader) LoadConfigFile(file string) error { |
||||||
|
cm.Lock() |
||||||
|
defer cm.Unlock() |
||||||
|
c, err := ReadConfigFile(file) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("cannot load config file: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
for _, cc := range c { |
||||||
|
cm.configs[cc.Name] = *cc |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (cm *ConfigLoader) LoadConfig(file string) error { |
||||||
|
cm.Lock() |
||||||
|
defer cm.Unlock() |
||||||
|
c, err := ReadConfig(file) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("cannot read config file: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
cm.configs[c.Name] = *c |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { |
||||||
|
cm.Lock() |
||||||
|
defer cm.Unlock() |
||||||
|
v, exists := cm.configs[m] |
||||||
|
return v, exists |
||||||
|
} |
||||||
|
|
||||||
|
func (cm *ConfigLoader) ListConfigs() []string { |
||||||
|
cm.Lock() |
||||||
|
defer cm.Unlock() |
||||||
|
var res []string |
||||||
|
for k := range cm.configs { |
||||||
|
res = append(res, k) |
||||||
|
} |
||||||
|
return res |
||||||
|
} |
||||||
|
|
||||||
|
func (cm *ConfigLoader) LoadConfigs(path string) error { |
||||||
|
cm.Lock() |
||||||
|
defer cm.Unlock() |
||||||
|
entries, err := os.ReadDir(path) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
files := make([]fs.FileInfo, 0, len(entries)) |
||||||
|
for _, entry := range entries { |
||||||
|
info, err := entry.Info() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
files = append(files, info) |
||||||
|
} |
||||||
|
for _, file := range files { |
||||||
|
// Skip templates, YAML and .keep files
|
||||||
|
if !strings.Contains(file.Name(), ".yaml") { |
||||||
|
continue |
||||||
|
} |
||||||
|
c, err := ReadConfig(filepath.Join(path, file.Name())) |
||||||
|
if err == nil { |
||||||
|
cm.configs[c.Name] = *c |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,37 @@ |
|||||||
|
package api_config |
||||||
|
|
||||||
|
type PredictionOptions struct { |
||||||
|
|
||||||
|
// Also part of the OpenAI official spec
|
||||||
|
Model string `json:"model" yaml:"model"` |
||||||
|
|
||||||
|
// Also part of the OpenAI official spec
|
||||||
|
Language string `json:"language"` |
||||||
|
|
||||||
|
// Also part of the OpenAI official spec. use it for returning multiple results
|
||||||
|
N int `json:"n"` |
||||||
|
|
||||||
|
// Common options between all the API calls, part of the OpenAI spec
|
||||||
|
TopP float64 `json:"top_p" yaml:"top_p"` |
||||||
|
TopK int `json:"top_k" yaml:"top_k"` |
||||||
|
Temperature float64 `json:"temperature" yaml:"temperature"` |
||||||
|
Maxtokens int `json:"max_tokens" yaml:"max_tokens"` |
||||||
|
Echo bool `json:"echo"` |
||||||
|
|
||||||
|
// Custom parameters - not present in the OpenAI API
|
||||||
|
Batch int `json:"batch" yaml:"batch"` |
||||||
|
F16 bool `json:"f16" yaml:"f16"` |
||||||
|
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` |
||||||
|
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` |
||||||
|
Keep int `json:"n_keep" yaml:"n_keep"` |
||||||
|
|
||||||
|
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` |
||||||
|
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` |
||||||
|
Mirostat int `json:"mirostat" yaml:"mirostat"` |
||||||
|
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` |
||||||
|
TFZ float64 `json:"tfz" yaml:"tfz"` |
||||||
|
|
||||||
|
TypicalP float64 `json:"typical_p" yaml:"typical_p"` |
||||||
|
Seed int `json:"seed" yaml:"seed"` |
||||||
|
} |
@ -1,78 +0,0 @@ |
|||||||
package api |
|
||||||
|
|
||||||
import ( |
|
||||||
"fmt" |
|
||||||
"os" |
|
||||||
"path/filepath" |
|
||||||
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model" |
|
||||||
"github.com/go-skynet/LocalAI/pkg/tts" |
|
||||||
"github.com/go-skynet/LocalAI/pkg/utils" |
|
||||||
llama "github.com/go-skynet/go-llama.cpp" |
|
||||||
"github.com/gofiber/fiber/v2" |
|
||||||
) |
|
||||||
|
|
||||||
type TTSRequest struct { |
|
||||||
Model string `json:"model" yaml:"model"` |
|
||||||
Input string `json:"input" yaml:"input"` |
|
||||||
} |
|
||||||
|
|
||||||
func generateUniqueFileName(dir, baseName, ext string) string { |
|
||||||
counter := 1 |
|
||||||
fileName := baseName + ext |
|
||||||
|
|
||||||
for { |
|
||||||
filePath := filepath.Join(dir, fileName) |
|
||||||
_, err := os.Stat(filePath) |
|
||||||
if os.IsNotExist(err) { |
|
||||||
return fileName |
|
||||||
} |
|
||||||
|
|
||||||
counter++ |
|
||||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
|
|
||||||
input := new(TTSRequest) |
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
piperModel, err := o.loader.BackendLoader(model.PiperBackend, input.Model, []llama.ModelOption{}, uint32(0), o.assetsDestination) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
if piperModel == nil { |
|
||||||
return fmt.Errorf("could not load piper model") |
|
||||||
} |
|
||||||
|
|
||||||
w, ok := piperModel.(*tts.Piper) |
|
||||||
if !ok { |
|
||||||
return fmt.Errorf("loader returned non-piper object %+v", w) |
|
||||||
} |
|
||||||
|
|
||||||
if err := os.MkdirAll(o.audioDir, 0755); err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
fileName := generateUniqueFileName(o.audioDir, "piper", ".wav") |
|
||||||
filePath := filepath.Join(o.audioDir, fileName) |
|
||||||
|
|
||||||
modelPath := filepath.Join(o.loader.ModelPath, input.Model) |
|
||||||
|
|
||||||
if err := utils.VerifyPath(modelPath, o.loader.ModelPath); err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
if err := w.TTS(input.Input, modelPath, filePath); err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
return c.Download(filePath) |
|
||||||
} |
|
||||||
} |
|
@ -0,0 +1,84 @@ |
|||||||
|
package localai |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/utils" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
) |
||||||
|
|
||||||
|
type TTSRequest struct { |
||||||
|
Model string `json:"model" yaml:"model"` |
||||||
|
Input string `json:"input" yaml:"input"` |
||||||
|
} |
||||||
|
|
||||||
|
func generateUniqueFileName(dir, baseName, ext string) string { |
||||||
|
counter := 1 |
||||||
|
fileName := baseName + ext |
||||||
|
|
||||||
|
for { |
||||||
|
filePath := filepath.Join(dir, fileName) |
||||||
|
_, err := os.Stat(filePath) |
||||||
|
if os.IsNotExist(err) { |
||||||
|
return fileName |
||||||
|
} |
||||||
|
|
||||||
|
counter++ |
||||||
|
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
|
||||||
|
input := new(TTSRequest) |
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
piperModel, err := o.Loader.BackendLoader( |
||||||
|
model.WithBackendString(model.PiperBackend), |
||||||
|
model.WithModelFile(input.Model), |
||||||
|
model.WithContext(o.Context), |
||||||
|
model.WithAssetDir(o.AssetsDestination)) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if piperModel == nil { |
||||||
|
return fmt.Errorf("could not load piper model") |
||||||
|
} |
||||||
|
|
||||||
|
if err := os.MkdirAll(o.AudioDir, 0755); err != nil { |
||||||
|
return fmt.Errorf("failed creating audio directory: %s", err) |
||||||
|
} |
||||||
|
|
||||||
|
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") |
||||||
|
filePath := filepath.Join(o.AudioDir, fileName) |
||||||
|
|
||||||
|
modelPath := filepath.Join(o.Loader.ModelPath, input.Model) |
||||||
|
|
||||||
|
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ |
||||||
|
Text: input.Input, |
||||||
|
Model: modelPath, |
||||||
|
Dst: filePath, |
||||||
|
}); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
return c.Download(filePath) |
||||||
|
} |
||||||
|
} |
@ -1,961 +0,0 @@ |
|||||||
package api |
|
||||||
|
|
||||||
import ( |
|
||||||
"bufio" |
|
||||||
"bytes" |
|
||||||
"encoding/base64" |
|
||||||
"encoding/json" |
|
||||||
"errors" |
|
||||||
"fmt" |
|
||||||
"io" |
|
||||||
"io/ioutil" |
|
||||||
"net/http" |
|
||||||
"os" |
|
||||||
"path" |
|
||||||
"path/filepath" |
|
||||||
"strconv" |
|
||||||
"strings" |
|
||||||
|
|
||||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" |
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar" |
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model" |
|
||||||
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" |
|
||||||
llama "github.com/go-skynet/go-llama.cpp" |
|
||||||
"github.com/gofiber/fiber/v2" |
|
||||||
"github.com/rs/zerolog/log" |
|
||||||
"github.com/valyala/fasthttp" |
|
||||||
) |
|
||||||
|
|
||||||
// APIError provides error information returned by the OpenAI API.
|
|
||||||
type APIError struct { |
|
||||||
Code any `json:"code,omitempty"` |
|
||||||
Message string `json:"message"` |
|
||||||
Param *string `json:"param,omitempty"` |
|
||||||
Type string `json:"type"` |
|
||||||
} |
|
||||||
|
|
||||||
type ErrorResponse struct { |
|
||||||
Error *APIError `json:"error,omitempty"` |
|
||||||
} |
|
||||||
|
|
||||||
type OpenAIUsage struct { |
|
||||||
PromptTokens int `json:"prompt_tokens"` |
|
||||||
CompletionTokens int `json:"completion_tokens"` |
|
||||||
TotalTokens int `json:"total_tokens"` |
|
||||||
} |
|
||||||
|
|
||||||
type Item struct { |
|
||||||
Embedding []float32 `json:"embedding"` |
|
||||||
Index int `json:"index"` |
|
||||||
Object string `json:"object,omitempty"` |
|
||||||
|
|
||||||
// Images
|
|
||||||
URL string `json:"url,omitempty"` |
|
||||||
B64JSON string `json:"b64_json,omitempty"` |
|
||||||
} |
|
||||||
|
|
||||||
type OpenAIResponse struct { |
|
||||||
Created int `json:"created,omitempty"` |
|
||||||
Object string `json:"object,omitempty"` |
|
||||||
ID string `json:"id,omitempty"` |
|
||||||
Model string `json:"model,omitempty"` |
|
||||||
Choices []Choice `json:"choices,omitempty"` |
|
||||||
Data []Item `json:"data,omitempty"` |
|
||||||
|
|
||||||
Usage OpenAIUsage `json:"usage"` |
|
||||||
} |
|
||||||
|
|
||||||
type Choice struct { |
|
||||||
Index int `json:"index,omitempty"` |
|
||||||
FinishReason string `json:"finish_reason,omitempty"` |
|
||||||
Message *Message `json:"message,omitempty"` |
|
||||||
Delta *Message `json:"delta,omitempty"` |
|
||||||
Text string `json:"text,omitempty"` |
|
||||||
} |
|
||||||
|
|
||||||
type Message struct { |
|
||||||
// The message role
|
|
||||||
Role string `json:"role,omitempty" yaml:"role"` |
|
||||||
// The message content
|
|
||||||
Content *string `json:"content" yaml:"content"` |
|
||||||
// A result of a function call
|
|
||||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` |
|
||||||
} |
|
||||||
|
|
||||||
type OpenAIModel struct { |
|
||||||
ID string `json:"id"` |
|
||||||
Object string `json:"object"` |
|
||||||
} |
|
||||||
|
|
||||||
type OpenAIRequest struct { |
|
||||||
Model string `json:"model" yaml:"model"` |
|
||||||
|
|
||||||
// whisper
|
|
||||||
File string `json:"file" validate:"required"` |
|
||||||
Language string `json:"language"` |
|
||||||
//whisper/image
|
|
||||||
ResponseFormat string `json:"response_format"` |
|
||||||
// image
|
|
||||||
Size string `json:"size"` |
|
||||||
// Prompt is read only by completion/image API calls
|
|
||||||
Prompt interface{} `json:"prompt" yaml:"prompt"` |
|
||||||
|
|
||||||
// Edit endpoint
|
|
||||||
Instruction string `json:"instruction" yaml:"instruction"` |
|
||||||
Input interface{} `json:"input" yaml:"input"` |
|
||||||
|
|
||||||
Stop interface{} `json:"stop" yaml:"stop"` |
|
||||||
|
|
||||||
// Messages is read only by chat/completion API calls
|
|
||||||
Messages []Message `json:"messages" yaml:"messages"` |
|
||||||
|
|
||||||
// A list of available functions to call
|
|
||||||
Functions []grammar.Function `json:"functions" yaml:"functions"` |
|
||||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
|
||||||
|
|
||||||
Stream bool `json:"stream"` |
|
||||||
Echo bool `json:"echo"` |
|
||||||
// Common options between all the API calls
|
|
||||||
TopP float64 `json:"top_p" yaml:"top_p"` |
|
||||||
TopK int `json:"top_k" yaml:"top_k"` |
|
||||||
Temperature float64 `json:"temperature" yaml:"temperature"` |
|
||||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"` |
|
||||||
|
|
||||||
N int `json:"n"` |
|
||||||
|
|
||||||
// Custom parameters - not present in the OpenAI API
|
|
||||||
Batch int `json:"batch" yaml:"batch"` |
|
||||||
F16 bool `json:"f16" yaml:"f16"` |
|
||||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` |
|
||||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` |
|
||||||
Keep int `json:"n_keep" yaml:"n_keep"` |
|
||||||
|
|
||||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` |
|
||||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` |
|
||||||
Mirostat int `json:"mirostat" yaml:"mirostat"` |
|
||||||
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` |
|
||||||
TFZ float64 `json:"tfz" yaml:"tfz"` |
|
||||||
|
|
||||||
Seed int `json:"seed" yaml:"seed"` |
|
||||||
|
|
||||||
// Image (not supported by OpenAI)
|
|
||||||
Mode int `json:"mode"` |
|
||||||
Step int `json:"step"` |
|
||||||
|
|
||||||
// A grammar to constrain the LLM output
|
|
||||||
Grammar string `json:"grammar" yaml:"grammar"` |
|
||||||
// A grammar object
|
|
||||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` |
|
||||||
|
|
||||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"` |
|
||||||
} |
|
||||||
|
|
||||||
func defaultRequest(modelFile string) OpenAIRequest { |
|
||||||
return OpenAIRequest{ |
|
||||||
TopP: 0.7, |
|
||||||
TopK: 80, |
|
||||||
Maxtokens: 512, |
|
||||||
Temperature: 0.9, |
|
||||||
Model: modelFile, |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/completions
|
|
||||||
func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
|
||||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
|
||||||
resp := OpenAIResponse{ |
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{ |
|
||||||
{ |
|
||||||
Index: 0, |
|
||||||
Text: s, |
|
||||||
}, |
|
||||||
}, |
|
||||||
Object: "text_completion", |
|
||||||
} |
|
||||||
log.Debug().Msgf("Sending goroutine: %s", s) |
|
||||||
|
|
||||||
responses <- resp |
|
||||||
return true |
|
||||||
}) |
|
||||||
close(responses) |
|
||||||
} |
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
model, input, err := readInput(c, o.loader, true) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("`input`: %+v", input) |
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config) |
|
||||||
|
|
||||||
if input.Stream { |
|
||||||
log.Debug().Msgf("Stream request received") |
|
||||||
c.Context().SetContentType("text/event-stream") |
|
||||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
|
||||||
//c.Set("Content-Type", "text/event-stream")
|
|
||||||
c.Set("Cache-Control", "no-cache") |
|
||||||
c.Set("Connection", "keep-alive") |
|
||||||
c.Set("Transfer-Encoding", "chunked") |
|
||||||
} |
|
||||||
|
|
||||||
templateFile := config.Model |
|
||||||
|
|
||||||
if config.TemplateConfig.Completion != "" { |
|
||||||
templateFile = config.TemplateConfig.Completion |
|
||||||
} |
|
||||||
|
|
||||||
if input.Stream { |
|
||||||
if len(config.PromptStrings) > 1 { |
|
||||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") |
|
||||||
} |
|
||||||
|
|
||||||
predInput := config.PromptStrings[0] |
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
|
||||||
Input string |
|
||||||
}{ |
|
||||||
Input: predInput, |
|
||||||
}) |
|
||||||
if err == nil { |
|
||||||
predInput = templatedInput |
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
|
||||||
} |
|
||||||
|
|
||||||
responses := make(chan OpenAIResponse) |
|
||||||
|
|
||||||
go process(predInput, input, config, o.loader, responses) |
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
|
||||||
|
|
||||||
for ev := range responses { |
|
||||||
var buf bytes.Buffer |
|
||||||
enc := json.NewEncoder(&buf) |
|
||||||
enc.Encode(ev) |
|
||||||
|
|
||||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
|
||||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
|
||||||
w.Flush() |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{ |
|
||||||
{ |
|
||||||
Index: 0, |
|
||||||
FinishReason: "stop", |
|
||||||
}, |
|
||||||
}, |
|
||||||
Object: "text_completion", |
|
||||||
} |
|
||||||
respData, _ := json.Marshal(resp) |
|
||||||
|
|
||||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
|
||||||
w.WriteString("data: [DONE]\n\n") |
|
||||||
w.Flush() |
|
||||||
})) |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
var result []Choice |
|
||||||
for _, i := range config.PromptStrings { |
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
|
||||||
Input string |
|
||||||
}{ |
|
||||||
Input: i, |
|
||||||
}) |
|
||||||
if err == nil { |
|
||||||
i = templatedInput |
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
|
||||||
} |
|
||||||
|
|
||||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { |
|
||||||
*c = append(*c, Choice{Text: s}) |
|
||||||
}, nil) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
result = append(result, r...) |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result, |
|
||||||
Object: "text_completion", |
|
||||||
} |
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp) |
|
||||||
log.Debug().Msgf("Response: %s", jsonResult) |
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/embeddings
|
|
||||||
func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
model, input, err := readInput(c, o.loader, true) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config) |
|
||||||
items := []Item{} |
|
||||||
|
|
||||||
for i, s := range config.InputToken { |
|
||||||
// get the model function to call for the result
|
|
||||||
embedFn, err := ModelEmbedding("", s, o.loader, *config, o) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
embeddings, err := embedFn() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
|
||||||
} |
|
||||||
|
|
||||||
for i, s := range config.InputStrings { |
|
||||||
// get the model function to call for the result
|
|
||||||
embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config, o) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
embeddings, err := embedFn() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Data: items, |
|
||||||
Object: "list", |
|
||||||
} |
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp) |
|
||||||
log.Debug().Msgf("Response: %s", jsonResult) |
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
|
|
||||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
|
||||||
initialMessage := OpenAIResponse{ |
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, |
|
||||||
Object: "chat.completion.chunk", |
|
||||||
} |
|
||||||
responses <- initialMessage |
|
||||||
|
|
||||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
|
||||||
resp := OpenAIResponse{ |
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, |
|
||||||
Object: "chat.completion.chunk", |
|
||||||
} |
|
||||||
log.Debug().Msgf("Sending goroutine: %s", s) |
|
||||||
|
|
||||||
responses <- resp |
|
||||||
return true |
|
||||||
}) |
|
||||||
close(responses) |
|
||||||
} |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
processFunctions := false |
|
||||||
funcs := grammar.Functions{} |
|
||||||
model, input, err := readInput(c, o.loader, true) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
log.Debug().Msgf("Configuration read: %+v", config) |
|
||||||
|
|
||||||
// Allow the user to set custom actions via config file
|
|
||||||
// to be "embedded" in each model
|
|
||||||
noActionName := "answer" |
|
||||||
noActionDescription := "use this action to answer without performing any action" |
|
||||||
|
|
||||||
if config.FunctionsConfig.NoActionFunctionName != "" { |
|
||||||
noActionName = config.FunctionsConfig.NoActionFunctionName |
|
||||||
} |
|
||||||
if config.FunctionsConfig.NoActionDescriptionName != "" { |
|
||||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName |
|
||||||
} |
|
||||||
|
|
||||||
// process functions if we have any defined or if we have a function call string
|
|
||||||
if len(input.Functions) > 0 && |
|
||||||
((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) { |
|
||||||
log.Debug().Msgf("Response needs to process functions") |
|
||||||
|
|
||||||
processFunctions = true |
|
||||||
|
|
||||||
noActionGrammar := grammar.Function{ |
|
||||||
Name: noActionName, |
|
||||||
Description: noActionDescription, |
|
||||||
Parameters: map[string]interface{}{ |
|
||||||
"properties": map[string]interface{}{ |
|
||||||
"message": map[string]interface{}{ |
|
||||||
"type": "string", |
|
||||||
"description": "The message to reply the user with", |
|
||||||
}}, |
|
||||||
}, |
|
||||||
} |
|
||||||
|
|
||||||
// Append the no action function
|
|
||||||
funcs = append(funcs, input.Functions...) |
|
||||||
if !config.FunctionsConfig.DisableNoAction { |
|
||||||
funcs = append(funcs, noActionGrammar) |
|
||||||
} |
|
||||||
|
|
||||||
// Force picking one of the functions by the request
|
|
||||||
if config.functionCallNameString != "" { |
|
||||||
funcs = funcs.Select(config.functionCallNameString) |
|
||||||
} |
|
||||||
|
|
||||||
// Update input grammar
|
|
||||||
jsStruct := funcs.ToJSONStructure() |
|
||||||
config.Grammar = jsStruct.Grammar("") |
|
||||||
} else if input.JSONFunctionGrammarObject != nil { |
|
||||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("") |
|
||||||
} |
|
||||||
|
|
||||||
// functions are not supported in stream mode (yet?)
|
|
||||||
toStream := input.Stream && !processFunctions |
|
||||||
|
|
||||||
log.Debug().Msgf("Parameters: %+v", config) |
|
||||||
|
|
||||||
var predInput string |
|
||||||
|
|
||||||
mess := []string{} |
|
||||||
for _, i := range input.Messages { |
|
||||||
var content string |
|
||||||
role := i.Role |
|
||||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
|
||||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
|
||||||
if i.FunctionCall != nil && i.Role == "assistant" { |
|
||||||
roleFn := "assistant_function_call" |
|
||||||
r := config.Roles[roleFn] |
|
||||||
if r != "" { |
|
||||||
role = roleFn |
|
||||||
} |
|
||||||
} |
|
||||||
r := config.Roles[role] |
|
||||||
contentExists := i.Content != nil && *i.Content != "" |
|
||||||
if r != "" { |
|
||||||
if contentExists { |
|
||||||
content = fmt.Sprint(r, " ", *i.Content) |
|
||||||
} |
|
||||||
if i.FunctionCall != nil { |
|
||||||
j, err := json.Marshal(i.FunctionCall) |
|
||||||
if err == nil { |
|
||||||
if contentExists { |
|
||||||
content += "\n" + fmt.Sprint(r, " ", string(j)) |
|
||||||
} else { |
|
||||||
content = fmt.Sprint(r, " ", string(j)) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} else { |
|
||||||
if contentExists { |
|
||||||
content = fmt.Sprint(*i.Content) |
|
||||||
} |
|
||||||
if i.FunctionCall != nil { |
|
||||||
j, err := json.Marshal(i.FunctionCall) |
|
||||||
if err == nil { |
|
||||||
if contentExists { |
|
||||||
content += "\n" + string(j) |
|
||||||
} else { |
|
||||||
content = string(j) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
mess = append(mess, content) |
|
||||||
} |
|
||||||
|
|
||||||
predInput = strings.Join(mess, "\n") |
|
||||||
log.Debug().Msgf("Prompt (before templating): %s", predInput) |
|
||||||
|
|
||||||
if toStream { |
|
||||||
log.Debug().Msgf("Stream request received") |
|
||||||
c.Context().SetContentType("text/event-stream") |
|
||||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
|
||||||
// c.Set("Content-Type", "text/event-stream")
|
|
||||||
c.Set("Cache-Control", "no-cache") |
|
||||||
c.Set("Connection", "keep-alive") |
|
||||||
c.Set("Transfer-Encoding", "chunked") |
|
||||||
} |
|
||||||
|
|
||||||
templateFile := config.Model |
|
||||||
|
|
||||||
if config.TemplateConfig.Chat != "" && !processFunctions { |
|
||||||
templateFile = config.TemplateConfig.Chat |
|
||||||
} |
|
||||||
|
|
||||||
if config.TemplateConfig.Functions != "" && processFunctions { |
|
||||||
templateFile = config.TemplateConfig.Functions |
|
||||||
} |
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
|
||||||
Input string |
|
||||||
Functions []grammar.Function |
|
||||||
}{ |
|
||||||
Input: predInput, |
|
||||||
Functions: funcs, |
|
||||||
}) |
|
||||||
if err == nil { |
|
||||||
predInput = templatedInput |
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
|
||||||
} else { |
|
||||||
log.Debug().Msgf("Template failed loading: %s", err.Error()) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Prompt (after templating): %s", predInput) |
|
||||||
if processFunctions { |
|
||||||
log.Debug().Msgf("Grammar: %+v", config.Grammar) |
|
||||||
} |
|
||||||
|
|
||||||
if toStream { |
|
||||||
responses := make(chan OpenAIResponse) |
|
||||||
|
|
||||||
go process(predInput, input, config, o.loader, responses) |
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
|
||||||
|
|
||||||
for ev := range responses { |
|
||||||
var buf bytes.Buffer |
|
||||||
enc := json.NewEncoder(&buf) |
|
||||||
enc.Encode(ev) |
|
||||||
|
|
||||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
|
||||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
|
||||||
w.Flush() |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{ |
|
||||||
{ |
|
||||||
FinishReason: "stop", |
|
||||||
Index: 0, |
|
||||||
Delta: &Message{}, |
|
||||||
}}, |
|
||||||
Object: "chat.completion.chunk", |
|
||||||
} |
|
||||||
respData, _ := json.Marshal(resp) |
|
||||||
|
|
||||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
|
||||||
w.WriteString("data: [DONE]\n\n") |
|
||||||
w.Flush() |
|
||||||
})) |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) { |
|
||||||
if processFunctions { |
|
||||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
|
||||||
ss := map[string]interface{}{} |
|
||||||
json.Unmarshal([]byte(s), &ss) |
|
||||||
log.Debug().Msgf("Function return: %s %+v", s, ss) |
|
||||||
|
|
||||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
|
||||||
func_name := ss["function"] |
|
||||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
|
||||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
|
||||||
d, _ := json.Marshal(args) |
|
||||||
|
|
||||||
ss["arguments"] = string(d) |
|
||||||
ss["name"] = func_name |
|
||||||
|
|
||||||
// if do nothing, reply with a message
|
|
||||||
if func_name == noActionName { |
|
||||||
log.Debug().Msgf("nothing to do, computing a reply") |
|
||||||
|
|
||||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
|
||||||
arguments := map[string]interface{}{} |
|
||||||
json.Unmarshal([]byte(d), &arguments) |
|
||||||
m, exists := arguments["message"] |
|
||||||
if exists { |
|
||||||
switch message := m.(type) { |
|
||||||
case string: |
|
||||||
if message != "" { |
|
||||||
log.Debug().Msgf("Reply received from LLM: %s", message) |
|
||||||
message = Finetune(*config, predInput, message) |
|
||||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) |
|
||||||
|
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) |
|
||||||
return |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply") |
|
||||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
|
||||||
// Note: This costs (in term of CPU) another computation
|
|
||||||
config.Grammar = "" |
|
||||||
predFunc, err := ModelInference(predInput, o.loader, *config, o, nil) |
|
||||||
if err != nil { |
|
||||||
log.Error().Msgf("inference error: %s", err.Error()) |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
prediction, err := predFunc() |
|
||||||
if err != nil { |
|
||||||
log.Error().Msgf("inference error: %s", err.Error()) |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
prediction = Finetune(*config, predInput, prediction) |
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) |
|
||||||
} else { |
|
||||||
// otherwise reply with the function call
|
|
||||||
*c = append(*c, Choice{ |
|
||||||
FinishReason: "function_call", |
|
||||||
Message: &Message{Role: "assistant", FunctionCall: ss}, |
|
||||||
}) |
|
||||||
} |
|
||||||
|
|
||||||
return |
|
||||||
} |
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) |
|
||||||
}, nil) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result, |
|
||||||
Object: "chat.completion", |
|
||||||
} |
|
||||||
respData, _ := json.Marshal(resp) |
|
||||||
log.Debug().Msgf("Response: %s", respData) |
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
model, input, err := readInput(c, o.loader, true) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config) |
|
||||||
|
|
||||||
templateFile := config.Model |
|
||||||
|
|
||||||
if config.TemplateConfig.Edit != "" { |
|
||||||
templateFile = config.TemplateConfig.Edit |
|
||||||
} |
|
||||||
|
|
||||||
var result []Choice |
|
||||||
for _, i := range config.InputStrings { |
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
|
||||||
Input string |
|
||||||
Instruction string |
|
||||||
}{Input: i}) |
|
||||||
if err == nil { |
|
||||||
i = templatedInput |
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
|
||||||
} |
|
||||||
|
|
||||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { |
|
||||||
*c = append(*c, Choice{Text: s}) |
|
||||||
}, nil) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
result = append(result, r...) |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result, |
|
||||||
Object: "edit", |
|
||||||
} |
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp) |
|
||||||
log.Debug().Msgf("Response: %s", jsonResult) |
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/images/create
|
|
||||||
|
|
||||||
/* |
|
||||||
* |
|
||||||
|
|
||||||
curl http://localhost:8080/v1/images/generations \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{ |
|
||||||
"prompt": "A cute baby sea otter", |
|
||||||
"n": 1, |
|
||||||
"size": "512x512" |
|
||||||
}' |
|
||||||
|
|
||||||
* |
|
||||||
*/ |
|
||||||
func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
m, input, err := readInput(c, o.loader, false) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
if m == "" { |
|
||||||
m = model.StableDiffusionBackend |
|
||||||
} |
|
||||||
log.Debug().Msgf("Loading model: %+v", m) |
|
||||||
|
|
||||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config) |
|
||||||
|
|
||||||
// XXX: Only stablediffusion is supported for now
|
|
||||||
if config.Backend == "" { |
|
||||||
config.Backend = model.StableDiffusionBackend |
|
||||||
} |
|
||||||
|
|
||||||
sizeParts := strings.Split(input.Size, "x") |
|
||||||
if len(sizeParts) != 2 { |
|
||||||
return fmt.Errorf("Invalid value for 'size'") |
|
||||||
} |
|
||||||
width, err := strconv.Atoi(sizeParts[0]) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("Invalid value for 'size'") |
|
||||||
} |
|
||||||
height, err := strconv.Atoi(sizeParts[1]) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("Invalid value for 'size'") |
|
||||||
} |
|
||||||
|
|
||||||
b64JSON := false |
|
||||||
if input.ResponseFormat == "b64_json" { |
|
||||||
b64JSON = true |
|
||||||
} |
|
||||||
|
|
||||||
var result []Item |
|
||||||
for _, i := range config.PromptStrings { |
|
||||||
n := input.N |
|
||||||
if input.N == 0 { |
|
||||||
n = 1 |
|
||||||
} |
|
||||||
for j := 0; j < n; j++ { |
|
||||||
prompts := strings.Split(i, "|") |
|
||||||
positive_prompt := prompts[0] |
|
||||||
negative_prompt := "" |
|
||||||
if len(prompts) > 1 { |
|
||||||
negative_prompt = prompts[1] |
|
||||||
} |
|
||||||
|
|
||||||
mode := 0 |
|
||||||
step := 15 |
|
||||||
|
|
||||||
if input.Mode != 0 { |
|
||||||
mode = input.Mode |
|
||||||
} |
|
||||||
|
|
||||||
if input.Step != 0 { |
|
||||||
step = input.Step |
|
||||||
} |
|
||||||
|
|
||||||
tempDir := "" |
|
||||||
if !b64JSON { |
|
||||||
tempDir = o.imageDir |
|
||||||
} |
|
||||||
// Create a temporary file
|
|
||||||
outputFile, err := ioutil.TempFile(tempDir, "b64") |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
outputFile.Close() |
|
||||||
output := outputFile.Name() + ".png" |
|
||||||
// Rename the temporary file
|
|
||||||
err = os.Rename(outputFile.Name(), output) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
baseURL := c.BaseURL() |
|
||||||
|
|
||||||
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config, o) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
if err := fn(); err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
item := &Item{} |
|
||||||
|
|
||||||
if b64JSON { |
|
||||||
defer os.RemoveAll(output) |
|
||||||
data, err := os.ReadFile(output) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
|
||||||
} else { |
|
||||||
base := filepath.Base(output) |
|
||||||
item.URL = baseURL + "/generated-images/" + base |
|
||||||
} |
|
||||||
|
|
||||||
result = append(result, *item) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Data: result, |
|
||||||
} |
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp) |
|
||||||
log.Debug().Msgf("Response: %s", jsonResult) |
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/audio/create
|
|
||||||
func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
m, input, err := readInput(c, o.loader, false) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
// retrieve the file data from the request
|
|
||||||
file, err := c.FormFile("file") |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
f, err := file.Open() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
defer f.Close() |
|
||||||
|
|
||||||
dir, err := os.MkdirTemp("", "whisper") |
|
||||||
|
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
defer os.RemoveAll(dir) |
|
||||||
|
|
||||||
dst := filepath.Join(dir, path.Base(file.Filename)) |
|
||||||
dstFile, err := os.Create(dst) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
if _, err := io.Copy(dstFile, f); err != nil { |
|
||||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Audio file copied to: %+v", dst) |
|
||||||
|
|
||||||
whisperModel, err := o.loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads), o.assetsDestination) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
if whisperModel == nil { |
|
||||||
return fmt.Errorf("could not load whisper model") |
|
||||||
} |
|
||||||
|
|
||||||
w, ok := whisperModel.(whisper.Model) |
|
||||||
if !ok { |
|
||||||
return fmt.Errorf("loader returned non-whisper object") |
|
||||||
} |
|
||||||
|
|
||||||
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Trascribed: %+v", tr) |
|
||||||
// TODO: handle different outputs here
|
|
||||||
return c.Status(http.StatusOK).JSON(tr) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
models, err := loader.ListModels() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
var mm map[string]interface{} = map[string]interface{}{} |
|
||||||
|
|
||||||
dataModels := []OpenAIModel{} |
|
||||||
for _, m := range models { |
|
||||||
mm[m] = nil |
|
||||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) |
|
||||||
} |
|
||||||
|
|
||||||
for _, k := range cm.ListConfigs() { |
|
||||||
if _, exists := mm[k]; !exists { |
|
||||||
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return c.JSON(struct { |
|
||||||
Object string `json:"object"` |
|
||||||
Data []OpenAIModel `json:"data"` |
|
||||||
}{ |
|
||||||
Object: "list", |
|
||||||
Data: dataModels, |
|
||||||
}) |
|
||||||
} |
|
||||||
} |
|
@ -0,0 +1,105 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||||
|
) |
||||||
|
|
||||||
|
// APIError provides error information returned by the OpenAI API.
|
||||||
|
type APIError struct { |
||||||
|
Code any `json:"code,omitempty"` |
||||||
|
Message string `json:"message"` |
||||||
|
Param *string `json:"param,omitempty"` |
||||||
|
Type string `json:"type"` |
||||||
|
} |
||||||
|
|
||||||
|
type ErrorResponse struct { |
||||||
|
Error *APIError `json:"error,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
type OpenAIUsage struct { |
||||||
|
PromptTokens int `json:"prompt_tokens"` |
||||||
|
CompletionTokens int `json:"completion_tokens"` |
||||||
|
TotalTokens int `json:"total_tokens"` |
||||||
|
} |
||||||
|
|
||||||
|
type Item struct { |
||||||
|
Embedding []float32 `json:"embedding"` |
||||||
|
Index int `json:"index"` |
||||||
|
Object string `json:"object,omitempty"` |
||||||
|
|
||||||
|
// Images
|
||||||
|
URL string `json:"url,omitempty"` |
||||||
|
B64JSON string `json:"b64_json,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
type OpenAIResponse struct { |
||||||
|
Created int `json:"created,omitempty"` |
||||||
|
Object string `json:"object,omitempty"` |
||||||
|
ID string `json:"id,omitempty"` |
||||||
|
Model string `json:"model,omitempty"` |
||||||
|
Choices []Choice `json:"choices,omitempty"` |
||||||
|
Data []Item `json:"data,omitempty"` |
||||||
|
|
||||||
|
Usage OpenAIUsage `json:"usage"` |
||||||
|
} |
||||||
|
|
||||||
|
type Choice struct { |
||||||
|
Index int `json:"index,omitempty"` |
||||||
|
FinishReason string `json:"finish_reason,omitempty"` |
||||||
|
Message *Message `json:"message,omitempty"` |
||||||
|
Delta *Message `json:"delta,omitempty"` |
||||||
|
Text string `json:"text,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
type Message struct { |
||||||
|
// The message role
|
||||||
|
Role string `json:"role,omitempty" yaml:"role"` |
||||||
|
// The message content
|
||||||
|
Content *string `json:"content" yaml:"content"` |
||||||
|
// A result of a function call
|
||||||
|
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
type OpenAIModel struct { |
||||||
|
ID string `json:"id"` |
||||||
|
Object string `json:"object"` |
||||||
|
} |
||||||
|
|
||||||
|
type OpenAIRequest struct { |
||||||
|
config.PredictionOptions |
||||||
|
|
||||||
|
// whisper
|
||||||
|
File string `json:"file" validate:"required"` |
||||||
|
//whisper/image
|
||||||
|
ResponseFormat string `json:"response_format"` |
||||||
|
// image
|
||||||
|
Size string `json:"size"` |
||||||
|
// Prompt is read only by completion/image API calls
|
||||||
|
Prompt interface{} `json:"prompt" yaml:"prompt"` |
||||||
|
|
||||||
|
// Edit endpoint
|
||||||
|
Instruction string `json:"instruction" yaml:"instruction"` |
||||||
|
Input interface{} `json:"input" yaml:"input"` |
||||||
|
|
||||||
|
Stop interface{} `json:"stop" yaml:"stop"` |
||||||
|
|
||||||
|
// Messages is read only by chat/completion API calls
|
||||||
|
Messages []Message `json:"messages" yaml:"messages"` |
||||||
|
|
||||||
|
// A list of available functions to call
|
||||||
|
Functions []grammar.Function `json:"functions" yaml:"functions"` |
||||||
|
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||||
|
|
||||||
|
Stream bool `json:"stream"` |
||||||
|
|
||||||
|
// Image (not supported by OpenAI)
|
||||||
|
Mode int `json:"mode"` |
||||||
|
Step int `json:"step"` |
||||||
|
|
||||||
|
// A grammar to constrain the LLM output
|
||||||
|
Grammar string `json:"grammar" yaml:"grammar"` |
||||||
|
|
||||||
|
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` |
||||||
|
} |
@ -0,0 +1,320 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"bufio" |
||||||
|
"bytes" |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
"strings" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend" |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
"github.com/valyala/fasthttp" |
||||||
|
) |
||||||
|
|
||||||
|
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||||
|
initialMessage := OpenAIResponse{ |
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, |
||||||
|
Object: "chat.completion.chunk", |
||||||
|
} |
||||||
|
responses <- initialMessage |
||||||
|
|
||||||
|
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||||
|
resp := OpenAIResponse{ |
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, |
||||||
|
Object: "chat.completion.chunk", |
||||||
|
} |
||||||
|
|
||||||
|
responses <- resp |
||||||
|
return true |
||||||
|
}) |
||||||
|
close(responses) |
||||||
|
} |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
processFunctions := false |
||||||
|
funcs := grammar.Functions{} |
||||||
|
model, input, err := readInput(c, o.Loader, true) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
log.Debug().Msgf("Configuration read: %+v", config) |
||||||
|
|
||||||
|
// Allow the user to set custom actions via config file
|
||||||
|
// to be "embedded" in each model
|
||||||
|
noActionName := "answer" |
||||||
|
noActionDescription := "use this action to answer without performing any action" |
||||||
|
|
||||||
|
if config.FunctionsConfig.NoActionFunctionName != "" { |
||||||
|
noActionName = config.FunctionsConfig.NoActionFunctionName |
||||||
|
} |
||||||
|
if config.FunctionsConfig.NoActionDescriptionName != "" { |
||||||
|
noActionDescription = config.FunctionsConfig.NoActionDescriptionName |
||||||
|
} |
||||||
|
|
||||||
|
// process functions if we have any defined or if we have a function call string
|
||||||
|
if len(input.Functions) > 0 && config.ShouldUseFunctions() { |
||||||
|
log.Debug().Msgf("Response needs to process functions") |
||||||
|
|
||||||
|
processFunctions = true |
||||||
|
|
||||||
|
noActionGrammar := grammar.Function{ |
||||||
|
Name: noActionName, |
||||||
|
Description: noActionDescription, |
||||||
|
Parameters: map[string]interface{}{ |
||||||
|
"properties": map[string]interface{}{ |
||||||
|
"message": map[string]interface{}{ |
||||||
|
"type": "string", |
||||||
|
"description": "The message to reply the user with", |
||||||
|
}}, |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
// Append the no action function
|
||||||
|
funcs = append(funcs, input.Functions...) |
||||||
|
if !config.FunctionsConfig.DisableNoAction { |
||||||
|
funcs = append(funcs, noActionGrammar) |
||||||
|
} |
||||||
|
|
||||||
|
// Force picking one of the functions by the request
|
||||||
|
if config.FunctionToCall() != "" { |
||||||
|
funcs = funcs.Select(config.FunctionToCall()) |
||||||
|
} |
||||||
|
|
||||||
|
// Update input grammar
|
||||||
|
jsStruct := funcs.ToJSONStructure() |
||||||
|
config.Grammar = jsStruct.Grammar("") |
||||||
|
} else if input.JSONFunctionGrammarObject != nil { |
||||||
|
config.Grammar = input.JSONFunctionGrammarObject.Grammar("") |
||||||
|
} |
||||||
|
|
||||||
|
// functions are not supported in stream mode (yet?)
|
||||||
|
toStream := input.Stream && !processFunctions |
||||||
|
|
||||||
|
log.Debug().Msgf("Parameters: %+v", config) |
||||||
|
|
||||||
|
var predInput string |
||||||
|
|
||||||
|
mess := []string{} |
||||||
|
for _, i := range input.Messages { |
||||||
|
var content string |
||||||
|
role := i.Role |
||||||
|
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||||
|
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||||
|
if i.FunctionCall != nil && i.Role == "assistant" { |
||||||
|
roleFn := "assistant_function_call" |
||||||
|
r := config.Roles[roleFn] |
||||||
|
if r != "" { |
||||||
|
role = roleFn |
||||||
|
} |
||||||
|
} |
||||||
|
r := config.Roles[role] |
||||||
|
contentExists := i.Content != nil && *i.Content != "" |
||||||
|
if r != "" { |
||||||
|
if contentExists { |
||||||
|
content = fmt.Sprint(r, " ", *i.Content) |
||||||
|
} |
||||||
|
if i.FunctionCall != nil { |
||||||
|
j, err := json.Marshal(i.FunctionCall) |
||||||
|
if err == nil { |
||||||
|
if contentExists { |
||||||
|
content += "\n" + fmt.Sprint(r, " ", string(j)) |
||||||
|
} else { |
||||||
|
content = fmt.Sprint(r, " ", string(j)) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} else { |
||||||
|
if contentExists { |
||||||
|
content = fmt.Sprint(*i.Content) |
||||||
|
} |
||||||
|
if i.FunctionCall != nil { |
||||||
|
j, err := json.Marshal(i.FunctionCall) |
||||||
|
if err == nil { |
||||||
|
if contentExists { |
||||||
|
content += "\n" + string(j) |
||||||
|
} else { |
||||||
|
content = string(j) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
mess = append(mess, content) |
||||||
|
} |
||||||
|
|
||||||
|
predInput = strings.Join(mess, "\n") |
||||||
|
log.Debug().Msgf("Prompt (before templating): %s", predInput) |
||||||
|
|
||||||
|
if toStream { |
||||||
|
log.Debug().Msgf("Stream request received") |
||||||
|
c.Context().SetContentType("text/event-stream") |
||||||
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||||
|
// c.Set("Content-Type", "text/event-stream")
|
||||||
|
c.Set("Cache-Control", "no-cache") |
||||||
|
c.Set("Connection", "keep-alive") |
||||||
|
c.Set("Transfer-Encoding", "chunked") |
||||||
|
} |
||||||
|
|
||||||
|
templateFile := config.Model |
||||||
|
|
||||||
|
if config.TemplateConfig.Chat != "" && !processFunctions { |
||||||
|
templateFile = config.TemplateConfig.Chat |
||||||
|
} |
||||||
|
|
||||||
|
if config.TemplateConfig.Functions != "" && processFunctions { |
||||||
|
templateFile = config.TemplateConfig.Functions |
||||||
|
} |
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||||
|
Input string |
||||||
|
Functions []grammar.Function |
||||||
|
}{ |
||||||
|
Input: predInput, |
||||||
|
Functions: funcs, |
||||||
|
}) |
||||||
|
if err == nil { |
||||||
|
predInput = templatedInput |
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||||
|
} else { |
||||||
|
log.Debug().Msgf("Template failed loading: %s", err.Error()) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Prompt (after templating): %s", predInput) |
||||||
|
if processFunctions { |
||||||
|
log.Debug().Msgf("Grammar: %+v", config.Grammar) |
||||||
|
} |
||||||
|
|
||||||
|
if toStream { |
||||||
|
responses := make(chan OpenAIResponse) |
||||||
|
|
||||||
|
go process(predInput, input, config, o.Loader, responses) |
||||||
|
|
||||||
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||||
|
|
||||||
|
for ev := range responses { |
||||||
|
var buf bytes.Buffer |
||||||
|
enc := json.NewEncoder(&buf) |
||||||
|
enc.Encode(ev) |
||||||
|
|
||||||
|
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||||
|
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||||
|
w.Flush() |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{ |
||||||
|
{ |
||||||
|
FinishReason: "stop", |
||||||
|
Index: 0, |
||||||
|
Delta: &Message{}, |
||||||
|
}}, |
||||||
|
Object: "chat.completion.chunk", |
||||||
|
} |
||||||
|
respData, _ := json.Marshal(resp) |
||||||
|
|
||||||
|
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||||
|
w.WriteString("data: [DONE]\n\n") |
||||||
|
w.Flush() |
||||||
|
})) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||||
|
if processFunctions { |
||||||
|
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||||
|
ss := map[string]interface{}{} |
||||||
|
json.Unmarshal([]byte(s), &ss) |
||||||
|
log.Debug().Msgf("Function return: %s %+v", s, ss) |
||||||
|
|
||||||
|
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||||
|
func_name := ss["function"] |
||||||
|
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||||
|
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||||
|
d, _ := json.Marshal(args) |
||||||
|
|
||||||
|
ss["arguments"] = string(d) |
||||||
|
ss["name"] = func_name |
||||||
|
|
||||||
|
// if do nothing, reply with a message
|
||||||
|
if func_name == noActionName { |
||||||
|
log.Debug().Msgf("nothing to do, computing a reply") |
||||||
|
|
||||||
|
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||||
|
arguments := map[string]interface{}{} |
||||||
|
json.Unmarshal([]byte(d), &arguments) |
||||||
|
m, exists := arguments["message"] |
||||||
|
if exists { |
||||||
|
switch message := m.(type) { |
||||||
|
case string: |
||||||
|
if message != "" { |
||||||
|
log.Debug().Msgf("Reply received from LLM: %s", message) |
||||||
|
message = backend.Finetune(*config, predInput, message) |
||||||
|
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) |
||||||
|
|
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("No action received from LLM, without a message, computing a reply") |
||||||
|
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||||
|
// Note: This costs (in term of CPU) another computation
|
||||||
|
config.Grammar = "" |
||||||
|
predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil) |
||||||
|
if err != nil { |
||||||
|
log.Error().Msgf("inference error: %s", err.Error()) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
prediction, err := predFunc() |
||||||
|
if err != nil { |
||||||
|
log.Error().Msgf("inference error: %s", err.Error()) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
prediction = backend.Finetune(*config, predInput, prediction) |
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) |
||||||
|
} else { |
||||||
|
// otherwise reply with the function call
|
||||||
|
*c = append(*c, Choice{ |
||||||
|
FinishReason: "function_call", |
||||||
|
Message: &Message{Role: "assistant", FunctionCall: ss}, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
return |
||||||
|
} |
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) |
||||||
|
}, nil) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result, |
||||||
|
Object: "chat.completion", |
||||||
|
} |
||||||
|
respData, _ := json.Marshal(resp) |
||||||
|
log.Debug().Msgf("Response: %s", respData) |
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,159 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"bufio" |
||||||
|
"bytes" |
||||||
|
"encoding/json" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
"github.com/valyala/fasthttp" |
||||||
|
) |
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/completions
|
||||||
|
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||||
|
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||||
|
resp := OpenAIResponse{ |
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{ |
||||||
|
{ |
||||||
|
Index: 0, |
||||||
|
Text: s, |
||||||
|
}, |
||||||
|
}, |
||||||
|
Object: "text_completion", |
||||||
|
} |
||||||
|
log.Debug().Msgf("Sending goroutine: %s", s) |
||||||
|
|
||||||
|
responses <- resp |
||||||
|
return true |
||||||
|
}) |
||||||
|
close(responses) |
||||||
|
} |
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
model, input, err := readInput(c, o.Loader, true) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("`input`: %+v", input) |
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config) |
||||||
|
|
||||||
|
if input.Stream { |
||||||
|
log.Debug().Msgf("Stream request received") |
||||||
|
c.Context().SetContentType("text/event-stream") |
||||||
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||||
|
//c.Set("Content-Type", "text/event-stream")
|
||||||
|
c.Set("Cache-Control", "no-cache") |
||||||
|
c.Set("Connection", "keep-alive") |
||||||
|
c.Set("Transfer-Encoding", "chunked") |
||||||
|
} |
||||||
|
|
||||||
|
templateFile := config.Model |
||||||
|
|
||||||
|
if config.TemplateConfig.Completion != "" { |
||||||
|
templateFile = config.TemplateConfig.Completion |
||||||
|
} |
||||||
|
|
||||||
|
if input.Stream { |
||||||
|
if len(config.PromptStrings) > 1 { |
||||||
|
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") |
||||||
|
} |
||||||
|
|
||||||
|
predInput := config.PromptStrings[0] |
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||||
|
Input string |
||||||
|
}{ |
||||||
|
Input: predInput, |
||||||
|
}) |
||||||
|
if err == nil { |
||||||
|
predInput = templatedInput |
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||||
|
} |
||||||
|
|
||||||
|
responses := make(chan OpenAIResponse) |
||||||
|
|
||||||
|
go process(predInput, input, config, o.Loader, responses) |
||||||
|
|
||||||
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||||
|
|
||||||
|
for ev := range responses { |
||||||
|
var buf bytes.Buffer |
||||||
|
enc := json.NewEncoder(&buf) |
||||||
|
enc.Encode(ev) |
||||||
|
|
||||||
|
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||||
|
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||||
|
w.Flush() |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{ |
||||||
|
{ |
||||||
|
Index: 0, |
||||||
|
FinishReason: "stop", |
||||||
|
}, |
||||||
|
}, |
||||||
|
Object: "text_completion", |
||||||
|
} |
||||||
|
respData, _ := json.Marshal(resp) |
||||||
|
|
||||||
|
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||||
|
w.WriteString("data: [DONE]\n\n") |
||||||
|
w.Flush() |
||||||
|
})) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
var result []Choice |
||||||
|
for _, i := range config.PromptStrings { |
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||||
|
Input string |
||||||
|
}{ |
||||||
|
Input: i, |
||||||
|
}) |
||||||
|
if err == nil { |
||||||
|
i = templatedInput |
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||||
|
} |
||||||
|
|
||||||
|
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||||
|
*c = append(*c, Choice{Text: s}) |
||||||
|
}, nil) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
result = append(result, r...) |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result, |
||||||
|
Object: "text_completion", |
||||||
|
} |
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp) |
||||||
|
log.Debug().Msgf("Response: %s", jsonResult) |
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,67 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
) |
||||||
|
|
||||||
|
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
model, input, err := readInput(c, o.Loader, true) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config) |
||||||
|
|
||||||
|
templateFile := config.Model |
||||||
|
|
||||||
|
if config.TemplateConfig.Edit != "" { |
||||||
|
templateFile = config.TemplateConfig.Edit |
||||||
|
} |
||||||
|
|
||||||
|
var result []Choice |
||||||
|
for _, i := range config.InputStrings { |
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||||
|
Input string |
||||||
|
Instruction string |
||||||
|
}{Input: i}) |
||||||
|
if err == nil { |
||||||
|
i = templatedInput |
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||||
|
} |
||||||
|
|
||||||
|
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||||
|
*c = append(*c, Choice{Text: s}) |
||||||
|
}, nil) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
result = append(result, r...) |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result, |
||||||
|
Object: "edit", |
||||||
|
} |
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp) |
||||||
|
log.Debug().Msgf("Response: %s", jsonResult) |
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,70 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend" |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
) |
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/embeddings
|
||||||
|
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
model, input, err := readInput(c, o.Loader, true) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config) |
||||||
|
items := []Item{} |
||||||
|
|
||||||
|
for i, s := range config.InputToken { |
||||||
|
// get the model function to call for the result
|
||||||
|
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
embeddings, err := embedFn() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||||
|
} |
||||||
|
|
||||||
|
for i, s := range config.InputStrings { |
||||||
|
// get the model function to call for the result
|
||||||
|
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
embeddings, err := embedFn() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Data: items, |
||||||
|
Object: "list", |
||||||
|
} |
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp) |
||||||
|
log.Debug().Msgf("Response: %s", jsonResult) |
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,158 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/base64" |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
"io/ioutil" |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend" |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
) |
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/images/create
|
||||||
|
|
||||||
|
/* |
||||||
|
* |
||||||
|
|
||||||
|
curl http://localhost:8080/v1/images/generations \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{ |
||||||
|
"prompt": "A cute baby sea otter", |
||||||
|
"n": 1, |
||||||
|
"size": "512x512" |
||||||
|
}' |
||||||
|
|
||||||
|
* |
||||||
|
*/ |
||||||
|
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
m, input, err := readInput(c, o.Loader, false) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
if m == "" { |
||||||
|
m = model.StableDiffusionBackend |
||||||
|
} |
||||||
|
log.Debug().Msgf("Loading model: %+v", m) |
||||||
|
|
||||||
|
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config) |
||||||
|
|
||||||
|
// XXX: Only stablediffusion is supported for now
|
||||||
|
if config.Backend == "" { |
||||||
|
config.Backend = model.StableDiffusionBackend |
||||||
|
} |
||||||
|
|
||||||
|
sizeParts := strings.Split(input.Size, "x") |
||||||
|
if len(sizeParts) != 2 { |
||||||
|
return fmt.Errorf("Invalid value for 'size'") |
||||||
|
} |
||||||
|
width, err := strconv.Atoi(sizeParts[0]) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("Invalid value for 'size'") |
||||||
|
} |
||||||
|
height, err := strconv.Atoi(sizeParts[1]) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("Invalid value for 'size'") |
||||||
|
} |
||||||
|
|
||||||
|
b64JSON := false |
||||||
|
if input.ResponseFormat == "b64_json" { |
||||||
|
b64JSON = true |
||||||
|
} |
||||||
|
|
||||||
|
var result []Item |
||||||
|
for _, i := range config.PromptStrings { |
||||||
|
n := input.N |
||||||
|
if input.N == 0 { |
||||||
|
n = 1 |
||||||
|
} |
||||||
|
for j := 0; j < n; j++ { |
||||||
|
prompts := strings.Split(i, "|") |
||||||
|
positive_prompt := prompts[0] |
||||||
|
negative_prompt := "" |
||||||
|
if len(prompts) > 1 { |
||||||
|
negative_prompt = prompts[1] |
||||||
|
} |
||||||
|
|
||||||
|
mode := 0 |
||||||
|
step := 15 |
||||||
|
|
||||||
|
if input.Mode != 0 { |
||||||
|
mode = input.Mode |
||||||
|
} |
||||||
|
|
||||||
|
if input.Step != 0 { |
||||||
|
step = input.Step |
||||||
|
} |
||||||
|
|
||||||
|
tempDir := "" |
||||||
|
if !b64JSON { |
||||||
|
tempDir = o.ImageDir |
||||||
|
} |
||||||
|
// Create a temporary file
|
||||||
|
outputFile, err := ioutil.TempFile(tempDir, "b64") |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
outputFile.Close() |
||||||
|
output := outputFile.Name() + ".png" |
||||||
|
// Rename the temporary file
|
||||||
|
err = os.Rename(outputFile.Name(), output) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
baseURL := c.BaseURL() |
||||||
|
|
||||||
|
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.Loader, *config, o) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if err := fn(); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
item := &Item{} |
||||||
|
|
||||||
|
if b64JSON { |
||||||
|
defer os.RemoveAll(output) |
||||||
|
data, err := os.ReadFile(output) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
||||||
|
} else { |
||||||
|
base := filepath.Base(output) |
||||||
|
item.URL = baseURL + "/generated-images/" + base |
||||||
|
} |
||||||
|
|
||||||
|
result = append(result, *item) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Data: result, |
||||||
|
} |
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp) |
||||||
|
log.Debug().Msgf("Response: %s", jsonResult) |
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,36 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/go-skynet/LocalAI/api/backend" |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
) |
||||||
|
|
||||||
|
func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { |
||||||
|
result := []Choice{} |
||||||
|
|
||||||
|
if n == 0 { |
||||||
|
n = 1 |
||||||
|
} |
||||||
|
|
||||||
|
// get the model function to call for the result
|
||||||
|
predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback) |
||||||
|
if err != nil { |
||||||
|
return result, err |
||||||
|
} |
||||||
|
|
||||||
|
for i := 0; i < n; i++ { |
||||||
|
prediction, err := predFunc() |
||||||
|
if err != nil { |
||||||
|
return result, err |
||||||
|
} |
||||||
|
|
||||||
|
prediction = backend.Finetune(*config, predInput, prediction) |
||||||
|
cb(prediction, &result) |
||||||
|
|
||||||
|
//result = append(result, Choice{Text: prediction})
|
||||||
|
|
||||||
|
} |
||||||
|
return result, err |
||||||
|
} |
@ -0,0 +1,37 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
) |
||||||
|
|
||||||
|
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
models, err := loader.ListModels() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
var mm map[string]interface{} = map[string]interface{}{} |
||||||
|
|
||||||
|
dataModels := []OpenAIModel{} |
||||||
|
for _, m := range models { |
||||||
|
mm[m] = nil |
||||||
|
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) |
||||||
|
} |
||||||
|
|
||||||
|
for _, k := range cm.ListConfigs() { |
||||||
|
if _, exists := mm[k]; !exists { |
||||||
|
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return c.JSON(struct { |
||||||
|
Object string `json:"object"` |
||||||
|
Data []OpenAIModel `json:"data"` |
||||||
|
}{ |
||||||
|
Object: "list", |
||||||
|
Data: dataModels, |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,234 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
"strings" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
) |
||||||
|
|
||||||
|
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { |
||||||
|
input := new(OpenAIRequest) |
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil { |
||||||
|
return "", nil, err |
||||||
|
} |
||||||
|
|
||||||
|
modelFile := input.Model |
||||||
|
|
||||||
|
if c.Params("model") != "" { |
||||||
|
modelFile = c.Params("model") |
||||||
|
} |
||||||
|
|
||||||
|
received, _ := json.Marshal(input) |
||||||
|
|
||||||
|
log.Debug().Msgf("Request received: %s", string(received)) |
||||||
|
|
||||||
|
// Set model from bearer token, if available
|
||||||
|
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
||||||
|
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
||||||
|
|
||||||
|
// If no model was specified, take the first available
|
||||||
|
if modelFile == "" && !bearerExists && randomModel { |
||||||
|
models, _ := loader.ListModels() |
||||||
|
if len(models) > 0 { |
||||||
|
modelFile = models[0] |
||||||
|
log.Debug().Msgf("No model specified, using: %s", modelFile) |
||||||
|
} else { |
||||||
|
log.Debug().Msgf("No model specified, returning error") |
||||||
|
return "", nil, fmt.Errorf("no model specified") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// If a model is found in bearer token takes precedence
|
||||||
|
if bearerExists { |
||||||
|
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
||||||
|
modelFile = bearer |
||||||
|
} |
||||||
|
return modelFile, input, nil |
||||||
|
} |
||||||
|
|
||||||
|
func updateConfig(config *config.Config, input *OpenAIRequest) { |
||||||
|
if input.Echo { |
||||||
|
config.Echo = input.Echo |
||||||
|
} |
||||||
|
if input.TopK != 0 { |
||||||
|
config.TopK = input.TopK |
||||||
|
} |
||||||
|
if input.TopP != 0 { |
||||||
|
config.TopP = input.TopP |
||||||
|
} |
||||||
|
|
||||||
|
if input.Grammar != "" { |
||||||
|
config.Grammar = input.Grammar |
||||||
|
} |
||||||
|
|
||||||
|
if input.Temperature != 0 { |
||||||
|
config.Temperature = input.Temperature |
||||||
|
} |
||||||
|
|
||||||
|
if input.Maxtokens != 0 { |
||||||
|
config.Maxtokens = input.Maxtokens |
||||||
|
} |
||||||
|
|
||||||
|
switch stop := input.Stop.(type) { |
||||||
|
case string: |
||||||
|
if stop != "" { |
||||||
|
config.StopWords = append(config.StopWords, stop) |
||||||
|
} |
||||||
|
case []interface{}: |
||||||
|
for _, pp := range stop { |
||||||
|
if s, ok := pp.(string); ok { |
||||||
|
config.StopWords = append(config.StopWords, s) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if input.RepeatPenalty != 0 { |
||||||
|
config.RepeatPenalty = input.RepeatPenalty |
||||||
|
} |
||||||
|
|
||||||
|
if input.Keep != 0 { |
||||||
|
config.Keep = input.Keep |
||||||
|
} |
||||||
|
|
||||||
|
if input.Batch != 0 { |
||||||
|
config.Batch = input.Batch |
||||||
|
} |
||||||
|
|
||||||
|
if input.F16 { |
||||||
|
config.F16 = input.F16 |
||||||
|
} |
||||||
|
|
||||||
|
if input.IgnoreEOS { |
||||||
|
config.IgnoreEOS = input.IgnoreEOS |
||||||
|
} |
||||||
|
|
||||||
|
if input.Seed != 0 { |
||||||
|
config.Seed = input.Seed |
||||||
|
} |
||||||
|
|
||||||
|
if input.Mirostat != 0 { |
||||||
|
config.Mirostat = input.Mirostat |
||||||
|
} |
||||||
|
|
||||||
|
if input.MirostatETA != 0 { |
||||||
|
config.MirostatETA = input.MirostatETA |
||||||
|
} |
||||||
|
|
||||||
|
if input.MirostatTAU != 0 { |
||||||
|
config.MirostatTAU = input.MirostatTAU |
||||||
|
} |
||||||
|
|
||||||
|
if input.TypicalP != 0 { |
||||||
|
config.TypicalP = input.TypicalP |
||||||
|
} |
||||||
|
|
||||||
|
switch inputs := input.Input.(type) { |
||||||
|
case string: |
||||||
|
if inputs != "" { |
||||||
|
config.InputStrings = append(config.InputStrings, inputs) |
||||||
|
} |
||||||
|
case []interface{}: |
||||||
|
for _, pp := range inputs { |
||||||
|
switch i := pp.(type) { |
||||||
|
case string: |
||||||
|
config.InputStrings = append(config.InputStrings, i) |
||||||
|
case []interface{}: |
||||||
|
tokens := []int{} |
||||||
|
for _, ii := range i { |
||||||
|
tokens = append(tokens, int(ii.(float64))) |
||||||
|
} |
||||||
|
config.InputToken = append(config.InputToken, tokens) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Can be either a string or an object
|
||||||
|
switch fnc := input.FunctionCall.(type) { |
||||||
|
case string: |
||||||
|
if fnc != "" { |
||||||
|
config.SetFunctionCallString(fnc) |
||||||
|
} |
||||||
|
case map[string]interface{}: |
||||||
|
var name string |
||||||
|
n, exists := fnc["name"] |
||||||
|
if exists { |
||||||
|
nn, e := n.(string) |
||||||
|
if !e { |
||||||
|
name = nn |
||||||
|
} |
||||||
|
} |
||||||
|
config.SetFunctionCallNameString(name) |
||||||
|
} |
||||||
|
|
||||||
|
switch p := input.Prompt.(type) { |
||||||
|
case string: |
||||||
|
config.PromptStrings = append(config.PromptStrings, p) |
||||||
|
case []interface{}: |
||||||
|
for _, pp := range p { |
||||||
|
if s, ok := pp.(string); ok { |
||||||
|
config.PromptStrings = append(config.PromptStrings, s) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) { |
||||||
|
// Load a config file if present after the model name
|
||||||
|
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") |
||||||
|
|
||||||
|
var cfg *config.Config |
||||||
|
|
||||||
|
defaults := func() { |
||||||
|
cfg = config.DefaultConfig(modelFile) |
||||||
|
cfg.ContextSize = ctx |
||||||
|
cfg.Threads = threads |
||||||
|
cfg.F16 = f16 |
||||||
|
cfg.Debug = debug |
||||||
|
} |
||||||
|
|
||||||
|
cfgExisting, exists := cm.GetConfig(modelFile) |
||||||
|
if !exists { |
||||||
|
if _, err := os.Stat(modelConfig); err == nil { |
||||||
|
if err := cm.LoadConfig(modelConfig); err != nil { |
||||||
|
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) |
||||||
|
} |
||||||
|
cfgExisting, exists = cm.GetConfig(modelFile) |
||||||
|
if exists { |
||||||
|
cfg = &cfgExisting |
||||||
|
} else { |
||||||
|
defaults() |
||||||
|
} |
||||||
|
} else { |
||||||
|
defaults() |
||||||
|
} |
||||||
|
} else { |
||||||
|
cfg = &cfgExisting |
||||||
|
} |
||||||
|
|
||||||
|
// Set the parameters for the language model prediction
|
||||||
|
updateConfig(cfg, input) |
||||||
|
|
||||||
|
// Don't allow 0 as setting
|
||||||
|
if cfg.Threads == 0 { |
||||||
|
if threads != 0 { |
||||||
|
cfg.Threads = threads |
||||||
|
} else { |
||||||
|
cfg.Threads = 4 |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Enforce debug flag if passed from CLI
|
||||||
|
if debug { |
||||||
|
cfg.Debug = true |
||||||
|
} |
||||||
|
|
||||||
|
return cfg, input, nil |
||||||
|
} |
@ -0,0 +1,91 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"net/http" |
||||||
|
"os" |
||||||
|
"path" |
||||||
|
"path/filepath" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
) |
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/audio/create
|
||||||
|
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
m, input, err := readInput(c, o.Loader, false) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
// retrieve the file data from the request
|
||||||
|
file, err := c.FormFile("file") |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
f, err := file.Open() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
defer f.Close() |
||||||
|
|
||||||
|
dir, err := os.MkdirTemp("", "whisper") |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
defer os.RemoveAll(dir) |
||||||
|
|
||||||
|
dst := filepath.Join(dir, path.Base(file.Filename)) |
||||||
|
dstFile, err := os.Create(dst) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if _, err := io.Copy(dstFile, f); err != nil { |
||||||
|
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Audio file copied to: %+v", dst) |
||||||
|
|
||||||
|
whisperModel, err := o.Loader.BackendLoader( |
||||||
|
model.WithBackendString(model.WhisperBackend), |
||||||
|
model.WithModelFile(config.Model), |
||||||
|
model.WithContext(o.Context), |
||||||
|
model.WithThreads(uint32(config.Threads)), |
||||||
|
model.WithAssetDir(o.AssetsDestination)) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if whisperModel == nil { |
||||||
|
return fmt.Errorf("could not load whisper model") |
||||||
|
} |
||||||
|
|
||||||
|
tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ |
||||||
|
Dst: dst, |
||||||
|
Language: input.Language, |
||||||
|
Threads: uint32(config.Threads), |
||||||
|
}) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Trascribed: %+v", tr) |
||||||
|
// TODO: handle different outputs here
|
||||||
|
return c.Status(http.StatusOK).JSON(tr) |
||||||
|
} |
||||||
|
} |
@ -1,649 +0,0 @@ |
|||||||
package api |
|
||||||
|
|
||||||
import ( |
|
||||||
"fmt" |
|
||||||
"os" |
|
||||||
"path/filepath" |
|
||||||
"regexp" |
|
||||||
"strings" |
|
||||||
"sync" |
|
||||||
|
|
||||||
"github.com/donomii/go-rwkv.cpp" |
|
||||||
"github.com/go-skynet/LocalAI/pkg/langchain" |
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model" |
|
||||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion" |
|
||||||
"github.com/go-skynet/bloomz.cpp" |
|
||||||
bert "github.com/go-skynet/go-bert.cpp" |
|
||||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
|
||||||
llama "github.com/go-skynet/go-llama.cpp" |
|
||||||
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" |
|
||||||
) |
|
||||||
|
|
||||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
var mutexMap sync.Mutex |
|
||||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) |
|
||||||
|
|
||||||
func defaultLLamaOpts(c Config) []llama.ModelOption { |
|
||||||
llamaOpts := []llama.ModelOption{} |
|
||||||
if c.ContextSize != 0 { |
|
||||||
llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize)) |
|
||||||
} |
|
||||||
if c.F16 { |
|
||||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory) |
|
||||||
} |
|
||||||
if c.Embeddings { |
|
||||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings) |
|
||||||
} |
|
||||||
|
|
||||||
if c.NGPULayers != 0 { |
|
||||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers)) |
|
||||||
} |
|
||||||
|
|
||||||
llamaOpts = append(llamaOpts, llama.SetMMap(c.MMap)) |
|
||||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(c.MainGPU)) |
|
||||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(c.TensorSplit)) |
|
||||||
if c.Batch != 0 { |
|
||||||
llamaOpts = append(llamaOpts, llama.SetNBatch(c.Batch)) |
|
||||||
} else { |
|
||||||
llamaOpts = append(llamaOpts, llama.SetNBatch(512)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.NUMA { |
|
||||||
llamaOpts = append(llamaOpts, llama.EnableNUMA) |
|
||||||
} |
|
||||||
|
|
||||||
if c.LowVRAM { |
|
||||||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) |
|
||||||
} |
|
||||||
|
|
||||||
return llamaOpts |
|
||||||
} |
|
||||||
|
|
||||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) { |
|
||||||
if c.Backend != model.StableDiffusionBackend { |
|
||||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models") |
|
||||||
} |
|
||||||
inferenceModel, err := loader.BackendLoader(c.Backend, c.ImageGenerationAssets, []llama.ModelOption{}, uint32(c.Threads), o.assetsDestination) |
|
||||||
if err != nil { |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
var fn func() error |
|
||||||
switch model := inferenceModel.(type) { |
|
||||||
case *stablediffusion.StableDiffusion: |
|
||||||
fn = func() error { |
|
||||||
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) |
|
||||||
} |
|
||||||
|
|
||||||
default: |
|
||||||
fn = func() error { |
|
||||||
return fmt.Errorf("creation of images not supported by the backend") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return func() error { |
|
||||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
mutexMap.Lock() |
|
||||||
l, ok := mutexes[c.Backend] |
|
||||||
if !ok { |
|
||||||
m := &sync.Mutex{} |
|
||||||
mutexes[c.Backend] = m |
|
||||||
l = m |
|
||||||
} |
|
||||||
mutexMap.Unlock() |
|
||||||
l.Lock() |
|
||||||
defer l.Unlock() |
|
||||||
|
|
||||||
return fn() |
|
||||||
}, nil |
|
||||||
} |
|
||||||
|
|
||||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, o *Option) (func() ([]float32, error), error) { |
|
||||||
if !c.Embeddings { |
|
||||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") |
|
||||||
} |
|
||||||
|
|
||||||
modelFile := c.Model |
|
||||||
|
|
||||||
llamaOpts := defaultLLamaOpts(c) |
|
||||||
|
|
||||||
var inferenceModel interface{} |
|
||||||
var err error |
|
||||||
if c.Backend == "" { |
|
||||||
inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) |
|
||||||
} else { |
|
||||||
inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) |
|
||||||
} |
|
||||||
if err != nil { |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
var fn func() ([]float32, error) |
|
||||||
switch model := inferenceModel.(type) { |
|
||||||
case *llama.LLama: |
|
||||||
fn = func() ([]float32, error) { |
|
||||||
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) |
|
||||||
if len(tokens) > 0 { |
|
||||||
return model.TokenEmbeddings(tokens, predictOptions...) |
|
||||||
} |
|
||||||
return model.Embeddings(s, predictOptions...) |
|
||||||
} |
|
||||||
// bert embeddings
|
|
||||||
case *bert.Bert: |
|
||||||
fn = func() ([]float32, error) { |
|
||||||
if len(tokens) > 0 { |
|
||||||
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) |
|
||||||
} |
|
||||||
return model.Embeddings(s, bert.SetThreads(c.Threads)) |
|
||||||
} |
|
||||||
default: |
|
||||||
fn = func() ([]float32, error) { |
|
||||||
return nil, fmt.Errorf("embeddings not supported by the backend") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return func() ([]float32, error) { |
|
||||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
mutexMap.Lock() |
|
||||||
l, ok := mutexes[modelFile] |
|
||||||
if !ok { |
|
||||||
m := &sync.Mutex{} |
|
||||||
mutexes[modelFile] = m |
|
||||||
l = m |
|
||||||
} |
|
||||||
mutexMap.Unlock() |
|
||||||
l.Lock() |
|
||||||
defer l.Unlock() |
|
||||||
|
|
||||||
embeds, err := fn() |
|
||||||
if err != nil { |
|
||||||
return embeds, err |
|
||||||
} |
|
||||||
// Remove trailing 0s
|
|
||||||
for i := len(embeds) - 1; i >= 0; i-- { |
|
||||||
if embeds[i] == 0.0 { |
|
||||||
embeds = embeds[:i] |
|
||||||
} else { |
|
||||||
break |
|
||||||
} |
|
||||||
} |
|
||||||
return embeds, nil |
|
||||||
}, nil |
|
||||||
} |
|
||||||
|
|
||||||
func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []llama.PredictOption{ |
|
||||||
llama.SetTemperature(c.Temperature), |
|
||||||
llama.SetTopP(c.TopP), |
|
||||||
llama.SetTopK(c.TopK), |
|
||||||
llama.SetTokens(c.Maxtokens), |
|
||||||
llama.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.PromptCacheAll { |
|
||||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll) |
|
||||||
} |
|
||||||
|
|
||||||
if c.PromptCacheRO { |
|
||||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO) |
|
||||||
} |
|
||||||
|
|
||||||
predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar)) |
|
||||||
|
|
||||||
if c.PromptCachePath != "" { |
|
||||||
// Create parent directory
|
|
||||||
p := filepath.Join(modelPath, c.PromptCachePath) |
|
||||||
os.MkdirAll(filepath.Dir(p), 0755) |
|
||||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(p)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Mirostat != 0 { |
|
||||||
predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.MirostatETA != 0 { |
|
||||||
predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.MirostatTAU != 0 { |
|
||||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Debug { |
|
||||||
predictOptions = append(predictOptions, llama.Debug) |
|
||||||
} |
|
||||||
|
|
||||||
predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) |
|
||||||
|
|
||||||
if c.RepeatPenalty != 0 { |
|
||||||
predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Keep != 0 { |
|
||||||
predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Batch != 0 { |
|
||||||
predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.F16 { |
|
||||||
predictOptions = append(predictOptions, llama.EnableF16KV) |
|
||||||
} |
|
||||||
|
|
||||||
if c.IgnoreEOS { |
|
||||||
predictOptions = append(predictOptions, llama.IgnoreEOS) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
|
||||||
|
|
||||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty)) |
|
||||||
predictOptions = append(predictOptions, llama.SetMlock(c.MMlock)) |
|
||||||
predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap)) |
|
||||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU)) |
|
||||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit)) |
|
||||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ)) |
|
||||||
predictOptions = append(predictOptions, llama.SetTypicalP(c.TypicalP)) |
|
||||||
|
|
||||||
return predictOptions |
|
||||||
} |
|
||||||
|
|
||||||
func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, tokenCallback func(string) bool) (func() (string, error), error) { |
|
||||||
supportStreams := false |
|
||||||
modelFile := c.Model |
|
||||||
|
|
||||||
llamaOpts := defaultLLamaOpts(c) |
|
||||||
|
|
||||||
var inferenceModel interface{} |
|
||||||
var err error |
|
||||||
if c.Backend == "" { |
|
||||||
inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) |
|
||||||
} else { |
|
||||||
inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) |
|
||||||
} |
|
||||||
if err != nil { |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
var fn func() (string, error) |
|
||||||
|
|
||||||
switch model := inferenceModel.(type) { |
|
||||||
case *rwkv.RwkvState: |
|
||||||
supportStreams = true |
|
||||||
|
|
||||||
fn = func() (string, error) { |
|
||||||
stopWord := "\n" |
|
||||||
if len(c.StopWords) > 0 { |
|
||||||
stopWord = c.StopWords[0] |
|
||||||
} |
|
||||||
|
|
||||||
if err := model.ProcessInput(s); err != nil { |
|
||||||
return "", err |
|
||||||
} |
|
||||||
|
|
||||||
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) |
|
||||||
|
|
||||||
return response, nil |
|
||||||
} |
|
||||||
case *transformers.GPTNeoX: |
|
||||||
fn = func() (string, error) { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []transformers.PredictOption{ |
|
||||||
transformers.SetTemperature(c.Temperature), |
|
||||||
transformers.SetTopP(c.TopP), |
|
||||||
transformers.SetTopK(c.TopK), |
|
||||||
transformers.SetTokens(c.Maxtokens), |
|
||||||
transformers.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Batch != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
return model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
} |
|
||||||
case *transformers.Replit: |
|
||||||
fn = func() (string, error) { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []transformers.PredictOption{ |
|
||||||
transformers.SetTemperature(c.Temperature), |
|
||||||
transformers.SetTopP(c.TopP), |
|
||||||
transformers.SetTopK(c.TopK), |
|
||||||
transformers.SetTokens(c.Maxtokens), |
|
||||||
transformers.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Batch != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
return model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
} |
|
||||||
case *transformers.Starcoder: |
|
||||||
fn = func() (string, error) { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []transformers.PredictOption{ |
|
||||||
transformers.SetTemperature(c.Temperature), |
|
||||||
transformers.SetTopP(c.TopP), |
|
||||||
transformers.SetTopK(c.TopK), |
|
||||||
transformers.SetTokens(c.Maxtokens), |
|
||||||
transformers.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Batch != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
return model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
} |
|
||||||
case *transformers.MPT: |
|
||||||
fn = func() (string, error) { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []transformers.PredictOption{ |
|
||||||
transformers.SetTemperature(c.Temperature), |
|
||||||
transformers.SetTopP(c.TopP), |
|
||||||
transformers.SetTopK(c.TopK), |
|
||||||
transformers.SetTokens(c.Maxtokens), |
|
||||||
transformers.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Batch != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
return model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
} |
|
||||||
case *bloomz.Bloomz: |
|
||||||
fn = func() (string, error) { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []bloomz.PredictOption{ |
|
||||||
bloomz.SetTemperature(c.Temperature), |
|
||||||
bloomz.SetTopP(c.TopP), |
|
||||||
bloomz.SetTopK(c.TopK), |
|
||||||
bloomz.SetTokens(c.Maxtokens), |
|
||||||
bloomz.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
return model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
} |
|
||||||
case *transformers.Falcon: |
|
||||||
fn = func() (string, error) { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []transformers.PredictOption{ |
|
||||||
transformers.SetTemperature(c.Temperature), |
|
||||||
transformers.SetTopP(c.TopP), |
|
||||||
transformers.SetTopK(c.TopK), |
|
||||||
transformers.SetTokens(c.Maxtokens), |
|
||||||
transformers.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Batch != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
return model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
} |
|
||||||
case *transformers.GPTJ: |
|
||||||
fn = func() (string, error) { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []transformers.PredictOption{ |
|
||||||
transformers.SetTemperature(c.Temperature), |
|
||||||
transformers.SetTopP(c.TopP), |
|
||||||
transformers.SetTopK(c.TopK), |
|
||||||
transformers.SetTokens(c.Maxtokens), |
|
||||||
transformers.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Batch != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
return model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
} |
|
||||||
case *transformers.Dolly: |
|
||||||
fn = func() (string, error) { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []transformers.PredictOption{ |
|
||||||
transformers.SetTemperature(c.Temperature), |
|
||||||
transformers.SetTopP(c.TopP), |
|
||||||
transformers.SetTopK(c.TopK), |
|
||||||
transformers.SetTokens(c.Maxtokens), |
|
||||||
transformers.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Batch != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
return model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
} |
|
||||||
case *transformers.GPT2: |
|
||||||
fn = func() (string, error) { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []transformers.PredictOption{ |
|
||||||
transformers.SetTemperature(c.Temperature), |
|
||||||
transformers.SetTopP(c.TopP), |
|
||||||
transformers.SetTopK(c.TopK), |
|
||||||
transformers.SetTokens(c.Maxtokens), |
|
||||||
transformers.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Batch != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
return model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
} |
|
||||||
case *gpt4all.Model: |
|
||||||
supportStreams = true |
|
||||||
|
|
||||||
fn = func() (string, error) { |
|
||||||
if tokenCallback != nil { |
|
||||||
model.SetTokenCallback(tokenCallback) |
|
||||||
} |
|
||||||
|
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []gpt4all.PredictOption{ |
|
||||||
gpt4all.SetTemperature(c.Temperature), |
|
||||||
gpt4all.SetTopP(c.TopP), |
|
||||||
gpt4all.SetTopK(c.TopK), |
|
||||||
gpt4all.SetTokens(c.Maxtokens), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Batch != 0 { |
|
||||||
predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch)) |
|
||||||
} |
|
||||||
|
|
||||||
str, er := model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
|
||||||
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
|
||||||
// after a stream event has occurred
|
|
||||||
model.SetTokenCallback(nil) |
|
||||||
return str, er |
|
||||||
} |
|
||||||
case *llama.LLama: |
|
||||||
supportStreams = true |
|
||||||
fn = func() (string, error) { |
|
||||||
|
|
||||||
if tokenCallback != nil { |
|
||||||
model.SetTokenCallback(tokenCallback) |
|
||||||
} |
|
||||||
|
|
||||||
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) |
|
||||||
|
|
||||||
str, er := model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
|
||||||
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
|
||||||
// after a stream event has occurred
|
|
||||||
model.SetTokenCallback(nil) |
|
||||||
return str, er |
|
||||||
} |
|
||||||
case *langchain.HuggingFace: |
|
||||||
fn = func() (string, error) { |
|
||||||
|
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []langchain.PredictOption{ |
|
||||||
langchain.SetModel(c.Model), |
|
||||||
langchain.SetMaxTokens(c.Maxtokens), |
|
||||||
langchain.SetTemperature(c.Temperature), |
|
||||||
langchain.SetStopWords(c.StopWords), |
|
||||||
} |
|
||||||
|
|
||||||
pred, er := model.PredictHuggingFace(s, predictOptions...) |
|
||||||
if er != nil { |
|
||||||
return "", er |
|
||||||
} |
|
||||||
return pred.Completion, nil |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return func() (string, error) { |
|
||||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
mutexMap.Lock() |
|
||||||
l, ok := mutexes[modelFile] |
|
||||||
if !ok { |
|
||||||
m := &sync.Mutex{} |
|
||||||
mutexes[modelFile] = m |
|
||||||
l = m |
|
||||||
} |
|
||||||
mutexMap.Unlock() |
|
||||||
l.Lock() |
|
||||||
defer l.Unlock() |
|
||||||
|
|
||||||
res, err := fn() |
|
||||||
if tokenCallback != nil && !supportStreams { |
|
||||||
tokenCallback(res) |
|
||||||
} |
|
||||||
return res, err |
|
||||||
}, nil |
|
||||||
} |
|
||||||
|
|
||||||
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, o *Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { |
|
||||||
result := []Choice{} |
|
||||||
|
|
||||||
n := input.N |
|
||||||
|
|
||||||
if input.N == 0 { |
|
||||||
n = 1 |
|
||||||
} |
|
||||||
|
|
||||||
// get the model function to call for the result
|
|
||||||
predFunc, err := ModelInference(predInput, loader, *config, o, tokenCallback) |
|
||||||
if err != nil { |
|
||||||
return result, err |
|
||||||
} |
|
||||||
|
|
||||||
for i := 0; i < n; i++ { |
|
||||||
prediction, err := predFunc() |
|
||||||
if err != nil { |
|
||||||
return result, err |
|
||||||
} |
|
||||||
|
|
||||||
prediction = Finetune(*config, predInput, prediction) |
|
||||||
cb(prediction, &result) |
|
||||||
|
|
||||||
//result = append(result, Choice{Text: prediction})
|
|
||||||
|
|
||||||
} |
|
||||||
return result, err |
|
||||||
} |
|
||||||
|
|
||||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
|
||||||
var mu sync.Mutex = sync.Mutex{} |
|
||||||
|
|
||||||
func Finetune(config Config, input, prediction string) string { |
|
||||||
if config.Echo { |
|
||||||
prediction = input + prediction |
|
||||||
} |
|
||||||
|
|
||||||
for _, c := range config.Cutstrings { |
|
||||||
mu.Lock() |
|
||||||
reg, ok := cutstrings[c] |
|
||||||
if !ok { |
|
||||||
cutstrings[c] = regexp.MustCompile(c) |
|
||||||
reg = cutstrings[c] |
|
||||||
} |
|
||||||
mu.Unlock() |
|
||||||
prediction = reg.ReplaceAllString(prediction, "") |
|
||||||
} |
|
||||||
|
|
||||||
for _, c := range config.TrimSpace { |
|
||||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
|
||||||
} |
|
||||||
return prediction |
|
||||||
|
|
||||||
} |
|
@ -0,0 +1,22 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
bert "github.com/go-skynet/LocalAI/pkg/grpc/llm/bert" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &bert.Embeddings{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
bloomz "github.com/go-skynet/LocalAI/pkg/grpc/llm/bloomz" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &bloomz.LLM{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &transformers.Dolly{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &transformers.Falcon{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,25 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// GRPC Falcon server
|
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
falcon "github.com/go-skynet/LocalAI/pkg/grpc/llm/falcon" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &falcon.LLM{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &transformers.GPT2{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
gpt4all "github.com/go-skynet/LocalAI/pkg/grpc/llm/gpt4all" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &gpt4all.LLM{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &transformers.GPTJ{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &transformers.GPTNeoX{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
langchain "github.com/go-skynet/LocalAI/pkg/grpc/llm/langchain" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &langchain.LLM{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,25 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// GRPC Falcon server
|
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &llama.LLM{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &transformers.MPT{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
tts "github.com/go-skynet/LocalAI/pkg/grpc/tts" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &tts.Piper{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &transformers.Replit{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
rwkv "github.com/go-skynet/LocalAI/pkg/grpc/llm/rwkv" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &rwkv.LLM{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
image "github.com/go-skynet/LocalAI/pkg/grpc/image" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &image.StableDiffusion{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &transformers.Starcoder{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,23 @@ |
|||||||
|
package main |
||||||
|
|
||||||
|
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||||
|
|
||||||
|
import ( |
||||||
|
"flag" |
||||||
|
|
||||||
|
transcribe "github.com/go-skynet/LocalAI/pkg/grpc/transcribe" |
||||||
|
|
||||||
|
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||||
|
) |
||||||
|
|
||||||
|
func main() { |
||||||
|
flag.Parse() |
||||||
|
|
||||||
|
if err := grpc.StartServer(*addr, &transcribe.Whisper{}); err != nil { |
||||||
|
panic(err) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,42 @@ |
|||||||
|
package base |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" |
||||||
|
) |
||||||
|
|
||||||
|
type Base struct { |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Base) Load(opts *pb.ModelOptions) error { |
||||||
|
return fmt.Errorf("unimplemented") |
||||||
|
|
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Base) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return "", fmt.Errorf("unimplemented") |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Base) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
return fmt.Errorf("unimplemented") |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Base) Embeddings(opts *pb.PredictOptions) ([]float32, error) { |
||||||
|
return []float32{}, fmt.Errorf("unimplemented") |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { |
||||||
|
return fmt.Errorf("unimplemented") |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (api.Result, error) { |
||||||
|
return api.Result{}, fmt.Errorf("unimplemented") |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Base) TTS(*pb.TTSRequest) error { |
||||||
|
return fmt.Errorf("unimplemented") |
||||||
|
} |
@ -0,0 +1,160 @@ |
|||||||
|
package grpc |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"time" |
||||||
|
|
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" |
||||||
|
"google.golang.org/grpc" |
||||||
|
"google.golang.org/grpc/credentials/insecure" |
||||||
|
) |
||||||
|
|
||||||
|
type Client struct { |
||||||
|
address string |
||||||
|
} |
||||||
|
|
||||||
|
func NewClient(address string) *Client { |
||||||
|
return &Client{ |
||||||
|
address: address, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) HealthCheck(ctx context.Context) bool { |
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||||
|
if err != nil { |
||||||
|
fmt.Println(err) |
||||||
|
return false |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
client := pb.NewBackendClient(conn) |
||||||
|
|
||||||
|
// The healthcheck call shouldn't take long time
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 10*time.Second) |
||||||
|
defer cancel() |
||||||
|
|
||||||
|
res, err := client.Health(ctx, &pb.HealthMessage{}) |
||||||
|
if err != nil { |
||||||
|
fmt.Println(err) |
||||||
|
|
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
if res.Message == "OK" { |
||||||
|
return true |
||||||
|
} |
||||||
|
return false |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) { |
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
client := pb.NewBackendClient(conn) |
||||||
|
|
||||||
|
return client.Embedding(ctx, in, opts...) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) { |
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
client := pb.NewBackendClient(conn) |
||||||
|
|
||||||
|
return client.Predict(ctx, in, opts...) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) { |
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
client := pb.NewBackendClient(conn) |
||||||
|
return client.LoadModel(ctx, in, opts...) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s string), opts ...grpc.CallOption) error { |
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
client := pb.NewBackendClient(conn) |
||||||
|
|
||||||
|
stream, err := client.PredictStream(ctx, in, opts...) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
for { |
||||||
|
feature, err := stream.Recv() |
||||||
|
if err == io.EOF { |
||||||
|
break |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
fmt.Println("Error", err) |
||||||
|
|
||||||
|
return err |
||||||
|
} |
||||||
|
f(feature.GetMessage()) |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) { |
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
client := pb.NewBackendClient(conn) |
||||||
|
return client.GenerateImage(ctx, in, opts...) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { |
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
client := pb.NewBackendClient(conn) |
||||||
|
return client.TTS(ctx, in, opts...) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*api.Result, error) { |
||||||
|
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
defer conn.Close() |
||||||
|
client := pb.NewBackendClient(conn) |
||||||
|
res, err := client.AudioTranscription(ctx, in, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
tresult := &api.Result{} |
||||||
|
for _, s := range res.Segments { |
||||||
|
tks := []int{} |
||||||
|
for _, t := range s.Tokens { |
||||||
|
tks = append(tks, int(t)) |
||||||
|
} |
||||||
|
tresult.Segments = append(tresult.Segments, |
||||||
|
api.Segment{ |
||||||
|
Text: s.Text, |
||||||
|
Id: int(s.Id), |
||||||
|
Start: time.Duration(s.Start), |
||||||
|
End: time.Duration(s.End), |
||||||
|
Tokens: tks, |
||||||
|
}) |
||||||
|
} |
||||||
|
tresult.Text = res.Text |
||||||
|
return tresult, err |
||||||
|
} |
@ -0,0 +1,33 @@ |
|||||||
|
package image |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/stablediffusion" |
||||||
|
) |
||||||
|
|
||||||
|
type StableDiffusion struct { |
||||||
|
base.Base |
||||||
|
stablediffusion *stablediffusion.StableDiffusion |
||||||
|
} |
||||||
|
|
||||||
|
func (sd *StableDiffusion) Load(opts *pb.ModelOptions) error { |
||||||
|
var err error |
||||||
|
// Note: the Model here is a path to a directory containing the model files
|
||||||
|
sd.stablediffusion, err = stablediffusion.New(opts.Model) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (sd *StableDiffusion) GenerateImage(opts *pb.GenerateImageRequest) error { |
||||||
|
return sd.stablediffusion.GenerateImage( |
||||||
|
int(opts.Height), |
||||||
|
int(opts.Width), |
||||||
|
int(opts.Mode), |
||||||
|
int(opts.Step), |
||||||
|
int(opts.Seed), |
||||||
|
opts.PositivePrompt, |
||||||
|
opts.NegativePrompt, |
||||||
|
opts.Dst) |
||||||
|
} |
@ -0,0 +1,16 @@ |
|||||||
|
package grpc |
||||||
|
|
||||||
|
import ( |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" |
||||||
|
) |
||||||
|
|
||||||
|
type LLM interface { |
||||||
|
Predict(*pb.PredictOptions) (string, error) |
||||||
|
PredictStream(*pb.PredictOptions, chan string) error |
||||||
|
Load(*pb.ModelOptions) error |
||||||
|
Embeddings(*pb.PredictOptions) ([]float32, error) |
||||||
|
GenerateImage(*pb.GenerateImageRequest) error |
||||||
|
AudioTranscription(*pb.TranscriptRequest) (api.Result, error) |
||||||
|
TTS(*pb.TTSRequest) error |
||||||
|
} |
@ -0,0 +1,33 @@ |
|||||||
|
package bert |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
bert "github.com/go-skynet/go-bert.cpp" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
) |
||||||
|
|
||||||
|
type Embeddings struct { |
||||||
|
base.Base |
||||||
|
bert *bert.Bert |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Embeddings) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := bert.New(opts.Model) |
||||||
|
llm.bert = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) { |
||||||
|
if len(opts.EmbeddingTokens) > 0 { |
||||||
|
tokens := []int{} |
||||||
|
for _, t := range opts.EmbeddingTokens { |
||||||
|
tokens = append(tokens, int(t)) |
||||||
|
} |
||||||
|
return llm.bert.TokenEmbeddings(tokens, bert.SetThreads(int(opts.Threads))) |
||||||
|
} |
||||||
|
|
||||||
|
return llm.bert.Embeddings(opts.Embeddings, bert.SetThreads(int(opts.Threads))) |
||||||
|
} |
@ -0,0 +1,59 @@ |
|||||||
|
package bloomz |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
"github.com/go-skynet/bloomz.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type LLM struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
bloomz *bloomz.Bloomz |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := bloomz.New(opts.Model) |
||||||
|
llm.bloomz = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func buildPredictOptions(opts *pb.PredictOptions) []bloomz.PredictOption { |
||||||
|
predictOptions := []bloomz.PredictOption{ |
||||||
|
bloomz.SetTemperature(float64(opts.Temperature)), |
||||||
|
bloomz.SetTopP(float64(opts.TopP)), |
||||||
|
bloomz.SetTopK(int(opts.TopK)), |
||||||
|
bloomz.SetTokens(int(opts.Tokens)), |
||||||
|
bloomz.SetThreads(int(opts.Threads)), |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Seed != 0 { |
||||||
|
predictOptions = append(predictOptions, bloomz.SetSeed(int(opts.Seed))) |
||||||
|
} |
||||||
|
|
||||||
|
return predictOptions |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
// fallback to Predict
|
||||||
|
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
go func() { |
||||||
|
res, err := llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
results <- res |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,144 @@ |
|||||||
|
package falcon |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
ggllm "github.com/mudler/go-ggllm.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type LLM struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
falcon *ggllm.Falcon |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||||
|
ggllmOpts := []ggllm.ModelOption{} |
||||||
|
if opts.ContextSize != 0 { |
||||||
|
ggllmOpts = append(ggllmOpts, ggllm.SetContext(int(opts.ContextSize))) |
||||||
|
} |
||||||
|
// F16 doesn't seem to produce good output at all!
|
||||||
|
//if c.F16 {
|
||||||
|
// llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||||
|
//}
|
||||||
|
|
||||||
|
if opts.NGPULayers != 0 { |
||||||
|
ggllmOpts = append(ggllmOpts, ggllm.SetGPULayers(int(opts.NGPULayers))) |
||||||
|
} |
||||||
|
|
||||||
|
ggllmOpts = append(ggllmOpts, ggllm.SetMMap(opts.MMap)) |
||||||
|
ggllmOpts = append(ggllmOpts, ggllm.SetMainGPU(opts.MainGPU)) |
||||||
|
ggllmOpts = append(ggllmOpts, ggllm.SetTensorSplit(opts.TensorSplit)) |
||||||
|
if opts.NBatch != 0 { |
||||||
|
ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(int(opts.NBatch))) |
||||||
|
} else { |
||||||
|
ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(512)) |
||||||
|
} |
||||||
|
|
||||||
|
model, err := ggllm.New(opts.Model, ggllmOpts...) |
||||||
|
llm.falcon = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption { |
||||||
|
predictOptions := []ggllm.PredictOption{ |
||||||
|
ggllm.SetTemperature(float64(opts.Temperature)), |
||||||
|
ggllm.SetTopP(float64(opts.TopP)), |
||||||
|
ggllm.SetTopK(int(opts.TopK)), |
||||||
|
ggllm.SetTokens(int(opts.Tokens)), |
||||||
|
ggllm.SetThreads(int(opts.Threads)), |
||||||
|
} |
||||||
|
|
||||||
|
if opts.PromptCacheAll { |
||||||
|
predictOptions = append(predictOptions, ggllm.EnablePromptCacheAll) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.PromptCacheRO { |
||||||
|
predictOptions = append(predictOptions, ggllm.EnablePromptCacheRO) |
||||||
|
} |
||||||
|
|
||||||
|
// Expected absolute path
|
||||||
|
if opts.PromptCachePath != "" { |
||||||
|
predictOptions = append(predictOptions, ggllm.SetPathPromptCache(opts.PromptCachePath)) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Mirostat != 0 { |
||||||
|
predictOptions = append(predictOptions, ggllm.SetMirostat(int(opts.Mirostat))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.MirostatETA != 0 { |
||||||
|
predictOptions = append(predictOptions, ggllm.SetMirostatETA(float64(opts.MirostatETA))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.MirostatTAU != 0 { |
||||||
|
predictOptions = append(predictOptions, ggllm.SetMirostatTAU(float64(opts.MirostatTAU))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Debug { |
||||||
|
predictOptions = append(predictOptions, ggllm.Debug) |
||||||
|
} |
||||||
|
|
||||||
|
predictOptions = append(predictOptions, ggllm.SetStopWords(opts.StopPrompts...)) |
||||||
|
|
||||||
|
if opts.PresencePenalty != 0 { |
||||||
|
predictOptions = append(predictOptions, ggllm.SetPenalty(float64(opts.PresencePenalty))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.NKeep != 0 { |
||||||
|
predictOptions = append(predictOptions, ggllm.SetNKeep(int(opts.NKeep))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Batch != 0 { |
||||||
|
predictOptions = append(predictOptions, ggllm.SetBatch(int(opts.Batch))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.IgnoreEOS { |
||||||
|
predictOptions = append(predictOptions, ggllm.IgnoreEOS) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Seed != 0 { |
||||||
|
predictOptions = append(predictOptions, ggllm.SetSeed(int(opts.Seed))) |
||||||
|
} |
||||||
|
|
||||||
|
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||||
|
|
||||||
|
predictOptions = append(predictOptions, ggllm.SetFrequencyPenalty(float64(opts.FrequencyPenalty))) |
||||||
|
predictOptions = append(predictOptions, ggllm.SetMlock(opts.MLock)) |
||||||
|
predictOptions = append(predictOptions, ggllm.SetMemoryMap(opts.MMap)) |
||||||
|
predictOptions = append(predictOptions, ggllm.SetPredictionMainGPU(opts.MainGPU)) |
||||||
|
predictOptions = append(predictOptions, ggllm.SetPredictionTensorSplit(opts.TensorSplit)) |
||||||
|
predictOptions = append(predictOptions, ggllm.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ))) |
||||||
|
predictOptions = append(predictOptions, ggllm.SetTypicalP(float64(opts.TypicalP))) |
||||||
|
return predictOptions |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
predictOptions := buildPredictOptions(opts) |
||||||
|
|
||||||
|
predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool { |
||||||
|
if token == "<|endoftext|>" { |
||||||
|
return true |
||||||
|
} |
||||||
|
results <- token |
||||||
|
return true |
||||||
|
})) |
||||||
|
|
||||||
|
go func() { |
||||||
|
_, err := llm.falcon.Predict(opts.Prompt, predictOptions...) |
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,62 @@ |
|||||||
|
package gpt4all |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" |
||||||
|
) |
||||||
|
|
||||||
|
type LLM struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
gpt4all *gpt4all.Model |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := gpt4all.New(opts.Model, |
||||||
|
gpt4all.SetThreads(int(opts.Threads)), |
||||||
|
gpt4all.SetLibrarySearchPath(opts.LibrarySearchPath)) |
||||||
|
llm.gpt4all = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func buildPredictOptions(opts *pb.PredictOptions) []gpt4all.PredictOption { |
||||||
|
predictOptions := []gpt4all.PredictOption{ |
||||||
|
gpt4all.SetTemperature(float64(opts.Temperature)), |
||||||
|
gpt4all.SetTopP(float64(opts.TopP)), |
||||||
|
gpt4all.SetTopK(int(opts.TopK)), |
||||||
|
gpt4all.SetTokens(int(opts.Tokens)), |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Batch != 0 { |
||||||
|
predictOptions = append(predictOptions, gpt4all.SetBatch(int(opts.Batch))) |
||||||
|
} |
||||||
|
return predictOptions |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
predictOptions := buildPredictOptions(opts) |
||||||
|
|
||||||
|
go func() { |
||||||
|
llm.gpt4all.SetTokenCallback(func(token string) bool { |
||||||
|
results <- token |
||||||
|
return true |
||||||
|
}) |
||||||
|
_, err := llm.gpt4all.Predict(opts.Prompt, predictOptions...) |
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
llm.gpt4all.SetTokenCallback(nil) |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,58 @@ |
|||||||
|
package langchain |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/langchain" |
||||||
|
) |
||||||
|
|
||||||
|
type LLM struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
langchain *langchain.HuggingFace |
||||||
|
model string |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||||
|
llm.langchain, _ = langchain.NewHuggingFace(opts.Model) |
||||||
|
llm.model = opts.Model |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
o := []langchain.PredictOption{ |
||||||
|
langchain.SetModel(llm.model), |
||||||
|
langchain.SetMaxTokens(int(opts.Tokens)), |
||||||
|
langchain.SetTemperature(float64(opts.Temperature)), |
||||||
|
langchain.SetStopWords(opts.StopPrompts), |
||||||
|
} |
||||||
|
pred, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...) |
||||||
|
if err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
return pred.Completion, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
o := []langchain.PredictOption{ |
||||||
|
langchain.SetModel(llm.model), |
||||||
|
langchain.SetMaxTokens(int(opts.Tokens)), |
||||||
|
langchain.SetTemperature(float64(opts.Temperature)), |
||||||
|
langchain.SetStopWords(opts.StopPrompts), |
||||||
|
} |
||||||
|
go func() { |
||||||
|
res, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
results <- res.Completion |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,170 @@ |
|||||||
|
package llama |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
"github.com/go-skynet/go-llama.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type LLM struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
llama *llama.LLama |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||||
|
llamaOpts := []llama.ModelOption{} |
||||||
|
|
||||||
|
if opts.ContextSize != 0 { |
||||||
|
llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize))) |
||||||
|
} |
||||||
|
if opts.F16Memory { |
||||||
|
llamaOpts = append(llamaOpts, llama.EnableF16Memory) |
||||||
|
} |
||||||
|
if opts.Embeddings { |
||||||
|
llamaOpts = append(llamaOpts, llama.EnableEmbeddings) |
||||||
|
} |
||||||
|
if opts.NGPULayers != 0 { |
||||||
|
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers))) |
||||||
|
} |
||||||
|
|
||||||
|
llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap)) |
||||||
|
llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU)) |
||||||
|
llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit)) |
||||||
|
if opts.NBatch != 0 { |
||||||
|
llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch))) |
||||||
|
} else { |
||||||
|
llamaOpts = append(llamaOpts, llama.SetNBatch(512)) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.NUMA { |
||||||
|
llamaOpts = append(llamaOpts, llama.EnableNUMA) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.LowVRAM { |
||||||
|
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) |
||||||
|
} |
||||||
|
|
||||||
|
model, err := llama.New(opts.Model, llamaOpts...) |
||||||
|
llm.llama = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { |
||||||
|
predictOptions := []llama.PredictOption{ |
||||||
|
llama.SetTemperature(float64(opts.Temperature)), |
||||||
|
llama.SetTopP(float64(opts.TopP)), |
||||||
|
llama.SetTopK(int(opts.TopK)), |
||||||
|
llama.SetTokens(int(opts.Tokens)), |
||||||
|
llama.SetThreads(int(opts.Threads)), |
||||||
|
} |
||||||
|
|
||||||
|
if opts.PromptCacheAll { |
||||||
|
predictOptions = append(predictOptions, llama.EnablePromptCacheAll) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.PromptCacheRO { |
||||||
|
predictOptions = append(predictOptions, llama.EnablePromptCacheRO) |
||||||
|
} |
||||||
|
|
||||||
|
predictOptions = append(predictOptions, llama.WithGrammar(opts.Grammar)) |
||||||
|
|
||||||
|
// Expected absolute path
|
||||||
|
if opts.PromptCachePath != "" { |
||||||
|
predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath)) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Mirostat != 0 { |
||||||
|
predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.MirostatETA != 0 { |
||||||
|
predictOptions = append(predictOptions, llama.SetMirostatETA(float64(opts.MirostatETA))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.MirostatTAU != 0 { |
||||||
|
predictOptions = append(predictOptions, llama.SetMirostatTAU(float64(opts.MirostatTAU))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Debug { |
||||||
|
predictOptions = append(predictOptions, llama.Debug) |
||||||
|
} |
||||||
|
|
||||||
|
predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...)) |
||||||
|
|
||||||
|
if opts.PresencePenalty != 0 { |
||||||
|
predictOptions = append(predictOptions, llama.SetPenalty(float64(opts.PresencePenalty))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.NKeep != 0 { |
||||||
|
predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Batch != 0 { |
||||||
|
predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.F16KV { |
||||||
|
predictOptions = append(predictOptions, llama.EnableF16KV) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.IgnoreEOS { |
||||||
|
predictOptions = append(predictOptions, llama.IgnoreEOS) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Seed != 0 { |
||||||
|
predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed))) |
||||||
|
} |
||||||
|
|
||||||
|
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||||
|
|
||||||
|
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(float64(opts.FrequencyPenalty))) |
||||||
|
predictOptions = append(predictOptions, llama.SetMlock(opts.MLock)) |
||||||
|
predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap)) |
||||||
|
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU)) |
||||||
|
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit)) |
||||||
|
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ))) |
||||||
|
predictOptions = append(predictOptions, llama.SetTypicalP(float64(opts.TypicalP))) |
||||||
|
return predictOptions |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
predictOptions := buildPredictOptions(opts) |
||||||
|
|
||||||
|
predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool { |
||||||
|
results <- token |
||||||
|
return true |
||||||
|
})) |
||||||
|
|
||||||
|
go func() { |
||||||
|
_, err := llm.llama.Predict(opts.Prompt, predictOptions...) |
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { |
||||||
|
predictOptions := buildPredictOptions(opts) |
||||||
|
|
||||||
|
if len(opts.EmbeddingTokens) > 0 { |
||||||
|
tokens := []int{} |
||||||
|
for _, t := range opts.EmbeddingTokens { |
||||||
|
tokens = append(tokens, int(t)) |
||||||
|
} |
||||||
|
return llm.llama.TokenEmbeddings(tokens, predictOptions...) |
||||||
|
} |
||||||
|
|
||||||
|
return llm.llama.Embeddings(opts.Embeddings, predictOptions...) |
||||||
|
} |
@ -0,0 +1,71 @@ |
|||||||
|
package rwkv |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"path/filepath" |
||||||
|
|
||||||
|
"github.com/donomii/go-rwkv.cpp" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
) |
||||||
|
|
||||||
|
const tokenizerSuffix = ".tokenizer.json" |
||||||
|
|
||||||
|
type LLM struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
rwkv *rwkv.RwkvState |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||||
|
modelPath := filepath.Dir(opts.Model) |
||||||
|
modelFile := filepath.Base(opts.Model) |
||||||
|
model := rwkv.LoadFiles(opts.Model, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads())) |
||||||
|
|
||||||
|
if model == nil { |
||||||
|
return fmt.Errorf("could not load model") |
||||||
|
} |
||||||
|
llm.rwkv = model |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
|
||||||
|
stopWord := "\n" |
||||||
|
if len(opts.StopPrompts) > 0 { |
||||||
|
stopWord = opts.StopPrompts[0] |
||||||
|
} |
||||||
|
|
||||||
|
if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
|
||||||
|
response := llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), nil) |
||||||
|
|
||||||
|
return response, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
go func() { |
||||||
|
|
||||||
|
stopWord := "\n" |
||||||
|
if len(opts.StopPrompts) > 0 { |
||||||
|
stopWord = opts.StopPrompts[0] |
||||||
|
} |
||||||
|
|
||||||
|
if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil { |
||||||
|
fmt.Println("Error processing input: ", err) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), func(s string) bool { |
||||||
|
results <- s |
||||||
|
return true |
||||||
|
}) |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,43 @@ |
|||||||
|
package transformers |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type Dolly struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
dolly *transformers.Dolly |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Dolly) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := transformers.NewDolly(opts.Model) |
||||||
|
llm.dolly = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
// fallback to Predict
|
||||||
|
func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
go func() { |
||||||
|
res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
results <- res |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,43 @@ |
|||||||
|
package transformers |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type Falcon struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
falcon *transformers.Falcon |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Falcon) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := transformers.NewFalcon(opts.Model) |
||||||
|
llm.falcon = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
// fallback to Predict
|
||||||
|
func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
go func() { |
||||||
|
res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
results <- res |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,42 @@ |
|||||||
|
package transformers |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type GPT2 struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
gpt2 *transformers.GPT2 |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *GPT2) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := transformers.New(opts.Model) |
||||||
|
llm.gpt2 = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
// fallback to Predict
|
||||||
|
func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
go func() { |
||||||
|
res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
results <- res |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,42 @@ |
|||||||
|
package transformers |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type GPTJ struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
gptj *transformers.GPTJ |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *GPTJ) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := transformers.NewGPTJ(opts.Model) |
||||||
|
llm.gptj = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
// fallback to Predict
|
||||||
|
func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
go func() { |
||||||
|
res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
results <- res |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,42 @@ |
|||||||
|
package transformers |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type GPTNeoX struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
gptneox *transformers.GPTNeoX |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := transformers.NewGPTNeoX(opts.Model) |
||||||
|
llm.gptneox = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
// fallback to Predict
|
||||||
|
func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
go func() { |
||||||
|
res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
results <- res |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,42 @@ |
|||||||
|
package transformers |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type MPT struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
mpt *transformers.MPT |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *MPT) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := transformers.NewMPT(opts.Model) |
||||||
|
llm.mpt = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
// fallback to Predict
|
||||||
|
func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
go func() { |
||||||
|
res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
results <- res |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,26 @@ |
|||||||
|
package transformers |
||||||
|
|
||||||
|
import ( |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
func buildPredictOptions(opts *pb.PredictOptions) []transformers.PredictOption { |
||||||
|
predictOptions := []transformers.PredictOption{ |
||||||
|
transformers.SetTemperature(float64(opts.Temperature)), |
||||||
|
transformers.SetTopP(float64(opts.TopP)), |
||||||
|
transformers.SetTopK(int(opts.TopK)), |
||||||
|
transformers.SetTokens(int(opts.Tokens)), |
||||||
|
transformers.SetThreads(int(opts.Threads)), |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Batch != 0 { |
||||||
|
predictOptions = append(predictOptions, transformers.SetBatch(int(opts.Batch))) |
||||||
|
} |
||||||
|
|
||||||
|
if opts.Seed != 0 { |
||||||
|
predictOptions = append(predictOptions, transformers.SetSeed(int(opts.Seed))) |
||||||
|
} |
||||||
|
|
||||||
|
return predictOptions |
||||||
|
} |
@ -0,0 +1,42 @@ |
|||||||
|
package transformers |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type Replit struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
replit *transformers.Replit |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Replit) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := transformers.NewReplit(opts.Model) |
||||||
|
llm.replit = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
// fallback to Predict
|
||||||
|
func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
go func() { |
||||||
|
res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
results <- res |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,43 @@ |
|||||||
|
package transformers |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
type Starcoder struct { |
||||||
|
base.Base |
||||||
|
|
||||||
|
starcoder *transformers.Starcoder |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Starcoder) Load(opts *pb.ModelOptions) error { |
||||||
|
model, err := transformers.NewStarcoder(opts.Model) |
||||||
|
llm.starcoder = model |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) { |
||||||
|
return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
} |
||||||
|
|
||||||
|
// fallback to Predict
|
||||||
|
func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||||
|
go func() { |
||||||
|
res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
fmt.Println("err: ", err) |
||||||
|
} |
||||||
|
results <- res |
||||||
|
close(results) |
||||||
|
}() |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,129 @@ |
|||||||
|
syntax = "proto3"; |
||||||
|
|
||||||
|
option go_package = "github.com/go-skynet/LocalAI/pkg/grpc/proto"; |
||||||
|
option java_multiple_files = true; |
||||||
|
option java_package = "io.skynet.localai.backend"; |
||||||
|
option java_outer_classname = "LocalAIBackend"; |
||||||
|
|
||||||
|
package backend; |
||||||
|
|
||||||
|
service Backend { |
||||||
|
rpc Health(HealthMessage) returns (Reply) {} |
||||||
|
rpc Predict(PredictOptions) returns (Reply) {} |
||||||
|
rpc LoadModel(ModelOptions) returns (Result) {} |
||||||
|
rpc PredictStream(PredictOptions) returns (stream Reply) {} |
||||||
|
rpc Embedding(PredictOptions) returns (EmbeddingResult) {} |
||||||
|
rpc GenerateImage(GenerateImageRequest) returns (Result) {} |
||||||
|
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} |
||||||
|
rpc TTS(TTSRequest) returns (Result) {} |
||||||
|
} |
||||||
|
|
||||||
|
message HealthMessage {} |
||||||
|
|
||||||
|
// The request message containing the user's name. |
||||||
|
message PredictOptions { |
||||||
|
string Prompt = 1; |
||||||
|
int32 Seed = 2; |
||||||
|
int32 Threads = 3; |
||||||
|
int32 Tokens = 4; |
||||||
|
int32 TopK = 5; |
||||||
|
int32 Repeat = 6; |
||||||
|
int32 Batch = 7; |
||||||
|
int32 NKeep = 8; |
||||||
|
float Temperature = 9; |
||||||
|
float Penalty = 10; |
||||||
|
bool F16KV = 11; |
||||||
|
bool DebugMode = 12; |
||||||
|
repeated string StopPrompts = 13; |
||||||
|
bool IgnoreEOS = 14; |
||||||
|
float TailFreeSamplingZ = 15; |
||||||
|
float TypicalP = 16; |
||||||
|
float FrequencyPenalty = 17; |
||||||
|
float PresencePenalty = 18; |
||||||
|
int32 Mirostat = 19; |
||||||
|
float MirostatETA = 20; |
||||||
|
float MirostatTAU = 21; |
||||||
|
bool PenalizeNL = 22; |
||||||
|
string LogitBias = 23; |
||||||
|
bool MLock = 25; |
||||||
|
bool MMap = 26; |
||||||
|
bool PromptCacheAll = 27; |
||||||
|
bool PromptCacheRO = 28; |
||||||
|
string Grammar = 29; |
||||||
|
string MainGPU = 30; |
||||||
|
string TensorSplit = 31; |
||||||
|
float TopP = 32; |
||||||
|
string PromptCachePath = 33; |
||||||
|
bool Debug = 34; |
||||||
|
repeated int32 EmbeddingTokens = 35; |
||||||
|
string Embeddings = 36; |
||||||
|
} |
||||||
|
|
||||||
|
// The response message containing the result |
||||||
|
message Reply { |
||||||
|
string message = 1; |
||||||
|
} |
||||||
|
|
||||||
|
message ModelOptions { |
||||||
|
string Model = 1; |
||||||
|
int32 ContextSize = 2; |
||||||
|
int32 Seed = 3; |
||||||
|
int32 NBatch = 4; |
||||||
|
bool F16Memory = 5; |
||||||
|
bool MLock = 6; |
||||||
|
bool MMap = 7; |
||||||
|
bool VocabOnly = 8; |
||||||
|
bool LowVRAM = 9; |
||||||
|
bool Embeddings = 10; |
||||||
|
bool NUMA = 11; |
||||||
|
int32 NGPULayers = 12; |
||||||
|
string MainGPU = 13; |
||||||
|
string TensorSplit = 14; |
||||||
|
int32 Threads = 15; |
||||||
|
string LibrarySearchPath = 16; |
||||||
|
} |
||||||
|
|
||||||
|
message Result { |
||||||
|
string message = 1; |
||||||
|
bool success = 2; |
||||||
|
} |
||||||
|
|
||||||
|
message EmbeddingResult { |
||||||
|
repeated float embeddings = 1; |
||||||
|
} |
||||||
|
|
||||||
|
message TranscriptRequest { |
||||||
|
string dst = 2; |
||||||
|
string language = 3; |
||||||
|
uint32 threads = 4; |
||||||
|
} |
||||||
|
|
||||||
|
message TranscriptResult { |
||||||
|
repeated TranscriptSegment segments = 1; |
||||||
|
string text = 2; |
||||||
|
} |
||||||
|
|
||||||
|
message TranscriptSegment { |
||||||
|
int32 id = 1; |
||||||
|
int64 start = 2; |
||||||
|
int64 end = 3; |
||||||
|
string text = 4; |
||||||
|
repeated int32 tokens = 5; |
||||||
|
} |
||||||
|
|
||||||
|
message GenerateImageRequest { |
||||||
|
int32 height = 1; |
||||||
|
int32 width = 2; |
||||||
|
int32 mode = 3; |
||||||
|
int32 step = 4; |
||||||
|
int32 seed = 5; |
||||||
|
string positive_prompt = 6; |
||||||
|
string negative_prompt = 7; |
||||||
|
string dst = 8; |
||||||
|
} |
||||||
|
|
||||||
|
message TTSRequest { |
||||||
|
string text = 1; |
||||||
|
string model = 2; |
||||||
|
string dst = 3; |
||||||
|
} |
@ -0,0 +1,385 @@ |
|||||||
|
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||||
|
// versions:
|
||||||
|
// - protoc-gen-go-grpc v1.2.0
|
||||||
|
// - protoc v3.15.8
|
||||||
|
// source: pkg/grpc/proto/backend.proto
|
||||||
|
|
||||||
|
package proto |
||||||
|
|
||||||
|
import ( |
||||||
|
context "context" |
||||||
|
grpc "google.golang.org/grpc" |
||||||
|
codes "google.golang.org/grpc/codes" |
||||||
|
status "google.golang.org/grpc/status" |
||||||
|
) |
||||||
|
|
||||||
|
// This is a compile-time assertion to ensure that this generated file
|
||||||
|
// is compatible with the grpc package it is being compiled against.
|
||||||
|
// Requires gRPC-Go v1.32.0 or later.
|
||||||
|
const _ = grpc.SupportPackageIsVersion7 |
||||||
|
|
||||||
|
// BackendClient is the client API for Backend service.
|
||||||
|
//
|
||||||
|
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||||
|
type BackendClient interface { |
||||||
|
Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) |
||||||
|
Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) |
||||||
|
LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) |
||||||
|
PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) |
||||||
|
Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) |
||||||
|
GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) |
||||||
|
AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) |
||||||
|
TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) |
||||||
|
} |
||||||
|
|
||||||
|
type backendClient struct { |
||||||
|
cc grpc.ClientConnInterface |
||||||
|
} |
||||||
|
|
||||||
|
func NewBackendClient(cc grpc.ClientConnInterface) BackendClient { |
||||||
|
return &backendClient{cc} |
||||||
|
} |
||||||
|
|
||||||
|
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { |
||||||
|
out := new(Reply) |
||||||
|
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return out, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { |
||||||
|
out := new(Reply) |
||||||
|
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return out, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { |
||||||
|
out := new(Result) |
||||||
|
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return out, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) { |
||||||
|
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
x := &backendPredictStreamClient{stream} |
||||||
|
if err := x.ClientStream.SendMsg(in); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if err := x.ClientStream.CloseSend(); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return x, nil |
||||||
|
} |
||||||
|
|
||||||
|
type Backend_PredictStreamClient interface { |
||||||
|
Recv() (*Reply, error) |
||||||
|
grpc.ClientStream |
||||||
|
} |
||||||
|
|
||||||
|
type backendPredictStreamClient struct { |
||||||
|
grpc.ClientStream |
||||||
|
} |
||||||
|
|
||||||
|
func (x *backendPredictStreamClient) Recv() (*Reply, error) { |
||||||
|
m := new(Reply) |
||||||
|
if err := x.ClientStream.RecvMsg(m); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return m, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { |
||||||
|
out := new(EmbeddingResult) |
||||||
|
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return out, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) { |
||||||
|
out := new(Result) |
||||||
|
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return out, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) { |
||||||
|
out := new(TranscriptResult) |
||||||
|
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return out, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) { |
||||||
|
out := new(Result) |
||||||
|
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return out, nil |
||||||
|
} |
||||||
|
|
||||||
|
// BackendServer is the server API for Backend service.
|
||||||
|
// All implementations must embed UnimplementedBackendServer
|
||||||
|
// for forward compatibility
|
||||||
|
type BackendServer interface { |
||||||
|
Health(context.Context, *HealthMessage) (*Reply, error) |
||||||
|
Predict(context.Context, *PredictOptions) (*Reply, error) |
||||||
|
LoadModel(context.Context, *ModelOptions) (*Result, error) |
||||||
|
PredictStream(*PredictOptions, Backend_PredictStreamServer) error |
||||||
|
Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) |
||||||
|
GenerateImage(context.Context, *GenerateImageRequest) (*Result, error) |
||||||
|
AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error) |
||||||
|
TTS(context.Context, *TTSRequest) (*Result, error) |
||||||
|
mustEmbedUnimplementedBackendServer() |
||||||
|
} |
||||||
|
|
||||||
|
// UnimplementedBackendServer must be embedded to have forward compatible implementations.
|
||||||
|
type UnimplementedBackendServer struct { |
||||||
|
} |
||||||
|
|
||||||
|
func (UnimplementedBackendServer) Health(context.Context, *HealthMessage) (*Reply, error) { |
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method Health not implemented") |
||||||
|
} |
||||||
|
func (UnimplementedBackendServer) Predict(context.Context, *PredictOptions) (*Reply, error) { |
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented") |
||||||
|
} |
||||||
|
func (UnimplementedBackendServer) LoadModel(context.Context, *ModelOptions) (*Result, error) { |
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method LoadModel not implemented") |
||||||
|
} |
||||||
|
func (UnimplementedBackendServer) PredictStream(*PredictOptions, Backend_PredictStreamServer) error { |
||||||
|
return status.Errorf(codes.Unimplemented, "method PredictStream not implemented") |
||||||
|
} |
||||||
|
func (UnimplementedBackendServer) Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) { |
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method Embedding not implemented") |
||||||
|
} |
||||||
|
func (UnimplementedBackendServer) GenerateImage(context.Context, *GenerateImageRequest) (*Result, error) { |
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GenerateImage not implemented") |
||||||
|
} |
||||||
|
func (UnimplementedBackendServer) AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error) { |
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method AudioTranscription not implemented") |
||||||
|
} |
||||||
|
func (UnimplementedBackendServer) TTS(context.Context, *TTSRequest) (*Result, error) { |
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method TTS not implemented") |
||||||
|
} |
||||||
|
func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {} |
||||||
|
|
||||||
|
// UnsafeBackendServer may be embedded to opt out of forward compatibility for this service.
|
||||||
|
// Use of this interface is not recommended, as added methods to BackendServer will
|
||||||
|
// result in compilation errors.
|
||||||
|
type UnsafeBackendServer interface { |
||||||
|
mustEmbedUnimplementedBackendServer() |
||||||
|
} |
||||||
|
|
||||||
|
func RegisterBackendServer(s grpc.ServiceRegistrar, srv BackendServer) { |
||||||
|
s.RegisterService(&Backend_ServiceDesc, srv) |
||||||
|
} |
||||||
|
|
||||||
|
func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||||
|
in := new(HealthMessage) |
||||||
|
if err := dec(in); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if interceptor == nil { |
||||||
|
return srv.(BackendServer).Health(ctx, in) |
||||||
|
} |
||||||
|
info := &grpc.UnaryServerInfo{ |
||||||
|
Server: srv, |
||||||
|
FullMethod: "/backend.Backend/Health", |
||||||
|
} |
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||||
|
return srv.(BackendServer).Health(ctx, req.(*HealthMessage)) |
||||||
|
} |
||||||
|
return interceptor(ctx, in, info, handler) |
||||||
|
} |
||||||
|
|
||||||
|
func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||||
|
in := new(PredictOptions) |
||||||
|
if err := dec(in); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if interceptor == nil { |
||||||
|
return srv.(BackendServer).Predict(ctx, in) |
||||||
|
} |
||||||
|
info := &grpc.UnaryServerInfo{ |
||||||
|
Server: srv, |
||||||
|
FullMethod: "/backend.Backend/Predict", |
||||||
|
} |
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||||
|
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions)) |
||||||
|
} |
||||||
|
return interceptor(ctx, in, info, handler) |
||||||
|
} |
||||||
|
|
||||||
|
func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||||
|
in := new(ModelOptions) |
||||||
|
if err := dec(in); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if interceptor == nil { |
||||||
|
return srv.(BackendServer).LoadModel(ctx, in) |
||||||
|
} |
||||||
|
info := &grpc.UnaryServerInfo{ |
||||||
|
Server: srv, |
||||||
|
FullMethod: "/backend.Backend/LoadModel", |
||||||
|
} |
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||||
|
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions)) |
||||||
|
} |
||||||
|
return interceptor(ctx, in, info, handler) |
||||||
|
} |
||||||
|
|
||||||
|
func _Backend_PredictStream_Handler(srv interface{}, stream grpc.ServerStream) error { |
||||||
|
m := new(PredictOptions) |
||||||
|
if err := stream.RecvMsg(m); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
return srv.(BackendServer).PredictStream(m, &backendPredictStreamServer{stream}) |
||||||
|
} |
||||||
|
|
||||||
|
type Backend_PredictStreamServer interface { |
||||||
|
Send(*Reply) error |
||||||
|
grpc.ServerStream |
||||||
|
} |
||||||
|
|
||||||
|
type backendPredictStreamServer struct { |
||||||
|
grpc.ServerStream |
||||||
|
} |
||||||
|
|
||||||
|
func (x *backendPredictStreamServer) Send(m *Reply) error { |
||||||
|
return x.ServerStream.SendMsg(m) |
||||||
|
} |
||||||
|
|
||||||
|
func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||||
|
in := new(PredictOptions) |
||||||
|
if err := dec(in); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if interceptor == nil { |
||||||
|
return srv.(BackendServer).Embedding(ctx, in) |
||||||
|
} |
||||||
|
info := &grpc.UnaryServerInfo{ |
||||||
|
Server: srv, |
||||||
|
FullMethod: "/backend.Backend/Embedding", |
||||||
|
} |
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||||
|
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions)) |
||||||
|
} |
||||||
|
return interceptor(ctx, in, info, handler) |
||||||
|
} |
||||||
|
|
||||||
|
func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||||
|
in := new(GenerateImageRequest) |
||||||
|
if err := dec(in); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if interceptor == nil { |
||||||
|
return srv.(BackendServer).GenerateImage(ctx, in) |
||||||
|
} |
||||||
|
info := &grpc.UnaryServerInfo{ |
||||||
|
Server: srv, |
||||||
|
FullMethod: "/backend.Backend/GenerateImage", |
||||||
|
} |
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||||
|
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest)) |
||||||
|
} |
||||||
|
return interceptor(ctx, in, info, handler) |
||||||
|
} |
||||||
|
|
||||||
|
func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||||
|
in := new(TranscriptRequest) |
||||||
|
if err := dec(in); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if interceptor == nil { |
||||||
|
return srv.(BackendServer).AudioTranscription(ctx, in) |
||||||
|
} |
||||||
|
info := &grpc.UnaryServerInfo{ |
||||||
|
Server: srv, |
||||||
|
FullMethod: "/backend.Backend/AudioTranscription", |
||||||
|
} |
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||||
|
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest)) |
||||||
|
} |
||||||
|
return interceptor(ctx, in, info, handler) |
||||||
|
} |
||||||
|
|
||||||
|
func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||||
|
in := new(TTSRequest) |
||||||
|
if err := dec(in); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
if interceptor == nil { |
||||||
|
return srv.(BackendServer).TTS(ctx, in) |
||||||
|
} |
||||||
|
info := &grpc.UnaryServerInfo{ |
||||||
|
Server: srv, |
||||||
|
FullMethod: "/backend.Backend/TTS", |
||||||
|
} |
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||||
|
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest)) |
||||||
|
} |
||||||
|
return interceptor(ctx, in, info, handler) |
||||||
|
} |
||||||
|
|
||||||
|
// Backend_ServiceDesc is the grpc.ServiceDesc for Backend service.
|
||||||
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
|
// and not to be introspected or modified (even as a copy)
|
||||||
|
var Backend_ServiceDesc = grpc.ServiceDesc{ |
||||||
|
ServiceName: "backend.Backend", |
||||||
|
HandlerType: (*BackendServer)(nil), |
||||||
|
Methods: []grpc.MethodDesc{ |
||||||
|
{ |
||||||
|
MethodName: "Health", |
||||||
|
Handler: _Backend_Health_Handler, |
||||||
|
}, |
||||||
|
{ |
||||||
|
MethodName: "Predict", |
||||||
|
Handler: _Backend_Predict_Handler, |
||||||
|
}, |
||||||
|
{ |
||||||
|
MethodName: "LoadModel", |
||||||
|
Handler: _Backend_LoadModel_Handler, |
||||||
|
}, |
||||||
|
{ |
||||||
|
MethodName: "Embedding", |
||||||
|
Handler: _Backend_Embedding_Handler, |
||||||
|
}, |
||||||
|
{ |
||||||
|
MethodName: "GenerateImage", |
||||||
|
Handler: _Backend_GenerateImage_Handler, |
||||||
|
}, |
||||||
|
{ |
||||||
|
MethodName: "AudioTranscription", |
||||||
|
Handler: _Backend_AudioTranscription_Handler, |
||||||
|
}, |
||||||
|
{ |
||||||
|
MethodName: "TTS", |
||||||
|
Handler: _Backend_TTS_Handler, |
||||||
|
}, |
||||||
|
}, |
||||||
|
Streams: []grpc.StreamDesc{ |
||||||
|
{ |
||||||
|
StreamName: "PredictStream", |
||||||
|
Handler: _Backend_PredictStream_Handler, |
||||||
|
ServerStreams: true, |
||||||
|
}, |
||||||
|
}, |
||||||
|
Metadata: "pkg/grpc/proto/backend.proto", |
||||||
|
} |
@ -0,0 +1,126 @@ |
|||||||
|
package grpc |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"log" |
||||||
|
"net" |
||||||
|
|
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
"google.golang.org/grpc" |
||||||
|
) |
||||||
|
|
||||||
|
// A GRPC Server that allows to run LLM inference.
|
||||||
|
// It is used by the LLMServices to expose the LLM functionalities that are called by the client.
|
||||||
|
// The GRPC Service is general, trying to encompass all the possible LLM options models.
|
||||||
|
// It depends on the real implementer then what can be done or not.
|
||||||
|
//
|
||||||
|
// The server is implemented as a GRPC service, with the following methods:
|
||||||
|
// - Predict: to run the inference with options
|
||||||
|
// - PredictStream: to run the inference with options and stream the results
|
||||||
|
|
||||||
|
// server is used to implement helloworld.GreeterServer.
|
||||||
|
type server struct { |
||||||
|
pb.UnimplementedBackendServer |
||||||
|
llm LLM |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) { |
||||||
|
return &pb.Reply{Message: "OK"}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) { |
||||||
|
embeds, err := s.llm.Embeddings(in) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return &pb.EmbeddingResult{Embeddings: embeds}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { |
||||||
|
err := s.llm.Load(in) |
||||||
|
if err != nil { |
||||||
|
return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err |
||||||
|
} |
||||||
|
return &pb.Result{Message: "Loading succeeded", Success: true}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) { |
||||||
|
result, err := s.llm.Predict(in) |
||||||
|
return &pb.Reply{Message: result}, err |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) { |
||||||
|
err := s.llm.GenerateImage(in) |
||||||
|
if err != nil { |
||||||
|
return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err |
||||||
|
} |
||||||
|
return &pb.Result{Message: "Image generated", Success: true}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { |
||||||
|
err := s.llm.TTS(in) |
||||||
|
if err != nil { |
||||||
|
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err |
||||||
|
} |
||||||
|
return &pb.Result{Message: "Audio generated", Success: true}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { |
||||||
|
result, err := s.llm.AudioTranscription(in) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
tresult := &pb.TranscriptResult{} |
||||||
|
for _, s := range result.Segments { |
||||||
|
tks := []int32{} |
||||||
|
for _, t := range s.Tokens { |
||||||
|
tks = append(tks, int32(t)) |
||||||
|
} |
||||||
|
tresult.Segments = append(tresult.Segments, |
||||||
|
&pb.TranscriptSegment{ |
||||||
|
Text: s.Text, |
||||||
|
Id: int32(s.Id), |
||||||
|
Start: int64(s.Start), |
||||||
|
End: int64(s.End), |
||||||
|
Tokens: tks, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
tresult.Text = result.Text |
||||||
|
return tresult, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error { |
||||||
|
|
||||||
|
resultChan := make(chan string) |
||||||
|
|
||||||
|
done := make(chan bool) |
||||||
|
go func() { |
||||||
|
for result := range resultChan { |
||||||
|
stream.Send(&pb.Reply{Message: result}) |
||||||
|
} |
||||||
|
done <- true |
||||||
|
}() |
||||||
|
|
||||||
|
s.llm.PredictStream(in, resultChan) |
||||||
|
<-done |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func StartServer(address string, model LLM) error { |
||||||
|
lis, err := net.Listen("tcp", address) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
s := grpc.NewServer() |
||||||
|
pb.RegisterBackendServer(s, &server{llm: model}) |
||||||
|
log.Printf("gRPC Server listening at %v", lis.Addr()) |
||||||
|
if err := s.Serve(lis); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,27 @@ |
|||||||
|
package transcribe |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
whisperutil "github.com/go-skynet/LocalAI/pkg/grpc/whisper" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" |
||||||
|
) |
||||||
|
|
||||||
|
type Whisper struct { |
||||||
|
base.Base |
||||||
|
whisper whisper.Model |
||||||
|
} |
||||||
|
|
||||||
|
func (sd *Whisper) Load(opts *pb.ModelOptions) error { |
||||||
|
// Note: the Model here is a path to a directory containing the model files
|
||||||
|
w, err := whisper.New(opts.Model) |
||||||
|
sd.whisper = w |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (api.Result, error) { |
||||||
|
return whisperutil.Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) |
||||||
|
} |
@ -0,0 +1,44 @@ |
|||||||
|
package tts |
||||||
|
|
||||||
|
// This is a wrapper to statisfy the GRPC service interface
|
||||||
|
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||||
|
import ( |
||||||
|
"os" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
piper "github.com/mudler/go-piper" |
||||||
|
) |
||||||
|
|
||||||
|
type Piper struct { |
||||||
|
base.Base |
||||||
|
piper *PiperB |
||||||
|
} |
||||||
|
|
||||||
|
func (sd *Piper) Load(opts *pb.ModelOptions) error { |
||||||
|
var err error |
||||||
|
// Note: the Model here is a path to a directory containing the model files
|
||||||
|
sd.piper, err = New(opts.LibrarySearchPath) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
func (sd *Piper) TTS(opts *pb.TTSRequest) error { |
||||||
|
return sd.piper.TTS(opts.Text, opts.Model, opts.Dst) |
||||||
|
} |
||||||
|
|
||||||
|
type PiperB struct { |
||||||
|
assetDir string |
||||||
|
} |
||||||
|
|
||||||
|
func New(assetDir string) (*PiperB, error) { |
||||||
|
if _, err := os.Stat(assetDir); err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
return &PiperB{ |
||||||
|
assetDir: assetDir, |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (s *PiperB) TTS(text, model, dst string) error { |
||||||
|
return piper.TextToWav(text, model, s.assetDir, "", dst) |
||||||
|
} |
@ -0,0 +1,16 @@ |
|||||||
|
package api |
||||||
|
|
||||||
|
import "time" |
||||||
|
|
||||||
|
type Segment struct { |
||||||
|
Id int `json:"id"` |
||||||
|
Start time.Duration `json:"start"` |
||||||
|
End time.Duration `json:"end"` |
||||||
|
Text string `json:"text"` |
||||||
|
Tokens []int `json:"tokens"` |
||||||
|
} |
||||||
|
|
||||||
|
type Result struct { |
||||||
|
Segments []Segment `json:"segments"` |
||||||
|
Text string `json:"text"` |
||||||
|
} |
@ -0,0 +1,66 @@ |
|||||||
|
package model |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
|
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
) |
||||||
|
|
||||||
|
type Options struct { |
||||||
|
backendString string |
||||||
|
modelFile string |
||||||
|
threads uint32 |
||||||
|
assetDir string |
||||||
|
context context.Context |
||||||
|
|
||||||
|
gRPCOptions *pb.ModelOptions |
||||||
|
} |
||||||
|
|
||||||
|
type Option func(*Options) |
||||||
|
|
||||||
|
func WithBackendString(backend string) Option { |
||||||
|
return func(o *Options) { |
||||||
|
o.backendString = backend |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func WithModelFile(modelFile string) Option { |
||||||
|
return func(o *Options) { |
||||||
|
o.modelFile = modelFile |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func WithLoadGRPCLLMModelOpts(opts *pb.ModelOptions) Option { |
||||||
|
return func(o *Options) { |
||||||
|
o.gRPCOptions = opts |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func WithThreads(threads uint32) Option { |
||||||
|
return func(o *Options) { |
||||||
|
o.threads = threads |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func WithAssetDir(assetDir string) Option { |
||||||
|
return func(o *Options) { |
||||||
|
o.assetDir = assetDir |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func WithContext(ctx context.Context) Option { |
||||||
|
return func(o *Options) { |
||||||
|
o.context = ctx |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func NewOptions(opts ...Option) *Options { |
||||||
|
o := &Options{ |
||||||
|
gRPCOptions: &pb.ModelOptions{}, |
||||||
|
context: context.Background(), |
||||||
|
} |
||||||
|
for _, opt := range opts { |
||||||
|
opt(o) |
||||||
|
} |
||||||
|
return o |
||||||
|
} |
@ -1,12 +0,0 @@ |
|||||||
//go:build tts
|
|
||||||
// +build tts
|
|
||||||
|
|
||||||
package tts |
|
||||||
|
|
||||||
import ( |
|
||||||
piper "github.com/mudler/go-piper" |
|
||||||
) |
|
||||||
|
|
||||||
func tts(text, model, assetDir, arLib, dst string) error { |
|
||||||
return piper.TextToWav(text, model, assetDir, arLib, dst) |
|
||||||
} |
|
@ -1,10 +0,0 @@ |
|||||||
//go:build !tts
|
|
||||||
// +build !tts
|
|
||||||
|
|
||||||
package tts |
|
||||||
|
|
||||||
import "fmt" |
|
||||||
|
|
||||||
func tts(text, model, assetDir, arLib, dst string) error { |
|
||||||
return fmt.Errorf("this version of LocalAI was built without the tts tag") |
|
||||||
} |
|
@ -1,20 +0,0 @@ |
|||||||
package tts |
|
||||||
|
|
||||||
import "os" |
|
||||||
|
|
||||||
type Piper struct { |
|
||||||
assetDir string |
|
||||||
} |
|
||||||
|
|
||||||
func New(assetDir string) (*Piper, error) { |
|
||||||
if _, err := os.Stat(assetDir); err != nil { |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
return &Piper{ |
|
||||||
assetDir: assetDir, |
|
||||||
}, nil |
|
||||||
} |
|
||||||
|
|
||||||
func (s *Piper) TTS(text, model, dst string) error { |
|
||||||
return tts(text, model, s.assetDir, "", dst) |
|
||||||
} |
|
Loading…
Reference in new issue