diff --git a/api/config.go b/api/config.go index d5df3de..b032d15 100644 --- a/api/config.go +++ b/api/config.go @@ -1,12 +1,16 @@ package api import ( + "encoding/json" "fmt" "io/ioutil" "os" "path/filepath" "strings" + model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) @@ -106,3 +110,172 @@ func (cm ConfigMerger) LoadConfigs(path string) error { return nil } + +func updateConfig(config *Config, input *OpenAIRequest) { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != 0 { + config.TopK = input.TopK + } + if input.TopP != 0 { + config.TopP = input.TopP + } + + if input.Temperature != 0 { + config.Temperature = input.Temperature + } + + if input.Maxtokens != 0 { + config.Maxtokens = input.Maxtokens + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.F16 { + config.F16 = input.F16 + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != 0 { + config.Seed = input.Seed + } + + if input.Mirostat != 0 { + config.Mirostat = input.Mirostat + } + + if input.MirostatETA != 0 { + config.MirostatETA = input.MirostatETA + } + + if input.MirostatTAU != 0 { + config.MirostatTAU = input.MirostatTAU + } + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + if s, ok := pp.(string); ok { + config.InputStrings = append(config.InputStrings, s) + } + } + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } +} + +func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { + input := new(OpenAIRequest) + // Get input data from the request body + if err := c.BodyParser(input); err != nil { + return nil, nil, err + } + + modelFile := input.Model + + if c.Params("model") != "" { + modelFile = c.Params("model") + } + + received, _ := json.Marshal(input) + + log.Debug().Msgf("Request received: %s", string(received)) + + // Set model from bearer token, if available + bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") + bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) + + // If no model was specified, take the first available + if modelFile == "" && !bearerExists { + models, _ := loader.ListModels() + if len(models) > 0 { + modelFile = models[0] + log.Debug().Msgf("No model specified, using: %s", modelFile) + } else { + log.Debug().Msgf("No model specified, returning error") + return nil, nil, fmt.Errorf("no model specified") + } + } + + // If a model is found in bearer token takes precedence + if bearerExists { + log.Debug().Msgf("Using model from bearer token: %s", bearer) + modelFile = bearer + } + + // Load a config file if present after the model name + modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") + if _, err := os.Stat(modelConfig); err == nil { + if err := cm.LoadConfig(modelConfig); err != nil { + return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) + } + } + + var config *Config + cfg, exists := cm[modelFile] + if !exists { + config = &Config{ + OpenAIRequest: defaultRequest(modelFile), + ContextSize: ctx, + Threads: threads, + F16: f16, + Debug: debug, + } + } else { + config = &cfg + } + + // Set the parameters for the language model prediction + updateConfig(config, input) + + // Don't allow 0 as setting + if config.Threads == 0 { + if threads != 0 { + config.Threads = threads + } else { + config.Threads = 4 + } + } + + return config, input, nil +} diff --git a/api/openai.go b/api/openai.go index 6061e35..d98dc56 100644 --- a/api/openai.go +++ b/api/openai.go @@ -5,8 +5,6 @@ import ( "bytes" "encoding/json" "fmt" - "os" - "path/filepath" "strings" model "github.com/go-skynet/LocalAI/pkg/model" @@ -117,166 +115,6 @@ func defaultRequest(modelFile string) OpenAIRequest { } } -func updateConfig(config *Config, input *OpenAIRequest) { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != 0 { - config.TopK = input.TopK - } - if input.TopP != 0 { - config.TopP = input.TopP - } - - if input.Temperature != 0 { - config.Temperature = input.Temperature - } - - if input.Maxtokens != 0 { - config.Maxtokens = input.Maxtokens - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.F16 { - config.F16 = input.F16 - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != 0 { - config.Seed = input.Seed - } - - if input.Mirostat != 0 { - config.Mirostat = input.Mirostat - } - - if input.MirostatETA != 0 { - config.MirostatETA = input.MirostatETA - } - - if input.MirostatTAU != 0 { - config.MirostatTAU = input.MirostatTAU - } - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []interface{}: - for _, pp := range inputs { - if s, ok := pp.(string); ok { - config.InputStrings = append(config.InputStrings, s) - } - } - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } -} - -func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { - input := new(OpenAIRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return nil, nil, err - } - - modelFile := input.Model - - if c.Params("model") != "" { - modelFile = c.Params("model") - } - - received, _ := json.Marshal(input) - - log.Debug().Msgf("Request received: %s", string(received)) - - // Set model from bearer token, if available - bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") - bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) - - // If no model was specified, take the first available - if modelFile == "" && !bearerExists { - models, _ := loader.ListModels() - if len(models) > 0 { - modelFile = models[0] - log.Debug().Msgf("No model specified, using: %s", modelFile) - } else { - log.Debug().Msgf("No model specified, returning error") - return nil, nil, fmt.Errorf("no model specified") - } - } - - // If a model is found in bearer token takes precedence - if bearerExists { - log.Debug().Msgf("Using model from bearer token: %s", bearer) - modelFile = bearer - } - - // Load a config file if present after the model name - modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") - if _, err := os.Stat(modelConfig); err == nil { - if err := cm.LoadConfig(modelConfig); err != nil { - return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) - } - } - - var config *Config - cfg, exists := cm[modelFile] - if !exists { - config = &Config{ - OpenAIRequest: defaultRequest(modelFile), - ContextSize: ctx, - Threads: threads, - F16: f16, - Debug: debug, - } - } else { - config = &cfg - } - - // Set the parameters for the language model prediction - updateConfig(config, input) - - return config, input, nil -} - // https://platform.openai.com/docs/api-reference/completions func completionEndpoint(cm ConfigMerger, debug bool, loader *model.ModelLoader, threads, ctx int, f16 bool) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { diff --git a/tests/fixtures/config.yaml b/tests/fixtures/config.yaml index 9910ffa..3deabf9 100644 --- a/tests/fixtures/config.yaml +++ b/tests/fixtures/config.yaml @@ -1,7 +1,10 @@ - name: list1 parameters: model: testmodel - context_size: 128 + top_p: 80 + top_k: 0.9 + temperature: 0.1 + context_size: 10 stopwords: - "HUMAN:" - "### Response:" @@ -13,8 +16,11 @@ chat: ggml-gpt4all-j - name: list2 parameters: + top_p: 80 + top_k: 0.9 + temperature: 0.1 model: testmodel - context_size: 128 + context_size: 10 stopwords: - "HUMAN:" - "### Response:" diff --git a/tests/fixtures/gpt4.yaml b/tests/fixtures/gpt4.yaml index 54743bd..77b72b3 100644 --- a/tests/fixtures/gpt4.yaml +++ b/tests/fixtures/gpt4.yaml @@ -1,7 +1,10 @@ name: gpt4all parameters: model: testmodel -context_size: 128 + top_p: 80 + top_k: 0.9 + temperature: 0.1 +context_size: 10 stopwords: - "HUMAN:" - "### Response:" diff --git a/tests/fixtures/gpt4_2.yaml b/tests/fixtures/gpt4_2.yaml index 43ef5a1..62d9fdb 100644 --- a/tests/fixtures/gpt4_2.yaml +++ b/tests/fixtures/gpt4_2.yaml @@ -1,7 +1,10 @@ name: gpt4all-2 parameters: model: testmodel -context_size: 128 + top_p: 80 + top_k: 0.9 + temperature: 0.1 +context_size: 10 stopwords: - "HUMAN:" - "### Response:"