diff --git a/api/config.go b/api/config.go index 1ad0352..8e550e1 100644 --- a/api/config.go +++ b/api/config.go @@ -24,6 +24,9 @@ type Config struct { Embeddings bool `yaml:"embeddings"` Backend string `yaml:"backend"` TemplateConfig TemplateConfig `yaml:"template"` + MirostatETA float64 `yaml:"mirostat_eta"` + MirostatTAU float64 `yaml:"mirostat_tau"` + Mirostat int `yaml:"mirostat"` } type TemplateConfig struct { diff --git a/api/openai.go b/api/openai.go index 2ac966f..fc982f2 100644 --- a/api/openai.go +++ b/api/openai.go @@ -100,6 +100,10 @@ type OpenAIRequest struct { RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` Keep int `json:"n_keep" yaml:"n_keep"` + MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"` + MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"` + Mirostat int `json:"mirostat" yaml:"mirostat"` + Seed int `json:"seed" yaml:"seed"` } @@ -168,6 +172,18 @@ func updateConfig(config *Config, input *OpenAIRequest) { if input.Seed != 0 { config.Seed = input.Seed } + + if input.Mirostat != 0 { + config.Mirostat = input.Mirostat + } + + if input.MirostatETA != 0 { + config.MirostatETA = input.MirostatETA + } + + if input.MirostatTAU != 0 { + config.MirostatTAU = input.MirostatTAU + } } func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { diff --git a/api/prediction.go b/api/prediction.go index b2dfbb1..45db078 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -206,6 +206,18 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback llama.SetThreads(c.Threads), } + if c.Mirostat != 0 { + predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) + } + + if c.MirostatETA != 0 { + predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) + } + + if c.MirostatTAU != 0 { + predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) + } + if c.Debug { predictOptions = append(predictOptions, llama.Debug) }