@ -21,15 +21,23 @@ type ModelLoader struct {
modelPath string
mu sync . Mutex
models map [ string ] * llama . LLama
gptmodels map [ string ] * gptj . GPTJ
gpt2models map [ string ] * gpt2 . GPT2
models map [ string ] * llama . LLama
gptmodels map [ string ] * gptj . GPTJ
gpt2models map [ string ] * gpt2 . GPT2
gptstablelmmodels map [ string ] * gpt2 . StableLM
promptsTemplates map [ string ] * template . Template
}
func NewModelLoader ( modelPath string ) * ModelLoader {
return & ModelLoader { modelPath : modelPath , gpt2models : make ( map [ string ] * gpt2 . GPT2 ) , gptmodels : make ( map [ string ] * gptj . GPTJ ) , models : make ( map [ string ] * llama . LLama ) , promptsTemplates : make ( map [ string ] * template . Template ) }
return & ModelLoader {
modelPath : modelPath ,
gpt2models : make ( map [ string ] * gpt2 . GPT2 ) ,
gptmodels : make ( map [ string ] * gptj . GPTJ ) ,
gptstablelmmodels : make ( map [ string ] * gpt2 . StableLM ) ,
models : make ( map [ string ] * llama . LLama ) ,
promptsTemplates : make ( map [ string ] * template . Template ) ,
}
}
func ( ml * ModelLoader ) ExistsInModelPath ( s string ) bool {
@ -102,6 +110,38 @@ func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error {
return nil
}
func ( ml * ModelLoader ) LoadStableLMModel ( modelName string ) ( * gpt2 . StableLM , error ) {
ml . mu . Lock ( )
defer ml . mu . Unlock ( )
// Check if we already have a loaded model
if ! ml . ExistsInModelPath ( modelName ) {
return nil , fmt . Errorf ( "model does not exist" )
}
if m , ok := ml . gptstablelmmodels [ modelName ] ; ok {
log . Debug ( ) . Msgf ( "Model already loaded in memory: %s" , modelName )
return m , nil
}
// Load the model and keep it in memory for later use
modelFile := filepath . Join ( ml . modelPath , modelName )
log . Debug ( ) . Msgf ( "Loading model in memory from file: %s" , modelFile )
model , err := gpt2 . NewStableLM ( modelFile )
if err != nil {
return nil , err
}
// If there is a prompt template, load it
if err := ml . loadTemplateIfExists ( modelName , modelFile ) ; err != nil {
return nil , err
}
ml . gptstablelmmodels [ modelName ] = model
return model , err
}
func ( ml * ModelLoader ) LoadGPT2Model ( modelName string ) ( * gpt2 . GPT2 , error ) {
ml . mu . Lock ( )
defer ml . mu . Unlock ( )
@ -116,6 +156,13 @@ func (ml *ModelLoader) LoadGPT2Model(modelName string) (*gpt2.GPT2, error) {
return m , nil
}
// TODO: This needs refactoring, it's really bad to have it in here
// Check if we have a GPTStable model loaded instead - if we do we return an error so the API tries with StableLM
if _ , ok := ml . gptstablelmmodels [ modelName ] ; ok {
log . Debug ( ) . Msgf ( "Model is GPTStableLM: %s" , modelName )
return nil , fmt . Errorf ( "this model is a GPTStableLM one" )
}
// Load the model and keep it in memory for later use
modelFile := filepath . Join ( ml . modelPath , modelName )
log . Debug ( ) . Msgf ( "Loading model in memory from file: %s" , modelFile )
@ -154,6 +201,10 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) {
log . Debug ( ) . Msgf ( "Model is GPT2: %s" , modelName )
return nil , fmt . Errorf ( "this model is a GPT2 one" )
}
if _ , ok := ml . gptstablelmmodels [ modelName ] ; ok {
log . Debug ( ) . Msgf ( "Model is GPTStableLM: %s" , modelName )
return nil , fmt . Errorf ( "this model is a GPTStableLM one" )
}
// Load the model and keep it in memory for later use
modelFile := filepath . Join ( ml . modelPath , modelName )
@ -199,6 +250,10 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio
log . Debug ( ) . Msgf ( "Model is GPT2: %s" , modelName )
return nil , fmt . Errorf ( "this model is a GPT2 one" )
}
if _ , ok := ml . gptstablelmmodels [ modelName ] ; ok {
log . Debug ( ) . Msgf ( "Model is GPTStableLM: %s" , modelName )
return nil , fmt . Errorf ( "this model is a GPTStableLM one" )
}
// Load the model and keep it in memory for later use
modelFile := filepath . Join ( ml . modelPath , modelName )