Signed-off-by: Ettore Di Giacinto <mudler@localai.io>renovate/github.com-imdario-mergo-1.x
parent
f2f1d7fe72
commit
5dcfdbe51d
@ -0,0 +1,107 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"sync" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
bert "github.com/go-skynet/go-bert.cpp" |
||||
) |
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { |
||||
if !c.Embeddings { |
||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") |
||||
} |
||||
|
||||
modelFile := c.Model |
||||
|
||||
grpcOpts := gRPCModelOpts(c) |
||||
|
||||
var inferenceModel interface{} |
||||
var err error |
||||
|
||||
opts := []model.Option{ |
||||
model.WithLoadGRPCOpts(grpcOpts), |
||||
model.WithThreads(uint32(c.Threads)), |
||||
model.WithAssetDir(o.AssetsDestination), |
||||
model.WithModelFile(modelFile), |
||||
} |
||||
|
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(opts...) |
||||
} else { |
||||
opts = append(opts, model.WithBackendString(c.Backend)) |
||||
inferenceModel, err = loader.BackendLoader(opts...) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() ([]float32, error) |
||||
switch model := inferenceModel.(type) { |
||||
case *grpc.Client: |
||||
fn = func() ([]float32, error) { |
||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath) |
||||
if len(tokens) > 0 { |
||||
embeds := []int32{} |
||||
|
||||
for _, t := range tokens { |
||||
embeds = append(embeds, int32(t)) |
||||
} |
||||
predictOptions.EmbeddingTokens = embeds |
||||
|
||||
res, err := model.Embeddings(context.TODO(), predictOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return res.Embeddings, nil |
||||
} |
||||
predictOptions.Embeddings = s |
||||
|
||||
res, err := model.Embeddings(context.TODO(), predictOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return res.Embeddings, nil |
||||
} |
||||
|
||||
// bert embeddings
|
||||
case *bert.Bert: |
||||
fn = func() ([]float32, error) { |
||||
if len(tokens) > 0 { |
||||
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) |
||||
} |
||||
return model.Embeddings(s, bert.SetThreads(c.Threads)) |
||||
} |
||||
default: |
||||
fn = func() ([]float32, error) { |
||||
return nil, fmt.Errorf("embeddings not supported by the backend") |
||||
} |
||||
} |
||||
|
||||
return func() ([]float32, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
l := Lock(modelFile) |
||||
defer l.Unlock() |
||||
|
||||
embeds, err := fn() |
||||
if err != nil { |
||||
return embeds, err |
||||
} |
||||
// Remove trailing 0s
|
||||
for i := len(embeds) - 1; i >= 0; i-- { |
||||
if embeds[i] == 0.0 { |
||||
embeds = embeds[:i] |
||||
} else { |
||||
break |
||||
} |
||||
} |
||||
return embeds, nil |
||||
}, nil |
||||
} |
@ -0,0 +1,56 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"fmt" |
||||
"sync" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion" |
||||
) |
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { |
||||
if c.Backend != model.StableDiffusionBackend { |
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models") |
||||
} |
||||
|
||||
inferenceModel, err := loader.BackendLoader( |
||||
model.WithBackendString(c.Backend), |
||||
model.WithAssetDir(o.AssetsDestination), |
||||
model.WithThreads(uint32(c.Threads)), |
||||
model.WithModelFile(c.ImageGenerationAssets), |
||||
) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() error |
||||
switch model := inferenceModel.(type) { |
||||
case *stablediffusion.StableDiffusion: |
||||
fn = func() error { |
||||
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) |
||||
} |
||||
|
||||
default: |
||||
fn = func() error { |
||||
return fmt.Errorf("creation of images not supported by the backend") |
||||
} |
||||
} |
||||
|
||||
return func() error { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[c.Backend] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[c.Backend] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
return fn() |
||||
}, nil |
||||
} |
@ -0,0 +1,160 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"context" |
||||
"regexp" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"github.com/donomii/go-rwkv.cpp" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||
"github.com/go-skynet/LocalAI/pkg/langchain" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/go-skynet/bloomz.cpp" |
||||
) |
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { |
||||
supportStreams := false |
||||
modelFile := c.Model |
||||
|
||||
grpcOpts := gRPCModelOpts(c) |
||||
|
||||
var inferenceModel interface{} |
||||
var err error |
||||
|
||||
opts := []model.Option{ |
||||
model.WithLoadGRPCOpts(grpcOpts), |
||||
model.WithThreads(uint32(c.Threads)), // GPT4all uses this
|
||||
model.WithAssetDir(o.AssetsDestination), |
||||
model.WithModelFile(modelFile), |
||||
} |
||||
|
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(opts...) |
||||
} else { |
||||
opts = append(opts, model.WithBackendString(c.Backend)) |
||||
inferenceModel, err = loader.BackendLoader(opts...) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() (string, error) |
||||
|
||||
switch model := inferenceModel.(type) { |
||||
case *rwkv.RwkvState: |
||||
supportStreams = true |
||||
|
||||
fn = func() (string, error) { |
||||
stopWord := "\n" |
||||
if len(c.StopWords) > 0 { |
||||
stopWord = c.StopWords[0] |
||||
} |
||||
|
||||
if err := model.ProcessInput(s); err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) |
||||
|
||||
return response, nil |
||||
} |
||||
case *bloomz.Bloomz: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []bloomz.PredictOption{ |
||||
bloomz.SetTemperature(c.Temperature), |
||||
bloomz.SetTopP(c.TopP), |
||||
bloomz.SetTopK(c.TopK), |
||||
bloomz.SetTokens(c.Maxtokens), |
||||
bloomz.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
|
||||
case *grpc.Client: |
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
supportStreams = true |
||||
fn = func() (string, error) { |
||||
|
||||
opts := gRPCPredictOpts(c, loader.ModelPath) |
||||
opts.Prompt = s |
||||
if tokenCallback != nil { |
||||
ss := "" |
||||
err := model.PredictStream(context.TODO(), opts, func(s string) { |
||||
tokenCallback(s) |
||||
ss += s |
||||
}) |
||||
return ss, err |
||||
} else { |
||||
reply, err := model.Predict(context.TODO(), opts) |
||||
return reply.Message, err |
||||
} |
||||
} |
||||
case *langchain.HuggingFace: |
||||
fn = func() (string, error) { |
||||
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []langchain.PredictOption{ |
||||
langchain.SetModel(c.Model), |
||||
langchain.SetMaxTokens(c.Maxtokens), |
||||
langchain.SetTemperature(c.Temperature), |
||||
langchain.SetStopWords(c.StopWords), |
||||
} |
||||
|
||||
pred, er := model.PredictHuggingFace(s, predictOptions...) |
||||
if er != nil { |
||||
return "", er |
||||
} |
||||
return pred.Completion, nil |
||||
} |
||||
} |
||||
|
||||
return func() (string, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
l := Lock(modelFile) |
||||
defer l.Unlock() |
||||
|
||||
res, err := fn() |
||||
if tokenCallback != nil && !supportStreams { |
||||
tokenCallback(res) |
||||
} |
||||
return res, err |
||||
}, nil |
||||
} |
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
||||
var mu sync.Mutex = sync.Mutex{} |
||||
|
||||
func Finetune(config config.Config, input, prediction string) string { |
||||
if config.Echo { |
||||
prediction = input + prediction |
||||
} |
||||
|
||||
for _, c := range config.Cutstrings { |
||||
mu.Lock() |
||||
reg, ok := cutstrings[c] |
||||
if !ok { |
||||
cutstrings[c] = regexp.MustCompile(c) |
||||
reg = cutstrings[c] |
||||
} |
||||
mu.Unlock() |
||||
prediction = reg.ReplaceAllString(prediction, "") |
||||
} |
||||
|
||||
for _, c := range config.TrimSpace { |
||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
||||
} |
||||
return prediction |
||||
|
||||
} |
@ -0,0 +1,22 @@ |
||||
package backend |
||||
|
||||
import "sync" |
||||
|
||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
var mutexMap sync.Mutex |
||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) |
||||
|
||||
func Lock(s string) *sync.Mutex { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[s] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[s] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
|
||||
return l |
||||
} |
@ -0,0 +1,98 @@ |
||||
package backend |
||||
|
||||
import ( |
||||
"os" |
||||
"path/filepath" |
||||
|
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/pkg/langchain" |
||||
"github.com/go-skynet/bloomz.cpp" |
||||
) |
||||
|
||||
func langchainOptions(c config.Config) []langchain.PredictOption { |
||||
return []langchain.PredictOption{ |
||||
langchain.SetModel(c.Model), |
||||
langchain.SetMaxTokens(c.Maxtokens), |
||||
langchain.SetTemperature(c.Temperature), |
||||
langchain.SetStopWords(c.StopWords), |
||||
} |
||||
} |
||||
|
||||
func bloomzOptions(c config.Config) []bloomz.PredictOption { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []bloomz.PredictOption{ |
||||
bloomz.SetTemperature(c.Temperature), |
||||
bloomz.SetTopP(c.TopP), |
||||
bloomz.SetTopK(c.TopK), |
||||
bloomz.SetTokens(c.Maxtokens), |
||||
bloomz.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) |
||||
} |
||||
return predictOptions |
||||
} |
||||
func gRPCModelOpts(c config.Config) *pb.ModelOptions { |
||||
b := 512 |
||||
if c.Batch != 0 { |
||||
b = c.Batch |
||||
} |
||||
return &pb.ModelOptions{ |
||||
ContextSize: int32(c.ContextSize), |
||||
Seed: int32(c.Seed), |
||||
NBatch: int32(b), |
||||
F16Memory: c.F16, |
||||
MLock: c.MMlock, |
||||
NUMA: c.NUMA, |
||||
Embeddings: c.Embeddings, |
||||
LowVRAM: c.LowVRAM, |
||||
NGPULayers: int32(c.NGPULayers), |
||||
MMap: c.MMap, |
||||
MainGPU: c.MainGPU, |
||||
Threads: int32(c.Threads), |
||||
TensorSplit: c.TensorSplit, |
||||
} |
||||
} |
||||
|
||||
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { |
||||
promptCachePath := "" |
||||
if c.PromptCachePath != "" { |
||||
p := filepath.Join(modelPath, c.PromptCachePath) |
||||
os.MkdirAll(filepath.Dir(p), 0755) |
||||
promptCachePath = p |
||||
} |
||||
return &pb.PredictOptions{ |
||||
Temperature: float32(c.Temperature), |
||||
TopP: float32(c.TopP), |
||||
TopK: int32(c.TopK), |
||||
Tokens: int32(c.Maxtokens), |
||||
Threads: int32(c.Threads), |
||||
PromptCacheAll: c.PromptCacheAll, |
||||
PromptCacheRO: c.PromptCacheRO, |
||||
PromptCachePath: promptCachePath, |
||||
F16KV: c.F16, |
||||
DebugMode: c.Debug, |
||||
Grammar: c.Grammar, |
||||
|
||||
Mirostat: int32(c.Mirostat), |
||||
MirostatETA: float32(c.MirostatETA), |
||||
MirostatTAU: float32(c.MirostatTAU), |
||||
Debug: c.Debug, |
||||
StopPrompts: c.StopWords, |
||||
Repeat: int32(c.RepeatPenalty), |
||||
NKeep: int32(c.Keep), |
||||
Batch: int32(c.Batch), |
||||
IgnoreEOS: c.IgnoreEOS, |
||||
Seed: int32(c.Seed), |
||||
FrequencyPenalty: float32(c.FrequencyPenalty), |
||||
MLock: c.MMlock, |
||||
MMap: c.MMap, |
||||
MainGPU: c.MainGPU, |
||||
TensorSplit: c.TensorSplit, |
||||
TailFreeSamplingZ: float32(c.TFZ), |
||||
TypicalP: float32(c.TypicalP), |
||||
} |
||||
} |
@ -1,401 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/fs" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
"sync" |
||||
|
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"gopkg.in/yaml.v3" |
||||
) |
||||
|
||||
type Config struct { |
||||
OpenAIRequest `yaml:"parameters"` |
||||
Name string `yaml:"name"` |
||||
StopWords []string `yaml:"stopwords"` |
||||
Cutstrings []string `yaml:"cutstrings"` |
||||
TrimSpace []string `yaml:"trimspace"` |
||||
ContextSize int `yaml:"context_size"` |
||||
F16 bool `yaml:"f16"` |
||||
NUMA bool `yaml:"numa"` |
||||
Threads int `yaml:"threads"` |
||||
Debug bool `yaml:"debug"` |
||||
Roles map[string]string `yaml:"roles"` |
||||
Embeddings bool `yaml:"embeddings"` |
||||
Backend string `yaml:"backend"` |
||||
TemplateConfig TemplateConfig `yaml:"template"` |
||||
MirostatETA float64 `yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `yaml:"mirostat_tau"` |
||||
Mirostat int `yaml:"mirostat"` |
||||
NGPULayers int `yaml:"gpu_layers"` |
||||
MMap bool `yaml:"mmap"` |
||||
MMlock bool `yaml:"mmlock"` |
||||
LowVRAM bool `yaml:"low_vram"` |
||||
|
||||
TensorSplit string `yaml:"tensor_split"` |
||||
MainGPU string `yaml:"main_gpu"` |
||||
ImageGenerationAssets string `yaml:"asset_dir"` |
||||
|
||||
PromptCachePath string `yaml:"prompt_cache_path"` |
||||
PromptCacheAll bool `yaml:"prompt_cache_all"` |
||||
PromptCacheRO bool `yaml:"prompt_cache_ro"` |
||||
|
||||
Grammar string `yaml:"grammar"` |
||||
|
||||
FunctionsConfig Functions `yaml:"function"` |
||||
|
||||
PromptStrings, InputStrings []string |
||||
InputToken [][]int |
||||
functionCallString, functionCallNameString string |
||||
} |
||||
|
||||
type Functions struct { |
||||
DisableNoAction bool `yaml:"disable_no_action"` |
||||
NoActionFunctionName string `yaml:"no_action_function_name"` |
||||
NoActionDescriptionName string `yaml:"no_action_description_name"` |
||||
} |
||||
|
||||
type TemplateConfig struct { |
||||
Completion string `yaml:"completion"` |
||||
Functions string `yaml:"function"` |
||||
Chat string `yaml:"chat"` |
||||
Edit string `yaml:"edit"` |
||||
} |
||||
|
||||
type ConfigMerger struct { |
||||
configs map[string]Config |
||||
sync.Mutex |
||||
} |
||||
|
||||
func defaultConfig(modelFile string) *Config { |
||||
return &Config{ |
||||
OpenAIRequest: defaultRequest(modelFile), |
||||
} |
||||
} |
||||
|
||||
func NewConfigMerger() *ConfigMerger { |
||||
return &ConfigMerger{ |
||||
configs: make(map[string]Config), |
||||
} |
||||
} |
||||
func ReadConfigFile(file string) ([]*Config, error) { |
||||
c := &[]*Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return *c, nil |
||||
} |
||||
|
||||
func ReadConfig(file string) (*Config, error) { |
||||
c := &Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
func (cm *ConfigMerger) LoadConfigFile(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfigFile(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot load config file: %w", err) |
||||
} |
||||
|
||||
for _, cc := range c { |
||||
cm.configs[cc.Name] = *cc |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (cm *ConfigMerger) LoadConfig(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfig(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
|
||||
cm.configs[c.Name] = *c |
||||
return nil |
||||
} |
||||
|
||||
func (cm *ConfigMerger) GetConfig(m string) (Config, bool) { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
v, exists := cm.configs[m] |
||||
return v, exists |
||||
} |
||||
|
||||
func (cm *ConfigMerger) ListConfigs() []string { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
var res []string |
||||
for k := range cm.configs { |
||||
res = append(res, k) |
||||
} |
||||
return res |
||||
} |
||||
|
||||
func (cm *ConfigMerger) LoadConfigs(path string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
entries, err := os.ReadDir(path) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
files := make([]fs.FileInfo, 0, len(entries)) |
||||
for _, entry := range entries { |
||||
info, err := entry.Info() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
files = append(files, info) |
||||
} |
||||
for _, file := range files { |
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") { |
||||
continue |
||||
} |
||||
c, err := ReadConfig(filepath.Join(path, file.Name())) |
||||
if err == nil { |
||||
cm.configs[c.Name] = *c |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
func updateConfig(config *Config, input *OpenAIRequest) { |
||||
if input.Echo { |
||||
config.Echo = input.Echo |
||||
} |
||||
if input.TopK != 0 { |
||||
config.TopK = input.TopK |
||||
} |
||||
if input.TopP != 0 { |
||||
config.TopP = input.TopP |
||||
} |
||||
|
||||
if input.Grammar != "" { |
||||
config.Grammar = input.Grammar |
||||
} |
||||
|
||||
if input.Temperature != 0 { |
||||
config.Temperature = input.Temperature |
||||
} |
||||
|
||||
if input.Maxtokens != 0 { |
||||
config.Maxtokens = input.Maxtokens |
||||
} |
||||
|
||||
switch stop := input.Stop.(type) { |
||||
case string: |
||||
if stop != "" { |
||||
config.StopWords = append(config.StopWords, stop) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range stop { |
||||
if s, ok := pp.(string); ok { |
||||
config.StopWords = append(config.StopWords, s) |
||||
} |
||||
} |
||||
} |
||||
|
||||
if input.RepeatPenalty != 0 { |
||||
config.RepeatPenalty = input.RepeatPenalty |
||||
} |
||||
|
||||
if input.Keep != 0 { |
||||
config.Keep = input.Keep |
||||
} |
||||
|
||||
if input.Batch != 0 { |
||||
config.Batch = input.Batch |
||||
} |
||||
|
||||
if input.F16 { |
||||
config.F16 = input.F16 |
||||
} |
||||
|
||||
if input.IgnoreEOS { |
||||
config.IgnoreEOS = input.IgnoreEOS |
||||
} |
||||
|
||||
if input.Seed != 0 { |
||||
config.Seed = input.Seed |
||||
} |
||||
|
||||
if input.Mirostat != 0 { |
||||
config.Mirostat = input.Mirostat |
||||
} |
||||
|
||||
if input.MirostatETA != 0 { |
||||
config.MirostatETA = input.MirostatETA |
||||
} |
||||
|
||||
if input.MirostatTAU != 0 { |
||||
config.MirostatTAU = input.MirostatTAU |
||||
} |
||||
|
||||
if input.TypicalP != 0 { |
||||
config.TypicalP = input.TypicalP |
||||
} |
||||
|
||||
switch inputs := input.Input.(type) { |
||||
case string: |
||||
if inputs != "" { |
||||
config.InputStrings = append(config.InputStrings, inputs) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range inputs { |
||||
switch i := pp.(type) { |
||||
case string: |
||||
config.InputStrings = append(config.InputStrings, i) |
||||
case []interface{}: |
||||
tokens := []int{} |
||||
for _, ii := range i { |
||||
tokens = append(tokens, int(ii.(float64))) |
||||
} |
||||
config.InputToken = append(config.InputToken, tokens) |
||||
} |
||||
} |
||||
} |
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) { |
||||
case string: |
||||
if fnc != "" { |
||||
config.functionCallString = fnc |
||||
} |
||||
case map[string]interface{}: |
||||
var name string |
||||
n, exists := fnc["name"] |
||||
if exists { |
||||
nn, e := n.(string) |
||||
if e { |
||||
name = nn |
||||
} |
||||
} |
||||
config.functionCallNameString = name |
||||
} |
||||
|
||||
switch p := input.Prompt.(type) { |
||||
case string: |
||||
config.PromptStrings = append(config.PromptStrings, p) |
||||
case []interface{}: |
||||
for _, pp := range p { |
||||
if s, ok := pp.(string); ok { |
||||
config.PromptStrings = append(config.PromptStrings, s) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { |
||||
input := new(OpenAIRequest) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return "", nil, err |
||||
} |
||||
|
||||
modelFile := input.Model |
||||
|
||||
if c.Params("model") != "" { |
||||
modelFile = c.Params("model") |
||||
} |
||||
|
||||
received, _ := json.Marshal(input) |
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received)) |
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel { |
||||
models, _ := loader.ListModels() |
||||
if len(models) > 0 { |
||||
modelFile = models[0] |
||||
log.Debug().Msgf("No model specified, using: %s", modelFile) |
||||
} else { |
||||
log.Debug().Msgf("No model specified, returning error") |
||||
return "", nil, fmt.Errorf("no model specified") |
||||
} |
||||
} |
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists { |
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
||||
modelFile = bearer |
||||
} |
||||
return modelFile, input, nil |
||||
} |
||||
|
||||
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { |
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") |
||||
|
||||
var config *Config |
||||
|
||||
defaults := func() { |
||||
config = defaultConfig(modelFile) |
||||
config.ContextSize = ctx |
||||
config.Threads = threads |
||||
config.F16 = f16 |
||||
config.Debug = debug |
||||
} |
||||
|
||||
cfg, exists := cm.GetConfig(modelFile) |
||||
if !exists { |
||||
if _, err := os.Stat(modelConfig); err == nil { |
||||
if err := cm.LoadConfig(modelConfig); err != nil { |
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) |
||||
} |
||||
cfg, exists = cm.GetConfig(modelFile) |
||||
if exists { |
||||
config = &cfg |
||||
} else { |
||||
defaults() |
||||
} |
||||
} else { |
||||
defaults() |
||||
} |
||||
} else { |
||||
config = &cfg |
||||
} |
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateConfig(config, input) |
||||
|
||||
// Don't allow 0 as setting
|
||||
if config.Threads == 0 { |
||||
if threads != 0 { |
||||
config.Threads = threads |
||||
} else { |
||||
config.Threads = 4 |
||||
} |
||||
} |
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if debug { |
||||
config.Debug = true |
||||
} |
||||
|
||||
return config, input, nil |
||||
} |
@ -0,0 +1,209 @@ |
||||
package api_config |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io/fs" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"gopkg.in/yaml.v3" |
||||
) |
||||
|
||||
type Config struct { |
||||
PredictionOptions `yaml:"parameters"` |
||||
Name string `yaml:"name"` |
||||
StopWords []string `yaml:"stopwords"` |
||||
Cutstrings []string `yaml:"cutstrings"` |
||||
TrimSpace []string `yaml:"trimspace"` |
||||
ContextSize int `yaml:"context_size"` |
||||
F16 bool `yaml:"f16"` |
||||
NUMA bool `yaml:"numa"` |
||||
Threads int `yaml:"threads"` |
||||
Debug bool `yaml:"debug"` |
||||
Roles map[string]string `yaml:"roles"` |
||||
Embeddings bool `yaml:"embeddings"` |
||||
Backend string `yaml:"backend"` |
||||
TemplateConfig TemplateConfig `yaml:"template"` |
||||
MirostatETA float64 `yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `yaml:"mirostat_tau"` |
||||
Mirostat int `yaml:"mirostat"` |
||||
NGPULayers int `yaml:"gpu_layers"` |
||||
MMap bool `yaml:"mmap"` |
||||
MMlock bool `yaml:"mmlock"` |
||||
LowVRAM bool `yaml:"low_vram"` |
||||
|
||||
TensorSplit string `yaml:"tensor_split"` |
||||
MainGPU string `yaml:"main_gpu"` |
||||
ImageGenerationAssets string `yaml:"asset_dir"` |
||||
|
||||
PromptCachePath string `yaml:"prompt_cache_path"` |
||||
PromptCacheAll bool `yaml:"prompt_cache_all"` |
||||
PromptCacheRO bool `yaml:"prompt_cache_ro"` |
||||
|
||||
Grammar string `yaml:"grammar"` |
||||
|
||||
PromptStrings, InputStrings []string |
||||
InputToken [][]int |
||||
functionCallString, functionCallNameString string |
||||
|
||||
FunctionsConfig Functions `yaml:"function"` |
||||
} |
||||
|
||||
type Functions struct { |
||||
DisableNoAction bool `yaml:"disable_no_action"` |
||||
NoActionFunctionName string `yaml:"no_action_function_name"` |
||||
NoActionDescriptionName string `yaml:"no_action_description_name"` |
||||
} |
||||
|
||||
type TemplateConfig struct { |
||||
Completion string `yaml:"completion"` |
||||
Functions string `yaml:"function"` |
||||
Chat string `yaml:"chat"` |
||||
Edit string `yaml:"edit"` |
||||
} |
||||
|
||||
type ConfigLoader struct { |
||||
configs map[string]Config |
||||
sync.Mutex |
||||
} |
||||
|
||||
func (c *Config) SetFunctionCallString(s string) { |
||||
c.functionCallString = s |
||||
} |
||||
|
||||
func (c *Config) SetFunctionCallNameString(s string) { |
||||
c.functionCallNameString = s |
||||
} |
||||
|
||||
func (c *Config) ShouldUseFunctions() bool { |
||||
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) |
||||
} |
||||
|
||||
func (c *Config) ShouldCallSpecificFunction() bool { |
||||
return len(c.functionCallNameString) > 0 |
||||
} |
||||
|
||||
func (c *Config) FunctionToCall() string { |
||||
return c.functionCallNameString |
||||
} |
||||
|
||||
func defaultPredictOptions(modelFile string) PredictionOptions { |
||||
return PredictionOptions{ |
||||
TopP: 0.7, |
||||
TopK: 80, |
||||
Maxtokens: 512, |
||||
Temperature: 0.9, |
||||
Model: modelFile, |
||||
} |
||||
} |
||||
|
||||
func DefaultConfig(modelFile string) *Config { |
||||
return &Config{ |
||||
PredictionOptions: defaultPredictOptions(modelFile), |
||||
} |
||||
} |
||||
|
||||
func NewConfigLoader() *ConfigLoader { |
||||
return &ConfigLoader{ |
||||
configs: make(map[string]Config), |
||||
} |
||||
} |
||||
func ReadConfigFile(file string) ([]*Config, error) { |
||||
c := &[]*Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return *c, nil |
||||
} |
||||
|
||||
func ReadConfig(file string) (*Config, error) { |
||||
c := &Config{} |
||||
f, err := os.ReadFile(file) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
if err := yaml.Unmarshal(f, c); err != nil { |
||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||
} |
||||
|
||||
return c, nil |
||||
} |
||||
|
||||
func (cm *ConfigLoader) LoadConfigFile(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfigFile(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot load config file: %w", err) |
||||
} |
||||
|
||||
for _, cc := range c { |
||||
cm.configs[cc.Name] = *cc |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (cm *ConfigLoader) LoadConfig(file string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
c, err := ReadConfig(file) |
||||
if err != nil { |
||||
return fmt.Errorf("cannot read config file: %w", err) |
||||
} |
||||
|
||||
cm.configs[c.Name] = *c |
||||
return nil |
||||
} |
||||
|
||||
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
v, exists := cm.configs[m] |
||||
return v, exists |
||||
} |
||||
|
||||
func (cm *ConfigLoader) ListConfigs() []string { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
var res []string |
||||
for k := range cm.configs { |
||||
res = append(res, k) |
||||
} |
||||
return res |
||||
} |
||||
|
||||
func (cm *ConfigLoader) LoadConfigs(path string) error { |
||||
cm.Lock() |
||||
defer cm.Unlock() |
||||
entries, err := os.ReadDir(path) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
files := make([]fs.FileInfo, 0, len(entries)) |
||||
for _, entry := range entries { |
||||
info, err := entry.Info() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
files = append(files, info) |
||||
} |
||||
for _, file := range files { |
||||
// Skip templates, YAML and .keep files
|
||||
if !strings.Contains(file.Name(), ".yaml") { |
||||
continue |
||||
} |
||||
c, err := ReadConfig(filepath.Join(path, file.Name())) |
||||
if err == nil { |
||||
cm.configs[c.Name] = *c |
||||
} |
||||
} |
||||
|
||||
return nil |
||||
} |
@ -0,0 +1,37 @@ |
||||
package api_config |
||||
|
||||
type PredictionOptions struct { |
||||
|
||||
// Also part of the OpenAI official spec
|
||||
Model string `json:"model" yaml:"model"` |
||||
|
||||
// Also part of the OpenAI official spec
|
||||
Language string `json:"language"` |
||||
|
||||
// Also part of the OpenAI official spec. use it for returning multiple results
|
||||
N int `json:"n"` |
||||
|
||||
// Common options between all the API calls, part of the OpenAI spec
|
||||
TopP float64 `json:"top_p" yaml:"top_p"` |
||||
TopK int `json:"top_k" yaml:"top_k"` |
||||
Temperature float64 `json:"temperature" yaml:"temperature"` |
||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"` |
||||
Echo bool `json:"echo"` |
||||
|
||||
// Custom parameters - not present in the OpenAI API
|
||||
Batch int `json:"batch" yaml:"batch"` |
||||
F16 bool `json:"f16" yaml:"f16"` |
||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` |
||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` |
||||
Keep int `json:"n_keep" yaml:"n_keep"` |
||||
|
||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` |
||||
Mirostat int `json:"mirostat" yaml:"mirostat"` |
||||
|
||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` |
||||
TFZ float64 `json:"tfz" yaml:"tfz"` |
||||
|
||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"` |
||||
Seed int `json:"seed" yaml:"seed"` |
||||
} |
@ -1,973 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"encoding/base64" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"io/ioutil" |
||||
"net/http" |
||||
"os" |
||||
"path" |
||||
"path/filepath" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" |
||||
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"github.com/valyala/fasthttp" |
||||
) |
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
type APIError struct { |
||||
Code any `json:"code,omitempty"` |
||||
Message string `json:"message"` |
||||
Param *string `json:"param,omitempty"` |
||||
Type string `json:"type"` |
||||
} |
||||
|
||||
type ErrorResponse struct { |
||||
Error *APIError `json:"error,omitempty"` |
||||
} |
||||
|
||||
type OpenAIUsage struct { |
||||
PromptTokens int `json:"prompt_tokens"` |
||||
CompletionTokens int `json:"completion_tokens"` |
||||
TotalTokens int `json:"total_tokens"` |
||||
} |
||||
|
||||
type Item struct { |
||||
Embedding []float32 `json:"embedding"` |
||||
Index int `json:"index"` |
||||
Object string `json:"object,omitempty"` |
||||
|
||||
// Images
|
||||
URL string `json:"url,omitempty"` |
||||
B64JSON string `json:"b64_json,omitempty"` |
||||
} |
||||
|
||||
type OpenAIResponse struct { |
||||
Created int `json:"created,omitempty"` |
||||
Object string `json:"object,omitempty"` |
||||
ID string `json:"id,omitempty"` |
||||
Model string `json:"model,omitempty"` |
||||
Choices []Choice `json:"choices,omitempty"` |
||||
Data []Item `json:"data,omitempty"` |
||||
|
||||
Usage OpenAIUsage `json:"usage"` |
||||
} |
||||
|
||||
type Choice struct { |
||||
Index int `json:"index,omitempty"` |
||||
FinishReason string `json:"finish_reason,omitempty"` |
||||
Message *Message `json:"message,omitempty"` |
||||
Delta *Message `json:"delta,omitempty"` |
||||
Text string `json:"text,omitempty"` |
||||
} |
||||
|
||||
type Message struct { |
||||
// The message role
|
||||
Role string `json:"role,omitempty" yaml:"role"` |
||||
// The message content
|
||||
Content *string `json:"content" yaml:"content"` |
||||
// A result of a function call
|
||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` |
||||
} |
||||
|
||||
type OpenAIModel struct { |
||||
ID string `json:"id"` |
||||
Object string `json:"object"` |
||||
} |
||||
|
||||
type OpenAIRequest struct { |
||||
Model string `json:"model" yaml:"model"` |
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"` |
||||
Language string `json:"language"` |
||||
//whisper/image
|
||||
ResponseFormat string `json:"response_format"` |
||||
// image
|
||||
Size string `json:"size"` |
||||
// Prompt is read only by completion/image API calls
|
||||
Prompt interface{} `json:"prompt" yaml:"prompt"` |
||||
|
||||
// Edit endpoint
|
||||
Instruction string `json:"instruction" yaml:"instruction"` |
||||
Input interface{} `json:"input" yaml:"input"` |
||||
|
||||
Stop interface{} `json:"stop" yaml:"stop"` |
||||
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages" yaml:"messages"` |
||||
|
||||
// A list of available functions to call
|
||||
Functions []grammar.Function `json:"functions" yaml:"functions"` |
||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||
|
||||
Stream bool `json:"stream"` |
||||
Echo bool `json:"echo"` |
||||
// Common options between all the API calls
|
||||
TopP float64 `json:"top_p" yaml:"top_p"` |
||||
TopK int `json:"top_k" yaml:"top_k"` |
||||
Temperature float64 `json:"temperature" yaml:"temperature"` |
||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"` |
||||
|
||||
N int `json:"n"` |
||||
|
||||
// Custom parameters - not present in the OpenAI API
|
||||
Batch int `json:"batch" yaml:"batch"` |
||||
F16 bool `json:"f16" yaml:"f16"` |
||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` |
||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` |
||||
Keep int `json:"n_keep" yaml:"n_keep"` |
||||
|
||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` |
||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` |
||||
Mirostat int `json:"mirostat" yaml:"mirostat"` |
||||
|
||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` |
||||
TFZ float64 `json:"tfz" yaml:"tfz"` |
||||
|
||||
Seed int `json:"seed" yaml:"seed"` |
||||
|
||||
// Image (not supported by OpenAI)
|
||||
Mode int `json:"mode"` |
||||
Step int `json:"step"` |
||||
|
||||
// A grammar to constrain the LLM output
|
||||
Grammar string `json:"grammar" yaml:"grammar"` |
||||
// A grammar object
|
||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` |
||||
|
||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"` |
||||
} |
||||
|
||||
func defaultRequest(modelFile string) OpenAIRequest { |
||||
return OpenAIRequest{ |
||||
TopP: 0.7, |
||||
TopK: 80, |
||||
Maxtokens: 512, |
||||
Temperature: 0.9, |
||||
Model: modelFile, |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
Index: 0, |
||||
Text: s, |
||||
}, |
||||
}, |
||||
Object: "text_completion", |
||||
} |
||||
log.Debug().Msgf("Sending goroutine: %s", s) |
||||
|
||||
responses <- resp |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
|
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("`input`: %+v", input) |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
if input.Stream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Completion != "" { |
||||
templateFile = config.TemplateConfig.Completion |
||||
} |
||||
|
||||
if input.Stream { |
||||
if len(config.PromptStrings) > 1 { |
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") |
||||
} |
||||
|
||||
predInput := config.PromptStrings[0] |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{ |
||||
Input: predInput, |
||||
}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} |
||||
|
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
Index: 0, |
||||
FinishReason: "stop", |
||||
}, |
||||
}, |
||||
Object: "text_completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.PromptStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{ |
||||
Input: i, |
||||
}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "text_completion", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
items := []Item{} |
||||
|
||||
for i, s := range config.InputToken { |
||||
// get the model function to call for the result
|
||||
embedFn, err := ModelEmbedding("", s, o.loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
for i, s := range config.InputStrings { |
||||
// get the model function to call for the result
|
||||
embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items, |
||||
Object: "list", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
func isEOS(s string) bool { |
||||
if s == "<|endoftext|>" { |
||||
return true |
||||
} |
||||
|
||||
return false |
||||
} |
||||
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
|
||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
initialMessage := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
responses <- initialMessage |
||||
|
||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
log.Debug().Msgf("Sending goroutine: %s", s) |
||||
|
||||
if s != "" && !isEOS(s) { |
||||
responses <- resp |
||||
} |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
return func(c *fiber.Ctx) error { |
||||
processFunctions := false |
||||
funcs := grammar.Functions{} |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
log.Debug().Msgf("Configuration read: %+v", config) |
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer" |
||||
noActionDescription := "use this action to answer without performing any action" |
||||
|
||||
if config.FunctionsConfig.NoActionFunctionName != "" { |
||||
noActionName = config.FunctionsConfig.NoActionFunctionName |
||||
} |
||||
if config.FunctionsConfig.NoActionDescriptionName != "" { |
||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName |
||||
} |
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && |
||||
((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) { |
||||
log.Debug().Msgf("Response needs to process functions") |
||||
|
||||
processFunctions = true |
||||
|
||||
noActionGrammar := grammar.Function{ |
||||
Name: noActionName, |
||||
Description: noActionDescription, |
||||
Parameters: map[string]interface{}{ |
||||
"properties": map[string]interface{}{ |
||||
"message": map[string]interface{}{ |
||||
"type": "string", |
||||
"description": "The message to reply the user with", |
||||
}}, |
||||
}, |
||||
} |
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...) |
||||
if !config.FunctionsConfig.DisableNoAction { |
||||
funcs = append(funcs, noActionGrammar) |
||||
} |
||||
|
||||
// Force picking one of the functions by the request
|
||||
if config.functionCallNameString != "" { |
||||
funcs = funcs.Select(config.functionCallNameString) |
||||
} |
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure() |
||||
config.Grammar = jsStruct.Grammar("") |
||||
} else if input.JSONFunctionGrammarObject != nil { |
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("") |
||||
} |
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream && !processFunctions |
||||
|
||||
log.Debug().Msgf("Parameters: %+v", config) |
||||
|
||||
var predInput string |
||||
|
||||
mess := []string{} |
||||
for _, i := range input.Messages { |
||||
var content string |
||||
role := i.Role |
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if i.FunctionCall != nil && i.Role == "assistant" { |
||||
roleFn := "assistant_function_call" |
||||
r := config.Roles[roleFn] |
||||
if r != "" { |
||||
role = roleFn |
||||
} |
||||
} |
||||
r := config.Roles[role] |
||||
contentExists := i.Content != nil && *i.Content != "" |
||||
if r != "" { |
||||
if contentExists { |
||||
content = fmt.Sprint(r, " ", *i.Content) |
||||
} |
||||
if i.FunctionCall != nil { |
||||
j, err := json.Marshal(i.FunctionCall) |
||||
if err == nil { |
||||
if contentExists { |
||||
content += "\n" + fmt.Sprint(r, " ", string(j)) |
||||
} else { |
||||
content = fmt.Sprint(r, " ", string(j)) |
||||
} |
||||
} |
||||
} |
||||
} else { |
||||
if contentExists { |
||||
content = fmt.Sprint(*i.Content) |
||||
} |
||||
if i.FunctionCall != nil { |
||||
j, err := json.Marshal(i.FunctionCall) |
||||
if err == nil { |
||||
if contentExists { |
||||
content += "\n" + string(j) |
||||
} else { |
||||
content = string(j) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
mess = append(mess, content) |
||||
} |
||||
|
||||
predInput = strings.Join(mess, "\n") |
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput) |
||||
|
||||
if toStream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions { |
||||
templateFile = config.TemplateConfig.Chat |
||||
} |
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions { |
||||
templateFile = config.TemplateConfig.Functions |
||||
} |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Functions []grammar.Function |
||||
}{ |
||||
Input: predInput, |
||||
Functions: funcs, |
||||
}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} else { |
||||
log.Debug().Msgf("Template failed loading: %s", err.Error()) |
||||
} |
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput) |
||||
if processFunctions { |
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar) |
||||
} |
||||
|
||||
if toStream { |
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
FinishReason: "stop", |
||||
Index: 0, |
||||
Delta: &Message{}, |
||||
}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) { |
||||
if processFunctions { |
||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||
ss := map[string]interface{}{} |
||||
json.Unmarshal([]byte(s), &ss) |
||||
log.Debug().Msgf("Function return: %s %+v", s, ss) |
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name := ss["function"] |
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
d, _ := json.Marshal(args) |
||||
|
||||
ss["arguments"] = string(d) |
||||
ss["name"] = func_name |
||||
|
||||
// if do nothing, reply with a message
|
||||
if func_name == noActionName { |
||||
log.Debug().Msgf("nothing to do, computing a reply") |
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{} |
||||
json.Unmarshal([]byte(d), &arguments) |
||||
m, exists := arguments["message"] |
||||
if exists { |
||||
switch message := m.(type) { |
||||
case string: |
||||
if message != "" { |
||||
log.Debug().Msgf("Reply received from LLM: %s", message) |
||||
message = Finetune(*config, predInput, message) |
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) |
||||
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply") |
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU) another computation
|
||||
config.Grammar = "" |
||||
predFunc, err := ModelInference(predInput, o.loader, *config, o, nil) |
||||
if err != nil { |
||||
log.Error().Msgf("inference error: %s", err.Error()) |
||||
return |
||||
} |
||||
|
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
log.Error().Msgf("inference error: %s", err.Error()) |
||||
return |
||||
} |
||||
|
||||
prediction = Finetune(*config, predInput, prediction) |
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) |
||||
} else { |
||||
// otherwise reply with the function call
|
||||
*c = append(*c, Choice{ |
||||
FinishReason: "function_call", |
||||
Message: &Message{Role: "assistant", FunctionCall: ss}, |
||||
}) |
||||
} |
||||
|
||||
return |
||||
} |
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "chat.completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", respData) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Edit != "" { |
||||
templateFile = config.TemplateConfig.Edit |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.InputStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Instruction string |
||||
}{Input: i}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "edit", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/* |
||||
* |
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{ |
||||
"prompt": "A cute baby sea otter", |
||||
"n": 1, |
||||
"size": "512x512" |
||||
}' |
||||
|
||||
* |
||||
*/ |
||||
func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
if m == "" { |
||||
m = model.StableDiffusionBackend |
||||
} |
||||
log.Debug().Msgf("Loading model: %+v", m) |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
// XXX: Only stablediffusion is supported for now
|
||||
if config.Backend == "" { |
||||
config.Backend = model.StableDiffusionBackend |
||||
} |
||||
|
||||
sizeParts := strings.Split(input.Size, "x") |
||||
if len(sizeParts) != 2 { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
width, err := strconv.Atoi(sizeParts[0]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
height, err := strconv.Atoi(sizeParts[1]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
|
||||
b64JSON := false |
||||
if input.ResponseFormat == "b64_json" { |
||||
b64JSON = true |
||||
} |
||||
|
||||
var result []Item |
||||
for _, i := range config.PromptStrings { |
||||
n := input.N |
||||
if input.N == 0 { |
||||
n = 1 |
||||
} |
||||
for j := 0; j < n; j++ { |
||||
prompts := strings.Split(i, "|") |
||||
positive_prompt := prompts[0] |
||||
negative_prompt := "" |
||||
if len(prompts) > 1 { |
||||
negative_prompt = prompts[1] |
||||
} |
||||
|
||||
mode := 0 |
||||
step := 15 |
||||
|
||||
if input.Mode != 0 { |
||||
mode = input.Mode |
||||
} |
||||
|
||||
if input.Step != 0 { |
||||
step = input.Step |
||||
} |
||||
|
||||
tempDir := "" |
||||
if !b64JSON { |
||||
tempDir = o.imageDir |
||||
} |
||||
// Create a temporary file
|
||||
outputFile, err := ioutil.TempFile(tempDir, "b64") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
outputFile.Close() |
||||
output := outputFile.Name() + ".png" |
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
baseURL := c.BaseURL() |
||||
|
||||
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err := fn(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
item := &Item{} |
||||
|
||||
if b64JSON { |
||||
defer os.RemoveAll(output) |
||||
data, err := os.ReadFile(output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
||||
} else { |
||||
base := filepath.Base(output) |
||||
item.URL = baseURL + "/generated-images/" + base |
||||
} |
||||
|
||||
result = append(result, *item) |
||||
} |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Data: result, |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
f, err := file.Open() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer f.Close() |
||||
|
||||
dir, err := os.MkdirTemp("", "whisper") |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer os.RemoveAll(dir) |
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename)) |
||||
dstFile, err := os.Create(dst) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil { |
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst) |
||||
|
||||
whisperModel, err := o.loader.BackendLoader( |
||||
model.WithBackendString(model.WhisperBackend), |
||||
model.WithModelFile(config.Model), |
||||
model.WithThreads(uint32(config.Threads)), |
||||
model.WithAssetDir(o.assetsDestination)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if whisperModel == nil { |
||||
return fmt.Errorf("could not load whisper model") |
||||
} |
||||
|
||||
w, ok := whisperModel.(whisper.Model) |
||||
if !ok { |
||||
return fmt.Errorf("loader returned non-whisper object") |
||||
} |
||||
|
||||
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr) |
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(tr) |
||||
} |
||||
} |
||||
|
||||
func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
models, err := loader.ListModels() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var mm map[string]interface{} = map[string]interface{}{} |
||||
|
||||
dataModels := []OpenAIModel{} |
||||
for _, m := range models { |
||||
mm[m] = nil |
||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) |
||||
} |
||||
|
||||
for _, k := range cm.ListConfigs() { |
||||
if _, exists := mm[k]; !exists { |
||||
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) |
||||
} |
||||
} |
||||
|
||||
return c.JSON(struct { |
||||
Object string `json:"object"` |
||||
Data []OpenAIModel `json:"data"` |
||||
}{ |
||||
Object: "list", |
||||
Data: dataModels, |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,105 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
|
||||
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||
) |
||||
|
||||
// APIError provides error information returned by the OpenAI API.
|
||||
type APIError struct { |
||||
Code any `json:"code,omitempty"` |
||||
Message string `json:"message"` |
||||
Param *string `json:"param,omitempty"` |
||||
Type string `json:"type"` |
||||
} |
||||
|
||||
type ErrorResponse struct { |
||||
Error *APIError `json:"error,omitempty"` |
||||
} |
||||
|
||||
type OpenAIUsage struct { |
||||
PromptTokens int `json:"prompt_tokens"` |
||||
CompletionTokens int `json:"completion_tokens"` |
||||
TotalTokens int `json:"total_tokens"` |
||||
} |
||||
|
||||
type Item struct { |
||||
Embedding []float32 `json:"embedding"` |
||||
Index int `json:"index"` |
||||
Object string `json:"object,omitempty"` |
||||
|
||||
// Images
|
||||
URL string `json:"url,omitempty"` |
||||
B64JSON string `json:"b64_json,omitempty"` |
||||
} |
||||
|
||||
type OpenAIResponse struct { |
||||
Created int `json:"created,omitempty"` |
||||
Object string `json:"object,omitempty"` |
||||
ID string `json:"id,omitempty"` |
||||
Model string `json:"model,omitempty"` |
||||
Choices []Choice `json:"choices,omitempty"` |
||||
Data []Item `json:"data,omitempty"` |
||||
|
||||
Usage OpenAIUsage `json:"usage"` |
||||
} |
||||
|
||||
type Choice struct { |
||||
Index int `json:"index,omitempty"` |
||||
FinishReason string `json:"finish_reason,omitempty"` |
||||
Message *Message `json:"message,omitempty"` |
||||
Delta *Message `json:"delta,omitempty"` |
||||
Text string `json:"text,omitempty"` |
||||
} |
||||
|
||||
type Message struct { |
||||
// The message role
|
||||
Role string `json:"role,omitempty" yaml:"role"` |
||||
// The message content
|
||||
Content *string `json:"content" yaml:"content"` |
||||
// A result of a function call
|
||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` |
||||
} |
||||
|
||||
type OpenAIModel struct { |
||||
ID string `json:"id"` |
||||
Object string `json:"object"` |
||||
} |
||||
|
||||
type OpenAIRequest struct { |
||||
config.PredictionOptions |
||||
|
||||
// whisper
|
||||
File string `json:"file" validate:"required"` |
||||
//whisper/image
|
||||
ResponseFormat string `json:"response_format"` |
||||
// image
|
||||
Size string `json:"size"` |
||||
// Prompt is read only by completion/image API calls
|
||||
Prompt interface{} `json:"prompt" yaml:"prompt"` |
||||
|
||||
// Edit endpoint
|
||||
Instruction string `json:"instruction" yaml:"instruction"` |
||||
Input interface{} `json:"input" yaml:"input"` |
||||
|
||||
Stop interface{} `json:"stop" yaml:"stop"` |
||||
|
||||
// Messages is read only by chat/completion API calls
|
||||
Messages []Message `json:"messages" yaml:"messages"` |
||||
|
||||
// A list of available functions to call
|
||||
Functions []grammar.Function `json:"functions" yaml:"functions"` |
||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||
|
||||
Stream bool `json:"stream"` |
||||
|
||||
// Image (not supported by OpenAI)
|
||||
Mode int `json:"mode"` |
||||
Step int `json:"step"` |
||||
|
||||
// A grammar to constrain the LLM output
|
||||
Grammar string `json:"grammar" yaml:"grammar"` |
||||
|
||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` |
||||
} |
@ -0,0 +1,320 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"encoding/json" |
||||
"fmt" |
||||
"strings" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"github.com/valyala/fasthttp" |
||||
) |
||||
|
||||
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
initialMessage := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
responses <- initialMessage |
||||
|
||||
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
|
||||
responses <- resp |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
return func(c *fiber.Ctx) error { |
||||
processFunctions := false |
||||
funcs := grammar.Functions{} |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
log.Debug().Msgf("Configuration read: %+v", config) |
||||
|
||||
// Allow the user to set custom actions via config file
|
||||
// to be "embedded" in each model
|
||||
noActionName := "answer" |
||||
noActionDescription := "use this action to answer without performing any action" |
||||
|
||||
if config.FunctionsConfig.NoActionFunctionName != "" { |
||||
noActionName = config.FunctionsConfig.NoActionFunctionName |
||||
} |
||||
if config.FunctionsConfig.NoActionDescriptionName != "" { |
||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName |
||||
} |
||||
|
||||
// process functions if we have any defined or if we have a function call string
|
||||
if len(input.Functions) > 0 && config.ShouldUseFunctions() { |
||||
log.Debug().Msgf("Response needs to process functions") |
||||
|
||||
processFunctions = true |
||||
|
||||
noActionGrammar := grammar.Function{ |
||||
Name: noActionName, |
||||
Description: noActionDescription, |
||||
Parameters: map[string]interface{}{ |
||||
"properties": map[string]interface{}{ |
||||
"message": map[string]interface{}{ |
||||
"type": "string", |
||||
"description": "The message to reply the user with", |
||||
}}, |
||||
}, |
||||
} |
||||
|
||||
// Append the no action function
|
||||
funcs = append(funcs, input.Functions...) |
||||
if !config.FunctionsConfig.DisableNoAction { |
||||
funcs = append(funcs, noActionGrammar) |
||||
} |
||||
|
||||
// Force picking one of the functions by the request
|
||||
if config.FunctionToCall() != "" { |
||||
funcs = funcs.Select(config.FunctionToCall()) |
||||
} |
||||
|
||||
// Update input grammar
|
||||
jsStruct := funcs.ToJSONStructure() |
||||
config.Grammar = jsStruct.Grammar("") |
||||
} else if input.JSONFunctionGrammarObject != nil { |
||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("") |
||||
} |
||||
|
||||
// functions are not supported in stream mode (yet?)
|
||||
toStream := input.Stream && !processFunctions |
||||
|
||||
log.Debug().Msgf("Parameters: %+v", config) |
||||
|
||||
var predInput string |
||||
|
||||
mess := []string{} |
||||
for _, i := range input.Messages { |
||||
var content string |
||||
role := i.Role |
||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||
if i.FunctionCall != nil && i.Role == "assistant" { |
||||
roleFn := "assistant_function_call" |
||||
r := config.Roles[roleFn] |
||||
if r != "" { |
||||
role = roleFn |
||||
} |
||||
} |
||||
r := config.Roles[role] |
||||
contentExists := i.Content != nil && *i.Content != "" |
||||
if r != "" { |
||||
if contentExists { |
||||
content = fmt.Sprint(r, " ", *i.Content) |
||||
} |
||||
if i.FunctionCall != nil { |
||||
j, err := json.Marshal(i.FunctionCall) |
||||
if err == nil { |
||||
if contentExists { |
||||
content += "\n" + fmt.Sprint(r, " ", string(j)) |
||||
} else { |
||||
content = fmt.Sprint(r, " ", string(j)) |
||||
} |
||||
} |
||||
} |
||||
} else { |
||||
if contentExists { |
||||
content = fmt.Sprint(*i.Content) |
||||
} |
||||
if i.FunctionCall != nil { |
||||
j, err := json.Marshal(i.FunctionCall) |
||||
if err == nil { |
||||
if contentExists { |
||||
content += "\n" + string(j) |
||||
} else { |
||||
content = string(j) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
mess = append(mess, content) |
||||
} |
||||
|
||||
predInput = strings.Join(mess, "\n") |
||||
log.Debug().Msgf("Prompt (before templating): %s", predInput) |
||||
|
||||
if toStream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
// c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Chat != "" && !processFunctions { |
||||
templateFile = config.TemplateConfig.Chat |
||||
} |
||||
|
||||
if config.TemplateConfig.Functions != "" && processFunctions { |
||||
templateFile = config.TemplateConfig.Functions |
||||
} |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Functions []grammar.Function |
||||
}{ |
||||
Input: predInput, |
||||
Functions: funcs, |
||||
}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} else { |
||||
log.Debug().Msgf("Template failed loading: %s", err.Error()) |
||||
} |
||||
|
||||
log.Debug().Msgf("Prompt (after templating): %s", predInput) |
||||
if processFunctions { |
||||
log.Debug().Msgf("Grammar: %+v", config.Grammar) |
||||
} |
||||
|
||||
if toStream { |
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.Loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
FinishReason: "stop", |
||||
Index: 0, |
||||
Delta: &Message{}, |
||||
}}, |
||||
Object: "chat.completion.chunk", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||
if processFunctions { |
||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||
ss := map[string]interface{}{} |
||||
json.Unmarshal([]byte(s), &ss) |
||||
log.Debug().Msgf("Function return: %s %+v", s, ss) |
||||
|
||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||
func_name := ss["function"] |
||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||
d, _ := json.Marshal(args) |
||||
|
||||
ss["arguments"] = string(d) |
||||
ss["name"] = func_name |
||||
|
||||
// if do nothing, reply with a message
|
||||
if func_name == noActionName { |
||||
log.Debug().Msgf("nothing to do, computing a reply") |
||||
|
||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||
arguments := map[string]interface{}{} |
||||
json.Unmarshal([]byte(d), &arguments) |
||||
m, exists := arguments["message"] |
||||
if exists { |
||||
switch message := m.(type) { |
||||
case string: |
||||
if message != "" { |
||||
log.Debug().Msgf("Reply received from LLM: %s", message) |
||||
message = backend.Finetune(*config, predInput, message) |
||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) |
||||
|
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) |
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply") |
||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||
// Note: This costs (in term of CPU) another computation
|
||||
config.Grammar = "" |
||||
predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil) |
||||
if err != nil { |
||||
log.Error().Msgf("inference error: %s", err.Error()) |
||||
return |
||||
} |
||||
|
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
log.Error().Msgf("inference error: %s", err.Error()) |
||||
return |
||||
} |
||||
|
||||
prediction = backend.Finetune(*config, predInput, prediction) |
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) |
||||
} else { |
||||
// otherwise reply with the function call
|
||||
*c = append(*c, Choice{ |
||||
FinishReason: "function_call", |
||||
Message: &Message{Role: "assistant", FunctionCall: ss}, |
||||
}) |
||||
} |
||||
|
||||
return |
||||
} |
||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "chat.completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", respData) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,159 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
"github.com/valyala/fasthttp" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/completions
|
||||
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||
resp := OpenAIResponse{ |
||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
Index: 0, |
||||
Text: s, |
||||
}, |
||||
}, |
||||
Object: "text_completion", |
||||
} |
||||
log.Debug().Msgf("Sending goroutine: %s", s) |
||||
|
||||
responses <- resp |
||||
return true |
||||
}) |
||||
close(responses) |
||||
} |
||||
|
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("`input`: %+v", input) |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
if input.Stream { |
||||
log.Debug().Msgf("Stream request received") |
||||
c.Context().SetContentType("text/event-stream") |
||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||
//c.Set("Content-Type", "text/event-stream")
|
||||
c.Set("Cache-Control", "no-cache") |
||||
c.Set("Connection", "keep-alive") |
||||
c.Set("Transfer-Encoding", "chunked") |
||||
} |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Completion != "" { |
||||
templateFile = config.TemplateConfig.Completion |
||||
} |
||||
|
||||
if input.Stream { |
||||
if len(config.PromptStrings) > 1 { |
||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") |
||||
} |
||||
|
||||
predInput := config.PromptStrings[0] |
||||
|
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{ |
||||
Input: predInput, |
||||
}) |
||||
if err == nil { |
||||
predInput = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||
} |
||||
|
||||
responses := make(chan OpenAIResponse) |
||||
|
||||
go process(predInput, input, config, o.Loader, responses) |
||||
|
||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||
|
||||
for ev := range responses { |
||||
var buf bytes.Buffer |
||||
enc := json.NewEncoder(&buf) |
||||
enc.Encode(ev) |
||||
|
||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||
w.Flush() |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: []Choice{ |
||||
{ |
||||
Index: 0, |
||||
FinishReason: "stop", |
||||
}, |
||||
}, |
||||
Object: "text_completion", |
||||
} |
||||
respData, _ := json.Marshal(resp) |
||||
|
||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||
w.WriteString("data: [DONE]\n\n") |
||||
w.Flush() |
||||
})) |
||||
return nil |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.PromptStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
}{ |
||||
Input: i, |
||||
}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "text_completion", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,67 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
templateFile := config.Model |
||||
|
||||
if config.TemplateConfig.Edit != "" { |
||||
templateFile = config.TemplateConfig.Edit |
||||
} |
||||
|
||||
var result []Choice |
||||
for _, i := range config.InputStrings { |
||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||
Input string |
||||
Instruction string |
||||
}{Input: i}) |
||||
if err == nil { |
||||
i = templatedInput |
||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||
} |
||||
|
||||
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||
*c = append(*c, Choice{Text: s}) |
||||
}, nil) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
result = append(result, r...) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Choices: result, |
||||
Object: "edit", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,70 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/embeddings
|
||||
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
model, input, err := readInput(c, o.Loader, true) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
items := []Item{} |
||||
|
||||
for i, s := range config.InputToken { |
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
for i, s := range config.InputStrings { |
||||
// get the model function to call for the result
|
||||
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
embeddings, err := embedFn() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||
Data: items, |
||||
Object: "list", |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,158 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/base64" |
||||
"encoding/json" |
||||
"fmt" |
||||
"io/ioutil" |
||||
"os" |
||||
"path/filepath" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/* |
||||
* |
||||
|
||||
curl http://localhost:8080/v1/images/generations \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{ |
||||
"prompt": "A cute baby sea otter", |
||||
"n": 1, |
||||
"size": "512x512" |
||||
}' |
||||
|
||||
* |
||||
*/ |
||||
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.Loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
if m == "" { |
||||
m = model.StableDiffusionBackend |
||||
} |
||||
log.Debug().Msgf("Loading model: %+v", m) |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
log.Debug().Msgf("Parameter Config: %+v", config) |
||||
|
||||
// XXX: Only stablediffusion is supported for now
|
||||
if config.Backend == "" { |
||||
config.Backend = model.StableDiffusionBackend |
||||
} |
||||
|
||||
sizeParts := strings.Split(input.Size, "x") |
||||
if len(sizeParts) != 2 { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
width, err := strconv.Atoi(sizeParts[0]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
height, err := strconv.Atoi(sizeParts[1]) |
||||
if err != nil { |
||||
return fmt.Errorf("Invalid value for 'size'") |
||||
} |
||||
|
||||
b64JSON := false |
||||
if input.ResponseFormat == "b64_json" { |
||||
b64JSON = true |
||||
} |
||||
|
||||
var result []Item |
||||
for _, i := range config.PromptStrings { |
||||
n := input.N |
||||
if input.N == 0 { |
||||
n = 1 |
||||
} |
||||
for j := 0; j < n; j++ { |
||||
prompts := strings.Split(i, "|") |
||||
positive_prompt := prompts[0] |
||||
negative_prompt := "" |
||||
if len(prompts) > 1 { |
||||
negative_prompt = prompts[1] |
||||
} |
||||
|
||||
mode := 0 |
||||
step := 15 |
||||
|
||||
if input.Mode != 0 { |
||||
mode = input.Mode |
||||
} |
||||
|
||||
if input.Step != 0 { |
||||
step = input.Step |
||||
} |
||||
|
||||
tempDir := "" |
||||
if !b64JSON { |
||||
tempDir = o.ImageDir |
||||
} |
||||
// Create a temporary file
|
||||
outputFile, err := ioutil.TempFile(tempDir, "b64") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
outputFile.Close() |
||||
output := outputFile.Name() + ".png" |
||||
// Rename the temporary file
|
||||
err = os.Rename(outputFile.Name(), output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
baseURL := c.BaseURL() |
||||
|
||||
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.Loader, *config, o) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err := fn(); err != nil { |
||||
return err |
||||
} |
||||
|
||||
item := &Item{} |
||||
|
||||
if b64JSON { |
||||
defer os.RemoveAll(output) |
||||
data, err := os.ReadFile(output) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
||||
} else { |
||||
base := filepath.Base(output) |
||||
item.URL = baseURL + "/generated-images/" + base |
||||
} |
||||
|
||||
result = append(result, *item) |
||||
} |
||||
} |
||||
|
||||
resp := &OpenAIResponse{ |
||||
Data: result, |
||||
} |
||||
|
||||
jsonResult, _ := json.Marshal(resp) |
||||
log.Debug().Msgf("Response: %s", jsonResult) |
||||
|
||||
// Return the prediction in the response body
|
||||
return c.JSON(resp) |
||||
} |
||||
} |
@ -0,0 +1,36 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"github.com/go-skynet/LocalAI/api/backend" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
) |
||||
|
||||
func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { |
||||
result := []Choice{} |
||||
|
||||
if n == 0 { |
||||
n = 1 |
||||
} |
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback) |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
for i := 0; i < n; i++ { |
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
prediction = backend.Finetune(*config, predInput, prediction) |
||||
cb(prediction, &result) |
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
} |
||||
return result, err |
||||
} |
@ -0,0 +1,37 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
) |
||||
|
||||
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
models, err := loader.ListModels() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
var mm map[string]interface{} = map[string]interface{}{} |
||||
|
||||
dataModels := []OpenAIModel{} |
||||
for _, m := range models { |
||||
mm[m] = nil |
||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) |
||||
} |
||||
|
||||
for _, k := range cm.ListConfigs() { |
||||
if _, exists := mm[k]; !exists { |
||||
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) |
||||
} |
||||
} |
||||
|
||||
return c.JSON(struct { |
||||
Object string `json:"object"` |
||||
Data []OpenAIModel `json:"data"` |
||||
}{ |
||||
Object: "list", |
||||
Data: dataModels, |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,234 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
"strings" |
||||
|
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { |
||||
input := new(OpenAIRequest) |
||||
// Get input data from the request body
|
||||
if err := c.BodyParser(input); err != nil { |
||||
return "", nil, err |
||||
} |
||||
|
||||
modelFile := input.Model |
||||
|
||||
if c.Params("model") != "" { |
||||
modelFile = c.Params("model") |
||||
} |
||||
|
||||
received, _ := json.Marshal(input) |
||||
|
||||
log.Debug().Msgf("Request received: %s", string(received)) |
||||
|
||||
// Set model from bearer token, if available
|
||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
||||
|
||||
// If no model was specified, take the first available
|
||||
if modelFile == "" && !bearerExists && randomModel { |
||||
models, _ := loader.ListModels() |
||||
if len(models) > 0 { |
||||
modelFile = models[0] |
||||
log.Debug().Msgf("No model specified, using: %s", modelFile) |
||||
} else { |
||||
log.Debug().Msgf("No model specified, returning error") |
||||
return "", nil, fmt.Errorf("no model specified") |
||||
} |
||||
} |
||||
|
||||
// If a model is found in bearer token takes precedence
|
||||
if bearerExists { |
||||
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
||||
modelFile = bearer |
||||
} |
||||
return modelFile, input, nil |
||||
} |
||||
|
||||
func updateConfig(config *config.Config, input *OpenAIRequest) { |
||||
if input.Echo { |
||||
config.Echo = input.Echo |
||||
} |
||||
if input.TopK != 0 { |
||||
config.TopK = input.TopK |
||||
} |
||||
if input.TopP != 0 { |
||||
config.TopP = input.TopP |
||||
} |
||||
|
||||
if input.Grammar != "" { |
||||
config.Grammar = input.Grammar |
||||
} |
||||
|
||||
if input.Temperature != 0 { |
||||
config.Temperature = input.Temperature |
||||
} |
||||
|
||||
if input.Maxtokens != 0 { |
||||
config.Maxtokens = input.Maxtokens |
||||
} |
||||
|
||||
switch stop := input.Stop.(type) { |
||||
case string: |
||||
if stop != "" { |
||||
config.StopWords = append(config.StopWords, stop) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range stop { |
||||
if s, ok := pp.(string); ok { |
||||
config.StopWords = append(config.StopWords, s) |
||||
} |
||||
} |
||||
} |
||||
|
||||
if input.RepeatPenalty != 0 { |
||||
config.RepeatPenalty = input.RepeatPenalty |
||||
} |
||||
|
||||
if input.Keep != 0 { |
||||
config.Keep = input.Keep |
||||
} |
||||
|
||||
if input.Batch != 0 { |
||||
config.Batch = input.Batch |
||||
} |
||||
|
||||
if input.F16 { |
||||
config.F16 = input.F16 |
||||
} |
||||
|
||||
if input.IgnoreEOS { |
||||
config.IgnoreEOS = input.IgnoreEOS |
||||
} |
||||
|
||||
if input.Seed != 0 { |
||||
config.Seed = input.Seed |
||||
} |
||||
|
||||
if input.Mirostat != 0 { |
||||
config.Mirostat = input.Mirostat |
||||
} |
||||
|
||||
if input.MirostatETA != 0 { |
||||
config.MirostatETA = input.MirostatETA |
||||
} |
||||
|
||||
if input.MirostatTAU != 0 { |
||||
config.MirostatTAU = input.MirostatTAU |
||||
} |
||||
|
||||
if input.TypicalP != 0 { |
||||
config.TypicalP = input.TypicalP |
||||
} |
||||
|
||||
switch inputs := input.Input.(type) { |
||||
case string: |
||||
if inputs != "" { |
||||
config.InputStrings = append(config.InputStrings, inputs) |
||||
} |
||||
case []interface{}: |
||||
for _, pp := range inputs { |
||||
switch i := pp.(type) { |
||||
case string: |
||||
config.InputStrings = append(config.InputStrings, i) |
||||
case []interface{}: |
||||
tokens := []int{} |
||||
for _, ii := range i { |
||||
tokens = append(tokens, int(ii.(float64))) |
||||
} |
||||
config.InputToken = append(config.InputToken, tokens) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Can be either a string or an object
|
||||
switch fnc := input.FunctionCall.(type) { |
||||
case string: |
||||
if fnc != "" { |
||||
config.SetFunctionCallString(fnc) |
||||
} |
||||
case map[string]interface{}: |
||||
var name string |
||||
n, exists := fnc["name"] |
||||
if exists { |
||||
nn, e := n.(string) |
||||
if !e { |
||||
name = nn |
||||
} |
||||
} |
||||
config.SetFunctionCallNameString(name) |
||||
} |
||||
|
||||
switch p := input.Prompt.(type) { |
||||
case string: |
||||
config.PromptStrings = append(config.PromptStrings, p) |
||||
case []interface{}: |
||||
for _, pp := range p { |
||||
if s, ok := pp.(string); ok { |
||||
config.PromptStrings = append(config.PromptStrings, s) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) { |
||||
// Load a config file if present after the model name
|
||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") |
||||
|
||||
var cfg *config.Config |
||||
|
||||
defaults := func() { |
||||
cfg = config.DefaultConfig(modelFile) |
||||
cfg.ContextSize = ctx |
||||
cfg.Threads = threads |
||||
cfg.F16 = f16 |
||||
cfg.Debug = debug |
||||
} |
||||
|
||||
cfgExisting, exists := cm.GetConfig(modelFile) |
||||
if !exists { |
||||
if _, err := os.Stat(modelConfig); err == nil { |
||||
if err := cm.LoadConfig(modelConfig); err != nil { |
||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) |
||||
} |
||||
cfgExisting, exists = cm.GetConfig(modelFile) |
||||
if exists { |
||||
cfg = &cfgExisting |
||||
} else { |
||||
defaults() |
||||
} |
||||
} else { |
||||
defaults() |
||||
} |
||||
} else { |
||||
cfg = &cfgExisting |
||||
} |
||||
|
||||
// Set the parameters for the language model prediction
|
||||
updateConfig(cfg, input) |
||||
|
||||
// Don't allow 0 as setting
|
||||
if cfg.Threads == 0 { |
||||
if threads != 0 { |
||||
cfg.Threads = threads |
||||
} else { |
||||
cfg.Threads = 4 |
||||
} |
||||
} |
||||
|
||||
// Enforce debug flag if passed from CLI
|
||||
if debug { |
||||
cfg.Debug = true |
||||
} |
||||
|
||||
return cfg, input, nil |
||||
} |
@ -0,0 +1,91 @@ |
||||
package openai |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"net/http" |
||||
"os" |
||||
"path" |
||||
"path/filepath" |
||||
|
||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" |
||||
config "github.com/go-skynet/LocalAI/api/config" |
||||
"github.com/go-skynet/LocalAI/api/options" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" |
||||
|
||||
"github.com/gofiber/fiber/v2" |
||||
"github.com/rs/zerolog/log" |
||||
) |
||||
|
||||
// https://platform.openai.com/docs/api-reference/audio/create
|
||||
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||
return func(c *fiber.Ctx) error { |
||||
m, input, err := readInput(c, o.Loader, false) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
|
||||
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||
} |
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file") |
||||
if err != nil { |
||||
return err |
||||
} |
||||
f, err := file.Open() |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer f.Close() |
||||
|
||||
dir, err := os.MkdirTemp("", "whisper") |
||||
|
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer os.RemoveAll(dir) |
||||
|
||||
dst := filepath.Join(dir, path.Base(file.Filename)) |
||||
dstFile, err := os.Create(dst) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if _, err := io.Copy(dstFile, f); err != nil { |
||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst) |
||||
|
||||
whisperModel, err := o.Loader.BackendLoader( |
||||
model.WithBackendString(model.WhisperBackend), |
||||
model.WithModelFile(config.Model), |
||||
model.WithThreads(uint32(config.Threads)), |
||||
model.WithAssetDir(o.AssetsDestination)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
if whisperModel == nil { |
||||
return fmt.Errorf("could not load whisper model") |
||||
} |
||||
|
||||
w, ok := whisperModel.(whisper.Model) |
||||
if !ok { |
||||
return fmt.Errorf("loader returned non-whisper object") |
||||
} |
||||
|
||||
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
log.Debug().Msgf("Trascribed: %+v", tr) |
||||
// TODO: handle different outputs here
|
||||
return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr}) |
||||
} |
||||
} |
@ -1,415 +0,0 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"os" |
||||
"path/filepath" |
||||
"regexp" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"github.com/donomii/go-rwkv.cpp" |
||||
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||
"github.com/go-skynet/LocalAI/pkg/langchain" |
||||
model "github.com/go-skynet/LocalAI/pkg/model" |
||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion" |
||||
"github.com/go-skynet/bloomz.cpp" |
||||
bert "github.com/go-skynet/go-bert.cpp" |
||||
) |
||||
|
||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
var mutexMap sync.Mutex |
||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) |
||||
|
||||
func gRPCModelOpts(c Config) *pb.ModelOptions { |
||||
b := 512 |
||||
if c.Batch != 0 { |
||||
b = c.Batch |
||||
} |
||||
return &pb.ModelOptions{ |
||||
ContextSize: int32(c.ContextSize), |
||||
Seed: int32(c.Seed), |
||||
NBatch: int32(b), |
||||
F16Memory: c.F16, |
||||
MLock: c.MMlock, |
||||
NUMA: c.NUMA, |
||||
Embeddings: c.Embeddings, |
||||
LowVRAM: c.LowVRAM, |
||||
NGPULayers: int32(c.NGPULayers), |
||||
MMap: c.MMap, |
||||
MainGPU: c.MainGPU, |
||||
Threads: int32(c.Threads), |
||||
TensorSplit: c.TensorSplit, |
||||
} |
||||
} |
||||
|
||||
func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { |
||||
promptCachePath := "" |
||||
if c.PromptCachePath != "" { |
||||
p := filepath.Join(modelPath, c.PromptCachePath) |
||||
os.MkdirAll(filepath.Dir(p), 0755) |
||||
promptCachePath = p |
||||
} |
||||
return &pb.PredictOptions{ |
||||
Temperature: float32(c.Temperature), |
||||
TopP: float32(c.TopP), |
||||
TopK: int32(c.TopK), |
||||
Tokens: int32(c.Maxtokens), |
||||
Threads: int32(c.Threads), |
||||
PromptCacheAll: c.PromptCacheAll, |
||||
PromptCacheRO: c.PromptCacheRO, |
||||
PromptCachePath: promptCachePath, |
||||
F16KV: c.F16, |
||||
DebugMode: c.Debug, |
||||
Grammar: c.Grammar, |
||||
|
||||
Mirostat: int32(c.Mirostat), |
||||
MirostatETA: float32(c.MirostatETA), |
||||
MirostatTAU: float32(c.MirostatTAU), |
||||
Debug: c.Debug, |
||||
StopPrompts: c.StopWords, |
||||
Repeat: int32(c.RepeatPenalty), |
||||
NKeep: int32(c.Keep), |
||||
Batch: int32(c.Batch), |
||||
IgnoreEOS: c.IgnoreEOS, |
||||
Seed: int32(c.Seed), |
||||
FrequencyPenalty: float32(c.FrequencyPenalty), |
||||
MLock: c.MMlock, |
||||
MMap: c.MMap, |
||||
MainGPU: c.MainGPU, |
||||
TensorSplit: c.TensorSplit, |
||||
TailFreeSamplingZ: float32(c.TFZ), |
||||
TypicalP: float32(c.TypicalP), |
||||
} |
||||
} |
||||
|
||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) { |
||||
if c.Backend != model.StableDiffusionBackend { |
||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models") |
||||
} |
||||
|
||||
inferenceModel, err := loader.BackendLoader( |
||||
model.WithBackendString(c.Backend), |
||||
model.WithAssetDir(o.assetsDestination), |
||||
model.WithThreads(uint32(c.Threads)), |
||||
model.WithModelFile(c.ImageGenerationAssets), |
||||
) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() error |
||||
switch model := inferenceModel.(type) { |
||||
case *stablediffusion.StableDiffusion: |
||||
fn = func() error { |
||||
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) |
||||
} |
||||
|
||||
default: |
||||
fn = func() error { |
||||
return fmt.Errorf("creation of images not supported by the backend") |
||||
} |
||||
} |
||||
|
||||
return func() error { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[c.Backend] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[c.Backend] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
return fn() |
||||
}, nil |
||||
} |
||||
|
||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, o *Option) (func() ([]float32, error), error) { |
||||
if !c.Embeddings { |
||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") |
||||
} |
||||
|
||||
modelFile := c.Model |
||||
|
||||
grpcOpts := gRPCModelOpts(c) |
||||
|
||||
var inferenceModel interface{} |
||||
var err error |
||||
|
||||
opts := []model.Option{ |
||||
model.WithLoadGRPCOpts(grpcOpts), |
||||
model.WithThreads(uint32(c.Threads)), |
||||
model.WithAssetDir(o.assetsDestination), |
||||
model.WithModelFile(modelFile), |
||||
} |
||||
|
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(opts...) |
||||
} else { |
||||
opts = append(opts, model.WithBackendString(c.Backend)) |
||||
inferenceModel, err = loader.BackendLoader(opts...) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() ([]float32, error) |
||||
switch model := inferenceModel.(type) { |
||||
case *grpc.Client: |
||||
fn = func() ([]float32, error) { |
||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath) |
||||
if len(tokens) > 0 { |
||||
embeds := []int32{} |
||||
|
||||
for _, t := range tokens { |
||||
embeds = append(embeds, int32(t)) |
||||
} |
||||
predictOptions.EmbeddingTokens = embeds |
||||
|
||||
res, err := model.Embeddings(context.TODO(), predictOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return res.Embeddings, nil |
||||
} |
||||
predictOptions.Embeddings = s |
||||
|
||||
res, err := model.Embeddings(context.TODO(), predictOptions) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return res.Embeddings, nil |
||||
} |
||||
|
||||
// bert embeddings
|
||||
case *bert.Bert: |
||||
fn = func() ([]float32, error) { |
||||
if len(tokens) > 0 { |
||||
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) |
||||
} |
||||
return model.Embeddings(s, bert.SetThreads(c.Threads)) |
||||
} |
||||
default: |
||||
fn = func() ([]float32, error) { |
||||
return nil, fmt.Errorf("embeddings not supported by the backend") |
||||
} |
||||
} |
||||
|
||||
return func() ([]float32, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[modelFile] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[modelFile] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
embeds, err := fn() |
||||
if err != nil { |
||||
return embeds, err |
||||
} |
||||
// Remove trailing 0s
|
||||
for i := len(embeds) - 1; i >= 0; i-- { |
||||
if embeds[i] == 0.0 { |
||||
embeds = embeds[:i] |
||||
} else { |
||||
break |
||||
} |
||||
} |
||||
return embeds, nil |
||||
}, nil |
||||
} |
||||
|
||||
func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, tokenCallback func(string) bool) (func() (string, error), error) { |
||||
supportStreams := false |
||||
modelFile := c.Model |
||||
|
||||
grpcOpts := gRPCModelOpts(c) |
||||
|
||||
var inferenceModel interface{} |
||||
var err error |
||||
|
||||
opts := []model.Option{ |
||||
model.WithLoadGRPCOpts(grpcOpts), |
||||
model.WithThreads(uint32(c.Threads)), // GPT4all uses this
|
||||
model.WithAssetDir(o.assetsDestination), |
||||
model.WithModelFile(modelFile), |
||||
} |
||||
|
||||
if c.Backend == "" { |
||||
inferenceModel, err = loader.GreedyLoader(opts...) |
||||
} else { |
||||
opts = append(opts, model.WithBackendString(c.Backend)) |
||||
inferenceModel, err = loader.BackendLoader(opts...) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var fn func() (string, error) |
||||
|
||||
switch model := inferenceModel.(type) { |
||||
case *rwkv.RwkvState: |
||||
supportStreams = true |
||||
|
||||
fn = func() (string, error) { |
||||
stopWord := "\n" |
||||
if len(c.StopWords) > 0 { |
||||
stopWord = c.StopWords[0] |
||||
} |
||||
|
||||
if err := model.ProcessInput(s); err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) |
||||
|
||||
return response, nil |
||||
} |
||||
case *bloomz.Bloomz: |
||||
fn = func() (string, error) { |
||||
// Generate the prediction using the language model
|
||||
predictOptions := []bloomz.PredictOption{ |
||||
bloomz.SetTemperature(c.Temperature), |
||||
bloomz.SetTopP(c.TopP), |
||||
bloomz.SetTopK(c.TopK), |
||||
bloomz.SetTokens(c.Maxtokens), |
||||
bloomz.SetThreads(c.Threads), |
||||
} |
||||
|
||||
if c.Seed != 0 { |
||||
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) |
||||
} |
||||
|
||||
return model.Predict( |
||||
s, |
||||
predictOptions..., |
||||
) |
||||
} |
||||
|
||||
case *grpc.Client: |
||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||
supportStreams = true |
||||
fn = func() (string, error) { |
||||
|
||||
opts := gRPCPredictOpts(c, loader.ModelPath) |
||||
opts.Prompt = s |
||||
if tokenCallback != nil { |
||||
ss := "" |
||||
err := model.PredictStream(context.TODO(), opts, func(s string) { |
||||
tokenCallback(s) |
||||
ss += s |
||||
}) |
||||
return ss, err |
||||
} else { |
||||
reply, err := model.Predict(context.TODO(), opts) |
||||
return reply.Message, err |
||||
} |
||||
} |
||||
case *langchain.HuggingFace: |
||||
fn = func() (string, error) { |
||||
|
||||
// Generate the prediction using the language model
|
||||
predictOptions := []langchain.PredictOption{ |
||||
langchain.SetModel(c.Model), |
||||
langchain.SetMaxTokens(c.Maxtokens), |
||||
langchain.SetTemperature(c.Temperature), |
||||
langchain.SetStopWords(c.StopWords), |
||||
} |
||||
|
||||
pred, er := model.PredictHuggingFace(s, predictOptions...) |
||||
if er != nil { |
||||
return "", er |
||||
} |
||||
return pred.Completion, nil |
||||
} |
||||
} |
||||
|
||||
return func() (string, error) { |
||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||
mutexMap.Lock() |
||||
l, ok := mutexes[modelFile] |
||||
if !ok { |
||||
m := &sync.Mutex{} |
||||
mutexes[modelFile] = m |
||||
l = m |
||||
} |
||||
mutexMap.Unlock() |
||||
l.Lock() |
||||
defer l.Unlock() |
||||
|
||||
res, err := fn() |
||||
if tokenCallback != nil && !supportStreams { |
||||
tokenCallback(res) |
||||
} |
||||
return res, err |
||||
}, nil |
||||
} |
||||
|
||||
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, o *Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { |
||||
result := []Choice{} |
||||
|
||||
n := input.N |
||||
|
||||
if input.N == 0 { |
||||
n = 1 |
||||
} |
||||
|
||||
// get the model function to call for the result
|
||||
predFunc, err := ModelInference(predInput, loader, *config, o, tokenCallback) |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
for i := 0; i < n; i++ { |
||||
prediction, err := predFunc() |
||||
if err != nil { |
||||
return result, err |
||||
} |
||||
|
||||
prediction = Finetune(*config, predInput, prediction) |
||||
cb(prediction, &result) |
||||
|
||||
//result = append(result, Choice{Text: prediction})
|
||||
|
||||
} |
||||
return result, err |
||||
} |
||||
|
||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
||||
var mu sync.Mutex = sync.Mutex{} |
||||
|
||||
func Finetune(config Config, input, prediction string) string { |
||||
if config.Echo { |
||||
prediction = input + prediction |
||||
} |
||||
|
||||
for _, c := range config.Cutstrings { |
||||
mu.Lock() |
||||
reg, ok := cutstrings[c] |
||||
if !ok { |
||||
cutstrings[c] = regexp.MustCompile(c) |
||||
reg = cutstrings[c] |
||||
} |
||||
mu.Unlock() |
||||
prediction = reg.ReplaceAllString(prediction, "") |
||||
} |
||||
|
||||
for _, c := range config.TrimSpace { |
||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
||||
} |
||||
return prediction |
||||
|
||||
} |
Loading…
Reference in new issue