@ -10,22 +10,86 @@ import (
gpt2 "github.com/go-skynet/go-gpt2.cpp"
gptj "github.com/go-skynet/go-gpt4all-j.cpp"
llama "github.com/go-skynet/go-llama.cpp"
"github.com/hashicorp/go-multierror"
)
// mutex still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
var mutexMap sync . Mutex
var mutexes map [ string ] * sync . Mutex = make ( map [ string ] * sync . Mutex )
var loadedModels map [ string ] interface { } = map [ string ] interface { } { }
var muModels sync . Mutex
func backendLoader ( backendString string , loader * model . ModelLoader , modelFile string , llamaOpts [ ] llama . ModelOption ) ( model interface { } , err error ) {
switch strings . ToLower ( backendString ) {
case "llama" :
return loader . LoadLLaMAModel ( modelFile , llamaOpts ... )
case "stablelm" :
return loader . LoadStableLMModel ( modelFile )
case "gpt2" :
return loader . LoadGPT2Model ( modelFile )
case "gptj" :
return loader . LoadGPTJModel ( modelFile )
default :
return nil , fmt . Errorf ( "backend unsupported: %s" , backendString )
}
}
func greedyLoader ( loader * model . ModelLoader , modelFile string , llamaOpts [ ] llama . ModelOption ) ( model interface { } , err error ) {
updateModels := func ( model interface { } ) {
muModels . Lock ( )
defer muModels . Unlock ( )
loadedModels [ modelFile ] = model
}
muModels . Lock ( )
m , exists := loadedModels [ modelFile ]
if exists {
muModels . Unlock ( )
return m , nil
}
muModels . Unlock ( )
model , modelerr := loader . LoadLLaMAModel ( modelFile , llamaOpts ... )
if modelerr == nil {
updateModels ( model )
return model , nil
} else {
err = multierror . Append ( err , modelerr )
}
model , modelerr = loader . LoadGPTJModel ( modelFile )
if modelerr == nil {
updateModels ( model )
return model , nil
} else {
err = multierror . Append ( err , modelerr )
}
model , modelerr = loader . LoadGPT2Model ( modelFile )
if modelerr == nil {
updateModels ( model )
return model , nil
} else {
err = multierror . Append ( err , modelerr )
}
model , modelerr = loader . LoadStableLMModel ( modelFile )
if modelerr == nil {
updateModels ( model )
return model , nil
} else {
err = multierror . Append ( err , modelerr )
}
return nil , fmt . Errorf ( "could not load model - all backends returned error: %s" , err . Error ( ) )
}
func ModelInference ( s string , loader * model . ModelLoader , c Config , tokenCallback func ( string ) bool ) ( func ( ) ( string , error ) , error ) {
var model * llama . LLama
var gptModel * gptj . GPTJ
var gpt2Model * gpt2 . GPT2
var stableLMModel * gpt2 . StableLM
supportStreams := false
modelFile := c . Model
// Try to load the model
var llamaerr , gpt2err , gptjerr , stableerr error
llamaOpts := [ ] llama . ModelOption { }
if c . ContextSize != 0 {
llamaOpts = append ( llamaOpts , llama . SetContext ( c . ContextSize ) )
@ -34,25 +98,21 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
llamaOpts = append ( llamaOpts , llama . EnableF16Memory )
}
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
model , llamaerr = loader . LoadLLaMAModel ( modelFile , llamaOpts ... )
if llamaerr != nil {
gptModel , gptjerr = loader . LoadGPTJModel ( modelFile )
if gptjerr != nil {
gpt2Model , gpt2err = loader . LoadGPT2Model ( modelFile )
if gpt2err != nil {
stableLMModel , stableerr = loader . LoadStableLMModel ( modelFile )
if stableerr != nil {
return nil , fmt . Errorf ( "llama: %s gpt: %s gpt2: %s stableLM: %s" , llamaerr . Error ( ) , gptjerr . Error ( ) , gpt2err . Error ( ) , stableerr . Error ( ) ) // llama failed first, so we want to catch both errors
}
}
var inferenceModel interface { }
var err error
if c . Backend == "" {
inferenceModel , err = greedyLoader ( loader , modelFile , llamaOpts )
} else {
inferenceModel , err = backendLoader ( c . Backend , loader , modelFile , llamaOpts )
}
if err != nil {
return nil , err
}
var fn func ( ) ( string , error )
switch {
case stableLMModel != nil :
switch model := inferenceModel . ( type ) {
case * gpt2 . StableLM :
fn = func ( ) ( string , error ) {
// Generate the prediction using the language model
predictOptions := [ ] gpt2 . PredictOption {
@ -71,12 +131,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
predictOptions = append ( predictOptions , gpt2 . SetSeed ( c . Seed ) )
}
return stableLMM odel. Predict (
return m odel. Predict (
s ,
predictOptions ... ,
)
}
case gpt2Model != nil :
case * gpt2 . GPT2 :
fn = func ( ) ( string , error ) {
// Generate the prediction using the language model
predictOptions := [ ] gpt2 . PredictOption {
@ -95,12 +155,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
predictOptions = append ( predictOptions , gpt2 . SetSeed ( c . Seed ) )
}
return gpt2M odel. Predict (
return m odel. Predict (
s ,
predictOptions ... ,
)
}
case gptModel != nil :
case * gptj . GPTJ :
fn = func ( ) ( string , error ) {
// Generate the prediction using the language model
predictOptions := [ ] gptj . PredictOption {
@ -119,12 +179,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
predictOptions = append ( predictOptions , gptj . SetSeed ( c . Seed ) )
}
return gptM odel. Predict (
return m odel. Predict (
s ,
predictOptions ... ,
)
}
case model != nil :
case * llama . LLama :
supportStreams = true
fn = func ( ) ( string , error ) {