feat: `gRPC`-based backends (#743)
commit
e3cabb555d
@ -0,0 +1,105 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"fmt" |
||||
"sync" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { |
||||
if !c.Embeddings { |
||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") |
||||
} |
||||
|
||||
modelFile := c.Model |
||||
|
||||
grpcOpts := gRPCModelOpts(c) |
||||
|
||||
var inferenceModel interface{} |
||||
var err error |
||||
|
||||
opts := []model.Option{ |
||||
model.WithLoadGRPCLLMModelOpts(grpcOpts), |
||||
model.WithThreads(uint32(c.Threads)), |
||||
model.WithAssetDir(o.AssetsDestination), |
||||
model.WithModelFile(modelFile), |
||||
model.WithContext(o.Context), |
||||
} |
||||
|
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(opts...) |
||||
} else { |
||||
opts = append(opts, model.WithBackendString(c.Backend)) |
||||
inferenceModel, err = loader.BackendLoader(opts...) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() ([]float32, error) |
||||
switch model := inferenceModel.(type) { |
||||
case *grpc.Client: |
||||
fn = func() ([]float32, error) { |
||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath) |
||||
if len(tokens) > 0 { |
||||
embeds := []int32{} |
||||
|
||||
for _, t := range tokens { |
||||
embeds = append(embeds, int32(t)) |
||||
} |
||||
predictOptions.EmbeddingTokens = embeds |
||||
|
||||
res, err := model.Embeddings(o.Context, predictOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return res.Embeddings, nil |
||||
} |
||||
predictOptions.Embeddings = s |
||||
|
||||
res, err := model.Embeddings(o.Context, predictOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return res.Embeddings, nil |
||||
} |
||||
default: |
||||
fn = func() ([]float32, error) { |
||||
return nil, fmt.Errorf("embeddings not supported by the backend") |
||||
} |
||||
} |
||||
|
||||
return func() ([]float32, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[modelFile] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[modelFile] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
embeds, err := fn() |
||||
if err != nil { |
||||
return embeds, err |
||||
} |
||||
// Remove trailing 0s
|
||||
for i := len(embeds) - 1; i >= 0; i-- { |
||||
if embeds[i] == 0.0 { |
||||
embeds = embeds[:i] |
||||
} else { |
||||
break |
||||
} |
||||
} |
||||
return embeds, nil |
||||
}, nil |
||||
} |
@ -0,0 +1,60 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"fmt" |
||||
"sync" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { |
||||
if c.Backend != model.StableDiffusionBackend { |
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models") |
||||
} |
||||
|
||||
inferenceModel, err := loader.BackendLoader( |
||||
model.WithBackendString(c.Backend), |
||||
model.WithAssetDir(o.AssetsDestination), |
||||
model.WithThreads(uint32(c.Threads)), |
||||
model.WithContext(o.Context), |
||||
model.WithModelFile(c.ImageGenerationAssets), |
||||
) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
fn := func() error { |
||||
_, err := inferenceModel.GenerateImage( |
||||
o.Context, |
||||
&proto.GenerateImageRequest{ |
||||
Height: int32(height), |
||||
Width: int32(width), |
||||
Mode: int32(mode), |
||||
Step: int32(step), |
||||
Seed: int32(seed), |
||||
PositivePrompt: positive_prompt, |
||||
NegativePrompt: negative_prompt, |
||||
Dst: dst, |
||||
}) |
||||
return err |
||||
} |
||||
|
||||
return func() error { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[c.Backend] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[c.Backend] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
return fn() |
||||
}, nil |
||||
} |
@ -0,0 +1,98 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"regexp" |
||||
"strings" |
||||
"sync" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { |
||||
modelFile := c.Model |
||||
|
||||
grpcOpts := gRPCModelOpts(c) |
||||
|
||||
var inferenceModel *grpc.Client |
||||
var err error |
||||
|
||||
opts := []model.Option{ |
||||
model.WithLoadGRPCLLMModelOpts(grpcOpts), |
||||
model.WithThreads(uint32(c.Threads)), // some models uses this to allocate threads during startup
|
||||
model.WithAssetDir(o.AssetsDestination), |
||||
model.WithModelFile(modelFile), |
||||
model.WithContext(o.Context), |
||||
} |
||||
|
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(opts...) |
||||
} else { |
||||
opts = append(opts, model.WithBackendString(c.Backend)) |
||||
inferenceModel, err = loader.BackendLoader(opts...) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
fn := func() (string, error) { |
||||
opts := gRPCPredictOpts(c, loader.ModelPath) |
||||
opts.Prompt = s |
||||
if tokenCallback != nil { |
||||
ss := "" |
||||
err := inferenceModel.PredictStream(o.Context, opts, func(s string) { |
||||
tokenCallback(s) |
||||
ss += s |
||||
}) |
||||
return ss, err |
||||
} else { |
||||
reply, err := inferenceModel.Predict(o.Context, opts) |
||||
return reply.Message, err |
||||
} |
||||
} |
||||
|
||||
return func() (string, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[modelFile] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[modelFile] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
return fn() |
||||
}, nil |
||||
} |
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
||||
var mu sync.Mutex = sync.Mutex{} |
||||
|
||||
func Finetune(config config.Config, input, prediction string) string { |
||||
if config.Echo { |
||||
prediction = input + prediction |
||||
} |
||||
|
||||
for _, c := range config.Cutstrings { |
||||
mu.Lock() |
||||
reg, ok := cutstrings[c] |
||||
if !ok { |
||||
cutstrings[c] = regexp.MustCompile(c) |
||||
reg = cutstrings[c] |
||||
} |
||||
mu.Unlock() |
||||
prediction = reg.ReplaceAllString(prediction, "") |
||||
} |
||||
|
||||
for _, c := range config.TrimSpace { |
||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
||||
} |
||||
return prediction |
||||
|
||||
} |
@ -0,0 +1,22 @@ |
||||
package backend |
||||
|
||||
import "sync" |
||||
|
||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
var mutexMap sync.Mutex |
||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) |
||||
|
||||
func Lock(s string) *sync.Mutex { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[s] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[s] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
|
||||
return l |
||||
} |
@ -0,0 +1,72 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"os" |
||||
"path/filepath" |
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
) |
||||
|
||||
func gRPCModelOpts(c config.Config) *pb.ModelOptions { |
||||
b := 512 |
||||
if c.Batch != 0 { |
||||
b = c.Batch |
||||
} |
||||
return &pb.ModelOptions{ |
||||
ContextSize: int32(c.ContextSize), |
||||
Seed: int32(c.Seed), |
||||
NBatch: int32(b), |
||||
F16Memory: c.F16, |
||||
MLock: c.MMlock, |
||||
NUMA: c.NUMA, |
||||
Embeddings: c.Embeddings, |
||||
LowVRAM: c.LowVRAM, |
||||
NGPULayers: int32(c.NGPULayers), |
||||
MMap: c.MMap, |
||||
MainGPU: c.MainGPU, |
||||
Threads: int32(c.Threads), |
||||
TensorSplit: c.TensorSplit, |
||||
} |
||||
} |
||||
|
||||
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { |
||||
promptCachePath := "" |
||||
if c.PromptCachePath != "" { |
||||
p := filepath.Join(modelPath, c.PromptCachePath) |
||||
os.MkdirAll(filepath.Dir(p), 0755) |
||||
promptCachePath = p |
||||
} |
||||
return &pb.PredictOptions{ |
||||
Temperature: float32(c.Temperature), |
||||
TopP: float32(c.TopP), |
||||
TopK: int32(c.TopK), |
||||
Tokens: int32(c.Maxtokens), |
||||
Threads: int32(c.Threads), |
||||
PromptCacheAll: c.PromptCacheAll, |
||||
PromptCacheRO: c.PromptCacheRO, |
||||
PromptCachePath: promptCachePath, |
||||
F16KV: c.F16, |
||||
DebugMode: c.Debug, |
||||
Grammar: c.Grammar, |
||||
|
||||
Mirostat: int32(c.Mirostat), |
||||
MirostatETA: float32(c.MirostatETA), |
||||
MirostatTAU: float32(c.MirostatTAU), |
||||
Debug: c.Debug, |
||||
StopPrompts: c.StopWords, |
||||
Repeat: int32(c.RepeatPenalty), |
||||
NKeep: int32(c.Keep), |
||||
Batch: int32(c.Batch), |
||||
IgnoreEOS: c.IgnoreEOS, |
||||
Seed: int32(c.Seed), |
||||
FrequencyPenalty: float32(c.FrequencyPenalty), |
||||
MLock: c.MMlock, |
||||
MMap: c.MMap, |
||||
MainGPU: c.MainGPU, |
||||
TensorSplit: c.TensorSplit, |
||||
TailFreeSamplingZ: float32(c.TFZ), |
||||
TypicalP: float32(c.TypicalP), |
||||
} |
||||
} |
@ -1,401 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/fs" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
"sync" |
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"gopkg.in/yaml.v3" |
||||
) |
||||
|
||||
type Config struct { |
||||
OpenAIRequest `yaml:"parameters"` |
||||
Name string `yaml:"name"` |
||||
StopWords []string `yaml:"stopwords"` |
||||
Cutstrings []string `yaml:"cutstrings"` |
||||
TrimSpace []string `yaml:"trimspace"` |
||||
ContextSize int `yaml:"context_size"` |
||||
F16 bool `yaml:"f16"` |
||||
NUMA bool `yaml:"numa"` |
||||
Threads int `yaml:"threads"` |
||||
Debug bool `yaml:"debug"` |
||||
Roles map[string]string `yaml:"roles"` |
||||
Embeddings bool `yaml:"embeddings"` |
||||
Backend string `yaml:"backend"` |
||||
TemplateConfig TemplateConfig `yaml:"template"` |
||||
MirostatETA float64 `yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `yaml:"mirostat_tau"` |
||||
Mirostat int `yaml:"mirostat"` |
||||
NGPULayers int `yaml:"gpu_layers"` |
||||
MMap bool `yaml:"mmap"` |
||||
MMlock bool `yaml:"mmlock"` |
||||
LowVRAM bool `yaml:"low_vram"` |
||||
|
||||
TensorSplit string `yaml:"tensor_split"` |
||||
MainGPU string `yaml:"main_gpu"` |
||||
ImageGenerationAssets string `yaml:"asset_dir"` |
||||
|
||||
PromptCachePath string `yaml:"prompt_cache_path"` |
||||
PromptCacheAll bool `yaml:"prompt_cache_all"` |
||||
PromptCacheRO bool `yaml:"prompt_cache_ro"` |
||||
|
||||
Grammar string `yaml:"grammar"` |
||||
|
||||
FunctionsConfig Functions `yaml:"function"` |
||||
|
||||
PromptStrings, InputStrings []string |
||||
InputToken [][]int |
||||
functionCallString, functionCallNameString string |
||||
} |
||||
|
||||
type Functions struct { |
||||
DisableNoAction bool `yaml:"disable_no_action"` |
||||
NoActionFunctionName string `yaml:"no_action_function_name"` |
||||
NoActionDescriptionName string `yaml:"no_action_description_name"` |
||||
} |
||||
|
||||
type TemplateConfig struct { |
||||
Completion string `yaml:"completion"` |
||||
Functions string `yaml:"function"` |
||||
Chat string `yaml:"chat"` |
||||
Edit string `yaml:"edit"` |
||||
} |
||||
|
||||
type ConfigMerger struct { |
||||
configs map[string]Config |
||||
sync.Mutex |
||||
} |
||||
|
||||
func defaultConfig(modelFile string) *Config { |
||||
return &Config{ |
||||
OpenAIRequest: defaultRequest(modelFile), |
||||
} |
||||
} |
||||
|
||||
func NewConfigMerger() *ConfigMerger { |
||||
return &ConfigMerger{ |
||||
configs: make(map[string]Config), |
||||
} |
||||
} |
||||
func ReadConfigFile(file string) ([]*Config, error) { |
||||
c := &[]*Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return *c, nil |
||||
} |
||||
|
||||
func ReadConfig(file string) (*Config, error) { |
||||
c := &Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
func (cm *ConfigMerger) LoadConfigFile(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfigFile(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot load config file: %w", err) |
||||
} |
||||
|
||||
for _, cc := range c { |
||||
cm.configs[cc.Name] = *cc |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (cm *ConfigMerger) LoadConfig(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfig(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
|
||||
cm.configs[c.Name] = *c |
||||
return nil |
||||
} |
||||
|
||||
func (cm *ConfigMerger) GetConfig(m string) (Config, bool) { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
v, exists := cm.configs[m] |
||||
return v, exists |
||||
} |
||||
|
||||
func (cm *ConfigMerger) ListConfigs() []string { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
var res []string |
||||
for k := range cm.configs { |
||||
res = append(res, k) |
||||
} |
||||
return res |
||||
} |
||||
|
||||
func (cm *ConfigMerger) LoadConfigs(path string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
entries, err := os.ReadDir(path) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
files := make([]fs.FileInfo, 0, len(entries)) |
||||
for _, entry := range entries { |
||||
info, err := entry.Info() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
files = append(files, info) |
||||
} |
||||
for _, file := range files { |
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") { |
||||
continue |
||||
} |
||||
c, err := ReadConfig(filepath.Join(path, file.Name())) |
||||
if err == nil { |
||||
cm.configs[c.Name] = *c |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func updateConfig(config *Config, input *OpenAIRequest) { |
||||
if input.Echo { |
||||
config.Echo = input.Echo |
||||
} |
||||
if input.TopK != 0 { |
||||
config.TopK = input.TopK |
||||
} |
||||
if input.TopP != 0 { |
||||
config.TopP = input.TopP |
||||
} |
||||
|
||||
if input.Grammar != "" { |
||||
config.Grammar = input.Grammar |
||||
} |
||||
|
||||
if input.Temperature != 0 { |
||||
config.Temperature = input.Temperature |
||||
} |
||||
|
||||
if input.Maxtokens != 0 { |
||||
config.Maxtokens = input.Maxtokens |
||||
} |
||||
|
||||
switch stop := input.Stop.(type) { |
||||
case string: |
||||
if stop != "" { |
||||
config.StopWords = append(config.StopWords, stop) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range stop { |
||||
if s, ok := pp.(string); ok { |
||||
config.StopWords = append(config.StopWords, s) |
||||
} |
||||
} |
||||
} |
||||
|
||||
if input.RepeatPenalty != 0 { |
||||
config.RepeatPenalty = input.RepeatPenalty |
||||
} |
||||
|
||||
if input.Keep != 0 { |
||||
config.Keep = input.Keep |
||||
} |
||||
|
||||
if input.Batch != 0 { |
||||
config.Batch = input.Batch |
||||
} |
||||
|
||||
if input.F16 { |
||||
config.F16 = input.F16 |
||||
} |
||||
|
||||
if input.IgnoreEOS { |
||||
config.IgnoreEOS = input.IgnoreEOS |
||||
} |
||||
|
||||
if input.Seed != 0 { |
||||
config.Seed = input.Seed |
||||
} |
||||
|
||||
if input.Mirostat != 0 { |
||||
config.Mirostat = input.Mirostat |
||||
} |
||||
|
||||
if input.MirostatETA != 0 { |
||||
config.MirostatETA = input.MirostatETA |
||||
} |
||||
|
||||
if input.MirostatTAU != 0 { |
||||
config.MirostatTAU = input.MirostatTAU |
||||
} |
||||
|
||||
if input.TypicalP != 0 { |
||||
config.TypicalP = input.TypicalP |
||||
} |
||||
|
||||
switch inputs := input.Input.(type) { |
||||
case string: |
||||
if inputs != "" { |
||||
config.InputStrings = append(config.InputStrings, inputs) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range inputs { |
||||
switch i := pp.(type) { |
||||
case string: |
||||
config.InputStrings = append(config.InputStrings, i) |
||||
case []interface{}: |
||||
tokens := []int{} |
||||
for _, ii := range i { |
||||
tokens = append(tokens, int(ii.(float64))) |
||||
} |
||||
config.InputToken = append(config.InputToken, tokens) |
||||
} |
||||
} |
||||
} |
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) { |
||||
case string: |
||||
if fnc != "" { |
||||
config.functionCallString = fnc |
||||
} |
||||
case map[string]interface{}: |
||||
var name string |
||||
n, exists := fnc["name"] |
||||
if exists { |
||||
nn, e := n.(string) |
||||
if e { |
||||
name = nn |
||||
} |
||||
} |
||||
config.functionCallNameString = name |
||||
} |
||||
|
||||
switch p := input.Prompt.(type) { |
||||
case string: |
||||
config.PromptStrings = append(config.PromptStrings, p) |
||||
case []interface{}: |
||||
for _, pp := range p { |
||||
if s, ok := pp.(string); ok { |
||||
config.PromptStrings = append(config.PromptStrings, s) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { |
||||
input := new(OpenAIRequest) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return "", nil, err |
||||
} |
||||
|
||||
modelFile := input.Model |
||||
|
||||
if c.Params("model") != "" { |
||||
modelFile = c.Params("model") |
||||
} |
||||
|
||||
received, _ := json.Marshal(input) |
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received)) |
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel { |
||||
models, _ := loader.ListModels() |
||||
if len(models) > 0 { |
||||
modelFile = models[0] |
||||
log.Debug().Msgf("No model specified, using: %s", modelFile) |
||||
} else { |
||||
log.Debug().Msgf("No model specified, returning error") |
||||
return "", nil, fmt.Errorf("no model specified") |
||||
} |
||||
} |
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists { |
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
||||
modelFile = bearer |
||||
} |
||||
return modelFile, input, nil |
||||
} |
||||
|
||||
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { |
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") |
||||
|
||||
var config *Config |
||||
|
||||
defaults := func() { |
||||
config = defaultConfig(modelFile) |
||||
config.ContextSize = ctx |
||||
config.Threads = threads |
||||
config.F16 = f16 |
||||
config.Debug = debug |
||||
} |
||||
|
||||
cfg, exists := cm.GetConfig(modelFile) |
||||
if !exists { |
||||
if _, err := os.Stat(modelConfig); err == nil { |
||||
if err := cm.LoadConfig(modelConfig); err != nil { |
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) |
||||
} |
||||
cfg, exists = cm.GetConfig(modelFile) |
||||
if exists { |
||||
config = &cfg |
||||
} else { |
||||
defaults() |
||||
} |
||||
} else { |
||||
defaults() |
||||
} |
||||
} else { |
||||
config = &cfg |
||||
} |
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateConfig(config, input) |
||||
|
||||
// Don't allow 0 as setting
|
||||
if config.Threads == 0 { |
||||
if threads != 0 { |
||||
config.Threads = threads |
||||
} else { |
||||
config.Threads = 4 |
||||
} |
||||
} |
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if debug { |
||||
config.Debug = true |
||||
} |
||||
|
||||
return config, input, nil |
||||
} |
@ -0,0 +1,209 @@ |
||||
package api_config |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io/fs" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"gopkg.in/yaml.v3" |
||||
) |
||||
|
||||
type Config struct { |
||||
PredictionOptions `yaml:"parameters"` |
||||
Name string `yaml:"name"` |
||||
StopWords []string `yaml:"stopwords"` |
||||
Cutstrings []string `yaml:"cutstrings"` |
||||
TrimSpace []string `yaml:"trimspace"` |
||||
ContextSize int `yaml:"context_size"` |
||||
F16 bool `yaml:"f16"` |
||||
NUMA bool `yaml:"numa"` |
||||
Threads int `yaml:"threads"` |
||||
Debug bool `yaml:"debug"` |
||||
Roles map[string]string `yaml:"roles"` |
||||
Embeddings bool `yaml:"embeddings"` |
||||
Backend string `yaml:"backend"` |
||||
TemplateConfig TemplateConfig `yaml:"template"` |
||||
MirostatETA float64 `yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `yaml:"mirostat_tau"` |
||||
Mirostat int `yaml:"mirostat"` |
||||
NGPULayers int `yaml:"gpu_layers"` |
||||
MMap bool `yaml:"mmap"` |
||||
MMlock bool `yaml:"mmlock"` |
||||
LowVRAM bool `yaml:"low_vram"` |
||||
|
||||
TensorSplit string `yaml:"tensor_split"` |
||||
MainGPU string `yaml:"main_gpu"` |
||||
ImageGenerationAssets string `yaml:"asset_dir"` |
||||
|
||||
PromptCachePath string `yaml:"prompt_cache_path"` |
||||
PromptCacheAll bool `yaml:"prompt_cache_all"` |
||||
PromptCacheRO bool `yaml:"prompt_cache_ro"` |
||||
|
||||
Grammar string `yaml:"grammar"` |
||||
|
||||
PromptStrings, InputStrings []string |
||||
InputToken [][]int |
||||
functionCallString, functionCallNameString string |
||||
|
||||
FunctionsConfig Functions `yaml:"function"` |
||||
} |
||||
|
||||
type Functions struct { |
||||
DisableNoAction bool `yaml:"disable_no_action"` |
||||
NoActionFunctionName string `yaml:"no_action_function_name"` |
||||
NoActionDescriptionName string `yaml:"no_action_description_name"` |
||||
} |
||||
|
||||
type TemplateConfig struct { |
||||
Completion string `yaml:"completion"` |
||||
Functions string `yaml:"function"` |
||||
Chat string `yaml:"chat"` |
||||
Edit string `yaml:"edit"` |
||||
} |
||||
|
||||
type ConfigLoader struct { |
||||
configs map[string]Config |
||||
sync.Mutex |
||||
} |
||||
|
||||
func (c *Config) SetFunctionCallString(s string) { |
||||
c.functionCallString = s |
||||
} |
||||
|
||||
func (c *Config) SetFunctionCallNameString(s string) { |
||||
c.functionCallNameString = s |
||||
} |
||||
|
||||
func (c *Config) ShouldUseFunctions() bool { |
||||
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) |
||||
} |
||||
|
||||
func (c *Config) ShouldCallSpecificFunction() bool { |
||||
return len(c.functionCallNameString) > 0 |
||||
} |
||||
|
||||
func (c *Config) FunctionToCall() string { |
||||
return c.functionCallNameString |
||||
} |
||||
|
||||
func defaultPredictOptions(modelFile string) PredictionOptions { |
||||
return PredictionOptions{ |
||||
TopP: 0.7, |
||||
TopK: 80, |
||||
Maxtokens: 512, |
||||
Temperature: 0.9, |
||||
Model: modelFile, |
||||
} |
||||
} |
||||
|
||||
func DefaultConfig(modelFile string) *Config { |
||||
return &Config{ |
||||
PredictionOptions: defaultPredictOptions(modelFile), |
||||
} |
||||
} |
||||
|
||||
func NewConfigLoader() *ConfigLoader { |
||||
return &ConfigLoader{ |
||||
configs: make(map[string]Config), |
||||
} |
||||
} |
||||
func ReadConfigFile(file string) ([]*Config, error) { |
||||
c := &[]*Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return *c, nil |
||||
} |
||||
|
||||
func ReadConfig(file string) (*Config, error) { |
||||
c := &Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
func (cm *ConfigLoader) LoadConfigFile(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfigFile(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot load config file: %w", err) |
||||
} |
||||
|
||||
for _, cc := range c { |
||||
cm.configs[cc.Name] = *cc |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (cm *ConfigLoader) LoadConfig(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfig(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
|
||||
cm.configs[c.Name] = *c |
||||
return nil |
||||
} |
||||
|
||||
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
v, exists := cm.configs[m] |
||||
return v, exists |
||||
} |
||||
|
||||
func (cm *ConfigLoader) ListConfigs() []string { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
var res []string |
||||
for k := range cm.configs { |
||||
res = append(res, k) |
||||
} |
||||
return res |
||||
} |
||||
|
||||
func (cm *ConfigLoader) LoadConfigs(path string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
entries, err := os.ReadDir(path) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
files := make([]fs.FileInfo, 0, len(entries)) |
||||
for _, entry := range entries { |
||||
info, err := entry.Info() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
files = append(files, info) |
||||
} |
||||
for _, file := range files { |
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") { |
||||
continue |
||||
} |
||||
c, err := ReadConfig(filepath.Join(path, file.Name())) |
||||
if err == nil { |
||||
cm.configs[c.Name] = *c |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,37 @@ |
||||
package api_config |
||||
|
||||
type PredictionOptions struct { |
||||
|
||||
// Also part of the OpenAI official spec
|
||||
Model string `json:"model" yaml:"model"` |
||||
|
||||
// Also part of the OpenAI official spec
|
||||
Language string `json:"language"` |
||||
|
||||
// Also part of the OpenAI official spec. use it for returning multiple results
|
||||
N int `json:"n"` |
||||
|
||||
// Common options between all the API calls, part of the OpenAI spec
|
||||
TopP float64 `json:"top_p" yaml:"top_p"` |
||||
TopK int `json:"top_k" yaml:"top_k"` |
||||
Temperature float64 `json:"temperature" yaml:"temperature"` |
||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"` |
||||
Echo bool `json:"echo"` |
||||
|
||||
// Custom parameters - not present in the OpenAI API
|
||||
Batch int `json:"batch" yaml:"batch"` |
||||
F16 bool `json:"f16" yaml:"f16"` |
||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` |
||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` |
||||
Keep int `json:"n_keep" yaml:"n_keep"` |
||||
|
||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` |
||||
Mirostat int `json:"mirostat" yaml:"mirostat"` |
||||
|
||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` |
||||
TFZ float64 `json:"tfz" yaml:"tfz"` |
||||
|
||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"` |
||||
Seed int `json:"seed" yaml:"seed"` |
||||
} |
@ -1,78 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/go-skynet/LocalAI/pkg/tts" |
||||
"github.com/go-skynet/LocalAI/pkg/utils" |
||||
llama "github.com/go-skynet/go-llama.cpp" |
||||
"github.com/gofiber/fiber/v2" |
||||
) |
||||
|
||||
type TTSRequest struct { |
||||
Model string `json:"model" yaml:"model"` |
||||
Input string `json:"input" yaml:"input"` |
||||
} |
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string { |
||||
counter := 1 |
||||
fileName := baseName + ext |
||||
|
||||
for { |
||||
filePath := filepath.Join(dir, fileName) |
||||
_, err := os.Stat(filePath) |
||||
if os.IsNotExist(err) { |
||||
return fileName |
||||
} |
||||
|
||||
counter++ |
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) |
||||
} |
||||
} |
||||
|
||||
func ttsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
|
||||
input := new(TTSRequest) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return err |
||||
} |
||||
|
||||
piperModel, err := o.loader.BackendLoader(model.PiperBackend, input.Model, []llama.ModelOption{}, uint32(0), o.assetsDestination) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if piperModel == nil { |
||||
return fmt.Errorf("could not load piper model") |
||||
} |
||||
|
||||
w, ok := piperModel.(*tts.Piper) |
||||
if !ok { |
||||
return fmt.Errorf("loader returned non-piper object %+v", w) |
||||
} |
||||
|
||||
if err := os.MkdirAll(o.audioDir, 0755); err != nil { |
||||
return err |
||||
} |
||||
|
||||
fileName := generateUniqueFileName(o.audioDir, "piper", ".wav") |
||||
filePath := filepath.Join(o.audioDir, fileName) |
||||
|
||||
modelPath := filepath.Join(o.loader.ModelPath, input.Model) |
||||
|
||||
if err := utils.VerifyPath(modelPath, o.loader.ModelPath); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if err := w.TTS(input.Input, modelPath, filePath); err != nil { |
||||
return err |
||||
} |
||||
|
||||
return c.Download(filePath) |
||||
} |
||||
} |
@ -0,0 +1,84 @@ |
||||
package localai |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/go-skynet/LocalAI/pkg/utils" |
||||
"github.com/gofiber/fiber/v2" |
||||
) |
||||
|
||||
type TTSRequest struct { |
||||
Model string `json:"model" yaml:"model"` |
||||
Input string `json:"input" yaml:"input"` |
||||
} |
||||
|
||||
func generateUniqueFileName(dir, baseName, ext string) string { |
||||
counter := 1 |
||||
fileName := baseName + ext |
||||
|
||||
for { |
||||
filePath := filepath.Join(dir, fileName) |
||||
_, err := os.Stat(filePath) |
||||
if os.IsNotExist(err) { |
||||
return fileName |
||||
} |
||||
|
||||
counter++ |
||||
fileName = fmt.Sprintf("%s_%d%s", baseName, counter, ext) |
||||
} |
||||
} |
||||
|
||||
func TTSEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
|
||||
input := new(TTSRequest) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return err |
||||
} |
||||
|
||||
piperModel, err := o.Loader.BackendLoader( |
||||
model.WithBackendString(model.PiperBackend), |
||||
model.WithModelFile(input.Model), |
||||
model.WithContext(o.Context), |
||||
model.WithAssetDir(o.AssetsDestination)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if piperModel == nil { |
||||
return fmt.Errorf("could not load piper model") |
||||
} |
||||
|
||||
if err := os.MkdirAll(o.AudioDir, 0755); err != nil { |
||||
return fmt.Errorf("failed creating audio directory: %s", err) |
||||
} |
||||
|
||||
fileName := generateUniqueFileName(o.AudioDir, "piper", ".wav") |
||||
filePath := filepath.Join(o.AudioDir, fileName) |
||||
|
||||
modelPath := filepath.Join(o.Loader.ModelPath, input.Model) |
||||
|
||||
if err := utils.VerifyPath(modelPath, o.Loader.ModelPath); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if _, err := piperModel.TTS(context.Background(), &proto.TTSRequest{ |
||||
Text: input.Input, |
||||
Model: modelPath, |
||||
Dst: filePath, |
||||
}); err != nil { |
||||
return err |
||||
} |
||||
|
||||
return c.Download(filePath) |
||||
} |
||||
} |
@ -1,961 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"encoding/base64" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"net/http" |
||||
"os" |
||||
"path" |
||||
"path/filepath" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" |
||||
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" |
||||
llama "github.com/go-skynet/go-llama.cpp" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"github.com/valyala/fasthttp" |
||||
) |
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
type APIError struct { |
||||
Code any `json:"code,omitempty"` |
||||
Message string `json:"message"` |
||||
Param *string `json:"param,omitempty"` |
||||
Type string `json:"type"` |
||||
} |
||||
|
||||
type ErrorResponse struct { |
||||
Error *APIError `json:"error,omitempty"` |
||||
} |
||||
|
||||
type OpenAIUsage struct { |
||||
PromptTokens int `json:"prompt_tokens"` |
||||
CompletionTokens int `json:"completion_tokens"` |
||||
TotalTokens int `json:"total_tokens"` |
||||
} |
||||
|
||||
type Item struct { |
||||
Embedding []float32 `json:"embedding"` |
||||
Index int `json:"index"` |
||||
Object string `json:"object,omitempty"` |
||||
|
||||
// Images
|
||||
URL string `json:"url,omitempty"` |
||||
B64JSON string `json:"b64_json,omitempty"` |
||||
} |
||||
|
||||
type OpenAIResponse struct { |
||||
Created int `json:"created,omitempty"` |
||||
Object string `json:"object,omitempty"` |
||||
ID string `json:"id,omitempty"` |
||||
Model string `json:"model,omitempty"` |
||||
Choices []Choice `json:"choices,omitempty"` |
||||
Data []Item `json:"data,omitempty"` |
||||
|
||||
Usage OpenAIUsage `json:"usage"` |
||||
} |
||||
|
||||
type Choice struct { |
||||
Index int `json:"index,omitempty"` |
||||
FinishReason string `json:"finish_reason,omitempty"` |
||||
Message *Message `json:"message,omitempty"` |
||||
Delta *Message `json:"delta,omitempty"` |
||||
Text string `json:"text,omitempty"` |
||||
} |
||||
|
||||
type Message struct { |
||||
// The message role
|
||||
Role string `json:"role,omitempty" yaml:"role"` |
||||
// The message content
|
||||
Content *string `json:"content" yaml:"content"` |
||||
// A result of a function call
|
||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` |
||||
} |
||||
|
||||
type OpenAIModel struct { |
||||
ID string `json:"id"` |
||||
Object string `json:"object"` |
||||
} |
||||
|
||||
type OpenAIRequest struct { |
||||
Model string `json:"model" yaml:"model"` |
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"` |
||||
Language string `json:"language"` |
||||
//whisper/image
|
||||
ResponseFormat string `json:"response_format"` |
||||
// image
|
||||
Size string `json:"size"` |
||||
// Prompt is read only by completion/image API calls
|
||||
Prompt interface{} `json:"prompt" yaml:"prompt"` |
||||
|
||||
// Edit endpoint
|
||||
Instruction string `json:"instruction" yaml:"instruction"` |
||||
Input interface{} `json:"input" yaml:"input"` |
||||
|
||||
Stop interface{} `json:"stop" yaml:"stop"` |
||||
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages" yaml:"messages"` |
||||
|
||||
// A list of available functions to call
|
||||
Functions []grammar.Function `json:"functions" yaml:"functions"` |
||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||
|
||||
Stream bool `json:"stream"` |
||||
Echo bool `json:"echo"` |
||||
// Common options between all the API calls
|
||||
TopP float64 `json:"top_p" yaml:"top_p"` |
||||
TopK int `json:"top_k" yaml:"top_k"` |
||||
Temperature float64 `json:"temperature" yaml:"temperature"` |
||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"` |
||||
|
||||
N int `json:"n"` |
||||
|
||||
// Custom parameters - not present in the OpenAI API
|
||||
Batch int `json:"batch" yaml:"batch"` |
||||
F16 bool `json:"f16" yaml:"f16"` |
||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` |
||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` |
||||
Keep int `json:"n_keep" yaml:"n_keep"` |
||||
|
||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` |
||||
Mirostat int `json:"mirostat" yaml:"mirostat"` |
||||
|
||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` |
||||
TFZ float64 `json:"tfz" yaml:"tfz"` |
||||
|
||||
Seed int `json:"seed" yaml:"seed"` |
||||
|
||||
// Image (not supported by OpenAI)
|
||||
Mode int `json:"mode"` |
||||
Step int `json:"step"` |
||||
|
||||
// A grammar to constrain the LLM output
|
||||
Grammar string `json:"grammar" yaml:"grammar"` |
||||
// A grammar object
|
||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` |
||||
|
||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"` |
||||
} |
||||
|
||||
func defaultRequest(modelFile string) OpenAIRequest { |
||||
return OpenAIRequest{ |
||||
TopP: 0.7, |
||||
TopK: 80, |
||||
Maxtokens: 512, |
||||
Temperature: 0.9, |
||||
Model: modelFile, |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
Index: 0, |
||||
Text: s, |
||||
}, |
||||
}, |
||||
Object: "text_completion", |
||||
} |
||||
log.Debug().Msgf("Sending goroutine: %s", s) |
||||
|
||||
responses <- resp |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
|
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("`input`: %+v", input) |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
if input.Stream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Completion != "" { |
||||
templateFile = config.TemplateConfig.Completion |
||||
} |
||||
|
||||
if input.Stream { |
||||
if len(config.PromptStrings) > 1 { |
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") |
||||
} |
||||
|
||||
predInput := config.PromptStrings[0] |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{ |
||||
Input: predInput, |
||||
}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} |
||||
|
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
Index: 0, |
||||
FinishReason: "stop", |
||||
}, |
||||
}, |
||||
Object: "text_completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.PromptStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{ |
||||
Input: i, |
||||
}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "text_completion", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
items := []Item{} |
||||
|
||||
for i, s := range config.InputToken { |
||||
// get the model function to call for the result
|
||||
embedFn, err := ModelEmbedding("", s, o.loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
for i, s := range config.InputStrings { |
||||
// get the model function to call for the result
|
||||
embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items, |
||||
Object: "list", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
|
||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
initialMessage := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
responses <- initialMessage |
||||
|
||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
log.Debug().Msgf("Sending goroutine: %s", s) |
||||
|
||||
responses <- resp |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
return func(c *fiber.Ctx) error { |
||||
processFunctions := false |
||||
funcs := grammar.Functions{} |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
log.Debug().Msgf("Configuration read: %+v", config) |
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer" |
||||
noActionDescription := "use this action to answer without performing any action" |
||||
|
||||
if config.FunctionsConfig.NoActionFunctionName != "" { |
||||
noActionName = config.FunctionsConfig.NoActionFunctionName |
||||
} |
||||
if config.FunctionsConfig.NoActionDescriptionName != "" { |
||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName |
||||
} |
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && |
||||
((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) { |
||||
log.Debug().Msgf("Response needs to process functions") |
||||
|
||||
processFunctions = true |
||||
|
||||
noActionGrammar := grammar.Function{ |
||||
Name: noActionName, |
||||
Description: noActionDescription, |
||||
Parameters: map[string]interface{}{ |
||||
"properties": map[string]interface{}{ |
||||
"message": map[string]interface{}{ |
||||
"type": "string", |
||||
"description": "The message to reply the user with", |
||||
}}, |
||||
}, |
||||
} |
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...) |
||||
if !config.FunctionsConfig.DisableNoAction { |
||||
funcs = append(funcs, noActionGrammar) |
||||
} |
||||
|
||||
// Force picking one of the functions by the request
|
||||
if config.functionCallNameString != "" { |
||||
funcs = funcs.Select(config.functionCallNameString) |
||||
} |
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure() |
||||
config.Grammar = jsStruct.Grammar("") |
||||
} else if input.JSONFunctionGrammarObject != nil { |
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("") |
||||
} |
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream && !processFunctions |
||||
|
||||
log.Debug().Msgf("Parameters: %+v", config) |
||||
|
||||
var predInput string |
||||
|
||||
mess := []string{} |
||||
for _, i := range input.Messages { |
||||
var content string |
||||
role := i.Role |
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if i.FunctionCall != nil && i.Role == "assistant" { |
||||
roleFn := "assistant_function_call" |
||||
r := config.Roles[roleFn] |
||||
if r != "" { |
||||
role = roleFn |
||||
} |
||||
} |
||||
r := config.Roles[role] |
||||
contentExists := i.Content != nil && *i.Content != "" |
||||
if r != "" { |
||||
if contentExists { |
||||
content = fmt.Sprint(r, " ", *i.Content) |
||||
} |
||||
if i.FunctionCall != nil { |
||||
j, err := json.Marshal(i.FunctionCall) |
||||
if err == nil { |
||||
if contentExists { |
||||
content += "\n" + fmt.Sprint(r, " ", string(j)) |
||||
} else { |
||||
content = fmt.Sprint(r, " ", string(j)) |
||||
} |
||||
} |
||||
} |
||||
} else { |
||||
if contentExists { |
||||
content = fmt.Sprint(*i.Content) |
||||
} |
||||
if i.FunctionCall != nil { |
||||
j, err := json.Marshal(i.FunctionCall) |
||||
if err == nil { |
||||
if contentExists { |
||||
content += "\n" + string(j) |
||||
} else { |
||||
content = string(j) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
mess = append(mess, content) |
||||
} |
||||
|
||||
predInput = strings.Join(mess, "\n") |
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput) |
||||
|
||||
if toStream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions { |
||||
templateFile = config.TemplateConfig.Chat |
||||
} |
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions { |
||||
templateFile = config.TemplateConfig.Functions |
||||
} |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Functions []grammar.Function |
||||
}{ |
||||
Input: predInput, |
||||
Functions: funcs, |
||||
}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} else { |
||||
log.Debug().Msgf("Template failed loading: %s", err.Error()) |
||||
} |
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput) |
||||
if processFunctions { |
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar) |
||||
} |
||||
|
||||
if toStream { |
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
FinishReason: "stop", |
||||
Index: 0, |
||||
Delta: &Message{}, |
||||
}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) { |
||||
if processFunctions { |
||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||
ss := map[string]interface{}{} |
||||
json.Unmarshal([]byte(s), &ss) |
||||
log.Debug().Msgf("Function return: %s %+v", s, ss) |
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name := ss["function"] |
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
d, _ := json.Marshal(args) |
||||
|
||||
ss["arguments"] = string(d) |
||||
ss["name"] = func_name |
||||
|
||||
// if do nothing, reply with a message
|
||||
if func_name == noActionName { |
||||
log.Debug().Msgf("nothing to do, computing a reply") |
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{} |
||||
json.Unmarshal([]byte(d), &arguments) |
||||
m, exists := arguments["message"] |
||||
if exists { |
||||
switch message := m.(type) { |
||||
case string: |
||||
if message != "" { |
||||
log.Debug().Msgf("Reply received from LLM: %s", message) |
||||
message = Finetune(*config, predInput, message) |
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) |
||||
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply") |
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU) another computation
|
||||
config.Grammar = "" |
||||
predFunc, err := ModelInference(predInput, o.loader, *config, o, nil) |
||||
if err != nil { |
||||
log.Error().Msgf("inference error: %s", err.Error()) |
||||
return |
||||
} |
||||
|
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
log.Error().Msgf("inference error: %s", err.Error()) |
||||
return |
||||
} |
||||
|
||||
prediction = Finetune(*config, predInput, prediction) |
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) |
||||
} else { |
||||
// otherwise reply with the function call
|
||||
*c = append(*c, Choice{ |
||||
FinishReason: "function_call", |
||||
Message: &Message{Role: "assistant", FunctionCall: ss}, |
||||
}) |
||||
} |
||||
|
||||
return |
||||
} |
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "chat.completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", respData) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Edit != "" { |
||||
templateFile = config.TemplateConfig.Edit |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.InputStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Instruction string |
||||
}{Input: i}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "edit", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/* |
||||
* |
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{ |
||||
"prompt": "A cute baby sea otter", |
||||
"n": 1, |
||||
"size": "512x512" |
||||
}' |
||||
|
||||
* |
||||
*/ |
||||
func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
if m == "" { |
||||
m = model.StableDiffusionBackend |
||||
} |
||||
log.Debug().Msgf("Loading model: %+v", m) |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
// XXX: Only stablediffusion is supported for now
|
||||
if config.Backend == "" { |
||||
config.Backend = model.StableDiffusionBackend |
||||
} |
||||
|
||||
sizeParts := strings.Split(input.Size, "x") |
||||
if len(sizeParts) != 2 { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
width, err := strconv.Atoi(sizeParts[0]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
height, err := strconv.Atoi(sizeParts[1]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
|
||||
b64JSON := false |
||||
if input.ResponseFormat == "b64_json" { |
||||
b64JSON = true |
||||
} |
||||
|
||||
var result []Item |
||||
for _, i := range config.PromptStrings { |
||||
n := input.N |
||||
if input.N == 0 { |
||||
n = 1 |
||||
} |
||||
for j := 0; j < n; j++ { |
||||
prompts := strings.Split(i, "|") |
||||
positive_prompt := prompts[0] |
||||
negative_prompt := "" |
||||
if len(prompts) > 1 { |
||||
negative_prompt = prompts[1] |
||||
} |
||||
|
||||
mode := 0 |
||||
step := 15 |
||||
|
||||
if input.Mode != 0 { |
||||
mode = input.Mode |
||||
} |
||||
|
||||
if input.Step != 0 { |
||||
step = input.Step |
||||
} |
||||
|
||||
tempDir := "" |
||||
if !b64JSON { |
||||
tempDir = o.imageDir |
||||
} |
||||
// Create a temporary file
|
||||
outputFile, err := ioutil.TempFile(tempDir, "b64") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
outputFile.Close() |
||||
output := outputFile.Name() + ".png" |
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
baseURL := c.BaseURL() |
||||
|
||||
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err := fn(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
item := &Item{} |
||||
|
||||
if b64JSON { |
||||
defer os.RemoveAll(output) |
||||
data, err := os.ReadFile(output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
||||
} else { |
||||
base := filepath.Base(output) |
||||
item.URL = baseURL + "/generated-images/" + base |
||||
} |
||||
|
||||
result = append(result, *item) |
||||
} |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Data: result, |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
f, err := file.Open() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer f.Close() |
||||
|
||||
dir, err := os.MkdirTemp("", "whisper") |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer os.RemoveAll(dir) |
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename)) |
||||
dstFile, err := os.Create(dst) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil { |
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst) |
||||
|
||||
whisperModel, err := o.loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads), o.assetsDestination) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if whisperModel == nil { |
||||
return fmt.Errorf("could not load whisper model") |
||||
} |
||||
|
||||
w, ok := whisperModel.(whisper.Model) |
||||
if !ok { |
||||
return fmt.Errorf("loader returned non-whisper object") |
||||
} |
||||
|
||||
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr) |
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr) |
||||
} |
||||
} |
||||
|
||||
func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
models, err := loader.ListModels() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var mm map[string]interface{} = map[string]interface{}{} |
||||
|
||||
dataModels := []OpenAIModel{} |
||||
for _, m := range models { |
||||
mm[m] = nil |
||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) |
||||
} |
||||
|
||||
for _, k := range cm.ListConfigs() { |
||||
if _, exists := mm[k]; !exists { |
||||
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) |
||||
} |
||||
} |
||||
|
||||
return c.JSON(struct { |
||||
Object string `json:"object"` |
||||
Data []OpenAIModel `json:"data"` |
||||
}{ |
||||
Object: "list", |
||||
Data: dataModels, |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,105 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||
) |
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
type APIError struct { |
||||
Code any `json:"code,omitempty"` |
||||
Message string `json:"message"` |
||||
Param *string `json:"param,omitempty"` |
||||
Type string `json:"type"` |
||||
} |
||||
|
||||
type ErrorResponse struct { |
||||
Error *APIError `json:"error,omitempty"` |
||||
} |
||||
|
||||
type OpenAIUsage struct { |
||||
PromptTokens int `json:"prompt_tokens"` |
||||
CompletionTokens int `json:"completion_tokens"` |
||||
TotalTokens int `json:"total_tokens"` |
||||
} |
||||
|
||||
type Item struct { |
||||
Embedding []float32 `json:"embedding"` |
||||
Index int `json:"index"` |
||||
Object string `json:"object,omitempty"` |
||||
|
||||
// Images
|
||||
URL string `json:"url,omitempty"` |
||||
B64JSON string `json:"b64_json,omitempty"` |
||||
} |
||||
|
||||
type OpenAIResponse struct { |
||||
Created int `json:"created,omitempty"` |
||||
Object string `json:"object,omitempty"` |
||||
ID string `json:"id,omitempty"` |
||||
Model string `json:"model,omitempty"` |
||||
Choices []Choice `json:"choices,omitempty"` |
||||
Data []Item `json:"data,omitempty"` |
||||
|
||||
Usage OpenAIUsage `json:"usage"` |
||||
} |
||||
|
||||
type Choice struct { |
||||
Index int `json:"index,omitempty"` |
||||
FinishReason string `json:"finish_reason,omitempty"` |
||||
Message *Message `json:"message,omitempty"` |
||||
Delta *Message `json:"delta,omitempty"` |
||||
Text string `json:"text,omitempty"` |
||||
} |
||||
|
||||
type Message struct { |
||||
// The message role
|
||||
Role string `json:"role,omitempty" yaml:"role"` |
||||
// The message content
|
||||
Content *string `json:"content" yaml:"content"` |
||||
// A result of a function call
|
||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` |
||||
} |
||||
|
||||
type OpenAIModel struct { |
||||
ID string `json:"id"` |
||||
Object string `json:"object"` |
||||
} |
||||
|
||||
type OpenAIRequest struct { |
||||
config.PredictionOptions |
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"` |
||||
//whisper/image
|
||||
ResponseFormat string `json:"response_format"` |
||||
// image
|
||||
Size string `json:"size"` |
||||
// Prompt is read only by completion/image API calls
|
||||
Prompt interface{} `json:"prompt" yaml:"prompt"` |
||||
|
||||
// Edit endpoint
|
||||
Instruction string `json:"instruction" yaml:"instruction"` |
||||
Input interface{} `json:"input" yaml:"input"` |
||||
|
||||
Stop interface{} `json:"stop" yaml:"stop"` |
||||
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages" yaml:"messages"` |
||||
|
||||
// A list of available functions to call
|
||||
Functions []grammar.Function `json:"functions" yaml:"functions"` |
||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||
|
||||
Stream bool `json:"stream"` |
||||
|
||||
// Image (not supported by OpenAI)
|
||||
Mode int `json:"mode"` |
||||
Step int `json:"step"` |
||||
|
||||
// A grammar to constrain the LLM output
|
||||
Grammar string `json:"grammar" yaml:"grammar"` |
||||
|
||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` |
||||
} |
@ -0,0 +1,320 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"encoding/json" |
||||
"fmt" |
||||
"strings" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"github.com/valyala/fasthttp" |
||||
) |
||||
|
||||
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
initialMessage := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
responses <- initialMessage |
||||
|
||||
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
|
||||
responses <- resp |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
return func(c *fiber.Ctx) error { |
||||
processFunctions := false |
||||
funcs := grammar.Functions{} |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
log.Debug().Msgf("Configuration read: %+v", config) |
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer" |
||||
noActionDescription := "use this action to answer without performing any action" |
||||
|
||||
if config.FunctionsConfig.NoActionFunctionName != "" { |
||||
noActionName = config.FunctionsConfig.NoActionFunctionName |
||||
} |
||||
if config.FunctionsConfig.NoActionDescriptionName != "" { |
||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName |
||||
} |
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() { |
||||
log.Debug().Msgf("Response needs to process functions") |
||||
|
||||
processFunctions = true |
||||
|
||||
noActionGrammar := grammar.Function{ |
||||
Name: noActionName, |
||||
Description: noActionDescription, |
||||
Parameters: map[string]interface{}{ |
||||
"properties": map[string]interface{}{ |
||||
"message": map[string]interface{}{ |
||||
"type": "string", |
||||
"description": "The message to reply the user with", |
||||
}}, |
||||
}, |
||||
} |
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...) |
||||
if !config.FunctionsConfig.DisableNoAction { |
||||
funcs = append(funcs, noActionGrammar) |
||||
} |
||||
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" { |
||||
funcs = funcs.Select(config.FunctionToCall()) |
||||
} |
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure() |
||||
config.Grammar = jsStruct.Grammar("") |
||||
} else if input.JSONFunctionGrammarObject != nil { |
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("") |
||||
} |
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream && !processFunctions |
||||
|
||||
log.Debug().Msgf("Parameters: %+v", config) |
||||
|
||||
var predInput string |
||||
|
||||
mess := []string{} |
||||
for _, i := range input.Messages { |
||||
var content string |
||||
role := i.Role |
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if i.FunctionCall != nil && i.Role == "assistant" { |
||||
roleFn := "assistant_function_call" |
||||
r := config.Roles[roleFn] |
||||
if r != "" { |
||||
role = roleFn |
||||
} |
||||
} |
||||
r := config.Roles[role] |
||||
contentExists := i.Content != nil && *i.Content != "" |
||||
if r != "" { |
||||
if contentExists { |
||||
content = fmt.Sprint(r, " ", *i.Content) |
||||
} |
||||
if i.FunctionCall != nil { |
||||
j, err := json.Marshal(i.FunctionCall) |
||||
if err == nil { |
||||
if contentExists { |
||||
content += "\n" + fmt.Sprint(r, " ", string(j)) |
||||
} else { |
||||
content = fmt.Sprint(r, " ", string(j)) |
||||
} |
||||
} |
||||
} |
||||
} else { |
||||
if contentExists { |
||||
content = fmt.Sprint(*i.Content) |
||||
} |
||||
if i.FunctionCall != nil { |
||||
j, err := json.Marshal(i.FunctionCall) |
||||
if err == nil { |
||||
if contentExists { |
||||
content += "\n" + string(j) |
||||
} else { |
||||
content = string(j) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
mess = append(mess, content) |
||||
} |
||||
|
||||
predInput = strings.Join(mess, "\n") |
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput) |
||||
|
||||
if toStream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions { |
||||
templateFile = config.TemplateConfig.Chat |
||||
} |
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions { |
||||
templateFile = config.TemplateConfig.Functions |
||||
} |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Functions []grammar.Function |
||||
}{ |
||||
Input: predInput, |
||||
Functions: funcs, |
||||
}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} else { |
||||
log.Debug().Msgf("Template failed loading: %s", err.Error()) |
||||
} |
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput) |
||||
if processFunctions { |
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar) |
||||
} |
||||
|
||||
if toStream { |
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.Loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
FinishReason: "stop", |
||||
Index: 0, |
||||
Delta: &Message{}, |
||||
}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||
if processFunctions { |
||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||
ss := map[string]interface{}{} |
||||
json.Unmarshal([]byte(s), &ss) |
||||
log.Debug().Msgf("Function return: %s %+v", s, ss) |
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name := ss["function"] |
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
d, _ := json.Marshal(args) |
||||
|
||||
ss["arguments"] = string(d) |
||||
ss["name"] = func_name |
||||
|
||||
// if do nothing, reply with a message
|
||||
if func_name == noActionName { |
||||
log.Debug().Msgf("nothing to do, computing a reply") |
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{} |
||||
json.Unmarshal([]byte(d), &arguments) |
||||
m, exists := arguments["message"] |
||||
if exists { |
||||
switch message := m.(type) { |
||||
case string: |
||||
if message != "" { |
||||
log.Debug().Msgf("Reply received from LLM: %s", message) |
||||
message = backend.Finetune(*config, predInput, message) |
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) |
||||
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply") |
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU) another computation
|
||||
config.Grammar = "" |
||||
predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil) |
||||
if err != nil { |
||||
log.Error().Msgf("inference error: %s", err.Error()) |
||||
return |
||||
} |
||||
|
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
log.Error().Msgf("inference error: %s", err.Error()) |
||||
return |
||||
} |
||||
|
||||
prediction = backend.Finetune(*config, predInput, prediction) |
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) |
||||
} else { |
||||
// otherwise reply with the function call
|
||||
*c = append(*c, Choice{ |
||||
FinishReason: "function_call", |
||||
Message: &Message{Role: "assistant", FunctionCall: ss}, |
||||
}) |
||||
} |
||||
|
||||
return |
||||
} |
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "chat.completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", respData) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,159 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"github.com/valyala/fasthttp" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
Index: 0, |
||||
Text: s, |
||||
}, |
||||
}, |
||||
Object: "text_completion", |
||||
} |
||||
log.Debug().Msgf("Sending goroutine: %s", s) |
||||
|
||||
responses <- resp |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
|
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("`input`: %+v", input) |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
if input.Stream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Completion != "" { |
||||
templateFile = config.TemplateConfig.Completion |
||||
} |
||||
|
||||
if input.Stream { |
||||
if len(config.PromptStrings) > 1 { |
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") |
||||
} |
||||
|
||||
predInput := config.PromptStrings[0] |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{ |
||||
Input: predInput, |
||||
}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} |
||||
|
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.Loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
Index: 0, |
||||
FinishReason: "stop", |
||||
}, |
||||
}, |
||||
Object: "text_completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.PromptStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{ |
||||
Input: i, |
||||
}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "text_completion", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,67 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Edit != "" { |
||||
templateFile = config.TemplateConfig.Edit |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.InputStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Instruction string |
||||
}{Input: i}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "edit", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,70 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
items := []Item{} |
||||
|
||||
for i, s := range config.InputToken { |
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
for i, s := range config.InputStrings { |
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items, |
||||
Object: "list", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,158 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/base64" |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"path/filepath" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/* |
||||
* |
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{ |
||||
"prompt": "A cute baby sea otter", |
||||
"n": 1, |
||||
"size": "512x512" |
||||
}' |
||||
|
||||
* |
||||
*/ |
||||
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.Loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
if m == "" { |
||||
m = model.StableDiffusionBackend |
||||
} |
||||
log.Debug().Msgf("Loading model: %+v", m) |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
// XXX: Only stablediffusion is supported for now
|
||||
if config.Backend == "" { |
||||
config.Backend = model.StableDiffusionBackend |
||||
} |
||||
|
||||
sizeParts := strings.Split(input.Size, "x") |
||||
if len(sizeParts) != 2 { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
width, err := strconv.Atoi(sizeParts[0]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
height, err := strconv.Atoi(sizeParts[1]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
|
||||
b64JSON := false |
||||
if input.ResponseFormat == "b64_json" { |
||||
b64JSON = true |
||||
} |
||||
|
||||
var result []Item |
||||
for _, i := range config.PromptStrings { |
||||
n := input.N |
||||
if input.N == 0 { |
||||
n = 1 |
||||
} |
||||
for j := 0; j < n; j++ { |
||||
prompts := strings.Split(i, "|") |
||||
positive_prompt := prompts[0] |
||||
negative_prompt := "" |
||||
if len(prompts) > 1 { |
||||
negative_prompt = prompts[1] |
||||
} |
||||
|
||||
mode := 0 |
||||
step := 15 |
||||
|
||||
if input.Mode != 0 { |
||||
mode = input.Mode |
||||
} |
||||
|
||||
if input.Step != 0 { |
||||
step = input.Step |
||||
} |
||||
|
||||
tempDir := "" |
||||
if !b64JSON { |
||||
tempDir = o.ImageDir |
||||
} |
||||
// Create a temporary file
|
||||
outputFile, err := ioutil.TempFile(tempDir, "b64") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
outputFile.Close() |
||||
output := outputFile.Name() + ".png" |
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
baseURL := c.BaseURL() |
||||
|
||||
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.Loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err := fn(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
item := &Item{} |
||||
|
||||
if b64JSON { |
||||
defer os.RemoveAll(output) |
||||
data, err := os.ReadFile(output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
||||
} else { |
||||
base := filepath.Base(output) |
||||
item.URL = baseURL + "/generated-images/" + base |
||||
} |
||||
|
||||
result = append(result, *item) |
||||
} |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Data: result, |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,36 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { |
||||
result := []Choice{} |
||||
|
||||
if n == 0 { |
||||
n = 1 |
||||
} |
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback) |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
for i := 0; i < n; i++ { |
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
prediction = backend.Finetune(*config, predInput, prediction) |
||||
cb(prediction, &result) |
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
} |
||||
return result, err |
||||
} |
@ -0,0 +1,37 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
) |
||||
|
||||
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
models, err := loader.ListModels() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var mm map[string]interface{} = map[string]interface{}{} |
||||
|
||||
dataModels := []OpenAIModel{} |
||||
for _, m := range models { |
||||
mm[m] = nil |
||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) |
||||
} |
||||
|
||||
for _, k := range cm.ListConfigs() { |
||||
if _, exists := mm[k]; !exists { |
||||
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) |
||||
} |
||||
} |
||||
|
||||
return c.JSON(struct { |
||||
Object string `json:"object"` |
||||
Data []OpenAIModel `json:"data"` |
||||
}{ |
||||
Object: "list", |
||||
Data: dataModels, |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,234 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { |
||||
input := new(OpenAIRequest) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return "", nil, err |
||||
} |
||||
|
||||
modelFile := input.Model |
||||
|
||||
if c.Params("model") != "" { |
||||
modelFile = c.Params("model") |
||||
} |
||||
|
||||
received, _ := json.Marshal(input) |
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received)) |
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel { |
||||
models, _ := loader.ListModels() |
||||
if len(models) > 0 { |
||||
modelFile = models[0] |
||||
log.Debug().Msgf("No model specified, using: %s", modelFile) |
||||
} else { |
||||
log.Debug().Msgf("No model specified, returning error") |
||||
return "", nil, fmt.Errorf("no model specified") |
||||
} |
||||
} |
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists { |
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
||||
modelFile = bearer |
||||
} |
||||
return modelFile, input, nil |
||||
} |
||||
|
||||
func updateConfig(config *config.Config, input *OpenAIRequest) { |
||||
if input.Echo { |
||||
config.Echo = input.Echo |
||||
} |
||||
if input.TopK != 0 { |
||||
config.TopK = input.TopK |
||||
} |
||||
if input.TopP != 0 { |
||||
config.TopP = input.TopP |
||||
} |
||||
|
||||
if input.Grammar != "" { |
||||
config.Grammar = input.Grammar |
||||
} |
||||
|
||||
if input.Temperature != 0 { |
||||
config.Temperature = input.Temperature |
||||
} |
||||
|
||||
if input.Maxtokens != 0 { |
||||
config.Maxtokens = input.Maxtokens |
||||
} |
||||
|
||||
switch stop := input.Stop.(type) { |
||||
case string: |
||||
if stop != "" { |
||||
config.StopWords = append(config.StopWords, stop) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range stop { |
||||
if s, ok := pp.(string); ok { |
||||
config.StopWords = append(config.StopWords, s) |
||||
} |
||||
} |
||||
} |
||||
|
||||
if input.RepeatPenalty != 0 { |
||||
config.RepeatPenalty = input.RepeatPenalty |
||||
} |
||||
|
||||
if input.Keep != 0 { |
||||
config.Keep = input.Keep |
||||
} |
||||
|
||||
if input.Batch != 0 { |
||||
config.Batch = input.Batch |
||||
} |
||||
|
||||
if input.F16 { |
||||
config.F16 = input.F16 |
||||
} |
||||
|
||||
if input.IgnoreEOS { |
||||
config.IgnoreEOS = input.IgnoreEOS |
||||
} |
||||
|
||||
if input.Seed != 0 { |
||||
config.Seed = input.Seed |
||||
} |
||||
|
||||
if input.Mirostat != 0 { |
||||
config.Mirostat = input.Mirostat |
||||
} |
||||
|
||||
if input.MirostatETA != 0 { |
||||
config.MirostatETA = input.MirostatETA |
||||
} |
||||
|
||||
if input.MirostatTAU != 0 { |
||||
config.MirostatTAU = input.MirostatTAU |
||||
} |
||||
|
||||
if input.TypicalP != 0 { |
||||
config.TypicalP = input.TypicalP |
||||
} |
||||
|
||||
switch inputs := input.Input.(type) { |
||||
case string: |
||||
if inputs != "" { |
||||
config.InputStrings = append(config.InputStrings, inputs) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range inputs { |
||||
switch i := pp.(type) { |
||||
case string: |
||||
config.InputStrings = append(config.InputStrings, i) |
||||
case []interface{}: |
||||
tokens := []int{} |
||||
for _, ii := range i { |
||||
tokens = append(tokens, int(ii.(float64))) |
||||
} |
||||
config.InputToken = append(config.InputToken, tokens) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) { |
||||
case string: |
||||
if fnc != "" { |
||||
config.SetFunctionCallString(fnc) |
||||
} |
||||
case map[string]interface{}: |
||||
var name string |
||||
n, exists := fnc["name"] |
||||
if exists { |
||||
nn, e := n.(string) |
||||
if !e { |
||||
name = nn |
||||
} |
||||
} |
||||
config.SetFunctionCallNameString(name) |
||||
} |
||||
|
||||
switch p := input.Prompt.(type) { |
||||
case string: |
||||
config.PromptStrings = append(config.PromptStrings, p) |
||||
case []interface{}: |
||||
for _, pp := range p { |
||||
if s, ok := pp.(string); ok { |
||||
config.PromptStrings = append(config.PromptStrings, s) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) { |
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") |
||||
|
||||
var cfg *config.Config |
||||
|
||||
defaults := func() { |
||||
cfg = config.DefaultConfig(modelFile) |
||||
cfg.ContextSize = ctx |
||||
cfg.Threads = threads |
||||
cfg.F16 = f16 |
||||
cfg.Debug = debug |
||||
} |
||||
|
||||
cfgExisting, exists := cm.GetConfig(modelFile) |
||||
if !exists { |
||||
if _, err := os.Stat(modelConfig); err == nil { |
||||
if err := cm.LoadConfig(modelConfig); err != nil { |
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) |
||||
} |
||||
cfgExisting, exists = cm.GetConfig(modelFile) |
||||
if exists { |
||||
cfg = &cfgExisting |
||||
} else { |
||||
defaults() |
||||
} |
||||
} else { |
||||
defaults() |
||||
} |
||||
} else { |
||||
cfg = &cfgExisting |
||||
} |
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateConfig(cfg, input) |
||||
|
||||
// Don't allow 0 as setting
|
||||
if cfg.Threads == 0 { |
||||
if threads != 0 { |
||||
cfg.Threads = threads |
||||
} else { |
||||
cfg.Threads = 4 |
||||
} |
||||
} |
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if debug { |
||||
cfg.Debug = true |
||||
} |
||||
|
||||
return cfg, input, nil |
||||
} |
@ -0,0 +1,91 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"io" |
||||
"net/http" |
||||
"os" |
||||
"path" |
||||
"path/filepath" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
|
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.Loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
f, err := file.Open() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer f.Close() |
||||
|
||||
dir, err := os.MkdirTemp("", "whisper") |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer os.RemoveAll(dir) |
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename)) |
||||
dstFile, err := os.Create(dst) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil { |
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst) |
||||
|
||||
whisperModel, err := o.Loader.BackendLoader( |
||||
model.WithBackendString(model.WhisperBackend), |
||||
model.WithModelFile(config.Model), |
||||
model.WithContext(o.Context), |
||||
model.WithThreads(uint32(config.Threads)), |
||||
model.WithAssetDir(o.AssetsDestination)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if whisperModel == nil { |
||||
return fmt.Errorf("could not load whisper model") |
||||
} |
||||
|
||||
tr, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ |
||||
Dst: dst, |
||||
Language: input.Language, |
||||
Threads: uint32(config.Threads), |
||||
}) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr) |
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr) |
||||
} |
||||
} |
@ -1,649 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
"regexp" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"github.com/donomii/go-rwkv.cpp" |
||||
"github.com/go-skynet/LocalAI/pkg/langchain" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion" |
||||
"github.com/go-skynet/bloomz.cpp" |
||||
bert "github.com/go-skynet/go-bert.cpp" |
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
llama "github.com/go-skynet/go-llama.cpp" |
||||
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" |
||||
) |
||||
|
||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
var mutexMap sync.Mutex |
||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) |
||||
|
||||
func defaultLLamaOpts(c Config) []llama.ModelOption { |
||||
llamaOpts := []llama.ModelOption{} |
||||
if c.ContextSize != 0 { |
||||
llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize)) |
||||
} |
||||
if c.F16 { |
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory) |
||||
} |
||||
if c.Embeddings { |
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings) |
||||
} |
||||
|
||||
if c.NGPULayers != 0 { |
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(c.NGPULayers)) |
||||
} |
||||
|
||||
llamaOpts = append(llamaOpts, llama.SetMMap(c.MMap)) |
||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(c.MainGPU)) |
||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(c.TensorSplit)) |
||||
if c.Batch != 0 { |
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(c.Batch)) |
||||
} else { |
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(512)) |
||||
} |
||||
|
||||
if c.NUMA { |
||||
llamaOpts = append(llamaOpts, llama.EnableNUMA) |
||||
} |
||||
|
||||
if c.LowVRAM { |
||||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) |
||||
} |
||||
|
||||
return llamaOpts |
||||
} |
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) { |
||||
if c.Backend != model.StableDiffusionBackend { |
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models") |
||||
} |
||||
inferenceModel, err := loader.BackendLoader(c.Backend, c.ImageGenerationAssets, []llama.ModelOption{}, uint32(c.Threads), o.assetsDestination) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() error |
||||
switch model := inferenceModel.(type) { |
||||
case *stablediffusion.StableDiffusion: |
||||
fn = func() error { |
||||
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) |
||||
} |
||||
|
||||
default: |
||||
fn = func() error { |
||||
return fmt.Errorf("creation of images not supported by the backend") |
||||
} |
||||
} |
||||
|
||||
return func() error { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[c.Backend] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[c.Backend] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
return fn() |
||||
}, nil |
||||
} |
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, o *Option) (func() ([]float32, error), error) { |
||||
if !c.Embeddings { |
||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") |
||||
} |
||||
|
||||
modelFile := c.Model |
||||
|
||||
llamaOpts := defaultLLamaOpts(c) |
||||
|
||||
var inferenceModel interface{} |
||||
var err error |
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) |
||||
} else { |
||||
inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() ([]float32, error) |
||||
switch model := inferenceModel.(type) { |
||||
case *llama.LLama: |
||||
fn = func() ([]float32, error) { |
||||
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) |
||||
if len(tokens) > 0 { |
||||
return model.TokenEmbeddings(tokens, predictOptions...) |
||||
} |
||||
return model.Embeddings(s, predictOptions...) |
||||
} |
||||
// bert embeddings
|
||||
case *bert.Bert: |
||||
fn = func() ([]float32, error) { |
||||
if len(tokens) > 0 { |
||||
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) |
||||
} |
||||
return model.Embeddings(s, bert.SetThreads(c.Threads)) |
||||
} |
||||
default: |
||||
fn = func() ([]float32, error) { |
||||
return nil, fmt.Errorf("embeddings not supported by the backend") |
||||
} |
||||
} |
||||
|
||||
return func() ([]float32, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[modelFile] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[modelFile] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
embeds, err := fn() |
||||
if err != nil { |
||||
return embeds, err |
||||
} |
||||
// Remove trailing 0s
|
||||
for i := len(embeds) - 1; i >= 0; i-- { |
||||
if embeds[i] == 0.0 { |
||||
embeds = embeds[:i] |
||||
} else { |
||||
break |
||||
} |
||||
} |
||||
return embeds, nil |
||||
}, nil |
||||
} |
||||
|
||||
func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []llama.PredictOption{ |
||||
llama.SetTemperature(c.Temperature), |
||||
llama.SetTopP(c.TopP), |
||||
llama.SetTopK(c.TopK), |
||||
llama.SetTokens(c.Maxtokens), |
||||
llama.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.PromptCacheAll { |
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll) |
||||
} |
||||
|
||||
if c.PromptCacheRO { |
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO) |
||||
} |
||||
|
||||
predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar)) |
||||
|
||||
if c.PromptCachePath != "" { |
||||
// Create parent directory
|
||||
p := filepath.Join(modelPath, c.PromptCachePath) |
||||
os.MkdirAll(filepath.Dir(p), 0755) |
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(p)) |
||||
} |
||||
|
||||
if c.Mirostat != 0 { |
||||
predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) |
||||
} |
||||
|
||||
if c.MirostatETA != 0 { |
||||
predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) |
||||
} |
||||
|
||||
if c.MirostatTAU != 0 { |
||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) |
||||
} |
||||
|
||||
if c.Debug { |
||||
predictOptions = append(predictOptions, llama.Debug) |
||||
} |
||||
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) |
||||
|
||||
if c.RepeatPenalty != 0 { |
||||
predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) |
||||
} |
||||
|
||||
if c.Keep != 0 { |
||||
predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.F16 { |
||||
predictOptions = append(predictOptions, llama.EnableF16KV) |
||||
} |
||||
|
||||
if c.IgnoreEOS { |
||||
predictOptions = append(predictOptions, llama.IgnoreEOS) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(c.FrequencyPenalty)) |
||||
predictOptions = append(predictOptions, llama.SetMlock(c.MMlock)) |
||||
predictOptions = append(predictOptions, llama.SetMemoryMap(c.MMap)) |
||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(c.MainGPU)) |
||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(c.TensorSplit)) |
||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(c.TFZ)) |
||||
predictOptions = append(predictOptions, llama.SetTypicalP(c.TypicalP)) |
||||
|
||||
return predictOptions |
||||
} |
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, tokenCallback func(string) bool) (func() (string, error), error) { |
||||
supportStreams := false |
||||
modelFile := c.Model |
||||
|
||||
llamaOpts := defaultLLamaOpts(c) |
||||
|
||||
var inferenceModel interface{} |
||||
var err error |
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) |
||||
} else { |
||||
inferenceModel, err = loader.BackendLoader(c.Backend, modelFile, llamaOpts, uint32(c.Threads), o.assetsDestination) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() (string, error) |
||||
|
||||
switch model := inferenceModel.(type) { |
||||
case *rwkv.RwkvState: |
||||
supportStreams = true |
||||
|
||||
fn = func() (string, error) { |
||||
stopWord := "\n" |
||||
if len(c.StopWords) > 0 { |
||||
stopWord = c.StopWords[0] |
||||
} |
||||
|
||||
if err := model.ProcessInput(s); err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) |
||||
|
||||
return response, nil |
||||
} |
||||
case *transformers.GPTNeoX: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.Replit: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.Starcoder: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.MPT: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *bloomz.Bloomz: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []bloomz.PredictOption{ |
||||
bloomz.SetTemperature(c.Temperature), |
||||
bloomz.SetTopP(c.TopP), |
||||
bloomz.SetTopK(c.TopK), |
||||
bloomz.SetTokens(c.Maxtokens), |
||||
bloomz.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.Falcon: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.GPTJ: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.Dolly: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *transformers.GPT2: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(c.Temperature), |
||||
transformers.SetTopP(c.TopP), |
||||
transformers.SetTopK(c.TopK), |
||||
transformers.SetTokens(c.Maxtokens), |
||||
transformers.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
case *gpt4all.Model: |
||||
supportStreams = true |
||||
|
||||
fn = func() (string, error) { |
||||
if tokenCallback != nil { |
||||
model.SetTokenCallback(tokenCallback) |
||||
} |
||||
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []gpt4all.PredictOption{ |
||||
gpt4all.SetTemperature(c.Temperature), |
||||
gpt4all.SetTopP(c.TopP), |
||||
gpt4all.SetTopK(c.TopK), |
||||
gpt4all.SetTokens(c.Maxtokens), |
||||
} |
||||
|
||||
if c.Batch != 0 { |
||||
predictOptions = append(predictOptions, gpt4all.SetBatch(c.Batch)) |
||||
} |
||||
|
||||
str, er := model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
||||
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
||||
// after a stream event has occurred
|
||||
model.SetTokenCallback(nil) |
||||
return str, er |
||||
} |
||||
case *llama.LLama: |
||||
supportStreams = true |
||||
fn = func() (string, error) { |
||||
|
||||
if tokenCallback != nil { |
||||
model.SetTokenCallback(tokenCallback) |
||||
} |
||||
|
||||
predictOptions := buildLLamaPredictOptions(c, loader.ModelPath) |
||||
|
||||
str, er := model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
// Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels)
|
||||
// For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}}
|
||||
// after a stream event has occurred
|
||||
model.SetTokenCallback(nil) |
||||
return str, er |
||||
} |
||||
case *langchain.HuggingFace: |
||||
fn = func() (string, error) { |
||||
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []langchain.PredictOption{ |
||||
langchain.SetModel(c.Model), |
||||
langchain.SetMaxTokens(c.Maxtokens), |
||||
langchain.SetTemperature(c.Temperature), |
||||
langchain.SetStopWords(c.StopWords), |
||||
} |
||||
|
||||
pred, er := model.PredictHuggingFace(s, predictOptions...) |
||||
if er != nil { |
||||
return "", er |
||||
} |
||||
return pred.Completion, nil |
||||
} |
||||
} |
||||
|
||||
return func() (string, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[modelFile] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[modelFile] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
res, err := fn() |
||||
if tokenCallback != nil && !supportStreams { |
||||
tokenCallback(res) |
||||
} |
||||
return res, err |
||||
}, nil |
||||
} |
||||
|
||||
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, o *Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { |
||||
result := []Choice{} |
||||
|
||||
n := input.N |
||||
|
||||
if input.N == 0 { |
||||
n = 1 |
||||
} |
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := ModelInference(predInput, loader, *config, o, tokenCallback) |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
for i := 0; i < n; i++ { |
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
prediction = Finetune(*config, predInput, prediction) |
||||
cb(prediction, &result) |
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
} |
||||
return result, err |
||||
} |
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
||||
var mu sync.Mutex = sync.Mutex{} |
||||
|
||||
func Finetune(config Config, input, prediction string) string { |
||||
if config.Echo { |
||||
prediction = input + prediction |
||||
} |
||||
|
||||
for _, c := range config.Cutstrings { |
||||
mu.Lock() |
||||
reg, ok := cutstrings[c] |
||||
if !ok { |
||||
cutstrings[c] = regexp.MustCompile(c) |
||||
reg = cutstrings[c] |
||||
} |
||||
mu.Unlock() |
||||
prediction = reg.ReplaceAllString(prediction, "") |
||||
} |
||||
|
||||
for _, c := range config.TrimSpace { |
||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
||||
} |
||||
return prediction |
||||
|
||||
} |
@ -0,0 +1,22 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
bert "github.com/go-skynet/LocalAI/pkg/grpc/llm/bert" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &bert.Embeddings{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
bloomz "github.com/go-skynet/LocalAI/pkg/grpc/llm/bloomz" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &bloomz.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Dolly{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Falcon{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,25 @@ |
||||
package main |
||||
|
||||
// GRPC Falcon server
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
falcon "github.com/go-skynet/LocalAI/pkg/grpc/llm/falcon" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &falcon.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.GPT2{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
gpt4all "github.com/go-skynet/LocalAI/pkg/grpc/llm/gpt4all" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &gpt4all.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.GPTJ{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.GPTNeoX{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
langchain "github.com/go-skynet/LocalAI/pkg/grpc/llm/langchain" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &langchain.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,25 @@ |
||||
package main |
||||
|
||||
// GRPC Falcon server
|
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
llama "github.com/go-skynet/LocalAI/pkg/grpc/llm/llama" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &llama.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.MPT{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
tts "github.com/go-skynet/LocalAI/pkg/grpc/tts" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &tts.Piper{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Replit{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
rwkv "github.com/go-skynet/LocalAI/pkg/grpc/llm/rwkv" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &rwkv.LLM{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
image "github.com/go-skynet/LocalAI/pkg/grpc/image" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &image.StableDiffusion{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transformers "github.com/go-skynet/LocalAI/pkg/grpc/llm/transformers" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transformers.Starcoder{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,23 @@ |
||||
package main |
||||
|
||||
// Note: this is started internally by LocalAI and a server is allocated for each model
|
||||
|
||||
import ( |
||||
"flag" |
||||
|
||||
transcribe "github.com/go-skynet/LocalAI/pkg/grpc/transcribe" |
||||
|
||||
grpc "github.com/go-skynet/LocalAI/pkg/grpc" |
||||
) |
||||
|
||||
var ( |
||||
addr = flag.String("addr", "localhost:50051", "the address to connect to") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
|
||||
if err := grpc.StartServer(*addr, &transcribe.Whisper{}); err != nil { |
||||
panic(err) |
||||
} |
||||
} |
@ -0,0 +1,42 @@ |
||||
package base |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" |
||||
) |
||||
|
||||
type Base struct { |
||||
} |
||||
|
||||
func (llm *Base) Load(opts *pb.ModelOptions) error { |
||||
return fmt.Errorf("unimplemented") |
||||
|
||||
} |
||||
|
||||
func (llm *Base) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return "", fmt.Errorf("unimplemented") |
||||
} |
||||
|
||||
func (llm *Base) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
return fmt.Errorf("unimplemented") |
||||
} |
||||
|
||||
func (llm *Base) Embeddings(opts *pb.PredictOptions) ([]float32, error) { |
||||
return []float32{}, fmt.Errorf("unimplemented") |
||||
} |
||||
|
||||
func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error { |
||||
return fmt.Errorf("unimplemented") |
||||
} |
||||
|
||||
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (api.Result, error) { |
||||
return api.Result{}, fmt.Errorf("unimplemented") |
||||
} |
||||
|
||||
func (llm *Base) TTS(*pb.TTSRequest) error { |
||||
return fmt.Errorf("unimplemented") |
||||
} |
@ -0,0 +1,160 @@ |
||||
package grpc |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"io" |
||||
"time" |
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" |
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/credentials/insecure" |
||||
) |
||||
|
||||
type Client struct { |
||||
address string |
||||
} |
||||
|
||||
func NewClient(address string) *Client { |
||||
return &Client{ |
||||
address: address, |
||||
} |
||||
} |
||||
|
||||
func (c *Client) HealthCheck(ctx context.Context) bool { |
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
return false |
||||
} |
||||
defer conn.Close() |
||||
client := pb.NewBackendClient(conn) |
||||
|
||||
// The healthcheck call shouldn't take long time
|
||||
ctx, cancel := context.WithTimeout(ctx, 10*time.Second) |
||||
defer cancel() |
||||
|
||||
res, err := client.Health(ctx, &pb.HealthMessage{}) |
||||
if err != nil { |
||||
fmt.Println(err) |
||||
|
||||
return false |
||||
} |
||||
|
||||
if res.Message == "OK" { |
||||
return true |
||||
} |
||||
return false |
||||
} |
||||
|
||||
func (c *Client) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.EmbeddingResult, error) { |
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
defer conn.Close() |
||||
client := pb.NewBackendClient(conn) |
||||
|
||||
return client.Embedding(ctx, in, opts...) |
||||
} |
||||
|
||||
func (c *Client) Predict(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.Reply, error) { |
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
defer conn.Close() |
||||
client := pb.NewBackendClient(conn) |
||||
|
||||
return client.Predict(ctx, in, opts...) |
||||
} |
||||
|
||||
func (c *Client) LoadModel(ctx context.Context, in *pb.ModelOptions, opts ...grpc.CallOption) (*pb.Result, error) { |
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
defer conn.Close() |
||||
client := pb.NewBackendClient(conn) |
||||
return client.LoadModel(ctx, in, opts...) |
||||
} |
||||
|
||||
func (c *Client) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(s string), opts ...grpc.CallOption) error { |
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer conn.Close() |
||||
client := pb.NewBackendClient(conn) |
||||
|
||||
stream, err := client.PredictStream(ctx, in, opts...) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
for { |
||||
feature, err := stream.Recv() |
||||
if err == io.EOF { |
||||
break |
||||
} |
||||
if err != nil { |
||||
fmt.Println("Error", err) |
||||
|
||||
return err |
||||
} |
||||
f(feature.GetMessage()) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (c *Client) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error) { |
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
defer conn.Close() |
||||
client := pb.NewBackendClient(conn) |
||||
return client.GenerateImage(ctx, in, opts...) |
||||
} |
||||
|
||||
func (c *Client) TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error) { |
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
defer conn.Close() |
||||
client := pb.NewBackendClient(conn) |
||||
return client.TTS(ctx, in, opts...) |
||||
} |
||||
|
||||
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*api.Result, error) { |
||||
conn, err := grpc.Dial(c.address, grpc.WithTransportCredentials(insecure.NewCredentials())) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
defer conn.Close() |
||||
client := pb.NewBackendClient(conn) |
||||
res, err := client.AudioTranscription(ctx, in, opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
tresult := &api.Result{} |
||||
for _, s := range res.Segments { |
||||
tks := []int{} |
||||
for _, t := range s.Tokens { |
||||
tks = append(tks, int(t)) |
||||
} |
||||
tresult.Segments = append(tresult.Segments, |
||||
api.Segment{ |
||||
Text: s.Text, |
||||
Id: int(s.Id), |
||||
Start: time.Duration(s.Start), |
||||
End: time.Duration(s.End), |
||||
Tokens: tks, |
||||
}) |
||||
} |
||||
tresult.Text = res.Text |
||||
return tresult, err |
||||
} |
@ -0,0 +1,33 @@ |
||||
package image |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion" |
||||
) |
||||
|
||||
type StableDiffusion struct { |
||||
base.Base |
||||
stablediffusion *stablediffusion.StableDiffusion |
||||
} |
||||
|
||||
func (sd *StableDiffusion) Load(opts *pb.ModelOptions) error { |
||||
var err error |
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
sd.stablediffusion, err = stablediffusion.New(opts.Model) |
||||
return err |
||||
} |
||||
|
||||
func (sd *StableDiffusion) GenerateImage(opts *pb.GenerateImageRequest) error { |
||||
return sd.stablediffusion.GenerateImage( |
||||
int(opts.Height), |
||||
int(opts.Width), |
||||
int(opts.Mode), |
||||
int(opts.Step), |
||||
int(opts.Seed), |
||||
opts.PositivePrompt, |
||||
opts.NegativePrompt, |
||||
opts.Dst) |
||||
} |
@ -0,0 +1,16 @@ |
||||
package grpc |
||||
|
||||
import ( |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" |
||||
) |
||||
|
||||
type LLM interface { |
||||
Predict(*pb.PredictOptions) (string, error) |
||||
PredictStream(*pb.PredictOptions, chan string) error |
||||
Load(*pb.ModelOptions) error |
||||
Embeddings(*pb.PredictOptions) ([]float32, error) |
||||
GenerateImage(*pb.GenerateImageRequest) error |
||||
AudioTranscription(*pb.TranscriptRequest) (api.Result, error) |
||||
TTS(*pb.TTSRequest) error |
||||
} |
@ -0,0 +1,33 @@ |
||||
package bert |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
bert "github.com/go-skynet/go-bert.cpp" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
) |
||||
|
||||
type Embeddings struct { |
||||
base.Base |
||||
bert *bert.Bert |
||||
} |
||||
|
||||
func (llm *Embeddings) Load(opts *pb.ModelOptions) error { |
||||
model, err := bert.New(opts.Model) |
||||
llm.bert = model |
||||
return err |
||||
} |
||||
|
||||
func (llm *Embeddings) Embeddings(opts *pb.PredictOptions) ([]float32, error) { |
||||
if len(opts.EmbeddingTokens) > 0 { |
||||
tokens := []int{} |
||||
for _, t := range opts.EmbeddingTokens { |
||||
tokens = append(tokens, int(t)) |
||||
} |
||||
return llm.bert.TokenEmbeddings(tokens, bert.SetThreads(int(opts.Threads))) |
||||
} |
||||
|
||||
return llm.bert.Embeddings(opts.Embeddings, bert.SetThreads(int(opts.Threads))) |
||||
} |
@ -0,0 +1,59 @@ |
||||
package bloomz |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
"github.com/go-skynet/bloomz.cpp" |
||||
) |
||||
|
||||
type LLM struct { |
||||
base.Base |
||||
|
||||
bloomz *bloomz.Bloomz |
||||
} |
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||
model, err := bloomz.New(opts.Model) |
||||
llm.bloomz = model |
||||
return err |
||||
} |
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []bloomz.PredictOption { |
||||
predictOptions := []bloomz.PredictOption{ |
||||
bloomz.SetTemperature(float64(opts.Temperature)), |
||||
bloomz.SetTopP(float64(opts.TopP)), |
||||
bloomz.SetTopK(int(opts.TopK)), |
||||
bloomz.SetTokens(int(opts.Tokens)), |
||||
bloomz.SetThreads(int(opts.Threads)), |
||||
} |
||||
|
||||
if opts.Seed != 0 { |
||||
predictOptions = append(predictOptions, bloomz.SetSeed(int(opts.Seed))) |
||||
} |
||||
|
||||
return predictOptions |
||||
} |
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
// fallback to Predict
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
go func() { |
||||
res, err := llm.bloomz.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
|
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
results <- res |
||||
close(results) |
||||
}() |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,144 @@ |
||||
package falcon |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
ggllm "github.com/mudler/go-ggllm.cpp" |
||||
) |
||||
|
||||
type LLM struct { |
||||
base.Base |
||||
|
||||
falcon *ggllm.Falcon |
||||
} |
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||
ggllmOpts := []ggllm.ModelOption{} |
||||
if opts.ContextSize != 0 { |
||||
ggllmOpts = append(ggllmOpts, ggllm.SetContext(int(opts.ContextSize))) |
||||
} |
||||
// F16 doesn't seem to produce good output at all!
|
||||
//if c.F16 {
|
||||
// llamaOpts = append(llamaOpts, llama.EnableF16Memory)
|
||||
//}
|
||||
|
||||
if opts.NGPULayers != 0 { |
||||
ggllmOpts = append(ggllmOpts, ggllm.SetGPULayers(int(opts.NGPULayers))) |
||||
} |
||||
|
||||
ggllmOpts = append(ggllmOpts, ggllm.SetMMap(opts.MMap)) |
||||
ggllmOpts = append(ggllmOpts, ggllm.SetMainGPU(opts.MainGPU)) |
||||
ggllmOpts = append(ggllmOpts, ggllm.SetTensorSplit(opts.TensorSplit)) |
||||
if opts.NBatch != 0 { |
||||
ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(int(opts.NBatch))) |
||||
} else { |
||||
ggllmOpts = append(ggllmOpts, ggllm.SetNBatch(512)) |
||||
} |
||||
|
||||
model, err := ggllm.New(opts.Model, ggllmOpts...) |
||||
llm.falcon = model |
||||
return err |
||||
} |
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []ggllm.PredictOption { |
||||
predictOptions := []ggllm.PredictOption{ |
||||
ggllm.SetTemperature(float64(opts.Temperature)), |
||||
ggllm.SetTopP(float64(opts.TopP)), |
||||
ggllm.SetTopK(int(opts.TopK)), |
||||
ggllm.SetTokens(int(opts.Tokens)), |
||||
ggllm.SetThreads(int(opts.Threads)), |
||||
} |
||||
|
||||
if opts.PromptCacheAll { |
||||
predictOptions = append(predictOptions, ggllm.EnablePromptCacheAll) |
||||
} |
||||
|
||||
if opts.PromptCacheRO { |
||||
predictOptions = append(predictOptions, ggllm.EnablePromptCacheRO) |
||||
} |
||||
|
||||
// Expected absolute path
|
||||
if opts.PromptCachePath != "" { |
||||
predictOptions = append(predictOptions, ggllm.SetPathPromptCache(opts.PromptCachePath)) |
||||
} |
||||
|
||||
if opts.Mirostat != 0 { |
||||
predictOptions = append(predictOptions, ggllm.SetMirostat(int(opts.Mirostat))) |
||||
} |
||||
|
||||
if opts.MirostatETA != 0 { |
||||
predictOptions = append(predictOptions, ggllm.SetMirostatETA(float64(opts.MirostatETA))) |
||||
} |
||||
|
||||
if opts.MirostatTAU != 0 { |
||||
predictOptions = append(predictOptions, ggllm.SetMirostatTAU(float64(opts.MirostatTAU))) |
||||
} |
||||
|
||||
if opts.Debug { |
||||
predictOptions = append(predictOptions, ggllm.Debug) |
||||
} |
||||
|
||||
predictOptions = append(predictOptions, ggllm.SetStopWords(opts.StopPrompts...)) |
||||
|
||||
if opts.PresencePenalty != 0 { |
||||
predictOptions = append(predictOptions, ggllm.SetPenalty(float64(opts.PresencePenalty))) |
||||
} |
||||
|
||||
if opts.NKeep != 0 { |
||||
predictOptions = append(predictOptions, ggllm.SetNKeep(int(opts.NKeep))) |
||||
} |
||||
|
||||
if opts.Batch != 0 { |
||||
predictOptions = append(predictOptions, ggllm.SetBatch(int(opts.Batch))) |
||||
} |
||||
|
||||
if opts.IgnoreEOS { |
||||
predictOptions = append(predictOptions, ggllm.IgnoreEOS) |
||||
} |
||||
|
||||
if opts.Seed != 0 { |
||||
predictOptions = append(predictOptions, ggllm.SetSeed(int(opts.Seed))) |
||||
} |
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, ggllm.SetFrequencyPenalty(float64(opts.FrequencyPenalty))) |
||||
predictOptions = append(predictOptions, ggllm.SetMlock(opts.MLock)) |
||||
predictOptions = append(predictOptions, ggllm.SetMemoryMap(opts.MMap)) |
||||
predictOptions = append(predictOptions, ggllm.SetPredictionMainGPU(opts.MainGPU)) |
||||
predictOptions = append(predictOptions, ggllm.SetPredictionTensorSplit(opts.TensorSplit)) |
||||
predictOptions = append(predictOptions, ggllm.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ))) |
||||
predictOptions = append(predictOptions, ggllm.SetTypicalP(float64(opts.TypicalP))) |
||||
return predictOptions |
||||
} |
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
predictOptions := buildPredictOptions(opts) |
||||
|
||||
predictOptions = append(predictOptions, ggllm.SetTokenCallback(func(token string) bool { |
||||
if token == "<|endoftext|>" { |
||||
return true |
||||
} |
||||
results <- token |
||||
return true |
||||
})) |
||||
|
||||
go func() { |
||||
_, err := llm.falcon.Predict(opts.Prompt, predictOptions...) |
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
close(results) |
||||
}() |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,62 @@ |
||||
package gpt4all |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
gpt4all "github.com/nomic-ai/gpt4all/gpt4all-bindings/golang" |
||||
) |
||||
|
||||
type LLM struct { |
||||
base.Base |
||||
|
||||
gpt4all *gpt4all.Model |
||||
} |
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||
model, err := gpt4all.New(opts.Model, |
||||
gpt4all.SetThreads(int(opts.Threads)), |
||||
gpt4all.SetLibrarySearchPath(opts.LibrarySearchPath)) |
||||
llm.gpt4all = model |
||||
return err |
||||
} |
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []gpt4all.PredictOption { |
||||
predictOptions := []gpt4all.PredictOption{ |
||||
gpt4all.SetTemperature(float64(opts.Temperature)), |
||||
gpt4all.SetTopP(float64(opts.TopP)), |
||||
gpt4all.SetTopK(int(opts.TopK)), |
||||
gpt4all.SetTokens(int(opts.Tokens)), |
||||
} |
||||
|
||||
if opts.Batch != 0 { |
||||
predictOptions = append(predictOptions, gpt4all.SetBatch(int(opts.Batch))) |
||||
} |
||||
return predictOptions |
||||
} |
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.gpt4all.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
predictOptions := buildPredictOptions(opts) |
||||
|
||||
go func() { |
||||
llm.gpt4all.SetTokenCallback(func(token string) bool { |
||||
results <- token |
||||
return true |
||||
}) |
||||
_, err := llm.gpt4all.Predict(opts.Prompt, predictOptions...) |
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
llm.gpt4all.SetTokenCallback(nil) |
||||
close(results) |
||||
}() |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,58 @@ |
||||
package langchain |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
"github.com/go-skynet/LocalAI/pkg/langchain" |
||||
) |
||||
|
||||
type LLM struct { |
||||
base.Base |
||||
|
||||
langchain *langchain.HuggingFace |
||||
model string |
||||
} |
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||
llm.langchain, _ = langchain.NewHuggingFace(opts.Model) |
||||
llm.model = opts.Model |
||||
return nil |
||||
} |
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||
o := []langchain.PredictOption{ |
||||
langchain.SetModel(llm.model), |
||||
langchain.SetMaxTokens(int(opts.Tokens)), |
||||
langchain.SetTemperature(float64(opts.Temperature)), |
||||
langchain.SetStopWords(opts.StopPrompts), |
||||
} |
||||
pred, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return pred.Completion, nil |
||||
} |
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
o := []langchain.PredictOption{ |
||||
langchain.SetModel(llm.model), |
||||
langchain.SetMaxTokens(int(opts.Tokens)), |
||||
langchain.SetTemperature(float64(opts.Temperature)), |
||||
langchain.SetStopWords(opts.StopPrompts), |
||||
} |
||||
go func() { |
||||
res, err := llm.langchain.PredictHuggingFace(opts.Prompt, o...) |
||||
|
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
results <- res.Completion |
||||
close(results) |
||||
}() |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,170 @@ |
||||
package llama |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
"github.com/go-skynet/go-llama.cpp" |
||||
) |
||||
|
||||
type LLM struct { |
||||
base.Base |
||||
|
||||
llama *llama.LLama |
||||
} |
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||
llamaOpts := []llama.ModelOption{} |
||||
|
||||
if opts.ContextSize != 0 { |
||||
llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize))) |
||||
} |
||||
if opts.F16Memory { |
||||
llamaOpts = append(llamaOpts, llama.EnableF16Memory) |
||||
} |
||||
if opts.Embeddings { |
||||
llamaOpts = append(llamaOpts, llama.EnableEmbeddings) |
||||
} |
||||
if opts.NGPULayers != 0 { |
||||
llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers))) |
||||
} |
||||
|
||||
llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap)) |
||||
llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU)) |
||||
llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit)) |
||||
if opts.NBatch != 0 { |
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch))) |
||||
} else { |
||||
llamaOpts = append(llamaOpts, llama.SetNBatch(512)) |
||||
} |
||||
|
||||
if opts.NUMA { |
||||
llamaOpts = append(llamaOpts, llama.EnableNUMA) |
||||
} |
||||
|
||||
if opts.LowVRAM { |
||||
llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) |
||||
} |
||||
|
||||
model, err := llama.New(opts.Model, llamaOpts...) |
||||
llm.llama = model |
||||
return err |
||||
} |
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { |
||||
predictOptions := []llama.PredictOption{ |
||||
llama.SetTemperature(float64(opts.Temperature)), |
||||
llama.SetTopP(float64(opts.TopP)), |
||||
llama.SetTopK(int(opts.TopK)), |
||||
llama.SetTokens(int(opts.Tokens)), |
||||
llama.SetThreads(int(opts.Threads)), |
||||
} |
||||
|
||||
if opts.PromptCacheAll { |
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheAll) |
||||
} |
||||
|
||||
if opts.PromptCacheRO { |
||||
predictOptions = append(predictOptions, llama.EnablePromptCacheRO) |
||||
} |
||||
|
||||
predictOptions = append(predictOptions, llama.WithGrammar(opts.Grammar)) |
||||
|
||||
// Expected absolute path
|
||||
if opts.PromptCachePath != "" { |
||||
predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath)) |
||||
} |
||||
|
||||
if opts.Mirostat != 0 { |
||||
predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat))) |
||||
} |
||||
|
||||
if opts.MirostatETA != 0 { |
||||
predictOptions = append(predictOptions, llama.SetMirostatETA(float64(opts.MirostatETA))) |
||||
} |
||||
|
||||
if opts.MirostatTAU != 0 { |
||||
predictOptions = append(predictOptions, llama.SetMirostatTAU(float64(opts.MirostatTAU))) |
||||
} |
||||
|
||||
if opts.Debug { |
||||
predictOptions = append(predictOptions, llama.Debug) |
||||
} |
||||
|
||||
predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...)) |
||||
|
||||
if opts.PresencePenalty != 0 { |
||||
predictOptions = append(predictOptions, llama.SetPenalty(float64(opts.PresencePenalty))) |
||||
} |
||||
|
||||
if opts.NKeep != 0 { |
||||
predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep))) |
||||
} |
||||
|
||||
if opts.Batch != 0 { |
||||
predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch))) |
||||
} |
||||
|
||||
if opts.F16KV { |
||||
predictOptions = append(predictOptions, llama.EnableF16KV) |
||||
} |
||||
|
||||
if opts.IgnoreEOS { |
||||
predictOptions = append(predictOptions, llama.IgnoreEOS) |
||||
} |
||||
|
||||
if opts.Seed != 0 { |
||||
predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed))) |
||||
} |
||||
|
||||
//predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed))
|
||||
|
||||
predictOptions = append(predictOptions, llama.SetFrequencyPenalty(float64(opts.FrequencyPenalty))) |
||||
predictOptions = append(predictOptions, llama.SetMlock(opts.MLock)) |
||||
predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap)) |
||||
predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU)) |
||||
predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit)) |
||||
predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(float64(opts.TailFreeSamplingZ))) |
||||
predictOptions = append(predictOptions, llama.SetTypicalP(float64(opts.TypicalP))) |
||||
return predictOptions |
||||
} |
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
predictOptions := buildPredictOptions(opts) |
||||
|
||||
predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool { |
||||
results <- token |
||||
return true |
||||
})) |
||||
|
||||
go func() { |
||||
_, err := llm.llama.Predict(opts.Prompt, predictOptions...) |
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
close(results) |
||||
}() |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { |
||||
predictOptions := buildPredictOptions(opts) |
||||
|
||||
if len(opts.EmbeddingTokens) > 0 { |
||||
tokens := []int{} |
||||
for _, t := range opts.EmbeddingTokens { |
||||
tokens = append(tokens, int(t)) |
||||
} |
||||
return llm.llama.TokenEmbeddings(tokens, predictOptions...) |
||||
} |
||||
|
||||
return llm.llama.Embeddings(opts.Embeddings, predictOptions...) |
||||
} |
@ -0,0 +1,71 @@ |
||||
package rwkv |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
"path/filepath" |
||||
|
||||
"github.com/donomii/go-rwkv.cpp" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
) |
||||
|
||||
const tokenizerSuffix = ".tokenizer.json" |
||||
|
||||
type LLM struct { |
||||
base.Base |
||||
|
||||
rwkv *rwkv.RwkvState |
||||
} |
||||
|
||||
func (llm *LLM) Load(opts *pb.ModelOptions) error { |
||||
modelPath := filepath.Dir(opts.Model) |
||||
modelFile := filepath.Base(opts.Model) |
||||
model := rwkv.LoadFiles(opts.Model, filepath.Join(modelPath, modelFile+tokenizerSuffix), uint32(opts.GetThreads())) |
||||
|
||||
if model == nil { |
||||
return fmt.Errorf("could not load model") |
||||
} |
||||
llm.rwkv = model |
||||
return nil |
||||
} |
||||
|
||||
func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { |
||||
|
||||
stopWord := "\n" |
||||
if len(opts.StopPrompts) > 0 { |
||||
stopWord = opts.StopPrompts[0] |
||||
} |
||||
|
||||
if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
response := llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), nil) |
||||
|
||||
return response, nil |
||||
} |
||||
|
||||
func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
go func() { |
||||
|
||||
stopWord := "\n" |
||||
if len(opts.StopPrompts) > 0 { |
||||
stopWord = opts.StopPrompts[0] |
||||
} |
||||
|
||||
if err := llm.rwkv.ProcessInput(opts.Prompt); err != nil { |
||||
fmt.Println("Error processing input: ", err) |
||||
return |
||||
} |
||||
|
||||
llm.rwkv.GenerateResponse(int(opts.Tokens), stopWord, float32(opts.Temperature), float32(opts.TopP), func(s string) bool { |
||||
results <- s |
||||
return true |
||||
}) |
||||
close(results) |
||||
}() |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,43 @@ |
||||
package transformers |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
) |
||||
|
||||
type Dolly struct { |
||||
base.Base |
||||
|
||||
dolly *transformers.Dolly |
||||
} |
||||
|
||||
func (llm *Dolly) Load(opts *pb.ModelOptions) error { |
||||
model, err := transformers.NewDolly(opts.Model) |
||||
llm.dolly = model |
||||
return err |
||||
} |
||||
|
||||
func (llm *Dolly) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
// fallback to Predict
|
||||
func (llm *Dolly) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
go func() { |
||||
res, err := llm.dolly.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
|
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
results <- res |
||||
close(results) |
||||
}() |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,43 @@ |
||||
package transformers |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
) |
||||
|
||||
type Falcon struct { |
||||
base.Base |
||||
|
||||
falcon *transformers.Falcon |
||||
} |
||||
|
||||
func (llm *Falcon) Load(opts *pb.ModelOptions) error { |
||||
model, err := transformers.NewFalcon(opts.Model) |
||||
llm.falcon = model |
||||
return err |
||||
} |
||||
|
||||
func (llm *Falcon) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
// fallback to Predict
|
||||
func (llm *Falcon) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
go func() { |
||||
res, err := llm.falcon.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
|
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
results <- res |
||||
close(results) |
||||
}() |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,42 @@ |
||||
package transformers |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
) |
||||
|
||||
type GPT2 struct { |
||||
base.Base |
||||
|
||||
gpt2 *transformers.GPT2 |
||||
} |
||||
|
||||
func (llm *GPT2) Load(opts *pb.ModelOptions) error { |
||||
model, err := transformers.New(opts.Model) |
||||
llm.gpt2 = model |
||||
return err |
||||
} |
||||
|
||||
func (llm *GPT2) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPT2) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
go func() { |
||||
res, err := llm.gpt2.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
|
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
results <- res |
||||
close(results) |
||||
}() |
||||
return nil |
||||
} |
@ -0,0 +1,42 @@ |
||||
package transformers |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
) |
||||
|
||||
type GPTJ struct { |
||||
base.Base |
||||
|
||||
gptj *transformers.GPTJ |
||||
} |
||||
|
||||
func (llm *GPTJ) Load(opts *pb.ModelOptions) error { |
||||
model, err := transformers.NewGPTJ(opts.Model) |
||||
llm.gptj = model |
||||
return err |
||||
} |
||||
|
||||
func (llm *GPTJ) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPTJ) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
go func() { |
||||
res, err := llm.gptj.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
|
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
results <- res |
||||
close(results) |
||||
}() |
||||
return nil |
||||
} |
@ -0,0 +1,42 @@ |
||||
package transformers |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
) |
||||
|
||||
type GPTNeoX struct { |
||||
base.Base |
||||
|
||||
gptneox *transformers.GPTNeoX |
||||
} |
||||
|
||||
func (llm *GPTNeoX) Load(opts *pb.ModelOptions) error { |
||||
model, err := transformers.NewGPTNeoX(opts.Model) |
||||
llm.gptneox = model |
||||
return err |
||||
} |
||||
|
||||
func (llm *GPTNeoX) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
// fallback to Predict
|
||||
func (llm *GPTNeoX) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
go func() { |
||||
res, err := llm.gptneox.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
|
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
results <- res |
||||
close(results) |
||||
}() |
||||
return nil |
||||
} |
@ -0,0 +1,42 @@ |
||||
package transformers |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
) |
||||
|
||||
type MPT struct { |
||||
base.Base |
||||
|
||||
mpt *transformers.MPT |
||||
} |
||||
|
||||
func (llm *MPT) Load(opts *pb.ModelOptions) error { |
||||
model, err := transformers.NewMPT(opts.Model) |
||||
llm.mpt = model |
||||
return err |
||||
} |
||||
|
||||
func (llm *MPT) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
// fallback to Predict
|
||||
func (llm *MPT) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
go func() { |
||||
res, err := llm.mpt.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
|
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
results <- res |
||||
close(results) |
||||
}() |
||||
return nil |
||||
} |
@ -0,0 +1,26 @@ |
||||
package transformers |
||||
|
||||
import ( |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
) |
||||
|
||||
func buildPredictOptions(opts *pb.PredictOptions) []transformers.PredictOption { |
||||
predictOptions := []transformers.PredictOption{ |
||||
transformers.SetTemperature(float64(opts.Temperature)), |
||||
transformers.SetTopP(float64(opts.TopP)), |
||||
transformers.SetTopK(int(opts.TopK)), |
||||
transformers.SetTokens(int(opts.Tokens)), |
||||
transformers.SetThreads(int(opts.Threads)), |
||||
} |
||||
|
||||
if opts.Batch != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetBatch(int(opts.Batch))) |
||||
} |
||||
|
||||
if opts.Seed != 0 { |
||||
predictOptions = append(predictOptions, transformers.SetSeed(int(opts.Seed))) |
||||
} |
||||
|
||||
return predictOptions |
||||
} |
@ -0,0 +1,42 @@ |
||||
package transformers |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
) |
||||
|
||||
type Replit struct { |
||||
base.Base |
||||
|
||||
replit *transformers.Replit |
||||
} |
||||
|
||||
func (llm *Replit) Load(opts *pb.ModelOptions) error { |
||||
model, err := transformers.NewReplit(opts.Model) |
||||
llm.replit = model |
||||
return err |
||||
} |
||||
|
||||
func (llm *Replit) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
// fallback to Predict
|
||||
func (llm *Replit) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
go func() { |
||||
res, err := llm.replit.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
|
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
results <- res |
||||
close(results) |
||||
}() |
||||
return nil |
||||
} |
@ -0,0 +1,43 @@ |
||||
package transformers |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
transformers "github.com/go-skynet/go-ggml-transformers.cpp" |
||||
) |
||||
|
||||
type Starcoder struct { |
||||
base.Base |
||||
|
||||
starcoder *transformers.Starcoder |
||||
} |
||||
|
||||
func (llm *Starcoder) Load(opts *pb.ModelOptions) error { |
||||
model, err := transformers.NewStarcoder(opts.Model) |
||||
llm.starcoder = model |
||||
return err |
||||
} |
||||
|
||||
func (llm *Starcoder) Predict(opts *pb.PredictOptions) (string, error) { |
||||
return llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
} |
||||
|
||||
// fallback to Predict
|
||||
func (llm *Starcoder) PredictStream(opts *pb.PredictOptions, results chan string) error { |
||||
go func() { |
||||
res, err := llm.starcoder.Predict(opts.Prompt, buildPredictOptions(opts)...) |
||||
|
||||
if err != nil { |
||||
fmt.Println("err: ", err) |
||||
} |
||||
results <- res |
||||
close(results) |
||||
}() |
||||
|
||||
return nil |
||||
} |
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,129 @@ |
||||
syntax = "proto3"; |
||||
|
||||
option go_package = "github.com/go-skynet/LocalAI/pkg/grpc/proto"; |
||||
option java_multiple_files = true; |
||||
option java_package = "io.skynet.localai.backend"; |
||||
option java_outer_classname = "LocalAIBackend"; |
||||
|
||||
package backend; |
||||
|
||||
service Backend { |
||||
rpc Health(HealthMessage) returns (Reply) {} |
||||
rpc Predict(PredictOptions) returns (Reply) {} |
||||
rpc LoadModel(ModelOptions) returns (Result) {} |
||||
rpc PredictStream(PredictOptions) returns (stream Reply) {} |
||||
rpc Embedding(PredictOptions) returns (EmbeddingResult) {} |
||||
rpc GenerateImage(GenerateImageRequest) returns (Result) {} |
||||
rpc AudioTranscription(TranscriptRequest) returns (TranscriptResult) {} |
||||
rpc TTS(TTSRequest) returns (Result) {} |
||||
} |
||||
|
||||
message HealthMessage {} |
||||
|
||||
// The request message containing the user's name. |
||||
message PredictOptions { |
||||
string Prompt = 1; |
||||
int32 Seed = 2; |
||||
int32 Threads = 3; |
||||
int32 Tokens = 4; |
||||
int32 TopK = 5; |
||||
int32 Repeat = 6; |
||||
int32 Batch = 7; |
||||
int32 NKeep = 8; |
||||
float Temperature = 9; |
||||
float Penalty = 10; |
||||
bool F16KV = 11; |
||||
bool DebugMode = 12; |
||||
repeated string StopPrompts = 13; |
||||
bool IgnoreEOS = 14; |
||||
float TailFreeSamplingZ = 15; |
||||
float TypicalP = 16; |
||||
float FrequencyPenalty = 17; |
||||
float PresencePenalty = 18; |
||||
int32 Mirostat = 19; |
||||
float MirostatETA = 20; |
||||
float MirostatTAU = 21; |
||||
bool PenalizeNL = 22; |
||||
string LogitBias = 23; |
||||
bool MLock = 25; |
||||
bool MMap = 26; |
||||
bool PromptCacheAll = 27; |
||||
bool PromptCacheRO = 28; |
||||
string Grammar = 29; |
||||
string MainGPU = 30; |
||||
string TensorSplit = 31; |
||||
float TopP = 32; |
||||
string PromptCachePath = 33; |
||||
bool Debug = 34; |
||||
repeated int32 EmbeddingTokens = 35; |
||||
string Embeddings = 36; |
||||
} |
||||
|
||||
// The response message containing the result |
||||
message Reply { |
||||
string message = 1; |
||||
} |
||||
|
||||
message ModelOptions { |
||||
string Model = 1; |
||||
int32 ContextSize = 2; |
||||
int32 Seed = 3; |
||||
int32 NBatch = 4; |
||||
bool F16Memory = 5; |
||||
bool MLock = 6; |
||||
bool MMap = 7; |
||||
bool VocabOnly = 8; |
||||
bool LowVRAM = 9; |
||||
bool Embeddings = 10; |
||||
bool NUMA = 11; |
||||
int32 NGPULayers = 12; |
||||
string MainGPU = 13; |
||||
string TensorSplit = 14; |
||||
int32 Threads = 15; |
||||
string LibrarySearchPath = 16; |
||||
} |
||||
|
||||
message Result { |
||||
string message = 1; |
||||
bool success = 2; |
||||
} |
||||
|
||||
message EmbeddingResult { |
||||
repeated float embeddings = 1; |
||||
} |
||||
|
||||
message TranscriptRequest { |
||||
string dst = 2; |
||||
string language = 3; |
||||
uint32 threads = 4; |
||||
} |
||||
|
||||
message TranscriptResult { |
||||
repeated TranscriptSegment segments = 1; |
||||
string text = 2; |
||||
} |
||||
|
||||
message TranscriptSegment { |
||||
int32 id = 1; |
||||
int64 start = 2; |
||||
int64 end = 3; |
||||
string text = 4; |
||||
repeated int32 tokens = 5; |
||||
} |
||||
|
||||
message GenerateImageRequest { |
||||
int32 height = 1; |
||||
int32 width = 2; |
||||
int32 mode = 3; |
||||
int32 step = 4; |
||||
int32 seed = 5; |
||||
string positive_prompt = 6; |
||||
string negative_prompt = 7; |
||||
string dst = 8; |
||||
} |
||||
|
||||
message TTSRequest { |
||||
string text = 1; |
||||
string model = 2; |
||||
string dst = 3; |
||||
} |
@ -0,0 +1,385 @@ |
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.2.0
|
||||
// - protoc v3.15.8
|
||||
// source: pkg/grpc/proto/backend.proto
|
||||
|
||||
package proto |
||||
|
||||
import ( |
||||
context "context" |
||||
grpc "google.golang.org/grpc" |
||||
codes "google.golang.org/grpc/codes" |
||||
status "google.golang.org/grpc/status" |
||||
) |
||||
|
||||
// This is a compile-time assertion to ensure that this generated file
|
||||
// is compatible with the grpc package it is being compiled against.
|
||||
// Requires gRPC-Go v1.32.0 or later.
|
||||
const _ = grpc.SupportPackageIsVersion7 |
||||
|
||||
// BackendClient is the client API for Backend service.
|
||||
//
|
||||
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
|
||||
type BackendClient interface { |
||||
Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) |
||||
Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) |
||||
LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) |
||||
PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) |
||||
Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) |
||||
GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) |
||||
AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) |
||||
TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) |
||||
} |
||||
|
||||
type backendClient struct { |
||||
cc grpc.ClientConnInterface |
||||
} |
||||
|
||||
func NewBackendClient(cc grpc.ClientConnInterface) BackendClient { |
||||
return &backendClient{cc} |
||||
} |
||||
|
||||
func (c *backendClient) Health(ctx context.Context, in *HealthMessage, opts ...grpc.CallOption) (*Reply, error) { |
||||
out := new(Reply) |
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Health", in, out, opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return out, nil |
||||
} |
||||
|
||||
func (c *backendClient) Predict(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*Reply, error) { |
||||
out := new(Reply) |
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Predict", in, out, opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return out, nil |
||||
} |
||||
|
||||
func (c *backendClient) LoadModel(ctx context.Context, in *ModelOptions, opts ...grpc.CallOption) (*Result, error) { |
||||
out := new(Result) |
||||
err := c.cc.Invoke(ctx, "/backend.Backend/LoadModel", in, out, opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return out, nil |
||||
} |
||||
|
||||
func (c *backendClient) PredictStream(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (Backend_PredictStreamClient, error) { |
||||
stream, err := c.cc.NewStream(ctx, &Backend_ServiceDesc.Streams[0], "/backend.Backend/PredictStream", opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
x := &backendPredictStreamClient{stream} |
||||
if err := x.ClientStream.SendMsg(in); err != nil { |
||||
return nil, err |
||||
} |
||||
if err := x.ClientStream.CloseSend(); err != nil { |
||||
return nil, err |
||||
} |
||||
return x, nil |
||||
} |
||||
|
||||
type Backend_PredictStreamClient interface { |
||||
Recv() (*Reply, error) |
||||
grpc.ClientStream |
||||
} |
||||
|
||||
type backendPredictStreamClient struct { |
||||
grpc.ClientStream |
||||
} |
||||
|
||||
func (x *backendPredictStreamClient) Recv() (*Reply, error) { |
||||
m := new(Reply) |
||||
if err := x.ClientStream.RecvMsg(m); err != nil { |
||||
return nil, err |
||||
} |
||||
return m, nil |
||||
} |
||||
|
||||
func (c *backendClient) Embedding(ctx context.Context, in *PredictOptions, opts ...grpc.CallOption) (*EmbeddingResult, error) { |
||||
out := new(EmbeddingResult) |
||||
err := c.cc.Invoke(ctx, "/backend.Backend/Embedding", in, out, opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return out, nil |
||||
} |
||||
|
||||
func (c *backendClient) GenerateImage(ctx context.Context, in *GenerateImageRequest, opts ...grpc.CallOption) (*Result, error) { |
||||
out := new(Result) |
||||
err := c.cc.Invoke(ctx, "/backend.Backend/GenerateImage", in, out, opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return out, nil |
||||
} |
||||
|
||||
func (c *backendClient) AudioTranscription(ctx context.Context, in *TranscriptRequest, opts ...grpc.CallOption) (*TranscriptResult, error) { |
||||
out := new(TranscriptResult) |
||||
err := c.cc.Invoke(ctx, "/backend.Backend/AudioTranscription", in, out, opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return out, nil |
||||
} |
||||
|
||||
func (c *backendClient) TTS(ctx context.Context, in *TTSRequest, opts ...grpc.CallOption) (*Result, error) { |
||||
out := new(Result) |
||||
err := c.cc.Invoke(ctx, "/backend.Backend/TTS", in, out, opts...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return out, nil |
||||
} |
||||
|
||||
// BackendServer is the server API for Backend service.
|
||||
// All implementations must embed UnimplementedBackendServer
|
||||
// for forward compatibility
|
||||
type BackendServer interface { |
||||
Health(context.Context, *HealthMessage) (*Reply, error) |
||||
Predict(context.Context, *PredictOptions) (*Reply, error) |
||||
LoadModel(context.Context, *ModelOptions) (*Result, error) |
||||
PredictStream(*PredictOptions, Backend_PredictStreamServer) error |
||||
Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) |
||||
GenerateImage(context.Context, *GenerateImageRequest) (*Result, error) |
||||
AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error) |
||||
TTS(context.Context, *TTSRequest) (*Result, error) |
||||
mustEmbedUnimplementedBackendServer() |
||||
} |
||||
|
||||
// UnimplementedBackendServer must be embedded to have forward compatible implementations.
|
||||
type UnimplementedBackendServer struct { |
||||
} |
||||
|
||||
func (UnimplementedBackendServer) Health(context.Context, *HealthMessage) (*Reply, error) { |
||||
return nil, status.Errorf(codes.Unimplemented, "method Health not implemented") |
||||
} |
||||
func (UnimplementedBackendServer) Predict(context.Context, *PredictOptions) (*Reply, error) { |
||||
return nil, status.Errorf(codes.Unimplemented, "method Predict not implemented") |
||||
} |
||||
func (UnimplementedBackendServer) LoadModel(context.Context, *ModelOptions) (*Result, error) { |
||||
return nil, status.Errorf(codes.Unimplemented, "method LoadModel not implemented") |
||||
} |
||||
func (UnimplementedBackendServer) PredictStream(*PredictOptions, Backend_PredictStreamServer) error { |
||||
return status.Errorf(codes.Unimplemented, "method PredictStream not implemented") |
||||
} |
||||
func (UnimplementedBackendServer) Embedding(context.Context, *PredictOptions) (*EmbeddingResult, error) { |
||||
return nil, status.Errorf(codes.Unimplemented, "method Embedding not implemented") |
||||
} |
||||
func (UnimplementedBackendServer) GenerateImage(context.Context, *GenerateImageRequest) (*Result, error) { |
||||
return nil, status.Errorf(codes.Unimplemented, "method GenerateImage not implemented") |
||||
} |
||||
func (UnimplementedBackendServer) AudioTranscription(context.Context, *TranscriptRequest) (*TranscriptResult, error) { |
||||
return nil, status.Errorf(codes.Unimplemented, "method AudioTranscription not implemented") |
||||
} |
||||
func (UnimplementedBackendServer) TTS(context.Context, *TTSRequest) (*Result, error) { |
||||
return nil, status.Errorf(codes.Unimplemented, "method TTS not implemented") |
||||
} |
||||
func (UnimplementedBackendServer) mustEmbedUnimplementedBackendServer() {} |
||||
|
||||
// UnsafeBackendServer may be embedded to opt out of forward compatibility for this service.
|
||||
// Use of this interface is not recommended, as added methods to BackendServer will
|
||||
// result in compilation errors.
|
||||
type UnsafeBackendServer interface { |
||||
mustEmbedUnimplementedBackendServer() |
||||
} |
||||
|
||||
func RegisterBackendServer(s grpc.ServiceRegistrar, srv BackendServer) { |
||||
s.RegisterService(&Backend_ServiceDesc, srv) |
||||
} |
||||
|
||||
func _Backend_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||
in := new(HealthMessage) |
||||
if err := dec(in); err != nil { |
||||
return nil, err |
||||
} |
||||
if interceptor == nil { |
||||
return srv.(BackendServer).Health(ctx, in) |
||||
} |
||||
info := &grpc.UnaryServerInfo{ |
||||
Server: srv, |
||||
FullMethod: "/backend.Backend/Health", |
||||
} |
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return srv.(BackendServer).Health(ctx, req.(*HealthMessage)) |
||||
} |
||||
return interceptor(ctx, in, info, handler) |
||||
} |
||||
|
||||
func _Backend_Predict_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||
in := new(PredictOptions) |
||||
if err := dec(in); err != nil { |
||||
return nil, err |
||||
} |
||||
if interceptor == nil { |
||||
return srv.(BackendServer).Predict(ctx, in) |
||||
} |
||||
info := &grpc.UnaryServerInfo{ |
||||
Server: srv, |
||||
FullMethod: "/backend.Backend/Predict", |
||||
} |
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return srv.(BackendServer).Predict(ctx, req.(*PredictOptions)) |
||||
} |
||||
return interceptor(ctx, in, info, handler) |
||||
} |
||||
|
||||
func _Backend_LoadModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||
in := new(ModelOptions) |
||||
if err := dec(in); err != nil { |
||||
return nil, err |
||||
} |
||||
if interceptor == nil { |
||||
return srv.(BackendServer).LoadModel(ctx, in) |
||||
} |
||||
info := &grpc.UnaryServerInfo{ |
||||
Server: srv, |
||||
FullMethod: "/backend.Backend/LoadModel", |
||||
} |
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return srv.(BackendServer).LoadModel(ctx, req.(*ModelOptions)) |
||||
} |
||||
return interceptor(ctx, in, info, handler) |
||||
} |
||||
|
||||
func _Backend_PredictStream_Handler(srv interface{}, stream grpc.ServerStream) error { |
||||
m := new(PredictOptions) |
||||
if err := stream.RecvMsg(m); err != nil { |
||||
return err |
||||
} |
||||
return srv.(BackendServer).PredictStream(m, &backendPredictStreamServer{stream}) |
||||
} |
||||
|
||||
type Backend_PredictStreamServer interface { |
||||
Send(*Reply) error |
||||
grpc.ServerStream |
||||
} |
||||
|
||||
type backendPredictStreamServer struct { |
||||
grpc.ServerStream |
||||
} |
||||
|
||||
func (x *backendPredictStreamServer) Send(m *Reply) error { |
||||
return x.ServerStream.SendMsg(m) |
||||
} |
||||
|
||||
func _Backend_Embedding_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||
in := new(PredictOptions) |
||||
if err := dec(in); err != nil { |
||||
return nil, err |
||||
} |
||||
if interceptor == nil { |
||||
return srv.(BackendServer).Embedding(ctx, in) |
||||
} |
||||
info := &grpc.UnaryServerInfo{ |
||||
Server: srv, |
||||
FullMethod: "/backend.Backend/Embedding", |
||||
} |
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return srv.(BackendServer).Embedding(ctx, req.(*PredictOptions)) |
||||
} |
||||
return interceptor(ctx, in, info, handler) |
||||
} |
||||
|
||||
func _Backend_GenerateImage_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||
in := new(GenerateImageRequest) |
||||
if err := dec(in); err != nil { |
||||
return nil, err |
||||
} |
||||
if interceptor == nil { |
||||
return srv.(BackendServer).GenerateImage(ctx, in) |
||||
} |
||||
info := &grpc.UnaryServerInfo{ |
||||
Server: srv, |
||||
FullMethod: "/backend.Backend/GenerateImage", |
||||
} |
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return srv.(BackendServer).GenerateImage(ctx, req.(*GenerateImageRequest)) |
||||
} |
||||
return interceptor(ctx, in, info, handler) |
||||
} |
||||
|
||||
func _Backend_AudioTranscription_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||
in := new(TranscriptRequest) |
||||
if err := dec(in); err != nil { |
||||
return nil, err |
||||
} |
||||
if interceptor == nil { |
||||
return srv.(BackendServer).AudioTranscription(ctx, in) |
||||
} |
||||
info := &grpc.UnaryServerInfo{ |
||||
Server: srv, |
||||
FullMethod: "/backend.Backend/AudioTranscription", |
||||
} |
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return srv.(BackendServer).AudioTranscription(ctx, req.(*TranscriptRequest)) |
||||
} |
||||
return interceptor(ctx, in, info, handler) |
||||
} |
||||
|
||||
func _Backend_TTS_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { |
||||
in := new(TTSRequest) |
||||
if err := dec(in); err != nil { |
||||
return nil, err |
||||
} |
||||
if interceptor == nil { |
||||
return srv.(BackendServer).TTS(ctx, in) |
||||
} |
||||
info := &grpc.UnaryServerInfo{ |
||||
Server: srv, |
||||
FullMethod: "/backend.Backend/TTS", |
||||
} |
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) { |
||||
return srv.(BackendServer).TTS(ctx, req.(*TTSRequest)) |
||||
} |
||||
return interceptor(ctx, in, info, handler) |
||||
} |
||||
|
||||
// Backend_ServiceDesc is the grpc.ServiceDesc for Backend service.
|
||||
// It's only intended for direct use with grpc.RegisterService,
|
||||
// and not to be introspected or modified (even as a copy)
|
||||
var Backend_ServiceDesc = grpc.ServiceDesc{ |
||||
ServiceName: "backend.Backend", |
||||
HandlerType: (*BackendServer)(nil), |
||||
Methods: []grpc.MethodDesc{ |
||||
{ |
||||
MethodName: "Health", |
||||
Handler: _Backend_Health_Handler, |
||||
}, |
||||
{ |
||||
MethodName: "Predict", |
||||
Handler: _Backend_Predict_Handler, |
||||
}, |
||||
{ |
||||
MethodName: "LoadModel", |
||||
Handler: _Backend_LoadModel_Handler, |
||||
}, |
||||
{ |
||||
MethodName: "Embedding", |
||||
Handler: _Backend_Embedding_Handler, |
||||
}, |
||||
{ |
||||
MethodName: "GenerateImage", |
||||
Handler: _Backend_GenerateImage_Handler, |
||||
}, |
||||
{ |
||||
MethodName: "AudioTranscription", |
||||
Handler: _Backend_AudioTranscription_Handler, |
||||
}, |
||||
{ |
||||
MethodName: "TTS", |
||||
Handler: _Backend_TTS_Handler, |
||||
}, |
||||
}, |
||||
Streams: []grpc.StreamDesc{ |
||||
{ |
||||
StreamName: "PredictStream", |
||||
Handler: _Backend_PredictStream_Handler, |
||||
ServerStreams: true, |
||||
}, |
||||
}, |
||||
Metadata: "pkg/grpc/proto/backend.proto", |
||||
} |
@ -0,0 +1,126 @@ |
||||
package grpc |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"log" |
||||
"net" |
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
"google.golang.org/grpc" |
||||
) |
||||
|
||||
// A GRPC Server that allows to run LLM inference.
|
||||
// It is used by the LLMServices to expose the LLM functionalities that are called by the client.
|
||||
// The GRPC Service is general, trying to encompass all the possible LLM options models.
|
||||
// It depends on the real implementer then what can be done or not.
|
||||
//
|
||||
// The server is implemented as a GRPC service, with the following methods:
|
||||
// - Predict: to run the inference with options
|
||||
// - PredictStream: to run the inference with options and stream the results
|
||||
|
||||
// server is used to implement helloworld.GreeterServer.
|
||||
type server struct { |
||||
pb.UnimplementedBackendServer |
||||
llm LLM |
||||
} |
||||
|
||||
func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) { |
||||
return &pb.Reply{Message: "OK"}, nil |
||||
} |
||||
|
||||
func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) { |
||||
embeds, err := s.llm.Embeddings(in) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &pb.EmbeddingResult{Embeddings: embeds}, nil |
||||
} |
||||
|
||||
func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { |
||||
err := s.llm.Load(in) |
||||
if err != nil { |
||||
return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err |
||||
} |
||||
return &pb.Result{Message: "Loading succeeded", Success: true}, nil |
||||
} |
||||
|
||||
func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) { |
||||
result, err := s.llm.Predict(in) |
||||
return &pb.Reply{Message: result}, err |
||||
} |
||||
|
||||
func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) { |
||||
err := s.llm.GenerateImage(in) |
||||
if err != nil { |
||||
return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err |
||||
} |
||||
return &pb.Result{Message: "Image generated", Success: true}, nil |
||||
} |
||||
|
||||
func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { |
||||
err := s.llm.TTS(in) |
||||
if err != nil { |
||||
return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err |
||||
} |
||||
return &pb.Result{Message: "Audio generated", Success: true}, nil |
||||
} |
||||
|
||||
func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { |
||||
result, err := s.llm.AudioTranscription(in) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
tresult := &pb.TranscriptResult{} |
||||
for _, s := range result.Segments { |
||||
tks := []int32{} |
||||
for _, t := range s.Tokens { |
||||
tks = append(tks, int32(t)) |
||||
} |
||||
tresult.Segments = append(tresult.Segments, |
||||
&pb.TranscriptSegment{ |
||||
Text: s.Text, |
||||
Id: int32(s.Id), |
||||
Start: int64(s.Start), |
||||
End: int64(s.End), |
||||
Tokens: tks, |
||||
}) |
||||
} |
||||
|
||||
tresult.Text = result.Text |
||||
return tresult, nil |
||||
} |
||||
|
||||
func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error { |
||||
|
||||
resultChan := make(chan string) |
||||
|
||||
done := make(chan bool) |
||||
go func() { |
||||
for result := range resultChan { |
||||
stream.Send(&pb.Reply{Message: result}) |
||||
} |
||||
done <- true |
||||
}() |
||||
|
||||
s.llm.PredictStream(in, resultChan) |
||||
<-done |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func StartServer(address string, model LLM) error { |
||||
lis, err := net.Listen("tcp", address) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
s := grpc.NewServer() |
||||
pb.RegisterBackendServer(s, &server{llm: model}) |
||||
log.Printf("gRPC Server listening at %v", lis.Addr()) |
||||
if err := s.Serve(lis); err != nil { |
||||
return err |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,27 @@ |
||||
package transcribe |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
whisperutil "github.com/go-skynet/LocalAI/pkg/grpc/whisper" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc/whisper/api" |
||||
) |
||||
|
||||
type Whisper struct { |
||||
base.Base |
||||
whisper whisper.Model |
||||
} |
||||
|
||||
func (sd *Whisper) Load(opts *pb.ModelOptions) error { |
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
w, err := whisper.New(opts.Model) |
||||
sd.whisper = w |
||||
return err |
||||
} |
||||
|
||||
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (api.Result, error) { |
||||
return whisperutil.Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) |
||||
} |
@ -0,0 +1,44 @@ |
||||
package tts |
||||
|
||||
// This is a wrapper to statisfy the GRPC service interface
|
||||
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
|
||||
import ( |
||||
"os" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grpc/base" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
piper "github.com/mudler/go-piper" |
||||
) |
||||
|
||||
type Piper struct { |
||||
base.Base |
||||
piper *PiperB |
||||
} |
||||
|
||||
func (sd *Piper) Load(opts *pb.ModelOptions) error { |
||||
var err error |
||||
// Note: the Model here is a path to a directory containing the model files
|
||||
sd.piper, err = New(opts.LibrarySearchPath) |
||||
return err |
||||
} |
||||
|
||||
func (sd *Piper) TTS(opts *pb.TTSRequest) error { |
||||
return sd.piper.TTS(opts.Text, opts.Model, opts.Dst) |
||||
} |
||||
|
||||
type PiperB struct { |
||||
assetDir string |
||||
} |
||||
|
||||
func New(assetDir string) (*PiperB, error) { |
||||
if _, err := os.Stat(assetDir); err != nil { |
||||
return nil, err |
||||
} |
||||
return &PiperB{ |
||||
assetDir: assetDir, |
||||
}, nil |
||||
} |
||||
|
||||
func (s *PiperB) TTS(text, model, dst string) error { |
||||
return piper.TextToWav(text, model, s.assetDir, "", dst) |
||||
} |
@ -0,0 +1,16 @@ |
||||
package api |
||||
|
||||
import "time" |
||||
|
||||
type Segment struct { |
||||
Id int `json:"id"` |
||||
Start time.Duration `json:"start"` |
||||
End time.Duration `json:"end"` |
||||
Text string `json:"text"` |
||||
Tokens []int `json:"tokens"` |
||||
} |
||||
|
||||
type Result struct { |
||||
Segments []Segment `json:"segments"` |
||||
Text string `json:"text"` |
||||
} |
@ -0,0 +1,66 @@ |
||||
package model |
||||
|
||||
import ( |
||||
"context" |
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
) |
||||
|
||||
type Options struct { |
||||
backendString string |
||||
modelFile string |
||||
threads uint32 |
||||
assetDir string |
||||
context context.Context |
||||
|
||||
gRPCOptions *pb.ModelOptions |
||||
} |
||||
|
||||
type Option func(*Options) |
||||
|
||||
func WithBackendString(backend string) Option { |
||||
return func(o *Options) { |
||||
o.backendString = backend |
||||
} |
||||
} |
||||
|
||||
func WithModelFile(modelFile string) Option { |
||||
return func(o *Options) { |
||||
o.modelFile = modelFile |
||||
} |
||||
} |
||||
|
||||
func WithLoadGRPCLLMModelOpts(opts *pb.ModelOptions) Option { |
||||
return func(o *Options) { |
||||
o.gRPCOptions = opts |
||||
} |
||||
} |
||||
|
||||
func WithThreads(threads uint32) Option { |
||||
return func(o *Options) { |
||||
o.threads = threads |
||||
} |
||||
} |
||||
|
||||
func WithAssetDir(assetDir string) Option { |
||||
return func(o *Options) { |
||||
o.assetDir = assetDir |
||||
} |
||||
} |
||||
|
||||
func WithContext(ctx context.Context) Option { |
||||
return func(o *Options) { |
||||
o.context = ctx |
||||
} |
||||
} |
||||
|
||||
func NewOptions(opts ...Option) *Options { |
||||
o := &Options{ |
||||
gRPCOptions: &pb.ModelOptions{}, |
||||
context: context.Background(), |
||||
} |
||||
for _, opt := range opts { |
||||
opt(o) |
||||
} |
||||
return o |
||||
} |
@ -1,12 +0,0 @@ |
||||
//go:build tts
|
||||
// +build tts
|
||||
|
||||
package tts |
||||
|
||||
import ( |
||||
piper "github.com/mudler/go-piper" |
||||
) |
||||
|
||||
func tts(text, model, assetDir, arLib, dst string) error { |
||||
return piper.TextToWav(text, model, assetDir, arLib, dst) |
||||
} |
@ -1,10 +0,0 @@ |
||||
//go:build !tts
|
||||
// +build !tts
|
||||
|
||||
package tts |
||||
|
||||
import "fmt" |
||||
|
||||
func tts(text, model, assetDir, arLib, dst string) error { |
||||
return fmt.Errorf("this version of LocalAI was built without the tts tag") |
||||
} |
@ -1,20 +0,0 @@ |
||||
package tts |
||||
|
||||
import "os" |
||||
|
||||
type Piper struct { |
||||
assetDir string |
||||
} |
||||
|
||||
func New(assetDir string) (*Piper, error) { |
||||
if _, err := os.Stat(assetDir); err != nil { |
||||
return nil, err |
||||
} |
||||
return &Piper{ |
||||
assetDir: assetDir, |
||||
}, nil |
||||
} |
||||
|
||||
func (s *Piper) TTS(text, model, dst string) error { |
||||
return tts(text, model, s.assetDir, "", dst) |
||||
} |
Loading…
Reference in new issue