From 4aa78843c0b15764491e3b12be2a8c0b9a1c0aa8 Mon Sep 17 00:00:00 2001 From: Robert Hambrock Date: Sun, 21 May 2023 15:24:04 +0200 Subject: [PATCH] fix: spec compliant instantiation and termination of streams (#341) --- api/openai.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/api/openai.go b/api/openai.go index dffdcbf..b97b4e5 100644 --- a/api/openai.go +++ b/api/openai.go @@ -259,10 +259,17 @@ func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { + initialMessage := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{Delta: &Message{Role: "assistant"}}}, + Object: "chat.completion.chunk", + } + responses <- initialMessage + ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { resp := OpenAIResponse{ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Role: "assistant", Content: s}}}, + Choices: []Choice{{Delta: &Message{Content: s}}}, Object: "chat.completion.chunk", } log.Debug().Msgf("Sending goroutine: %s", s) @@ -339,13 +346,11 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { enc := json.NewEncoder(&buf) enc.Encode(ev) - fmt.Fprintf(w, "event: data\n\n") - fmt.Fprintf(w, "data: %v\n\n", buf.String()) log.Debug().Msgf("Sending chunk: %s", buf.String()) + fmt.Fprintf(w, "data: %v\n", buf.String()) w.Flush() } - w.WriteString("event: data\n\n") resp := &OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []Choice{{FinishReason: "stop"}}, @@ -353,6 +358,7 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { respData, _ := json.Marshal(resp) w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) + w.WriteString("data: [DONE]\n\n") w.Flush() })) return nil