feat: enhance API, expose more parameters (#24)

Signed-off-by: mudler <mudler@c3os.io>
add/first-example
Ettore Di Giacinto 2 years ago committed by GitHub
parent c37175271f
commit b062f3142b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 54
      api/api.go

@ -28,7 +28,7 @@ type OpenAIResponse struct {
type Choice struct { type Choice struct {
Index int `json:"index,omitempty"` Index int `json:"index,omitempty"`
FinishReason string `json:"finish_reason,omitempty"` FinishReason string `json:"finish_reason,omitempty"`
Message Message `json:"message,omitempty"` Message *Message `json:"message,omitempty"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
} }
@ -51,16 +51,25 @@ type OpenAIRequest struct {
// Messages is read only by chat/completion API calls // Messages is read only by chat/completion API calls
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Echo bool `json:"echo"`
// Common options between all the API calls // Common options between all the API calls
TopP float64 `json:"top_p"` TopP float64 `json:"top_p"`
TopK int `json:"top_k"` TopK int `json:"top_k"`
Temperature float64 `json:"temperature"` Temperature float64 `json:"temperature"`
Maxtokens int `json:"max_tokens"` Maxtokens int `json:"max_tokens"`
N int `json:"n"`
// Custom parameters - not present in the OpenAI API
Batch int `json:"batch"`
F16 bool `json:"f16kv"`
IgnoreEOS bool `json:"ignore_eos"`
} }
//go:embed index.html //go:embed index.html
var indexHTML embed.FS var indexHTML embed.FS
// https://platform.openai.com/docs/api-reference/completions
func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
var err error var err error
@ -139,31 +148,58 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa
predInput = templatedInput predInput = templatedInput
} }
result := []Choice{}
n := input.N
if input.N == 0 {
n = 1
}
for i := 0; i < n; i++ {
// Generate the prediction using the language model // Generate the prediction using the language model
prediction, err := model.Predict( predictOptions := []llama.PredictOption{
predInput,
llama.SetTemperature(temperature), llama.SetTemperature(temperature),
llama.SetTopP(topP), llama.SetTopP(topP),
llama.SetTopK(topK), llama.SetTopK(topK),
llama.SetTokens(tokens), llama.SetTokens(tokens),
llama.SetThreads(threads), llama.SetThreads(threads),
}
if input.Batch != 0 {
predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
}
if input.F16 {
predictOptions = append(predictOptions, llama.EnableF16KV)
}
if input.IgnoreEOS {
predictOptions = append(predictOptions, llama.IgnoreEOS)
}
prediction, err := model.Predict(
predInput,
predictOptions...,
) )
if err != nil { if err != nil {
return err return err
} }
if input.Echo {
prediction = predInput + prediction
}
if chat { if chat {
// Return the chat prediction in the response body result = append(result, Choice{Message: &Message{Role: "assistant", Content: prediction}})
return c.JSON(OpenAIResponse{ } else {
Model: input.Model, result = append(result, Choice{Text: prediction})
Choices: []Choice{{Message: Message{Role: "assistant", Content: prediction}}}, }
})
} }
// Return the prediction in the response body // Return the prediction in the response body
return c.JSON(OpenAIResponse{ return c.JSON(OpenAIResponse{
Model: input.Model, Model: input.Model,
Choices: []Choice{{Text: prediction}}, Choices: result,
}) })
} }
} }

Loading…
Cancel
Save