From bbc44689086191bbcaaef9d3205bf8996a7aa655 Mon Sep 17 00:00:00 2001 From: mudler Date: Sun, 9 Jul 2023 10:02:09 +0200 Subject: [PATCH] Make functions more compatible with OpenAI specs --- api/openai.go | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/api/openai.go b/api/openai.go index 3d09f3c..41796ac 100644 --- a/api/openai.go +++ b/api/openai.go @@ -77,7 +77,7 @@ type Message struct { // The message role Role string `json:"role,omitempty" yaml:"role"` // The message content - Content string `json:"content,omitempty" yaml:"content"` + Content *string `json:"content" yaml:"content"` // A result of a function call FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` } @@ -392,7 +392,7 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { resp := OpenAIResponse{ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Content: s}, Index: 0}}, + Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, Object: "chat.completion.chunk", } log.Debug().Msgf("Sending goroutine: %s", s) @@ -460,12 +460,15 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { } } r := config.Roles[role] + contentExists := i.Content != nil && *i.Content != "" if r != "" { - content = fmt.Sprint(r, " ", i.Content) + if contentExists { + content = fmt.Sprint(r, " ", *i.Content) + } if i.FunctionCall != nil { j, err := json.Marshal(i.FunctionCall) if err == nil { - if i.Content != "" { + if contentExists { content += "\n" + fmt.Sprint(r, " ", string(j)) } else { content = fmt.Sprint(r, " ", string(j)) @@ -473,11 +476,13 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { } } } else { - content = i.Content + if contentExists { + content = fmt.Sprint(*i.Content) + } if i.FunctionCall != nil { j, err := json.Marshal(i.FunctionCall) if err == nil { - if i.Content != "" { + if contentExists { content += "\n" + string(j) } else { content = string(j) @@ -600,7 +605,7 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { message = Finetune(*config, predInput, message) log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: message}}) + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) return } } @@ -623,18 +628,18 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { } prediction = Finetune(*config, predInput, prediction) - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: prediction}}) + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) } else { // otherwise reply with the function call *c = append(*c, Choice{ FinishReason: "function_call", - Message: &Message{Role: "function", FunctionCall: ss}, + Message: &Message{Role: "assistant", FunctionCall: ss}, }) } return } - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}}) + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) }, nil) if err != nil { return err