You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
160 lines
3.8 KiB
160 lines
3.8 KiB
package backend
|
|
|
|
import (
|
|
"context"
|
|
"regexp"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/donomii/go-rwkv.cpp"
|
|
config "github.com/go-skynet/LocalAI/api/config"
|
|
"github.com/go-skynet/LocalAI/api/options"
|
|
"github.com/go-skynet/LocalAI/pkg/grpc"
|
|
"github.com/go-skynet/LocalAI/pkg/langchain"
|
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
|
"github.com/go-skynet/bloomz.cpp"
|
|
)
|
|
|
|
func ModelInference(s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
|
|
supportStreams := false
|
|
modelFile := c.Model
|
|
|
|
grpcOpts := gRPCModelOpts(c)
|
|
|
|
var inferenceModel interface{}
|
|
var err error
|
|
|
|
opts := []model.Option{
|
|
model.WithLoadGRPCOpts(grpcOpts),
|
|
model.WithThreads(uint32(c.Threads)), // GPT4all uses this
|
|
model.WithAssetDir(o.AssetsDestination),
|
|
model.WithModelFile(modelFile),
|
|
}
|
|
|
|
if c.Backend == "" {
|
|
inferenceModel, err = loader.GreedyLoader(opts...)
|
|
} else {
|
|
opts = append(opts, model.WithBackendString(c.Backend))
|
|
inferenceModel, err = loader.BackendLoader(opts...)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var fn func() (string, error)
|
|
|
|
switch model := inferenceModel.(type) {
|
|
case *rwkv.RwkvState:
|
|
supportStreams = true
|
|
|
|
fn = func() (string, error) {
|
|
stopWord := "\n"
|
|
if len(c.StopWords) > 0 {
|
|
stopWord = c.StopWords[0]
|
|
}
|
|
|
|
if err := model.ProcessInput(s); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
response := model.GenerateResponse(c.Maxtokens, stopWord, float32(c.Temperature), float32(c.TopP), tokenCallback)
|
|
|
|
return response, nil
|
|
}
|
|
case *bloomz.Bloomz:
|
|
fn = func() (string, error) {
|
|
// Generate the prediction using the language model
|
|
predictOptions := []bloomz.PredictOption{
|
|
bloomz.SetTemperature(c.Temperature),
|
|
bloomz.SetTopP(c.TopP),
|
|
bloomz.SetTopK(c.TopK),
|
|
bloomz.SetTokens(c.Maxtokens),
|
|
bloomz.SetThreads(c.Threads),
|
|
}
|
|
|
|
if c.Seed != 0 {
|
|
predictOptions = append(predictOptions, bloomz.SetSeed(c.Seed))
|
|
}
|
|
|
|
return model.Predict(
|
|
s,
|
|
predictOptions...,
|
|
)
|
|
}
|
|
|
|
case *grpc.Client:
|
|
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
|
|
supportStreams = true
|
|
fn = func() (string, error) {
|
|
|
|
opts := gRPCPredictOpts(c, loader.ModelPath)
|
|
opts.Prompt = s
|
|
if tokenCallback != nil {
|
|
ss := ""
|
|
err := model.PredictStream(context.TODO(), opts, func(s string) {
|
|
tokenCallback(s)
|
|
ss += s
|
|
})
|
|
return ss, err
|
|
} else {
|
|
reply, err := model.Predict(context.TODO(), opts)
|
|
return reply.Message, err
|
|
}
|
|
}
|
|
case *langchain.HuggingFace:
|
|
fn = func() (string, error) {
|
|
|
|
// Generate the prediction using the language model
|
|
predictOptions := []langchain.PredictOption{
|
|
langchain.SetModel(c.Model),
|
|
langchain.SetMaxTokens(c.Maxtokens),
|
|
langchain.SetTemperature(c.Temperature),
|
|
langchain.SetStopWords(c.StopWords),
|
|
}
|
|
|
|
pred, er := model.PredictHuggingFace(s, predictOptions...)
|
|
if er != nil {
|
|
return "", er
|
|
}
|
|
return pred.Completion, nil
|
|
}
|
|
}
|
|
|
|
return func() (string, error) {
|
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
l := Lock(modelFile)
|
|
defer l.Unlock()
|
|
|
|
res, err := fn()
|
|
if tokenCallback != nil && !supportStreams {
|
|
tokenCallback(res)
|
|
}
|
|
return res, err
|
|
}, nil
|
|
}
|
|
|
|
var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
|
|
var mu sync.Mutex = sync.Mutex{}
|
|
|
|
func Finetune(config config.Config, input, prediction string) string {
|
|
if config.Echo {
|
|
prediction = input + prediction
|
|
}
|
|
|
|
for _, c := range config.Cutstrings {
|
|
mu.Lock()
|
|
reg, ok := cutstrings[c]
|
|
if !ok {
|
|
cutstrings[c] = regexp.MustCompile(c)
|
|
reg = cutstrings[c]
|
|
}
|
|
mu.Unlock()
|
|
prediction = reg.ReplaceAllString(prediction, "")
|
|
}
|
|
|
|
for _, c := range config.TrimSpace {
|
|
prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c))
|
|
}
|
|
return prediction
|
|
|
|
}
|
|
|