From 9fb581739bc9a81b1461c654379badb11a22aa24 Mon Sep 17 00:00:00 2001 From: mudler Date: Sat, 8 Apr 2023 10:46:51 +0200 Subject: [PATCH] Allow to template model prompts inputs --- api.go | 12 +++++++++-- model_loader.go | 55 +++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 59 insertions(+), 8 deletions(-) diff --git a/api.go b/api.go index 18c833b..06ec8de 100644 --- a/api.go +++ b/api.go @@ -103,10 +103,18 @@ func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, thre mess = append(mess, i.Content) } - fmt.Println("Received", input, input.Model) + predInput := strings.Join(mess, "\n") + + templatedInput, err := loader.TemplatePrefix(input.Model, struct { + Input string + }{Input: predInput}) + if err == nil { + predInput = templatedInput + } + // Generate the prediction using the language model prediction, err := model.Predict( - strings.Join(mess, "\n"), + templatedInput, llama.SetTemperature(temperature), llama.SetTopP(topP), llama.SetTopK(topK), diff --git a/model_loader.go b/model_loader.go index 13c860f..7c87079 100644 --- a/model_loader.go +++ b/model_loader.go @@ -1,30 +1,55 @@ package main import ( + "bytes" "fmt" "os" "path/filepath" "sync" + "text/template" llama "github.com/go-skynet/go-llama.cpp" ) type ModelLoader struct { - modelPath string - mu sync.Mutex - models map[string]*llama.LLama + modelPath string + mu sync.Mutex + models map[string]*llama.LLama + promptsTemplates map[string]*template.Template } func NewModelLoader(modelPath string) *ModelLoader { - return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama)} + return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)} } -func (ml *ModelLoader) LoadModel(s string, opts ...llama.ModelOption) (*llama.LLama, error) { +func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, error) { + ml.mu.Lock() + defer ml.mu.Unlock() + + m, ok := ml.promptsTemplates[modelName] + if !ok { + // try to find a s.bin + modelBin := fmt.Sprintf("%s.bin", modelName) + m, ok = ml.promptsTemplates[modelBin] + if !ok { + return "", fmt.Errorf("no prompt template available") + } + } + + var buf bytes.Buffer + + if err := m.Execute(&buf, in); err != nil { + return "", err + } + return buf.String(), nil +} + +func (ml *ModelLoader) LoadModel(modelName string, opts ...llama.ModelOption) (*llama.LLama, error) { ml.mu.Lock() defer ml.mu.Unlock() // Check if we already have a loaded model - modelFile := filepath.Join(ml.modelPath, s) + modelFile := filepath.Join(ml.modelPath, modelName) if m, ok := ml.models[modelFile]; ok { return m, nil @@ -47,6 +72,24 @@ func (ml *ModelLoader) LoadModel(s string, opts ...llama.ModelOption) (*llama.LL return nil, err } + // If there is a prompt template, load it + + modelTemplateFile := fmt.Sprintf("%s.tmpl", modelFile) + // Check if the model path exists + if _, err := os.Stat(modelTemplateFile); err == nil { + dat, err := os.ReadFile(modelTemplateFile) + if err != nil { + return nil, err + } + + // Parse the template + tmpl, err := template.New("prompt").Parse(string(dat)) + if err != nil { + return nil, err + } + ml.promptsTemplates[modelFile] = tmpl + } + ml.models[modelFile] = model return model, err }