parent
a6839fd238
commit
f09ddd2983
@ -0,0 +1,50 @@ |
|||||||
|
package grammar |
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
) |
||||||
|
|
||||||
|
type Function struct { |
||||||
|
Name string `json:"name"` |
||||||
|
Description string `json:"description"` |
||||||
|
Parameters map[string]interface{} `json:"parameters"` |
||||||
|
} |
||||||
|
type Functions []Function |
||||||
|
|
||||||
|
func (f Functions) ToJSONStructure() JSONStructure { |
||||||
|
js := JSONStructure{} |
||||||
|
for _, function := range f { |
||||||
|
// t := function.Parameters["type"]
|
||||||
|
//tt := t.(string)
|
||||||
|
|
||||||
|
properties := function.Parameters["properties"] |
||||||
|
dat, _ := json.Marshal(properties) |
||||||
|
prop := map[string]interface{}{} |
||||||
|
json.Unmarshal(dat, &prop) |
||||||
|
js.OneOf = append(js.OneOf, Item{ |
||||||
|
Type: "object", |
||||||
|
Properties: Properties{ |
||||||
|
Function: FunctionName{Const: function.Name}, |
||||||
|
Arguments: Argument{ |
||||||
|
Type: "object", |
||||||
|
Properties: prop, |
||||||
|
}, |
||||||
|
}, |
||||||
|
}) |
||||||
|
} |
||||||
|
return js |
||||||
|
} |
||||||
|
|
||||||
|
// Select returns a list of functions containing the function with the given name
|
||||||
|
func (f Functions) Select(name string) Functions { |
||||||
|
var funcs Functions |
||||||
|
|
||||||
|
for _, f := range f { |
||||||
|
if f.Name == name { |
||||||
|
funcs = []Function{f} |
||||||
|
break |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
return funcs |
||||||
|
} |
@ -0,0 +1,13 @@ |
|||||||
|
package grammar |
||||||
|
|
||||||
|
import ( |
||||||
|
"testing" |
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo/v2" |
||||||
|
. "github.com/onsi/gomega" |
||||||
|
) |
||||||
|
|
||||||
|
func TestGrammar(t *testing.T) { |
||||||
|
RegisterFailHandler(Fail) |
||||||
|
RunSpecs(t, "Grammar test suite") |
||||||
|
} |
@ -0,0 +1,222 @@ |
|||||||
|
package grammar |
||||||
|
|
||||||
|
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887
|
||||||
|
|
||||||
|
import ( |
||||||
|
"encoding/json" |
||||||
|
"fmt" |
||||||
|
"regexp" |
||||||
|
"sort" |
||||||
|
"strings" |
||||||
|
) |
||||||
|
|
||||||
|
var ( |
||||||
|
SPACE_RULE = `" "?` |
||||||
|
|
||||||
|
PRIMITIVE_RULES = map[string]string{ |
||||||
|
"boolean": `("true" | "false") space`, |
||||||
|
"number": `[0-9]+ space`, // TODO complete
|
||||||
|
"string": `"\"" [ \t!#-\[\]-~]* "\"" space`, // TODO complete
|
||||||
|
"null": `"null" space`, |
||||||
|
} |
||||||
|
|
||||||
|
INVALID_RULE_CHARS_RE = regexp.MustCompile(`[^a-zA-Z0-9-]+`) |
||||||
|
GRAMMAR_LITERAL_ESCAPE_RE = regexp.MustCompile(`[\r\n"]`) |
||||||
|
GRAMMAR_LITERAL_ESCAPES = map[string]string{ |
||||||
|
"\r": `\r`, |
||||||
|
"\n": `\n`, |
||||||
|
`"`: `\"`, |
||||||
|
} |
||||||
|
) |
||||||
|
|
||||||
|
type JSONSchemaConverter struct { |
||||||
|
propOrder map[string]int |
||||||
|
rules map[string]string |
||||||
|
} |
||||||
|
|
||||||
|
func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter { |
||||||
|
propOrderSlice := strings.Split(propOrder, ",") |
||||||
|
propOrderMap := make(map[string]int) |
||||||
|
for idx, name := range propOrderSlice { |
||||||
|
propOrderMap[name] = idx |
||||||
|
} |
||||||
|
|
||||||
|
rules := make(map[string]string) |
||||||
|
rules["space"] = SPACE_RULE |
||||||
|
|
||||||
|
return &JSONSchemaConverter{ |
||||||
|
propOrder: propOrderMap, |
||||||
|
rules: rules, |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) string { |
||||||
|
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jsonString(literal), func(match string) string { |
||||||
|
return GRAMMAR_LITERAL_ESCAPES[match] |
||||||
|
}) |
||||||
|
return fmt.Sprintf(`"%s"`, escaped) |
||||||
|
} |
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) addRule(name, rule string) string { |
||||||
|
escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-") |
||||||
|
key := escName |
||||||
|
if existingRule, ok := sc.rules[escName]; ok && existingRule != rule { |
||||||
|
i := 0 |
||||||
|
for { |
||||||
|
key = fmt.Sprintf("%s%d", escName, i) |
||||||
|
if _, ok := sc.rules[key]; !ok { |
||||||
|
break |
||||||
|
} |
||||||
|
i++ |
||||||
|
} |
||||||
|
} |
||||||
|
sc.rules[key] = rule |
||||||
|
return key |
||||||
|
} |
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) formatGrammar() string { |
||||||
|
var lines []string |
||||||
|
for name, rule := range sc.rules { |
||||||
|
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) |
||||||
|
} |
||||||
|
return strings.Join(lines, "\n") |
||||||
|
} |
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string) string { |
||||||
|
st, existType := schema["type"] |
||||||
|
var schemaType string |
||||||
|
if existType { |
||||||
|
schemaType = st.(string) |
||||||
|
} |
||||||
|
ruleName := name |
||||||
|
if name == "" { |
||||||
|
ruleName = "root" |
||||||
|
} |
||||||
|
_, oneOfExists := schema["oneOf"] |
||||||
|
_, anyOfExists := schema["anyOf"] |
||||||
|
if oneOfExists || anyOfExists { |
||||||
|
var alternatives []string |
||||||
|
oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{}) |
||||||
|
anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{}) |
||||||
|
|
||||||
|
if oneOfExists { |
||||||
|
for i, altSchema := range oneOfSchemas { |
||||||
|
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i)) |
||||||
|
alternatives = append(alternatives, alternative) |
||||||
|
} |
||||||
|
} else if anyOfExists { |
||||||
|
for i, altSchema := range anyOfSchemas { |
||||||
|
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i)) |
||||||
|
alternatives = append(alternatives, alternative) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
rule := strings.Join(alternatives, " | ") |
||||||
|
return sc.addRule(ruleName, rule) |
||||||
|
} else if constVal, exists := schema["const"]; exists { |
||||||
|
return sc.addRule(ruleName, sc.formatLiteral(constVal)) |
||||||
|
} else if enumVals, exists := schema["enum"].([]interface{}); exists { |
||||||
|
var enumRules []string |
||||||
|
for _, enumVal := range enumVals { |
||||||
|
enumRule := sc.formatLiteral(enumVal) |
||||||
|
enumRules = append(enumRules, enumRule) |
||||||
|
} |
||||||
|
rule := strings.Join(enumRules, " | ") |
||||||
|
return sc.addRule(ruleName, rule) |
||||||
|
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists { |
||||||
|
propOrder := sc.propOrder |
||||||
|
var propPairs []struct { |
||||||
|
propName string |
||||||
|
propSchema map[string]interface{} |
||||||
|
} |
||||||
|
|
||||||
|
for propName, propSchema := range properties { |
||||||
|
propPairs = append(propPairs, struct { |
||||||
|
propName string |
||||||
|
propSchema map[string]interface{} |
||||||
|
}{propName: propName, propSchema: propSchema.(map[string]interface{})}) |
||||||
|
} |
||||||
|
|
||||||
|
sort.Slice(propPairs, func(i, j int) bool { |
||||||
|
iOrder := propOrder[propPairs[i].propName] |
||||||
|
jOrder := propOrder[propPairs[j].propName] |
||||||
|
if iOrder != 0 && jOrder != 0 { |
||||||
|
return iOrder < jOrder |
||||||
|
} |
||||||
|
return propPairs[i].propName < propPairs[j].propName |
||||||
|
}) |
||||||
|
|
||||||
|
var rule strings.Builder |
||||||
|
rule.WriteString(`"{" space`) |
||||||
|
|
||||||
|
for i, propPair := range propPairs { |
||||||
|
propName := propPair.propName |
||||||
|
propSchema := propPair.propSchema |
||||||
|
propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName)) |
||||||
|
|
||||||
|
if i > 0 { |
||||||
|
rule.WriteString(` "," space`) |
||||||
|
} |
||||||
|
|
||||||
|
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, sc.formatLiteral(propName), propRuleName)) |
||||||
|
} |
||||||
|
|
||||||
|
rule.WriteString(` "}" space`) |
||||||
|
return sc.addRule(ruleName, rule.String()) |
||||||
|
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists { |
||||||
|
itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName)) |
||||||
|
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName) |
||||||
|
return sc.addRule(ruleName, rule) |
||||||
|
} else { |
||||||
|
primitiveRule, exists := PRIMITIVE_RULES[schemaType] |
||||||
|
if !exists { |
||||||
|
panic(fmt.Sprintf("Unrecognized schema: %v", schema)) |
||||||
|
} |
||||||
|
return sc.addRule(schemaType, primitiveRule) |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string { |
||||||
|
sc.visit(schema, "") |
||||||
|
return sc.formatGrammar() |
||||||
|
} |
||||||
|
|
||||||
|
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string { |
||||||
|
var schema map[string]interface{} |
||||||
|
_ = json.Unmarshal(b, &schema) |
||||||
|
return sc.Grammar(schema) |
||||||
|
} |
||||||
|
|
||||||
|
func jsonString(v interface{}) string { |
||||||
|
b, _ := json.Marshal(v) |
||||||
|
return string(b) |
||||||
|
} |
||||||
|
|
||||||
|
type FunctionName struct { |
||||||
|
Const string `json:"const"` |
||||||
|
} |
||||||
|
|
||||||
|
type Properties struct { |
||||||
|
Function FunctionName `json:"function"` |
||||||
|
Arguments Argument `json:"arguments"` |
||||||
|
} |
||||||
|
|
||||||
|
type Argument struct { |
||||||
|
Type string `json:"type"` |
||||||
|
Properties map[string]interface{} `json:"properties"` |
||||||
|
} |
||||||
|
|
||||||
|
type Item struct { |
||||||
|
Type string `json:"type"` |
||||||
|
Properties Properties `json:"properties"` |
||||||
|
} |
||||||
|
|
||||||
|
type JSONStructure struct { |
||||||
|
OneOf []Item `json:"oneOf,omitempty"` |
||||||
|
AnyOf []Item `json:"anyOf,omitempty"` |
||||||
|
} |
||||||
|
|
||||||
|
func (j JSONStructure) Grammar(propOrder string) string { |
||||||
|
dat, _ := json.Marshal(j) |
||||||
|
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat) |
||||||
|
} |
@ -0,0 +1,113 @@ |
|||||||
|
package grammar_test |
||||||
|
|
||||||
|
import ( |
||||||
|
"strings" |
||||||
|
|
||||||
|
. "github.com/go-skynet/LocalAI/pkg/grammar" |
||||||
|
. "github.com/onsi/ginkgo/v2" |
||||||
|
. "github.com/onsi/gomega" |
||||||
|
) |
||||||
|
|
||||||
|
const ( |
||||||
|
testInput1 = ` |
||||||
|
{ |
||||||
|
"oneOf": [ |
||||||
|
{ |
||||||
|
"type": "object", |
||||||
|
"properties": { |
||||||
|
"function": {"const": "create_event"}, |
||||||
|
"arguments": { |
||||||
|
"type": "object", |
||||||
|
"properties": { |
||||||
|
"title": {"type": "string"}, |
||||||
|
"date": {"type": "string"}, |
||||||
|
"time": {"type": "string"} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
}, |
||||||
|
{ |
||||||
|
"type": "object", |
||||||
|
"properties": { |
||||||
|
"function": {"const": "search"}, |
||||||
|
"arguments": { |
||||||
|
"type": "object", |
||||||
|
"properties": { |
||||||
|
"query": {"type": "string"} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
} |
||||||
|
] |
||||||
|
}` |
||||||
|
|
||||||
|
inputResult1 = `root-0-function ::= "\"create_event\"" |
||||||
|
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space |
||||||
|
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space |
||||||
|
root ::= root-0 | root-1 |
||||||
|
space ::= " "? |
||||||
|
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space |
||||||
|
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space |
||||||
|
string ::= "\"" [ \t!#-\[\]-~]* "\"" space |
||||||
|
root-1-function ::= "\"search\""` |
||||||
|
) |
||||||
|
|
||||||
|
var _ = Describe("JSON schema grammar tests", func() { |
||||||
|
Context("JSON", func() { |
||||||
|
It("generates a valid grammar from JSON schema", func() { |
||||||
|
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1)) |
||||||
|
results := strings.Split(inputResult1, "\n") |
||||||
|
for _, r := range results { |
||||||
|
if r != "" { |
||||||
|
Expect(grammar).To(ContainSubstring(r)) |
||||||
|
} |
||||||
|
} |
||||||
|
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) |
||||||
|
}) |
||||||
|
It("generates a valid grammar from JSON Objects", func() { |
||||||
|
|
||||||
|
structuredGrammar := JSONStructure{ |
||||||
|
OneOf: []Item{ |
||||||
|
{ |
||||||
|
Type: "object", |
||||||
|
Properties: Properties{ |
||||||
|
Function: FunctionName{ |
||||||
|
Const: "create_event", |
||||||
|
}, |
||||||
|
Arguments: Argument{ // this is OpenAI's parameter
|
||||||
|
Type: "object", |
||||||
|
Properties: map[string]interface{}{ |
||||||
|
"title": map[string]string{"type": "string"}, |
||||||
|
"date": map[string]string{"type": "string"}, |
||||||
|
"time": map[string]string{"type": "string"}, |
||||||
|
}, |
||||||
|
}, |
||||||
|
}, |
||||||
|
}, |
||||||
|
{ |
||||||
|
Type: "object", |
||||||
|
Properties: Properties{ |
||||||
|
Function: FunctionName{ |
||||||
|
Const: "search", |
||||||
|
}, |
||||||
|
Arguments: Argument{ |
||||||
|
Type: "object", |
||||||
|
Properties: map[string]interface{}{ |
||||||
|
"query": map[string]string{"type": "string"}, |
||||||
|
}, |
||||||
|
}, |
||||||
|
}, |
||||||
|
}, |
||||||
|
}} |
||||||
|
|
||||||
|
grammar := structuredGrammar.Grammar("") |
||||||
|
results := strings.Split(inputResult1, "\n") |
||||||
|
for _, r := range results { |
||||||
|
if r != "" { |
||||||
|
Expect(grammar).To(ContainSubstring(r)) |
||||||
|
} |
||||||
|
} |
||||||
|
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) |
||||||
|
}) |
||||||
|
}) |
||||||
|
}) |
Loading…
Reference in new issue