diff --git a/model_loader.go b/model_loader.go index 7c87079..1548f4f 100644 --- a/model_loader.go +++ b/model_loader.go @@ -62,6 +62,7 @@ func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (* if _, err := os.Stat(modelBin); os.IsNotExist(err) { return nil, err } else { + modelName = fmt.Sprintf("%s.bin", modelName) modelFile = modelBin } } @@ -87,7 +88,7 @@ func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (* if err != nil { return nil, err } - ml.promptsTemplates[modelFile] = tmpl + ml.promptsTemplates[modelName] = tmpl } ml.models[modelFile] = model