feat: drop default model and llama-specific API (#26)

Signed-off-by: mudler <mudler@c3os.io>
add/first-example v0.9.1
Ettore Di Giacinto 2 years ago committed by GitHub
parent 1370b4482f
commit 63601fabd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 29
      README.md
  2. 77
      api/api.go
  3. 75
      client/client.go
  4. 51
      client/options.go
  5. 20
      main.go

@ -27,6 +27,7 @@ docker compose up -d --build
# Now API is accessible at localhost:8080 # Now API is accessible at localhost:8080
curl http://localhost:8080/v1/models curl http://localhost:8080/v1/models
# {"object":"list","data":[{"id":"your-model.bin","object":"model"}]} # {"object":"list","data":[{"id":"your-model.bin","object":"model"}]}
curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{ curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d '{
"model": "your-model.bin", "model": "your-model.bin",
@ -88,7 +89,7 @@ llama-cli --model <model_path> --instruction <instruction> [--input <input>] [--
| template | TEMPLATE | | A file containing a template for output formatting (optional). | | template | TEMPLATE | | A file containing a template for output formatting (optional). |
| instruction | INSTRUCTION | | Input prompt text or instruction. "-" for STDIN. | | instruction | INSTRUCTION | | Input prompt text or instruction. "-" for STDIN. |
| input | INPUT | - | Path to text or "-" for STDIN. | | input | INPUT | - | Path to text or "-" for STDIN. |
| model | MODEL_PATH | | The path to the pre-trained GPT-based model. | | model | MODEL | | The path to the pre-trained GPT-based model. |
| tokens | TOKENS | 128 | The maximum number of tokens to generate. | | tokens | TOKENS | 128 | The maximum number of tokens to generate. |
| threads | THREADS | NumCPU() | The number of threads to use for text generation. | | threads | THREADS | NumCPU() | The number of threads to use for text generation. |
| temperature | TEMPERATURE | 0.95 | Sampling temperature for model output. ( values between `0.1` and `1.0` ) | | temperature | TEMPERATURE | 0.95 | Sampling temperature for model output. ( values between `0.1` and `1.0` ) |
@ -216,32 +217,6 @@ python 828bddec6162a023114ce19146cb2b82/gistfile1.txt models tokenizer.model
# There will be a new model with the ".tmp" extension, you have to use that one! # There will be a new model with the ".tmp" extension, you have to use that one!
``` ```
### Golang client API
The `llama-cli` codebase has also a small client in go that can be used alongside with the api:
```golang
package main
import (
"fmt"
client "github.com/go-skynet/llama-cli/client"
)
func main() {
cli := client.NewClient("http://ip:port")
out, err := cli.Predict("What's an alpaca?")
if err != nil {
panic(err)
}
fmt.Println(out)
}
```
### Windows compatibility ### Windows compatibility
It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/llama-cli/issues/2 It should work, however you need to make sure you give enough resources to the container. See https://github.com/go-skynet/llama-cli/issues/2

@ -4,7 +4,6 @@ import (
"embed" "embed"
"fmt" "fmt"
"net/http" "net/http"
"strconv"
"strings" "strings"
"sync" "sync"
@ -70,7 +69,7 @@ type OpenAIRequest struct {
var indexHTML embed.FS var indexHTML embed.FS
// https://platform.openai.com/docs/api-reference/completions // https://platform.openai.com/docs/api-reference/completions
func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { func openAIEndpoint(chat bool, loader *model.ModelLoader, threads int, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
var err error var err error
var model *llama.LLama var model *llama.LLama
@ -82,10 +81,7 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa
} }
if input.Model == "" { if input.Model == "" {
if defaultModel == nil { return fmt.Errorf("no model specified")
return fmt.Errorf("no default model loaded, and no model specified")
}
model = defaultModel
} else { } else {
model, err = loader.LoadModel(input.Model) model, err = loader.LoadModel(input.Model)
if err != nil { if err != nil {
@ -204,7 +200,7 @@ func openAIEndpoint(chat bool, defaultModel *llama.LLama, loader *model.ModelLoa
} }
} }
func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr string, threads int) error { func Start(loader *model.ModelLoader, listenAddr string, threads int) error {
app := fiber.New() app := fiber.New()
// Default middleware config // Default middleware config
@ -217,8 +213,8 @@ func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr stri
var mumutex = &sync.Mutex{} var mumutex = &sync.Mutex{}
// openAI compatible API endpoint // openAI compatible API endpoint
app.Post("/v1/chat/completions", openAIEndpoint(true, defaultModel, loader, threads, mutex, mumutex, mu)) app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, mutex, mumutex, mu))
app.Post("/v1/completions", openAIEndpoint(false, defaultModel, loader, threads, mutex, mumutex, mu)) app.Post("/v1/completions", openAIEndpoint(false, loader, threads, mutex, mumutex, mu))
app.Get("/v1/models", func(c *fiber.Ctx) error { app.Get("/v1/models", func(c *fiber.Ctx) error {
models, err := loader.ListModels() models, err := loader.ListModels()
if err != nil { if err != nil {
@ -243,69 +239,6 @@ func Start(defaultModel *llama.LLama, loader *model.ModelLoader, listenAddr stri
NotFoundFile: "index.html", NotFoundFile: "index.html",
})) }))
/*
curl --location --request POST 'http://localhost:8080/predict' --header 'Content-Type: application/json' --data-raw '{
"text": "What is an alpaca?",
"topP": 0.8,
"topK": 50,
"temperature": 0.7,
"tokens": 100
}'
*/
// Endpoint to generate the prediction
app.Post("/predict", func(c *fiber.Ctx) error {
mutex.Lock()
defer mutex.Unlock()
// Get input data from the request body
input := new(struct {
Text string `json:"text"`
})
if err := c.BodyParser(input); err != nil {
return err
}
// Set the parameters for the language model prediction
topP, err := strconv.ParseFloat(c.Query("topP", "0.9"), 64) // Default value of topP is 0.9
if err != nil {
return err
}
topK, err := strconv.Atoi(c.Query("topK", "40")) // Default value of topK is 40
if err != nil {
return err
}
temperature, err := strconv.ParseFloat(c.Query("temperature", "0.5"), 64) // Default value of temperature is 0.5
if err != nil {
return err
}
tokens, err := strconv.Atoi(c.Query("tokens", "128")) // Default value of tokens is 128
if err != nil {
return err
}
// Generate the prediction using the language model
prediction, err := defaultModel.Predict(
input.Text,
llama.SetTemperature(temperature),
llama.SetTopP(topP),
llama.SetTopK(topK),
llama.SetTokens(tokens),
llama.SetThreads(threads),
)
if err != nil {
return err
}
// Return the prediction in the response body
return c.JSON(struct {
Prediction string `json:"prediction"`
}{
Prediction: prediction,
})
})
// Start the server // Start the server
app.Listen(listenAddr) app.Listen(listenAddr)
return nil return nil

@ -1,75 +0,0 @@
package client
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
)
type Prediction struct {
Prediction string `json:"prediction"`
}
type Client struct {
baseURL string
client *http.Client
endpoint string
}
func NewClient(baseURL string) *Client {
return &Client{
baseURL: baseURL,
client: &http.Client{},
endpoint: "/predict",
}
}
type InputData struct {
Text string `json:"text"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Tokens int `json:"tokens,omitempty"`
}
func (c *Client) Predict(text string, opts ...InputOption) (string, error) {
input := NewInputData(opts...)
input.Text = text
// encode input data to JSON format
inputBytes, err := json.Marshal(input)
if err != nil {
return "", err
}
// create HTTP request
url := c.baseURL + c.endpoint
req, err := http.NewRequest("POST", url, bytes.NewBuffer(inputBytes))
if err != nil {
return "", err
}
// set request headers
req.Header.Set("Content-Type", "application/json")
// send request and get response
resp, err := c.client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("request failed with status %d", resp.StatusCode)
}
// decode response body to Prediction struct
var prediction Prediction
err = json.NewDecoder(resp.Body).Decode(&prediction)
if err != nil {
return "", err
}
return prediction.Prediction, nil
}

@ -1,51 +0,0 @@
package client
import "net/http"
type ClientOption func(c *Client)
func WithHTTPClient(httpClient *http.Client) ClientOption {
return func(c *Client) {
c.client = httpClient
}
}
func WithEndpoint(endpoint string) ClientOption {
return func(c *Client) {
c.endpoint = endpoint
}
}
type InputOption func(d *InputData)
func NewInputData(opts ...InputOption) *InputData {
data := &InputData{}
for _, opt := range opts {
opt(data)
}
return data
}
func WithTopP(topP float64) InputOption {
return func(d *InputData) {
d.TopP = topP
}
}
func WithTopK(topK int) InputOption {
return func(d *InputData) {
d.TopK = topK
}
}
func WithTemperature(temperature float64) InputOption {
return func(d *InputData) {
d.Temperature = temperature
}
}
func WithTokens(tokens int) InputOption {
return func(d *InputData) {
d.Tokens = tokens
}
}

@ -57,7 +57,7 @@ func templateString(t string, in interface{}) (string, error) {
var modelFlags = []cli.Flag{ var modelFlags = []cli.Flag{
&cli.StringFlag{ &cli.StringFlag{
Name: "model", Name: "model",
EnvVars: []string{"MODEL_PATH"}, EnvVars: []string{"MODEL"},
}, },
&cli.IntFlag{ &cli.IntFlag{
Name: "tokens", Name: "tokens",
@ -134,10 +134,6 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
Name: "models-path", Name: "models-path",
EnvVars: []string{"MODELS_PATH"}, EnvVars: []string{"MODELS_PATH"},
}, },
&cli.StringFlag{
Name: "default-model",
EnvVars: []string{"DEFAULT_MODEL"},
},
&cli.StringFlag{ &cli.StringFlag{
Name: "address", Name: "address",
EnvVars: []string{"ADDRESS"}, EnvVars: []string{"ADDRESS"},
@ -150,19 +146,7 @@ echo "An Alpaca (Vicugna pacos) is a domesticated species of South American came
}, },
}, },
Action: func(ctx *cli.Context) error { Action: func(ctx *cli.Context) error {
return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
var defaultModel *llama.LLama
defModel := ctx.String("default-model")
if defModel != "" {
opts := []llama.ModelOption{llama.SetContext(ctx.Int("context-size"))}
var err error
defaultModel, err = llama.New(ctx.String("default-model"), opts...)
if err != nil {
return err
}
}
return api.Start(defaultModel, model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"))
}, },
}, },
}, },

Loading…
Cancel
Save