package assert import ( "errors" "fmt" "go/ast" "gotest.tools/v3/assert/cmp" "gotest.tools/v3/internal/format" "gotest.tools/v3/internal/source" ) // RunComparison and return Comparison.Success. If the comparison fails a messages // will be printed using t.Log. func RunComparison( t LogT, argSelector argSelector, f cmp.Comparison, msgAndArgs ...interface{}, ) bool { if ht, ok := t.(helperT); ok { ht.Helper() } result := f() if result.Success() { return true } if source.IsUpdate() { if updater, ok := result.(updateExpected); ok { const stackIndex = 3 // Assert/Check, assert, RunComparison err := updater.UpdatedExpected(stackIndex) switch { case err == nil: return true case errors.Is(err, source.ErrNotFound): // do nothing, fallthrough to regular failure message default: t.Log("failed to update source", err) return false } } } var message string switch typed := result.(type) { case resultWithComparisonArgs: const stackIndex = 3 // Assert/Check, assert, RunComparison args, err := source.CallExprArgs(stackIndex) if err != nil { t.Log(err.Error()) } message = typed.FailureMessage(filterPrintableExpr(argSelector(args))) case resultBasic: message = typed.FailureMessage() default: message = fmt.Sprintf("comparison returned invalid Result type: %T", result) } t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...)) return false } type resultWithComparisonArgs interface { FailureMessage(args []ast.Expr) string } type resultBasic interface { FailureMessage() string } type updateExpected interface { UpdatedExpected(stackIndex int) error } // filterPrintableExpr filters the ast.Expr slice to only include Expr that are // easy to read when printed and contain relevant information to an assertion. // // Ident and SelectorExpr are included because they print nicely and the variable // names may provide additional context to their values. // BasicLit and CompositeLit are excluded because their source is equivalent to // their value, which is already available. // Other types are ignored for now, but could be added if they are relevant. func filterPrintableExpr(args []ast.Expr) []ast.Expr { result := make([]ast.Expr, len(args)) for i, arg := range args { if isShortPrintableExpr(arg) { result[i] = arg continue } if starExpr, ok := arg.(*ast.StarExpr); ok { result[i] = starExpr.X continue } } return result } func isShortPrintableExpr(expr ast.Expr) bool { switch expr.(type) { case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr: return true case *ast.BinaryExpr, *ast.UnaryExpr: return true default: // CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr return false } } type argSelector func([]ast.Expr) []ast.Expr // ArgsAfterT selects args starting at position 1. Used when the caller has a // testing.T as the first argument, and the args to select should follow it. func ArgsAfterT(args []ast.Expr) []ast.Expr { if len(args) < 1 { return nil } return args[1:] } // ArgsFromComparisonCall selects args from the CallExpression at position 1. // Used when the caller has a testing.T as the first argument, and the args to // select are passed to the cmp.Comparison at position 1. func ArgsFromComparisonCall(args []ast.Expr) []ast.Expr { if len(args) <= 1 { return nil } if callExpr, ok := args[1].(*ast.CallExpr); ok { return callExpr.Args } return nil } // ArgsAtZeroIndex selects args from the CallExpression at position 1. // Used when the caller accepts a single cmp.Comparison argument. func ArgsAtZeroIndex(args []ast.Expr) []ast.Expr { if len(args) == 0 { return nil } if callExpr, ok := args[0].(*ast.CallExpr); ok { return callExpr.Args } return nil }