diff --git a/cli/command/builder/client_test.go b/cli/command/builder/client_test.go new file mode 100644 index 0000000000..6214dcf30d --- /dev/null +++ b/cli/command/builder/client_test.go @@ -0,0 +1,20 @@ +package builder + +import ( + "context" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/client" +) + +type fakeClient struct { + client.Client + builderPruneFunc func(ctx context.Context, opts types.BuildCachePruneOptions) (*types.BuildCachePruneReport, error) +} + +func (c *fakeClient) BuildCachePrune(ctx context.Context, opts types.BuildCachePruneOptions) (*types.BuildCachePruneReport, error) { + if c.builderPruneFunc != nil { + return c.builderPruneFunc(ctx, opts) + } + return nil, nil +} diff --git a/cli/command/builder/prune.go b/cli/command/builder/prune.go index 6286218cac..cd453cbbc0 100644 --- a/cli/command/builder/prune.go +++ b/cli/command/builder/prune.go @@ -66,8 +66,10 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) if options.all { warning = allCacheWarning } - if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), warning) { - return 0, "", nil + if !options.force { + if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning); !r || err != nil { + return 0, "", err + } } report, err := dockerCli.Client().BuildCachePrune(ctx, types.BuildCachePruneOptions{ diff --git a/cli/command/builder/prune_test.go b/cli/command/builder/prune_test.go new file mode 100644 index 0000000000..57ea995445 --- /dev/null +++ b/cli/command/builder/prune_test.go @@ -0,0 +1,28 @@ +package builder + +import ( + "context" + "errors" + "testing" + + "github.com/docker/cli/cli/command" + "github.com/docker/cli/internal/test" + "github.com/docker/docker/api/types" + "gotest.tools/v3/assert" +) + +func TestBuilderPromptTermination(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cli := test.NewFakeCli(&fakeClient{ + builderPruneFunc: func(ctx context.Context, opts types.BuildCachePruneOptions) (*types.BuildCachePruneReport, error) { + return nil, errors.New("fakeClient builderPruneFunc should not be called") + }, + }) + cmd := NewPruneCommand(cli) + test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, command.ErrPromptTerminated) + }) +} diff --git a/cli/command/container/client_test.go b/cli/command/container/client_test.go index 4b17fb3ca9..c45f34040f 100644 --- a/cli/command/container/client_test.go +++ b/cli/command/container/client_test.go @@ -6,6 +6,7 @@ import ( "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/network" "github.com/docker/docker/api/types/system" @@ -35,6 +36,7 @@ type fakeClient struct { containerExecResizeFunc func(id string, options container.ResizeOptions) error containerRemoveFunc func(ctx context.Context, containerID string, options container.RemoveOptions) error containerKillFunc func(ctx context.Context, containerID, signal string) error + containerPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) Version string } @@ -164,3 +166,10 @@ func (f *fakeClient) ContainerKill(ctx context.Context, containerID, signal stri } return nil } + +func (f *fakeClient) ContainersPrune(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) { + if f.containerPruneFunc != nil { + return f.containerPruneFunc(ctx, pruneFilters) + } + return types.ContainersPruneReport{}, nil +} diff --git a/cli/command/container/prune.go b/cli/command/container/prune.go index 0db41092e8..583473fac7 100644 --- a/cli/command/container/prune.go +++ b/cli/command/container/prune.go @@ -53,8 +53,10 @@ Are you sure you want to continue?` func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) (spaceReclaimed uint64, output string, err error) { pruneFilters := command.PruneFilters(dockerCli, options.filter.Value()) - if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), warning) { - return 0, "", nil + if !options.force { + if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning); !r || err != nil { + return 0, "", err + } } report, err := dockerCli.Client().ContainersPrune(ctx, pruneFilters) diff --git a/cli/command/container/prune_test.go b/cli/command/container/prune_test.go new file mode 100644 index 0000000000..05cf87b37b --- /dev/null +++ b/cli/command/container/prune_test.go @@ -0,0 +1,29 @@ +package container + +import ( + "context" + "testing" + + "github.com/docker/cli/cli/command" + "github.com/docker/cli/internal/test" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/filters" + "github.com/pkg/errors" + "gotest.tools/v3/assert" +) + +func TestContainerPrunePromptTermination(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cli := test.NewFakeCli(&fakeClient{ + containerPruneFunc: func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) { + return types.ContainersPruneReport{}, errors.New("fakeClient containerPruneFunc should not be called") + }, + }) + cmd := NewPruneCommand(cli) + test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, command.ErrPromptTerminated) + }) +} diff --git a/cli/command/image/prune.go b/cli/command/image/prune.go index 205a9dedb3..d627633cd5 100644 --- a/cli/command/image/prune.go +++ b/cli/command/image/prune.go @@ -67,8 +67,10 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) if options.all { warning = allImageWarning } - if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), warning) { - return 0, "", nil + if !options.force { + if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning); !r || err != nil { + return 0, "", err + } } report, err := dockerCli.Client().ImagesPrune(ctx, pruneFilters) diff --git a/cli/command/image/prune_test.go b/cli/command/image/prune_test.go index a60c283c68..472f9f7f6d 100644 --- a/cli/command/image/prune_test.go +++ b/cli/command/image/prune_test.go @@ -1,10 +1,12 @@ package image import ( + "context" "fmt" "io" "testing" + "github.com/docker/cli/cli/command" "github.com/docker/cli/internal/test" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/filters" @@ -101,3 +103,19 @@ func TestNewPruneCommandSuccess(t *testing.T) { golden.Assert(t, cli.OutBuffer().String(), fmt.Sprintf("prune-command-success.%s.golden", tc.name)) } } + +func TestPrunePromptTermination(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cli := test.NewFakeCli(&fakeClient{ + imagesPruneFunc: func(pruneFilter filters.Args) (types.ImagesPruneReport, error) { + return types.ImagesPruneReport{}, errors.New("fakeClient imagesPruneFunc should not be called") + }, + }) + cmd := NewPruneCommand(cli) + test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, command.ErrPromptTerminated) + }) +} diff --git a/cli/command/network/client_test.go b/cli/command/network/client_test.go index 18c91e251d..d2fa4111fb 100644 --- a/cli/command/network/client_test.go +++ b/cli/command/network/client_test.go @@ -4,6 +4,7 @@ import ( "context" "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/network" "github.com/docker/docker/client" ) @@ -15,6 +16,8 @@ type fakeClient struct { networkDisconnectFunc func(ctx context.Context, networkID, container string, force bool) error networkRemoveFunc func(ctx context.Context, networkID string) error networkListFunc func(ctx context.Context, options types.NetworkListOptions) ([]types.NetworkResource, error) + networkPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.NetworksPruneReport, error) + networkInspectFunc func(ctx context.Context, networkID string, options types.NetworkInspectOptions) (types.NetworkResource, []byte, error) } func (c *fakeClient) NetworkCreate(ctx context.Context, name string, options types.NetworkCreate) (types.NetworkCreateResponse, error) { @@ -52,6 +55,16 @@ func (c *fakeClient) NetworkRemove(ctx context.Context, networkID string) error return nil } -func (c *fakeClient) NetworkInspectWithRaw(context.Context, string, types.NetworkInspectOptions) (types.NetworkResource, []byte, error) { +func (c *fakeClient) NetworkInspectWithRaw(ctx context.Context, networkID string, opts types.NetworkInspectOptions) (types.NetworkResource, []byte, error) { + if c.networkInspectFunc != nil { + return c.networkInspectFunc(ctx, networkID, opts) + } return types.NetworkResource{}, nil, nil } + +func (c *fakeClient) NetworksPrune(ctx context.Context, pruneFilter filters.Args) (types.NetworksPruneReport, error) { + if c.networkPruneFunc != nil { + return c.networkPruneFunc(ctx, pruneFilter) + } + return types.NetworksPruneReport{}, nil +} diff --git a/cli/command/network/prune.go b/cli/command/network/prune.go index f097878ac8..5f06770327 100644 --- a/cli/command/network/prune.go +++ b/cli/command/network/prune.go @@ -49,8 +49,10 @@ Are you sure you want to continue?` func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) (output string, err error) { pruneFilters := command.PruneFilters(dockerCli, options.filter.Value()) - if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), warning) { - return "", nil + if !options.force { + if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning); !r || err != nil { + return "", err + } } report, err := dockerCli.Client().NetworksPrune(ctx, pruneFilters) diff --git a/cli/command/network/prune_test.go b/cli/command/network/prune_test.go new file mode 100644 index 0000000000..9cf62a38e0 --- /dev/null +++ b/cli/command/network/prune_test.go @@ -0,0 +1,29 @@ +package network + +import ( + "context" + "testing" + + "github.com/docker/cli/cli/command" + "github.com/docker/cli/internal/test" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/filters" + "github.com/pkg/errors" + "gotest.tools/v3/assert" +) + +func TestNetworkPrunePromptTermination(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cli := test.NewFakeCli(&fakeClient{ + networkPruneFunc: func(ctx context.Context, pruneFilters filters.Args) (types.NetworksPruneReport, error) { + return types.NetworksPruneReport{}, errors.New("fakeClient networkPruneFunc should not be called") + }, + }) + cmd := NewPruneCommand(cli) + test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, command.ErrPromptTerminated) + }) +} diff --git a/cli/command/network/remove.go b/cli/command/network/remove.go index cefd06c065..b584a4c37c 100644 --- a/cli/command/network/remove.go +++ b/cli/command/network/remove.go @@ -46,10 +46,15 @@ func runRemove(ctx context.Context, dockerCli command.Cli, networks []string, op status := 0 for _, name := range networks { - if nw, _, err := client.NetworkInspectWithRaw(ctx, name, types.NetworkInspectOptions{}); err == nil && - nw.Ingress && - !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), ingressWarning) { - continue + nw, _, err := client.NetworkInspectWithRaw(ctx, name, types.NetworkInspectOptions{}) + if err == nil && nw.Ingress { + r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), ingressWarning) + if err != nil { + return err + } + if !r { + continue + } } if err := client.NetworkRemove(ctx, name); err != nil { if opts.force && errdefs.IsNotFound(err) { diff --git a/cli/command/network/remove_test.go b/cli/command/network/remove_test.go index 98ea2592f6..5ad2a61df9 100644 --- a/cli/command/network/remove_test.go +++ b/cli/command/network/remove_test.go @@ -5,7 +5,9 @@ import ( "io" "testing" + "github.com/docker/cli/cli/command" "github.com/docker/cli/internal/test" + "github.com/docker/docker/api/types" "github.com/docker/docker/errdefs" "github.com/pkg/errors" "gotest.tools/v3/assert" @@ -94,3 +96,27 @@ func TestNetworkRemoveForce(t *testing.T) { }) } } + +func TestNetworkRemovePromptTermination(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cli := test.NewFakeCli(&fakeClient{ + networkRemoveFunc: func(ctx context.Context, networkID string) error { + return errors.New("fakeClient networkRemoveFunc should not be called") + }, + networkInspectFunc: func(ctx context.Context, networkID string, options types.NetworkInspectOptions) (types.NetworkResource, []byte, error) { + return types.NetworkResource{ + ID: "existing-network", + Name: "existing-network", + Ingress: true, + }, nil, nil + }, + }) + cmd := newRemoveCommand(cli) + cmd.SetArgs([]string{"existing-network"}) + test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, command.ErrPromptTerminated) + }) +} diff --git a/cli/command/plugin/client_test.go b/cli/command/plugin/client_test.go index c2bd136d0c..8eaf06896d 100644 --- a/cli/command/plugin/client_test.go +++ b/cli/command/plugin/client_test.go @@ -19,6 +19,7 @@ type fakeClient struct { pluginInstallFunc func(name string, options types.PluginInstallOptions) (io.ReadCloser, error) pluginListFunc func(filter filters.Args) (types.PluginsListResponse, error) pluginInspectFunc func(name string) (*types.Plugin, []byte, error) + pluginUpgradeFunc func(name string, options types.PluginInstallOptions) (io.ReadCloser, error) } func (c *fakeClient) PluginCreate(_ context.Context, createContext io.Reader, createOptions types.PluginCreateOptions) error { @@ -75,3 +76,10 @@ func (c *fakeClient) PluginInspectWithRaw(_ context.Context, name string) (*type func (c *fakeClient) Info(context.Context) (system.Info, error) { return system.Info{}, nil } + +func (c *fakeClient) PluginUpgrade(ctx context.Context, name string, options types.PluginInstallOptions) (io.ReadCloser, error) { + if c.pluginUpgradeFunc != nil { + return c.pluginUpgradeFunc(name, options) + } + return nil, nil +} diff --git a/cli/command/plugin/install.go b/cli/command/plugin/install.go index beea9af86d..f1aa9d8725 100644 --- a/cli/command/plugin/install.go +++ b/cli/command/plugin/install.go @@ -142,6 +142,7 @@ func acceptPrivileges(dockerCli command.Cli, name string) func(privileges types. for _, privilege := range privileges { fmt.Fprintf(dockerCli.Out(), " - %s: %v\n", privilege.Name, privilege.Value) } - return command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), "Do you grant the above permissions?"), nil + ctx := context.TODO() + return command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), "Do you grant the above permissions?") } } diff --git a/cli/command/plugin/testdata/plugin-upgrade-terminate.golden b/cli/command/plugin/testdata/plugin-upgrade-terminate.golden new file mode 100644 index 0000000000..511aad36cc --- /dev/null +++ b/cli/command/plugin/testdata/plugin-upgrade-terminate.golden @@ -0,0 +1,2 @@ +Upgrading plugin foo/bar from localhost:5000/foo/bar:v0.1.0 to localhost:5000/foo/bar:v1.0.0 +Plugin images do not match, are you sure? [y/N] \ No newline at end of file diff --git a/cli/command/plugin/upgrade.go b/cli/command/plugin/upgrade.go index 5134f3642d..fc7abb4127 100644 --- a/cli/command/plugin/upgrade.go +++ b/cli/command/plugin/upgrade.go @@ -63,7 +63,10 @@ func runUpgrade(ctx context.Context, dockerCli command.Cli, opts pluginOptions) fmt.Fprintf(dockerCli.Out(), "Upgrading plugin %s from %s to %s\n", p.Name, reference.FamiliarString(old), reference.FamiliarString(remote)) if !opts.skipRemoteCheck && remote.String() != old.String() { - if !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), "Plugin images do not match, are you sure?") { + if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), "Plugin images do not match, are you sure?"); !r || err != nil { + if err != nil { + return errors.Wrap(err, "canceling upgrade request") + } return errors.New("canceling upgrade request") } } diff --git a/cli/command/plugin/upgrade_test.go b/cli/command/plugin/upgrade_test.go new file mode 100644 index 0000000000..6511e4364c --- /dev/null +++ b/cli/command/plugin/upgrade_test.go @@ -0,0 +1,42 @@ +package plugin + +import ( + "context" + "io" + "testing" + + "github.com/docker/cli/cli/command" + "github.com/docker/cli/internal/test" + "github.com/docker/docker/api/types" + "github.com/pkg/errors" + "gotest.tools/v3/assert" + "gotest.tools/v3/golden" +) + +func TestUpgradePromptTermination(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cli := test.NewFakeCli(&fakeClient{ + pluginUpgradeFunc: func(name string, options types.PluginInstallOptions) (io.ReadCloser, error) { + return nil, errors.New("should not be called") + }, + pluginInspectFunc: func(name string) (*types.Plugin, []byte, error) { + return &types.Plugin{ + ID: "5724e2c8652da337ab2eedd19fc6fc0ec908e4bd907c7421bf6a8dfc70c4c078", + Name: "foo/bar", + Enabled: false, + PluginReference: "localhost:5000/foo/bar:v0.1.0", + }, []byte{}, nil + }, + }) + cmd := newUpgradeCommand(cli) + // need to set a remote address that does not match the plugin + // reference sent by the `pluginInspectFunc` + cmd.SetArgs([]string{"foo/bar", "localhost:5000/foo/bar:v1.0.0"}) + test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, command.ErrPromptTerminated) + }) + golden.Assert(t, cli.OutBuffer().String(), "plugin-upgrade-terminate.golden") +} diff --git a/cli/command/system/client_test.go b/cli/command/system/client_test.go index a275426fe1..01dc4b3d25 100644 --- a/cli/command/system/client_test.go +++ b/cli/command/system/client_test.go @@ -5,15 +5,18 @@ import ( "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/events" + "github.com/docker/docker/api/types/filters" "github.com/docker/docker/client" ) type fakeClient struct { client.Client - version string - serverVersion func(ctx context.Context) (types.Version, error) - eventsFn func(context.Context, types.EventsOptions) (<-chan events.Message, <-chan error) + version string + serverVersion func(ctx context.Context) (types.Version, error) + eventsFn func(context.Context, types.EventsOptions) (<-chan events.Message, <-chan error) + containerPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) + networkPruneFunc func(ctx context.Context, pruneFilter filters.Args) (types.NetworksPruneReport, error) } func (cli *fakeClient) ServerVersion(ctx context.Context) (types.Version, error) { @@ -27,3 +30,17 @@ func (cli *fakeClient) ClientVersion() string { func (cli *fakeClient) Events(ctx context.Context, opts types.EventsOptions) (<-chan events.Message, <-chan error) { return cli.eventsFn(ctx, opts) } + +func (cli *fakeClient) ContainersPrune(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) { + if cli.containerPruneFunc != nil { + return cli.containerPruneFunc(ctx, pruneFilters) + } + return types.ContainersPruneReport{}, nil +} + +func (cli *fakeClient) NetworksPrune(ctx context.Context, pruneFilter filters.Args) (types.NetworksPruneReport, error) { + if cli.networkPruneFunc != nil { + return cli.networkPruneFunc(ctx, pruneFilter) + } + return types.NetworksPruneReport{}, nil +} diff --git a/cli/command/system/prune.go b/cli/command/system/prune.go index d91c450bf6..b3200ecfad 100644 --- a/cli/command/system/prune.go +++ b/cli/command/system/prune.go @@ -74,8 +74,10 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) if options.pruneVolumes && options.filter.Value().Contains("until") { return fmt.Errorf(`ERROR: The "until" filter is not supported with "--volumes"`) } - if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), confirmationMessage(dockerCli, options)) { - return nil + if !options.force { + if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), confirmationMessage(dockerCli, options)); !r || err != nil { + return err + } } pruneFuncs := []func(ctx context.Context, dockerCli command.Cli, all bool, filter opts.FilterOpt) (uint64, string, error){ container.RunPrune, diff --git a/cli/command/system/prune_test.go b/cli/command/system/prune_test.go index ff0f7ae699..30cfb727a2 100644 --- a/cli/command/system/prune_test.go +++ b/cli/command/system/prune_test.go @@ -1,10 +1,15 @@ package system import ( + "context" "testing" + "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/config/configfile" "github.com/docker/cli/internal/test" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/filters" + "github.com/pkg/errors" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" ) @@ -49,3 +54,23 @@ func TestPrunePromptFilters(t *testing.T) { Are you sure you want to continue? [y/N] ` assert.Check(t, is.Equal(expected, cli.OutBuffer().String())) } + +func TestSystemPrunePromptTermination(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cli := test.NewFakeCli(&fakeClient{ + containerPruneFunc: func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) { + return types.ContainersPruneReport{}, errors.New("fakeClient containerPruneFunc should not be called") + }, + networkPruneFunc: func(ctx context.Context, pruneFilters filters.Args) (types.NetworksPruneReport, error) { + return types.NetworksPruneReport{}, errors.New("fakeClient networkPruneFunc should not be called") + }, + }) + + cmd := newPruneCommand(cli) + test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, command.ErrPromptTerminated) + }) +} diff --git a/cli/command/trust/revoke.go b/cli/command/trust/revoke.go index 0c52aaeebd..ac391a69c2 100644 --- a/cli/command/trust/revoke.go +++ b/cli/command/trust/revoke.go @@ -3,7 +3,6 @@ package trust import ( "context" "fmt" - "os" "github.com/docker/cli/cli" "github.com/docker/cli/cli/command" @@ -44,7 +43,11 @@ func revokeTrust(ctx context.Context, dockerCLI command.Cli, remote string, opti return fmt.Errorf("cannot use a digest reference for IMAGE:TAG") } if imgRefAndAuth.Tag() == "" && !options.forceYes { - deleteRemote := command.PromptForConfirmation(os.Stdin, dockerCLI.Out(), fmt.Sprintf("Please confirm you would like to delete all signature data for %s?", remote)) + deleteRemote, err := command.PromptForConfirmation(ctx, dockerCLI.In(), dockerCLI.Out(), fmt.Sprintf("Please confirm you would like to delete all signature data for %s?", remote)) + if err != nil { + fmt.Fprintf(dockerCLI.Out(), "\nAborting action.\n") + return errors.Wrap(err, "aborting action") + } if !deleteRemote { fmt.Fprintf(dockerCLI.Out(), "\nAborting action.\n") return nil diff --git a/cli/command/trust/revoke_test.go b/cli/command/trust/revoke_test.go index ca20189adb..a48e9898d5 100644 --- a/cli/command/trust/revoke_test.go +++ b/cli/command/trust/revoke_test.go @@ -1,9 +1,11 @@ package trust import ( + "context" "io" "testing" + "github.com/docker/cli/cli/command" "github.com/docker/cli/cli/trust" "github.com/docker/cli/internal/test" "github.com/docker/cli/internal/test/notary" @@ -12,6 +14,7 @@ import ( "github.com/theupdateframework/notary/trustpinning" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" + "gotest.tools/v3/golden" ) func TestTrustRevokeCommandErrors(t *testing.T) { @@ -148,3 +151,18 @@ func TestGetSignableRolesForTargetAndRemoveError(t *testing.T) { err = getSignableRolesForTargetAndRemove(target, notaryRepo) assert.Error(t, err, "client is offline") } + +func TestRevokeTrustPromptTermination(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cli := test.NewFakeCli(&fakeClient{}) + cmd := newRevokeCommand(cli) + cmd.SetArgs([]string{"example/trust-demo"}) + test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) { + t.Helper() + assert.ErrorIs(t, err, command.ErrPromptTerminated) + }) + assert.Equal(t, cli.ErrBuffer().String(), "") + golden.Assert(t, cli.OutBuffer().String(), "trust-revoke-prompt-termination.golden") +} diff --git a/cli/command/trust/signer_remove.go b/cli/command/trust/signer_remove.go index bef3d3b271..06deaa7e0d 100644 --- a/cli/command/trust/signer_remove.go +++ b/cli/command/trust/signer_remove.go @@ -3,7 +3,6 @@ package trust import ( "context" "fmt" - "os" "strings" "github.com/docker/cli/cli" @@ -76,6 +75,22 @@ func isLastSignerForReleases(roleWithSig data.Role, allRoles []client.RoleWithSi return counter < releasesRoleWithSigs.Threshold, nil } +func maybePromptForSignerRemoval(ctx context.Context, dockerCLI command.Cli, repoName, signerName string, isLastSigner, forceYes bool) (bool, error) { + if isLastSigner && !forceYes { + message := fmt.Sprintf("The signer \"%s\" signed the last released version of %s. "+ + "Removing this signer will make %s unpullable. "+ + "Are you sure you want to continue?", + signerName, repoName, repoName, + ) + removeSigner, err := command.PromptForConfirmation(ctx, dockerCLI.In(), dockerCLI.Out(), message) + if err != nil { + return false, err + } + return removeSigner, nil + } + return false, nil +} + // removeSingleSigner attempts to remove a single signer and returns whether signer removal happened. // The signer not being removed doesn't necessarily raise an error e.g. user choosing "No" when prompted for confirmation. func removeSingleSigner(ctx context.Context, dockerCLI command.Cli, repoName, signerName string, forceYes bool) (bool, error) { @@ -110,28 +125,26 @@ func removeSingleSigner(ctx context.Context, dockerCLI command.Cli, repoName, si if err != nil { return false, err } - if ok, err := isLastSignerForReleases(role, allRoles); ok && !forceYes { - removeSigner := command.PromptForConfirmation(os.Stdin, dockerCLI.Out(), fmt.Sprintf("The signer \"%s\" signed the last released version of %s. "+ - "Removing this signer will make %s unpullable. "+ - "Are you sure you want to continue?", - signerName, repoName, repoName, - )) - if !removeSigner { - fmt.Fprintf(dockerCLI.Out(), "\nAborting action.\n") - return false, nil - } - } else if err != nil { - return false, err - } - if err = notaryRepo.RemoveDelegationKeys(releasesRoleTUFName, role.KeyIDs); err != nil { - return false, err - } - if err = notaryRepo.RemoveDelegationRole(signerDelegation); err != nil { + isLastSigner, err := isLastSignerForReleases(role, allRoles) + if err != nil { return false, err } - if err = notaryRepo.Publish(); err != nil { + ok, err := maybePromptForSignerRemoval(ctx, dockerCLI, repoName, signerName, isLastSigner, forceYes) + if err != nil || !ok { + fmt.Fprintf(dockerCLI.Out(), "\nAborting action.\n") + return false, err + } + + if err := notaryRepo.RemoveDelegationKeys(releasesRoleTUFName, role.KeyIDs); err != nil { + return false, err + } + if err := notaryRepo.RemoveDelegationRole(signerDelegation); err != nil { + return false, err + } + + if err := notaryRepo.Publish(); err != nil { return false, err } diff --git a/cli/command/trust/signer_remove_test.go b/cli/command/trust/signer_remove_test.go index 05fa679715..9b54bcaf2f 100644 --- a/cli/command/trust/signer_remove_test.go +++ b/cli/command/trust/signer_remove_test.go @@ -111,7 +111,8 @@ func TestIsLastSignerForReleases(t *testing.T) { releaserole.Name = releasesRoleTUFName releaserole.Threshold = 1 allrole := []client.RoleWithSignatures{releaserole} - lastsigner, _ := isLastSignerForReleases(role, allrole) + lastsigner, err := isLastSignerForReleases(role, allrole) + assert.Error(t, err, "all signed tags are currently revoked, use docker trust sign to fix") assert.Check(t, is.Equal(false, lastsigner)) role.KeyIDs = []string{"deadbeef"} @@ -120,13 +121,15 @@ func TestIsLastSignerForReleases(t *testing.T) { releaserole.Signatures = []data.Signature{sig} releaserole.Threshold = 1 allrole = []client.RoleWithSignatures{releaserole} - lastsigner, _ = isLastSignerForReleases(role, allrole) + lastsigner, err = isLastSignerForReleases(role, allrole) + assert.NilError(t, err) assert.Check(t, is.Equal(true, lastsigner)) sig.KeyID = "8badf00d" releaserole.Signatures = []data.Signature{sig} releaserole.Threshold = 1 allrole = []client.RoleWithSignatures{releaserole} - lastsigner, _ = isLastSignerForReleases(role, allrole) + lastsigner, err = isLastSignerForReleases(role, allrole) + assert.NilError(t, err) assert.Check(t, is.Equal(false, lastsigner)) } diff --git a/cli/command/trust/testdata/trust-revoke-prompt-termination.golden b/cli/command/trust/testdata/trust-revoke-prompt-termination.golden new file mode 100644 index 0000000000..bf854da789 --- /dev/null +++ b/cli/command/trust/testdata/trust-revoke-prompt-termination.golden @@ -0,0 +1,2 @@ +Please confirm you would like to delete all signature data for example/trust-demo? [y/N] +Aborting action. diff --git a/cli/command/utils.go b/cli/command/utils.go index 5c83efc986..df1f876faf 100644 --- a/cli/command/utils.go +++ b/cli/command/utils.go @@ -5,12 +5,15 @@ package command import ( "bufio" + "context" "fmt" "io" "os" + "os/signal" "path/filepath" "runtime" "strings" + "syscall" "github.com/docker/cli/cli/streams" "github.com/docker/docker/api/types/filters" @@ -72,12 +75,21 @@ func PrettyPrint(i any) string { } } -// PromptForConfirmation requests and checks confirmation from user. -// This will display the provided message followed by ' [y/N] '. If -// the user input 'y' or 'Y' it returns true other false. If no -// message is provided "Are you sure you want to proceed? [y/N] " -// will be used instead. -func PromptForConfirmation(ins io.Reader, outs io.Writer, message string) bool { +type PromptError error + +var ErrPromptTerminated = PromptError(errors.New("prompt terminated")) + +// PromptForConfirmation requests and checks confirmation from the user. +// This will display the provided message followed by ' [y/N] '. If the user +// input 'y' or 'Y' it returns true otherwise false. If no message is provided, +// "Are you sure you want to proceed? [y/N] " will be used instead. +// +// If the user terminates the CLI with SIGINT or SIGTERM while the prompt is +// active, the prompt will return false with an ErrPromptTerminated error. +// When the prompt returns an error, the caller should propagate the error up +// the stack and close the io.Reader used for the prompt which will prevent the +// background goroutine from blocking indefinitely. +func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, message string) (bool, error) { if message == "" { message = "Are you sure you want to proceed?" } @@ -90,9 +102,31 @@ func PromptForConfirmation(ins io.Reader, outs io.Writer, message string) bool { ins = streams.NewIn(os.Stdin) } - reader := bufio.NewReader(ins) - answer, _, _ := reader.ReadLine() - return strings.ToLower(string(answer)) == "y" + result := make(chan bool) + + // Catch the termination signal and exit the prompt gracefully. + // The caller is responsible for properly handling the termination. + notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM) + defer notifyCancel() + + go func() { + var res bool + scanner := bufio.NewScanner(ins) + if scanner.Scan() { + answer := strings.TrimSpace(scanner.Text()) + if strings.EqualFold(answer, "y") { + res = true + } + } + result <- res + }() + + select { + case <-notifyCtx.Done(): + return false, ErrPromptTerminated + case r := <-result: + return r, nil + } } // PruneFilters returns consolidated prune filters obtained from config.json and cli diff --git a/cli/command/utils_test.go b/cli/command/utils_test.go index 5ca5d0b3bb..869b33b644 100644 --- a/cli/command/utils_test.go +++ b/cli/command/utils_test.go @@ -1,36 +1,46 @@ -package command +package command_test import ( + "bufio" + "bytes" + "context" + "fmt" + "io" "os" "path/filepath" + "strings" + "syscall" "testing" + "time" + "github.com/docker/cli/cli/command" + "github.com/docker/cli/internal/test" "github.com/pkg/errors" "gotest.tools/v3/assert" ) func TestStringSliceReplaceAt(t *testing.T) { - out, ok := StringSliceReplaceAt([]string{"abc", "foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, -1) + out, ok := command.StringSliceReplaceAt([]string{"abc", "foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, -1) assert.Assert(t, ok) assert.DeepEqual(t, []string{"abc", "baz", "bax"}, out) - out, ok = StringSliceReplaceAt([]string{"foo"}, []string{"foo", "bar"}, []string{"baz"}, -1) + out, ok = command.StringSliceReplaceAt([]string{"foo"}, []string{"foo", "bar"}, []string{"baz"}, -1) assert.Assert(t, !ok) assert.DeepEqual(t, []string{"foo"}, out) - out, ok = StringSliceReplaceAt([]string{"abc", "foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, 0) + out, ok = command.StringSliceReplaceAt([]string{"abc", "foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, 0) assert.Assert(t, !ok) assert.DeepEqual(t, []string{"abc", "foo", "bar", "bax"}, out) - out, ok = StringSliceReplaceAt([]string{"foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, 0) + out, ok = command.StringSliceReplaceAt([]string{"foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, 0) assert.Assert(t, ok) assert.DeepEqual(t, []string{"baz", "bax"}, out) - out, ok = StringSliceReplaceAt([]string{"abc", "foo", "bar", "baz"}, []string{"foo", "bar"}, nil, -1) + out, ok = command.StringSliceReplaceAt([]string{"abc", "foo", "bar", "baz"}, []string{"foo", "bar"}, nil, -1) assert.Assert(t, ok) assert.DeepEqual(t, []string{"abc", "baz"}, out) - out, ok = StringSliceReplaceAt([]string{"foo"}, nil, []string{"baz"}, -1) + out, ok = command.StringSliceReplaceAt([]string{"foo"}, nil, []string{"baz"}, -1) assert.Assert(t, !ok) assert.DeepEqual(t, []string{"foo"}, out) } @@ -59,7 +69,7 @@ func TestValidateOutputPath(t *testing.T) { for _, testcase := range testcases { t.Run(testcase.path, func(t *testing.T) { - err := ValidateOutputPath(testcase.path) + err := command.ValidateOutputPath(testcase.path) if testcase.err == nil { assert.NilError(t, err) } else { @@ -68,3 +78,177 @@ func TestValidateOutputPath(t *testing.T) { }) } } + +func TestPromptForConfirmation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + type promptResult struct { + result bool + err error + } + + buf := new(bytes.Buffer) + bufioWriter := bufio.NewWriter(buf) + + var ( + promptWriter *io.PipeWriter + promptReader *io.PipeReader + ) + + defer func() { + if promptWriter != nil { + promptWriter.Close() + } + if promptReader != nil { + promptReader.Close() + } + }() + + for _, tc := range []struct { + desc string + f func(*testing.T, context.Context, chan promptResult) + }{ + {"SIGINT", func(t *testing.T, ctx context.Context, c chan promptResult) { + t.Helper() + + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + + select { + case <-ctx.Done(): + t.Fatal("PromptForConfirmation did not return after SIGINT") + case r := <-c: + assert.Check(t, !r.result) + assert.ErrorContains(t, r.err, "prompt terminated") + } + }}, + {"no", func(t *testing.T, ctx context.Context, c chan promptResult) { + t.Helper() + + _, err := fmt.Fprint(promptWriter, "n\n") + assert.NilError(t, err) + + select { + case <-ctx.Done(): + t.Fatal("PromptForConfirmation did not return after user input `n`") + case r := <-c: + assert.Check(t, !r.result) + assert.NilError(t, r.err) + } + }}, + {"yes", func(t *testing.T, ctx context.Context, c chan promptResult) { + t.Helper() + + _, err := fmt.Fprint(promptWriter, "y\n") + assert.NilError(t, err) + + select { + case <-ctx.Done(): + t.Fatal("PromptForConfirmation did not return after user input `y`") + case r := <-c: + assert.Check(t, r.result) + assert.NilError(t, r.err) + } + }}, + {"any", func(t *testing.T, ctx context.Context, c chan promptResult) { + t.Helper() + + _, err := fmt.Fprint(promptWriter, "a\n") + assert.NilError(t, err) + + select { + case <-ctx.Done(): + t.Fatal("PromptForConfirmation did not return after user input `a`") + case r := <-c: + assert.Check(t, !r.result) + assert.NilError(t, r.err) + } + }}, + {"with space", func(t *testing.T, ctx context.Context, c chan promptResult) { + t.Helper() + + _, err := fmt.Fprint(promptWriter, " y\n") + assert.NilError(t, err) + + select { + case <-ctx.Done(): + t.Fatal("PromptForConfirmation did not return after user input ` y`") + case r := <-c: + assert.Check(t, r.result) + assert.NilError(t, r.err) + } + }}, + {"reader closed", func(t *testing.T, ctx context.Context, c chan promptResult) { + t.Helper() + + assert.NilError(t, promptReader.Close()) + + select { + case <-ctx.Done(): + t.Fatal("PromptForConfirmation did not return after promptReader was closed") + case r := <-c: + assert.Check(t, !r.result) + assert.NilError(t, r.err) + } + }}, + } { + t.Run("case="+tc.desc, func(t *testing.T) { + buf.Reset() + promptReader, promptWriter = io.Pipe() + + wroteHook := make(chan struct{}, 1) + defer close(wroteHook) + promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) { + wroteHook <- struct{}{} + }) + + result := make(chan promptResult, 1) + defer close(result) + go func() { + r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "") + result <- promptResult{r, err} + }() + + // wait for the Prompt to write to the buffer + pollForPromptOutput(ctx, t, wroteHook) + drainChannel(ctx, wroteHook) + + assert.NilError(t, bufioWriter.Flush()) + assert.Equal(t, strings.TrimSpace(buf.String()), "Are you sure you want to proceed? [y/N]") + + resultCtx, resultCancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer resultCancel() + + tc.f(t, resultCtx, result) + }) + } +} + +func drainChannel(ctx context.Context, ch <-chan struct{}) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-ch: + } + } + }() +} + +func pollForPromptOutput(ctx context.Context, t *testing.T, wroteHook <-chan struct{}) { + t.Helper() + + ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + + for { + select { + case <-ctx.Done(): + t.Fatal("Prompt output was not written to before ctx was cancelled") + return + case <-wroteHook: + return + } + } +} diff --git a/cli/command/volume/prune.go b/cli/command/volume/prune.go index 0756f8fbb7..eff271665b 100644 --- a/cli/command/volume/prune.go +++ b/cli/command/volume/prune.go @@ -80,8 +80,10 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) // API < v1.42 removes all volumes (anonymous and named) by default. warning = allVolumesWarning } - if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), warning) { - return 0, "", errdefs.Cancelled(errors.New("user cancelled operation")) + if !options.force { + if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning); !r || err != nil { + return 0, "", errdefs.Cancelled(errors.New("user cancelled operation")) + } } report, err := dockerCli.Client().VolumesPrune(ctx, pruneFilters) diff --git a/cli/command/volume/prune_test.go b/cli/command/volume/prune_test.go index ec0b188894..03fbe9dcd0 100644 --- a/cli/command/volume/prune_test.go +++ b/cli/command/volume/prune_test.go @@ -1,6 +1,7 @@ package volume import ( + "context" "fmt" "io" "runtime" @@ -183,3 +184,19 @@ func simplePruneFunc(filters.Args) (types.VolumesPruneReport, error) { SpaceReclaimed: 2000, }, nil } + +func TestVolumePrunePromptTerminate(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + cli := test.NewFakeCli(&fakeClient{ + volumePruneFunc: func(filter filters.Args) (types.VolumesPruneReport, error) { + return types.VolumesPruneReport{}, errors.New("fakeClient volumePruneFunc should not be called") + }, + }) + + cmd := NewPruneCommand(cli) + test.TerminatePrompt(ctx, t, cmd, cli, nil) + + golden.Assert(t, cli.OutBuffer().String(), "volume-prune-terminate.golden") +} diff --git a/cli/command/volume/testdata/volume-prune-terminate.golden b/cli/command/volume/testdata/volume-prune-terminate.golden new file mode 100644 index 0000000000..8e918f9c9e --- /dev/null +++ b/cli/command/volume/testdata/volume-prune-terminate.golden @@ -0,0 +1,2 @@ +WARNING! This will remove anonymous local volumes not used by at least one container. +Are you sure you want to continue? [y/N] diff --git a/internal/test/cmd.go b/internal/test/cmd.go new file mode 100644 index 0000000000..562542e6f8 --- /dev/null +++ b/internal/test/cmd.go @@ -0,0 +1,83 @@ +package test + +import ( + "context" + "os" + "syscall" + "testing" + "time" + + "github.com/docker/cli/cli/streams" + "github.com/spf13/cobra" + "gotest.tools/v3/assert" +) + +func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli *FakeCli, assertFunc func(*testing.T, error)) { + t.Helper() + + errChan := make(chan error) + defer close(errChan) + + // wrap the out stream to detect when the prompt is ready + writerHookChan := make(chan struct{}) + defer close(writerHookChan) + + outStream := streams.NewOut(NewWriterWithHook(cli.OutBuffer(), func(p []byte) { + writerHookChan <- struct{}{} + })) + cli.SetOut(outStream) + + r, _, err := os.Pipe() + assert.NilError(t, err) + cli.SetIn(streams.NewIn(r)) + + go func() { + errChan <- cmd.ExecuteContext(ctx) + }() + + writeCtx, writeCancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer writeCancel() + + // wait for the prompt to be ready + select { + case <-writeCtx.Done(): + t.Fatalf("command %s did not write prompt to stdout", cmd.Name()) + case <-writerHookChan: + // drain the channel for future buffer writes + go func() { + for { + select { + case <-ctx.Done(): + return + case <-writerHookChan: + } + } + }() + } + + assert.Check(t, cli.OutBuffer().Len() > 0) + + // a small delay to ensure the plugin is prompting + time.Sleep(100 * time.Microsecond) + + errCtx, errCancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer errCancel() + + // sigint and sigterm are caught by the prompt + // this allows us to gracefully exit the prompt with a 0 exit code + syscall.Kill(syscall.Getpid(), syscall.SIGINT) + + select { + case <-errCtx.Done(): + t.Logf("command stdout:\n%s\n", cli.OutBuffer().String()) + t.Logf("command stderr:\n%s\n", cli.ErrBuffer().String()) + t.Fatalf("command %s did not return after SIGINT", cmd.Name()) + case err := <-errChan: + if assertFunc != nil { + assertFunc(t, err) + } else { + assert.NilError(t, err) + assert.Equal(t, cli.ErrBuffer().String(), "") + } + } +} diff --git a/internal/test/writer.go b/internal/test/writer.go new file mode 100644 index 0000000000..c4ceede21e --- /dev/null +++ b/internal/test/writer.go @@ -0,0 +1,30 @@ +package test + +import ( + "io" +) + +// WriterWithHook is an io.Writer that calls a hook function +// after every write. +// This is useful in testing to wait for a write to complete, +// or to check what was written. +// To create a WriterWithHook use the NewWriterWithHook function. +type WriterWithHook struct { + actualWriter io.Writer + hook func([]byte) +} + +// Write writes p to the actual writer and then calls the hook function. +func (w *WriterWithHook) Write(p []byte) (n int, err error) { + defer w.hook(p) + return w.actualWriter.Write(p) +} + +var _ io.Writer = (*WriterWithHook)(nil) + +// NewWriterWithHook returns a new WriterWithHook that still writes to the actualWriter +// but also calls the hook function after every write. +// The hook function is useful for testing, or waiting for a write to complete. +func NewWriterWithHook(actualWriter io.Writer, hook func([]byte)) *WriterWithHook { + return &WriterWithHook{actualWriter: actualWriter, hook: hook} +}