|
|
|
@ -37,7 +37,7 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri |
|
|
|
|
// TODO: this is ugly, better identifying the model somehow! however, it is a good stab for a first implementation..
|
|
|
|
|
model, llamaerr = loader.LoadLLaMAModel(modelFile, llamaOpts...) |
|
|
|
|
if llamaerr != nil { |
|
|
|
|
gptModel, gptjerr = loader.LoadGPTJModel(modelFile) |
|
|
|
|
gptModel, gptjerr = loader.LoadGPTJModel(modelFile, gptj.SetThreads(c.Threads)) |
|
|
|
|
if gptjerr != nil { |
|
|
|
|
gpt2Model, gpt2err = loader.LoadGPT2Model(modelFile) |
|
|
|
|
if gpt2err != nil { |
|
|
|
@ -108,17 +108,12 @@ func ModelInference(s string, loader *model.ModelLoader, c Config) (func() (stri |
|
|
|
|
gptj.SetTopP(c.TopP), |
|
|
|
|
gptj.SetTopK(c.TopK), |
|
|
|
|
gptj.SetTokens(c.Maxtokens), |
|
|
|
|
gptj.SetThreads(c.Threads), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if c.Batch != 0 { |
|
|
|
|
predictOptions = append(predictOptions, gptj.SetBatch(c.Batch)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if c.Seed != 0 { |
|
|
|
|
predictOptions = append(predictOptions, gptj.SetSeed(c.Seed)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return gptModel.Predict( |
|
|
|
|
s, |
|
|
|
|
predictOptions..., |
|
|
|
|