From 714bfcd45bd2e2d53a4a280e0f02ed83bd8ff9f3 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 4 May 2023 19:49:43 +0200 Subject: [PATCH] fix: missing returning error and free callback stream (#187) --- api/openai.go | 29 ++++++++++++++++------------- api/prediction.go | 7 ++++++- pkg/model/loader.go | 3 +-- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/api/openai.go b/api/openai.go index 1afbb06..02e4093 100644 --- a/api/openai.go +++ b/api/openai.go @@ -299,6 +299,21 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, } func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { + + process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { + ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { + resp := OpenAIResponse{ + Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{Delta: &Message{Role: "assistant", Content: s}}}, + Object: "chat.completion.chunk", + } + log.Debug().Msgf("Sending goroutine: %s", s) + + responses <- resp + return true + }) + close(responses) + } return func(c *fiber.Ctx) error { config, input, err := readConfig(cm, c, loader, debug, threads, ctx, f16) if err != nil { @@ -350,19 +365,7 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread if input.Stream { responses := make(chan OpenAIResponse) - go func() { - ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { - resp := OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Role: "assistant", Content: s}}}, - Object: "chat.completion.chunk", - } - - responses <- resp - return true - }) - close(responses) - }() + go process(predInput, input, config, loader, responses) c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { diff --git a/api/prediction.go b/api/prediction.go index 1fbb57b..5f6a8e2 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -261,10 +261,15 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) } - return model.Predict( + str, er := model.Predict( s, predictOptions..., ) + // Seems that if we don't free the callback explicitly we leave functions registered (that might try to send on closed channels) + // For instance otherwise the API returns: {"error":{"code":500,"message":"send on closed channel","type":""}} + // after a stream event has occurred + model.SetTokenCallback(nil) + return str, er } } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 7037e86..7a22b71 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -81,10 +81,9 @@ func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, if exists { m = t } - } if m == nil { - return "", nil + return "", fmt.Errorf("failed loading any template") } var buf bytes.Buffer