From 5cba71de700d5ec5d56443e82ef2bda7628401e8 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 21 Apr 2023 19:46:59 +0200 Subject: [PATCH] Add stopwords, debug mode, and other API enhancements (#54) Signed-off-by: mudler --- Makefile | 2 +- api/api.go | 38 +++++++++++++++++++++++++++++--------- go.mod | 2 +- go.sum | 2 ++ main.go | 5 +++-- 5 files changed, 36 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index 3035be8..91f8de3 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ GOCMD=go GOTEST=$(GOCMD) test GOVET=$(GOCMD) vet BINARY_NAME=local-ai -GOLLAMA_VERSION?=llama.cpp-5ecff35 +GOLLAMA_VERSION?=llama.cpp-8687c1f GOGPT4ALLJ_VERSION?=1f548782d80d48b9a0fac33aae6f129358787bc0 GOGPT2_VERSION?=1c24f5b86ac428cd5e81dae1f1427b1463bd2b06 diff --git a/api/api.go b/api/api.go index 60e750a..946dced 100644 --- a/api/api.go +++ b/api/api.go @@ -48,6 +48,8 @@ type OpenAIRequest struct { // Prompt is read only by completion API calls Prompt string `json:"prompt"` + Stop string `json:"stop"` + // Messages is read only by chat/completion API calls Messages []Message `json:"messages"` @@ -61,15 +63,17 @@ type OpenAIRequest struct { N int `json:"n"` // Custom parameters - not present in the OpenAI API - Batch int `json:"batch"` - F16 bool `json:"f16kv"` - IgnoreEOS bool `json:"ignore_eos"` + Batch int `json:"batch"` + F16 bool `json:"f16kv"` + IgnoreEOS bool `json:"ignore_eos"` + RepeatPenalty float64 `json:"repeat_penalty"` + Keep int `json:"n_keep"` Seed int `json:"seed"` } // https://platform.openai.com/docs/api-reference/completions -func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { +func openAIEndpoint(chat, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { var err error var model *llama.LLama @@ -269,6 +273,22 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 llama.SetThreads(threads), } + if debug { + predictOptions = append(predictOptions, llama.Debug) + } + + if input.Stop != "" { + predictOptions = append(predictOptions, llama.SetStopWords(input.Stop)) + } + + if input.RepeatPenalty != 0 { + predictOptions = append(predictOptions, llama.SetPenalty(input.RepeatPenalty)) + } + + if input.Keep != 0 { + predictOptions = append(predictOptions, llama.SetNKeep(input.Keep)) + } + if input.Batch != 0 { predictOptions = append(predictOptions, llama.SetBatch(input.Batch)) } @@ -341,7 +361,7 @@ func listModels(loader *model.ModelLoader) func(ctx *fiber.Ctx) error { } } -func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error { +func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool, debug bool) error { // Return errors as JSON responses app := fiber.New(fiber.Config{ // Override default error handler @@ -371,11 +391,11 @@ func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f var mumutex = &sync.Mutex{} // openAI compatible API endpoint - app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu)) - app.Post("/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu)) + app.Post("/v1/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu)) + app.Post("/chat/completions", openAIEndpoint(true, debug, loader, threads, ctxSize, f16, mumutex, mu)) - app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu)) - app.Post("/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu)) + app.Post("/v1/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu)) + app.Post("/completions", openAIEndpoint(false, debug, loader, threads, ctxSize, f16, mumutex, mu)) app.Get("/v1/models", listModels(loader)) app.Get("/models", listModels(loader)) diff --git a/go.mod b/go.mod index e65ef28..2175ecb 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.19 require ( github.com/go-skynet/go-gpt2.cpp v0.0.0-20230420213900-1c24f5b86ac4 github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 - github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 + github.com/go-skynet/go-llama.cpp v0.0.0-20230421172644-351a5a40eead github.com/gofiber/fiber/v2 v2.42.0 github.com/jaypipes/ghw v0.10.0 github.com/rs/zerolog v1.29.1 diff --git a/go.sum b/go.sum index 0218734..5cf8e44 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,8 @@ github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 h1:rtrr github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod h1:5VZ9XbcINI0XcHhkcX8GPK8TplFGAzu1Hrg4tNiMCtI= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw= +github.com/go-skynet/go-llama.cpp v0.0.0-20230421172644-351a5a40eead h1:C+lcH1srw+c0qPDx1WF8zjGiiOqoPxVICt7bI1sj5cM= +github.com/go-skynet/go-llama.cpp v0.0.0-20230421172644-351a5a40eead/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofiber/fiber/v2 v2.42.0 h1:Fnp7ybWvS+sjNQsFvkhf4G8OhXswvB6Vee8hM/LyS+8= diff --git a/main.go b/main.go index 24861f2..82526ca 100644 --- a/main.go +++ b/main.go @@ -81,10 +81,11 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. Copyright: "go-skynet authors", Action: func(ctx *cli.Context) error { zerolog.SetGlobalLevel(zerolog.InfoLevel) - if ctx.Bool("debug") { + debugMode := ctx.Bool("debug") + if debugMode { zerolog.SetGlobalLevel(zerolog.DebugLevel) } - return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16")) + return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), debugMode) }, }