From 96794851b3a638b37ca9041efe03df1cae67bd5b Mon Sep 17 00:00:00 2001 From: Samuel Maynard Date: Fri, 2 Jun 2023 17:27:03 -0500 Subject: [PATCH] feat: add support for `Stream: true` to completionEndpoint (#465) --- api/openai.go | 74 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/api/openai.go b/api/openai.go index b97b4e5..cb93510 100644 --- a/api/openai.go +++ b/api/openai.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/base64" + "errors" "encoding/json" "fmt" "io" @@ -143,13 +144,29 @@ func defaultRequest(modelFile string) OpenAIRequest { // https://platform.openai.com/docs/api-reference/completions func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { + process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { + ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { + resp := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{Text: s}}, + Object: "text_completion", + } + log.Debug().Msgf("Sending goroutine: %s", s) + + responses <- resp + return true + }) + close(responses) + } + return func(c *fiber.Ctx) error { model, input, err := readInput(c, o.loader, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } + log.Debug().Msgf("`input`: %+v", input) + config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) @@ -157,12 +174,67 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { log.Debug().Msgf("Parameter Config: %+v", config) + if input.Stream { + log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") + //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) + //c.Set("Content-Type", "text/event-stream") + c.Set("Cache-Control", "no-cache") + c.Set("Connection", "keep-alive") + c.Set("Transfer-Encoding", "chunked") + } + templateFile := config.Model if config.TemplateConfig.Completion != "" { templateFile = config.TemplateConfig.Completion } + if input.Stream { + if (len(config.PromptStrings) > 1) { + return errors.New("cannot handle more than 1 `PromptStrings` when `Stream`ing") + } + + predInput := config.PromptStrings[0] + + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix + templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { + Input string + }{Input: predInput}) + if err == nil { + predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) + } + + responses := make(chan OpenAIResponse) + + go process(predInput, input, config, o.loader, responses) + + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { + + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) + + log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) + w.Flush() + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{FinishReason: "stop"}}, + } + respData, _ := json.Marshal(resp) + + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") + w.Flush() + })) + return nil + } + var result []Choice for _, i := range config.PromptStrings { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix