Add stopwords, debug mode, and other API enhancements (#54)

Signed-off-by: mudler <mudler@mocaccino.org>
add/first-example
Ettore Di Giacinto 2 years ago committed by GitHub
parent 4b7e83056d
commit 5cba71de70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      Makefile
  2. 38
      api/api.go
  3. 2
      go.mod
  4. 2
      go.sum
  5. 5
      main.go

@ -2,7 +2,7 @@ GOCMD=go
GOTEST=$(GOCMD) test GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet GOVET=$(GOCMD) vet
BINARY_NAME=local-ai BINARY_NAME=local-ai
GOLLAMA_VERSION?=llama.cpp-5ecff35 GOLLAMA_VERSION?=llama.cpp-8687c1f
GOGPT4ALLJ_VERSION?=1f548782d80d48b9a0fac33aae6f129358787bc0 GOGPT4ALLJ_VERSION?=1f548782d80d48b9a0fac33aae6f129358787bc0
GOGPT2_VERSION?=1c24f5b86ac428cd5e81dae1f1427b1463bd2b06 GOGPT2_VERSION?=1c24f5b86ac428cd5e81dae1f1427b1463bd2b06

@ -48,6 +48,8 @@ type OpenAIRequest struct {
// Prompt is read only by completion API calls // Prompt is read only by completion API calls
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Stop string `json:"stop"`
// Messages is read only by chat/completion API calls // Messages is read only by chat/completion API calls
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
@ -61,15 +63,17 @@ type OpenAIRequest struct {
N int `json:"n"` N int `json:"n"`
// Custom parameters - not present in the OpenAI API // Custom parameters - not present in the OpenAI API
Batch int `json:"batch"` Batch int `json:"batch"`
F16 bool `json:"f16kv"` F16 bool `json:"f16kv"`
IgnoreEOS bool `json:"ignore_eos"` IgnoreEOS bool `json:"ignore_eos"`
RepeatPenalty float64 `json:"repeat_penalty"`
Keep int `json:"n_keep"`
Seed int `json:"seed"` Seed int `json:"seed"`
} }
// https://platform.openai.com/docs/api-reference/completions // https://platform.openai.com/docs/api-reference/completions
func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { func openAIEndpoint(chat, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
var err error var err error
var model *llama.LLama var model *llama.LLama
@ -269,6 +273,22 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16
llama.SetThreads(threads), llama.SetThreads(threads),
} }
if debug {
predictOptions = append(predictOptions, llama.Debug)
}
if input.Stop != "" {
predictOptions = append(predictOptions, llama.SetStopWords(input.Stop))
}
if input.RepeatPenalty != 0 {
predictOptions = append(predictOptions, llama.SetPenalty(input.RepeatPenalty))
}
if input.Keep != 0 {
predictOptions = append(predictOptions, llama.SetNKeep(input.Keep))
}
if input.Batch != 0 { if input.Batch != 0 {
predictOptions = append(predictOptions, llama.SetBatch(input.Batch)) predictOptions = append(predictOptions, llama.SetBatch(input.Batch))
} }
@ -341,7 +361,7 @@ func listModels(loader *model.ModelLoader) func(ctx *fiber.Ctx) error {
} }
} }
func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error { func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool, debug bool) error {
// Return errors as JSON responses // Return errors as JSON responses
app := fiber.New(fiber.Config{ app := fiber.New(fiber.Config{
// Override default error handler // Override default error handler
@ -371,11 +391,11 @@ func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f
var mumutex = &sync.Mutex{} var mumutex = &sync.Mutex{}
// openAI compatible API endpoint // openAI compatible API endpoint
app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu)) app.Post("/v1/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu))
app.Post("/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu)) app.Post("/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu))
app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu)) app.Post("/v1/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu))
app.Post("/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu)) app.Post("/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu))
app.Get("/v1/models", listModels(loader)) app.Get("/v1/models", listModels(loader))
app.Get("/models", listModels(loader)) app.Get("/models", listModels(loader))

@ -5,7 +5,7 @@ go 1.19
require ( require (
github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4 github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 github.com/go-skynet/go-llama.cpp v0.0.0-20230421172644-351a5a40eead
github.com/gofiber/fiber/v2 v2.42.0 github.com/gofiber/fiber/v2 v2.42.0
github.com/jaypipes/ghw v0.10.0 github.com/jaypipes/ghw v0.10.0
github.com/rs/zerolog v1.29.1 github.com/rs/zerolog v1.29.1

@ -18,6 +18,8 @@ github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrr
github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI=
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E=
github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw=
github.com/go-skynet/go-llama.cpp v0.0.0-20230421172644-351a5a40eead h1:C+lcH1srw+c0qPDx1WF8zjGiiOqoPxVICt7bI1sj5cM=
github.com/go-skynet/go-llama.cpp v0.0.0-20230421172644-351a5a40eead/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA=
github.com/gofiber/fiber/v2 v2.42.0 h1:Fnp7ybWvS+sjNQsFvkhf4G8OhXswvB6Vee8hM/LyS+8= github.com/gofiber/fiber/v2 v2.42.0 h1:Fnp7ybWvS+sjNQsFvkhf4G8OhXswvB6Vee8hM/LyS+8=

@ -81,10 +81,11 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
Copyright: "go-skynet authors", Copyright: "go-skynet authors",
Action: func(ctx *cli.Context) error { Action: func(ctx *cli.Context) error {
zerolog.SetGlobalLevel(zerolog.InfoLevel) zerolog.SetGlobalLevel(zerolog.InfoLevel)
if ctx.Bool("debug") { debugMode := ctx.Bool("debug")
if debugMode {
zerolog.SetGlobalLevel(zerolog.DebugLevel) zerolog.SetGlobalLevel(zerolog.DebugLevel)
} }
return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16")) return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), debugMode)
}, },
} }

Loading…
Cancel
Save