From 63601fabd186ebb149d5baf88ffd0aa1bb775193 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 16 Apr 2023 10:40:50 +0200 Subject: [PATCH] feat: drop default model and llama-specific API (#26) Signed-off-by: mudler --- README.md | 29 ++---------------- api/api.go | 77 +++-------------------------------------------- client/client.go | 75 --------------------------------------------- client/options.go | 51 ------------------------------- main.go | 20 ++---------- 5 files changed, 9 insertions(+), 243 deletions(-) delete mode 100644 client/client.go delete mode 100644 client/options.go diff --git a/README.md b/README.md index e626118..11b1be6 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ docker compose up -d --build # Now API is accessible at localhost:8080 curl http://localhost:8080/v1/models + # {"object":"list","data":[{"id":"your-model.bin","object":"model"}]} curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{ "model": "your-model.bin", @@ -88,7 +89,7 @@ llama-cli --model --instruction [--input ] [-- | template | TEMPLATE | | A file containing a template for output formatting (optional). | | instruction | INSTRUCTION | | Input prompt text or instruction. "-" for STDIN. | | input | INPUT | - | Path to text or "-" for STDIN. | -| model | MODEL_PATH | | The path to the pre-trained GPT-based model. | +| model | MODEL | | The path to the pre-trained GPT-based model. | | tokens | TOKENS | 128 | The maximum number of tokens to generate. | | threads | THREADS | NumCPU() | The number of threads to use for text generation. | | temperature | TEMPERATURE | 0.95 | Sampling temperature for model output. ( values between `0.1` and `1.0` ) | @@ -216,32 +217,6 @@ python 828bddec6162a023114ce19146cb2b82/gistfile1.txt models tokenizer.model # There will be a new model with the ".tmp" extension, you have to use that one! ``` -### Golang client API - -The `llama-cli` codebase has also a small client in go that can be used alongside with the api: - -```golang -package main - -import ( - "fmt" - - client "github.com/go-skynet/llama-cli/client" -) - -func main() { - - cli := client.NewClient("http://ip:port") - - out, err := cli.Predict("What's an alpaca?") - if err != nil { - panic(err) - } - - fmt.Println(out) -} -``` - ### Windows compatibility It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/llama-cli/issues/2 diff --git a/api/api.go b/api/api.go index 1a13bb7..3667d39 100644 --- a/api/api.go +++ b/api/api.go @@ -4,7 +4,6 @@ import ( "embed" "fmt" "net/http" - "strconv" "strings" "sync" @@ -70,7 +69,7 @@ type OpenAIRequest struct { var indexHTML embed.FS // https://platform.openai.com/docs/api-reference/completions -func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { +func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { var err error var model *llama.LLama @@ -82,10 +81,7 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa } if input.Model == "" { - if defaultModel == nil { - return fmt.Errorf("no default model loaded, and no model specified") - } - model = defaultModel + return fmt.Errorf("no model specified") } else { model, err = loader.LoadModel(input.Model) if err != nil { @@ -204,7 +200,7 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa } } -func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr string, threads int) error { +func Start(loader *model.ModelLoader, listenAddr string, threads int) error { app := fiber.New() // Default middleware config @@ -217,8 +213,8 @@ func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr stri var mumutex = &sync.Mutex{} // openAI compatible API endpoint - app.Post("/v1/chat/completions", openAIEndpoint(true, defaultModel, loader, threads, mutex, mumutex, mu)) - app.Post("/v1/completions", openAIEndpoint(false, defaultModel, loader, threads, mutex, mumutex, mu)) + app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, mutex, mumutex, mu)) + app.Post("/v1/completions", openAIEndpoint(false, loader, threads, mutex, mumutex, mu)) app.Get("/v1/models", func(c *fiber.Ctx) error { models, err := loader.ListModels() if err != nil { @@ -243,69 +239,6 @@ func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr stri NotFoundFile: "index.html", })) - /* - curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{ - "text": "What is an alpaca?", - "topP": 0.8, - "topK": 50, - "temperature": 0.7, - "tokens": 100 - }' - */ - // Endpoint to generate the prediction - app.Post("/predict", func(c *fiber.Ctx) error { - mutex.Lock() - defer mutex.Unlock() - // Get input data from the request body - input := new(struct { - Text string `json:"text"` - }) - if err := c.BodyParser(input); err != nil { - return err - } - - // Set the parameters for the language model prediction - topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9 - if err != nil { - return err - } - - topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40 - if err != nil { - return err - } - - temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5 - if err != nil { - return err - } - - tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128 - if err != nil { - return err - } - - // Generate the prediction using the language model - prediction, err := defaultModel.Predict( - input.Text, - llama.SetTemperature(temperature), - llama.SetTopP(topP), - llama.SetTopK(topK), - llama.SetTokens(tokens), - llama.SetThreads(threads), - ) - if err != nil { - return err - } - - // Return the prediction in the response body - return c.JSON(struct { - Prediction string `json:"prediction"` - }{ - Prediction: prediction, - }) - }) - // Start the server app.Listen(listenAddr) return nil diff --git a/client/client.go b/client/client.go deleted file mode 100644 index 785a46d..0000000 --- a/client/client.go +++ /dev/null @@ -1,75 +0,0 @@ -package client - -import ( - "bytes" - "encoding/json" - "fmt" - "net/http" -) - -type Prediction struct { - Prediction string `json:"prediction"` -} - -type Client struct { - baseURL string - client *http.Client - endpoint string -} - -func NewClient(baseURL string) *Client { - return &Client{ - baseURL: baseURL, - client: &http.Client{}, - endpoint: "/predict", - } -} - -type InputData struct { - Text string `json:"text"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Tokens int `json:"tokens,omitempty"` -} - -func (c *Client) Predict(text string, opts ...InputOption) (string, error) { - input := NewInputData(opts...) - input.Text = text - - // encode input data to JSON format - inputBytes, err := json.Marshal(input) - if err != nil { - return "", err - } - - // create HTTP request - url := c.baseURL + c.endpoint - req, err := http.NewRequest("POST", url, bytes.NewBuffer(inputBytes)) - if err != nil { - return "", err - } - - // set request headers - req.Header.Set("Content-Type", "application/json") - - // send request and get response - resp, err := c.client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("request failed with status %d", resp.StatusCode) - } - - // decode response body to Prediction struct - var prediction Prediction - err = json.NewDecoder(resp.Body).Decode(&prediction) - if err != nil { - return "", err - } - - return prediction.Prediction, nil -} diff --git a/client/options.go b/client/options.go deleted file mode 100644 index 6635763..0000000 --- a/client/options.go +++ /dev/null @@ -1,51 +0,0 @@ -package client - -import "net/http" - -type ClientOption func(c *Client) - -func WithHTTPClient(httpClient *http.Client) ClientOption { - return func(c *Client) { - c.client = httpClient - } -} - -func WithEndpoint(endpoint string) ClientOption { - return func(c *Client) { - c.endpoint = endpoint - } -} - -type InputOption func(d *InputData) - -func NewInputData(opts ...InputOption) *InputData { - data := &InputData{} - for _, opt := range opts { - opt(data) - } - return data -} - -func WithTopP(topP float64) InputOption { - return func(d *InputData) { - d.TopP = topP - } -} - -func WithTopK(topK int) InputOption { - return func(d *InputData) { - d.TopK = topK - } -} - -func WithTemperature(temperature float64) InputOption { - return func(d *InputData) { - d.Temperature = temperature - } -} - -func WithTokens(tokens int) InputOption { - return func(d *InputData) { - d.Tokens = tokens - } -} diff --git a/main.go b/main.go index 6fc7321..91397e2 100644 --- a/main.go +++ b/main.go @@ -57,7 +57,7 @@ func templateString(t string, in interface{}) (string, error) { var modelFlags = []cli.Flag{ &cli.StringFlag{ Name: "model", - EnvVars: []string{"MODEL_PATH"}, + EnvVars: []string{"MODEL"}, }, &cli.IntFlag{ Name: "tokens", @@ -134,10 +134,6 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came Name: "models-path", EnvVars: []string{"MODELS_PATH"}, }, - &cli.StringFlag{ - Name: "default-model", - EnvVars: []string{"DEFAULT_MODEL"}, - }, &cli.StringFlag{ Name: "address", EnvVars: []string{"ADDRESS"}, @@ -150,19 +146,7 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came }, }, Action: func(ctx *cli.Context) error { - - var defaultModel *llama.LLama - defModel := ctx.String("default-model") - if defModel != "" { - opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))} - var err error - defaultModel, err = llama.New(ctx.String("default-model"), opts...) - if err != nil { - return err - } - } - - return api.Start(defaultModel, model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads")) + return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads")) }, }, },