diff --git a/api.go b/api.go index 79e4e0f..98a6d3f 100644 --- a/api.go +++ b/api.go @@ -64,6 +64,7 @@ func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, thre input := new(struct { Messages []Message `json:"messages"` Model string `json:"model"` + Prompt string `json:"prompt"` }) if err := c.BodyParser(input); err != nil { return err @@ -126,12 +127,16 @@ func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, thre predInput := strings.Join(mess, "\n") - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(input.Model, struct { - Input string - }{Input: predInput}) - if err == nil { - predInput = templatedInput + if input.Prompt == "" { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(input.Model, struct { + Input string + }{Input: predInput}) + if err == nil { + predInput = templatedInput + } + } else { + predInput = input.Prompt + predInput } // Generate the prediction using the language model