@ -6,6 +6,7 @@ import (
"strings"
"sync"
"github.com/donomii/go-rwkv.cpp"
model "github.com/go-skynet/LocalAI/pkg/model"
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
@ -13,6 +14,8 @@ import (
"github.com/hashicorp/go-multierror"
)
const tokenizerSuffix = ".tokenizer.json"
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
var mutexMap sync . Mutex
var mutexes map [ string ] * sync . Mutex = make ( map [ string ] * sync . Mutex )
@ -20,7 +23,7 @@ var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex)
var loadedModels map [ string ] interface { } = map [ string ] interface { } { }
var muModels sync . Mutex
func backendLoader ( backendString string , loader * model . ModelLoader , modelFile string , llamaOpts [ ] llama . ModelOption ) ( model interface { } , err error ) {
func backendLoader ( backendString string , loader * model . ModelLoader , modelFile string , llamaOpts [ ] llama . ModelOption , threads uint32 ) ( model interface { } , err error ) {
switch strings . ToLower ( backendString ) {
case "llama" :
return loader . LoadLLaMAModel ( modelFile , llamaOpts ... )
@ -30,12 +33,14 @@ func backendLoader(backendString string, loader *model.ModelLoader, modelFile st
return loader . LoadGPT2Model ( modelFile )
case "gptj" :
return loader . LoadGPTJModel ( modelFile )
case "rwkv" :
return loader . LoadRWKV ( modelFile , modelFile + tokenizerSuffix , threads )
default :
return nil , fmt . Errorf ( "backend unsupported: %s" , backendString )
}
}
func greedyLoader ( loader * model . ModelLoader , modelFile string , llamaOpts [ ] llama . ModelOption ) ( model interface { } , err error ) {
func greedyLoader ( loader * model . ModelLoader , modelFile string , llamaOpts [ ] llama . ModelOption , threads uint32 ) ( model interface { } , err error ) {
updateModels := func ( model interface { } ) {
muModels . Lock ( )
defer muModels . Unlock ( )
@ -82,6 +87,14 @@ func greedyLoader(loader *model.ModelLoader, modelFile string, llamaOpts []llama
err = multierror . Append ( err , modelerr )
}
model , modelerr = loader . LoadRWKV ( modelFile , modelFile + tokenizerSuffix , threads )
if modelerr == nil {
updateModels ( model )
return model , nil
} else {
err = multierror . Append ( err , modelerr )
}
return nil , fmt . Errorf ( "could not load model - all backends returned error: %s" , err . Error ( ) )
}
@ -101,9 +114,9 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
var inferenceModel interface { }
var err error
if c . Backend == "" {
inferenceModel , err = greedyLoader ( loader , modelFile , llamaOpts )
inferenceModel , err = greedyLoader ( loader , modelFile , llamaOpts , uint32 ( c . Threads ) )
} else {
inferenceModel , err = backendLoader ( c . Backend , loader , modelFile , llamaOpts )
inferenceModel , err = backendLoader ( c . Backend , loader , modelFile , llamaOpts , uint32 ( c . Threads ) )
}
if err != nil {
return nil , err
@ -112,6 +125,20 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
var fn func ( ) ( string , error )
switch model := inferenceModel . ( type ) {
case * rwkv . RwkvState :
supportStreams = true
fn = func ( ) ( string , error ) {
//model.ProcessInput("You are a chatbot that is very good at chatting. blah blah blah")
stopWord := "\n"
if len ( c . StopWords ) > 0 {
stopWord = c . StopWords [ 0 ]
}
response := model . GenerateResponse ( c . Maxtokens , stopWord , float32 ( c . Temperature ) , float32 ( c . TopP ) , tokenCallback )
return response , nil
}
case * gpt2 . StableLM :
fn = func ( ) ( string , error ) {
// Generate the prediction using the language model