diff --git a/api/prediction.go b/api/prediction.go index 009641a..47229d6 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -56,7 +56,8 @@ func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]fl switch model := inferenceModel.(type) { case *llama.LLama: fn = func() ([]float32, error) { - return model.Embeddings(s) + predictOptions := buildLLamaPredictOptions(c) + return model.Embeddings(s, predictOptions...) } default: fn = func() ([]float32, error) { @@ -81,6 +82,61 @@ func ModelEmbedding(s string, loader *model.ModelLoader, c Config) (func() ([]fl }, nil } +func buildLLamaPredictOptions(c Config) []llama.PredictOption { + // Generate the prediction using the language model + predictOptions := []llama.PredictOption{ + llama.SetTemperature(c.Temperature), + llama.SetTopP(c.TopP), + llama.SetTopK(c.TopK), + llama.SetTokens(c.Maxtokens), + llama.SetThreads(c.Threads), + } + + if c.Mirostat != 0 { + predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) + } + + if c.MirostatETA != 0 { + predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) + } + + if c.MirostatTAU != 0 { + predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) + } + + if c.Debug { + predictOptions = append(predictOptions, llama.Debug) + } + + predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) + + if c.RepeatPenalty != 0 { + predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) + } + + if c.Keep != 0 { + predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) + } + + if c.Batch != 0 { + predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) + } + + if c.F16 { + predictOptions = append(predictOptions, llama.EnableF16KV) + } + + if c.IgnoreEOS { + predictOptions = append(predictOptions, llama.IgnoreEOS) + } + + if c.Seed != 0 { + predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) + } + + return predictOptions +} + func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback func(string) bool) (func() (string, error), error) { supportStreams := false modelFile := c.Model @@ -198,56 +254,7 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback model.SetTokenCallback(tokenCallback) } - // Generate the prediction using the language model - predictOptions := []llama.PredictOption{ - llama.SetTemperature(c.Temperature), - llama.SetTopP(c.TopP), - llama.SetTopK(c.TopK), - llama.SetTokens(c.Maxtokens), - llama.SetThreads(c.Threads), - } - - if c.Mirostat != 0 { - predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) - } - - if c.MirostatETA != 0 { - predictOptions = append(predictOptions, llama.SetMirostatETA(c.MirostatETA)) - } - - if c.MirostatTAU != 0 { - predictOptions = append(predictOptions, llama.SetMirostatTAU(c.MirostatTAU)) - } - - if c.Debug { - predictOptions = append(predictOptions, llama.Debug) - } - - predictOptions = append(predictOptions, llama.SetStopWords(c.StopWords...)) - - if c.RepeatPenalty != 0 { - predictOptions = append(predictOptions, llama.SetPenalty(c.RepeatPenalty)) - } - - if c.Keep != 0 { - predictOptions = append(predictOptions, llama.SetNKeep(c.Keep)) - } - - if c.Batch != 0 { - predictOptions = append(predictOptions, llama.SetBatch(c.Batch)) - } - - if c.F16 { - predictOptions = append(predictOptions, llama.EnableF16KV) - } - - if c.IgnoreEOS { - predictOptions = append(predictOptions, llama.IgnoreEOS) - } - - if c.Seed != 0 { - predictOptions = append(predictOptions, llama.SetSeed(c.Seed)) - } + predictOptions := buildLLamaPredictOptions(c) str, er := model.Predict( s,