|
|
@ -1,6 +1,8 @@ |
|
|
|
package api |
|
|
|
package api |
|
|
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
import ( |
|
|
|
|
|
|
|
"encoding/json" |
|
|
|
|
|
|
|
"errors" |
|
|
|
"fmt" |
|
|
|
"fmt" |
|
|
|
"strings" |
|
|
|
"strings" |
|
|
|
"sync" |
|
|
|
"sync" |
|
|
@ -11,6 +13,7 @@ import ( |
|
|
|
"github.com/gofiber/fiber/v2" |
|
|
|
"github.com/gofiber/fiber/v2" |
|
|
|
"github.com/gofiber/fiber/v2/middleware/cors" |
|
|
|
"github.com/gofiber/fiber/v2/middleware/cors" |
|
|
|
"github.com/gofiber/fiber/v2/middleware/recover" |
|
|
|
"github.com/gofiber/fiber/v2/middleware/recover" |
|
|
|
|
|
|
|
"github.com/rs/zerolog/log" |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
type OpenAIResponse struct { |
|
|
|
type OpenAIResponse struct { |
|
|
@ -65,7 +68,7 @@ type OpenAIRequest struct { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// https://platform.openai.com/docs/api-reference/completions
|
|
|
|
// https://platform.openai.com/docs/api-reference/completions
|
|
|
|
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { |
|
|
|
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { |
|
|
|
return func(c *fiber.Ctx) error { |
|
|
|
return func(c *fiber.Ctx) error { |
|
|
|
var err error |
|
|
|
var err error |
|
|
|
var model *llama.LLama |
|
|
|
var model *llama.LLama |
|
|
@ -76,10 +79,23 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 |
|
|
|
if err := c.BodyParser(input); err != nil { |
|
|
|
if err := c.BodyParser(input); err != nil { |
|
|
|
return err |
|
|
|
return err |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
modelFile := input.Model |
|
|
|
|
|
|
|
received, _ := json.Marshal(input) |
|
|
|
|
|
|
|
|
|
|
|
if input.Model == "" { |
|
|
|
log.Debug().Msgf("Request received: %s", string(received)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Set model from bearer token, if available
|
|
|
|
|
|
|
|
bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") |
|
|
|
|
|
|
|
bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) |
|
|
|
|
|
|
|
if modelFile == "" && !bearerExists { |
|
|
|
return fmt.Errorf("no model specified") |
|
|
|
return fmt.Errorf("no model specified") |
|
|
|
} else { |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if bearerExists { // model specified in bearer token takes precedence
|
|
|
|
|
|
|
|
log.Debug().Msgf("Using model from bearer token: %s", bearer) |
|
|
|
|
|
|
|
modelFile = bearer |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Try to load the model with both
|
|
|
|
// Try to load the model with both
|
|
|
|
var llamaerr error |
|
|
|
var llamaerr error |
|
|
|
llamaOpts := []llama.ModelOption{} |
|
|
|
llamaOpts := []llama.ModelOption{} |
|
|
@ -90,31 +106,25 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 |
|
|
|
llamaOpts = append(llamaOpts, llama.EnableF16Memory) |
|
|
|
llamaOpts = append(llamaOpts, llama.EnableF16Memory) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
model, llamaerr = loader.LoadLLaMAModel(input.Model, llamaOpts...) |
|
|
|
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...) |
|
|
|
if llamaerr != nil { |
|
|
|
if llamaerr != nil { |
|
|
|
gptModel, err = loader.LoadGPTJModel(input.Model) |
|
|
|
gptModel, err = loader.LoadGPTJModel(modelFile) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors
|
|
|
|
return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors
|
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
|
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
|
|
if input.Model != "" { |
|
|
|
|
|
|
|
mutexMap.Lock() |
|
|
|
mutexMap.Lock() |
|
|
|
l, ok := mutexes[input.Model] |
|
|
|
l, ok := mutexes[modelFile] |
|
|
|
if !ok { |
|
|
|
if !ok { |
|
|
|
m := &sync.Mutex{} |
|
|
|
m := &sync.Mutex{} |
|
|
|
mutexes[input.Model] = m |
|
|
|
mutexes[modelFile] = m |
|
|
|
l = m |
|
|
|
l = m |
|
|
|
} |
|
|
|
} |
|
|
|
mutexMap.Unlock() |
|
|
|
mutexMap.Unlock() |
|
|
|
l.Lock() |
|
|
|
l.Lock() |
|
|
|
defer l.Unlock() |
|
|
|
defer l.Unlock() |
|
|
|
} else { |
|
|
|
|
|
|
|
defaultMutex.Lock() |
|
|
|
|
|
|
|
defer defaultMutex.Unlock() |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Set the parameters for the language model prediction
|
|
|
|
// Set the parameters for the language model prediction
|
|
|
|
topP := input.TopP |
|
|
|
topP := input.TopP |
|
|
@ -139,6 +149,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 |
|
|
|
predInput := input.Prompt |
|
|
|
predInput := input.Prompt |
|
|
|
if chat { |
|
|
|
if chat { |
|
|
|
mess := []string{} |
|
|
|
mess := []string{} |
|
|
|
|
|
|
|
// TODO: encode roles
|
|
|
|
for _, i := range input.Messages { |
|
|
|
for _, i := range input.Messages { |
|
|
|
mess = append(mess, i.Content) |
|
|
|
mess = append(mess, i.Content) |
|
|
|
} |
|
|
|
} |
|
|
@ -147,11 +158,12 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
|
|
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix
|
|
|
|
templatedInput, err := loader.TemplatePrefix(input.Model, struct { |
|
|
|
templatedInput, err := loader.TemplatePrefix(modelFile, struct { |
|
|
|
Input string |
|
|
|
Input string |
|
|
|
}{Input: predInput}) |
|
|
|
}{Input: predInput}) |
|
|
|
if err == nil { |
|
|
|
if err == nil { |
|
|
|
predInput = templatedInput |
|
|
|
predInput = templatedInput |
|
|
|
|
|
|
|
log.Debug().Msgf("Template found, input modified to: %s", predInput) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
result := []Choice{} |
|
|
|
result := []Choice{} |
|
|
@ -223,8 +235,6 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
for i := 0; i < n; i++ { |
|
|
|
for i := 0; i < n; i++ { |
|
|
|
var prediction string |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prediction, err := predFunc() |
|
|
|
prediction, err := predFunc() |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
return err |
|
|
@ -241,30 +251,19 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
jsonResult, _ := json.Marshal(result) |
|
|
|
|
|
|
|
log.Debug().Msgf("Response: %s", jsonResult) |
|
|
|
|
|
|
|
|
|
|
|
// Return the prediction in the response body
|
|
|
|
// Return the prediction in the response body
|
|
|
|
return c.JSON(OpenAIResponse{ |
|
|
|
return c.JSON(OpenAIResponse{ |
|
|
|
Model: input.Model, |
|
|
|
Model: input.Model, // we have to return what the user sent here, due to OpenAI spec.
|
|
|
|
Choices: result, |
|
|
|
Choices: result, |
|
|
|
}) |
|
|
|
}) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error { |
|
|
|
func listModels(loader *model.ModelLoader) func(ctx *fiber.Ctx) error { |
|
|
|
app := fiber.New() |
|
|
|
return func(c *fiber.Ctx) error { |
|
|
|
|
|
|
|
|
|
|
|
// Default middleware config
|
|
|
|
|
|
|
|
app.Use(recover.New()) |
|
|
|
|
|
|
|
app.Use(cors.New()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
|
|
|
|
|
|
var mutex = &sync.Mutex{} |
|
|
|
|
|
|
|
mu := map[string]*sync.Mutex{} |
|
|
|
|
|
|
|
var mumutex = &sync.Mutex{} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// openAI compatible API endpoint
|
|
|
|
|
|
|
|
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mutex, mumutex, mu)) |
|
|
|
|
|
|
|
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mutex, mumutex, mu)) |
|
|
|
|
|
|
|
app.Get("/v1/models", func(c *fiber.Ctx) error { |
|
|
|
|
|
|
|
models, err := loader.ListModels() |
|
|
|
models, err := loader.ListModels() |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
return err |
|
|
@ -281,8 +280,48 @@ func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f |
|
|
|
Object: "list", |
|
|
|
Object: "list", |
|
|
|
Data: dataModels, |
|
|
|
Data: dataModels, |
|
|
|
}) |
|
|
|
}) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error { |
|
|
|
|
|
|
|
// Return errors as JSON responses
|
|
|
|
|
|
|
|
app := fiber.New(fiber.Config{ |
|
|
|
|
|
|
|
// Override default error handler
|
|
|
|
|
|
|
|
ErrorHandler: func(ctx *fiber.Ctx, err error) error { |
|
|
|
|
|
|
|
// Status code defaults to 500
|
|
|
|
|
|
|
|
code := fiber.StatusInternalServerError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Retrieve the custom status code if it's a *fiber.Error
|
|
|
|
|
|
|
|
var e *fiber.Error |
|
|
|
|
|
|
|
if errors.As(err, &e) { |
|
|
|
|
|
|
|
code = e.Code |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Send custom error page
|
|
|
|
|
|
|
|
return ctx.Status(code).JSON(struct { |
|
|
|
|
|
|
|
Error string `json:"error"` |
|
|
|
|
|
|
|
}{Error: err.Error()}) |
|
|
|
|
|
|
|
}, |
|
|
|
}) |
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Default middleware config
|
|
|
|
|
|
|
|
app.Use(recover.New()) |
|
|
|
|
|
|
|
app.Use(cors.New()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
|
|
|
|
|
|
|
|
mu := map[string]*sync.Mutex{} |
|
|
|
|
|
|
|
var mumutex = &sync.Mutex{} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// openAI compatible API endpoint
|
|
|
|
|
|
|
|
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu)) |
|
|
|
|
|
|
|
app.Post("/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu)) |
|
|
|
|
|
|
|
app.Post("/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.Get("/v1/models", listModels(loader)) |
|
|
|
|
|
|
|
app.Get("/models", listModels(loader)) |
|
|
|
|
|
|
|
|
|
|
|
// Start the server
|
|
|
|
// Start the server
|
|
|
|
app.Listen(listenAddr) |
|
|
|
app.Listen(listenAddr) |
|
|
|
return nil |
|
|
|
return nil |
|
|
|