|
|
@ -11,6 +11,7 @@ import ( |
|
|
|
"text/template" |
|
|
|
"text/template" |
|
|
|
|
|
|
|
|
|
|
|
rwkv "github.com/donomii/go-rwkv.cpp" |
|
|
|
rwkv "github.com/donomii/go-rwkv.cpp" |
|
|
|
|
|
|
|
whisper "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" |
|
|
|
bloomz "github.com/go-skynet/bloomz.cpp" |
|
|
|
bloomz "github.com/go-skynet/bloomz.cpp" |
|
|
|
bert "github.com/go-skynet/go-bert.cpp" |
|
|
|
bert "github.com/go-skynet/go-bert.cpp" |
|
|
|
gpt2 "github.com/go-skynet/go-gpt2.cpp" |
|
|
|
gpt2 "github.com/go-skynet/go-gpt2.cpp" |
|
|
@ -32,9 +33,9 @@ type ModelLoader struct { |
|
|
|
redpajama map[string]*gpt2.RedPajama |
|
|
|
redpajama map[string]*gpt2.RedPajama |
|
|
|
rwkv map[string]*rwkv.RwkvState |
|
|
|
rwkv map[string]*rwkv.RwkvState |
|
|
|
bloomz map[string]*bloomz.Bloomz |
|
|
|
bloomz map[string]*bloomz.Bloomz |
|
|
|
|
|
|
|
|
|
|
|
bert map[string]*bert.Bert |
|
|
|
bert map[string]*bert.Bert |
|
|
|
promptsTemplates map[string]*template.Template |
|
|
|
promptsTemplates map[string]*template.Template |
|
|
|
|
|
|
|
whisperModels map[string]whisper.Model |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func NewModelLoader(modelPath string) *ModelLoader { |
|
|
|
func NewModelLoader(modelPath string) *ModelLoader { |
|
|
@ -50,6 +51,7 @@ func NewModelLoader(modelPath string) *ModelLoader { |
|
|
|
bloomz: make(map[string]*bloomz.Bloomz), |
|
|
|
bloomz: make(map[string]*bloomz.Bloomz), |
|
|
|
bert: make(map[string]*bert.Bert), |
|
|
|
bert: make(map[string]*bert.Bert), |
|
|
|
promptsTemplates: make(map[string]*template.Template), |
|
|
|
promptsTemplates: make(map[string]*template.Template), |
|
|
|
|
|
|
|
whisperModels: make(map[string]whisper.Model), |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
@ -422,6 +424,33 @@ func (ml *ModelLoader) LoadLLaMAModel(modelName string, opts ...llama.ModelOptio |
|
|
|
return model, err |
|
|
|
return model, err |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (ml *ModelLoader) LoadWhisperModel(modelName string) (whisper.Model, error) { |
|
|
|
|
|
|
|
ml.mu.Lock() |
|
|
|
|
|
|
|
defer ml.mu.Unlock() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Check if we already have a loaded model
|
|
|
|
|
|
|
|
if !ml.ExistsInModelPath(modelName) { |
|
|
|
|
|
|
|
return nil, fmt.Errorf("model does not exist -- %s", modelName) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if m, ok := ml.whisperModels[modelName]; ok { |
|
|
|
|
|
|
|
log.Debug().Msgf("Model already loaded in memory: %s", modelName) |
|
|
|
|
|
|
|
return m, nil |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Load the model and keep it in memory for later use
|
|
|
|
|
|
|
|
modelFile := filepath.Join(ml.ModelPath, modelName) |
|
|
|
|
|
|
|
log.Debug().Msgf("Loading model in memory from file: %s", modelFile) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model, err := whisper.New(modelFile) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return nil, err |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ml.whisperModels[modelName] = model |
|
|
|
|
|
|
|
return model, err |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
const tokenizerSuffix = ".tokenizer.json" |
|
|
|
const tokenizerSuffix = ".tokenizer.json" |
|
|
|
|
|
|
|
|
|
|
|
var loadedModels map[string]interface{} = map[string]interface{}{} |
|
|
|
var loadedModels map[string]interface{} = map[string]interface{}{} |
|
|
@ -452,6 +481,16 @@ func (ml *ModelLoader) BackendLoader(backendString string, modelFile string, lla |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (ml *ModelLoader) WhisperLoader(backendString string, modelFile string) (model whisper.Model, err error) { |
|
|
|
|
|
|
|
//TODO expose more whisper options in next PR
|
|
|
|
|
|
|
|
switch strings.ToLower(backendString) { |
|
|
|
|
|
|
|
case "whisper": |
|
|
|
|
|
|
|
return ml.LoadWhisperModel(modelFile) |
|
|
|
|
|
|
|
default: |
|
|
|
|
|
|
|
return nil, fmt.Errorf("whisper backend unsupported: %s", backendString) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) { |
|
|
|
func (ml *ModelLoader) GreedyLoader(modelFile string, llamaOpts []llama.ModelOption, threads uint32) (model interface{}, err error) { |
|
|
|
updateModels := func(model interface{}) { |
|
|
|
updateModels := func(model interface{}) { |
|
|
|
muModels.Lock() |
|
|
|
muModels.Lock() |
|
|
|