Return model list

add/first-example v0.6
mudler 2 years ago
parent f43aeeb4a1
commit 93d8977ba2
  1. 33
      api.go
  2. 18
      model_loader.go

@ -35,6 +35,11 @@ type Message struct {
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`
} }
type OpenAIModel struct {
ID string `json:"id"`
Object string `json:"object"`
}
//go:embed index.html //go:embed index.html
var indexHTML embed.FS var indexHTML embed.FS
@ -241,11 +246,6 @@ func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, thre
app.Use(recover.New()) app.Use(recover.New())
app.Use(cors.New()) app.Use(cors.New())
app.Use("/", filesystem.New(filesystem.Config{
Root: http.FS(indexHTML),
NotFoundFile: "index.html",
}))
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
var mutex = &sync.Mutex{} var mutex = &sync.Mutex{}
mu := map[string]*sync.Mutex{} mu := map[string]*sync.Mutex{}
@ -254,6 +254,29 @@ func api(defaultModel *llama.LLama, loader *ModelLoader, listenAddr string, thre
// openAI compatible API endpoint // openAI compatible API endpoint
app.Post("/v1/chat/completions", chatEndpoint(defaultModel, loader, threads, mutex, mumutex, mu)) app.Post("/v1/chat/completions", chatEndpoint(defaultModel, loader, threads, mutex, mumutex, mu))
app.Post("/v1/completions", completionEndpoint(defaultModel, loader, threads, mutex, mumutex, mu)) app.Post("/v1/completions", completionEndpoint(defaultModel, loader, threads, mutex, mumutex, mu))
app.Get("/v1/models", func(c *fiber.Ctx) error {
models, err := loader.ListModels()
if err != nil {
return err
}
dataModels := []OpenAIModel{}
for _, m := range models {
dataModels = append(dataModels, OpenAIModel{ID: m, Object: "model"})
}
return c.JSON(struct {
Object string `json:"object"`
Data []OpenAIModel `json:"data"`
}{
Object: "list",
Data: dataModels,
})
})
app.Use("/", filesystem.New(filesystem.Config{
Root: http.FS(indexHTML),
NotFoundFile: "index.html",
}))
/* /*
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{ curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{

@ -3,8 +3,10 @@ package main
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"text/template" "text/template"
@ -22,6 +24,22 @@ func NewModelLoader(modelPath string) *ModelLoader {
return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)} return &ModelLoader{modelPath: modelPath, models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)}
} }
func (ml *ModelLoader) ListModels() ([]string, error) {
files, err := ioutil.ReadDir(ml.modelPath)
if err != nil {
return []string{}, err
}
models := []string{}
for _, file := range files {
if strings.HasSuffix(file.Name(), ".bin") {
models = append(models, strings.TrimRight(file.Name(), ".bin"))
}
}
return models, nil
}
func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, error) { func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, error) {
ml.mu.Lock() ml.mu.Lock()
defer ml.mu.Unlock() defer ml.mu.Unlock()

Loading…
Cancel
Save