mirror of https://github.com/docker/cli.git
172 lines
4.2 KiB
Go
172 lines
4.2 KiB
Go
package source
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"go/ast"
|
|
"go/format"
|
|
"go/parser"
|
|
"go/token"
|
|
"os"
|
|
"runtime"
|
|
"strings"
|
|
)
|
|
|
|
// IsUpdate is returns true if the -update flag is set. It indicates the user
|
|
// running the tests would like to update any golden values.
|
|
func IsUpdate() bool {
|
|
if Update {
|
|
return true
|
|
}
|
|
return flag.Lookup("update").Value.(flag.Getter).Get().(bool)
|
|
}
|
|
|
|
// Update is a shim for testing, and for compatibility with the old -update-golden
|
|
// flag.
|
|
var Update bool
|
|
|
|
func init() {
|
|
if f := flag.Lookup("update"); f != nil {
|
|
getter, ok := f.Value.(flag.Getter)
|
|
msg := "some other package defined an incompatible -update flag, expected a flag.Bool"
|
|
if !ok {
|
|
panic(msg)
|
|
}
|
|
if _, ok := getter.Get().(bool); !ok {
|
|
panic(msg)
|
|
}
|
|
return
|
|
}
|
|
flag.Bool("update", false, "update golden values")
|
|
}
|
|
|
|
// ErrNotFound indicates that UpdateExpectedValue failed to find the
|
|
// variable to update, likely because it is not a package level variable.
|
|
var ErrNotFound = fmt.Errorf("failed to find variable for update of golden value")
|
|
|
|
// UpdateExpectedValue looks for a package-level variable with a name that
|
|
// starts with expected in the arguments to the caller. If the variable is
|
|
// found, the value of the variable will be updated to value of the other
|
|
// argument to the caller.
|
|
func UpdateExpectedValue(stackIndex int, x, y interface{}) error {
|
|
_, filename, line, ok := runtime.Caller(stackIndex + 1)
|
|
if !ok {
|
|
return errors.New("failed to get call stack")
|
|
}
|
|
debug("call stack position: %s:%d", filename, line)
|
|
|
|
fileset := token.NewFileSet()
|
|
astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors|parser.ParseComments)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse source file %s: %w", filename, err)
|
|
}
|
|
|
|
expr, err := getCallExprArgs(fileset, astFile, line)
|
|
if err != nil {
|
|
return fmt.Errorf("call from %s:%d: %w", filename, line, err)
|
|
}
|
|
|
|
if len(expr) < 3 {
|
|
debug("not enough arguments %d: %v",
|
|
len(expr), debugFormatNode{Node: &ast.CallExpr{Args: expr}})
|
|
return ErrNotFound
|
|
}
|
|
|
|
argIndex, ident := getIdentForExpectedValueArg(expr)
|
|
if argIndex < 0 || ident == nil {
|
|
debug("no arguments started with the word 'expected': %v",
|
|
debugFormatNode{Node: &ast.CallExpr{Args: expr}})
|
|
return ErrNotFound
|
|
}
|
|
|
|
value := x
|
|
if argIndex == 1 {
|
|
value = y
|
|
}
|
|
|
|
strValue, ok := value.(string)
|
|
if !ok {
|
|
debug("value must be type string, got %T", value)
|
|
return ErrNotFound
|
|
}
|
|
return UpdateVariable(filename, fileset, astFile, ident, strValue)
|
|
}
|
|
|
|
// UpdateVariable writes to filename the contents of astFile with the value of
|
|
// the variable updated to value.
|
|
func UpdateVariable(
|
|
filename string,
|
|
fileset *token.FileSet,
|
|
astFile *ast.File,
|
|
ident *ast.Ident,
|
|
value string,
|
|
) error {
|
|
obj := ident.Obj
|
|
if obj == nil {
|
|
return ErrNotFound
|
|
}
|
|
if obj.Kind != ast.Con && obj.Kind != ast.Var {
|
|
debug("can only update var and const, found %v", obj.Kind)
|
|
return ErrNotFound
|
|
}
|
|
|
|
switch decl := obj.Decl.(type) {
|
|
case *ast.ValueSpec:
|
|
if len(decl.Names) != 1 {
|
|
debug("more than one name in ast.ValueSpec")
|
|
return ErrNotFound
|
|
}
|
|
|
|
decl.Values[0] = &ast.BasicLit{
|
|
Kind: token.STRING,
|
|
Value: "`" + value + "`",
|
|
}
|
|
|
|
case *ast.AssignStmt:
|
|
if len(decl.Lhs) != 1 {
|
|
debug("more than one name in ast.AssignStmt")
|
|
return ErrNotFound
|
|
}
|
|
|
|
decl.Rhs[0] = &ast.BasicLit{
|
|
Kind: token.STRING,
|
|
Value: "`" + value + "`",
|
|
}
|
|
|
|
default:
|
|
debug("can only update *ast.ValueSpec, found %T", obj.Decl)
|
|
return ErrNotFound
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := format.Node(&buf, fileset, astFile); err != nil {
|
|
return fmt.Errorf("failed to format file after update: %w", err)
|
|
}
|
|
|
|
fh, err := os.Create(filename)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open file %v: %w", filename, err)
|
|
}
|
|
if _, err = fh.Write(buf.Bytes()); err != nil {
|
|
return fmt.Errorf("failed to write file %v: %w", filename, err)
|
|
}
|
|
if err := fh.Sync(); err != nil {
|
|
return fmt.Errorf("failed to sync file %v: %w", filename, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getIdentForExpectedValueArg(expr []ast.Expr) (int, *ast.Ident) {
|
|
for i := 1; i < 3; i++ {
|
|
switch e := expr[i].(type) {
|
|
case *ast.Ident:
|
|
if strings.HasPrefix(strings.ToLower(e.Name), "expected") {
|
|
return i, e
|
|
}
|
|
}
|
|
}
|
|
return -1, nil
|
|
}
|