2018-02-28 10:11:02 -05:00
|
|
|
package assert
|
|
|
|
|
|
|
|
import (
|
2022-09-22 09:38:19 -04:00
|
|
|
"errors"
|
2018-02-28 10:11:02 -05:00
|
|
|
"fmt"
|
|
|
|
"go/ast"
|
|
|
|
|
2020-02-22 12:12:14 -05:00
|
|
|
"gotest.tools/v3/assert/cmp"
|
|
|
|
"gotest.tools/v3/internal/format"
|
|
|
|
"gotest.tools/v3/internal/source"
|
2018-02-28 10:11:02 -05:00
|
|
|
)
|
|
|
|
|
2022-03-01 09:50:32 -05:00
|
|
|
// RunComparison and return Comparison.Success. If the comparison fails a messages
|
|
|
|
// will be printed using t.Log.
|
|
|
|
func RunComparison(
|
|
|
|
t LogT,
|
2017-12-21 15:25:54 -05:00
|
|
|
argSelector argSelector,
|
2018-02-28 10:11:02 -05:00
|
|
|
f cmp.Comparison,
|
|
|
|
msgAndArgs ...interface{},
|
|
|
|
) bool {
|
|
|
|
if ht, ok := t.(helperT); ok {
|
|
|
|
ht.Helper()
|
|
|
|
}
|
|
|
|
result := f()
|
|
|
|
if result.Success() {
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
2023-10-20 11:39:10 -04:00
|
|
|
if source.IsUpdate() {
|
2022-09-22 09:38:19 -04:00
|
|
|
if updater, ok := result.(updateExpected); ok {
|
|
|
|
const stackIndex = 3 // Assert/Check, assert, RunComparison
|
|
|
|
err := updater.UpdatedExpected(stackIndex)
|
|
|
|
switch {
|
|
|
|
case err == nil:
|
|
|
|
return true
|
|
|
|
case errors.Is(err, source.ErrNotFound):
|
|
|
|
// do nothing, fallthrough to regular failure message
|
|
|
|
default:
|
|
|
|
t.Log("failed to update source", err)
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2018-02-28 10:11:02 -05:00
|
|
|
var message string
|
|
|
|
switch typed := result.(type) {
|
|
|
|
case resultWithComparisonArgs:
|
2022-03-01 09:50:32 -05:00
|
|
|
const stackIndex = 3 // Assert/Check, assert, RunComparison
|
2018-02-28 10:11:02 -05:00
|
|
|
args, err := source.CallExprArgs(stackIndex)
|
|
|
|
if err != nil {
|
|
|
|
t.Log(err.Error())
|
|
|
|
}
|
2017-12-21 15:25:54 -05:00
|
|
|
message = typed.FailureMessage(filterPrintableExpr(argSelector(args)))
|
2018-02-28 10:11:02 -05:00
|
|
|
case resultBasic:
|
|
|
|
message = typed.FailureMessage()
|
|
|
|
default:
|
|
|
|
message = fmt.Sprintf("comparison returned invalid Result type: %T", result)
|
|
|
|
}
|
|
|
|
|
|
|
|
t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
type resultWithComparisonArgs interface {
|
|
|
|
FailureMessage(args []ast.Expr) string
|
|
|
|
}
|
|
|
|
|
|
|
|
type resultBasic interface {
|
|
|
|
FailureMessage() string
|
|
|
|
}
|
|
|
|
|
2022-09-22 09:38:19 -04:00
|
|
|
type updateExpected interface {
|
|
|
|
UpdatedExpected(stackIndex int) error
|
|
|
|
}
|
|
|
|
|
2017-12-21 15:25:54 -05:00
|
|
|
// filterPrintableExpr filters the ast.Expr slice to only include Expr that are
|
2018-02-28 10:11:02 -05:00
|
|
|
// easy to read when printed and contain relevant information to an assertion.
|
|
|
|
//
|
|
|
|
// Ident and SelectorExpr are included because they print nicely and the variable
|
|
|
|
// names may provide additional context to their values.
|
|
|
|
// BasicLit and CompositeLit are excluded because their source is equivalent to
|
|
|
|
// their value, which is already available.
|
|
|
|
// Other types are ignored for now, but could be added if they are relevant.
|
|
|
|
func filterPrintableExpr(args []ast.Expr) []ast.Expr {
|
|
|
|
result := make([]ast.Expr, len(args))
|
|
|
|
for i, arg := range args {
|
2017-12-21 15:25:54 -05:00
|
|
|
if isShortPrintableExpr(arg) {
|
2018-02-28 10:11:02 -05:00
|
|
|
result[i] = arg
|
2017-12-21 15:25:54 -05:00
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if starExpr, ok := arg.(*ast.StarExpr); ok {
|
|
|
|
result[i] = starExpr.X
|
|
|
|
continue
|
2018-02-28 10:11:02 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
return result
|
|
|
|
}
|
|
|
|
|
2017-12-21 15:25:54 -05:00
|
|
|
func isShortPrintableExpr(expr ast.Expr) bool {
|
|
|
|
switch expr.(type) {
|
|
|
|
case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr:
|
|
|
|
return true
|
|
|
|
case *ast.BinaryExpr, *ast.UnaryExpr:
|
|
|
|
return true
|
|
|
|
default:
|
|
|
|
// CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
type argSelector func([]ast.Expr) []ast.Expr
|
|
|
|
|
2022-03-01 09:50:32 -05:00
|
|
|
// ArgsAfterT selects args starting at position 1. Used when the caller has a
|
|
|
|
// testing.T as the first argument, and the args to select should follow it.
|
|
|
|
func ArgsAfterT(args []ast.Expr) []ast.Expr {
|
2018-02-28 10:11:02 -05:00
|
|
|
if len(args) < 1 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
return args[1:]
|
|
|
|
}
|
|
|
|
|
2022-03-01 09:50:32 -05:00
|
|
|
// ArgsFromComparisonCall selects args from the CallExpression at position 1.
|
|
|
|
// Used when the caller has a testing.T as the first argument, and the args to
|
|
|
|
// select are passed to the cmp.Comparison at position 1.
|
|
|
|
func ArgsFromComparisonCall(args []ast.Expr) []ast.Expr {
|
|
|
|
if len(args) <= 1 {
|
2018-02-28 10:11:02 -05:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
if callExpr, ok := args[1].(*ast.CallExpr); ok {
|
|
|
|
return callExpr.Args
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
2022-03-01 09:50:32 -05:00
|
|
|
|
|
|
|
// ArgsAtZeroIndex selects args from the CallExpression at position 1.
|
|
|
|
// Used when the caller accepts a single cmp.Comparison argument.
|
|
|
|
func ArgsAtZeroIndex(args []ast.Expr) []ast.Expr {
|
|
|
|
if len(args) == 0 {
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
if callExpr, ok := args[0].(*ast.CallExpr); ok {
|
|
|
|
return callExpr.Args
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|