|
|
@ -4,6 +4,7 @@ import ( |
|
|
|
"embed" |
|
|
|
"embed" |
|
|
|
"net/http" |
|
|
|
"net/http" |
|
|
|
"strconv" |
|
|
|
"strconv" |
|
|
|
|
|
|
|
"sync" |
|
|
|
|
|
|
|
|
|
|
|
llama "github.com/go-skynet/llama/go" |
|
|
|
llama "github.com/go-skynet/llama/go" |
|
|
|
"github.com/gofiber/fiber/v2" |
|
|
|
"github.com/gofiber/fiber/v2" |
|
|
@ -28,9 +29,12 @@ func api(l *llama.LLama, listenAddr string, threads int) error { |
|
|
|
"tokens": 100 |
|
|
|
"tokens": 100 |
|
|
|
}' |
|
|
|
}' |
|
|
|
*/ |
|
|
|
*/ |
|
|
|
|
|
|
|
var mutex = &sync.Mutex{} |
|
|
|
|
|
|
|
|
|
|
|
// Endpoint to generate the prediction
|
|
|
|
// Endpoint to generate the prediction
|
|
|
|
app.Post("/predict", func(c *fiber.Ctx) error { |
|
|
|
app.Post("/predict", func(c *fiber.Ctx) error { |
|
|
|
|
|
|
|
mutex.Lock() |
|
|
|
|
|
|
|
defer mutex.Unlock() |
|
|
|
// Get input data from the request body
|
|
|
|
// Get input data from the request body
|
|
|
|
input := new(struct { |
|
|
|
input := new(struct { |
|
|
|
Text string `json:"text"` |
|
|
|
Text string `json:"text"` |
|
|
|