diff --git a/.dockerignore b/.dockerignore index e73b1f9..cf96388 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,4 +1,3 @@ -.git .idea models examples/chatbot-ui/models diff --git a/.env b/.env index 040bba3..890ff94 100644 --- a/.env +++ b/.env @@ -26,8 +26,8 @@ MODELS_PATH=/models ## Specify a build type. Available: cublas, openblas, clblas. # BUILD_TYPE=openblas -## Uncomment and set to false to disable rebuilding from source -# REBUILD=false +## Uncomment and set to true to enable rebuilding from source +# REBUILD=true ## Enable go tags, available: stablediffusion, tts ## stablediffusion: image generation with stablediffusion diff --git a/Dockerfile b/Dockerfile index 6e21da3..c22f5dc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -83,6 +83,8 @@ RUN make get-sources COPY go.mod . RUN make prepare COPY . . +COPY .git . + RUN ESPEAK_DATA=/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data make build ################################### @@ -92,7 +94,7 @@ FROM requirements ARG FFMPEG -ENV REBUILD=true +ENV REBUILD=false ENV HEALTHCHECK_ENDPOINT=http://localhost:8080/readyz # Add FFmpeg diff --git a/Makefile b/Makefile index d25ce75..ecd1774 100644 --- a/Makefile +++ b/Makefile @@ -3,24 +3,51 @@ GOTEST=$(GOCMD) test GOVET=$(GOCMD) vet BINARY_NAME=local-ai -GOLLAMA_VERSION?=ecd358d2f144b4282a73df443d60474fca5db9ec +# llama.cpp versions +# Temporarly pinned to https://github.com/go-skynet/go-llama.cpp/pull/124 +GOLLAMA_VERSION?=cb8d7cd4cb95725a04504a9e3a26dd72a12b69ac + +# Temporary set a specific version of llama.cpp +# containing: https://github.com/ggerganov/llama.cpp/pull/1773 and +# rebased on top of master. +# This pin can be dropped when the PR above is merged, and go-llama has merged changes as well +# Set empty to use the version pinned by go-llama +LLAMA_CPP_REPO?=https://github.com/mudler/llama.cpp +LLAMA_CPP_VERSION?=48ce8722a05a018681634af801fd0fd45b3a87cc + +# gpt4all version GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all GPT4ALL_VERSION?=70cbff70cc2a9ad26d492d44ab582d32e6219956 + +# go-ggml-transformers version GOGGMLTRANSFORMERS_VERSION?=8e31841dcddca16468c11b2e7809f279fa76a832 + +# go-rwkv version RWKV_REPO?=https://github.com/donomii/go-rwkv.cpp RWKV_VERSION?=f5a8c45396741470583f59b916a2a7641e63bcd0 + +# whisper.cpp version WHISPER_CPP_VERSION?=85ed71aaec8e0612a84c0b67804bde75aa75a273 + +# bert.cpp version BERT_VERSION?=6069103f54b9969c02e789d0fb12a23bd614285f + +# go-piper version PIPER_VERSION?=56b8a81b4760a6fbee1a82e62f007ae7e8f010a7 + +# go-bloomz version BLOOMZ_VERSION?=1834e77b83faafe912ad4092ccf7f77937349e2f + +# stablediffusion version +STABLEDIFFUSION_VERSION?=d89260f598afb809279bc72aa0107b4292587632 + export BUILD_TYPE?= CGO_LDFLAGS?= CUDA_LIBPATH?=/usr/local/cuda/lib64/ -STABLEDIFFUSION_VERSION?=d89260f598afb809279bc72aa0107b4292587632 GO_TAGS?= BUILD_ID?=git -VERSION?=$(shell git describe --always --tags --dirty || echo "dev" ) +VERSION?=$(shell git describe --always --tags || echo "dev" ) # go tool nm ./local-ai | grep Commit LD_FLAGS?= override LD_FLAGS += -X "github.com/go-skynet/LocalAI/internal.Version=$(VERSION)" @@ -201,6 +228,9 @@ whisper.cpp/libwhisper.a: whisper.cpp go-llama: git clone --recurse-submodules https://github.com/go-skynet/go-llama.cpp go-llama cd go-llama && git checkout -b build $(GOLLAMA_VERSION) && git submodule update --init --recursive --depth 1 +ifneq ($(LLAMA_CPP_REPO),) + cd go-llama && rm -rf llama.cpp && git clone $(LLAMA_CPP_REPO) llama.cpp && cd llama.cpp && git checkout -b build $(LLAMA_CPP_VERSION) && git submodule update --init --recursive --depth 1 +endif go-llama/libbinding.a: go-llama $(MAKE) -C go-llama BUILD_TYPE=$(BUILD_TYPE) libbinding.a @@ -227,6 +257,7 @@ prepare-sources: get-sources replace ## GENERIC rebuild: ## Rebuilds the project + $(GOCMD) clean -cache $(MAKE) -C go-llama clean $(MAKE) -C gpt4all/gpt4all-bindings/golang/ clean $(MAKE) -C go-ggml-transformers clean @@ -242,6 +273,7 @@ prepare: prepare-sources backend-assets/gpt4all $(OPTIONAL_TARGETS) go-llama/lib touch $@ clean: ## Remove build related file + $(GOCMD) clean -cache rm -fr ./go-llama rm -rf ./gpt4all rm -rf ./go-gpt2 diff --git a/api/api.go b/api/api.go index e4aac2f..543e756 100644 --- a/api/api.go +++ b/api/api.go @@ -51,6 +51,9 @@ func App(opts ...AppOption) (*fiber.App, error) { })) } + log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.threads, options.loader.ModelPath) + log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion()) + cm := NewConfigMerger() if err := cm.LoadConfigs(options.loader.ModelPath); err != nil { log.Error().Msgf("error loading config files: %s", err.Error()) diff --git a/api/config.go b/api/config.go index ba84e0d..57fe0d1 100644 --- a/api/config.go +++ b/api/config.go @@ -46,12 +46,24 @@ type Config struct { PromptCacheAll bool `yaml:"prompt_cache_all"` PromptCacheRO bool `yaml:"prompt_cache_ro"` - PromptStrings, InputStrings []string - InputToken [][]int + Grammar string `yaml:"grammar"` + + FunctionsConfig Functions `yaml:"function"` + + PromptStrings, InputStrings []string + InputToken [][]int + functionCallString, functionCallNameString string +} + +type Functions struct { + DisableNoAction bool `yaml:"disable_no_action"` + NoActionFunctionName string `yaml:"no_action_function_name"` + NoActionDescriptionName string `yaml:"no_action_description_name"` } type TemplateConfig struct { Completion string `yaml:"completion"` + Functions string `yaml:"function"` Chat string `yaml:"chat"` Edit string `yaml:"edit"` } @@ -181,6 +193,10 @@ func updateConfig(config *Config, input *OpenAIRequest) { config.TopP = input.TopP } + if input.Grammar != "" { + config.Grammar = input.Grammar + } + if input.Temperature != 0 { config.Temperature = input.Temperature } @@ -261,6 +277,23 @@ func updateConfig(config *Config, input *OpenAIRequest) { } } } + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.functionCallString = fnc + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } + } + config.functionCallNameString = name + } switch p := input.Prompt.(type) { case string: diff --git a/api/openai.go b/api/openai.go index 403a03b..77d2c8e 100644 --- a/api/openai.go +++ b/api/openai.go @@ -17,6 +17,7 @@ import ( "strings" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" llama "github.com/go-skynet/go-llama.cpp" @@ -73,8 +74,12 @@ type Choice struct { } type Message struct { - Role string `json:"role,omitempty" yaml:"role"` - Content string `json:"content,omitempty" yaml:"content"` + // The message role + Role string `json:"role,omitempty" yaml:"role"` + // The message content + Content *string `json:"content" yaml:"content"` + // A result of a function call + FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` } type OpenAIModel struct { @@ -104,6 +109,10 @@ type OpenAIRequest struct { // Messages is read only by chat/completion API calls Messages []Message `json:"messages" yaml:"messages"` + // A list of available functions to call + Functions []grammar.Function `json:"functions" yaml:"functions"` + FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object + Stream bool `json:"stream"` Echo bool `json:"echo"` // Common options between all the API calls @@ -134,6 +143,11 @@ type OpenAIRequest struct { Mode int `json:"mode"` Step int `json:"step"` + // A grammar to constrain the LLM output + Grammar string `json:"grammar" yaml:"grammar"` + // A grammar object + JSONFunctionGrammarObject *grammar.JSONFunctionStructure `json:"grammar_json_functions" yaml:"grammar_json_functions"` + TypicalP float64 `json:"typical_p" yaml:"typical_p"` } @@ -202,7 +216,7 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { if input.Stream { if len(config.PromptStrings) > 1 { - return errors.New("cannot handle more than 1 `PromptStrings` when `Stream`ing") + return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") } predInput := config.PromptStrings[0] @@ -210,7 +224,9 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { Input string - }{Input: predInput}) + }{ + Input: predInput, + }) if err == nil { predInput = templatedInput log.Debug().Msgf("Template found, input modified to: %s", predInput) @@ -256,7 +272,9 @@ func completionEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { // A model can have a "file.bin.tmpl" file associated with a prompt template prefix templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { Input string - }{Input: i}) + }{ + Input: i, + }) if err == nil { i = templatedInput log.Debug().Msgf("Template found, input modified to: %s", i) @@ -357,7 +375,7 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { ComputeChoices(s, req, config, o, loader, func(s string, c *[]Choice) {}, func(s string) bool { resp := OpenAIResponse{ Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. - Choices: []Choice{{Delta: &Message{Content: s}, Index: 0}}, + Choices: []Choice{{Delta: &Message{Content: &s}, Index: 0}}, Object: "chat.completion.chunk", } log.Debug().Msgf("Sending goroutine: %s", s) @@ -368,6 +386,8 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { close(responses) } return func(c *fiber.Ctx) error { + processFunctions := false + funcs := grammar.Functions{} model, input, err := readInput(c, o.loader, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) @@ -377,27 +397,116 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } + log.Debug().Msgf("Configuration read: %+v", config) - log.Debug().Msgf("Parameter Config: %+v", config) + // Allow the user to set custom actions via config file + // to be "embedded" in each model + noActionName := "answer" + noActionDescription := "use this action to answer without performing any action" + + if config.FunctionsConfig.NoActionFunctionName != "" { + noActionName = config.FunctionsConfig.NoActionFunctionName + } + if config.FunctionsConfig.NoActionDescriptionName != "" { + noActionDescription = config.FunctionsConfig.NoActionDescriptionName + } + + // process functions if we have any defined or if we have a function call string + if len(input.Functions) > 0 && + ((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) { + log.Debug().Msgf("Response needs to process functions") + + processFunctions = true + + noActionGrammar := grammar.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } + + // Append the no action function + funcs = append(funcs, input.Functions...) + if !config.FunctionsConfig.DisableNoAction { + funcs = append(funcs, noActionGrammar) + } + + // Force picking one of the functions by the request + if config.functionCallNameString != "" { + funcs = funcs.Select(config.functionCallNameString) + } + + // Update input grammar + jsStruct := funcs.ToJSONStructure() + config.Grammar = jsStruct.Grammar("") + } else if input.JSONFunctionGrammarObject != nil { + config.Grammar = input.JSONFunctionGrammarObject.Grammar("") + } + + // functions are not supported in stream mode (yet?) + toStream := input.Stream && !processFunctions + + log.Debug().Msgf("Parameters: %+v", config) var predInput string mess := []string{} for _, i := range input.Messages { var content string - r := config.Roles[i.Role] + role := i.Role + // if function call, we might want to customize the role so we can display better that the "assistant called a json action" + // if an "assistant_function_call" role is defined, we use it, otherwise we use the role that is passed by in the request + if i.FunctionCall != nil && i.Role == "assistant" { + roleFn := "assistant_function_call" + r := config.Roles[roleFn] + if r != "" { + role = roleFn + } + } + r := config.Roles[role] + contentExists := i.Content != nil && *i.Content != "" if r != "" { - content = fmt.Sprint(r, " ", i.Content) + if contentExists { + content = fmt.Sprint(r, " ", *i.Content) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + fmt.Sprint(r, " ", string(j)) + } else { + content = fmt.Sprint(r, " ", string(j)) + } + } + } } else { - content = i.Content + if contentExists { + content = fmt.Sprint(*i.Content) + } + if i.FunctionCall != nil { + j, err := json.Marshal(i.FunctionCall) + if err == nil { + if contentExists { + content += "\n" + string(j) + } else { + content = string(j) + } + } + } } mess = append(mess, content) } predInput = strings.Join(mess, "\n") + log.Debug().Msgf("Prompt (before templating): %s", predInput) - if input.Stream { + if toStream { log.Debug().Msgf("Stream request received") c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) @@ -409,20 +518,35 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { templateFile := config.Model - if config.TemplateConfig.Chat != "" { + if config.TemplateConfig.Chat != "" && !processFunctions { templateFile = config.TemplateConfig.Chat } + if config.TemplateConfig.Functions != "" && processFunctions { + templateFile = config.TemplateConfig.Functions + } + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - }{Input: predInput}) + Input string + Functions []grammar.Function + }{ + Input: predInput, + Functions: funcs, + }) if err == nil { predInput = templatedInput log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) } - if input.Stream { + log.Debug().Msgf("Prompt (after templating): %s", predInput) + if processFunctions { + log.Debug().Msgf("Grammar: %+v", config.Grammar) + } + + if toStream { responses := make(chan OpenAIResponse) go process(predInput, input, config, o.loader, responses) @@ -459,7 +583,72 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { } result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) { - *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}}) + if processFunctions { + // As we have to change the result before processing, we can't stream the answer (yet?) + ss := map[string]interface{}{} + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name := ss["function"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + d, _ := json.Marshal(args) + + ss["arguments"] = string(d) + ss["name"] = func_name + + // if do nothing, reply with a message + if func_name == noActionName { + log.Debug().Msgf("nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(d), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = Finetune(*config, predInput, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &message}}) + return + } + } + } + + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU) another computation + config.Grammar = "" + predFunc, err := ModelInference(predInput, o.loader, *config, o, nil) + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction, err := predFunc() + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction = Finetune(*config, predInput, prediction) + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &prediction}}) + } else { + // otherwise reply with the function call + *c = append(*c, Choice{ + FinishReason: "function_call", + Message: &Message{Role: "assistant", FunctionCall: ss}, + }) + } + + return + } + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: &s}}) }, nil) if err != nil { return err diff --git a/api/options.go b/api/options.go index b4669bc..923288a 100644 --- a/api/options.go +++ b/api/options.go @@ -3,9 +3,11 @@ package api import ( "context" "embed" + "encoding/json" "github.com/go-skynet/LocalAI/pkg/gallery" model "github.com/go-skynet/LocalAI/pkg/model" + "github.com/rs/zerolog/log" ) type Option struct { @@ -69,6 +71,20 @@ func WithBackendAssets(f embed.FS) AppOption { } } +func WithStringGalleries(galls string) AppOption { + return func(o *Option) { + if galls == "" { + log.Debug().Msgf("no galleries to load") + return + } + var galleries []gallery.Gallery + if err := json.Unmarshal([]byte(galls), &galleries); err != nil { + log.Error().Msgf("failed loading galleries: %s", err.Error()) + } + o.galleries = append(o.galleries, galleries...) + } +} + func WithGalleries(galleries []gallery.Gallery) AppOption { return func(o *Option) { o.galleries = append(o.galleries, galleries...) diff --git a/api/prediction.go b/api/prediction.go index bc23d86..7daa730 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -189,6 +189,8 @@ func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption predictOptions = append(predictOptions, llama.EnablePromptCacheRO) } + predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar)) + if c.PromptCachePath != "" { // Create parent directory p := filepath.Join(modelPath, c.PromptCachePath) diff --git a/entrypoint.sh b/entrypoint.sh index 2bd8d02..b787649 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -6,6 +6,16 @@ cd /build if [ "$REBUILD" != "false" ]; then rm -rf ./local-ai ESPEAK_DATA=/build/lib/Linux-$(uname -m)/piper_phonemize/lib/espeak-ng-data make build -j${BUILD_PARALLELISM:-1} +else + echo "@@@@@" + echo "Skipping rebuild" + echo "@@@@@" + echo "If you are experiencing issues with the pre-compiled builds, try setting REBUILD=true" + echo "If you are still experiencing issues with the build, try setting CMAKE_ARGS and disable the instructions set as needed:" + echo 'CMAKE_ARGS="-DLLAMA_F16C=OFF -DLLAMA_AVX512=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF"' + echo "see the documentation at: https://localai.io/basics/build/index.html" + echo "Note: See also https://github.com/go-skynet/LocalAI/issues/288" + echo "@@@@@" fi ./local-ai "$@" diff --git a/internal/version.go b/internal/version.go index 12246c2..86588b4 100644 --- a/internal/version.go +++ b/internal/version.go @@ -6,5 +6,5 @@ var Version = "" var Commit = "" func PrintableVersion() string { - return fmt.Sprintf("LocalAI %s (%s)", Version, Commit) + return fmt.Sprintf("%s (%s)", Version, Commit) } diff --git a/main.go b/main.go index 12b129c..fc1dea0 100644 --- a/main.go +++ b/main.go @@ -1,14 +1,11 @@ package main import ( - "encoding/json" - "fmt" "os" "path/filepath" api "github.com/go-skynet/LocalAI/api" "github.com/go-skynet/LocalAI/internal" - "github.com/go-skynet/LocalAI/pkg/gallery" model "github.com/go-skynet/LocalAI/pkg/model" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -126,19 +123,13 @@ Some of the models compatible are: - Alpaca - StableLM (ggml quantized) -It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. +For a list of compatible model, check out: https://localai.io/model-compatibility/index.html `, UsageText: `local-ai [options]`, - Copyright: "go-skynet authors", + Copyright: "Ettore Di Giacinto", Action: func(ctx *cli.Context) error { - fmt.Printf("Starting LocalAI using %d threads, with models path: %s\n", ctx.Int("threads"), ctx.String("models-path")) - galls := ctx.String("galleries") - var galleries []gallery.Gallery - err := json.Unmarshal([]byte(galls), &galleries) - fmt.Println(err) app, err := api.App( api.WithConfigFile(ctx.String("config-file")), - api.WithGalleries(galleries), api.WithJSONStringPreload(ctx.String("preload-models")), api.WithYAMLConfigPreload(ctx.String("preload-models-config")), api.WithModelLoader(model.NewModelLoader(ctx.String("models-path"))), @@ -147,6 +138,7 @@ It uses llama.cpp, ggml and gpt4all as backend with golang c bindings. api.WithImageDir(ctx.String("image-path")), api.WithAudioDir(ctx.String("audio-path")), api.WithF16(ctx.Bool("f16")), + api.WithStringGalleries(ctx.String("galleries")), api.WithDisableMessage(false), api.WithCors(ctx.Bool("cors")), api.WithCorsAllowOrigins(ctx.String("cors-allow-origins")), diff --git a/pkg/gallery/gallery.go b/pkg/gallery/gallery.go index aed5251..8e08592 100644 --- a/pkg/gallery/gallery.go +++ b/pkg/gallery/gallery.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/go-skynet/LocalAI/pkg/utils" "github.com/imdario/mergo" @@ -17,6 +18,10 @@ type Gallery struct { // Installs a model from the gallery (galleryname@modelname) func InstallModelFromGallery(galleries []Gallery, name string, basePath string, req GalleryModel, downloadStatus func(string, string, string, float64)) error { + + // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. + name = strings.ReplaceAll(name, string(os.PathSeparator), "__") + models, err := AvailableGalleryModels(galleries, basePath) if err != nil { return err diff --git a/pkg/grammar/functions.go b/pkg/grammar/functions.go new file mode 100644 index 0000000..c468a89 --- /dev/null +++ b/pkg/grammar/functions.go @@ -0,0 +1,50 @@ +package grammar + +import ( + "encoding/json" +) + +type Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} +type Functions []Function + +func (f Functions) ToJSONStructure() JSONFunctionStructure { + js := JSONFunctionStructure{} + for _, function := range f { + // t := function.Parameters["type"] + //tt := t.(string) + + properties := function.Parameters["properties"] + dat, _ := json.Marshal(properties) + prop := map[string]interface{}{} + json.Unmarshal(dat, &prop) + js.OneOf = append(js.OneOf, Item{ + Type: "object", + Properties: Properties{ + Function: FunctionName{Const: function.Name}, + Arguments: Argument{ + Type: "object", + Properties: prop, + }, + }, + }) + } + return js +} + +// Select returns a list of functions containing the function with the given name +func (f Functions) Select(name string) Functions { + var funcs Functions + + for _, f := range f { + if f.Name == name { + funcs = []Function{f} + break + } + } + + return funcs +} diff --git a/pkg/grammar/functions_test.go b/pkg/grammar/functions_test.go new file mode 100644 index 0000000..6e8a56e --- /dev/null +++ b/pkg/grammar/functions_test.go @@ -0,0 +1,63 @@ +package grammar_test + +import ( + . "github.com/go-skynet/LocalAI/pkg/grammar" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LocalAI grammar functions", func() { + Describe("ToJSONStructure()", func() { + It("converts a list of functions to a JSON structure that can be parsed to a grammar", func() { + var functions Functions = []Function{ + { + Name: "create_event", + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "event_name": map[string]interface{}{ + "type": "string", + }, + "event_date": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + { + Name: "search", + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "query": map[string]interface{}{ + "type": "string", + }, + }, + }, + }, + } + + js := functions.ToJSONStructure() + Expect(len(js.OneOf)).To(Equal(2)) + Expect(js.OneOf[0].Properties.Function.Const).To(Equal("create_event")) + Expect(js.OneOf[0].Properties.Arguments.Properties["event_name"].(map[string]interface{})["type"]).To(Equal("string")) + Expect(js.OneOf[0].Properties.Arguments.Properties["event_date"].(map[string]interface{})["type"]).To(Equal("string")) + Expect(js.OneOf[1].Properties.Function.Const).To(Equal("search")) + Expect(js.OneOf[1].Properties.Arguments.Properties["query"].(map[string]interface{})["type"]).To(Equal("string")) + }) + }) + Context("Select()", func() { + It("selects one of the functions and returns a list containing only the selected one", func() { + var functions Functions = []Function{ + { + Name: "create_event", + }, + { + Name: "search", + }, + } + + functions = functions.Select("create_event") + Expect(len(functions)).To(Equal(1)) + Expect(functions[0].Name).To(Equal("create_event")) + }) + }) +}) diff --git a/pkg/grammar/grammar_suite_test.go b/pkg/grammar/grammar_suite_test.go new file mode 100644 index 0000000..652643b --- /dev/null +++ b/pkg/grammar/grammar_suite_test.go @@ -0,0 +1,13 @@ +package grammar + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestGrammar(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Grammar test suite") +} diff --git a/pkg/grammar/json_schema.go b/pkg/grammar/json_schema.go new file mode 100644 index 0000000..5db2bca --- /dev/null +++ b/pkg/grammar/json_schema.go @@ -0,0 +1,222 @@ +package grammar + +// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887 + +import ( + "encoding/json" + "fmt" + "regexp" + "sort" + "strings" +) + +var ( + SPACE_RULE = `" "?` + + PRIMITIVE_RULES = map[string]string{ + "boolean": `("true" | "false") space`, + "number": `[0-9]+ space`, // TODO complete + "string": `"\"" [ \t!#-\[\]-~]* "\"" space`, // TODO complete + "null": `"null" space`, + } + + INVALID_RULE_CHARS_RE = regexp.MustCompile(`[^a-zA-Z0-9-]+`) + GRAMMAR_LITERAL_ESCAPE_RE = regexp.MustCompile(`[\r\n"]`) + GRAMMAR_LITERAL_ESCAPES = map[string]string{ + "\r": `\r`, + "\n": `\n`, + `"`: `\"`, + } +) + +type JSONSchemaConverter struct { + propOrder map[string]int + rules map[string]string +} + +func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter { + propOrderSlice := strings.Split(propOrder, ",") + propOrderMap := make(map[string]int) + for idx, name := range propOrderSlice { + propOrderMap[name] = idx + } + + rules := make(map[string]string) + rules["space"] = SPACE_RULE + + return &JSONSchemaConverter{ + propOrder: propOrderMap, + rules: rules, + } +} + +func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) string { + escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jsonString(literal), func(match string) string { + return GRAMMAR_LITERAL_ESCAPES[match] + }) + return fmt.Sprintf(`"%s"`, escaped) +} + +func (sc *JSONSchemaConverter) addRule(name, rule string) string { + escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-") + key := escName + if existingRule, ok := sc.rules[escName]; ok && existingRule != rule { + i := 0 + for { + key = fmt.Sprintf("%s%d", escName, i) + if _, ok := sc.rules[key]; !ok { + break + } + i++ + } + } + sc.rules[key] = rule + return key +} + +func (sc *JSONSchemaConverter) formatGrammar() string { + var lines []string + for name, rule := range sc.rules { + lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) + } + return strings.Join(lines, "\n") +} + +func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string) string { + st, existType := schema["type"] + var schemaType string + if existType { + schemaType = st.(string) + } + ruleName := name + if name == "" { + ruleName = "root" + } + _, oneOfExists := schema["oneOf"] + _, anyOfExists := schema["anyOf"] + if oneOfExists || anyOfExists { + var alternatives []string + oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{}) + anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{}) + + if oneOfExists { + for i, altSchema := range oneOfSchemas { + alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i)) + alternatives = append(alternatives, alternative) + } + } else if anyOfExists { + for i, altSchema := range anyOfSchemas { + alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i)) + alternatives = append(alternatives, alternative) + } + } + + rule := strings.Join(alternatives, " | ") + return sc.addRule(ruleName, rule) + } else if constVal, exists := schema["const"]; exists { + return sc.addRule(ruleName, sc.formatLiteral(constVal)) + } else if enumVals, exists := schema["enum"].([]interface{}); exists { + var enumRules []string + for _, enumVal := range enumVals { + enumRule := sc.formatLiteral(enumVal) + enumRules = append(enumRules, enumRule) + } + rule := strings.Join(enumRules, " | ") + return sc.addRule(ruleName, rule) + } else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists { + propOrder := sc.propOrder + var propPairs []struct { + propName string + propSchema map[string]interface{} + } + + for propName, propSchema := range properties { + propPairs = append(propPairs, struct { + propName string + propSchema map[string]interface{} + }{propName: propName, propSchema: propSchema.(map[string]interface{})}) + } + + sort.Slice(propPairs, func(i, j int) bool { + iOrder := propOrder[propPairs[i].propName] + jOrder := propOrder[propPairs[j].propName] + if iOrder != 0 && jOrder != 0 { + return iOrder < jOrder + } + return propPairs[i].propName < propPairs[j].propName + }) + + var rule strings.Builder + rule.WriteString(`"{" space`) + + for i, propPair := range propPairs { + propName := propPair.propName + propSchema := propPair.propSchema + propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName)) + + if i > 0 { + rule.WriteString(` "," space`) + } + + rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, sc.formatLiteral(propName), propRuleName)) + } + + rule.WriteString(` "}" space`) + return sc.addRule(ruleName, rule.String()) + } else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists { + itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName)) + rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName) + return sc.addRule(ruleName, rule) + } else { + primitiveRule, exists := PRIMITIVE_RULES[schemaType] + if !exists { + panic(fmt.Sprintf("Unrecognized schema: %v", schema)) + } + return sc.addRule(schemaType, primitiveRule) + } +} + +func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string { + sc.visit(schema, "") + return sc.formatGrammar() +} + +func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string { + var schema map[string]interface{} + _ = json.Unmarshal(b, &schema) + return sc.Grammar(schema) +} + +func jsonString(v interface{}) string { + b, _ := json.Marshal(v) + return string(b) +} + +type FunctionName struct { + Const string `json:"const"` +} + +type Properties struct { + Function FunctionName `json:"function"` + Arguments Argument `json:"arguments"` +} + +type Argument struct { + Type string `json:"type"` + Properties map[string]interface{} `json:"properties"` +} + +type Item struct { + Type string `json:"type"` + Properties Properties `json:"properties"` +} + +type JSONFunctionStructure struct { + OneOf []Item `json:"oneOf,omitempty"` + AnyOf []Item `json:"anyOf,omitempty"` +} + +func (j JSONFunctionStructure) Grammar(propOrder string) string { + dat, _ := json.Marshal(j) + return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat) +} diff --git a/pkg/grammar/json_schema_test.go b/pkg/grammar/json_schema_test.go new file mode 100644 index 0000000..0d8dd99 --- /dev/null +++ b/pkg/grammar/json_schema_test.go @@ -0,0 +1,113 @@ +package grammar_test + +import ( + "strings" + + . "github.com/go-skynet/LocalAI/pkg/grammar" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +const ( + testInput1 = ` + { + "oneOf": [ + { + "type": "object", + "properties": { + "function": {"const": "create_event"}, + "arguments": { + "type": "object", + "properties": { + "title": {"type": "string"}, + "date": {"type": "string"}, + "time": {"type": "string"} + } + } + } + }, + { + "type": "object", + "properties": { + "function": {"const": "search"}, + "arguments": { + "type": "object", + "properties": { + "query": {"type": "string"} + } + } + } + } + ] + }` + + inputResult1 = `root-0-function ::= "\"create_event\"" +root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space +root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space +root ::= root-0 | root-1 +space ::= " "? +root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space +root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space +string ::= "\"" [ \t!#-\[\]-~]* "\"" space +root-1-function ::= "\"search\""` +) + +var _ = Describe("JSON schema grammar tests", func() { + Context("JSON", func() { + It("generates a valid grammar from JSON schema", func() { + grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1)) + results := strings.Split(inputResult1, "\n") + for _, r := range results { + if r != "" { + Expect(grammar).To(ContainSubstring(r)) + } + } + Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) + }) + It("generates a valid grammar from JSON Objects", func() { + + structuredGrammar := JSONFunctionStructure{ + OneOf: []Item{ + { + Type: "object", + Properties: Properties{ + Function: FunctionName{ + Const: "create_event", + }, + Arguments: Argument{ // this is OpenAI's parameter + Type: "object", + Properties: map[string]interface{}{ + "title": map[string]string{"type": "string"}, + "date": map[string]string{"type": "string"}, + "time": map[string]string{"type": "string"}, + }, + }, + }, + }, + { + Type: "object", + Properties: Properties{ + Function: FunctionName{ + Const: "search", + }, + Arguments: Argument{ + Type: "object", + Properties: map[string]interface{}{ + "query": map[string]string{"type": "string"}, + }, + }, + }, + }, + }} + + grammar := structuredGrammar.Grammar("") + results := strings.Split(inputResult1, "\n") + for _, r := range results { + if r != "" { + Expect(grammar).To(ContainSubstring(r)) + } + } + Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) + }) + }) +})