package api import ( "fmt" "regexp" "strings" "sync" model "github.com/go-skynet/LocalAI/pkg/model" gpt2 "github.com/go-skynet/go-gpt2.cpp" gptj "github.com/go-skynet/go-gpt4all-j.cpp" llama "github.com/go-skynet/go-llama.cpp" ) // mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 var mutexMap sync.Mutex var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (string, error), error) { var model *llama.LLama var gptModel *gptj.GPTJ var gpt2Model *gpt2.GPT2 var stableLMModel *gpt2.StableLM modelFile := c.Model // Try to load the model var llamaerr, gpt2err, gptjerr, stableerr error llamaOpts := []llama.ModelOption{} if c.ContextSize != 0 { llamaOpts = append(llamaOpts, llama.SetContext(c.ContextSize)) } if c.F16 { llamaOpts = append(llamaOpts, llama.EnableF16Memory) } // TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation.. model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...) if llamaerr != nil { gptModel, gptjerr = loader.LoadGPTJModel(modelFile) if gptjerr != nil { gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile) if gpt2err != nil { stableLMModel, stableerr = loader.LoadStableLMModel(modelFile) if stableerr != nil { return nil, fmt.Errorf("llama: %s gpt: %s gpt2: %s stableLM: %s", llamaerr.Error(), gptjerr.Error(), gpt2err.Error(), stableerr.Error()) // llama failed first, so we want to catch both errors } } } } var fn func() (string, error) switch { case stableLMModel != nil: fn = func() (string, error) { // Generate the prediction using the language model predictOptions := []gpt2.PredictOption{ gpt2.SetTemperature(c.Temperature), gpt2.SetTopP(c.TopP), gpt2.SetTopK(c.TopK), gpt2.SetTokens(c.Maxtokens), gpt2.SetThreads(c.Threads), } if c.Batch != 0 { predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch)) } if c.Seed != 0 { predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed)) } return stableLMModel.Predict( s, predictOptions..., ) } case gpt2Model != nil: fn = func() (string, error) { // Generate the prediction using the language model predictOptions := []gpt2.PredictOption{ gpt2.SetTemperature(c.Temperature), gpt2.SetTopP(c.TopP), gpt2.SetTopK(c.TopK), gpt2.SetTokens(c.Maxtokens), gpt2.SetThreads(c.Threads), } if c.Batch != 0 { predictOptions = append(predictOptions, gpt2.SetBatch(c.Batch)) } if c.Seed != 0 { predictOptions = append(predictOptions, gpt2.SetSeed(c.Seed)) } return gpt2Model.Predict( s, predictOptions..., ) } case gptModel != nil: fn = func() (string, error) { // Generate the prediction using the language model predictOptions := []gptj.PredictOption{ gptj.SetTemperature(c.Temperature), gptj.SetTopP(c.TopP), gptj.SetTopK(c.TopK), gptj.SetTokens(c.Maxtokens), gptj.SetThreads(c.Threads), } if c.Batch != 0 { predictOptions = append(predictOptions, gptj.SetBatch(c.Batch)) } if c.Seed != 0 { predictOptions = append(predictOptions, gptj.SetSeed(c.Seed)) } return gptModel.Predict( s, predictOptions..., ) } case model != nil: fn = func() (string, error) { // Generate the prediction using the language model predictOptions := []llama.PredictOption{ llama.SetTemperature(c.Temperature), llama.SetTopP(c.TopP), llama.SetTopK(c.TopK), llama.SetTokens(c.Maxtokens), llama.SetThreads(c.Threads), } if c.Debug { predictOptions = append(predictOptions, llama.Debug) } predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) if c.RepeatPenalty != 0 { predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) } if c.Keep != 0 { predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) } if c.Batch != 0 { predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) } if c.F16 { predictOptions = append(predictOptions, llama.EnableF16KV) } if c.IgnoreEOS { predictOptions = append(predictOptions, llama.IgnoreEOS) } if c.Seed != 0 { predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) } return model.Predict( s, predictOptions..., ) } } return func() (string, error) { // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 mutexMap.Lock() l, ok := mutexes[modelFile] if !ok { m := &sync.Mutex{} mutexes[modelFile] = m l = m } mutexMap.Unlock() l.Lock() defer l.Unlock() return fn() }, nil } func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice)) ([]Choice, error) { result := []Choice{} n := input.N if input.N == 0 { n = 1 } // get the model function to call for the result predFunc, err := ModelInference(predInput, loader, *config) if err != nil { return result, err } for i := 0; i < n; i++ { prediction, err := predFunc() if err != nil { return result, err } prediction = Finetune(*config, predInput, prediction) cb(prediction, &result) //result = append(result, Choice{Text: prediction}) } return result, err } var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) var mu sync.Mutex = sync.Mutex{} func Finetune(config Config, input, prediction string) string { if config.Echo { prediction = input + prediction } for _, c := range config.Cutstrings { mu.Lock() reg, ok := cutstrings[c] if !ok { cutstrings[c] = regexp.MustCompile(c) reg = cutstrings[c] } mu.Unlock() prediction = reg.ReplaceAllString(prediction, "") } for _, c := range config.TrimSpace { prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) } return prediction }