feat: resolve JSONSchema refs (planners) (#774)

renovate/github.com-sashabaranov-go-openai-1.x
Ettore Di Giacinto 1 year ago committed by GitHub
parent a38dc497b2
commit 236497e331
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      pkg/grammar/functions.go
  2. 38
      pkg/grammar/json_schema.go

@ -18,9 +18,17 @@ func (f Functions) ToJSONStructure() JSONFunctionStructure {
//tt := t.(string) //tt := t.(string)
properties := function.Parameters["properties"] properties := function.Parameters["properties"]
defs := function.Parameters["$defs"]
dat, _ := json.Marshal(properties) dat, _ := json.Marshal(properties)
dat2, _ := json.Marshal(defs)
prop := map[string]interface{}{} prop := map[string]interface{}{}
defsD := map[string]interface{}{}
json.Unmarshal(dat, &prop) json.Unmarshal(dat, &prop)
json.Unmarshal(dat2, &defsD)
if js.Defs == nil {
js.Defs = defsD
}
js.OneOf = append(js.OneOf, Item{ js.OneOf = append(js.OneOf, Item{
Type: "object", Type: "object",
Properties: Properties{ Properties: Properties{

@ -16,6 +16,7 @@ var (
PRIMITIVE_RULES = map[string]string{ PRIMITIVE_RULES = map[string]string{
"boolean": `("true" | "false") space`, "boolean": `("true" | "false") space`,
"number": `[0-9]+ space`, // TODO complete "number": `[0-9]+ space`, // TODO complete
"integer": `[0-9]+ space`, // TODO complete
"string": `"\"" [ \t!#-\[\]-~]* "\"" space`, // TODO complete "string": `"\"" [ \t!#-\[\]-~]* "\"" space`, // TODO complete
"null": `"null" space`, "null": `"null" space`,
} }
@ -82,7 +83,7 @@ func (sc *JSONSchemaConverter) formatGrammar() string {
return strings.Join(lines, "\n") return strings.Join(lines, "\n")
} }
func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string) string { func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string, rootSchema map[string]interface{}) string {
st, existType := schema["type"] st, existType := schema["type"]
var schemaType string var schemaType string
if existType { if existType {
@ -101,18 +102,21 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string)
if oneOfExists { if oneOfExists {
for i, altSchema := range oneOfSchemas { for i, altSchema := range oneOfSchemas {
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i)) alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
alternatives = append(alternatives, alternative) alternatives = append(alternatives, alternative)
} }
} else if anyOfExists { } else if anyOfExists {
for i, altSchema := range anyOfSchemas { for i, altSchema := range anyOfSchemas {
alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i)) alternative := sc.visit(altSchema.(map[string]interface{}), fmt.Sprintf("%s-%d", ruleName, i), rootSchema)
alternatives = append(alternatives, alternative) alternatives = append(alternatives, alternative)
} }
} }
rule := strings.Join(alternatives, " | ") rule := strings.Join(alternatives, " | ")
return sc.addRule(ruleName, rule) return sc.addRule(ruleName, rule)
} else if ref, exists := schema["$ref"].(string); exists {
referencedSchema := sc.resolveReference(ref, rootSchema)
return sc.visit(referencedSchema, name, rootSchema)
} else if constVal, exists := schema["const"]; exists { } else if constVal, exists := schema["const"]; exists {
return sc.addRule(ruleName, sc.formatLiteral(constVal)) return sc.addRule(ruleName, sc.formatLiteral(constVal))
} else if enumVals, exists := schema["enum"].([]interface{}); exists { } else if enumVals, exists := schema["enum"].([]interface{}); exists {
@ -152,7 +156,7 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string)
for i, propPair := range propPairs { for i, propPair := range propPairs {
propName := propPair.propName propName := propPair.propName
propSchema := propPair.propSchema propSchema := propPair.propSchema
propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName)) propRuleName := sc.visit(propSchema, fmt.Sprintf("%s-%s", ruleName, propName), rootSchema)
if i > 0 { if i > 0 {
rule.WriteString(` "," space`) rule.WriteString(` "," space`)
@ -164,7 +168,7 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string)
rule.WriteString(` "}" space`) rule.WriteString(` "}" space`)
return sc.addRule(ruleName, rule.String()) return sc.addRule(ruleName, rule.String())
} else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists { } else if items, exists := schema["items"].(map[string]interface{}); schemaType == "array" && exists {
itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName)) itemRuleName := sc.visit(items, fmt.Sprintf("%s-item", ruleName), rootSchema)
rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName) rule := fmt.Sprintf(`"[" space (%s ("," space %s)*)? "]" space`, itemRuleName, itemRuleName)
return sc.addRule(ruleName, rule) return sc.addRule(ruleName, rule)
} else { } else {
@ -175,9 +179,30 @@ func (sc *JSONSchemaConverter) visit(schema map[string]interface{}, name string)
return sc.addRule(schemaType, primitiveRule) return sc.addRule(schemaType, primitiveRule)
} }
} }
func (sc *JSONSchemaConverter) resolveReference(ref string, rootSchema map[string]interface{}) map[string]interface{} {
if !strings.HasPrefix(ref, "#/$defs/") {
panic(fmt.Sprintf("Invalid reference format: %s", ref))
}
defKey := strings.TrimPrefix(ref, "#/$defs/")
definitions, exists := rootSchema["$defs"].(map[string]interface{})
if !exists {
fmt.Println(rootSchema)
panic("No definitions found in the schema")
}
def, exists := definitions[defKey].(map[string]interface{})
if !exists {
fmt.Println(definitions)
panic(fmt.Sprintf("Definition not found: %s", defKey))
}
return def
}
func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string { func (sc *JSONSchemaConverter) Grammar(schema map[string]interface{}) string {
sc.visit(schema, "") sc.visit(schema, "", schema)
return sc.formatGrammar() return sc.formatGrammar()
} }
@ -214,6 +239,7 @@ type Item struct {
type JSONFunctionStructure struct { type JSONFunctionStructure struct {
OneOf []Item `json:"oneOf,omitempty"` OneOf []Item `json:"oneOf,omitempty"`
AnyOf []Item `json:"anyOf,omitempty"` AnyOf []Item `json:"anyOf,omitempty"`
Defs map[string]interface{} `json:"$defs,omitempty"`
} }
func (j JSONFunctionStructure) Grammar(propOrder string) string { func (j JSONFunctionStructure) Grammar(propOrder string) string {

Loading…
Cancel
Save