package api import ( "encoding/json" "fmt" "io/ioutil" "os" "path/filepath" "strings" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" "github.com/rs/zerolog/log" "gopkg.in/yaml.v3" ) type Config struct { OpenAIRequest `yaml:"parameters"` Name string `yaml:"name"` StopWords []string `yaml:"stopwords"` Cutstrings []string `yaml:"cutstrings"` TrimSpace []string `yaml:"trimspace"` ContextSize int `yaml:"context_size"` F16 bool `yaml:"f16"` Threads int `yaml:"threads"` Debug bool `yaml:"debug"` Roles map[string]string `yaml:"roles"` Embeddings bool `yaml:"embeddings"` Backend string `yaml:"backend"` TemplateConfig TemplateConfig `yaml:"template"` MirostatETA float64 `yaml:"mirostat_eta"` MirostatTAU float64 `yaml:"mirostat_tau"` Mirostat int `yaml:"mirostat"` PromptStrings, InputStrings []string InputToken [][]int } type TemplateConfig struct { Completion string `yaml:"completion"` Chat string `yaml:"chat"` Edit string `yaml:"edit"` } type ConfigMerger map[string]Config func ReadConfigFile(file string) ([]*Config, error) { c := &[]*Config{} f, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("cannot read config file: %w", err) } if err := yaml.Unmarshal(f, c); err != nil { return nil, fmt.Errorf("cannot unmarshal config file: %w", err) } return *c, nil } func ReadConfig(file string) (*Config, error) { c := &Config{} f, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("cannot read config file: %w", err) } if err := yaml.Unmarshal(f, c); err != nil { return nil, fmt.Errorf("cannot unmarshal config file: %w", err) } return c, nil } func (cm ConfigMerger) LoadConfigFile(file string) error { c, err := ReadConfigFile(file) if err != nil { return fmt.Errorf("cannot load config file: %w", err) } for _, cc := range c { cm[cc.Name] = *cc } return nil } func (cm ConfigMerger) LoadConfig(file string) error { c, err := ReadConfig(file) if err != nil { return fmt.Errorf("cannot read config file: %w", err) } cm[c.Name] = *c return nil } func (cm ConfigMerger) LoadConfigs(path string) error { files, err := ioutil.ReadDir(path) if err != nil { return err } for _, file := range files { // Skip templates, YAML and .keep files if !strings.Contains(file.Name(), ".yaml") { continue } c, err := ReadConfig(filepath.Join(path, file.Name())) if err == nil { cm[c.Name] = *c } } return nil } func updateConfig(config *Config, input *OpenAIRequest) { if input.Echo { config.Echo = input.Echo } if input.TopK != 0 { config.TopK = input.TopK } if input.TopP != 0 { config.TopP = input.TopP } if input.Temperature != 0 { config.Temperature = input.Temperature } if input.Maxtokens != 0 { config.Maxtokens = input.Maxtokens } switch stop := input.Stop.(type) { case string: if stop != "" { config.StopWords = append(config.StopWords, stop) } case []interface{}: for _, pp := range stop { if s, ok := pp.(string); ok { config.StopWords = append(config.StopWords, s) } } } if input.RepeatPenalty != 0 { config.RepeatPenalty = input.RepeatPenalty } if input.Keep != 0 { config.Keep = input.Keep } if input.Batch != 0 { config.Batch = input.Batch } if input.F16 { config.F16 = input.F16 } if input.IgnoreEOS { config.IgnoreEOS = input.IgnoreEOS } if input.Seed != 0 { config.Seed = input.Seed } if input.Mirostat != 0 { config.Mirostat = input.Mirostat } if input.MirostatETA != 0 { config.MirostatETA = input.MirostatETA } if input.MirostatTAU != 0 { config.MirostatTAU = input.MirostatTAU } switch inputs := input.Input.(type) { case string: if inputs != "" { config.InputStrings = append(config.InputStrings, inputs) } case []interface{}: for _, pp := range inputs { switch i := pp.(type) { case string: config.InputStrings = append(config.InputStrings, i) case []interface{}: tokens := []int{} for _, ii := range i { tokens = append(tokens, int(ii.(float64))) } config.InputToken = append(config.InputToken, tokens) } } } switch p := input.Prompt.(type) { case string: config.PromptStrings = append(config.PromptStrings, p) case []interface{}: for _, pp := range p { if s, ok := pp.(string); ok { config.PromptStrings = append(config.PromptStrings, s) } } } } func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { input := new(OpenAIRequest) // Get input data from the request body if err := c.BodyParser(input); err != nil { return nil, nil, err } modelFile := input.Model if c.Params("model") != "" { modelFile = c.Params("model") } received, _ := json.Marshal(input) log.Debug().Msgf("Request received: %s", string(received)) // Set model from bearer token, if available bearer := strings.TrimLeft(c.Get("authorization"), "Bearer ") bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) // If no model was specified, take the first available if modelFile == "" && !bearerExists { models, _ := loader.ListModels() if len(models) > 0 { modelFile = models[0] log.Debug().Msgf("No model specified, using: %s", modelFile) } else { log.Debug().Msgf("No model specified, returning error") return nil, nil, fmt.Errorf("no model specified") } } // If a model is found in bearer token takes precedence if bearerExists { log.Debug().Msgf("Using model from bearer token: %s", bearer) modelFile = bearer } // Load a config file if present after the model name modelConfig := filepath.Join(loader.ModelPath, modelFile+".yaml") if _, err := os.Stat(modelConfig); err == nil { if err := cm.LoadConfig(modelConfig); err != nil { return nil, nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) } } var config *Config cfg, exists := cm[modelFile] if !exists { config = &Config{ OpenAIRequest: defaultRequest(modelFile), ContextSize: ctx, Threads: threads, F16: f16, Debug: debug, } } else { config = &cfg } // Set the parameters for the language model prediction updateConfig(config, input) // Don't allow 0 as setting if config.Threads == 0 { if threads != 0 { config.Threads = threads } else { config.Threads = 4 } } return config, input, nil }