feat: allow to set a prompt cache path and enable saving state (#395)

Signed-off-by: mudler <mudler@mocaccino.org>
examples_update
Ettore Di Giacinto 1 year ago committed by GitHub
parent 76c881043e
commit 217dbb448e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      Makefile
  2. 4
      api/config.go
  3. 19
      api/prediction.go

@ -3,7 +3,7 @@ GOTEST=$(GOCMD) test
GOVET=$(GOCMD) vet GOVET=$(GOCMD) vet
BINARY_NAME=local-ai BINARY_NAME=local-ai
GOLLAMA_VERSION?=8bd97d532e90cf34e755b3ea2d8aa17000443cf2 GOLLAMA_VERSION?=fbec625895ba0c458f783b62c8569135c5e80d79
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
GPT4ALL_VERSION?=73db20ba85fbbdc66a56e2619394c0eea40dc72b GPT4ALL_VERSION?=73db20ba85fbbdc66a56e2619394c0eea40dc72b
GOGGMLTRANSFORMERS_VERSION?=c4c581f1853cf1b66276501c7c0dbea1e3e564b7 GOGGMLTRANSFORMERS_VERSION?=c4c581f1853cf1b66276501c7c0dbea1e3e564b7

@ -34,6 +34,10 @@ type Config struct {
Mirostat int `yaml:"mirostat"` Mirostat int `yaml:"mirostat"`
NGPULayers int `yaml:"gpu_layers"` NGPULayers int `yaml:"gpu_layers"`
ImageGenerationAssets string `yaml:"asset_dir"` ImageGenerationAssets string `yaml:"asset_dir"`
PromptCachePath string `yaml:"prompt_cache_path"`
PromptCacheAll bool `yaml:"prompt_cache_all"`
PromptStrings, InputStrings []string PromptStrings, InputStrings []string
InputToken [][]int InputToken [][]int
} }

@ -2,6 +2,8 @@ package api
import ( import (
"fmt" "fmt"
"os"
"path/filepath"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
@ -102,7 +104,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config)
switch model := inferenceModel.(type) { switch model := inferenceModel.(type) {
case *llama.LLama: case *llama.LLama:
fn = func() ([]float32, error) { fn = func() ([]float32, error) {
predictOptions := buildLLamaPredictOptions(c) predictOptions := buildLLamaPredictOptions(c, loader.ModelPath)
if len(tokens) > 0 { if len(tokens) > 0 {
return model.TokenEmbeddings(tokens, predictOptions...) return model.TokenEmbeddings(tokens, predictOptions...)
} }
@ -151,7 +153,7 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c Config)
}, nil }, nil
} }
func buildLLamaPredictOptions(c Config) []llama.PredictOption { func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption {
// Generate the prediction using the language model // Generate the prediction using the language model
predictOptions := []llama.PredictOption{ predictOptions := []llama.PredictOption{
llama.SetTemperature(c.Temperature), llama.SetTemperature(c.Temperature),
@ -161,6 +163,17 @@ func buildLLamaPredictOptions(c Config) []llama.PredictOption {
llama.SetThreads(c.Threads), llama.SetThreads(c.Threads),
} }
if c.PromptCacheAll {
predictOptions = append(predictOptions, llama.EnablePromptCacheAll)
}
if c.PromptCachePath != "" {
// Create parent directory
p := filepath.Join(modelPath, c.PromptCachePath)
os.MkdirAll(filepath.Dir(p), 0755)
predictOptions = append(predictOptions, llama.SetPathPromptCache(p))
}
if c.Mirostat != 0 { if c.Mirostat != 0 {
predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat)) predictOptions = append(predictOptions, llama.SetMirostat(c.Mirostat))
} }
@ -469,7 +482,7 @@ func ModelInference(s string, loader *model.ModelLoader, c Config, tokenCallback
model.SetTokenCallback(tokenCallback) model.SetTokenCallback(tokenCallback)
} }
predictOptions := buildLLamaPredictOptions(c) predictOptions := buildLLamaPredictOptions(c, loader.ModelPath)
str, er := model.Predict( str, er := model.Predict(
s, s,

Loading…
Cancel
Save