From 67992a7d99e30c09054f5e1c0f253b3cd1b9c492 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Wed, 3 May 2023 13:13:31 +0200 Subject: [PATCH] feat: support slices or strings in the prompt completion endpoint (#162) Signed-off-by: mudler --- api/openai.go | 56 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/api/openai.go b/api/openai.go index 63b2b32..c1d4001 100644 --- a/api/openai.go +++ b/api/openai.go @@ -57,7 +57,7 @@ type OpenAIRequest struct { Model string `json:"model" yaml:"model"` // Prompt is read only by completion API calls - Prompt string `json:"prompt" yaml:"prompt"` + Prompt interface{} `json:"prompt" yaml:"prompt"` // Edit endpoint Instruction string `json:"instruction" yaml:"instruction"` @@ -122,9 +122,12 @@ func updateConfig(config *Config, input *OpenAIRequest) { if stop != "" { config.StopWords = append(config.StopWords, stop) } - case []string: - config.StopWords = append(config.StopWords, stop...) - + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } } if input.RepeatPenalty != 0 { @@ -234,27 +237,44 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, log.Debug().Msgf("Parameter Config: %+v", config) - predInput := input.Prompt + predInput := []string{} + + switch p := input.Prompt.(type) { + case string: + predInput = append(predInput, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + predInput = append(predInput, s) + } + } + } + templateFile := config.Model if config.TemplateConfig.Completion != "" { templateFile = config.TemplateConfig.Completion } - // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(templateFile, struct { - Input string - }{Input: predInput}) - if err == nil { - predInput = templatedInput - log.Debug().Msgf("Template found, input modified to: %s", predInput) - } + var result []Choice + for _, i := range predInput { + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := loader.TemplatePrefix(templateFile, struct { + Input string + }{Input: i}) + if err == nil { + i = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", i) + } - result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Text: s}) - }, nil) - if err != nil { - return err + r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Text: s}) + }, nil) + if err != nil { + return err + } + + result = append(result, r...) } resp := &OpenAIResponse{