From 6f54cab3f04a486eeb2be6fd65107c7bfb0a8cc4 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Sun, 21 May 2023 14:38:25 +0200 Subject: [PATCH] feat: allow to set cors (#339) --- api/api.go | 69 +++++++++++++++++-------------- api/api_test.go | 6 +-- api/openai.go | 60 +++++++++++++-------------- api/options.go | 108 ++++++++++++++++++++++++++++++++++++++++++++++++ main.go | 21 +++++++++- 5 files changed, 199 insertions(+), 65 deletions(-) create mode 100644 api/options.go diff --git a/api/api.go b/api/api.go index b81a89f..b8d77f2 100644 --- a/api/api.go +++ b/api/api.go @@ -1,10 +1,8 @@ package api import ( - "context" "errors" - model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/logger" @@ -13,16 +11,18 @@ import ( "github.com/rs/zerolog/log" ) -func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App { +func App(opts ...AppOption) *fiber.App { + options := newOptions(opts...) + zerolog.SetGlobalLevel(zerolog.InfoLevel) - if debug { + if options.debug { zerolog.SetGlobalLevel(zerolog.DebugLevel) } // Return errors as JSON responses app := fiber.New(fiber.Config{ - BodyLimit: uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB - DisableStartupMessage: disableMessage, + BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB + DisableStartupMessage: options.disableMessage, // Override default error handler ErrorHandler: func(ctx *fiber.Ctx, err error) error { // Status code defaults to 500 @@ -43,24 +43,24 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload }, }) - if debug { + if options.debug { app.Use(logger.New(logger.Config{ Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", })) } cm := NewConfigMerger() - if err := cm.LoadConfigs(loader.ModelPath); err != nil { + if err := cm.LoadConfigs(options.loader.ModelPath); err != nil { log.Error().Msgf("error loading config files: %s", err.Error()) } - if configFile != "" { - if err := cm.LoadConfigFile(configFile); err != nil { + if options.configFile != "" { + if err := cm.LoadConfigFile(options.configFile); err != nil { log.Error().Msgf("error loading config file: %s", err.Error()) } } - if debug { + if options.debug { for _, v := range cm.ListConfigs() { cfg, _ := cm.GetConfig(v) log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) @@ -68,46 +68,55 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload } // Default middleware config app.Use(recover.New()) - app.Use(cors.New()) + + if options.cors { + if options.corsAllowOrigins == "" { + app.Use(cors.New()) + } else { + app.Use(cors.New(cors.Config{ + AllowOrigins: options.corsAllowOrigins, + })) + } + } // LocalAI API endpoints - applier := newGalleryApplier(loader.ModelPath) - applier.start(c, cm) - app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C)) + applier := newGalleryApplier(options.loader.ModelPath) + applier.start(options.context, cm) + app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C)) app.Get("/models/jobs/:uuid", getOpStatus(applier)) // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16)) - app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/chat/completions", chatEndpoint(cm, options)) + app.Post("/chat/completions", chatEndpoint(cm, options)) // edit - app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16)) - app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/edits", editEndpoint(cm, options)) + app.Post("/edits", editEndpoint(cm, options)) // completion - app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16)) - app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/completions", completionEndpoint(cm, options)) + app.Post("/completions", completionEndpoint(cm, options)) // embeddings - app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) - app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) - app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/embeddings", embeddingsEndpoint(cm, options)) + app.Post("/embeddings", embeddingsEndpoint(cm, options)) + app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options)) // audio - app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16)) + app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options)) // images - app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir)) + app.Post("/v1/images/generations", imageEndpoint(cm, options)) - if imageDir != "" { - app.Static("/generated-images", imageDir) + if options.imageDir != "" { + app.Static("/generated-images", options.imageDir) } // models - app.Get("/v1/models", listModels(loader, cm)) - app.Get("/models", listModels(loader, cm)) + app.Get("/v1/models", listModels(options.loader, cm)) + app.Get("/models", listModels(options.loader, cm)) return app } diff --git a/api/api_test.go b/api/api_test.go index f061527..4b24514 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -114,7 +114,7 @@ var _ = Describe("API test", func() { modelLoader = model.NewModelLoader(tmpdir) c, cancel = context.WithCancel(context.Background()) - app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "") + app = App(WithContext(c), WithModelLoader(modelLoader)) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -198,7 +198,7 @@ var _ = Describe("API test", func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) c, cancel = context.WithCancel(context.Background()) - app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "") + app = App(WithContext(c), WithModelLoader(modelLoader)) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") @@ -316,7 +316,7 @@ var _ = Describe("API test", func() { modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) c, cancel = context.WithCancel(context.Background()) - app = App(c, os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "") + app = App(WithContext(c), WithModelLoader(modelLoader), WithConfigFile(os.Getenv("CONFIG_FILE"))) go app.Listen("127.0.0.1:9090") defaultConfig := openai.DefaultConfig("") diff --git a/api/openai.go b/api/openai.go index 0a85349..dffdcbf 100644 --- a/api/openai.go +++ b/api/openai.go @@ -142,15 +142,15 @@ func defaultRequest(modelFile string) OpenAIRequest { } // https://platform.openai.com/docs/api-reference/completions -func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - model, input, err := readInput(c, loader, true) + model, input, err := readInput(c, o.loader, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) + config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -166,7 +166,7 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, var result []Choice for _, i := range config.PromptStrings { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(templateFile, struct { + templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { Input string }{Input: i}) if err == nil { @@ -174,7 +174,7 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, log.Debug().Msgf("Template found, input modified to: %s", i) } - r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) { + r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) { *c = append(*c, Choice{Text: s}) }, nil) if err != nil { @@ -199,14 +199,14 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, } // https://platform.openai.com/docs/api-reference/embeddings -func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - model, input, err := readInput(c, loader, true) + model, input, err := readInput(c, o.loader, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) + config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -216,7 +216,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, for i, s := range config.InputToken { // get the model function to call for the result - embedFn, err := ModelEmbedding("", s, loader, *config) + embedFn, err := ModelEmbedding("", s, o.loader, *config) if err != nil { return err } @@ -230,7 +230,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, for i, s := range config.InputStrings { // get the model function to call for the result - embedFn, err := ModelEmbedding(s, []int{}, loader, *config) + embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config) if err != nil { return err } @@ -256,7 +256,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, } } -func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { @@ -273,12 +273,12 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa close(responses) } return func(c *fiber.Ctx) error { - model, input, err := readInput(c, loader, true) + model, input, err := readInput(c, o.loader, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) + config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -319,7 +319,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa } // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(templateFile, struct { + templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { Input string }{Input: predInput}) if err == nil { @@ -330,7 +330,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa if input.Stream { responses := make(chan OpenAIResponse) - go process(predInput, input, config, loader, responses) + go process(predInput, input, config, o.loader, responses) c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { @@ -358,7 +358,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa return nil } - result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { + result, err := ComputeChoices(predInput, input, config, o.loader, func(s string, c *[]Choice) { *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}}) }, nil) if err != nil { @@ -378,14 +378,14 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa } } -func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - model, input, err := readInput(c, loader, true) + model, input, err := readInput(c, o.loader, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) + config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -401,7 +401,7 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa var result []Choice for _, i := range config.InputStrings { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(templateFile, struct { + templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { Input string Instruction string }{Input: i}) @@ -410,7 +410,7 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa log.Debug().Msgf("Template found, input modified to: %s", i) } - r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) { + r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) { *c = append(*c, Choice{Text: s}) }, nil) if err != nil { @@ -449,9 +449,9 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa * */ -func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error { +func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readInput(c, loader, false) + m, input, err := readInput(c, o.loader, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -461,7 +461,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag } log.Debug().Msgf("Loading model: %+v", m) - config, input, err := readConfig(m, input, cm, loader, debug, 0, 0, false) + config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -518,7 +518,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag tempDir := "" if !b64JSON { - tempDir = imageDir + tempDir = o.imageDir } // Create a temporary file outputFile, err := ioutil.TempFile(tempDir, "b64") @@ -535,7 +535,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag baseURL := c.BaseURL() - fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, loader, *config) + fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config) if err != nil { return err } @@ -574,14 +574,14 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag } // https://platform.openai.com/docs/api-reference/audio/create -func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { +func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readInput(c, loader, false) + m, input, err := readInput(c, o.loader, false) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } - config, input, err := readConfig(m, input, cm, loader, debug, threads, ctx, f16) + config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } @@ -616,7 +616,7 @@ func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, log.Debug().Msgf("Audio file copied to: %+v", dst) - whisperModel, err := loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads)) + whisperModel, err := o.loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads)) if err != nil { return err } diff --git a/api/options.go b/api/options.go new file mode 100644 index 0000000..f99dda4 --- /dev/null +++ b/api/options.go @@ -0,0 +1,108 @@ +package api + +import ( + "context" + + model "github.com/go-skynet/LocalAI/pkg/model" +) + +type Option struct { + context context.Context + configFile string + loader *model.ModelLoader + uploadLimitMB, threads, ctxSize int + f16 bool + debug, disableMessage bool + imageDir string + cors bool + corsAllowOrigins string +} + +type AppOption func(*Option) + +func newOptions(o ...AppOption) *Option { + opt := &Option{ + context: context.Background(), + uploadLimitMB: 15, + threads: 1, + ctxSize: 512, + debug: true, + disableMessage: true, + } + for _, oo := range o { + oo(opt) + } + return opt +} + +func WithCors(b bool) AppOption { + return func(o *Option) { + o.cors = b + } +} + +func WithCorsAllowOrigins(b string) AppOption { + return func(o *Option) { + o.corsAllowOrigins = b + } +} + +func WithContext(ctx context.Context) AppOption { + return func(o *Option) { + o.context = ctx + } +} + +func WithConfigFile(configFile string) AppOption { + return func(o *Option) { + o.configFile = configFile + } +} + +func WithModelLoader(loader *model.ModelLoader) AppOption { + return func(o *Option) { + o.loader = loader + } +} + +func WithUploadLimitMB(limit int) AppOption { + return func(o *Option) { + o.uploadLimitMB = limit + } +} + +func WithThreads(threads int) AppOption { + return func(o *Option) { + o.threads = threads + } +} + +func WithContextSize(ctxSize int) AppOption { + return func(o *Option) { + o.ctxSize = ctxSize + } +} + +func WithF16(f16 bool) AppOption { + return func(o *Option) { + o.f16 = f16 + } +} + +func WithDebug(debug bool) AppOption { + return func(o *Option) { + o.debug = debug + } +} + +func WithDisableMessage(disableMessage bool) AppOption { + return func(o *Option) { + o.disableMessage = disableMessage + } +} + +func WithImageDir(imageDir string) AppOption { + return func(o *Option) { + o.imageDir = imageDir + } +} diff --git a/main.go b/main.go index f3ffc03..c52399e 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "fmt" "os" "path/filepath" @@ -34,6 +33,14 @@ func main() { Name: "debug", EnvVars: []string{"DEBUG"}, }, + &cli.BoolFlag{ + Name: "cors", + EnvVars: []string{"CORS"}, + }, + &cli.StringFlag{ + Name: "cors-allow-origins", + EnvVars: []string{"CORS_ALLOW_ORIGINS"}, + }, &cli.IntFlag{ Name: "threads", DefaultText: "Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested.", @@ -94,7 +101,17 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. Copyright: "go-skynet authors", Action: func(ctx *cli.Context) error { fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) - return api.App(context.Background(), ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-path")).Listen(ctx.String("address")) + return api.App( + api.WithConfigFile(ctx.String("config-file")), + api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), + api.WithContextSize(ctx.Int("context-size")), + api.WithDebug(ctx.Bool("debug")), + api.WithImageDir(ctx.String("image-path")), + api.WithF16(ctx.Bool("f16")), + api.WithCors(ctx.Bool("cors")), + api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), + api.WithThreads(ctx.Int("threads")), + api.WithUploadLimitMB(ctx.Int("upload-limit"))).Listen(ctx.String("address")) }, }