From d517a54e28e0557ed6311f2f8db222937cef52fb Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Thu, 20 Apr 2023 18:33:02 +0200 Subject: [PATCH] Major API enhancements (#44) --- .env | 1 + README.md | 2 +- api/api.go | 143 ++++++++++++++++++++++++++++---------------- docker-compose.yaml | 1 + go.mod | 3 +- go.sum | 10 ++++ main.go | 17 +++++- pkg/model/loader.go | 96 +++++++++++++++-------------- 8 files changed, 170 insertions(+), 103 deletions(-) diff --git a/.env b/.env index d53b55b..90d038d 100644 --- a/.env +++ b/.env @@ -1,3 +1,4 @@ THREADS=14 CONTEXT_SIZE=512 MODELS_PATH=/models +# DEBUG=true \ No newline at end of file diff --git a/README.md b/README.md index 38efd46..6cc217b 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ See the [prompt-templates](https://github.com/go-skynet/LocalAI/tree/master/prom Example of starting the API with `docker`: ```bash -docker run -p 8080:8080 -ti --rm quay.io/go-skynet/local-api:latest --models-path /path/to/models --context-size 700 --threads 4 +docker run -p 8080:8080 -ti --rm quay.io/go-skynet/local-ai:latest --models-path /path/to/models --context-size 700 --threads 4 ``` And you'll see: diff --git a/api/api.go b/api/api.go index b5d93b9..8c6fc8d 100644 --- a/api/api.go +++ b/api/api.go @@ -1,6 +1,8 @@ package api import ( + "encoding/json" + "errors" "fmt" "strings" "sync" @@ -11,6 +13,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/cors" "github.com/gofiber/fiber/v2/middleware/recover" + "github.com/rs/zerolog/log" ) type OpenAIResponse struct { @@ -65,7 +68,7 @@ type OpenAIRequest struct { } // https://platform.openai.com/docs/api-reference/completions -func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, defaultMutex *sync.Mutex, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { +func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 bool, mutexMap *sync.Mutex, mutexes map[string]*sync.Mutex) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { var err error var model *llama.LLama @@ -76,45 +79,52 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 if err := c.BodyParser(input); err != nil { return err } + modelFile := input.Model + received, _ := json.Marshal(input) - if input.Model == "" { + log.Debug().Msgf("Request received: %s", string(received)) + + // Set model from bearer token, if available + bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) + if modelFile == "" && !bearerExists { return fmt.Errorf("no model specified") - } else { - // Try to load the model with both - var llamaerr error - llamaOpts := []llama.ModelOption{} - if ctx != 0 { - llamaOpts = append(llamaOpts, llama.SetContext(ctx)) - } - if f16 { - llamaOpts = append(llamaOpts, llama.EnableF16Memory) - } + } - model, llamaerr = loader.LoadLLaMAModel(input.Model, llamaOpts...) - if llamaerr != nil { - gptModel, err = loader.LoadGPTJModel(input.Model) - if err != nil { - return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors - } + if bearerExists { // model specified in bearer token takes precedence + log.Debug().Msgf("Using model from bearer token: %s", bearer) + modelFile = bearer + } + + // Try to load the model with both + var llamaerr error + llamaOpts := []llama.ModelOption{} + if ctx != 0 { + llamaOpts = append(llamaOpts, llama.SetContext(ctx)) + } + if f16 { + llamaOpts = append(llamaOpts, llama.EnableF16Memory) + } + + model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...) + if llamaerr != nil { + gptModel, err = loader.LoadGPTJModel(modelFile) + if err != nil { + return fmt.Errorf("llama: %s gpt: %s", llamaerr.Error(), err.Error()) // llama failed first, so we want to catch both errors } } // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - if input.Model != "" { - mutexMap.Lock() - l, ok := mutexes[input.Model] - if !ok { - m := &sync.Mutex{} - mutexes[input.Model] = m - l = m - } - mutexMap.Unlock() - l.Lock() - defer l.Unlock() - } else { - defaultMutex.Lock() - defer defaultMutex.Unlock() + mutexMap.Lock() + l, ok := mutexes[modelFile] + if !ok { + m := &sync.Mutex{} + mutexes[modelFile] = m + l = m } + mutexMap.Unlock() + l.Lock() + defer l.Unlock() // Set the parameters for the language model prediction topP := input.TopP @@ -139,6 +149,7 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 predInput := input.Prompt if chat { mess := []string{} + // TODO: encode roles for _, i := range input.Messages { mess = append(mess, i.Content) } @@ -147,11 +158,12 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 } // A model can have a "file.bin.tmpl" file associated with a prompt template prefix - templatedInput, err := loader.TemplatePrefix(input.Model, struct { + templatedInput, err := loader.TemplatePrefix(modelFile, struct { Input string }{Input: predInput}) if err == nil { predInput = templatedInput + log.Debug().Msgf("Template found, input modified to: %s", predInput) } result := []Choice{} @@ -223,8 +235,6 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 } for i := 0; i < n; i++ { - var prediction string - prediction, err := predFunc() if err != nil { return err @@ -241,30 +251,19 @@ func openAIEndpoint(chat bool, loader *model.ModelLoader, threads, ctx int, f16 } } + jsonResult, _ := json.Marshal(result) + log.Debug().Msgf("Response: %s", jsonResult) + // Return the prediction in the response body return c.JSON(OpenAIResponse{ - Model: input.Model, + Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, }) } } -func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error { - app := fiber.New() - - // Default middleware config - app.Use(recover.New()) - app.Use(cors.New()) - - // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 - var mutex = &sync.Mutex{} - mu := map[string]*sync.Mutex{} - var mumutex = &sync.Mutex{} - - // openAI compatible API endpoint - app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mutex, mumutex, mu)) - app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mutex, mumutex, mu)) - app.Get("/v1/models", func(c *fiber.Ctx) error { +func listModels(loader *model.ModelLoader) func(ctx *fiber.Ctx) error { + return func(c *fiber.Ctx) error { models, err := loader.ListModels() if err != nil { return err @@ -281,8 +280,48 @@ func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f Object: "list", Data: dataModels, }) + } +} + +func Start(loader *model.ModelLoader, listenAddr string, threads, ctxSize int, f16 bool) error { + // Return errors as JSON responses + app := fiber.New(fiber.Config{ + // Override default error handler + ErrorHandler: func(ctx *fiber.Ctx, err error) error { + // Status code defaults to 500 + code := fiber.StatusInternalServerError + + // Retrieve the custom status code if it's a *fiber.Error + var e *fiber.Error + if errors.As(err, &e) { + code = e.Code + } + + // Send custom error page + return ctx.Status(code).JSON(struct { + Error string `json:"error"` + }{Error: err.Error()}) + }, }) + // Default middleware config + app.Use(recover.New()) + app.Use(cors.New()) + + // This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784 + mu := map[string]*sync.Mutex{} + var mumutex = &sync.Mutex{} + + // openAI compatible API endpoint + app.Post("/v1/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu)) + app.Post("/chat/completions", openAIEndpoint(true, loader, threads, ctxSize, f16, mumutex, mu)) + + app.Post("/v1/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu)) + app.Post("/completions", openAIEndpoint(false, loader, threads, ctxSize, f16, mumutex, mu)) + + app.Get("/v1/models", listModels(loader)) + app.Get("/models", listModels(loader)) + // Start the server app.Listen(listenAddr) return nil diff --git a/docker-compose.yaml b/docker-compose.yaml index d68177f..ad061c3 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -14,5 +14,6 @@ services: - MODELS_PATH=$MODELS_PATH - CONTEXT_SIZE=$CONTEXT_SIZE - THREADS=$THREADS + - DEBUG=$DEBUG volumes: - ./models:/models:cached \ No newline at end of file diff --git a/go.mod b/go.mod index 732468b..f7375d5 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,16 @@ module github.com/go-skynet/LocalAI go 1.19 require ( + github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 github.com/gofiber/fiber/v2 v2.42.0 + github.com/rs/zerolog v1.29.1 github.com/urfave/cli/v2 v2.25.0 ) require ( github.com/andybalholm/brotli v1.0.4 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect - github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94 // indirect github.com/google/uuid v1.3.0 // indirect github.com/klauspost/compress v1.15.9 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index 3d283f2..41a2071 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,6 @@ github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/go-logr/logr v1.2.3 h1:2DntVwHkVopvECVRSlL5PSo9eG+cAkDCuckLubN+rq0= @@ -8,6 +9,7 @@ github.com/go-skynet/go-gpt4all-j.cpp v0.0.0-20230419091210-303cf2a59a94/go.mod github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640 h1:8SSVbQ3yvq7JnfLCLF4USV0PkQnnduUkaNCv/hHDa3E= github.com/go-skynet/go-llama.cpp v0.0.0-20230415213228-bac222030640/go.mod h1:35AKIEMY+YTKCBJIa/8GZcNGJ2J+nQk1hQiWo/OnEWw= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofiber/fiber/v2 v2.42.0 h1:Fnp7ybWvS+sjNQsFvkhf4G8OhXswvB6Vee8hM/LyS+8= github.com/gofiber/fiber/v2 v2.42.0/go.mod h1:3+SGNjqMh5VQH5Vz2Wdi43zTIV16ktlFd3x3R6O1Zlc= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -16,8 +18,10 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -27,8 +31,12 @@ github.com/onsi/ginkgo/v2 v2.9.2 h1:BA2GMJOtfGAfagzYtrAlufIP0lq6QERkFmHLMLPwFSU= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/philhofer/fwd v1.1.1 h1:GdGcTjf5RNAxwS4QLsiMzJYj5KEvPJD3Abr261yRQXQ= github.com/philhofer/fwd v1.1.1/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= +github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/savsgio/dictpool v0.0.0-20221023140959-7bf2e61cea94 h1:rmMl4fXJhKMNWl+K+r/fq4FbbKI+Ia2m9hYBLm2h4G4= @@ -67,6 +75,8 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= diff --git a/main.go b/main.go index de956ec..6a2051b 100644 --- a/main.go +++ b/main.go @@ -1,20 +1,23 @@ package main import ( - "fmt" "os" "runtime" api "github.com/go-skynet/LocalAI/api" model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/urfave/cli/v2" ) func main() { + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + path, err := os.Getwd() if err != nil { - fmt.Println(err) + log.Error().Msgf("error: %s", err.Error()) os.Exit(1) } @@ -26,6 +29,10 @@ func main() { Name: "f16", EnvVars: []string{"F16"}, }, + &cli.BoolFlag{ + Name: "debug", + EnvVars: []string{"DEBUG"}, + }, &cli.IntFlag{ Name: "threads", DefaultText: "Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested.", @@ -66,13 +73,17 @@ It uses llama.cpp and gpt4all as backend, supporting all the models supported by UsageText: `local-ai [options]`, Copyright: "go-skynet authors", Action: func(ctx *cli.Context) error { + zerolog.SetGlobalLevel(zerolog.InfoLevel) + if ctx.Bool("debug") { + zerolog.SetGlobalLevel(zerolog.DebugLevel) + } return api.Start(model.NewModelLoader(ctx.String("models-path")), ctx.String("address"), ctx.Int("threads"), ctx.Int("context-size"), ctx.Bool("f16")) }, } err = app.Run(os.Args) if err != nil { - fmt.Println(err) + log.Error().Msgf("error: %s", err.Error()) os.Exit(1) } } diff --git a/pkg/model/loader.go b/pkg/model/loader.go index 83a9fa4..09f57db 100644 --- a/pkg/model/loader.go +++ b/pkg/model/loader.go @@ -10,6 +10,8 @@ import ( "sync" "text/template" + "github.com/rs/zerolog/log" + gptj "github.com/go-skynet/go-gpt4all-j.cpp" llama "github.com/go-skynet/go-llama.cpp" ) @@ -26,6 +28,11 @@ func NewModelLoader(modelPath string) *ModelLoader { return &ModelLoader{modelPath: modelPath, gptmodels: make(map[string]*gptj.GPTJ), models: make(map[string]*llama.LLama), promptsTemplates: make(map[string]*template.Template)} } +func (ml *ModelLoader) ExistsInModelPath(s string) bool { + _, err := os.Stat(filepath.Join(ml.modelPath, s)) + return err == nil +} + func (ml *ModelLoader) ListModels() ([]string, error) { files, err := ioutil.ReadDir(ml.modelPath) if err != nil { @@ -34,9 +41,12 @@ func (ml *ModelLoader) ListModels() ([]string, error) { models := []string{} for _, file := range files { - if strings.HasSuffix(file.Name(), ".bin") { - models = append(models, strings.TrimRight(file.Name(), ".bin")) + // Skip templates, YAML and .keep files + if strings.HasSuffix(file.Name(), ".tmpl") || strings.HasSuffix(file.Name(), ".keep") || strings.HasSuffix(file.Name(), ".yaml") || strings.HasSuffix(file.Name(), ".yml") { + continue } + + models = append(models, file.Name()) } return models, nil @@ -48,12 +58,7 @@ func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, m, ok := ml.promptsTemplates[modelName] if !ok { - // try to find a s.bin - modelBin := fmt.Sprintf("%s.bin", modelName) - m, ok = ml.promptsTemplates[modelBin] - if !ok { - return "", fmt.Errorf("no prompt template available") - } + return "", fmt.Errorf("no prompt template available") } var buf bytes.Buffer @@ -64,15 +69,21 @@ func (ml *ModelLoader) TemplatePrefix(modelName string, in interface{}) (string, return buf.String(), nil } -func (ml *ModelLoader) loadTemplate(modelName, modelFile string) error { - modelTemplateFile := fmt.Sprintf("%s.tmpl", modelFile) +func (ml *ModelLoader) loadTemplateIfExists(modelName, modelFile string) error { + // Check if the template was already loaded + if _, ok := ml.promptsTemplates[modelName]; ok { + return nil + } // Check if the model path exists - if _, err := os.Stat(modelTemplateFile); err != nil { + // skip any error here - we run anyway if a template is not exist + modelTemplateFile := fmt.Sprintf("%s.tmpl", modelName) + + if !ml.ExistsInModelPath(modelTemplateFile) { return nil } - dat, err := os.ReadFile(modelTemplateFile) + dat, err := os.ReadFile(filepath.Join(ml.modelPath, modelTemplateFile)) if err != nil { return err } @@ -92,36 +103,30 @@ func (ml *ModelLoader) LoadGPTJModel(modelName string) (*gptj.GPTJ, error) { defer ml.mu.Unlock() // Check if we already have a loaded model - modelFile := filepath.Join(ml.modelPath, modelName) - - if m, ok := ml.gptmodels[modelFile]; ok { - return m, nil + if !ml.ExistsInModelPath(modelName) { + return nil, fmt.Errorf("model does not exist") } - // Check if the model path exists - if _, err := os.Stat(modelFile); os.IsNotExist(err) { - // try to find a s.bin - modelBin := fmt.Sprintf("%s.bin", modelFile) - if _, err := os.Stat(modelBin); os.IsNotExist(err) { - return nil, err - } else { - modelName = fmt.Sprintf("%s.bin", modelName) - modelFile = modelBin - } + if m, ok := ml.gptmodels[modelName]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", modelName) + return m, nil } // Load the model and keep it in memory for later use + modelFile := filepath.Join(ml.modelPath, modelName) + log.Debug().Msgf("Loading model in memory from file: %s", modelFile) + model, err := gptj.New(modelFile) if err != nil { return nil, err } // If there is a prompt template, load it - if err := ml.loadTemplate(modelName, modelFile); err != nil { + if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil { return nil, err } - ml.gptmodels[modelFile] = model + ml.gptmodels[modelName] = model return model, err } @@ -129,40 +134,39 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio ml.mu.Lock() defer ml.mu.Unlock() + log.Debug().Msgf("Loading model name: %s", modelName) + // Check if we already have a loaded model - modelFile := filepath.Join(ml.modelPath, modelName) - if m, ok := ml.models[modelFile]; ok { + if !ml.ExistsInModelPath(modelName) { + return nil, fmt.Errorf("model does not exist") + } + + if m, ok := ml.models[modelName]; ok { + log.Debug().Msgf("Model already loaded in memory: %s", modelName) return m, nil } + // TODO: This needs refactoring, it's really bad to have it in here - // Check if we have a GPTJ model loaded instead - if _, ok := ml.gptmodels[modelFile]; ok { + // Check if we have a GPTJ model loaded instead - if we do we return an error so the API tries with GPTJ + if _, ok := ml.gptmodels[modelName]; ok { + log.Debug().Msgf("Model is GPTJ: %s", modelName) return nil, fmt.Errorf("this model is a GPTJ one") } - // Check if the model path exists - if _, err := os.Stat(modelFile); os.IsNotExist(err) { - // try to find a s.bin - modelBin := fmt.Sprintf("%s.bin", modelFile) - if _, err := os.Stat(modelBin); os.IsNotExist(err) { - return nil, err - } else { - modelName = fmt.Sprintf("%s.bin", modelName) - modelFile = modelBin - } - } - // Load the model and keep it in memory for later use + modelFile := filepath.Join(ml.modelPath, modelName) + log.Debug().Msgf("Loading model in memory from file: %s", modelFile) + model, err := llama.New(modelFile, opts...) if err != nil { return nil, err } // If there is a prompt template, load it - if err := ml.loadTemplate(modelName, modelFile); err != nil { + if err := ml.loadTemplateIfExists(modelName, modelFile); err != nil { return nil, err } - ml.models[modelFile] = model + ml.models[modelName] = model return model, err }