package assert import ( "fmt" "go/ast" "gotest.tools/assert/cmp" "gotest.tools/internal/format" "gotest.tools/internal/source" ) func runComparison( t TestingT, argSelector argSelector, f cmp.Comparison, msgAndArgs ...interface{}, ) bool { if ht, ok := t.(helperT); ok { ht.Helper() } result := f() if result.Success() { return true } var message string switch typed := result.(type) { case resultWithComparisonArgs: const stackIndex = 3 // Assert/Check, assert, runComparison args, err := source.CallExprArgs(stackIndex) if err != nil { t.Log(err.Error()) } message = typed.FailureMessage(filterPrintableExpr(argSelector(args))) case resultBasic: message = typed.FailureMessage() default: message = fmt.Sprintf("comparison returned invalid Result type: %T", result) } t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...)) return false } type resultWithComparisonArgs interface { FailureMessage(args []ast.Expr) string } type resultBasic interface { FailureMessage() string } // filterPrintableExpr filters the ast.Expr slice to only include Expr that are // easy to read when printed and contain relevant information to an assertion. // // Ident and SelectorExpr are included because they print nicely and the variable // names may provide additional context to their values. // BasicLit and CompositeLit are excluded because their source is equivalent to // their value, which is already available. // Other types are ignored for now, but could be added if they are relevant. func filterPrintableExpr(args []ast.Expr) []ast.Expr { result := make([]ast.Expr, len(args)) for i, arg := range args { if isShortPrintableExpr(arg) { result[i] = arg continue } if starExpr, ok := arg.(*ast.StarExpr); ok { result[i] = starExpr.X continue } } return result } func isShortPrintableExpr(expr ast.Expr) bool { switch expr.(type) { case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr: return true case *ast.BinaryExpr, *ast.UnaryExpr: return true default: // CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr return false } } type argSelector func([]ast.Expr) []ast.Expr func argsAfterT(args []ast.Expr) []ast.Expr { if len(args) < 1 { return nil } return args[1:] } func argsFromComparisonCall(args []ast.Expr) []ast.Expr { if len(args) < 1 { return nil } if callExpr, ok := args[1].(*ast.CallExpr); ok { return callExpr.Args } return nil }