parent
a6839fd238
commit
f09ddd2983
@ -0,0 +1,50 @@ |
||||
package grammar |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
) |
||||
|
||||
type Function struct { |
||||
Name string `json:"name"` |
||||
Description string `json:"description"` |
||||
Parameters map[string]interface{} `json:"parameters"` |
||||
} |
||||
type Functions []Function |
||||
|
||||
func (f Functions) ToJSONStructure() JSONStructure { |
||||
js := JSONStructure{} |
||||
for _, function := range f { |
||||
// t := function.Parameters["type"]
|
||||
//tt := t.(string)
|
||||
|
||||
properties := function.Parameters["properties"] |
||||
dat, _ := json.Marshal(properties) |
||||
prop := map[string]interface{}{} |
||||
json.Unmarshal(dat, &prop) |
||||
js.OneOf = append(js.OneOf, Item{ |
||||
Type: "object", |
||||
Properties: Properties{ |
||||
Function: FunctionName{Const: function.Name}, |
||||
Arguments: Argument{ |
||||
Type: "object", |
||||
Properties: prop, |
||||
}, |
||||
}, |
||||
}) |
||||
} |
||||
return js |
||||
} |
||||
|
||||
// Select returns a list of functions containing the function with the given name
|
||||
func (f Functions) Select(name string) Functions { |
||||
var funcs Functions |
||||
|
||||
for _, f := range f { |
||||
if f.Name == name { |
||||
funcs = []Function{f} |
||||
break |
||||
} |
||||
} |
||||
|
||||
return funcs |
||||
} |
@ -0,0 +1,13 @@ |
||||
package grammar |
||||
|
||||
import ( |
||||
"testing" |
||||
|
||||
. "github.com/onsi/ginkgo/v2" |
||||
. "github.com/onsi/gomega" |
||||
) |
||||
|
||||
func TestGrammar(t *testing.T) { |
||||
RegisterFailHandler(Fail) |
||||
RunSpecs(t, "Grammar test suite") |
||||
} |
@ -0,0 +1,222 @@ |
||||
package grammar |
||||
|
||||
// a golang port of https://github.com/ggerganov/llama.cpp/pull/1887
|
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"regexp" |
||||
"sort" |
||||
"strings" |
||||
) |
||||
|
||||
var ( |
||||
SPACE_RULE = `" "?` |
||||
|
||||
PRIMITIVE_RULES = map[string]string{ |
||||
"boolean": `("true" | "false") space`, |
||||
"number": `[0-9]+ space`, // TODO complete
|
||||
"string": `"\"" [ \t!#-\[\]-~]* "\"" space`, // TODO complete
|
||||
"null": `"null" space`, |
||||
} |
||||
|
||||
INVALID_RULE_CHARS_RE = regexp.MustCompile(`[^a-zA-Z0-9-]+`) |
||||
GRAMMAR_LITERAL_ESCAPE_RE = regexp.MustCompile(`[\r\n"]`) |
||||
GRAMMAR_LITERAL_ESCAPES = map[string]string{ |
||||
"\r": `\r`, |
||||
"\n": `\n`, |
||||
`"`: `\"`, |
||||
} |
||||
) |
||||
|
||||
type JSONSchemaConverter struct { |
||||
propOrder map[string]int |
||||
rules map[string]string |
||||
} |
||||
|
||||
func NewJSONSchemaConverter(propOrder string) *JSONSchemaConverter { |
||||
propOrderSlice := strings.Split(propOrder, ",") |
||||
propOrderMap := make(map[string]int) |
||||
for idx, name := range propOrderSlice { |
||||
propOrderMap[name] = idx |
||||
} |
||||
|
||||
rules := make(map[string]string) |
||||
rules["space"] = SPACE_RULE |
||||
|
||||
return &JSONSchemaConverter{ |
||||
propOrder: propOrderMap, |
||||
rules: rules, |
||||
} |
||||
} |
||||
|
||||
func (sc *JSONSchemaConverter) formatLiteral(literal interface{}) string { |
||||
escaped := GRAMMAR_LITERAL_ESCAPE_RE.ReplaceAllStringFunc(jsonString(literal), func(match string) string { |
||||
return GRAMMAR_LITERAL_ESCAPES[match] |
||||
}) |
||||
return fmt.Sprintf(`"%s"`, escaped) |
||||
} |
||||
|
||||
func (sc *JSONSchemaConverter) addRule(name, rule string) string { |
||||
escName := INVALID_RULE_CHARS_RE.ReplaceAllString(name, "-") |
||||
key := escName |
||||
if existingRule, ok := sc.rules[escName]; ok && existingRule != rule { |
||||
i := 0 |
||||
for { |
||||
key = fmt.Sprintf("%s%d", escName, i) |
||||
if _, ok := sc.rules[key]; !ok { |
||||
break |
||||
} |
||||
i++ |
||||
} |
||||
} |
||||
sc.rules[key] = rule |
||||
return key |
||||
} |
||||
|
||||
func (sc *JSONSchemaConverter) formatGrammar() string { |
||||
var lines []string |
||||
for name, rule := range sc.rules { |
||||
lines = append(lines, fmt.Sprintf("%s ::= %s", name, rule)) |
||||
} |
||||
return strings.Join(lines, "\n") |
||||
} |
||||
|
||||
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string) string { |
||||
st, existType := schema["type"] |
||||
var schemaType string |
||||
if existType { |
||||
schemaType = st.(string) |
||||
} |
||||
ruleName := name |
||||
if name == "" { |
||||
ruleName = "root" |
||||
} |
||||
_, oneOfExists := schema["oneOf"] |
||||
_, anyOfExists := schema["anyOf"] |
||||
if oneOfExists || anyOfExists { |
||||
var alternatives []string |
||||
oneOfSchemas, oneOfExists := schema["oneOf"].([]interface{}) |
||||
anyOfSchemas, anyOfExists := schema["anyOf"].([]interface{}) |
||||
|
||||
if oneOfExists { |
||||
for i, altSchema := range oneOfSchemas { |
||||
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i)) |
||||
alternatives = append(alternatives, alternative) |
||||
} |
||||
} else if anyOfExists { |
||||
for i, altSchema := range anyOfSchemas { |
||||
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i)) |
||||
alternatives = append(alternatives, alternative) |
||||
} |
||||
} |
||||
|
||||
rule := strings.Join(alternatives, " | ") |
||||
return sc.addRule(ruleName, rule) |
||||
} else if constVal, exists := schema["const"]; exists { |
||||
return sc.addRule(ruleName, sc.formatLiteral(constVal)) |
||||
} else if enumVals, exists := schema["enum"].([]interface{}); exists { |
||||
var enumRules []string |
||||
for _, enumVal := range enumVals { |
||||
enumRule := sc.formatLiteral(enumVal) |
||||
enumRules = append(enumRules, enumRule) |
||||
} |
||||
rule := strings.Join(enumRules, " | ") |
||||
return sc.addRule(ruleName, rule) |
||||
} else if properties, exists := schema["properties"].(map[string]interface{}); schemaType == "object" && exists { |
||||
propOrder := sc.propOrder |
||||
var propPairs []struct { |
||||
propName string |
||||
propSchema map[string]interface{} |
||||
} |
||||
|
||||
for propName, propSchema := range properties { |
||||
propPairs = append(propPairs, struct { |
||||
propName string |
||||
propSchema map[string]interface{} |
||||
}{propName: propName, propSchema: propSchema.(map[string]interface{})}) |
||||
} |
||||
|
||||
sort.Slice(propPairs, func(i, j int) bool { |
||||
iOrder := propOrder[propPairs[i].propName] |
||||
jOrder := propOrder[propPairs[j].propName] |
||||
if iOrder != 0 && jOrder != 0 { |
||||
return iOrder < jOrder |
||||
} |
||||
return propPairs[i].propName < propPairs[j].propName |
||||
}) |
||||
|
||||
var rule strings.Builder |
||||
rule.WriteString(`"{" space`) |
||||
|
||||
for i, propPair := range propPairs { |
||||
propName := propPair.propName |
||||
propSchema := propPair.propSchema |
||||
propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName)) |
||||
|
||||
if i > 0 { |
||||
rule.WriteString(` "," space`) |
||||
} |
||||
|
||||
rule.WriteString(fmt.Sprintf(` %s space ":" space %s`, sc.formatLiteral(propName), propRuleName)) |
||||
} |
||||
|
||||
rule.WriteString(` "}" space`) |
||||
return sc.addRule(ruleName, rule.String()) |
||||
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists { |
||||
itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName)) |
||||
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName) |
||||
return sc.addRule(ruleName, rule) |
||||
} else { |
||||
primitiveRule, exists := PRIMITIVE_RULES[schemaType] |
||||
if !exists { |
||||
panic(fmt.Sprintf("Unrecognized schema: %v", schema)) |
||||
} |
||||
return sc.addRule(schemaType, primitiveRule) |
||||
} |
||||
} |
||||
|
||||
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string { |
||||
sc.visit(schema, "") |
||||
return sc.formatGrammar() |
||||
} |
||||
|
||||
func (sc *JSONSchemaConverter) GrammarFromBytes(b []byte) string { |
||||
var schema map[string]interface{} |
||||
_ = json.Unmarshal(b, &schema) |
||||
return sc.Grammar(schema) |
||||
} |
||||
|
||||
func jsonString(v interface{}) string { |
||||
b, _ := json.Marshal(v) |
||||
return string(b) |
||||
} |
||||
|
||||
type FunctionName struct { |
||||
Const string `json:"const"` |
||||
} |
||||
|
||||
type Properties struct { |
||||
Function FunctionName `json:"function"` |
||||
Arguments Argument `json:"arguments"` |
||||
} |
||||
|
||||
type Argument struct { |
||||
Type string `json:"type"` |
||||
Properties map[string]interface{} `json:"properties"` |
||||
} |
||||
|
||||
type Item struct { |
||||
Type string `json:"type"` |
||||
Properties Properties `json:"properties"` |
||||
} |
||||
|
||||
type JSONStructure struct { |
||||
OneOf []Item `json:"oneOf,omitempty"` |
||||
AnyOf []Item `json:"anyOf,omitempty"` |
||||
} |
||||
|
||||
func (j JSONStructure) Grammar(propOrder string) string { |
||||
dat, _ := json.Marshal(j) |
||||
return NewJSONSchemaConverter(propOrder).GrammarFromBytes(dat) |
||||
} |
@ -0,0 +1,113 @@ |
||||
package grammar_test |
||||
|
||||
import ( |
||||
"strings" |
||||
|
||||
. "github.com/go-skynet/LocalAI/pkg/grammar" |
||||
. "github.com/onsi/ginkgo/v2" |
||||
. "github.com/onsi/gomega" |
||||
) |
||||
|
||||
const ( |
||||
testInput1 = ` |
||||
{ |
||||
"oneOf": [ |
||||
{ |
||||
"type": "object", |
||||
"properties": { |
||||
"function": {"const": "create_event"}, |
||||
"arguments": { |
||||
"type": "object", |
||||
"properties": { |
||||
"title": {"type": "string"}, |
||||
"date": {"type": "string"}, |
||||
"time": {"type": "string"} |
||||
} |
||||
} |
||||
} |
||||
}, |
||||
{ |
||||
"type": "object", |
||||
"properties": { |
||||
"function": {"const": "search"}, |
||||
"arguments": { |
||||
"type": "object", |
||||
"properties": { |
||||
"query": {"type": "string"} |
||||
} |
||||
} |
||||
} |
||||
} |
||||
] |
||||
}` |
||||
|
||||
inputResult1 = `root-0-function ::= "\"create_event\"" |
||||
root-0 ::= "{" space "\"arguments\"" space ":" space root-0-arguments "," space "\"function\"" space ":" space root-0-function "}" space |
||||
root-1-arguments ::= "{" space "\"query\"" space ":" space string "}" space |
||||
root ::= root-0 | root-1 |
||||
space ::= " "? |
||||
root-0-arguments ::= "{" space "\"date\"" space ":" space string "," space "\"time\"" space ":" space string "," space "\"title\"" space ":" space string "}" space |
||||
root-1 ::= "{" space "\"arguments\"" space ":" space root-1-arguments "," space "\"function\"" space ":" space root-1-function "}" space |
||||
string ::= "\"" [ \t!#-\[\]-~]* "\"" space |
||||
root-1-function ::= "\"search\""` |
||||
) |
||||
|
||||
var _ = Describe("JSON schema grammar tests", func() { |
||||
Context("JSON", func() { |
||||
It("generates a valid grammar from JSON schema", func() { |
||||
grammar := NewJSONSchemaConverter("").GrammarFromBytes([]byte(testInput1)) |
||||
results := strings.Split(inputResult1, "\n") |
||||
for _, r := range results { |
||||
if r != "" { |
||||
Expect(grammar).To(ContainSubstring(r)) |
||||
} |
||||
} |
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) |
||||
}) |
||||
It("generates a valid grammar from JSON Objects", func() { |
||||
|
||||
structuredGrammar := JSONStructure{ |
||||
OneOf: []Item{ |
||||
{ |
||||
Type: "object", |
||||
Properties: Properties{ |
||||
Function: FunctionName{ |
||||
Const: "create_event", |
||||
}, |
||||
Arguments: Argument{ // this is OpenAI's parameter
|
||||
Type: "object", |
||||
Properties: map[string]interface{}{ |
||||
"title": map[string]string{"type": "string"}, |
||||
"date": map[string]string{"type": "string"}, |
||||
"time": map[string]string{"type": "string"}, |
||||
}, |
||||
}, |
||||
}, |
||||
}, |
||||
{ |
||||
Type: "object", |
||||
Properties: Properties{ |
||||
Function: FunctionName{ |
||||
Const: "search", |
||||
}, |
||||
Arguments: Argument{ |
||||
Type: "object", |
||||
Properties: map[string]interface{}{ |
||||
"query": map[string]string{"type": "string"}, |
||||
}, |
||||
}, |
||||
}, |
||||
}, |
||||
}} |
||||
|
||||
grammar := structuredGrammar.Grammar("") |
||||
results := strings.Split(inputResult1, "\n") |
||||
for _, r := range results { |
||||
if r != "" { |
||||
Expect(grammar).To(ContainSubstring(r)) |
||||
} |
||||
} |
||||
Expect(len(results)).To(Equal(len(strings.Split(grammar, "\n")))) |
||||
}) |
||||
}) |
||||
}) |
Loading…
Reference in new issue