Signed-off-by: Ettore Di Giacinto <mudler@localai.io>renovate/github.com-imdario-mergo-1.x
parent
f2f1d7fe72
commit
5dcfdbe51d
@ -0,0 +1,107 @@ |
|||||||
|
package backend |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"fmt" |
||||||
|
"sync" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
bert "github.com/go-skynet/go-bert.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.Config, o *options.Option) (func() ([]float32, error), error) { |
||||||
|
if !c.Embeddings { |
||||||
|
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") |
||||||
|
} |
||||||
|
|
||||||
|
modelFile := c.Model |
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(c) |
||||||
|
|
||||||
|
var inferenceModel interface{} |
||||||
|
var err error |
||||||
|
|
||||||
|
opts := []model.Option{ |
||||||
|
model.WithLoadGRPCOpts(grpcOpts), |
||||||
|
model.WithThreads(uint32(c.Threads)), |
||||||
|
model.WithAssetDir(o.AssetsDestination), |
||||||
|
model.WithModelFile(modelFile), |
||||||
|
} |
||||||
|
|
||||||
|
if c.Backend == "" { |
||||||
|
inferenceModel, err = loader.GreedyLoader(opts...) |
||||||
|
} else { |
||||||
|
opts = append(opts, model.WithBackendString(c.Backend)) |
||||||
|
inferenceModel, err = loader.BackendLoader(opts...) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var fn func() ([]float32, error) |
||||||
|
switch model := inferenceModel.(type) { |
||||||
|
case *grpc.Client: |
||||||
|
fn = func() ([]float32, error) { |
||||||
|
predictOptions := gRPCPredictOpts(c, loader.ModelPath) |
||||||
|
if len(tokens) > 0 { |
||||||
|
embeds := []int32{} |
||||||
|
|
||||||
|
for _, t := range tokens { |
||||||
|
embeds = append(embeds, int32(t)) |
||||||
|
} |
||||||
|
predictOptions.EmbeddingTokens = embeds |
||||||
|
|
||||||
|
res, err := model.Embeddings(context.TODO(), predictOptions) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return res.Embeddings, nil |
||||||
|
} |
||||||
|
predictOptions.Embeddings = s |
||||||
|
|
||||||
|
res, err := model.Embeddings(context.TODO(), predictOptions) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
return res.Embeddings, nil |
||||||
|
} |
||||||
|
|
||||||
|
// bert embeddings
|
||||||
|
case *bert.Bert: |
||||||
|
fn = func() ([]float32, error) { |
||||||
|
if len(tokens) > 0 { |
||||||
|
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) |
||||||
|
} |
||||||
|
return model.Embeddings(s, bert.SetThreads(c.Threads)) |
||||||
|
} |
||||||
|
default: |
||||||
|
fn = func() ([]float32, error) { |
||||||
|
return nil, fmt.Errorf("embeddings not supported by the backend") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return func() ([]float32, error) { |
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
l := Lock(modelFile) |
||||||
|
defer l.Unlock() |
||||||
|
|
||||||
|
embeds, err := fn() |
||||||
|
if err != nil { |
||||||
|
return embeds, err |
||||||
|
} |
||||||
|
// Remove trailing 0s
|
||||||
|
for i := len(embeds) - 1; i >= 0; i-- { |
||||||
|
if embeds[i] == 0.0 { |
||||||
|
embeds = embeds[:i] |
||||||
|
} else { |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
return embeds, nil |
||||||
|
}, nil |
||||||
|
} |
@ -0,0 +1,56 @@ |
|||||||
|
package backend |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"sync" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/stablediffusion" |
||||||
|
) |
||||||
|
|
||||||
|
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c config.Config, o *options.Option) (func() error, error) { |
||||||
|
if c.Backend != model.StableDiffusionBackend { |
||||||
|
return nil, fmt.Errorf("endpoint only working with stablediffusion models") |
||||||
|
} |
||||||
|
|
||||||
|
inferenceModel, err := loader.BackendLoader( |
||||||
|
model.WithBackendString(c.Backend), |
||||||
|
model.WithAssetDir(o.AssetsDestination), |
||||||
|
model.WithThreads(uint32(c.Threads)), |
||||||
|
model.WithModelFile(c.ImageGenerationAssets), |
||||||
|
) |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var fn func() error |
||||||
|
switch model := inferenceModel.(type) { |
||||||
|
case *stablediffusion.StableDiffusion: |
||||||
|
fn = func() error { |
||||||
|
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) |
||||||
|
} |
||||||
|
|
||||||
|
default: |
||||||
|
fn = func() error { |
||||||
|
return fmt.Errorf("creation of images not supported by the backend") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return func() error { |
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
mutexMap.Lock() |
||||||
|
l, ok := mutexes[c.Backend] |
||||||
|
if !ok { |
||||||
|
m := &sync.Mutex{} |
||||||
|
mutexes[c.Backend] = m |
||||||
|
l = m |
||||||
|
} |
||||||
|
mutexMap.Unlock() |
||||||
|
l.Lock() |
||||||
|
defer l.Unlock() |
||||||
|
|
||||||
|
return fn() |
||||||
|
}, nil |
||||||
|
} |
@ -0,0 +1,160 @@ |
|||||||
|
package backend |
||||||
|
|
||||||
|
import ( |
||||||
|
"context" |
||||||
|
"regexp" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"github.com/donomii/go-rwkv.cpp" |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grpc" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/langchain" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/go-skynet/bloomz.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) { |
||||||
|
supportStreams := false |
||||||
|
modelFile := c.Model |
||||||
|
|
||||||
|
grpcOpts := gRPCModelOpts(c) |
||||||
|
|
||||||
|
var inferenceModel interface{} |
||||||
|
var err error |
||||||
|
|
||||||
|
opts := []model.Option{ |
||||||
|
model.WithLoadGRPCOpts(grpcOpts), |
||||||
|
model.WithThreads(uint32(c.Threads)), // GPT4all uses this
|
||||||
|
model.WithAssetDir(o.AssetsDestination), |
||||||
|
model.WithModelFile(modelFile), |
||||||
|
} |
||||||
|
|
||||||
|
if c.Backend == "" { |
||||||
|
inferenceModel, err = loader.GreedyLoader(opts...) |
||||||
|
} else { |
||||||
|
opts = append(opts, model.WithBackendString(c.Backend)) |
||||||
|
inferenceModel, err = loader.BackendLoader(opts...) |
||||||
|
} |
||||||
|
if err != nil { |
||||||
|
return nil, err |
||||||
|
} |
||||||
|
|
||||||
|
var fn func() (string, error) |
||||||
|
|
||||||
|
switch model := inferenceModel.(type) { |
||||||
|
case *rwkv.RwkvState: |
||||||
|
supportStreams = true |
||||||
|
|
||||||
|
fn = func() (string, error) { |
||||||
|
stopWord := "\n" |
||||||
|
if len(c.StopWords) > 0 { |
||||||
|
stopWord = c.StopWords[0] |
||||||
|
} |
||||||
|
|
||||||
|
if err := model.ProcessInput(s); err != nil { |
||||||
|
return "", err |
||||||
|
} |
||||||
|
|
||||||
|
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) |
||||||
|
|
||||||
|
return response, nil |
||||||
|
} |
||||||
|
case *bloomz.Bloomz: |
||||||
|
fn = func() (string, error) { |
||||||
|
// Generate the prediction using the language model
|
||||||
|
predictOptions := []bloomz.PredictOption{ |
||||||
|
bloomz.SetTemperature(c.Temperature), |
||||||
|
bloomz.SetTopP(c.TopP), |
||||||
|
bloomz.SetTopK(c.TopK), |
||||||
|
bloomz.SetTokens(c.Maxtokens), |
||||||
|
bloomz.SetThreads(c.Threads), |
||||||
|
} |
||||||
|
|
||||||
|
if c.Seed != 0 { |
||||||
|
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) |
||||||
|
} |
||||||
|
|
||||||
|
return model.Predict( |
||||||
|
s, |
||||||
|
predictOptions..., |
||||||
|
) |
||||||
|
} |
||||||
|
|
||||||
|
case *grpc.Client: |
||||||
|
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
||||||
|
supportStreams = true |
||||||
|
fn = func() (string, error) { |
||||||
|
|
||||||
|
opts := gRPCPredictOpts(c, loader.ModelPath) |
||||||
|
opts.Prompt = s |
||||||
|
if tokenCallback != nil { |
||||||
|
ss := "" |
||||||
|
err := model.PredictStream(context.TODO(), opts, func(s string) { |
||||||
|
tokenCallback(s) |
||||||
|
ss += s |
||||||
|
}) |
||||||
|
return ss, err |
||||||
|
} else { |
||||||
|
reply, err := model.Predict(context.TODO(), opts) |
||||||
|
return reply.Message, err |
||||||
|
} |
||||||
|
} |
||||||
|
case *langchain.HuggingFace: |
||||||
|
fn = func() (string, error) { |
||||||
|
|
||||||
|
// Generate the prediction using the language model
|
||||||
|
predictOptions := []langchain.PredictOption{ |
||||||
|
langchain.SetModel(c.Model), |
||||||
|
langchain.SetMaxTokens(c.Maxtokens), |
||||||
|
langchain.SetTemperature(c.Temperature), |
||||||
|
langchain.SetStopWords(c.StopWords), |
||||||
|
} |
||||||
|
|
||||||
|
pred, er := model.PredictHuggingFace(s, predictOptions...) |
||||||
|
if er != nil { |
||||||
|
return "", er |
||||||
|
} |
||||||
|
return pred.Completion, nil |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return func() (string, error) { |
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
l := Lock(modelFile) |
||||||
|
defer l.Unlock() |
||||||
|
|
||||||
|
res, err := fn() |
||||||
|
if tokenCallback != nil && !supportStreams { |
||||||
|
tokenCallback(res) |
||||||
|
} |
||||||
|
return res, err |
||||||
|
}, nil |
||||||
|
} |
||||||
|
|
||||||
|
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
||||||
|
var mu sync.Mutex = sync.Mutex{} |
||||||
|
|
||||||
|
func Finetune(config config.Config, input, prediction string) string { |
||||||
|
if config.Echo { |
||||||
|
prediction = input + prediction |
||||||
|
} |
||||||
|
|
||||||
|
for _, c := range config.Cutstrings { |
||||||
|
mu.Lock() |
||||||
|
reg, ok := cutstrings[c] |
||||||
|
if !ok { |
||||||
|
cutstrings[c] = regexp.MustCompile(c) |
||||||
|
reg = cutstrings[c] |
||||||
|
} |
||||||
|
mu.Unlock() |
||||||
|
prediction = reg.ReplaceAllString(prediction, "") |
||||||
|
} |
||||||
|
|
||||||
|
for _, c := range config.TrimSpace { |
||||||
|
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
||||||
|
} |
||||||
|
return prediction |
||||||
|
|
||||||
|
} |
@ -0,0 +1,22 @@ |
|||||||
|
package backend |
||||||
|
|
||||||
|
import "sync" |
||||||
|
|
||||||
|
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
var mutexMap sync.Mutex |
||||||
|
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) |
||||||
|
|
||||||
|
func Lock(s string) *sync.Mutex { |
||||||
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
||||||
|
mutexMap.Lock() |
||||||
|
l, ok := mutexes[s] |
||||||
|
if !ok { |
||||||
|
m := &sync.Mutex{} |
||||||
|
mutexes[s] = m |
||||||
|
l = m |
||||||
|
} |
||||||
|
mutexMap.Unlock() |
||||||
|
l.Lock() |
||||||
|
|
||||||
|
return l |
||||||
|
} |
@ -0,0 +1,98 @@ |
|||||||
|
package backend |
||||||
|
|
||||||
|
import ( |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
|
||||||
|
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/langchain" |
||||||
|
"github.com/go-skynet/bloomz.cpp" |
||||||
|
) |
||||||
|
|
||||||
|
func langchainOptions(c config.Config) []langchain.PredictOption { |
||||||
|
return []langchain.PredictOption{ |
||||||
|
langchain.SetModel(c.Model), |
||||||
|
langchain.SetMaxTokens(c.Maxtokens), |
||||||
|
langchain.SetTemperature(c.Temperature), |
||||||
|
langchain.SetStopWords(c.StopWords), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func bloomzOptions(c config.Config) []bloomz.PredictOption { |
||||||
|
// Generate the prediction using the language model
|
||||||
|
predictOptions := []bloomz.PredictOption{ |
||||||
|
bloomz.SetTemperature(c.Temperature), |
||||||
|
bloomz.SetTopP(c.TopP), |
||||||
|
bloomz.SetTopK(c.TopK), |
||||||
|
bloomz.SetTokens(c.Maxtokens), |
||||||
|
bloomz.SetThreads(c.Threads), |
||||||
|
} |
||||||
|
|
||||||
|
if c.Seed != 0 { |
||||||
|
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) |
||||||
|
} |
||||||
|
return predictOptions |
||||||
|
} |
||||||
|
func gRPCModelOpts(c config.Config) *pb.ModelOptions { |
||||||
|
b := 512 |
||||||
|
if c.Batch != 0 { |
||||||
|
b = c.Batch |
||||||
|
} |
||||||
|
return &pb.ModelOptions{ |
||||||
|
ContextSize: int32(c.ContextSize), |
||||||
|
Seed: int32(c.Seed), |
||||||
|
NBatch: int32(b), |
||||||
|
F16Memory: c.F16, |
||||||
|
MLock: c.MMlock, |
||||||
|
NUMA: c.NUMA, |
||||||
|
Embeddings: c.Embeddings, |
||||||
|
LowVRAM: c.LowVRAM, |
||||||
|
NGPULayers: int32(c.NGPULayers), |
||||||
|
MMap: c.MMap, |
||||||
|
MainGPU: c.MainGPU, |
||||||
|
Threads: int32(c.Threads), |
||||||
|
TensorSplit: c.TensorSplit, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func gRPCPredictOpts(c config.Config, modelPath string) *pb.PredictOptions { |
||||||
|
promptCachePath := "" |
||||||
|
if c.PromptCachePath != "" { |
||||||
|
p := filepath.Join(modelPath, c.PromptCachePath) |
||||||
|
os.MkdirAll(filepath.Dir(p), 0755) |
||||||
|
promptCachePath = p |
||||||
|
} |
||||||
|
return &pb.PredictOptions{ |
||||||
|
Temperature: float32(c.Temperature), |
||||||
|
TopP: float32(c.TopP), |
||||||
|
TopK: int32(c.TopK), |
||||||
|
Tokens: int32(c.Maxtokens), |
||||||
|
Threads: int32(c.Threads), |
||||||
|
PromptCacheAll: c.PromptCacheAll, |
||||||
|
PromptCacheRO: c.PromptCacheRO, |
||||||
|
PromptCachePath: promptCachePath, |
||||||
|
F16KV: c.F16, |
||||||
|
DebugMode: c.Debug, |
||||||
|
Grammar: c.Grammar, |
||||||
|
|
||||||
|
Mirostat: int32(c.Mirostat), |
||||||
|
MirostatETA: float32(c.MirostatETA), |
||||||
|
MirostatTAU: float32(c.MirostatTAU), |
||||||
|
Debug: c.Debug, |
||||||
|
StopPrompts: c.StopWords, |
||||||
|
Repeat: int32(c.RepeatPenalty), |
||||||
|
NKeep: int32(c.Keep), |
||||||
|
Batch: int32(c.Batch), |
||||||
|
IgnoreEOS: c.IgnoreEOS, |
||||||
|
Seed: int32(c.Seed), |
||||||
|
FrequencyPenalty: float32(c.FrequencyPenalty), |
||||||
|
MLock: c.MMlock, |
||||||
|
MMap: c.MMap, |
||||||
|
MainGPU: c.MainGPU, |
||||||
|
TensorSplit: c.TensorSplit, |
||||||
|
TailFreeSamplingZ: float32(c.TFZ), |
||||||
|
TypicalP: float32(c.TypicalP), |
||||||
|
} |
||||||
|
} |
@ -1,401 +0,0 @@ |
|||||||
package api |
|
||||||
|
|
||||||
import ( |
|
||||||
"encoding/json" |
|
||||||
"fmt" |
|
||||||
"io/fs" |
|
||||||
"os" |
|
||||||
"path/filepath" |
|
||||||
"strings" |
|
||||||
"sync" |
|
||||||
|
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model" |
|
||||||
"github.com/gofiber/fiber/v2" |
|
||||||
"github.com/rs/zerolog/log" |
|
||||||
"gopkg.in/yaml.v3" |
|
||||||
) |
|
||||||
|
|
||||||
type Config struct { |
|
||||||
OpenAIRequest `yaml:"parameters"` |
|
||||||
Name string `yaml:"name"` |
|
||||||
StopWords []string `yaml:"stopwords"` |
|
||||||
Cutstrings []string `yaml:"cutstrings"` |
|
||||||
TrimSpace []string `yaml:"trimspace"` |
|
||||||
ContextSize int `yaml:"context_size"` |
|
||||||
F16 bool `yaml:"f16"` |
|
||||||
NUMA bool `yaml:"numa"` |
|
||||||
Threads int `yaml:"threads"` |
|
||||||
Debug bool `yaml:"debug"` |
|
||||||
Roles map[string]string `yaml:"roles"` |
|
||||||
Embeddings bool `yaml:"embeddings"` |
|
||||||
Backend string `yaml:"backend"` |
|
||||||
TemplateConfig TemplateConfig `yaml:"template"` |
|
||||||
MirostatETA float64 `yaml:"mirostat_eta"` |
|
||||||
MirostatTAU float64 `yaml:"mirostat_tau"` |
|
||||||
Mirostat int `yaml:"mirostat"` |
|
||||||
NGPULayers int `yaml:"gpu_layers"` |
|
||||||
MMap bool `yaml:"mmap"` |
|
||||||
MMlock bool `yaml:"mmlock"` |
|
||||||
LowVRAM bool `yaml:"low_vram"` |
|
||||||
|
|
||||||
TensorSplit string `yaml:"tensor_split"` |
|
||||||
MainGPU string `yaml:"main_gpu"` |
|
||||||
ImageGenerationAssets string `yaml:"asset_dir"` |
|
||||||
|
|
||||||
PromptCachePath string `yaml:"prompt_cache_path"` |
|
||||||
PromptCacheAll bool `yaml:"prompt_cache_all"` |
|
||||||
PromptCacheRO bool `yaml:"prompt_cache_ro"` |
|
||||||
|
|
||||||
Grammar string `yaml:"grammar"` |
|
||||||
|
|
||||||
FunctionsConfig Functions `yaml:"function"` |
|
||||||
|
|
||||||
PromptStrings, InputStrings []string |
|
||||||
InputToken [][]int |
|
||||||
functionCallString, functionCallNameString string |
|
||||||
} |
|
||||||
|
|
||||||
type Functions struct { |
|
||||||
DisableNoAction bool `yaml:"disable_no_action"` |
|
||||||
NoActionFunctionName string `yaml:"no_action_function_name"` |
|
||||||
NoActionDescriptionName string `yaml:"no_action_description_name"` |
|
||||||
} |
|
||||||
|
|
||||||
type TemplateConfig struct { |
|
||||||
Completion string `yaml:"completion"` |
|
||||||
Functions string `yaml:"function"` |
|
||||||
Chat string `yaml:"chat"` |
|
||||||
Edit string `yaml:"edit"` |
|
||||||
} |
|
||||||
|
|
||||||
type ConfigMerger struct { |
|
||||||
configs map[string]Config |
|
||||||
sync.Mutex |
|
||||||
} |
|
||||||
|
|
||||||
func defaultConfig(modelFile string) *Config { |
|
||||||
return &Config{ |
|
||||||
OpenAIRequest: defaultRequest(modelFile), |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func NewConfigMerger() *ConfigMerger { |
|
||||||
return &ConfigMerger{ |
|
||||||
configs: make(map[string]Config), |
|
||||||
} |
|
||||||
} |
|
||||||
func ReadConfigFile(file string) ([]*Config, error) { |
|
||||||
c := &[]*Config{} |
|
||||||
f, err := os.ReadFile(file) |
|
||||||
if err != nil { |
|
||||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
|
||||||
} |
|
||||||
if err := yaml.Unmarshal(f, c); err != nil { |
|
||||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
return *c, nil |
|
||||||
} |
|
||||||
|
|
||||||
func ReadConfig(file string) (*Config, error) { |
|
||||||
c := &Config{} |
|
||||||
f, err := os.ReadFile(file) |
|
||||||
if err != nil { |
|
||||||
return nil, fmt.Errorf("cannot read config file: %w", err) |
|
||||||
} |
|
||||||
if err := yaml.Unmarshal(f, c); err != nil { |
|
||||||
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
return c, nil |
|
||||||
} |
|
||||||
|
|
||||||
func (cm *ConfigMerger) LoadConfigFile(file string) error { |
|
||||||
cm.Lock() |
|
||||||
defer cm.Unlock() |
|
||||||
c, err := ReadConfigFile(file) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("cannot load config file: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
for _, cc := range c { |
|
||||||
cm.configs[cc.Name] = *cc |
|
||||||
} |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (cm *ConfigMerger) LoadConfig(file string) error { |
|
||||||
cm.Lock() |
|
||||||
defer cm.Unlock() |
|
||||||
c, err := ReadConfig(file) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("cannot read config file: %w", err) |
|
||||||
} |
|
||||||
|
|
||||||
cm.configs[c.Name] = *c |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func (cm *ConfigMerger) GetConfig(m string) (Config, bool) { |
|
||||||
cm.Lock() |
|
||||||
defer cm.Unlock() |
|
||||||
v, exists := cm.configs[m] |
|
||||||
return v, exists |
|
||||||
} |
|
||||||
|
|
||||||
func (cm *ConfigMerger) ListConfigs() []string { |
|
||||||
cm.Lock() |
|
||||||
defer cm.Unlock() |
|
||||||
var res []string |
|
||||||
for k := range cm.configs { |
|
||||||
res = append(res, k) |
|
||||||
} |
|
||||||
return res |
|
||||||
} |
|
||||||
|
|
||||||
func (cm *ConfigMerger) LoadConfigs(path string) error { |
|
||||||
cm.Lock() |
|
||||||
defer cm.Unlock() |
|
||||||
entries, err := os.ReadDir(path) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
files := make([]fs.FileInfo, 0, len(entries)) |
|
||||||
for _, entry := range entries { |
|
||||||
info, err := entry.Info() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
files = append(files, info) |
|
||||||
} |
|
||||||
for _, file := range files { |
|
||||||
// Skip templates, YAML and .keep files
|
|
||||||
if !strings.Contains(file.Name(), ".yaml") { |
|
||||||
continue |
|
||||||
} |
|
||||||
c, err := ReadConfig(filepath.Join(path, file.Name())) |
|
||||||
if err == nil { |
|
||||||
cm.configs[c.Name] = *c |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
func updateConfig(config *Config, input *OpenAIRequest) { |
|
||||||
if input.Echo { |
|
||||||
config.Echo = input.Echo |
|
||||||
} |
|
||||||
if input.TopK != 0 { |
|
||||||
config.TopK = input.TopK |
|
||||||
} |
|
||||||
if input.TopP != 0 { |
|
||||||
config.TopP = input.TopP |
|
||||||
} |
|
||||||
|
|
||||||
if input.Grammar != "" { |
|
||||||
config.Grammar = input.Grammar |
|
||||||
} |
|
||||||
|
|
||||||
if input.Temperature != 0 { |
|
||||||
config.Temperature = input.Temperature |
|
||||||
} |
|
||||||
|
|
||||||
if input.Maxtokens != 0 { |
|
||||||
config.Maxtokens = input.Maxtokens |
|
||||||
} |
|
||||||
|
|
||||||
switch stop := input.Stop.(type) { |
|
||||||
case string: |
|
||||||
if stop != "" { |
|
||||||
config.StopWords = append(config.StopWords, stop) |
|
||||||
} |
|
||||||
case []interface{}: |
|
||||||
for _, pp := range stop { |
|
||||||
if s, ok := pp.(string); ok { |
|
||||||
config.StopWords = append(config.StopWords, s) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
if input.RepeatPenalty != 0 { |
|
||||||
config.RepeatPenalty = input.RepeatPenalty |
|
||||||
} |
|
||||||
|
|
||||||
if input.Keep != 0 { |
|
||||||
config.Keep = input.Keep |
|
||||||
} |
|
||||||
|
|
||||||
if input.Batch != 0 { |
|
||||||
config.Batch = input.Batch |
|
||||||
} |
|
||||||
|
|
||||||
if input.F16 { |
|
||||||
config.F16 = input.F16 |
|
||||||
} |
|
||||||
|
|
||||||
if input.IgnoreEOS { |
|
||||||
config.IgnoreEOS = input.IgnoreEOS |
|
||||||
} |
|
||||||
|
|
||||||
if input.Seed != 0 { |
|
||||||
config.Seed = input.Seed |
|
||||||
} |
|
||||||
|
|
||||||
if input.Mirostat != 0 { |
|
||||||
config.Mirostat = input.Mirostat |
|
||||||
} |
|
||||||
|
|
||||||
if input.MirostatETA != 0 { |
|
||||||
config.MirostatETA = input.MirostatETA |
|
||||||
} |
|
||||||
|
|
||||||
if input.MirostatTAU != 0 { |
|
||||||
config.MirostatTAU = input.MirostatTAU |
|
||||||
} |
|
||||||
|
|
||||||
if input.TypicalP != 0 { |
|
||||||
config.TypicalP = input.TypicalP |
|
||||||
} |
|
||||||
|
|
||||||
switch inputs := input.Input.(type) { |
|
||||||
case string: |
|
||||||
if inputs != "" { |
|
||||||
config.InputStrings = append(config.InputStrings, inputs) |
|
||||||
} |
|
||||||
case []interface{}: |
|
||||||
for _, pp := range inputs { |
|
||||||
switch i := pp.(type) { |
|
||||||
case string: |
|
||||||
config.InputStrings = append(config.InputStrings, i) |
|
||||||
case []interface{}: |
|
||||||
tokens := []int{} |
|
||||||
for _, ii := range i { |
|
||||||
tokens = append(tokens, int(ii.(float64))) |
|
||||||
} |
|
||||||
config.InputToken = append(config.InputToken, tokens) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
// Can be either a string or an object
|
|
||||||
switch fnc := input.FunctionCall.(type) { |
|
||||||
case string: |
|
||||||
if fnc != "" { |
|
||||||
config.functionCallString = fnc |
|
||||||
} |
|
||||||
case map[string]interface{}: |
|
||||||
var name string |
|
||||||
n, exists := fnc["name"] |
|
||||||
if exists { |
|
||||||
nn, e := n.(string) |
|
||||||
if e { |
|
||||||
name = nn |
|
||||||
} |
|
||||||
} |
|
||||||
config.functionCallNameString = name |
|
||||||
} |
|
||||||
|
|
||||||
switch p := input.Prompt.(type) { |
|
||||||
case string: |
|
||||||
config.PromptStrings = append(config.PromptStrings, p) |
|
||||||
case []interface{}: |
|
||||||
for _, pp := range p { |
|
||||||
if s, ok := pp.(string); ok { |
|
||||||
config.PromptStrings = append(config.PromptStrings, s) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { |
|
||||||
input := new(OpenAIRequest) |
|
||||||
// Get input data from the request body
|
|
||||||
if err := c.BodyParser(input); err != nil { |
|
||||||
return "", nil, err |
|
||||||
} |
|
||||||
|
|
||||||
modelFile := input.Model |
|
||||||
|
|
||||||
if c.Params("model") != "" { |
|
||||||
modelFile = c.Params("model") |
|
||||||
} |
|
||||||
|
|
||||||
received, _ := json.Marshal(input) |
|
||||||
|
|
||||||
log.Debug().Msgf("Request received: %s", string(received)) |
|
||||||
|
|
||||||
// Set model from bearer token, if available
|
|
||||||
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
|
||||||
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
|
||||||
|
|
||||||
// If no model was specified, take the first available
|
|
||||||
if modelFile == "" && !bearerExists && randomModel { |
|
||||||
models, _ := loader.ListModels() |
|
||||||
if len(models) > 0 { |
|
||||||
modelFile = models[0] |
|
||||||
log.Debug().Msgf("No model specified, using: %s", modelFile) |
|
||||||
} else { |
|
||||||
log.Debug().Msgf("No model specified, returning error") |
|
||||||
return "", nil, fmt.Errorf("no model specified") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// If a model is found in bearer token takes precedence
|
|
||||||
if bearerExists { |
|
||||||
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
|
||||||
modelFile = bearer |
|
||||||
} |
|
||||||
return modelFile, input, nil |
|
||||||
} |
|
||||||
|
|
||||||
func readConfig(modelFile string, input *OpenAIRequest, cm *ConfigMerger, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { |
|
||||||
// Load a config file if present after the model name
|
|
||||||
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") |
|
||||||
|
|
||||||
var config *Config |
|
||||||
|
|
||||||
defaults := func() { |
|
||||||
config = defaultConfig(modelFile) |
|
||||||
config.ContextSize = ctx |
|
||||||
config.Threads = threads |
|
||||||
config.F16 = f16 |
|
||||||
config.Debug = debug |
|
||||||
} |
|
||||||
|
|
||||||
cfg, exists := cm.GetConfig(modelFile) |
|
||||||
if !exists { |
|
||||||
if _, err := os.Stat(modelConfig); err == nil { |
|
||||||
if err := cm.LoadConfig(modelConfig); err != nil { |
|
||||||
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) |
|
||||||
} |
|
||||||
cfg, exists = cm.GetConfig(modelFile) |
|
||||||
if exists { |
|
||||||
config = &cfg |
|
||||||
} else { |
|
||||||
defaults() |
|
||||||
} |
|
||||||
} else { |
|
||||||
defaults() |
|
||||||
} |
|
||||||
} else { |
|
||||||
config = &cfg |
|
||||||
} |
|
||||||
|
|
||||||
// Set the parameters for the language model prediction
|
|
||||||
updateConfig(config, input) |
|
||||||
|
|
||||||
// Don't allow 0 as setting
|
|
||||||
if config.Threads == 0 { |
|
||||||
if threads != 0 { |
|
||||||
config.Threads = threads |
|
||||||
} else { |
|
||||||
config.Threads = 4 |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// Enforce debug flag if passed from CLI
|
|
||||||
if debug { |
|
||||||
config.Debug = true |
|
||||||
} |
|
||||||
|
|
||||||
return config, input, nil |
|
||||||
} |
|
@ -0,0 +1,209 @@ |
|||||||
|
package api_config |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"io/fs" |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
"strings" |
||||||
|
"sync" |
||||||
|
|
||||||
|
"gopkg.in/yaml.v3" |
||||||
|
) |
||||||
|
|
||||||
|
type Config struct { |
||||||
|
PredictionOptions `yaml:"parameters"` |
||||||
|
Name string `yaml:"name"` |
||||||
|
StopWords []string `yaml:"stopwords"` |
||||||
|
Cutstrings []string `yaml:"cutstrings"` |
||||||
|
TrimSpace []string `yaml:"trimspace"` |
||||||
|
ContextSize int `yaml:"context_size"` |
||||||
|
F16 bool `yaml:"f16"` |
||||||
|
NUMA bool `yaml:"numa"` |
||||||
|
Threads int `yaml:"threads"` |
||||||
|
Debug bool `yaml:"debug"` |
||||||
|
Roles map[string]string `yaml:"roles"` |
||||||
|
Embeddings bool `yaml:"embeddings"` |
||||||
|
Backend string `yaml:"backend"` |
||||||
|
TemplateConfig TemplateConfig `yaml:"template"` |
||||||
|
MirostatETA float64 `yaml:"mirostat_eta"` |
||||||
|
MirostatTAU float64 `yaml:"mirostat_tau"` |
||||||
|
Mirostat int `yaml:"mirostat"` |
||||||
|
NGPULayers int `yaml:"gpu_layers"` |
||||||
|
MMap bool `yaml:"mmap"` |
||||||
|
MMlock bool `yaml:"mmlock"` |
||||||
|
LowVRAM bool `yaml:"low_vram"` |
||||||
|
|
||||||
|
TensorSplit string `yaml:"tensor_split"` |
||||||
|
MainGPU string `yaml:"main_gpu"` |
||||||
|
ImageGenerationAssets string `yaml:"asset_dir"` |
||||||
|
|
||||||
|
PromptCachePath string `yaml:"prompt_cache_path"` |
||||||
|
PromptCacheAll bool `yaml:"prompt_cache_all"` |
||||||
|
PromptCacheRO bool `yaml:"prompt_cache_ro"` |
||||||
|
|
||||||
|
Grammar string `yaml:"grammar"` |
||||||
|
|
||||||
|
PromptStrings, InputStrings []string |
||||||
|
InputToken [][]int |
||||||
|
functionCallString, functionCallNameString string |
||||||
|
|
||||||
|
FunctionsConfig Functions `yaml:"function"` |
||||||
|
} |
||||||
|
|
||||||
|
type Functions struct { |
||||||
|
DisableNoAction bool `yaml:"disable_no_action"` |
||||||
|
NoActionFunctionName string `yaml:"no_action_function_name"` |
||||||
|
NoActionDescriptionName string `yaml:"no_action_description_name"` |
||||||
|
} |
||||||
|
|
||||||
|
type TemplateConfig struct { |
||||||
|
Completion string `yaml:"completion"` |
||||||
|
Functions string `yaml:"function"` |
||||||
|
Chat string `yaml:"chat"` |
||||||
|
Edit string `yaml:"edit"` |
||||||
|
} |
||||||
|
|
||||||
|
type ConfigLoader struct { |
||||||
|
configs map[string]Config |
||||||
|
sync.Mutex |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Config) SetFunctionCallString(s string) { |
||||||
|
c.functionCallString = s |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Config) SetFunctionCallNameString(s string) { |
||||||
|
c.functionCallNameString = s |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Config) ShouldUseFunctions() bool { |
||||||
|
return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Config) ShouldCallSpecificFunction() bool { |
||||||
|
return len(c.functionCallNameString) > 0 |
||||||
|
} |
||||||
|
|
||||||
|
func (c *Config) FunctionToCall() string { |
||||||
|
return c.functionCallNameString |
||||||
|
} |
||||||
|
|
||||||
|
func defaultPredictOptions(modelFile string) PredictionOptions { |
||||||
|
return PredictionOptions{ |
||||||
|
TopP: 0.7, |
||||||
|
TopK: 80, |
||||||
|
Maxtokens: 512, |
||||||
|
Temperature: 0.9, |
||||||
|
Model: modelFile, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func DefaultConfig(modelFile string) *Config { |
||||||
|
return &Config{ |
||||||
|
PredictionOptions: defaultPredictOptions(modelFile), |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func NewConfigLoader() *ConfigLoader { |
||||||
|
return &ConfigLoader{ |
||||||
|
configs: make(map[string]Config), |
||||||
|
} |
||||||
|
} |
||||||
|
func ReadConfigFile(file string) ([]*Config, error) { |
||||||
|
c := &[]*Config{} |
||||||
|
f, err := os.ReadFile(file) |
||||||
|
if err != nil { |
||||||
|
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||||
|
} |
||||||
|
if err := yaml.Unmarshal(f, c); err != nil { |
||||||
|
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
return *c, nil |
||||||
|
} |
||||||
|
|
||||||
|
func ReadConfig(file string) (*Config, error) { |
||||||
|
c := &Config{} |
||||||
|
f, err := os.ReadFile(file) |
||||||
|
if err != nil { |
||||||
|
return nil, fmt.Errorf("cannot read config file: %w", err) |
||||||
|
} |
||||||
|
if err := yaml.Unmarshal(f, c); err != nil { |
||||||
|
return nil, fmt.Errorf("cannot unmarshal config file: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
return c, nil |
||||||
|
} |
||||||
|
|
||||||
|
func (cm *ConfigLoader) LoadConfigFile(file string) error { |
||||||
|
cm.Lock() |
||||||
|
defer cm.Unlock() |
||||||
|
c, err := ReadConfigFile(file) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("cannot load config file: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
for _, cc := range c { |
||||||
|
cm.configs[cc.Name] = *cc |
||||||
|
} |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (cm *ConfigLoader) LoadConfig(file string) error { |
||||||
|
cm.Lock() |
||||||
|
defer cm.Unlock() |
||||||
|
c, err := ReadConfig(file) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("cannot read config file: %w", err) |
||||||
|
} |
||||||
|
|
||||||
|
cm.configs[c.Name] = *c |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
func (cm *ConfigLoader) GetConfig(m string) (Config, bool) { |
||||||
|
cm.Lock() |
||||||
|
defer cm.Unlock() |
||||||
|
v, exists := cm.configs[m] |
||||||
|
return v, exists |
||||||
|
} |
||||||
|
|
||||||
|
func (cm *ConfigLoader) ListConfigs() []string { |
||||||
|
cm.Lock() |
||||||
|
defer cm.Unlock() |
||||||
|
var res []string |
||||||
|
for k := range cm.configs { |
||||||
|
res = append(res, k) |
||||||
|
} |
||||||
|
return res |
||||||
|
} |
||||||
|
|
||||||
|
func (cm *ConfigLoader) LoadConfigs(path string) error { |
||||||
|
cm.Lock() |
||||||
|
defer cm.Unlock() |
||||||
|
entries, err := os.ReadDir(path) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
files := make([]fs.FileInfo, 0, len(entries)) |
||||||
|
for _, entry := range entries { |
||||||
|
info, err := entry.Info() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
files = append(files, info) |
||||||
|
} |
||||||
|
for _, file := range files { |
||||||
|
// Skip templates, YAML and .keep files
|
||||||
|
if !strings.Contains(file.Name(), ".yaml") { |
||||||
|
continue |
||||||
|
} |
||||||
|
c, err := ReadConfig(filepath.Join(path, file.Name())) |
||||||
|
if err == nil { |
||||||
|
cm.configs[c.Name] = *c |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return nil |
||||||
|
} |
@ -0,0 +1,37 @@ |
|||||||
|
package api_config |
||||||
|
|
||||||
|
type PredictionOptions struct { |
||||||
|
|
||||||
|
// Also part of the OpenAI official spec
|
||||||
|
Model string `json:"model" yaml:"model"` |
||||||
|
|
||||||
|
// Also part of the OpenAI official spec
|
||||||
|
Language string `json:"language"` |
||||||
|
|
||||||
|
// Also part of the OpenAI official spec. use it for returning multiple results
|
||||||
|
N int `json:"n"` |
||||||
|
|
||||||
|
// Common options between all the API calls, part of the OpenAI spec
|
||||||
|
TopP float64 `json:"top_p" yaml:"top_p"` |
||||||
|
TopK int `json:"top_k" yaml:"top_k"` |
||||||
|
Temperature float64 `json:"temperature" yaml:"temperature"` |
||||||
|
Maxtokens int `json:"max_tokens" yaml:"max_tokens"` |
||||||
|
Echo bool `json:"echo"` |
||||||
|
|
||||||
|
// Custom parameters - not present in the OpenAI API
|
||||||
|
Batch int `json:"batch" yaml:"batch"` |
||||||
|
F16 bool `json:"f16" yaml:"f16"` |
||||||
|
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` |
||||||
|
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` |
||||||
|
Keep int `json:"n_keep" yaml:"n_keep"` |
||||||
|
|
||||||
|
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` |
||||||
|
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` |
||||||
|
Mirostat int `json:"mirostat" yaml:"mirostat"` |
||||||
|
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` |
||||||
|
TFZ float64 `json:"tfz" yaml:"tfz"` |
||||||
|
|
||||||
|
TypicalP float64 `json:"typical_p" yaml:"typical_p"` |
||||||
|
Seed int `json:"seed" yaml:"seed"` |
||||||
|
} |
@ -1,973 +0,0 @@ |
|||||||
package api |
|
||||||
|
|
||||||
import ( |
|
||||||
"bufio" |
|
||||||
"bytes" |
|
||||||
"encoding/base64" |
|
||||||
"encoding/json" |
|
||||||
"errors" |
|
||||||
"fmt" |
|
||||||
"io" |
|
||||||
"io/ioutil" |
|
||||||
"net/http" |
|
||||||
"os" |
|
||||||
"path" |
|
||||||
"path/filepath" |
|
||||||
"strconv" |
|
||||||
"strings" |
|
||||||
|
|
||||||
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" |
|
||||||
"github.com/go-skynet/LocalAI/pkg/grammar" |
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model" |
|
||||||
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" |
|
||||||
"github.com/gofiber/fiber/v2" |
|
||||||
"github.com/rs/zerolog/log" |
|
||||||
"github.com/valyala/fasthttp" |
|
||||||
) |
|
||||||
|
|
||||||
// APIError provides error information returned by the OpenAI API.
|
|
||||||
type APIError struct { |
|
||||||
Code any `json:"code,omitempty"` |
|
||||||
Message string `json:"message"` |
|
||||||
Param *string `json:"param,omitempty"` |
|
||||||
Type string `json:"type"` |
|
||||||
} |
|
||||||
|
|
||||||
type ErrorResponse struct { |
|
||||||
Error *APIError `json:"error,omitempty"` |
|
||||||
} |
|
||||||
|
|
||||||
type OpenAIUsage struct { |
|
||||||
PromptTokens int `json:"prompt_tokens"` |
|
||||||
CompletionTokens int `json:"completion_tokens"` |
|
||||||
TotalTokens int `json:"total_tokens"` |
|
||||||
} |
|
||||||
|
|
||||||
type Item struct { |
|
||||||
Embedding []float32 `json:"embedding"` |
|
||||||
Index int `json:"index"` |
|
||||||
Object string `json:"object,omitempty"` |
|
||||||
|
|
||||||
// Images
|
|
||||||
URL string `json:"url,omitempty"` |
|
||||||
B64JSON string `json:"b64_json,omitempty"` |
|
||||||
} |
|
||||||
|
|
||||||
type OpenAIResponse struct { |
|
||||||
Created int `json:"created,omitempty"` |
|
||||||
Object string `json:"object,omitempty"` |
|
||||||
ID string `json:"id,omitempty"` |
|
||||||
Model string `json:"model,omitempty"` |
|
||||||
Choices []Choice `json:"choices,omitempty"` |
|
||||||
Data []Item `json:"data,omitempty"` |
|
||||||
|
|
||||||
Usage OpenAIUsage `json:"usage"` |
|
||||||
} |
|
||||||
|
|
||||||
type Choice struct { |
|
||||||
Index int `json:"index,omitempty"` |
|
||||||
FinishReason string `json:"finish_reason,omitempty"` |
|
||||||
Message *Message `json:"message,omitempty"` |
|
||||||
Delta *Message `json:"delta,omitempty"` |
|
||||||
Text string `json:"text,omitempty"` |
|
||||||
} |
|
||||||
|
|
||||||
type Message struct { |
|
||||||
// The message role
|
|
||||||
Role string `json:"role,omitempty" yaml:"role"` |
|
||||||
// The message content
|
|
||||||
Content *string `json:"content" yaml:"content"` |
|
||||||
// A result of a function call
|
|
||||||
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` |
|
||||||
} |
|
||||||
|
|
||||||
type OpenAIModel struct { |
|
||||||
ID string `json:"id"` |
|
||||||
Object string `json:"object"` |
|
||||||
} |
|
||||||
|
|
||||||
type OpenAIRequest struct { |
|
||||||
Model string `json:"model" yaml:"model"` |
|
||||||
|
|
||||||
// whisper
|
|
||||||
File string `json:"file" validate:"required"` |
|
||||||
Language string `json:"language"` |
|
||||||
//whisper/image
|
|
||||||
ResponseFormat string `json:"response_format"` |
|
||||||
// image
|
|
||||||
Size string `json:"size"` |
|
||||||
// Prompt is read only by completion/image API calls
|
|
||||||
Prompt interface{} `json:"prompt" yaml:"prompt"` |
|
||||||
|
|
||||||
// Edit endpoint
|
|
||||||
Instruction string `json:"instruction" yaml:"instruction"` |
|
||||||
Input interface{} `json:"input" yaml:"input"` |
|
||||||
|
|
||||||
Stop interface{} `json:"stop" yaml:"stop"` |
|
||||||
|
|
||||||
// Messages is read only by chat/completion API calls
|
|
||||||
Messages []Message `json:"messages" yaml:"messages"` |
|
||||||
|
|
||||||
// A list of available functions to call
|
|
||||||
Functions []grammar.Function `json:"functions" yaml:"functions"` |
|
||||||
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
|
||||||
|
|
||||||
Stream bool `json:"stream"` |
|
||||||
Echo bool `json:"echo"` |
|
||||||
// Common options between all the API calls
|
|
||||||
TopP float64 `json:"top_p" yaml:"top_p"` |
|
||||||
TopK int `json:"top_k" yaml:"top_k"` |
|
||||||
Temperature float64 `json:"temperature" yaml:"temperature"` |
|
||||||
Maxtokens int `json:"max_tokens" yaml:"max_tokens"` |
|
||||||
|
|
||||||
N int `json:"n"` |
|
||||||
|
|
||||||
// Custom parameters - not present in the OpenAI API
|
|
||||||
Batch int `json:"batch" yaml:"batch"` |
|
||||||
F16 bool `json:"f16" yaml:"f16"` |
|
||||||
IgnoreEOS bool `json:"ignore_eos" yaml:"ignore_eos"` |
|
||||||
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` |
|
||||||
Keep int `json:"n_keep" yaml:"n_keep"` |
|
||||||
|
|
||||||
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` |
|
||||||
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` |
|
||||||
Mirostat int `json:"mirostat" yaml:"mirostat"` |
|
||||||
|
|
||||||
FrequencyPenalty float64 `json:"frequency_penalty" yaml:"frequency_penalty"` |
|
||||||
TFZ float64 `json:"tfz" yaml:"tfz"` |
|
||||||
|
|
||||||
Seed int `json:"seed" yaml:"seed"` |
|
||||||
|
|
||||||
// Image (not supported by OpenAI)
|
|
||||||
Mode int `json:"mode"` |
|
||||||
Step int `json:"step"` |
|
||||||
|
|
||||||
// A grammar to constrain the LLM output
|
|
||||||
Grammar string `json:"grammar" yaml:"grammar"` |
|
||||||
// A grammar object
|
|
||||||
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` |
|
||||||
|
|
||||||
TypicalP float64 `json:"typical_p" yaml:"typical_p"` |
|
||||||
} |
|
||||||
|
|
||||||
func defaultRequest(modelFile string) OpenAIRequest { |
|
||||||
return OpenAIRequest{ |
|
||||||
TopP: 0.7, |
|
||||||
TopK: 80, |
|
||||||
Maxtokens: 512, |
|
||||||
Temperature: 0.9, |
|
||||||
Model: modelFile, |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/completions
|
|
||||||
func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
|
||||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
|
||||||
resp := OpenAIResponse{ |
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{ |
|
||||||
{ |
|
||||||
Index: 0, |
|
||||||
Text: s, |
|
||||||
}, |
|
||||||
}, |
|
||||||
Object: "text_completion", |
|
||||||
} |
|
||||||
log.Debug().Msgf("Sending goroutine: %s", s) |
|
||||||
|
|
||||||
responses <- resp |
|
||||||
return true |
|
||||||
}) |
|
||||||
close(responses) |
|
||||||
} |
|
||||||
|
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
model, input, err := readInput(c, o.loader, true) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("`input`: %+v", input) |
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config) |
|
||||||
|
|
||||||
if input.Stream { |
|
||||||
log.Debug().Msgf("Stream request received") |
|
||||||
c.Context().SetContentType("text/event-stream") |
|
||||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
|
||||||
//c.Set("Content-Type", "text/event-stream")
|
|
||||||
c.Set("Cache-Control", "no-cache") |
|
||||||
c.Set("Connection", "keep-alive") |
|
||||||
c.Set("Transfer-Encoding", "chunked") |
|
||||||
} |
|
||||||
|
|
||||||
templateFile := config.Model |
|
||||||
|
|
||||||
if config.TemplateConfig.Completion != "" { |
|
||||||
templateFile = config.TemplateConfig.Completion |
|
||||||
} |
|
||||||
|
|
||||||
if input.Stream { |
|
||||||
if len(config.PromptStrings) > 1 { |
|
||||||
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") |
|
||||||
} |
|
||||||
|
|
||||||
predInput := config.PromptStrings[0] |
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
|
||||||
Input string |
|
||||||
}{ |
|
||||||
Input: predInput, |
|
||||||
}) |
|
||||||
if err == nil { |
|
||||||
predInput = templatedInput |
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
|
||||||
} |
|
||||||
|
|
||||||
responses := make(chan OpenAIResponse) |
|
||||||
|
|
||||||
go process(predInput, input, config, o.loader, responses) |
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
|
||||||
|
|
||||||
for ev := range responses { |
|
||||||
var buf bytes.Buffer |
|
||||||
enc := json.NewEncoder(&buf) |
|
||||||
enc.Encode(ev) |
|
||||||
|
|
||||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
|
||||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
|
||||||
w.Flush() |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{ |
|
||||||
{ |
|
||||||
Index: 0, |
|
||||||
FinishReason: "stop", |
|
||||||
}, |
|
||||||
}, |
|
||||||
Object: "text_completion", |
|
||||||
} |
|
||||||
respData, _ := json.Marshal(resp) |
|
||||||
|
|
||||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
|
||||||
w.WriteString("data: [DONE]\n\n") |
|
||||||
w.Flush() |
|
||||||
})) |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
var result []Choice |
|
||||||
for _, i := range config.PromptStrings { |
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
|
||||||
Input string |
|
||||||
}{ |
|
||||||
Input: i, |
|
||||||
}) |
|
||||||
if err == nil { |
|
||||||
i = templatedInput |
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
|
||||||
} |
|
||||||
|
|
||||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { |
|
||||||
*c = append(*c, Choice{Text: s}) |
|
||||||
}, nil) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
result = append(result, r...) |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result, |
|
||||||
Object: "text_completion", |
|
||||||
} |
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp) |
|
||||||
log.Debug().Msgf("Response: %s", jsonResult) |
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/embeddings
|
|
||||||
func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
model, input, err := readInput(c, o.loader, true) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config) |
|
||||||
items := []Item{} |
|
||||||
|
|
||||||
for i, s := range config.InputToken { |
|
||||||
// get the model function to call for the result
|
|
||||||
embedFn, err := ModelEmbedding("", s, o.loader, *config, o) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
embeddings, err := embedFn() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
|
||||||
} |
|
||||||
|
|
||||||
for i, s := range config.InputStrings { |
|
||||||
// get the model function to call for the result
|
|
||||||
embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config, o) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
embeddings, err := embedFn() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Data: items, |
|
||||||
Object: "list", |
|
||||||
} |
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp) |
|
||||||
log.Debug().Msgf("Response: %s", jsonResult) |
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func isEOS(s string) bool { |
|
||||||
if s == "<|endoftext|>" { |
|
||||||
return true |
|
||||||
} |
|
||||||
|
|
||||||
return false |
|
||||||
} |
|
||||||
func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
|
|
||||||
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
|
||||||
initialMessage := OpenAIResponse{ |
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, |
|
||||||
Object: "chat.completion.chunk", |
|
||||||
} |
|
||||||
responses <- initialMessage |
|
||||||
|
|
||||||
ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
|
||||||
resp := OpenAIResponse{ |
|
||||||
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, |
|
||||||
Object: "chat.completion.chunk", |
|
||||||
} |
|
||||||
log.Debug().Msgf("Sending goroutine: %s", s) |
|
||||||
|
|
||||||
if s != "" && !isEOS(s) { |
|
||||||
responses <- resp |
|
||||||
} |
|
||||||
return true |
|
||||||
}) |
|
||||||
close(responses) |
|
||||||
} |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
processFunctions := false |
|
||||||
funcs := grammar.Functions{} |
|
||||||
model, input, err := readInput(c, o.loader, true) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
log.Debug().Msgf("Configuration read: %+v", config) |
|
||||||
|
|
||||||
// Allow the user to set custom actions via config file
|
|
||||||
// to be "embedded" in each model
|
|
||||||
noActionName := "answer" |
|
||||||
noActionDescription := "use this action to answer without performing any action" |
|
||||||
|
|
||||||
if config.FunctionsConfig.NoActionFunctionName != "" { |
|
||||||
noActionName = config.FunctionsConfig.NoActionFunctionName |
|
||||||
} |
|
||||||
if config.FunctionsConfig.NoActionDescriptionName != "" { |
|
||||||
noActionDescription = config.FunctionsConfig.NoActionDescriptionName |
|
||||||
} |
|
||||||
|
|
||||||
// process functions if we have any defined or if we have a function call string
|
|
||||||
if len(input.Functions) > 0 && |
|
||||||
((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) { |
|
||||||
log.Debug().Msgf("Response needs to process functions") |
|
||||||
|
|
||||||
processFunctions = true |
|
||||||
|
|
||||||
noActionGrammar := grammar.Function{ |
|
||||||
Name: noActionName, |
|
||||||
Description: noActionDescription, |
|
||||||
Parameters: map[string]interface{}{ |
|
||||||
"properties": map[string]interface{}{ |
|
||||||
"message": map[string]interface{}{ |
|
||||||
"type": "string", |
|
||||||
"description": "The message to reply the user with", |
|
||||||
}}, |
|
||||||
}, |
|
||||||
} |
|
||||||
|
|
||||||
// Append the no action function
|
|
||||||
funcs = append(funcs, input.Functions...) |
|
||||||
if !config.FunctionsConfig.DisableNoAction { |
|
||||||
funcs = append(funcs, noActionGrammar) |
|
||||||
} |
|
||||||
|
|
||||||
// Force picking one of the functions by the request
|
|
||||||
if config.functionCallNameString != "" { |
|
||||||
funcs = funcs.Select(config.functionCallNameString) |
|
||||||
} |
|
||||||
|
|
||||||
// Update input grammar
|
|
||||||
jsStruct := funcs.ToJSONStructure() |
|
||||||
config.Grammar = jsStruct.Grammar("") |
|
||||||
} else if input.JSONFunctionGrammarObject != nil { |
|
||||||
config.Grammar = input.JSONFunctionGrammarObject.Grammar("") |
|
||||||
} |
|
||||||
|
|
||||||
// functions are not supported in stream mode (yet?)
|
|
||||||
toStream := input.Stream && !processFunctions |
|
||||||
|
|
||||||
log.Debug().Msgf("Parameters: %+v", config) |
|
||||||
|
|
||||||
var predInput string |
|
||||||
|
|
||||||
mess := []string{} |
|
||||||
for _, i := range input.Messages { |
|
||||||
var content string |
|
||||||
role := i.Role |
|
||||||
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
|
||||||
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
|
||||||
if i.FunctionCall != nil && i.Role == "assistant" { |
|
||||||
roleFn := "assistant_function_call" |
|
||||||
r := config.Roles[roleFn] |
|
||||||
if r != "" { |
|
||||||
role = roleFn |
|
||||||
} |
|
||||||
} |
|
||||||
r := config.Roles[role] |
|
||||||
contentExists := i.Content != nil && *i.Content != "" |
|
||||||
if r != "" { |
|
||||||
if contentExists { |
|
||||||
content = fmt.Sprint(r, " ", *i.Content) |
|
||||||
} |
|
||||||
if i.FunctionCall != nil { |
|
||||||
j, err := json.Marshal(i.FunctionCall) |
|
||||||
if err == nil { |
|
||||||
if contentExists { |
|
||||||
content += "\n" + fmt.Sprint(r, " ", string(j)) |
|
||||||
} else { |
|
||||||
content = fmt.Sprint(r, " ", string(j)) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} else { |
|
||||||
if contentExists { |
|
||||||
content = fmt.Sprint(*i.Content) |
|
||||||
} |
|
||||||
if i.FunctionCall != nil { |
|
||||||
j, err := json.Marshal(i.FunctionCall) |
|
||||||
if err == nil { |
|
||||||
if contentExists { |
|
||||||
content += "\n" + string(j) |
|
||||||
} else { |
|
||||||
content = string(j) |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
mess = append(mess, content) |
|
||||||
} |
|
||||||
|
|
||||||
predInput = strings.Join(mess, "\n") |
|
||||||
log.Debug().Msgf("Prompt (before templating): %s", predInput) |
|
||||||
|
|
||||||
if toStream { |
|
||||||
log.Debug().Msgf("Stream request received") |
|
||||||
c.Context().SetContentType("text/event-stream") |
|
||||||
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
|
||||||
// c.Set("Content-Type", "text/event-stream")
|
|
||||||
c.Set("Cache-Control", "no-cache") |
|
||||||
c.Set("Connection", "keep-alive") |
|
||||||
c.Set("Transfer-Encoding", "chunked") |
|
||||||
} |
|
||||||
|
|
||||||
templateFile := config.Model |
|
||||||
|
|
||||||
if config.TemplateConfig.Chat != "" && !processFunctions { |
|
||||||
templateFile = config.TemplateConfig.Chat |
|
||||||
} |
|
||||||
|
|
||||||
if config.TemplateConfig.Functions != "" && processFunctions { |
|
||||||
templateFile = config.TemplateConfig.Functions |
|
||||||
} |
|
||||||
|
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
|
||||||
Input string |
|
||||||
Functions []grammar.Function |
|
||||||
}{ |
|
||||||
Input: predInput, |
|
||||||
Functions: funcs, |
|
||||||
}) |
|
||||||
if err == nil { |
|
||||||
predInput = templatedInput |
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
|
||||||
} else { |
|
||||||
log.Debug().Msgf("Template failed loading: %s", err.Error()) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Prompt (after templating): %s", predInput) |
|
||||||
if processFunctions { |
|
||||||
log.Debug().Msgf("Grammar: %+v", config.Grammar) |
|
||||||
} |
|
||||||
|
|
||||||
if toStream { |
|
||||||
responses := make(chan OpenAIResponse) |
|
||||||
|
|
||||||
go process(predInput, input, config, o.loader, responses) |
|
||||||
|
|
||||||
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
|
||||||
|
|
||||||
for ev := range responses { |
|
||||||
var buf bytes.Buffer |
|
||||||
enc := json.NewEncoder(&buf) |
|
||||||
enc.Encode(ev) |
|
||||||
|
|
||||||
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
|
||||||
fmt.Fprintf(w, "data: %v\n", buf.String()) |
|
||||||
w.Flush() |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: []Choice{ |
|
||||||
{ |
|
||||||
FinishReason: "stop", |
|
||||||
Index: 0, |
|
||||||
Delta: &Message{}, |
|
||||||
}}, |
|
||||||
Object: "chat.completion.chunk", |
|
||||||
} |
|
||||||
respData, _ := json.Marshal(resp) |
|
||||||
|
|
||||||
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
|
||||||
w.WriteString("data: [DONE]\n\n") |
|
||||||
w.Flush() |
|
||||||
})) |
|
||||||
return nil |
|
||||||
} |
|
||||||
|
|
||||||
result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) { |
|
||||||
if processFunctions { |
|
||||||
// As we have to change the result before processing, we can't stream the answer (yet?)
|
|
||||||
ss := map[string]interface{}{} |
|
||||||
json.Unmarshal([]byte(s), &ss) |
|
||||||
log.Debug().Msgf("Function return: %s %+v", s, ss) |
|
||||||
|
|
||||||
// The grammar defines the function name as "function", while OpenAI returns "name"
|
|
||||||
func_name := ss["function"] |
|
||||||
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
|
||||||
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
|
||||||
d, _ := json.Marshal(args) |
|
||||||
|
|
||||||
ss["arguments"] = string(d) |
|
||||||
ss["name"] = func_name |
|
||||||
|
|
||||||
// if do nothing, reply with a message
|
|
||||||
if func_name == noActionName { |
|
||||||
log.Debug().Msgf("nothing to do, computing a reply") |
|
||||||
|
|
||||||
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
|
||||||
arguments := map[string]interface{}{} |
|
||||||
json.Unmarshal([]byte(d), &arguments) |
|
||||||
m, exists := arguments["message"] |
|
||||||
if exists { |
|
||||||
switch message := m.(type) { |
|
||||||
case string: |
|
||||||
if message != "" { |
|
||||||
log.Debug().Msgf("Reply received from LLM: %s", message) |
|
||||||
message = Finetune(*config, predInput, message) |
|
||||||
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) |
|
||||||
|
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) |
|
||||||
return |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("No action received from LLM, without a message, computing a reply") |
|
||||||
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
|
||||||
// Note: This costs (in term of CPU) another computation
|
|
||||||
config.Grammar = "" |
|
||||||
predFunc, err := ModelInference(predInput, o.loader, *config, o, nil) |
|
||||||
if err != nil { |
|
||||||
log.Error().Msgf("inference error: %s", err.Error()) |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
prediction, err := predFunc() |
|
||||||
if err != nil { |
|
||||||
log.Error().Msgf("inference error: %s", err.Error()) |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
prediction = Finetune(*config, predInput, prediction) |
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) |
|
||||||
} else { |
|
||||||
// otherwise reply with the function call
|
|
||||||
*c = append(*c, Choice{ |
|
||||||
FinishReason: "function_call", |
|
||||||
Message: &Message{Role: "assistant", FunctionCall: ss}, |
|
||||||
}) |
|
||||||
} |
|
||||||
|
|
||||||
return |
|
||||||
} |
|
||||||
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) |
|
||||||
}, nil) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result, |
|
||||||
Object: "chat.completion", |
|
||||||
} |
|
||||||
respData, _ := json.Marshal(resp) |
|
||||||
log.Debug().Msgf("Response: %s", respData) |
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
model, input, err := readInput(c, o.loader, true) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config) |
|
||||||
|
|
||||||
templateFile := config.Model |
|
||||||
|
|
||||||
if config.TemplateConfig.Edit != "" { |
|
||||||
templateFile = config.TemplateConfig.Edit |
|
||||||
} |
|
||||||
|
|
||||||
var result []Choice |
|
||||||
for _, i := range config.InputStrings { |
|
||||||
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
||||||
templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { |
|
||||||
Input string |
|
||||||
Instruction string |
|
||||||
}{Input: i}) |
|
||||||
if err == nil { |
|
||||||
i = templatedInput |
|
||||||
log.Debug().Msgf("Template found, input modified to: %s", i) |
|
||||||
} |
|
||||||
|
|
||||||
r, err := ComputeChoices(i, input, config, o, o.loader, func(s string, c *[]Choice) { |
|
||||||
*c = append(*c, Choice{Text: s}) |
|
||||||
}, nil) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
result = append(result, r...) |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
||||||
Choices: result, |
|
||||||
Object: "edit", |
|
||||||
} |
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp) |
|
||||||
log.Debug().Msgf("Response: %s", jsonResult) |
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/images/create
|
|
||||||
|
|
||||||
/* |
|
||||||
* |
|
||||||
|
|
||||||
curl http://localhost:8080/v1/images/generations \
|
|
||||||
-H "Content-Type: application/json" \
|
|
||||||
-d '{ |
|
||||||
"prompt": "A cute baby sea otter", |
|
||||||
"n": 1, |
|
||||||
"size": "512x512" |
|
||||||
}' |
|
||||||
|
|
||||||
* |
|
||||||
*/ |
|
||||||
func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
m, input, err := readInput(c, o.loader, false) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
if m == "" { |
|
||||||
m = model.StableDiffusionBackend |
|
||||||
} |
|
||||||
log.Debug().Msgf("Loading model: %+v", m) |
|
||||||
|
|
||||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Parameter Config: %+v", config) |
|
||||||
|
|
||||||
// XXX: Only stablediffusion is supported for now
|
|
||||||
if config.Backend == "" { |
|
||||||
config.Backend = model.StableDiffusionBackend |
|
||||||
} |
|
||||||
|
|
||||||
sizeParts := strings.Split(input.Size, "x") |
|
||||||
if len(sizeParts) != 2 { |
|
||||||
return fmt.Errorf("Invalid value for 'size'") |
|
||||||
} |
|
||||||
width, err := strconv.Atoi(sizeParts[0]) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("Invalid value for 'size'") |
|
||||||
} |
|
||||||
height, err := strconv.Atoi(sizeParts[1]) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("Invalid value for 'size'") |
|
||||||
} |
|
||||||
|
|
||||||
b64JSON := false |
|
||||||
if input.ResponseFormat == "b64_json" { |
|
||||||
b64JSON = true |
|
||||||
} |
|
||||||
|
|
||||||
var result []Item |
|
||||||
for _, i := range config.PromptStrings { |
|
||||||
n := input.N |
|
||||||
if input.N == 0 { |
|
||||||
n = 1 |
|
||||||
} |
|
||||||
for j := 0; j < n; j++ { |
|
||||||
prompts := strings.Split(i, "|") |
|
||||||
positive_prompt := prompts[0] |
|
||||||
negative_prompt := "" |
|
||||||
if len(prompts) > 1 { |
|
||||||
negative_prompt = prompts[1] |
|
||||||
} |
|
||||||
|
|
||||||
mode := 0 |
|
||||||
step := 15 |
|
||||||
|
|
||||||
if input.Mode != 0 { |
|
||||||
mode = input.Mode |
|
||||||
} |
|
||||||
|
|
||||||
if input.Step != 0 { |
|
||||||
step = input.Step |
|
||||||
} |
|
||||||
|
|
||||||
tempDir := "" |
|
||||||
if !b64JSON { |
|
||||||
tempDir = o.imageDir |
|
||||||
} |
|
||||||
// Create a temporary file
|
|
||||||
outputFile, err := ioutil.TempFile(tempDir, "b64") |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
outputFile.Close() |
|
||||||
output := outputFile.Name() + ".png" |
|
||||||
// Rename the temporary file
|
|
||||||
err = os.Rename(outputFile.Name(), output) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
baseURL := c.BaseURL() |
|
||||||
|
|
||||||
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config, o) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
if err := fn(); err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
item := &Item{} |
|
||||||
|
|
||||||
if b64JSON { |
|
||||||
defer os.RemoveAll(output) |
|
||||||
data, err := os.ReadFile(output) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
|
||||||
} else { |
|
||||||
base := filepath.Base(output) |
|
||||||
item.URL = baseURL + "/generated-images/" + base |
|
||||||
} |
|
||||||
|
|
||||||
result = append(result, *item) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
resp := &OpenAIResponse{ |
|
||||||
Data: result, |
|
||||||
} |
|
||||||
|
|
||||||
jsonResult, _ := json.Marshal(resp) |
|
||||||
log.Debug().Msgf("Response: %s", jsonResult) |
|
||||||
|
|
||||||
// Return the prediction in the response body
|
|
||||||
return c.JSON(resp) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
// https://platform.openai.com/docs/api-reference/audio/create
|
|
||||||
func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
m, input, err := readInput(c, o.loader, false) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
|
|
||||||
config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) |
|
||||||
if err != nil { |
|
||||||
return fmt.Errorf("failed reading parameters from request:%w", err) |
|
||||||
} |
|
||||||
// retrieve the file data from the request
|
|
||||||
file, err := c.FormFile("file") |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
f, err := file.Open() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
defer f.Close() |
|
||||||
|
|
||||||
dir, err := os.MkdirTemp("", "whisper") |
|
||||||
|
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
defer os.RemoveAll(dir) |
|
||||||
|
|
||||||
dst := filepath.Join(dir, path.Base(file.Filename)) |
|
||||||
dstFile, err := os.Create(dst) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
if _, err := io.Copy(dstFile, f); err != nil { |
|
||||||
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Audio file copied to: %+v", dst) |
|
||||||
|
|
||||||
whisperModel, err := o.loader.BackendLoader( |
|
||||||
model.WithBackendString(model.WhisperBackend), |
|
||||||
model.WithModelFile(config.Model), |
|
||||||
model.WithThreads(uint32(config.Threads)), |
|
||||||
model.WithAssetDir(o.assetsDestination)) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
if whisperModel == nil { |
|
||||||
return fmt.Errorf("could not load whisper model") |
|
||||||
} |
|
||||||
|
|
||||||
w, ok := whisperModel.(whisper.Model) |
|
||||||
if !ok { |
|
||||||
return fmt.Errorf("loader returned non-whisper object") |
|
||||||
} |
|
||||||
|
|
||||||
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
|
|
||||||
log.Debug().Msgf("Trascribed: %+v", tr) |
|
||||||
// TODO: handle different outputs here
|
|
||||||
return c.Status(http.StatusOK).JSON(tr) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func listModels(loader *model.ModelLoader, cm *ConfigMerger) func(ctx *fiber.Ctx) error { |
|
||||||
return func(c *fiber.Ctx) error { |
|
||||||
models, err := loader.ListModels() |
|
||||||
if err != nil { |
|
||||||
return err |
|
||||||
} |
|
||||||
var mm map[string]interface{} = map[string]interface{}{} |
|
||||||
|
|
||||||
dataModels := []OpenAIModel{} |
|
||||||
for _, m := range models { |
|
||||||
mm[m] = nil |
|
||||||
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) |
|
||||||
} |
|
||||||
|
|
||||||
for _, k := range cm.ListConfigs() { |
|
||||||
if _, exists := mm[k]; !exists { |
|
||||||
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return c.JSON(struct { |
|
||||||
Object string `json:"object"` |
|
||||||
Data []OpenAIModel `json:"data"` |
|
||||||
}{ |
|
||||||
Object: "list", |
|
||||||
Data: dataModels, |
|
||||||
}) |
|
||||||
} |
|
||||||
} |
|
@ -0,0 +1,105 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||||
|
) |
||||||
|
|
||||||
|
// APIError provides error information returned by the OpenAI API.
|
||||||
|
type APIError struct { |
||||||
|
Code any `json:"code,omitempty"` |
||||||
|
Message string `json:"message"` |
||||||
|
Param *string `json:"param,omitempty"` |
||||||
|
Type string `json:"type"` |
||||||
|
} |
||||||
|
|
||||||
|
type ErrorResponse struct { |
||||||
|
Error *APIError `json:"error,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
type OpenAIUsage struct { |
||||||
|
PromptTokens int `json:"prompt_tokens"` |
||||||
|
CompletionTokens int `json:"completion_tokens"` |
||||||
|
TotalTokens int `json:"total_tokens"` |
||||||
|
} |
||||||
|
|
||||||
|
type Item struct { |
||||||
|
Embedding []float32 `json:"embedding"` |
||||||
|
Index int `json:"index"` |
||||||
|
Object string `json:"object,omitempty"` |
||||||
|
|
||||||
|
// Images
|
||||||
|
URL string `json:"url,omitempty"` |
||||||
|
B64JSON string `json:"b64_json,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
type OpenAIResponse struct { |
||||||
|
Created int `json:"created,omitempty"` |
||||||
|
Object string `json:"object,omitempty"` |
||||||
|
ID string `json:"id,omitempty"` |
||||||
|
Model string `json:"model,omitempty"` |
||||||
|
Choices []Choice `json:"choices,omitempty"` |
||||||
|
Data []Item `json:"data,omitempty"` |
||||||
|
|
||||||
|
Usage OpenAIUsage `json:"usage"` |
||||||
|
} |
||||||
|
|
||||||
|
type Choice struct { |
||||||
|
Index int `json:"index,omitempty"` |
||||||
|
FinishReason string `json:"finish_reason,omitempty"` |
||||||
|
Message *Message `json:"message,omitempty"` |
||||||
|
Delta *Message `json:"delta,omitempty"` |
||||||
|
Text string `json:"text,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
type Message struct { |
||||||
|
// The message role
|
||||||
|
Role string `json:"role,omitempty" yaml:"role"` |
||||||
|
// The message content
|
||||||
|
Content *string `json:"content" yaml:"content"` |
||||||
|
// A result of a function call
|
||||||
|
FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
type OpenAIModel struct { |
||||||
|
ID string `json:"id"` |
||||||
|
Object string `json:"object"` |
||||||
|
} |
||||||
|
|
||||||
|
type OpenAIRequest struct { |
||||||
|
config.PredictionOptions |
||||||
|
|
||||||
|
// whisper
|
||||||
|
File string `json:"file" validate:"required"` |
||||||
|
//whisper/image
|
||||||
|
ResponseFormat string `json:"response_format"` |
||||||
|
// image
|
||||||
|
Size string `json:"size"` |
||||||
|
// Prompt is read only by completion/image API calls
|
||||||
|
Prompt interface{} `json:"prompt" yaml:"prompt"` |
||||||
|
|
||||||
|
// Edit endpoint
|
||||||
|
Instruction string `json:"instruction" yaml:"instruction"` |
||||||
|
Input interface{} `json:"input" yaml:"input"` |
||||||
|
|
||||||
|
Stop interface{} `json:"stop" yaml:"stop"` |
||||||
|
|
||||||
|
// Messages is read only by chat/completion API calls
|
||||||
|
Messages []Message `json:"messages" yaml:"messages"` |
||||||
|
|
||||||
|
// A list of available functions to call
|
||||||
|
Functions []grammar.Function `json:"functions" yaml:"functions"` |
||||||
|
FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object
|
||||||
|
|
||||||
|
Stream bool `json:"stream"` |
||||||
|
|
||||||
|
// Image (not supported by OpenAI)
|
||||||
|
Mode int `json:"mode"` |
||||||
|
Step int `json:"step"` |
||||||
|
|
||||||
|
// A grammar to constrain the LLM output
|
||||||
|
Grammar string `json:"grammar" yaml:"grammar"` |
||||||
|
|
||||||
|
JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` |
||||||
|
} |
@ -0,0 +1,320 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"bufio" |
||||||
|
"bytes" |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
"strings" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend" |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/go-skynet/LocalAI/pkg/grammar" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
"github.com/valyala/fasthttp" |
||||||
|
) |
||||||
|
|
||||||
|
func ChatEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||||
|
initialMessage := OpenAIResponse{ |
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, |
||||||
|
Object: "chat.completion.chunk", |
||||||
|
} |
||||||
|
responses <- initialMessage |
||||||
|
|
||||||
|
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||||
|
resp := OpenAIResponse{ |
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, |
||||||
|
Object: "chat.completion.chunk", |
||||||
|
} |
||||||
|
|
||||||
|
responses <- resp |
||||||
|
return true |
||||||
|
}) |
||||||
|
close(responses) |
||||||
|
} |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
processFunctions := false |
||||||
|
funcs := grammar.Functions{} |
||||||
|
model, input, err := readInput(c, o.Loader, true) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
log.Debug().Msgf("Configuration read: %+v", config) |
||||||
|
|
||||||
|
// Allow the user to set custom actions via config file
|
||||||
|
// to be "embedded" in each model
|
||||||
|
noActionName := "answer" |
||||||
|
noActionDescription := "use this action to answer without performing any action" |
||||||
|
|
||||||
|
if config.FunctionsConfig.NoActionFunctionName != "" { |
||||||
|
noActionName = config.FunctionsConfig.NoActionFunctionName |
||||||
|
} |
||||||
|
if config.FunctionsConfig.NoActionDescriptionName != "" { |
||||||
|
noActionDescription = config.FunctionsConfig.NoActionDescriptionName |
||||||
|
} |
||||||
|
|
||||||
|
// process functions if we have any defined or if we have a function call string
|
||||||
|
if len(input.Functions) > 0 && config.ShouldUseFunctions() { |
||||||
|
log.Debug().Msgf("Response needs to process functions") |
||||||
|
|
||||||
|
processFunctions = true |
||||||
|
|
||||||
|
noActionGrammar := grammar.Function{ |
||||||
|
Name: noActionName, |
||||||
|
Description: noActionDescription, |
||||||
|
Parameters: map[string]interface{}{ |
||||||
|
"properties": map[string]interface{}{ |
||||||
|
"message": map[string]interface{}{ |
||||||
|
"type": "string", |
||||||
|
"description": "The message to reply the user with", |
||||||
|
}}, |
||||||
|
}, |
||||||
|
} |
||||||
|
|
||||||
|
// Append the no action function
|
||||||
|
funcs = append(funcs, input.Functions...) |
||||||
|
if !config.FunctionsConfig.DisableNoAction { |
||||||
|
funcs = append(funcs, noActionGrammar) |
||||||
|
} |
||||||
|
|
||||||
|
// Force picking one of the functions by the request
|
||||||
|
if config.FunctionToCall() != "" { |
||||||
|
funcs = funcs.Select(config.FunctionToCall()) |
||||||
|
} |
||||||
|
|
||||||
|
// Update input grammar
|
||||||
|
jsStruct := funcs.ToJSONStructure() |
||||||
|
config.Grammar = jsStruct.Grammar("") |
||||||
|
} else if input.JSONFunctionGrammarObject != nil { |
||||||
|
config.Grammar = input.JSONFunctionGrammarObject.Grammar("") |
||||||
|
} |
||||||
|
|
||||||
|
// functions are not supported in stream mode (yet?)
|
||||||
|
toStream := input.Stream && !processFunctions |
||||||
|
|
||||||
|
log.Debug().Msgf("Parameters: %+v", config) |
||||||
|
|
||||||
|
var predInput string |
||||||
|
|
||||||
|
mess := []string{} |
||||||
|
for _, i := range input.Messages { |
||||||
|
var content string |
||||||
|
role := i.Role |
||||||
|
// if function call, we might want to customize the role so we can display better that the "assistant called a json action"
|
||||||
|
// if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request
|
||||||
|
if i.FunctionCall != nil && i.Role == "assistant" { |
||||||
|
roleFn := "assistant_function_call" |
||||||
|
r := config.Roles[roleFn] |
||||||
|
if r != "" { |
||||||
|
role = roleFn |
||||||
|
} |
||||||
|
} |
||||||
|
r := config.Roles[role] |
||||||
|
contentExists := i.Content != nil && *i.Content != "" |
||||||
|
if r != "" { |
||||||
|
if contentExists { |
||||||
|
content = fmt.Sprint(r, " ", *i.Content) |
||||||
|
} |
||||||
|
if i.FunctionCall != nil { |
||||||
|
j, err := json.Marshal(i.FunctionCall) |
||||||
|
if err == nil { |
||||||
|
if contentExists { |
||||||
|
content += "\n" + fmt.Sprint(r, " ", string(j)) |
||||||
|
} else { |
||||||
|
content = fmt.Sprint(r, " ", string(j)) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} else { |
||||||
|
if contentExists { |
||||||
|
content = fmt.Sprint(*i.Content) |
||||||
|
} |
||||||
|
if i.FunctionCall != nil { |
||||||
|
j, err := json.Marshal(i.FunctionCall) |
||||||
|
if err == nil { |
||||||
|
if contentExists { |
||||||
|
content += "\n" + string(j) |
||||||
|
} else { |
||||||
|
content = string(j) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
mess = append(mess, content) |
||||||
|
} |
||||||
|
|
||||||
|
predInput = strings.Join(mess, "\n") |
||||||
|
log.Debug().Msgf("Prompt (before templating): %s", predInput) |
||||||
|
|
||||||
|
if toStream { |
||||||
|
log.Debug().Msgf("Stream request received") |
||||||
|
c.Context().SetContentType("text/event-stream") |
||||||
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||||
|
// c.Set("Content-Type", "text/event-stream")
|
||||||
|
c.Set("Cache-Control", "no-cache") |
||||||
|
c.Set("Connection", "keep-alive") |
||||||
|
c.Set("Transfer-Encoding", "chunked") |
||||||
|
} |
||||||
|
|
||||||
|
templateFile := config.Model |
||||||
|
|
||||||
|
if config.TemplateConfig.Chat != "" && !processFunctions { |
||||||
|
templateFile = config.TemplateConfig.Chat |
||||||
|
} |
||||||
|
|
||||||
|
if config.TemplateConfig.Functions != "" && processFunctions { |
||||||
|
templateFile = config.TemplateConfig.Functions |
||||||
|
} |
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||||
|
Input string |
||||||
|
Functions []grammar.Function |
||||||
|
}{ |
||||||
|
Input: predInput, |
||||||
|
Functions: funcs, |
||||||
|
}) |
||||||
|
if err == nil { |
||||||
|
predInput = templatedInput |
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||||
|
} else { |
||||||
|
log.Debug().Msgf("Template failed loading: %s", err.Error()) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Prompt (after templating): %s", predInput) |
||||||
|
if processFunctions { |
||||||
|
log.Debug().Msgf("Grammar: %+v", config.Grammar) |
||||||
|
} |
||||||
|
|
||||||
|
if toStream { |
||||||
|
responses := make(chan OpenAIResponse) |
||||||
|
|
||||||
|
go process(predInput, input, config, o.Loader, responses) |
||||||
|
|
||||||
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||||
|
|
||||||
|
for ev := range responses { |
||||||
|
var buf bytes.Buffer |
||||||
|
enc := json.NewEncoder(&buf) |
||||||
|
enc.Encode(ev) |
||||||
|
|
||||||
|
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||||
|
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||||
|
w.Flush() |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{ |
||||||
|
{ |
||||||
|
FinishReason: "stop", |
||||||
|
Index: 0, |
||||||
|
Delta: &Message{}, |
||||||
|
}}, |
||||||
|
Object: "chat.completion.chunk", |
||||||
|
} |
||||||
|
respData, _ := json.Marshal(resp) |
||||||
|
|
||||||
|
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||||
|
w.WriteString("data: [DONE]\n\n") |
||||||
|
w.Flush() |
||||||
|
})) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
result, err := ComputeChoices(predInput, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||||
|
if processFunctions { |
||||||
|
// As we have to change the result before processing, we can't stream the answer (yet?)
|
||||||
|
ss := map[string]interface{}{} |
||||||
|
json.Unmarshal([]byte(s), &ss) |
||||||
|
log.Debug().Msgf("Function return: %s %+v", s, ss) |
||||||
|
|
||||||
|
// The grammar defines the function name as "function", while OpenAI returns "name"
|
||||||
|
func_name := ss["function"] |
||||||
|
// Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object
|
||||||
|
args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix)
|
||||||
|
d, _ := json.Marshal(args) |
||||||
|
|
||||||
|
ss["arguments"] = string(d) |
||||||
|
ss["name"] = func_name |
||||||
|
|
||||||
|
// if do nothing, reply with a message
|
||||||
|
if func_name == noActionName { |
||||||
|
log.Debug().Msgf("nothing to do, computing a reply") |
||||||
|
|
||||||
|
// If there is a message that the LLM already sends as part of the JSON reply, use it
|
||||||
|
arguments := map[string]interface{}{} |
||||||
|
json.Unmarshal([]byte(d), &arguments) |
||||||
|
m, exists := arguments["message"] |
||||||
|
if exists { |
||||||
|
switch message := m.(type) { |
||||||
|
case string: |
||||||
|
if message != "" { |
||||||
|
log.Debug().Msgf("Reply received from LLM: %s", message) |
||||||
|
message = backend.Finetune(*config, predInput, message) |
||||||
|
log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) |
||||||
|
|
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) |
||||||
|
return |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("No action received from LLM, without a message, computing a reply") |
||||||
|
// Otherwise ask the LLM to understand the JSON output and the context, and return a message
|
||||||
|
// Note: This costs (in term of CPU) another computation
|
||||||
|
config.Grammar = "" |
||||||
|
predFunc, err := backend.ModelInference(predInput, o.Loader, *config, o, nil) |
||||||
|
if err != nil { |
||||||
|
log.Error().Msgf("inference error: %s", err.Error()) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
prediction, err := predFunc() |
||||||
|
if err != nil { |
||||||
|
log.Error().Msgf("inference error: %s", err.Error()) |
||||||
|
return |
||||||
|
} |
||||||
|
|
||||||
|
prediction = backend.Finetune(*config, predInput, prediction) |
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) |
||||||
|
} else { |
||||||
|
// otherwise reply with the function call
|
||||||
|
*c = append(*c, Choice{ |
||||||
|
FinishReason: "function_call", |
||||||
|
Message: &Message{Role: "assistant", FunctionCall: ss}, |
||||||
|
}) |
||||||
|
} |
||||||
|
|
||||||
|
return |
||||||
|
} |
||||||
|
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) |
||||||
|
}, nil) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result, |
||||||
|
Object: "chat.completion", |
||||||
|
} |
||||||
|
respData, _ := json.Marshal(resp) |
||||||
|
log.Debug().Msgf("Response: %s", respData) |
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,159 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"bufio" |
||||||
|
"bytes" |
||||||
|
"encoding/json" |
||||||
|
"errors" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
"github.com/valyala/fasthttp" |
||||||
|
) |
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/completions
|
||||||
|
func CompletionEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
process := func(s string, req *OpenAIRequest, config *config.Config, loader *model.ModelLoader, responses chan OpenAIResponse) { |
||||||
|
ComputeChoices(s, req.N, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { |
||||||
|
resp := OpenAIResponse{ |
||||||
|
Model: req.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{ |
||||||
|
{ |
||||||
|
Index: 0, |
||||||
|
Text: s, |
||||||
|
}, |
||||||
|
}, |
||||||
|
Object: "text_completion", |
||||||
|
} |
||||||
|
log.Debug().Msgf("Sending goroutine: %s", s) |
||||||
|
|
||||||
|
responses <- resp |
||||||
|
return true |
||||||
|
}) |
||||||
|
close(responses) |
||||||
|
} |
||||||
|
|
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
model, input, err := readInput(c, o.Loader, true) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("`input`: %+v", input) |
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config) |
||||||
|
|
||||||
|
if input.Stream { |
||||||
|
log.Debug().Msgf("Stream request received") |
||||||
|
c.Context().SetContentType("text/event-stream") |
||||||
|
//c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8)
|
||||||
|
//c.Set("Content-Type", "text/event-stream")
|
||||||
|
c.Set("Cache-Control", "no-cache") |
||||||
|
c.Set("Connection", "keep-alive") |
||||||
|
c.Set("Transfer-Encoding", "chunked") |
||||||
|
} |
||||||
|
|
||||||
|
templateFile := config.Model |
||||||
|
|
||||||
|
if config.TemplateConfig.Completion != "" { |
||||||
|
templateFile = config.TemplateConfig.Completion |
||||||
|
} |
||||||
|
|
||||||
|
if input.Stream { |
||||||
|
if len(config.PromptStrings) > 1 { |
||||||
|
return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") |
||||||
|
} |
||||||
|
|
||||||
|
predInput := config.PromptStrings[0] |
||||||
|
|
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||||
|
Input string |
||||||
|
}{ |
||||||
|
Input: predInput, |
||||||
|
}) |
||||||
|
if err == nil { |
||||||
|
predInput = templatedInput |
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
||||||
|
} |
||||||
|
|
||||||
|
responses := make(chan OpenAIResponse) |
||||||
|
|
||||||
|
go process(predInput, input, config, o.Loader, responses) |
||||||
|
|
||||||
|
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { |
||||||
|
|
||||||
|
for ev := range responses { |
||||||
|
var buf bytes.Buffer |
||||||
|
enc := json.NewEncoder(&buf) |
||||||
|
enc.Encode(ev) |
||||||
|
|
||||||
|
log.Debug().Msgf("Sending chunk: %s", buf.String()) |
||||||
|
fmt.Fprintf(w, "data: %v\n", buf.String()) |
||||||
|
w.Flush() |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: []Choice{ |
||||||
|
{ |
||||||
|
Index: 0, |
||||||
|
FinishReason: "stop", |
||||||
|
}, |
||||||
|
}, |
||||||
|
Object: "text_completion", |
||||||
|
} |
||||||
|
respData, _ := json.Marshal(resp) |
||||||
|
|
||||||
|
w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) |
||||||
|
w.WriteString("data: [DONE]\n\n") |
||||||
|
w.Flush() |
||||||
|
})) |
||||||
|
return nil |
||||||
|
} |
||||||
|
|
||||||
|
var result []Choice |
||||||
|
for _, i := range config.PromptStrings { |
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||||
|
Input string |
||||||
|
}{ |
||||||
|
Input: i, |
||||||
|
}) |
||||||
|
if err == nil { |
||||||
|
i = templatedInput |
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||||
|
} |
||||||
|
|
||||||
|
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||||
|
*c = append(*c, Choice{Text: s}) |
||||||
|
}, nil) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
result = append(result, r...) |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result, |
||||||
|
Object: "text_completion", |
||||||
|
} |
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp) |
||||||
|
log.Debug().Msgf("Response: %s", jsonResult) |
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,67 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
) |
||||||
|
|
||||||
|
func EditEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
model, input, err := readInput(c, o.Loader, true) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config) |
||||||
|
|
||||||
|
templateFile := config.Model |
||||||
|
|
||||||
|
if config.TemplateConfig.Edit != "" { |
||||||
|
templateFile = config.TemplateConfig.Edit |
||||||
|
} |
||||||
|
|
||||||
|
var result []Choice |
||||||
|
for _, i := range config.InputStrings { |
||||||
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
||||||
|
templatedInput, err := o.Loader.TemplatePrefix(templateFile, struct { |
||||||
|
Input string |
||||||
|
Instruction string |
||||||
|
}{Input: i}) |
||||||
|
if err == nil { |
||||||
|
i = templatedInput |
||||||
|
log.Debug().Msgf("Template found, input modified to: %s", i) |
||||||
|
} |
||||||
|
|
||||||
|
r, err := ComputeChoices(i, input.N, config, o, o.Loader, func(s string, c *[]Choice) { |
||||||
|
*c = append(*c, Choice{Text: s}) |
||||||
|
}, nil) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
result = append(result, r...) |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Choices: result, |
||||||
|
Object: "edit", |
||||||
|
} |
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp) |
||||||
|
log.Debug().Msgf("Response: %s", jsonResult) |
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,70 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend" |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
) |
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/embeddings
|
||||||
|
func EmbeddingsEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
model, input, err := readInput(c, o.Loader, true) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
config, input, err := readConfig(model, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config) |
||||||
|
items := []Item{} |
||||||
|
|
||||||
|
for i, s := range config.InputToken { |
||||||
|
// get the model function to call for the result
|
||||||
|
embedFn, err := backend.ModelEmbedding("", s, o.Loader, *config, o) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
embeddings, err := embedFn() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||||
|
} |
||||||
|
|
||||||
|
for i, s := range config.InputStrings { |
||||||
|
// get the model function to call for the result
|
||||||
|
embedFn, err := backend.ModelEmbedding(s, []int{}, o.Loader, *config, o) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
embeddings, err := embedFn() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
items = append(items, Item{Embedding: embeddings, Index: i, Object: "embedding"}) |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
||||||
|
Data: items, |
||||||
|
Object: "list", |
||||||
|
} |
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp) |
||||||
|
log.Debug().Msgf("Response: %s", jsonResult) |
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,158 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/base64" |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
"io/ioutil" |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
"strconv" |
||||||
|
"strings" |
||||||
|
|
||||||
|
"github.com/go-skynet/LocalAI/api/backend" |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
) |
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/images/create
|
||||||
|
|
||||||
|
/* |
||||||
|
* |
||||||
|
|
||||||
|
curl http://localhost:8080/v1/images/generations \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{ |
||||||
|
"prompt": "A cute baby sea otter", |
||||||
|
"n": 1, |
||||||
|
"size": "512x512" |
||||||
|
}' |
||||||
|
|
||||||
|
* |
||||||
|
*/ |
||||||
|
func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
m, input, err := readInput(c, o.Loader, false) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
if m == "" { |
||||||
|
m = model.StableDiffusionBackend |
||||||
|
} |
||||||
|
log.Debug().Msgf("Loading model: %+v", m) |
||||||
|
|
||||||
|
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, 0, 0, false) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Parameter Config: %+v", config) |
||||||
|
|
||||||
|
// XXX: Only stablediffusion is supported for now
|
||||||
|
if config.Backend == "" { |
||||||
|
config.Backend = model.StableDiffusionBackend |
||||||
|
} |
||||||
|
|
||||||
|
sizeParts := strings.Split(input.Size, "x") |
||||||
|
if len(sizeParts) != 2 { |
||||||
|
return fmt.Errorf("Invalid value for 'size'") |
||||||
|
} |
||||||
|
width, err := strconv.Atoi(sizeParts[0]) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("Invalid value for 'size'") |
||||||
|
} |
||||||
|
height, err := strconv.Atoi(sizeParts[1]) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("Invalid value for 'size'") |
||||||
|
} |
||||||
|
|
||||||
|
b64JSON := false |
||||||
|
if input.ResponseFormat == "b64_json" { |
||||||
|
b64JSON = true |
||||||
|
} |
||||||
|
|
||||||
|
var result []Item |
||||||
|
for _, i := range config.PromptStrings { |
||||||
|
n := input.N |
||||||
|
if input.N == 0 { |
||||||
|
n = 1 |
||||||
|
} |
||||||
|
for j := 0; j < n; j++ { |
||||||
|
prompts := strings.Split(i, "|") |
||||||
|
positive_prompt := prompts[0] |
||||||
|
negative_prompt := "" |
||||||
|
if len(prompts) > 1 { |
||||||
|
negative_prompt = prompts[1] |
||||||
|
} |
||||||
|
|
||||||
|
mode := 0 |
||||||
|
step := 15 |
||||||
|
|
||||||
|
if input.Mode != 0 { |
||||||
|
mode = input.Mode |
||||||
|
} |
||||||
|
|
||||||
|
if input.Step != 0 { |
||||||
|
step = input.Step |
||||||
|
} |
||||||
|
|
||||||
|
tempDir := "" |
||||||
|
if !b64JSON { |
||||||
|
tempDir = o.ImageDir |
||||||
|
} |
||||||
|
// Create a temporary file
|
||||||
|
outputFile, err := ioutil.TempFile(tempDir, "b64") |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
outputFile.Close() |
||||||
|
output := outputFile.Name() + ".png" |
||||||
|
// Rename the temporary file
|
||||||
|
err = os.Rename(outputFile.Name(), output) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
baseURL := c.BaseURL() |
||||||
|
|
||||||
|
fn, err := backend.ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.Loader, *config, o) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
if err := fn(); err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
item := &Item{} |
||||||
|
|
||||||
|
if b64JSON { |
||||||
|
defer os.RemoveAll(output) |
||||||
|
data, err := os.ReadFile(output) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
item.B64JSON = base64.StdEncoding.EncodeToString(data) |
||||||
|
} else { |
||||||
|
base := filepath.Base(output) |
||||||
|
item.URL = baseURL + "/generated-images/" + base |
||||||
|
} |
||||||
|
|
||||||
|
result = append(result, *item) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
resp := &OpenAIResponse{ |
||||||
|
Data: result, |
||||||
|
} |
||||||
|
|
||||||
|
jsonResult, _ := json.Marshal(resp) |
||||||
|
log.Debug().Msgf("Response: %s", jsonResult) |
||||||
|
|
||||||
|
// Return the prediction in the response body
|
||||||
|
return c.JSON(resp) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,36 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"github.com/go-skynet/LocalAI/api/backend" |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
) |
||||||
|
|
||||||
|
func ComputeChoices(predInput string, n int, config *config.Config, o *options.Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { |
||||||
|
result := []Choice{} |
||||||
|
|
||||||
|
if n == 0 { |
||||||
|
n = 1 |
||||||
|
} |
||||||
|
|
||||||
|
// get the model function to call for the result
|
||||||
|
predFunc, err := backend.ModelInference(predInput, loader, *config, o, tokenCallback) |
||||||
|
if err != nil { |
||||||
|
return result, err |
||||||
|
} |
||||||
|
|
||||||
|
for i := 0; i < n; i++ { |
||||||
|
prediction, err := predFunc() |
||||||
|
if err != nil { |
||||||
|
return result, err |
||||||
|
} |
||||||
|
|
||||||
|
prediction = backend.Finetune(*config, predInput, prediction) |
||||||
|
cb(prediction, &result) |
||||||
|
|
||||||
|
//result = append(result, Choice{Text: prediction})
|
||||||
|
|
||||||
|
} |
||||||
|
return result, err |
||||||
|
} |
@ -0,0 +1,37 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
) |
||||||
|
|
||||||
|
func ListModelsEndpoint(loader *model.ModelLoader, cm *config.ConfigLoader) func(ctx *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
models, err := loader.ListModels() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
var mm map[string]interface{} = map[string]interface{}{} |
||||||
|
|
||||||
|
dataModels := []OpenAIModel{} |
||||||
|
for _, m := range models { |
||||||
|
mm[m] = nil |
||||||
|
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"}) |
||||||
|
} |
||||||
|
|
||||||
|
for _, k := range cm.ListConfigs() { |
||||||
|
if _, exists := mm[k]; !exists { |
||||||
|
dataModels = append(dataModels, OpenAIModel{ID: k, Object: "model"}) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return c.JSON(struct { |
||||||
|
Object string `json:"object"` |
||||||
|
Data []OpenAIModel `json:"data"` |
||||||
|
}{ |
||||||
|
Object: "list", |
||||||
|
Data: dataModels, |
||||||
|
}) |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,234 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
"os" |
||||||
|
"path/filepath" |
||||||
|
"strings" |
||||||
|
|
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
) |
||||||
|
|
||||||
|
func readInput(c *fiber.Ctx, loader *model.ModelLoader, randomModel bool) (string, *OpenAIRequest, error) { |
||||||
|
input := new(OpenAIRequest) |
||||||
|
// Get input data from the request body
|
||||||
|
if err := c.BodyParser(input); err != nil { |
||||||
|
return "", nil, err |
||||||
|
} |
||||||
|
|
||||||
|
modelFile := input.Model |
||||||
|
|
||||||
|
if c.Params("model") != "" { |
||||||
|
modelFile = c.Params("model") |
||||||
|
} |
||||||
|
|
||||||
|
received, _ := json.Marshal(input) |
||||||
|
|
||||||
|
log.Debug().Msgf("Request received: %s", string(received)) |
||||||
|
|
||||||
|
// Set model from bearer token, if available
|
||||||
|
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
||||||
|
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
||||||
|
|
||||||
|
// If no model was specified, take the first available
|
||||||
|
if modelFile == "" && !bearerExists && randomModel { |
||||||
|
models, _ := loader.ListModels() |
||||||
|
if len(models) > 0 { |
||||||
|
modelFile = models[0] |
||||||
|
log.Debug().Msgf("No model specified, using: %s", modelFile) |
||||||
|
} else { |
||||||
|
log.Debug().Msgf("No model specified, returning error") |
||||||
|
return "", nil, fmt.Errorf("no model specified") |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// If a model is found in bearer token takes precedence
|
||||||
|
if bearerExists { |
||||||
|
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
||||||
|
modelFile = bearer |
||||||
|
} |
||||||
|
return modelFile, input, nil |
||||||
|
} |
||||||
|
|
||||||
|
func updateConfig(config *config.Config, input *OpenAIRequest) { |
||||||
|
if input.Echo { |
||||||
|
config.Echo = input.Echo |
||||||
|
} |
||||||
|
if input.TopK != 0 { |
||||||
|
config.TopK = input.TopK |
||||||
|
} |
||||||
|
if input.TopP != 0 { |
||||||
|
config.TopP = input.TopP |
||||||
|
} |
||||||
|
|
||||||
|
if input.Grammar != "" { |
||||||
|
config.Grammar = input.Grammar |
||||||
|
} |
||||||
|
|
||||||
|
if input.Temperature != 0 { |
||||||
|
config.Temperature = input.Temperature |
||||||
|
} |
||||||
|
|
||||||
|
if input.Maxtokens != 0 { |
||||||
|
config.Maxtokens = input.Maxtokens |
||||||
|
} |
||||||
|
|
||||||
|
switch stop := input.Stop.(type) { |
||||||
|
case string: |
||||||
|
if stop != "" { |
||||||
|
config.StopWords = append(config.StopWords, stop) |
||||||
|
} |
||||||
|
case []interface{}: |
||||||
|
for _, pp := range stop { |
||||||
|
if s, ok := pp.(string); ok { |
||||||
|
config.StopWords = append(config.StopWords, s) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
if input.RepeatPenalty != 0 { |
||||||
|
config.RepeatPenalty = input.RepeatPenalty |
||||||
|
} |
||||||
|
|
||||||
|
if input.Keep != 0 { |
||||||
|
config.Keep = input.Keep |
||||||
|
} |
||||||
|
|
||||||
|
if input.Batch != 0 { |
||||||
|
config.Batch = input.Batch |
||||||
|
} |
||||||
|
|
||||||
|
if input.F16 { |
||||||
|
config.F16 = input.F16 |
||||||
|
} |
||||||
|
|
||||||
|
if input.IgnoreEOS { |
||||||
|
config.IgnoreEOS = input.IgnoreEOS |
||||||
|
} |
||||||
|
|
||||||
|
if input.Seed != 0 { |
||||||
|
config.Seed = input.Seed |
||||||
|
} |
||||||
|
|
||||||
|
if input.Mirostat != 0 { |
||||||
|
config.Mirostat = input.Mirostat |
||||||
|
} |
||||||
|
|
||||||
|
if input.MirostatETA != 0 { |
||||||
|
config.MirostatETA = input.MirostatETA |
||||||
|
} |
||||||
|
|
||||||
|
if input.MirostatTAU != 0 { |
||||||
|
config.MirostatTAU = input.MirostatTAU |
||||||
|
} |
||||||
|
|
||||||
|
if input.TypicalP != 0 { |
||||||
|
config.TypicalP = input.TypicalP |
||||||
|
} |
||||||
|
|
||||||
|
switch inputs := input.Input.(type) { |
||||||
|
case string: |
||||||
|
if inputs != "" { |
||||||
|
config.InputStrings = append(config.InputStrings, inputs) |
||||||
|
} |
||||||
|
case []interface{}: |
||||||
|
for _, pp := range inputs { |
||||||
|
switch i := pp.(type) { |
||||||
|
case string: |
||||||
|
config.InputStrings = append(config.InputStrings, i) |
||||||
|
case []interface{}: |
||||||
|
tokens := []int{} |
||||||
|
for _, ii := range i { |
||||||
|
tokens = append(tokens, int(ii.(float64))) |
||||||
|
} |
||||||
|
config.InputToken = append(config.InputToken, tokens) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Can be either a string or an object
|
||||||
|
switch fnc := input.FunctionCall.(type) { |
||||||
|
case string: |
||||||
|
if fnc != "" { |
||||||
|
config.SetFunctionCallString(fnc) |
||||||
|
} |
||||||
|
case map[string]interface{}: |
||||||
|
var name string |
||||||
|
n, exists := fnc["name"] |
||||||
|
if exists { |
||||||
|
nn, e := n.(string) |
||||||
|
if !e { |
||||||
|
name = nn |
||||||
|
} |
||||||
|
} |
||||||
|
config.SetFunctionCallNameString(name) |
||||||
|
} |
||||||
|
|
||||||
|
switch p := input.Prompt.(type) { |
||||||
|
case string: |
||||||
|
config.PromptStrings = append(config.PromptStrings, p) |
||||||
|
case []interface{}: |
||||||
|
for _, pp := range p { |
||||||
|
if s, ok := pp.(string); ok { |
||||||
|
config.PromptStrings = append(config.PromptStrings, s) |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func readConfig(modelFile string, input *OpenAIRequest, cm *config.ConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.Config, *OpenAIRequest, error) { |
||||||
|
// Load a config file if present after the model name
|
||||||
|
modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") |
||||||
|
|
||||||
|
var cfg *config.Config |
||||||
|
|
||||||
|
defaults := func() { |
||||||
|
cfg = config.DefaultConfig(modelFile) |
||||||
|
cfg.ContextSize = ctx |
||||||
|
cfg.Threads = threads |
||||||
|
cfg.F16 = f16 |
||||||
|
cfg.Debug = debug |
||||||
|
} |
||||||
|
|
||||||
|
cfgExisting, exists := cm.GetConfig(modelFile) |
||||||
|
if !exists { |
||||||
|
if _, err := os.Stat(modelConfig); err == nil { |
||||||
|
if err := cm.LoadConfig(modelConfig); err != nil { |
||||||
|
return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) |
||||||
|
} |
||||||
|
cfgExisting, exists = cm.GetConfig(modelFile) |
||||||
|
if exists { |
||||||
|
cfg = &cfgExisting |
||||||
|
} else { |
||||||
|
defaults() |
||||||
|
} |
||||||
|
} else { |
||||||
|
defaults() |
||||||
|
} |
||||||
|
} else { |
||||||
|
cfg = &cfgExisting |
||||||
|
} |
||||||
|
|
||||||
|
// Set the parameters for the language model prediction
|
||||||
|
updateConfig(cfg, input) |
||||||
|
|
||||||
|
// Don't allow 0 as setting
|
||||||
|
if cfg.Threads == 0 { |
||||||
|
if threads != 0 { |
||||||
|
cfg.Threads = threads |
||||||
|
} else { |
||||||
|
cfg.Threads = 4 |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
// Enforce debug flag if passed from CLI
|
||||||
|
if debug { |
||||||
|
cfg.Debug = true |
||||||
|
} |
||||||
|
|
||||||
|
return cfg, input, nil |
||||||
|
} |
@ -0,0 +1,91 @@ |
|||||||
|
package openai |
||||||
|
|
||||||
|
import ( |
||||||
|
"fmt" |
||||||
|
"io" |
||||||
|
"net/http" |
||||||
|
"os" |
||||||
|
"path" |
||||||
|
"path/filepath" |
||||||
|
|
||||||
|
"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" |
||||||
|
config "github.com/go-skynet/LocalAI/api/config" |
||||||
|
"github.com/go-skynet/LocalAI/api/options" |
||||||
|
model "github.com/go-skynet/LocalAI/pkg/model" |
||||||
|
whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" |
||||||
|
|
||||||
|
"github.com/gofiber/fiber/v2" |
||||||
|
"github.com/rs/zerolog/log" |
||||||
|
) |
||||||
|
|
||||||
|
// https://platform.openai.com/docs/api-reference/audio/create
|
||||||
|
func TranscriptEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx) error { |
||||||
|
return func(c *fiber.Ctx) error { |
||||||
|
m, input, err := readInput(c, o.Loader, false) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
|
||||||
|
config, input, err := readConfig(m, input, cm, o.Loader, o.Debug, o.Threads, o.ContextSize, o.F16) |
||||||
|
if err != nil { |
||||||
|
return fmt.Errorf("failed reading parameters from request:%w", err) |
||||||
|
} |
||||||
|
// retrieve the file data from the request
|
||||||
|
file, err := c.FormFile("file") |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
f, err := file.Open() |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
defer f.Close() |
||||||
|
|
||||||
|
dir, err := os.MkdirTemp("", "whisper") |
||||||
|
|
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
defer os.RemoveAll(dir) |
||||||
|
|
||||||
|
dst := filepath.Join(dir, path.Base(file.Filename)) |
||||||
|
dstFile, err := os.Create(dst) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if _, err := io.Copy(dstFile, f); err != nil { |
||||||
|
log.Debug().Msgf("Audio file copying error %+v - %+v - err %+v", file.Filename, dst, err) |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Audio file copied to: %+v", dst) |
||||||
|
|
||||||
|
whisperModel, err := o.Loader.BackendLoader( |
||||||
|
model.WithBackendString(model.WhisperBackend), |
||||||
|
model.WithModelFile(config.Model), |
||||||
|
model.WithThreads(uint32(config.Threads)), |
||||||
|
model.WithAssetDir(o.AssetsDestination)) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
if whisperModel == nil { |
||||||
|
return fmt.Errorf("could not load whisper model") |
||||||
|
} |
||||||
|
|
||||||
|
w, ok := whisperModel.(whisper.Model) |
||||||
|
if !ok { |
||||||
|
return fmt.Errorf("loader returned non-whisper object") |
||||||
|
} |
||||||
|
|
||||||
|
tr, err := whisperutil.Transcript(w, dst, input.Language, uint(config.Threads)) |
||||||
|
if err != nil { |
||||||
|
return err |
||||||
|
} |
||||||
|
|
||||||
|
log.Debug().Msgf("Trascribed: %+v", tr) |
||||||
|
// TODO: handle different outputs here
|
||||||
|
return c.Status(http.StatusOK).JSON(fiber.Map{"text": tr}) |
||||||
|
} |
||||||
|
} |
@ -1,415 +0,0 @@ |
|||||||
package api |
|
||||||
|
|
||||||
import ( |
|
||||||
"context" |
|
||||||
"fmt" |
|
||||||
"os" |
|
||||||
"path/filepath" |
|
||||||
"regexp" |
|
||||||
"strings" |
|
||||||
"sync" |
|
||||||
|
|
||||||
"github.com/donomii/go-rwkv.cpp" |
|
||||||
"github.com/go-skynet/LocalAI/pkg/grpc" |
|
||||||
pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" |
|
||||||
"github.com/go-skynet/LocalAI/pkg/langchain" |
|
||||||
model "github.com/go-skynet/LocalAI/pkg/model" |
|
||||||
"github.com/go-skynet/LocalAI/pkg/stablediffusion" |
|
||||||
"github.com/go-skynet/bloomz.cpp" |
|
||||||
bert "github.com/go-skynet/go-bert.cpp" |
|
||||||
) |
|
||||||
|
|
||||||
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
var mutexMap sync.Mutex |
|
||||||
var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) |
|
||||||
|
|
||||||
func gRPCModelOpts(c Config) *pb.ModelOptions { |
|
||||||
b := 512 |
|
||||||
if c.Batch != 0 { |
|
||||||
b = c.Batch |
|
||||||
} |
|
||||||
return &pb.ModelOptions{ |
|
||||||
ContextSize: int32(c.ContextSize), |
|
||||||
Seed: int32(c.Seed), |
|
||||||
NBatch: int32(b), |
|
||||||
F16Memory: c.F16, |
|
||||||
MLock: c.MMlock, |
|
||||||
NUMA: c.NUMA, |
|
||||||
Embeddings: c.Embeddings, |
|
||||||
LowVRAM: c.LowVRAM, |
|
||||||
NGPULayers: int32(c.NGPULayers), |
|
||||||
MMap: c.MMap, |
|
||||||
MainGPU: c.MainGPU, |
|
||||||
Threads: int32(c.Threads), |
|
||||||
TensorSplit: c.TensorSplit, |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func gRPCPredictOpts(c Config, modelPath string) *pb.PredictOptions { |
|
||||||
promptCachePath := "" |
|
||||||
if c.PromptCachePath != "" { |
|
||||||
p := filepath.Join(modelPath, c.PromptCachePath) |
|
||||||
os.MkdirAll(filepath.Dir(p), 0755) |
|
||||||
promptCachePath = p |
|
||||||
} |
|
||||||
return &pb.PredictOptions{ |
|
||||||
Temperature: float32(c.Temperature), |
|
||||||
TopP: float32(c.TopP), |
|
||||||
TopK: int32(c.TopK), |
|
||||||
Tokens: int32(c.Maxtokens), |
|
||||||
Threads: int32(c.Threads), |
|
||||||
PromptCacheAll: c.PromptCacheAll, |
|
||||||
PromptCacheRO: c.PromptCacheRO, |
|
||||||
PromptCachePath: promptCachePath, |
|
||||||
F16KV: c.F16, |
|
||||||
DebugMode: c.Debug, |
|
||||||
Grammar: c.Grammar, |
|
||||||
|
|
||||||
Mirostat: int32(c.Mirostat), |
|
||||||
MirostatETA: float32(c.MirostatETA), |
|
||||||
MirostatTAU: float32(c.MirostatTAU), |
|
||||||
Debug: c.Debug, |
|
||||||
StopPrompts: c.StopWords, |
|
||||||
Repeat: int32(c.RepeatPenalty), |
|
||||||
NKeep: int32(c.Keep), |
|
||||||
Batch: int32(c.Batch), |
|
||||||
IgnoreEOS: c.IgnoreEOS, |
|
||||||
Seed: int32(c.Seed), |
|
||||||
FrequencyPenalty: float32(c.FrequencyPenalty), |
|
||||||
MLock: c.MMlock, |
|
||||||
MMap: c.MMap, |
|
||||||
MainGPU: c.MainGPU, |
|
||||||
TensorSplit: c.TensorSplit, |
|
||||||
TailFreeSamplingZ: float32(c.TFZ), |
|
||||||
TypicalP: float32(c.TypicalP), |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negative_prompt, dst string, loader *model.ModelLoader, c Config, o *Option) (func() error, error) { |
|
||||||
if c.Backend != model.StableDiffusionBackend { |
|
||||||
return nil, fmt.Errorf("endpoint only working with stablediffusion models") |
|
||||||
} |
|
||||||
|
|
||||||
inferenceModel, err := loader.BackendLoader( |
|
||||||
model.WithBackendString(c.Backend), |
|
||||||
model.WithAssetDir(o.assetsDestination), |
|
||||||
model.WithThreads(uint32(c.Threads)), |
|
||||||
model.WithModelFile(c.ImageGenerationAssets), |
|
||||||
) |
|
||||||
if err != nil { |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
var fn func() error |
|
||||||
switch model := inferenceModel.(type) { |
|
||||||
case *stablediffusion.StableDiffusion: |
|
||||||
fn = func() error { |
|
||||||
return model.GenerateImage(height, width, mode, step, seed, positive_prompt, negative_prompt, dst) |
|
||||||
} |
|
||||||
|
|
||||||
default: |
|
||||||
fn = func() error { |
|
||||||
return fmt.Errorf("creation of images not supported by the backend") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return func() error { |
|
||||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
mutexMap.Lock() |
|
||||||
l, ok := mutexes[c.Backend] |
|
||||||
if !ok { |
|
||||||
m := &sync.Mutex{} |
|
||||||
mutexes[c.Backend] = m |
|
||||||
l = m |
|
||||||
} |
|
||||||
mutexMap.Unlock() |
|
||||||
l.Lock() |
|
||||||
defer l.Unlock() |
|
||||||
|
|
||||||
return fn() |
|
||||||
}, nil |
|
||||||
} |
|
||||||
|
|
||||||
func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config, o *Option) (func() ([]float32, error), error) { |
|
||||||
if !c.Embeddings { |
|
||||||
return nil, fmt.Errorf("endpoint disabled for this model by API configuration") |
|
||||||
} |
|
||||||
|
|
||||||
modelFile := c.Model |
|
||||||
|
|
||||||
grpcOpts := gRPCModelOpts(c) |
|
||||||
|
|
||||||
var inferenceModel interface{} |
|
||||||
var err error |
|
||||||
|
|
||||||
opts := []model.Option{ |
|
||||||
model.WithLoadGRPCOpts(grpcOpts), |
|
||||||
model.WithThreads(uint32(c.Threads)), |
|
||||||
model.WithAssetDir(o.assetsDestination), |
|
||||||
model.WithModelFile(modelFile), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Backend == "" { |
|
||||||
inferenceModel, err = loader.GreedyLoader(opts...) |
|
||||||
} else { |
|
||||||
opts = append(opts, model.WithBackendString(c.Backend)) |
|
||||||
inferenceModel, err = loader.BackendLoader(opts...) |
|
||||||
} |
|
||||||
if err != nil { |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
var fn func() ([]float32, error) |
|
||||||
switch model := inferenceModel.(type) { |
|
||||||
case *grpc.Client: |
|
||||||
fn = func() ([]float32, error) { |
|
||||||
predictOptions := gRPCPredictOpts(c, loader.ModelPath) |
|
||||||
if len(tokens) > 0 { |
|
||||||
embeds := []int32{} |
|
||||||
|
|
||||||
for _, t := range tokens { |
|
||||||
embeds = append(embeds, int32(t)) |
|
||||||
} |
|
||||||
predictOptions.EmbeddingTokens = embeds |
|
||||||
|
|
||||||
res, err := model.Embeddings(context.TODO(), predictOptions) |
|
||||||
if err != nil { |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
return res.Embeddings, nil |
|
||||||
} |
|
||||||
predictOptions.Embeddings = s |
|
||||||
|
|
||||||
res, err := model.Embeddings(context.TODO(), predictOptions) |
|
||||||
if err != nil { |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
return res.Embeddings, nil |
|
||||||
} |
|
||||||
|
|
||||||
// bert embeddings
|
|
||||||
case *bert.Bert: |
|
||||||
fn = func() ([]float32, error) { |
|
||||||
if len(tokens) > 0 { |
|
||||||
return model.TokenEmbeddings(tokens, bert.SetThreads(c.Threads)) |
|
||||||
} |
|
||||||
return model.Embeddings(s, bert.SetThreads(c.Threads)) |
|
||||||
} |
|
||||||
default: |
|
||||||
fn = func() ([]float32, error) { |
|
||||||
return nil, fmt.Errorf("embeddings not supported by the backend") |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return func() ([]float32, error) { |
|
||||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
mutexMap.Lock() |
|
||||||
l, ok := mutexes[modelFile] |
|
||||||
if !ok { |
|
||||||
m := &sync.Mutex{} |
|
||||||
mutexes[modelFile] = m |
|
||||||
l = m |
|
||||||
} |
|
||||||
mutexMap.Unlock() |
|
||||||
l.Lock() |
|
||||||
defer l.Unlock() |
|
||||||
|
|
||||||
embeds, err := fn() |
|
||||||
if err != nil { |
|
||||||
return embeds, err |
|
||||||
} |
|
||||||
// Remove trailing 0s
|
|
||||||
for i := len(embeds) - 1; i >= 0; i-- { |
|
||||||
if embeds[i] == 0.0 { |
|
||||||
embeds = embeds[:i] |
|
||||||
} else { |
|
||||||
break |
|
||||||
} |
|
||||||
} |
|
||||||
return embeds, nil |
|
||||||
}, nil |
|
||||||
} |
|
||||||
|
|
||||||
func ModelInference(s string, loader *model.ModelLoader, c Config, o *Option, tokenCallback func(string) bool) (func() (string, error), error) { |
|
||||||
supportStreams := false |
|
||||||
modelFile := c.Model |
|
||||||
|
|
||||||
grpcOpts := gRPCModelOpts(c) |
|
||||||
|
|
||||||
var inferenceModel interface{} |
|
||||||
var err error |
|
||||||
|
|
||||||
opts := []model.Option{ |
|
||||||
model.WithLoadGRPCOpts(grpcOpts), |
|
||||||
model.WithThreads(uint32(c.Threads)), // GPT4all uses this
|
|
||||||
model.WithAssetDir(o.assetsDestination), |
|
||||||
model.WithModelFile(modelFile), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Backend == "" { |
|
||||||
inferenceModel, err = loader.GreedyLoader(opts...) |
|
||||||
} else { |
|
||||||
opts = append(opts, model.WithBackendString(c.Backend)) |
|
||||||
inferenceModel, err = loader.BackendLoader(opts...) |
|
||||||
} |
|
||||||
if err != nil { |
|
||||||
return nil, err |
|
||||||
} |
|
||||||
|
|
||||||
var fn func() (string, error) |
|
||||||
|
|
||||||
switch model := inferenceModel.(type) { |
|
||||||
case *rwkv.RwkvState: |
|
||||||
supportStreams = true |
|
||||||
|
|
||||||
fn = func() (string, error) { |
|
||||||
stopWord := "\n" |
|
||||||
if len(c.StopWords) > 0 { |
|
||||||
stopWord = c.StopWords[0] |
|
||||||
} |
|
||||||
|
|
||||||
if err := model.ProcessInput(s); err != nil { |
|
||||||
return "", err |
|
||||||
} |
|
||||||
|
|
||||||
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback) |
|
||||||
|
|
||||||
return response, nil |
|
||||||
} |
|
||||||
case *bloomz.Bloomz: |
|
||||||
fn = func() (string, error) { |
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []bloomz.PredictOption{ |
|
||||||
bloomz.SetTemperature(c.Temperature), |
|
||||||
bloomz.SetTopP(c.TopP), |
|
||||||
bloomz.SetTopK(c.TopK), |
|
||||||
bloomz.SetTokens(c.Maxtokens), |
|
||||||
bloomz.SetThreads(c.Threads), |
|
||||||
} |
|
||||||
|
|
||||||
if c.Seed != 0 { |
|
||||||
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed)) |
|
||||||
} |
|
||||||
|
|
||||||
return model.Predict( |
|
||||||
s, |
|
||||||
predictOptions..., |
|
||||||
) |
|
||||||
} |
|
||||||
|
|
||||||
case *grpc.Client: |
|
||||||
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
|
||||||
supportStreams = true |
|
||||||
fn = func() (string, error) { |
|
||||||
|
|
||||||
opts := gRPCPredictOpts(c, loader.ModelPath) |
|
||||||
opts.Prompt = s |
|
||||||
if tokenCallback != nil { |
|
||||||
ss := "" |
|
||||||
err := model.PredictStream(context.TODO(), opts, func(s string) { |
|
||||||
tokenCallback(s) |
|
||||||
ss += s |
|
||||||
}) |
|
||||||
return ss, err |
|
||||||
} else { |
|
||||||
reply, err := model.Predict(context.TODO(), opts) |
|
||||||
return reply.Message, err |
|
||||||
} |
|
||||||
} |
|
||||||
case *langchain.HuggingFace: |
|
||||||
fn = func() (string, error) { |
|
||||||
|
|
||||||
// Generate the prediction using the language model
|
|
||||||
predictOptions := []langchain.PredictOption{ |
|
||||||
langchain.SetModel(c.Model), |
|
||||||
langchain.SetMaxTokens(c.Maxtokens), |
|
||||||
langchain.SetTemperature(c.Temperature), |
|
||||||
langchain.SetStopWords(c.StopWords), |
|
||||||
} |
|
||||||
|
|
||||||
pred, er := model.PredictHuggingFace(s, predictOptions...) |
|
||||||
if er != nil { |
|
||||||
return "", er |
|
||||||
} |
|
||||||
return pred.Completion, nil |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
return func() (string, error) { |
|
||||||
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
||||||
mutexMap.Lock() |
|
||||||
l, ok := mutexes[modelFile] |
|
||||||
if !ok { |
|
||||||
m := &sync.Mutex{} |
|
||||||
mutexes[modelFile] = m |
|
||||||
l = m |
|
||||||
} |
|
||||||
mutexMap.Unlock() |
|
||||||
l.Lock() |
|
||||||
defer l.Unlock() |
|
||||||
|
|
||||||
res, err := fn() |
|
||||||
if tokenCallback != nil && !supportStreams { |
|
||||||
tokenCallback(res) |
|
||||||
} |
|
||||||
return res, err |
|
||||||
}, nil |
|
||||||
} |
|
||||||
|
|
||||||
func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, o *Option, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { |
|
||||||
result := []Choice{} |
|
||||||
|
|
||||||
n := input.N |
|
||||||
|
|
||||||
if input.N == 0 { |
|
||||||
n = 1 |
|
||||||
} |
|
||||||
|
|
||||||
// get the model function to call for the result
|
|
||||||
predFunc, err := ModelInference(predInput, loader, *config, o, tokenCallback) |
|
||||||
if err != nil { |
|
||||||
return result, err |
|
||||||
} |
|
||||||
|
|
||||||
for i := 0; i < n; i++ { |
|
||||||
prediction, err := predFunc() |
|
||||||
if err != nil { |
|
||||||
return result, err |
|
||||||
} |
|
||||||
|
|
||||||
prediction = Finetune(*config, predInput, prediction) |
|
||||||
cb(prediction, &result) |
|
||||||
|
|
||||||
//result = append(result, Choice{Text: prediction})
|
|
||||||
|
|
||||||
} |
|
||||||
return result, err |
|
||||||
} |
|
||||||
|
|
||||||
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) |
|
||||||
var mu sync.Mutex = sync.Mutex{} |
|
||||||
|
|
||||||
func Finetune(config Config, input, prediction string) string { |
|
||||||
if config.Echo { |
|
||||||
prediction = input + prediction |
|
||||||
} |
|
||||||
|
|
||||||
for _, c := range config.Cutstrings { |
|
||||||
mu.Lock() |
|
||||||
reg, ok := cutstrings[c] |
|
||||||
if !ok { |
|
||||||
cutstrings[c] = regexp.MustCompile(c) |
|
||||||
reg = cutstrings[c] |
|
||||||
} |
|
||||||
mu.Unlock() |
|
||||||
prediction = reg.ReplaceAllString(prediction, "") |
|
||||||
} |
|
||||||
|
|
||||||
for _, c := range config.TrimSpace { |
|
||||||
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) |
|
||||||
} |
|
||||||
return prediction |
|
||||||
|
|
||||||
} |
|
Loading…
Reference in new issue