From 220d6fd59b2c374629ea7ab6afcef7680a1809df Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Tue, 2 May 2023 20:03:35 +0200 Subject: [PATCH] feat: add stream events (#152) --- .github/workflows/image.yml | 4 +- api/openai.go | 79 +++++++++++++++++++++---------------- api/prediction.go | 20 +++++++--- 3 files changed, 62 insertions(+), 41 deletions(-) diff --git a/.github/workflows/image.yml b/.github/workflows/image.yml index c159095..d83f58d 100644 --- a/.github/workflows/image.yml +++ b/.github/workflows/image.yml @@ -54,8 +54,8 @@ jobs: uses: docker/login-action@v2 with: registry: quay.io - username: ${{ secrets.QUAY_USERNAME }} - password: ${{ secrets.QUAY_PASSWORD }} + username: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} + password: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} - name: Build if: github.event_name != 'pull_request' uses: docker/build-push-action@v4 diff --git a/api/openai.go b/api/openai.go index 80322eb..50aa503 100644 --- a/api/openai.go +++ b/api/openai.go @@ -2,6 +2,7 @@ package api import ( "bufio" + "bytes" "encoding/json" "fmt" "os" @@ -245,7 +246,7 @@ func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { *c = append(*c, Choice{Text: s}) - }) + }, nil) if err != nil { return err } @@ -290,8 +291,9 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread if input.Stream { log.Debug().Msgf("Stream request received") + c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) - c.Set("Content-Type", "text/event-stream; charset=utf-8") + // c.Set("Content-Type", "text/event-stream") c.Set("Cache-Control", "no-cache") c.Set("Connection", "keep-alive") c.Set("Transfer-Encoding", "chunked") @@ -312,53 +314,62 @@ func chatEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread log.Debug().Msgf("Template found, input modified to: %s", predInput) } - result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { - if input.Stream { - *c = append(*c, Choice{Delta: &Message{Role: "assistant", Content: s}}) - } else { - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}}) - } - }) - if err != nil { - return err - } - - resp := &OpenAIResponse{ - Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: result, - Object: "chat.completion", - } - if input.Stream { - resp.Object = "chat.completion.chunk" - jsonResult, _ := json.Marshal(resp) - log.Debug().Msgf("Response: %s", jsonResult) - log.Debug().Msgf("Handling stream request") + responses := make(chan OpenAIResponse) + + go func() { + ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { + resp := OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: []Choice{{Delta: &Message{Role: "assistant", Content: s}}}, + Object: "chat.completion.chunk", + } + + responses <- resp + return true + }) + close(responses) + }() + c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { - fmt.Fprintf(w, "event: data\n") - w.Flush() - fmt.Fprintf(w, "data: %s\n\n", jsonResult) - w.Flush() + for ev := range responses { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.Encode(ev) - fmt.Fprintf(w, "event: data\n") - w.Flush() + fmt.Fprintf(w, "event: data\n\n") + fmt.Fprintf(w, "data: %v\n\n", buf.String()) + log.Debug().Msgf("Sending chunk: %s", buf.String()) + w.Flush() + } + w.WriteString("event: data\n\n") resp := &OpenAIResponse{ Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []Choice{{FinishReason: "stop"}}, } respData, _ := json.Marshal(resp) - fmt.Fprintf(w, "data: %s\n\n", respData) + w.WriteString(fmt.Sprintf("data: %s\n\n", respData)) w.Flush() - - // fmt.Fprintf(w, "data: [DONE]\n\n") - // w.Flush() })) return nil } + result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}}) + }, nil) + if err != nil { + return err + } + + resp := &OpenAIResponse{ + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. + Choices: result, + Object: "chat.completion", + } + // Return the prediction in the response body return c.JSON(resp) } @@ -392,7 +403,7 @@ func editEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, thread result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { *c = append(*c, Choice{Text: s}) - }) + }, nil) if err != nil { return err } diff --git a/api/prediction.go b/api/prediction.go index 65cfce9..4f01abb 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -16,12 +16,12 @@ import ( var mutexMap sync.Mutex var mutexes map[string]*sync.Mutex = make(map[string]*sync.Mutex) -func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (string, error), error) { +func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback func(string) bool) (func() (string, error), error) { var model *llama.LLama var gptModel *gptj.GPTJ var gpt2Model *gpt2.GPT2 var stableLMModel *gpt2.StableLM - + supportStreams := false modelFile := c.Model // Try to load the model @@ -125,7 +125,13 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri ) } case model != nil: + supportStreams = true fn = func() (string, error) { + + if tokenCallback != nil { + model.SetTokenCallback(tokenCallback) + } + // Generate the prediction using the language model predictOptions := []llama.PredictOption{ llama.SetTemperature(c.Temperature), @@ -185,11 +191,15 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri l.Lock() defer l.Unlock() - return fn() + res, err := fn() + if tokenCallback != nil && !supportStreams { + tokenCallback(res) + } + return res, err }, nil } -func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice)) ([]Choice, error) { +func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, loader *model.ModelLoader, cb func(string, *[]Choice), tokenCallback func(string) bool) ([]Choice, error) { result := []Choice{} n := input.N @@ -199,7 +209,7 @@ func ComputeChoices(predInput string, input *OpenAIRequest, config *Config, load } // get the model function to call for the result - predFunc, err := ModelInference(predInput, loader, *config) + predFunc, err := ModelInference(predInput, loader, *config, tokenCallback) if err != nil { return result, err }