feat: expose mirostat to config (#193)

agent
Ettore Di Giacinto 2 years ago committed by GitHub
parent c839b334eb
commit 961cf29217
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      api/config.go
  2. 16
      api/openai.go
  3. 12
      api/prediction.go

@ -24,6 +24,9 @@ type Config struct {
Embeddings bool `yaml:"embeddings"` Embeddings bool `yaml:"embeddings"`
Backend string `yaml:"backend"` Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"` TemplateConfig TemplateConfig `yaml:"template"`
MirostatETA float64 `yaml:"mirostat_eta"`
MirostatTAU float64 `yaml:"mirostat_tau"`
Mirostat int `yaml:"mirostat"`
} }
type TemplateConfig struct { type TemplateConfig struct {

@ -100,6 +100,10 @@ type OpenAIRequest struct {
RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"` RepeatPenalty float64 `json:"repeat_penalty" yaml:"repeat_penalty"`
Keep int `json:"n_keep" yaml:"n_keep"` Keep int `json:"n_keep" yaml:"n_keep"`
MirostatETA float64 `json:"mirostat_eta" yaml:"mirostat_eta"`
MirostatTAU float64 `json:"mirostat_tau" yaml:"mirostat_tau"`
Mirostat int `json:"mirostat" yaml:"mirostat"`
Seed int `json:"seed" yaml:"seed"` Seed int `json:"seed" yaml:"seed"`
} }
@ -168,6 +172,18 @@ func updateConfig(config *Config, input *OpenAIRequest) {
if input.Seed != 0 { if input.Seed != 0 {
config.Seed = input.Seed config.Seed = input.Seed
} }
if input.Mirostat != 0 {
config.Mirostat = input.Mirostat
}
if input.MirostatETA != 0 {
config.MirostatETA = input.MirostatETA
}
if input.MirostatTAU != 0 {
config.MirostatTAU = input.MirostatTAU
}
} }
func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) { func readConfig(cm ConfigMerger, c *fiber.Ctx, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*Config, *OpenAIRequest, error) {

@ -206,6 +206,18 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
llama.SetThreads(c.Threads), llama.SetThreads(c.Threads),
} }
if c.Mirostat != 0 {
predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat))
}
if c.MirostatETA != 0 {
predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA))
}
if c.MirostatTAU != 0 {
predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU))
}
if c.Debug { if c.Debug {
predictOptions = append(predictOptions, llama.Debug) predictOptions = append(predictOptions, llama.Debug)
} }

Loading…
Cancel
Save