feat: allow to set cors (#339)

swagger2
Ettore Di Giacinto 2 years ago committed by GitHub
parent ed5df1e68e
commit 6f54cab3f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 69
      api/api.go
  2. 6
      api/api_test.go
  3. 60
      api/openai.go
  4. 108
      api/options.go
  5. 21
      main.go

@ -1,10 +1,8 @@
package api package api
import ( import (
"context"
"errors" "errors"
model "github.com/go-skynet/LocalAI/pkg/model"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/cors"
"github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/logger"
@ -13,16 +11,18 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func App(c context.Context, configFile string, loader *model.ModelLoader, uploadLimitMB, threads, ctxSize int, f16 bool, debug, disableMessage bool, imageDir string) *fiber.App { func App(opts ...AppOption) *fiber.App {
options := newOptions(opts...)
zerolog.SetGlobalLevel(zerolog.InfoLevel) zerolog.SetGlobalLevel(zerolog.InfoLevel)
if debug { if options.debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel) zerolog.SetGlobalLevel(zerolog.DebugLevel)
} }
// Return errors as JSON responses // Return errors as JSON responses
app := fiber.New(fiber.Config{ app := fiber.New(fiber.Config{
BodyLimit: uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB BodyLimit: options.uploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
DisableStartupMessage: disableMessage, DisableStartupMessage: options.disableMessage,
// Override default error handler // Override default error handler
ErrorHandler: func(ctx *fiber.Ctx, err error) error { ErrorHandler: func(ctx *fiber.Ctx, err error) error {
// Status code defaults to 500 // Status code defaults to 500
@ -43,24 +43,24 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload
}, },
}) })
if debug { if options.debug {
app.Use(logger.New(logger.Config{ app.Use(logger.New(logger.Config{
Format: "[${ip}]:${port} ${status} - ${method} ${path}\n", Format: "[${ip}]:${port} ${status} - ${method} ${path}\n",
})) }))
} }
cm := NewConfigMerger() cm := NewConfigMerger()
if err := cm.LoadConfigs(loader.ModelPath); err != nil { if err := cm.LoadConfigs(options.loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error()) log.Error().Msgf("error loading config files: %s", err.Error())
} }
if configFile != "" { if options.configFile != "" {
if err := cm.LoadConfigFile(configFile); err != nil { if err := cm.LoadConfigFile(options.configFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error()) log.Error().Msgf("error loading config file: %s", err.Error())
} }
} }
if debug { if options.debug {
for _, v := range cm.ListConfigs() { for _, v := range cm.ListConfigs() {
cfg, _ := cm.GetConfig(v) cfg, _ := cm.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg) log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
@ -68,46 +68,55 @@ func App(c context.Context, configFile string, loader *model.ModelLoader, upload
} }
// Default middleware config // Default middleware config
app.Use(recover.New()) app.Use(recover.New())
app.Use(cors.New())
if options.cors {
if options.corsAllowOrigins == "" {
app.Use(cors.New())
} else {
app.Use(cors.New(cors.Config{
AllowOrigins: options.corsAllowOrigins,
}))
}
}
// LocalAI API endpoints // LocalAI API endpoints
applier := newGalleryApplier(loader.ModelPath) applier := newGalleryApplier(options.loader.ModelPath)
applier.start(c, cm) applier.start(options.context, cm)
app.Post("/models/apply", applyModelGallery(loader.ModelPath, cm, applier.C)) app.Post("/models/apply", applyModelGallery(options.loader.ModelPath, cm, applier.C))
app.Get("/models/jobs/:uuid", getOpStatus(applier)) app.Get("/models/jobs/:uuid", getOpStatus(applier))
// openAI compatible API endpoint // openAI compatible API endpoint
// chat // chat
app.Post("/v1/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/v1/chat/completions", chatEndpoint(cm, options))
app.Post("/chat/completions", chatEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/chat/completions", chatEndpoint(cm, options))
// edit // edit
app.Post("/v1/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/v1/edits", editEndpoint(cm, options))
app.Post("/edits", editEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/edits", editEndpoint(cm, options))
// completion // completion
app.Post("/v1/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/v1/completions", completionEndpoint(cm, options))
app.Post("/completions", completionEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/completions", completionEndpoint(cm, options))
// embeddings // embeddings
app.Post("/v1/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/v1/embeddings", embeddingsEndpoint(cm, options))
app.Post("/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/embeddings", embeddingsEndpoint(cm, options))
app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/v1/engines/:model/embeddings", embeddingsEndpoint(cm, options))
// audio // audio
app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, debug, loader, threads, ctxSize, f16)) app.Post("/v1/audio/transcriptions", transcriptEndpoint(cm, options))
// images // images
app.Post("/v1/images/generations", imageEndpoint(cm, debug, loader, imageDir)) app.Post("/v1/images/generations", imageEndpoint(cm, options))
if imageDir != "" { if options.imageDir != "" {
app.Static("/generated-images", imageDir) app.Static("/generated-images", options.imageDir)
} }
// models // models
app.Get("/v1/models", listModels(loader, cm)) app.Get("/v1/models", listModels(options.loader, cm))
app.Get("/models", listModels(loader, cm)) app.Get("/models", listModels(options.loader, cm))
return app return app
} }

@ -114,7 +114,7 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(tmpdir) modelLoader = model.NewModelLoader(tmpdir)
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "") app = App(WithContext(c), WithModelLoader(modelLoader))
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
@ -198,7 +198,7 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
app = App(c, "", modelLoader, 15, 1, 512, false, true, true, "") app = App(WithContext(c), WithModelLoader(modelLoader))
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")
@ -316,7 +316,7 @@ var _ = Describe("API test", func() {
modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH")) modelLoader = model.NewModelLoader(os.Getenv("MODELS_PATH"))
c, cancel = context.WithCancel(context.Background()) c, cancel = context.WithCancel(context.Background())
app = App(c, os.Getenv("CONFIG_FILE"), modelLoader, 5, 1, 512, false, true, true, "") app = App(WithContext(c), WithModelLoader(modelLoader), WithConfigFile(os.Getenv("CONFIG_FILE")))
go app.Listen("127.0.0.1:9090") go app.Listen("127.0.0.1:9090")
defaultConfig := openai.DefaultConfig("") defaultConfig := openai.DefaultConfig("")

@ -142,15 +142,15 @@ func defaultRequest(modelFile string) OpenAIRequest {
} }
// https://platform.openai.com/docs/api-reference/completions // https://platform.openai.com/docs/api-reference/completions
func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true) model, input, err := readInput(c, o.loader, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -166,7 +166,7 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
var result []Choice var result []Choice
for _, i := range config.PromptStrings { for _, i := range config.PromptStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix // A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct { templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string Input string
}{Input: i}) }{Input: i})
if err == nil { if err == nil {
@ -174,7 +174,7 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Template found, input modified to: %s", i) log.Debug().Msgf("Template found, input modified to: %s", i)
} }
r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) { r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s}) *c = append(*c, Choice{Text: s})
}, nil) }, nil)
if err != nil { if err != nil {
@ -199,14 +199,14 @@ func completionEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
} }
// https://platform.openai.com/docs/api-reference/embeddings // https://platform.openai.com/docs/api-reference/embeddings
func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true) model, input, err := readInput(c, o.loader, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -216,7 +216,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
for i, s := range config.InputToken { for i, s := range config.InputToken {
// get the model function to call for the result // get the model function to call for the result
embedFn, err := ModelEmbedding("", s, loader, *config) embedFn, err := ModelEmbedding("", s, o.loader, *config)
if err != nil { if err != nil {
return err return err
} }
@ -230,7 +230,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
for i, s := range config.InputStrings { for i, s := range config.InputStrings {
// get the model function to call for the result // get the model function to call for the result
embedFn, err := ModelEmbedding(s, []int{}, loader, *config) embedFn, err := ModelEmbedding(s, []int{}, o.loader, *config)
if err != nil { if err != nil {
return err return err
} }
@ -256,7 +256,7 @@ func embeddingsEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
} }
} }
func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) {
ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool { ComputeChoices(s, req, config, loader, func(s string, c *[]Choice) {}, func(s string) bool {
@ -273,12 +273,12 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
close(responses) close(responses)
} }
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true) model, input, err := readInput(c, o.loader, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -319,7 +319,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
} }
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix // A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct { templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string Input string
}{Input: predInput}) }{Input: predInput})
if err == nil { if err == nil {
@ -330,7 +330,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
if input.Stream { if input.Stream {
responses := make(chan OpenAIResponse) responses := make(chan OpenAIResponse)
go process(predInput, input, config, loader, responses) go process(predInput, input, config, o.loader, responses)
c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) {
@ -358,7 +358,7 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
return nil return nil
} }
result, err := ComputeChoices(predInput, input, config, loader, func(s string, c *[]Choice) { result, err := ComputeChoices(predInput, input, config, o.loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}}) *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}})
}, nil) }, nil)
if err != nil { if err != nil {
@ -378,14 +378,14 @@ func chatEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
} }
} }
func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { func editEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
model, input, err := readInput(c, loader, true) model, input, err := readInput(c, o.loader, true)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
config, input, err := readConfig(model, input, cm, loader, debug, threads, ctx, f16) config, input, err := readConfig(model, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -401,7 +401,7 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
var result []Choice var result []Choice
for _, i := range config.InputStrings { for _, i := range config.InputStrings {
// A model can have a "file.bin.tmpl" file associated with a prompt template prefix // A model can have a "file.bin.tmpl" file associated with a prompt template prefix
templatedInput, err := loader.TemplatePrefix(templateFile, struct { templatedInput, err := o.loader.TemplatePrefix(templateFile, struct {
Input string Input string
Instruction string Instruction string
}{Input: i}) }{Input: i})
@ -410,7 +410,7 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
log.Debug().Msgf("Template found, input modified to: %s", i) log.Debug().Msgf("Template found, input modified to: %s", i)
} }
r, err := ComputeChoices(i, input, config, loader, func(s string, c *[]Choice) { r, err := ComputeChoices(i, input, config, o.loader, func(s string, c *[]Choice) {
*c = append(*c, Choice{Text: s}) *c = append(*c, Choice{Text: s})
}, nil) }, nil)
if err != nil { if err != nil {
@ -449,9 +449,9 @@ func editEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threa
* *
*/ */
func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imageDir string) func(c *fiber.Ctx) error { func imageEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
m, input, err := readInput(c, loader, false) m, input, err := readInput(c, o.loader, false)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -461,7 +461,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
} }
log.Debug().Msgf("Loading model: %+v", m) log.Debug().Msgf("Loading model: %+v", m)
config, input, err := readConfig(m, input, cm, loader, debug, 0, 0, false) config, input, err := readConfig(m, input, cm, o.loader, o.debug, 0, 0, false)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -518,7 +518,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
tempDir := "" tempDir := ""
if !b64JSON { if !b64JSON {
tempDir = imageDir tempDir = o.imageDir
} }
// Create a temporary file // Create a temporary file
outputFile, err := ioutil.TempFile(tempDir, "b64") outputFile, err := ioutil.TempFile(tempDir, "b64")
@ -535,7 +535,7 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
baseURL := c.BaseURL() baseURL := c.BaseURL()
fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, loader, *config) fn, err := ImageGeneration(height, width, mode, step, input.Seed, positive_prompt, negative_prompt, output, o.loader, *config)
if err != nil { if err != nil {
return err return err
} }
@ -574,14 +574,14 @@ func imageEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, imag
} }
// https://platform.openai.com/docs/api-reference/audio/create // https://platform.openai.com/docs/api-reference/audio/create
func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { func transcriptEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
m, input, err := readInput(c, loader, false) m, input, err := readInput(c, o.loader, false)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
config, input, err := readConfig(m, input, cm, loader, debug, threads, ctx, f16) config, input, err := readConfig(m, input, cm, o.loader, o.debug, o.threads, o.ctxSize, o.f16)
if err != nil { if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) return fmt.Errorf("failed reading parameters from request:%w", err)
} }
@ -616,7 +616,7 @@ func transcriptEndpoint(cm *ConfigMerger, debug bool, loader *model.ModelLoader,
log.Debug().Msgf("Audio file copied to: %+v", dst) log.Debug().Msgf("Audio file copied to: %+v", dst)
whisperModel, err := loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads)) whisperModel, err := o.loader.BackendLoader(model.WhisperBackend, config.Model, []llama.ModelOption{}, uint32(config.Threads))
if err != nil { if err != nil {
return err return err
} }

@ -0,0 +1,108 @@
package api
import (
"context"
model "github.com/go-skynet/LocalAI/pkg/model"
)
type Option struct {
context context.Context
configFile string
loader *model.ModelLoader
uploadLimitMB, threads, ctxSize int
f16 bool
debug, disableMessage bool
imageDir string
cors bool
corsAllowOrigins string
}
type AppOption func(*Option)
func newOptions(o ...AppOption) *Option {
opt := &Option{
context: context.Background(),
uploadLimitMB: 15,
threads: 1,
ctxSize: 512,
debug: true,
disableMessage: true,
}
for _, oo := range o {
oo(opt)
}
return opt
}
func WithCors(b bool) AppOption {
return func(o *Option) {
o.cors = b
}
}
func WithCorsAllowOrigins(b string) AppOption {
return func(o *Option) {
o.corsAllowOrigins = b
}
}
func WithContext(ctx context.Context) AppOption {
return func(o *Option) {
o.context = ctx
}
}
func WithConfigFile(configFile string) AppOption {
return func(o *Option) {
o.configFile = configFile
}
}
func WithModelLoader(loader *model.ModelLoader) AppOption {
return func(o *Option) {
o.loader = loader
}
}
func WithUploadLimitMB(limit int) AppOption {
return func(o *Option) {
o.uploadLimitMB = limit
}
}
func WithThreads(threads int) AppOption {
return func(o *Option) {
o.threads = threads
}
}
func WithContextSize(ctxSize int) AppOption {
return func(o *Option) {
o.ctxSize = ctxSize
}
}
func WithF16(f16 bool) AppOption {
return func(o *Option) {
o.f16 = f16
}
}
func WithDebug(debug bool) AppOption {
return func(o *Option) {
o.debug = debug
}
}
func WithDisableMessage(disableMessage bool) AppOption {
return func(o *Option) {
o.disableMessage = disableMessage
}
}
func WithImageDir(imageDir string) AppOption {
return func(o *Option) {
o.imageDir = imageDir
}
}

@ -1,7 +1,6 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@ -34,6 +33,14 @@ func main() {
Name: "debug", Name: "debug",
EnvVars: []string{"DEBUG"}, EnvVars: []string{"DEBUG"},
}, },
&cli.BoolFlag{
Name: "cors",
EnvVars: []string{"CORS"},
},
&cli.StringFlag{
Name: "cors-allow-origins",
EnvVars: []string{"CORS_ALLOW_ORIGINS"},
},
&cli.IntFlag{ &cli.IntFlag{
Name: "threads", Name: "threads",
DefaultText: "Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested.", DefaultText: "Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested.",
@ -94,7 +101,17 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings.
Copyright: "go-skynet authors", Copyright: "go-skynet authors",
Action: func(ctx *cli.Context) error { Action: func(ctx *cli.Context) error {
fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path"))
return api.App(context.Background(), ctx.String("config-file"), model.NewModelLoader(ctx.String("models-path")), ctx.Int("upload-limit"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16"), ctx.Bool("debug"), false, ctx.String("image-path")).Listen(ctx.String("address")) return api.App(
api.WithConfigFile(ctx.String("config-file")),
api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))),
api.WithContextSize(ctx.Int("context-size")),
api.WithDebug(ctx.Bool("debug")),
api.WithImageDir(ctx.String("image-path")),
api.WithF16(ctx.Bool("f16")),
api.WithCors(ctx.Bool("cors")),
api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
api.WithThreads(ctx.Int("threads")),
api.WithUploadLimitMB(ctx.Int("upload-limit"))).Listen(ctx.String("address"))
}, },
} }

Loading…
Cancel
Save