From f09ddd2983ae1e02a5b7c6fc51ebfc948c0c7cce Mon Sep 17 00:00:00 2001 From: mudler Date: Sun, 2 Jul 2023 11:13:51 +0200 Subject: [PATCH] feat: add grammar and functions call support --- api/config.go | 30 +++- api/openai.go | 150 +++++++++++++++++++- api/prediction.go | 2 + pkg/grammar/functions.go | 50 +++++++ pkg/grammar/grammar_suite_test.go | 13 ++ pkg/grammar/json_schema.go | 222 ++++++++++++++++++++++++++++++ pkg/grammar/json_schema_test.go | 113 +++++++++++++++ 7 files changed, 571 insertions(+), 9 deletions(-) create mode 100644 pkg/grammar/functions.go create mode 100644 pkg/grammar/grammar_suite_test.go create mode 100644 pkg/grammar/json_schema.go create mode 100644 pkg/grammar/json_schema_test.go diff --git a/api/config.go b/api/config.go index ba84e0d..c9a8092 100644 --- a/api/config.go +++ b/api/config.go @@ -46,12 +46,16 @@ type Config struct { PromptCacheAll bool `yaml:"prompt_cache_all"` PromptCacheRO bool `yaml:"prompt_cache_ro"` - PromptStrings, InputStrings []string - InputToken [][]int + Grammar string `yaml:"grammar"` + + PromptStrings, InputStrings []string + InputToken [][]int + functionCallString, functionCallNameString string } type TemplateConfig struct { Completion string `yaml:"completion"` + Functions string `yaml:"function"` Chat string `yaml:"chat"` Edit string `yaml:"edit"` } @@ -181,6 +185,10 @@ func updateConfig(config *Config, input *OpenAIRequest) { config.TopP = input.TopP } + if input.Grammar != "" { + config.Grammar = input.Grammar + } + if input.Temperature != 0 { config.Temperature = input.Temperature } @@ -262,6 +270,24 @@ func updateConfig(config *Config, input *OpenAIRequest) { } } + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.functionCallString = fnc + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if !e { + name = nn + } + } + config.functionCallNameString = name + } + switch p := input.Prompt.(type) { case string: config.PromptStrings = append(config.PromptStrings, p) diff --git a/api/openai.go b/api/openai.go index 403a03b..f361b72 100644 --- a/api/openai.go +++ b/api/openai.go @@ -17,6 +17,7 @@ import ( "strings" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" + "github.com/go-skynet/LocalAI/pkg/grammar" model "github.com/go-skynet/LocalAI/pkg/model" whisperutil "github.com/go-skynet/LocalAI/pkg/whisper" llama "github.com/go-skynet/go-llama.cpp" @@ -73,8 +74,12 @@ type Choice struct { } type Message struct { - Role string `json:"role,omitempty" yaml:"role"` + // The message role + Role string `json:"role,omitempty" yaml:"role"` + // The message content Content string `json:"content,omitempty" yaml:"content"` + // A result of a function call + FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` } type OpenAIModel struct { @@ -104,6 +109,10 @@ type OpenAIRequest struct { // Messages is read only by chat/completion API calls Messages []Message `json:"messages" yaml:"messages"` + // A list of available functions to call + Functions []grammar.Function `json:"functions" yaml:"functions"` + FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object + Stream bool `json:"stream"` Echo bool `json:"echo"` // Common options between all the API calls @@ -134,6 +143,9 @@ type OpenAIRequest struct { Mode int `json:"mode"` Step int `json:"step"` + // A grammar to constrain the LLM output + Grammar string `json:"grammar" yaml:"grammar"` + TypicalP float64 `json:"typical_p" yaml:"typical_p"` } @@ -345,6 +357,23 @@ func embeddingsEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { } func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { + // TODO: replace this with config settings + // Allow the user to set custom actions via config file + // to be "embedded" in each model + const noActionName = "answer" + const noActionDescription = "use this action to answer without performing any action" + + noActionGrammar := grammar.Function{ + Name: noActionName, + Description: noActionDescription, + Parameters: map[string]interface{}{ + "properties": map[string]interface{}{ + "message": map[string]interface{}{ + "type": "string", + "description": "The message to reply the user with", + }}, + }, + } process := func(s string, req *OpenAIRequest, config *Config, loader *model.ModelLoader, responses chan OpenAIResponse) { initialMessage := OpenAIResponse{ @@ -368,6 +397,8 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { close(responses) } return func(c *fiber.Ctx) error { + processFunctions := false + funcs := []grammar.Function{} model, input, err := readInput(c, o.loader, true) if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) @@ -377,8 +408,33 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { if err != nil { return fmt.Errorf("failed reading parameters from request:%w", err) } + log.Debug().Msgf("Configuration read: %+v", config) - log.Debug().Msgf("Parameter Config: %+v", config) + // process functions if we have any defined or if we have a function call string + if len(input.Functions) > 0 && + ((config.functionCallString != "none" || config.functionCallString == "") || len(config.functionCallNameString) > 0) { + log.Debug().Msgf("Response needs to process functions") + + var funcs grammar.Functions = input.Functions + processFunctions = true + + // Force picking one of the functions by the request + if config.functionCallNameString != "" { + funcs = funcs.Select(config.functionCallNameString) + } + + // Append the no action function + funcs = append(funcs, noActionGrammar) + + // Update input grammar + jsStruct := funcs.ToJSONStructure() + config.Grammar = jsStruct.Grammar("") + } + + // functions are not supported in stream mode (yet?) + toStream := input.Stream && !processFunctions + + log.Debug().Msgf("Parameters: %+v", config) var predInput string @@ -397,7 +453,7 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { predInput = strings.Join(mess, "\n") - if input.Stream { + if toStream { log.Debug().Msgf("Stream request received") c.Context().SetContentType("text/event-stream") //c.Response().Header.SetContentType(fiber.MIMETextHTMLCharsetUTF8) @@ -409,20 +465,35 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { templateFile := config.Model - if config.TemplateConfig.Chat != "" { + if config.TemplateConfig.Chat != "" && !processFunctions { templateFile = config.TemplateConfig.Chat } + if config.TemplateConfig.Functions != "" && processFunctions { + templateFile = config.TemplateConfig.Functions + } + // A model can have a "file.bin.tmpl" file associated with a prompt template prefix templatedInput, err := o.loader.TemplatePrefix(templateFile, struct { - Input string - }{Input: predInput}) + Input string + Functions []grammar.Function + }{ + Input: predInput, + Functions: funcs, + }) if err == nil { predInput = templatedInput log.Debug().Msgf("Template found, input modified to: %s", predInput) + } else { + log.Debug().Msgf("Template failed loading: %s", err.Error()) } - if input.Stream { + log.Debug().Msgf("Prompt: %s", predInput) + if processFunctions { + log.Debug().Msgf("Grammar: %+v", config.Grammar) + } + + if toStream { responses := make(chan OpenAIResponse) go process(predInput, input, config, o.loader, responses) @@ -459,6 +530,71 @@ func chatEndpoint(cm *ConfigMerger, o *Option) func(c *fiber.Ctx) error { } result, err := ComputeChoices(predInput, input, config, o, o.loader, func(s string, c *[]Choice) { + if processFunctions { + // As we have to change the result before processing, we can't stream the answer (yet?) + ss := map[string]interface{}{} + json.Unmarshal([]byte(s), &ss) + log.Debug().Msgf("Function return: %s %+v", s, ss) + + // The grammar defines the function name as "function", while OpenAI returns "name" + func_name := ss["function"] + // Similarly, while here arguments is a map[string]interface{}, OpenAI actually want a stringified object + args := ss["arguments"] // arguments needs to be a string, but we return an object from the grammar result (TODO: fix) + d, _ := json.Marshal(args) + + ss["arguments"] = string(d) + ss["name"] = func_name + + // if do nothing, reply with a message + if func_name == noActionName { + log.Debug().Msgf("nothing to do, computing a reply") + + // If there is a message that the LLM already sends as part of the JSON reply, use it + arguments := map[string]interface{}{} + json.Unmarshal([]byte(d), &arguments) + m, exists := arguments["message"] + if exists { + switch message := m.(type) { + case string: + if message != "" { + log.Debug().Msgf("Reply received from LLM: %s", message) + message = Finetune(*config, predInput, message) + log.Debug().Msgf("Reply received from LLM(finetuned): %s", message) + + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: message}}) + return + } + } + } + + log.Debug().Msgf("No action received from LLM, without a message, computing a reply") + // Otherwise ask the LLM to understand the JSON output and the context, and return a message + // Note: This costs (in term of CPU) another computation + config.Grammar = "" + predFunc, err := ModelInference(predInput, o.loader, *config, o, nil) + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction, err := predFunc() + if err != nil { + log.Error().Msgf("inference error: %s", err.Error()) + return + } + + prediction = Finetune(*config, predInput, prediction) + *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: prediction}}) + } else { + // otherwise reply with the function call + *c = append(*c, Choice{ + FinishReason: "function_call", + Message: &Message{Role: "function", FunctionCall: ss}, + }) + } + + return + } *c = append(*c, Choice{Message: &Message{Role: "assistant", Content: s}}) }, nil) if err != nil { diff --git a/api/prediction.go b/api/prediction.go index bc23d86..7daa730 100644 --- a/api/prediction.go +++ b/api/prediction.go @@ -189,6 +189,8 @@ func buildLLamaPredictOptions(c Config, modelPath string) []llama.PredictOption predictOptions = append(predictOptions, llama.EnablePromptCacheRO) } + predictOptions = append(predictOptions, llama.WithGrammar(c.Grammar)) + if c.PromptCachePath != "" { // Create parent directory p := filepath.Join(modelPath, c.PromptCachePath) diff --git a/pkg/grammar/functions.go b/pkg/grammar/functions.go new file mode 100644 index 0000000..0971322 --- /dev/null +++ b/pkg/grammar/functions.go @@ -0,0 +1,50 @@ +package grammar + +import ( + "encoding/json" +) + +type Function struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` +} +type Functions []Function + +func (f Functions) ToJSONStructure() JSONStructure { + js := JSONStructure{} + for _, function := range f { + // t := function.Parameters["type"] + //tt := t.(string) + + properties := function.Parameters["properties"] + dat, _ := json.Marshal(properties) + prop := map[string]interface{}{} + json.Unmarshal(dat, &prop) + js.OneOf = append(js.OneOf, Item{ + Type: "object", + Properties: Properties{ + Function: FunctionName{Const: function.Name}, + Arguments: Argument{ + Type: "object", + Properties: prop, + }, + }, + }) + } + return js +} + +// Select returns a list of functions containing the function with the given name +func (f Functions) Select(name string) Functions { + var funcs Functions + + for _, f := range f { + if f.Name == name { + funcs = []Function{f} + break + } + } + + return funcs +} diff --git a/pkg/grammar/grammar_suite_test.go b/pkg/grammar/grammar_suite_test.go new file mode 100644 index 0000000..652643b --- /dev/null +++ b/pkg/grammar/grammar_suite_test.go @@ -0,0 +1,13 @@ +package grammar + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestGrammar(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Grammar test suite") +} diff --git a/pkg/grammar/json_schema.go b/pkg/grammar/json_schema.go new file mode 100644 index 0000000..447921e --- /dev/null +++ b/pkg/grammar/json_schema.go @@ -0,0 +1,222 @@ +package grammar + +// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887 + +import ( + "encoding/json" + "fmt" + "regexp" + "sort" + "strings" +) + +var ( + SPACE_RULE = `" "?` + + PRIMITIVE_RULES = map[string]string{ + "boolean": `("true" | "false") space`, + "number": `[0-9]+ space`, // TODO complete + "string": `"\"" [ \t!#-\[\]-~]* "\"" space`, // TODO complete + "null": `"null" space`, + } + + INVALID_RULE_CHARS_RE = regexp.MustCompile(`[^a-zA-Z0-9-]+`) + GRAMMAR_LITERAL_ESCAPE_RE = regexp.MustCompile(`[\r\n"]`) + GRAMMAR_LITERAL_ESCAPES = map[string]string{ + "\r": `\r`, + "\n": `\n`, + `"`: `\"`, + } +) + +type JSONSchemaConverter struct { + propOrder map[string]int + rules map[string]string +} + +func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter { + propOrderSlice := strings.Split(propOrder, ",") + propOrderMap := make(map[string]int) + for idx, name := range propOrderSlice { + propOrderMap[name] = idx + } + + rules := make(map[string]string) + rules["space"] = SPACE_RULE + + return &JSONSchemaConverter{ + propOrder: propOrderMap, + rules: rules, + } +} + +func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) string { + escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jsonString(literal), func(match string) string { + return GRAMMAR_LITERAL_ESCAPES[match] + }) + return fmt.Sprintf(`"%s"`, escaped) +} + +func (sc *JSONSchemaConverter) addRule(name, rule string) string { + escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-") + key := escName + if existingRule, ok := sc.rules[escName]; ok && existingRule != rule { + i := 0 + for { + key = fmt.Sprintf("%s%d", escName, i) + if _, ok := sc.rules[key]; !ok { + break + } + i++ + } + } + sc.rules[key] = rule + return key +} + +func (sc *JSONSchemaConverter) formatGrammar() string { + var lines []string + for name, rule := range sc.rules { + lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) + } + return strings.Join(lines, "\n") +} + +func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string) string { + st, existType := schema["type"] + var schemaType string + if existType { + schemaType = st.(string) + } + ruleName := name + if name == "" { + ruleName = "root" + } + _, oneOfExists := schema["oneOf"] + _, anyOfExists := schema["anyOf"] + if oneOfExists || anyOfExists { + var alternatives []string + oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{}) + anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{}) + + if oneOfExists { + for i, altSchema := range oneOfSchemas { + alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i)) + alternatives = append(alternatives, alternative) + } + } else if anyOfExists { + for i, altSchema := range anyOfSchemas { + alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i)) + alternatives = append(alternatives, alternative) + } + } + + rule := strings.Join(alternatives, " | ") + return sc.addRule(ruleName, rule) + } else if constVal, exists := schema["const"]; exists { + return sc.addRule(ruleName, sc.formatLiteral(constVal)) + } else if enumVals, exists := schema["enum"].([]interface{}); exists { + var enumRules []string + for _, enumVal := range enumVals { + enumRule := sc.formatLiteral(enumVal) + enumRules = append(enumRules, enumRule) + } + rule := strings.Join(enumRules, " | ") + return sc.addRule(ruleName, rule) + } else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists { + propOrder := sc.propOrder + var propPairs []struct { + propName string + propSchema map[string]interface{} + } + + for propName, propSchema := range properties { + propPairs = append(propPairs, struct { + propName string + propSchema map[string]interface{} + }{propName: propName, propSchema: propSchema.(map[string]interface{})}) + } + + sort.Slice(propPairs, func(i, j int) bool { + iOrder := propOrder[propPairs[i].propName] + jOrder := propOrder[propPairs[j].propName] + if iOrder != 0 && jOrder != 0 { + return iOrder < jOrder + } + return propPairs[i].propName < propPairs[j].propName + }) + + var rule strings.Builder + rule.WriteString(`"{" space`) + + for i, propPair := range propPairs { + propName := propPair.propName + propSchema := propPair.propSchema + propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName)) + + if i > 0 { + rule.WriteString(` "," space`) + } + + rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, sc.formatLiteral(propName), propRuleName)) + } + + rule.WriteString(` "}" space`) + return sc.addRule(ruleName, rule.String()) + } else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists { + itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName)) + rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName) + return sc.addRule(ruleName, rule) + } else { + primitiveRule, exists := PRIMITIVE_RULES[schemaType] + if !exists { + panic(fmt.Sprintf("Unrecognized schema: %v", schema)) + } + return sc.addRule(schemaType, primitiveRule) + } +} + +func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string { + sc.visit(schema, "") + return sc.formatGrammar() +} + +func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string { + var schema map[string]interface{} + _ = json.Unmarshal(b, &schema) + return sc.Grammar(schema) +} + +func jsonString(v interface{}) string { + b, _ := json.Marshal(v) + return string(b) +} + +type FunctionName struct { + Const string `json:"const"` +} + +type Properties struct { + Function FunctionName `json:"function"` + Arguments Argument `json:"arguments"` +} + +type Argument struct { + Type string `json:"type"` + Properties map[string]interface{} `json:"properties"` +} + +type Item struct { + Type string `json:"type"` + Properties Properties `json:"properties"` +} + +type JSONStructure struct { + OneOf []Item `json:"oneOf,omitempty"` + AnyOf []Item `json:"anyOf,omitempty"` +} + +func (j JSONStructure) Grammar(propOrder string) string { + dat, _ := json.Marshal(j) + return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat) +} diff --git a/pkg/grammar/json_schema_test.go b/pkg/grammar/json_schema_test.go new file mode 100644 index 0000000..94e2958 --- /dev/null +++ b/pkg/grammar/json_schema_test.go @@ -0,0 +1,113 @@ +package grammar_test + +import ( + "strings" + + . "github.com/go-skynet/LocalAI/pkg/grammar" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +const ( + testInput1 = ` + { + "oneOf": [ + { + "type": "object", + "properties": { + "function": {"const": "create_event"}, + "arguments": { + "type": "object", + "properties": { + "title": {"type": "string"}, + "date": {"type": "string"}, + "time": {"type": "string"} + } + } + } + }, + { + "type": "object", + "properties": { + "function": {"const": "search"}, + "arguments": { + "type": "object", + "properties": { + "query": {"type": "string"} + } + } + } + } + ] + }` + + inputResult1 = `root-0-function ::= "\"create_event\"" +root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space +root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space +root ::= root-0 | root-1 +space ::= " "? +root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space +root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space +string ::= "\"" [ \t!#-\[\]-~]* "\"" space +root-1-function ::= "\"search\""` +) + +var _ = Describe("JSON schema grammar tests", func() { + Context("JSON", func() { + It("generates a valid grammar from JSON schema", func() { + grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1)) + results := strings.Split(inputResult1, "\n") + for _, r := range results { + if r != "" { + Expect(grammar).To(ContainSubstring(r)) + } + } + Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) + }) + It("generates a valid grammar from JSON Objects", func() { + + structuredGrammar := JSONStructure{ + OneOf: []Item{ + { + Type: "object", + Properties: Properties{ + Function: FunctionName{ + Const: "create_event", + }, + Arguments: Argument{ // this is OpenAI's parameter + Type: "object", + Properties: map[string]interface{}{ + "title": map[string]string{"type": "string"}, + "date": map[string]string{"type": "string"}, + "time": map[string]string{"type": "string"}, + }, + }, + }, + }, + { + Type: "object", + Properties: Properties{ + Function: FunctionName{ + Const: "search", + }, + Arguments: Argument{ + Type: "object", + Properties: map[string]interface{}{ + "query": map[string]string{"type": "string"}, + }, + }, + }, + }, + }} + + grammar := structuredGrammar.Grammar("") + results := strings.Split(inputResult1, "\n") + for _, r := range results { + if r != "" { + Expect(grammar).To(ContainSubstring(r)) + } + } + Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) + }) + }) +})