diff --git a/api.go b/api.go index 4a1a0b8..18c833b 100644 --- a/api.go +++ b/api.go @@ -2,24 +2,128 @@ package main import ( "embed" + "fmt" "net/http" "strconv" + "strings" "sync" llama "github.com/go-skynet/go-llama.cpp" "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/filesystem" + "github.com/gofiber/fiber/v2/middleware/recover" ) +type OpenAIResponse struct { + Created int `json:"created"` + Object string `json:"chat.completion"` + ID string `json:"id"` + Model string `json:"model"` + Choices []Choice `json:"choices"` +} + +type Choice struct { + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + Message Message `json:"message"` +} + +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + //go:embed index.html var indexHTML embed.FS -func api(l *llama.LLama, listenAddr string, threads int) error { +func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, threads int) error { app := fiber.New() + + // Default middleware config + app.Use(recover.New()) + app.Use(cors.New()) + app.Use("/", filesystem.New(filesystem.Config{ Root: http.FS(indexHTML), NotFoundFile: "index.html", })) + + var mutex = &sync.Mutex{} + + // openAI compatible API endpoint + app.Post("/v1/chat/completions", func(c *fiber.Ctx) error { + var err error + var model *llama.LLama + + // Get input data from the request body + input := new(struct { + Messages []Message `json:"messages"` + Model string `json:"model"` + }) + if err := c.BodyParser(input); err != nil { + return err + } + + if input.Model == "" { + if defaultModel == nil { + return fmt.Errorf("no default model loaded, and no model specified") + } + model = defaultModel + } else { + model, err = loader.LoadModel(input.Model) + if err != nil { + return err + } + } + + // Set the parameters for the language model prediction + topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9 + if err != nil { + return err + } + + topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40 + if err != nil { + return err + } + + temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5 + if err != nil { + return err + } + + tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128 + if err != nil { + return err + } + + mess := []string{} + for _, i := range input.Messages { + mess = append(mess, i.Content) + } + + fmt.Println("Received", input, input.Model) + // Generate the prediction using the language model + prediction, err := model.Predict( + strings.Join(mess, "\n"), + llama.SetTemperature(temperature), + llama.SetTopP(topP), + llama.SetTopK(topK), + llama.SetTokens(tokens), + llama.SetThreads(threads), + ) + if err != nil { + return err + } + + // Return the prediction in the response body + return c.JSON(OpenAIResponse{ + Model: input.Model, + Choices: []Choice{{Message: Message{Role: "assistant", Content: prediction}}}, + }) + }) + /* curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{ "text": "What is an alpaca?", @@ -29,8 +133,6 @@ func api(l *llama.LLama, listenAddr string, threads int) error { "tokens": 100 }' */ - var mutex = &sync.Mutex{} - // Endpoint to generate the prediction app.Post("/predict", func(c *fiber.Ctx) error { mutex.Lock() @@ -65,7 +167,7 @@ func api(l *llama.LLama, listenAddr string, threads int) error { } // Generate the prediction using the language model - prediction, err := l.Predict( + prediction, err := defaultModel.Predict( input.Text, llama.SetTemperature(temperature), llama.SetTopP(topP), @@ -86,6 +188,6 @@ func api(l *llama.LLama, listenAddr string, threads int) error { }) // Start the server - app.Listen(":8080") + app.Listen(listenAddr) return nil } diff --git a/main.go b/main.go index 9bcbdd3..5b9d91a 100644 --- a/main.go +++ b/main.go @@ -146,8 +146,12 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came Value: runtime.NumCPU(), }, &cli.StringFlag{ - Name: "model", - EnvVars: []string{"MODEL_PATH"}, + Name: "models-path", + EnvVars: []string{"MODELS_PATH"}, + }, + &cli.StringFlag{ + Name: "default-model", + EnvVars: []string{"default-model"}, }, &cli.StringFlag{ Name: "address", @@ -161,13 +165,19 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came }, }, Action: func(ctx *cli.Context) error { - l, err := llamaFromOptions(ctx) - if err != nil { - fmt.Println("Loading the model failed:", err.Error()) - os.Exit(1) + + var defaultModel *llama.LLama + defModel := ctx.String("default-model") + if defModel != "" { + opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))} + var err error + defaultModel, err = llama.New(ctx.String("default-model"), opts...) + if err != nil { + return err + } } - return api(l, ctx.String("address"), ctx.Int("threads")) + return api(defaultModel, NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads")) }, }, }, diff --git a/model_loader.go b/model_loader.go new file mode 100644 index 0000000..13c860f --- /dev/null +++ b/model_loader.go @@ -0,0 +1,52 @@ +package main + +import ( + "fmt" + "os" + "path/filepath" + "sync" + + llama "github.com/go-skynet/go-llama.cpp" +) + +type ModelLoader struct { + modelPath string + mu sync.Mutex + models map[string]*llama.LLama +} + +func NewModelLoader(modelPath string) *ModelLoader { + return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama)} +} + +func (ml *ModelLoader) LoadModel(s string, opts ...llama.ModelOption) (*llama.LLama, error) { + ml.mu.Lock() + defer ml.mu.Unlock() + + // Check if we already have a loaded model + modelFile := filepath.Join(ml.modelPath, s) + + if m, ok := ml.models[modelFile]; ok { + return m, nil + } + + // Check if the model path exists + if _, err := os.Stat(modelFile); os.IsNotExist(err) { + // try to find a s.bin + modelBin := fmt.Sprintf("%s.bin", modelFile) + if _, err := os.Stat(modelBin); os.IsNotExist(err) { + return nil, err + } else { + modelFile = modelBin + } + } + + // Load the model and keep it in memory for later use + model, err := llama.New(modelFile, opts...) + if err != nil { + return nil, err + } + + ml.models[modelFile] = model + return model, err +}