From b062f3142bd28e89750639ffd71c36edef34c493 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 16 Apr 2023 10:16:48 +0200 Subject: [PATCH] feat: enhance API, expose more parameters (#24) Signed-off-by: mudler --- api/api.go | 82 +++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/api/api.go b/api/api.go index 8892fbb..1a13bb7 100644 --- a/api/api.go +++ b/api/api.go @@ -26,10 +26,10 @@ type OpenAIResponse struct { } type Choice struct { - Index int `json:"index,omitempty"` - FinishReason string `json:"finish_reason,omitempty"` - Message Message `json:"message,omitempty"` - Text string `json:"text,omitempty"` + Index int `json:"index,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Message *Message `json:"message,omitempty"` + Text string `json:"text,omitempty"` } type Message struct { @@ -47,20 +47,29 @@ type OpenAIRequest struct { // Prompt is read only by completion API calls Prompt string `json:"prompt"` - + // Messages is read only by chat/completion API calls Messages []Message `json:"messages"` + Echo bool `json:"echo"` // Common options between all the API calls TopP float64 `json:"top_p"` TopK int `json:"top_k"` Temperature float64 `json:"temperature"` Maxtokens int `json:"max_tokens"` + + N int `json:"n"` + + // Custom parameters - not present in the OpenAI API + Batch int `json:"batch"` + F16 bool `json:"f16kv"` + IgnoreEOS bool `json:"ignore_eos"` } //go:embed index.html var indexHTML embed.FS +// https://platform.openai.com/docs/api-reference/completions func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { var err error @@ -139,31 +148,58 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa predInput = templatedInput } - // Generate the prediction using the language model - prediction, err := model.Predict( - predInput, - llama.SetTemperature(temperature), - llama.SetTopP(topP), - llama.SetTopK(topK), - llama.SetTokens(tokens), - llama.SetThreads(threads), - ) - if err != nil { - return err + result := []Choice{} + + n := input.N + + if input.N == 0 { + n = 1 } - if chat { - // Return the chat prediction in the response body - return c.JSON(OpenAIResponse{ - Model: input.Model, - Choices: []Choice{{Message: Message{Role: "assistant", Content: prediction}}}, - }) + for i := 0; i < n; i++ { + // Generate the prediction using the language model + predictOptions := []llama.PredictOption{ + llama.SetTemperature(temperature), + llama.SetTopP(topP), + llama.SetTopK(topK), + llama.SetTokens(tokens), + llama.SetThreads(threads), + } + + if input.Batch != 0 { + predictOptions = append(predictOptions, llama.SetBatch(input.Batch)) + } + + if input.F16 { + predictOptions = append(predictOptions, llama.EnableF16KV) + } + + if input.IgnoreEOS { + predictOptions = append(predictOptions, llama.IgnoreEOS) + } + + prediction, err := model.Predict( + predInput, + predictOptions..., + ) + if err != nil { + return err + } + + if input.Echo { + prediction = predInput + prediction + } + if chat { + result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}}) + } else { + result = append(result, Choice{Text: prediction}) + } } // Return the prediction in the response body return c.JSON(OpenAIResponse{ Model: input.Model, - Choices: []Choice{{Text: prediction}}, + Choices: result, }) } }