Merge pull request #4888 from Benehiko/fix-prompt-termination

fix: cli prompt termination exit code
This commit is contained in:
Bjorn Neergaard 2024-03-04 07:56:38 -07:00 committed by GitHub
commit 181575bf55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 739 additions and 63 deletions

View File

@ -0,0 +1,20 @@
package builder
import (
"context"
"github.com/docker/docker/api/types"
"github.com/docker/docker/client"
)
type fakeClient struct {
client.Client
builderPruneFunc func(ctx context.Context, opts types.BuildCachePruneOptions) (*types.BuildCachePruneReport, error)
}
func (c *fakeClient) BuildCachePrune(ctx context.Context, opts types.BuildCachePruneOptions) (*types.BuildCachePruneReport, error) {
if c.builderPruneFunc != nil {
return c.builderPruneFunc(ctx, opts)
}
return nil, nil
}

View File

@ -66,8 +66,10 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions)
if options.all { if options.all {
warning = allCacheWarning warning = allCacheWarning
} }
if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), warning) { if !options.force {
return 0, "", nil if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning); !r || err != nil {
return 0, "", err
}
} }
report, err := dockerCli.Client().BuildCachePrune(ctx, types.BuildCachePruneOptions{ report, err := dockerCli.Client().BuildCachePrune(ctx, types.BuildCachePruneOptions{

View File

@ -0,0 +1,28 @@
package builder
import (
"context"
"errors"
"testing"
"github.com/docker/cli/cli/command"
"github.com/docker/cli/internal/test"
"github.com/docker/docker/api/types"
"gotest.tools/v3/assert"
)
func TestBuilderPromptTermination(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
cli := test.NewFakeCli(&fakeClient{
builderPruneFunc: func(ctx context.Context, opts types.BuildCachePruneOptions) (*types.BuildCachePruneReport, error) {
return nil, errors.New("fakeClient builderPruneFunc should not be called")
},
})
cmd := NewPruneCommand(cli)
test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) {
t.Helper()
assert.ErrorIs(t, err, command.ErrPromptTerminated)
})
}

View File

@ -6,6 +6,7 @@ import (
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/image"
"github.com/docker/docker/api/types/network" "github.com/docker/docker/api/types/network"
"github.com/docker/docker/api/types/system" "github.com/docker/docker/api/types/system"
@ -35,6 +36,7 @@ type fakeClient struct {
containerExecResizeFunc func(id string, options container.ResizeOptions) error containerExecResizeFunc func(id string, options container.ResizeOptions) error
containerRemoveFunc func(ctx context.Context, containerID string, options container.RemoveOptions) error containerRemoveFunc func(ctx context.Context, containerID string, options container.RemoveOptions) error
containerKillFunc func(ctx context.Context, containerID, signal string) error containerKillFunc func(ctx context.Context, containerID, signal string) error
containerPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error)
Version string Version string
} }
@ -164,3 +166,10 @@ func (f *fakeClient) ContainerKill(ctx context.Context, containerID, signal stri
} }
return nil return nil
} }
func (f *fakeClient) ContainersPrune(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) {
if f.containerPruneFunc != nil {
return f.containerPruneFunc(ctx, pruneFilters)
}
return types.ContainersPruneReport{}, nil
}

View File

@ -53,8 +53,10 @@ Are you sure you want to continue?`
func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) (spaceReclaimed uint64, output string, err error) { func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) (spaceReclaimed uint64, output string, err error) {
pruneFilters := command.PruneFilters(dockerCli, options.filter.Value()) pruneFilters := command.PruneFilters(dockerCli, options.filter.Value())
if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), warning) { if !options.force {
return 0, "", nil if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning); !r || err != nil {
return 0, "", err
}
} }
report, err := dockerCli.Client().ContainersPrune(ctx, pruneFilters) report, err := dockerCli.Client().ContainersPrune(ctx, pruneFilters)

View File

@ -0,0 +1,29 @@
package container
import (
"context"
"testing"
"github.com/docker/cli/cli/command"
"github.com/docker/cli/internal/test"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters"
"github.com/pkg/errors"
"gotest.tools/v3/assert"
)
func TestContainerPrunePromptTermination(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
cli := test.NewFakeCli(&fakeClient{
containerPruneFunc: func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) {
return types.ContainersPruneReport{}, errors.New("fakeClient containerPruneFunc should not be called")
},
})
cmd := NewPruneCommand(cli)
test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) {
t.Helper()
assert.ErrorIs(t, err, command.ErrPromptTerminated)
})
}

View File

@ -67,8 +67,10 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions)
if options.all { if options.all {
warning = allImageWarning warning = allImageWarning
} }
if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), warning) { if !options.force {
return 0, "", nil if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning); !r || err != nil {
return 0, "", err
}
} }
report, err := dockerCli.Client().ImagesPrune(ctx, pruneFilters) report, err := dockerCli.Client().ImagesPrune(ctx, pruneFilters)

View File

@ -1,10 +1,12 @@
package image package image
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"testing" "testing"
"github.com/docker/cli/cli/command"
"github.com/docker/cli/internal/test" "github.com/docker/cli/internal/test"
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/filters"
@ -101,3 +103,19 @@ func TestNewPruneCommandSuccess(t *testing.T) {
golden.Assert(t, cli.OutBuffer().String(), fmt.Sprintf("prune-command-success.%s.golden", tc.name)) golden.Assert(t, cli.OutBuffer().String(), fmt.Sprintf("prune-command-success.%s.golden", tc.name))
} }
} }
func TestPrunePromptTermination(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
cli := test.NewFakeCli(&fakeClient{
imagesPruneFunc: func(pruneFilter filters.Args) (types.ImagesPruneReport, error) {
return types.ImagesPruneReport{}, errors.New("fakeClient imagesPruneFunc should not be called")
},
})
cmd := NewPruneCommand(cli)
test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) {
t.Helper()
assert.ErrorIs(t, err, command.ErrPromptTerminated)
})
}

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network" "github.com/docker/docker/api/types/network"
"github.com/docker/docker/client" "github.com/docker/docker/client"
) )
@ -15,6 +16,8 @@ type fakeClient struct {
networkDisconnectFunc func(ctx context.Context, networkID, container string, force bool) error networkDisconnectFunc func(ctx context.Context, networkID, container string, force bool) error
networkRemoveFunc func(ctx context.Context, networkID string) error networkRemoveFunc func(ctx context.Context, networkID string) error
networkListFunc func(ctx context.Context, options types.NetworkListOptions) ([]types.NetworkResource, error) networkListFunc func(ctx context.Context, options types.NetworkListOptions) ([]types.NetworkResource, error)
networkPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.NetworksPruneReport, error)
networkInspectFunc func(ctx context.Context, networkID string, options types.NetworkInspectOptions) (types.NetworkResource, []byte, error)
} }
func (c *fakeClient) NetworkCreate(ctx context.Context, name string, options types.NetworkCreate) (types.NetworkCreateResponse, error) { func (c *fakeClient) NetworkCreate(ctx context.Context, name string, options types.NetworkCreate) (types.NetworkCreateResponse, error) {
@ -52,6 +55,16 @@ func (c *fakeClient) NetworkRemove(ctx context.Context, networkID string) error
return nil return nil
} }
func (c *fakeClient) NetworkInspectWithRaw(context.Context, string, types.NetworkInspectOptions) (types.NetworkResource, []byte, error) { func (c *fakeClient) NetworkInspectWithRaw(ctx context.Context, networkID string, opts types.NetworkInspectOptions) (types.NetworkResource, []byte, error) {
if c.networkInspectFunc != nil {
return c.networkInspectFunc(ctx, networkID, opts)
}
return types.NetworkResource{}, nil, nil return types.NetworkResource{}, nil, nil
} }
func (c *fakeClient) NetworksPrune(ctx context.Context, pruneFilter filters.Args) (types.NetworksPruneReport, error) {
if c.networkPruneFunc != nil {
return c.networkPruneFunc(ctx, pruneFilter)
}
return types.NetworksPruneReport{}, nil
}

View File

@ -49,8 +49,10 @@ Are you sure you want to continue?`
func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) (output string, err error) { func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions) (output string, err error) {
pruneFilters := command.PruneFilters(dockerCli, options.filter.Value()) pruneFilters := command.PruneFilters(dockerCli, options.filter.Value())
if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), warning) { if !options.force {
return "", nil if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning); !r || err != nil {
return "", err
}
} }
report, err := dockerCli.Client().NetworksPrune(ctx, pruneFilters) report, err := dockerCli.Client().NetworksPrune(ctx, pruneFilters)

View File

@ -0,0 +1,29 @@
package network
import (
"context"
"testing"
"github.com/docker/cli/cli/command"
"github.com/docker/cli/internal/test"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters"
"github.com/pkg/errors"
"gotest.tools/v3/assert"
)
func TestNetworkPrunePromptTermination(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
cli := test.NewFakeCli(&fakeClient{
networkPruneFunc: func(ctx context.Context, pruneFilters filters.Args) (types.NetworksPruneReport, error) {
return types.NetworksPruneReport{}, errors.New("fakeClient networkPruneFunc should not be called")
},
})
cmd := NewPruneCommand(cli)
test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) {
t.Helper()
assert.ErrorIs(t, err, command.ErrPromptTerminated)
})
}

View File

@ -46,10 +46,15 @@ func runRemove(ctx context.Context, dockerCli command.Cli, networks []string, op
status := 0 status := 0
for _, name := range networks { for _, name := range networks {
if nw, _, err := client.NetworkInspectWithRaw(ctx, name, types.NetworkInspectOptions{}); err == nil && nw, _, err := client.NetworkInspectWithRaw(ctx, name, types.NetworkInspectOptions{})
nw.Ingress && if err == nil && nw.Ingress {
!command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), ingressWarning) { r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), ingressWarning)
continue if err != nil {
return err
}
if !r {
continue
}
} }
if err := client.NetworkRemove(ctx, name); err != nil { if err := client.NetworkRemove(ctx, name); err != nil {
if opts.force && errdefs.IsNotFound(err) { if opts.force && errdefs.IsNotFound(err) {

View File

@ -5,7 +5,9 @@ import (
"io" "io"
"testing" "testing"
"github.com/docker/cli/cli/command"
"github.com/docker/cli/internal/test" "github.com/docker/cli/internal/test"
"github.com/docker/docker/api/types"
"github.com/docker/docker/errdefs" "github.com/docker/docker/errdefs"
"github.com/pkg/errors" "github.com/pkg/errors"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
@ -94,3 +96,27 @@ func TestNetworkRemoveForce(t *testing.T) {
}) })
} }
} }
func TestNetworkRemovePromptTermination(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
cli := test.NewFakeCli(&fakeClient{
networkRemoveFunc: func(ctx context.Context, networkID string) error {
return errors.New("fakeClient networkRemoveFunc should not be called")
},
networkInspectFunc: func(ctx context.Context, networkID string, options types.NetworkInspectOptions) (types.NetworkResource, []byte, error) {
return types.NetworkResource{
ID: "existing-network",
Name: "existing-network",
Ingress: true,
}, nil, nil
},
})
cmd := newRemoveCommand(cli)
cmd.SetArgs([]string{"existing-network"})
test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) {
t.Helper()
assert.ErrorIs(t, err, command.ErrPromptTerminated)
})
}

View File

@ -19,6 +19,7 @@ type fakeClient struct {
pluginInstallFunc func(name string, options types.PluginInstallOptions) (io.ReadCloser, error) pluginInstallFunc func(name string, options types.PluginInstallOptions) (io.ReadCloser, error)
pluginListFunc func(filter filters.Args) (types.PluginsListResponse, error) pluginListFunc func(filter filters.Args) (types.PluginsListResponse, error)
pluginInspectFunc func(name string) (*types.Plugin, []byte, error) pluginInspectFunc func(name string) (*types.Plugin, []byte, error)
pluginUpgradeFunc func(name string, options types.PluginInstallOptions) (io.ReadCloser, error)
} }
func (c *fakeClient) PluginCreate(_ context.Context, createContext io.Reader, createOptions types.PluginCreateOptions) error { func (c *fakeClient) PluginCreate(_ context.Context, createContext io.Reader, createOptions types.PluginCreateOptions) error {
@ -75,3 +76,10 @@ func (c *fakeClient) PluginInspectWithRaw(_ context.Context, name string) (*type
func (c *fakeClient) Info(context.Context) (system.Info, error) { func (c *fakeClient) Info(context.Context) (system.Info, error) {
return system.Info{}, nil return system.Info{}, nil
} }
func (c *fakeClient) PluginUpgrade(ctx context.Context, name string, options types.PluginInstallOptions) (io.ReadCloser, error) {
if c.pluginUpgradeFunc != nil {
return c.pluginUpgradeFunc(name, options)
}
return nil, nil
}

View File

@ -142,6 +142,7 @@ func acceptPrivileges(dockerCli command.Cli, name string) func(privileges types.
for _, privilege := range privileges { for _, privilege := range privileges {
fmt.Fprintf(dockerCli.Out(), " - %s: %v\n", privilege.Name, privilege.Value) fmt.Fprintf(dockerCli.Out(), " - %s: %v\n", privilege.Name, privilege.Value)
} }
return command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), "Do you grant the above permissions?"), nil ctx := context.TODO()
return command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), "Do you grant the above permissions?")
} }
} }

View File

@ -0,0 +1,2 @@
Upgrading plugin foo/bar from localhost:5000/foo/bar:v0.1.0 to localhost:5000/foo/bar:v1.0.0
Plugin images do not match, are you sure? [y/N]

View File

@ -63,7 +63,10 @@ func runUpgrade(ctx context.Context, dockerCli command.Cli, opts pluginOptions)
fmt.Fprintf(dockerCli.Out(), "Upgrading plugin %s from %s to %s\n", p.Name, reference.FamiliarString(old), reference.FamiliarString(remote)) fmt.Fprintf(dockerCli.Out(), "Upgrading plugin %s from %s to %s\n", p.Name, reference.FamiliarString(old), reference.FamiliarString(remote))
if !opts.skipRemoteCheck && remote.String() != old.String() { if !opts.skipRemoteCheck && remote.String() != old.String() {
if !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), "Plugin images do not match, are you sure?") { if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), "Plugin images do not match, are you sure?"); !r || err != nil {
if err != nil {
return errors.Wrap(err, "canceling upgrade request")
}
return errors.New("canceling upgrade request") return errors.New("canceling upgrade request")
} }
} }

View File

@ -0,0 +1,42 @@
package plugin
import (
"context"
"io"
"testing"
"github.com/docker/cli/cli/command"
"github.com/docker/cli/internal/test"
"github.com/docker/docker/api/types"
"github.com/pkg/errors"
"gotest.tools/v3/assert"
"gotest.tools/v3/golden"
)
func TestUpgradePromptTermination(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
cli := test.NewFakeCli(&fakeClient{
pluginUpgradeFunc: func(name string, options types.PluginInstallOptions) (io.ReadCloser, error) {
return nil, errors.New("should not be called")
},
pluginInspectFunc: func(name string) (*types.Plugin, []byte, error) {
return &types.Plugin{
ID: "5724e2c8652da337ab2eedd19fc6fc0ec908e4bd907c7421bf6a8dfc70c4c078",
Name: "foo/bar",
Enabled: false,
PluginReference: "localhost:5000/foo/bar:v0.1.0",
}, []byte{}, nil
},
})
cmd := newUpgradeCommand(cli)
// need to set a remote address that does not match the plugin
// reference sent by the `pluginInspectFunc`
cmd.SetArgs([]string{"foo/bar", "localhost:5000/foo/bar:v1.0.0"})
test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) {
t.Helper()
assert.ErrorIs(t, err, command.ErrPromptTerminated)
})
golden.Assert(t, cli.OutBuffer().String(), "plugin-upgrade-terminate.golden")
}

View File

@ -5,15 +5,18 @@ import (
"github.com/docker/docker/api/types" "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/events" "github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/client" "github.com/docker/docker/client"
) )
type fakeClient struct { type fakeClient struct {
client.Client client.Client
version string version string
serverVersion func(ctx context.Context) (types.Version, error) serverVersion func(ctx context.Context) (types.Version, error)
eventsFn func(context.Context, types.EventsOptions) (<-chan events.Message, <-chan error) eventsFn func(context.Context, types.EventsOptions) (<-chan events.Message, <-chan error)
containerPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error)
networkPruneFunc func(ctx context.Context, pruneFilter filters.Args) (types.NetworksPruneReport, error)
} }
func (cli *fakeClient) ServerVersion(ctx context.Context) (types.Version, error) { func (cli *fakeClient) ServerVersion(ctx context.Context) (types.Version, error) {
@ -27,3 +30,17 @@ func (cli *fakeClient) ClientVersion() string {
func (cli *fakeClient) Events(ctx context.Context, opts types.EventsOptions) (<-chan events.Message, <-chan error) { func (cli *fakeClient) Events(ctx context.Context, opts types.EventsOptions) (<-chan events.Message, <-chan error) {
return cli.eventsFn(ctx, opts) return cli.eventsFn(ctx, opts)
} }
func (cli *fakeClient) ContainersPrune(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) {
if cli.containerPruneFunc != nil {
return cli.containerPruneFunc(ctx, pruneFilters)
}
return types.ContainersPruneReport{}, nil
}
func (cli *fakeClient) NetworksPrune(ctx context.Context, pruneFilter filters.Args) (types.NetworksPruneReport, error) {
if cli.networkPruneFunc != nil {
return cli.networkPruneFunc(ctx, pruneFilter)
}
return types.NetworksPruneReport{}, nil
}

View File

@ -74,8 +74,10 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions)
if options.pruneVolumes && options.filter.Value().Contains("until") { if options.pruneVolumes && options.filter.Value().Contains("until") {
return fmt.Errorf(`ERROR: The "until" filter is not supported with "--volumes"`) return fmt.Errorf(`ERROR: The "until" filter is not supported with "--volumes"`)
} }
if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), confirmationMessage(dockerCli, options)) { if !options.force {
return nil if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), confirmationMessage(dockerCli, options)); !r || err != nil {
return err
}
} }
pruneFuncs := []func(ctx context.Context, dockerCli command.Cli, all bool, filter opts.FilterOpt) (uint64, string, error){ pruneFuncs := []func(ctx context.Context, dockerCli command.Cli, all bool, filter opts.FilterOpt) (uint64, string, error){
container.RunPrune, container.RunPrune,

View File

@ -1,10 +1,15 @@
package system package system
import ( import (
"context"
"testing" "testing"
"github.com/docker/cli/cli/command"
"github.com/docker/cli/cli/config/configfile" "github.com/docker/cli/cli/config/configfile"
"github.com/docker/cli/internal/test" "github.com/docker/cli/internal/test"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/filters"
"github.com/pkg/errors"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp" is "gotest.tools/v3/assert/cmp"
) )
@ -49,3 +54,23 @@ func TestPrunePromptFilters(t *testing.T) {
Are you sure you want to continue? [y/N] ` Are you sure you want to continue? [y/N] `
assert.Check(t, is.Equal(expected, cli.OutBuffer().String())) assert.Check(t, is.Equal(expected, cli.OutBuffer().String()))
} }
func TestSystemPrunePromptTermination(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
cli := test.NewFakeCli(&fakeClient{
containerPruneFunc: func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) {
return types.ContainersPruneReport{}, errors.New("fakeClient containerPruneFunc should not be called")
},
networkPruneFunc: func(ctx context.Context, pruneFilters filters.Args) (types.NetworksPruneReport, error) {
return types.NetworksPruneReport{}, errors.New("fakeClient networkPruneFunc should not be called")
},
})
cmd := newPruneCommand(cli)
test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) {
t.Helper()
assert.ErrorIs(t, err, command.ErrPromptTerminated)
})
}

View File

@ -3,7 +3,6 @@ package trust
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"github.com/docker/cli/cli" "github.com/docker/cli/cli"
"github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command"
@ -44,7 +43,11 @@ func revokeTrust(ctx context.Context, dockerCLI command.Cli, remote string, opti
return fmt.Errorf("cannot use a digest reference for IMAGE:TAG") return fmt.Errorf("cannot use a digest reference for IMAGE:TAG")
} }
if imgRefAndAuth.Tag() == "" && !options.forceYes { if imgRefAndAuth.Tag() == "" && !options.forceYes {
deleteRemote := command.PromptForConfirmation(os.Stdin, dockerCLI.Out(), fmt.Sprintf("Please confirm you would like to delete all signature data for %s?", remote)) deleteRemote, err := command.PromptForConfirmation(ctx, dockerCLI.In(), dockerCLI.Out(), fmt.Sprintf("Please confirm you would like to delete all signature data for %s?", remote))
if err != nil {
fmt.Fprintf(dockerCLI.Out(), "\nAborting action.\n")
return errors.Wrap(err, "aborting action")
}
if !deleteRemote { if !deleteRemote {
fmt.Fprintf(dockerCLI.Out(), "\nAborting action.\n") fmt.Fprintf(dockerCLI.Out(), "\nAborting action.\n")
return nil return nil

View File

@ -1,9 +1,11 @@
package trust package trust
import ( import (
"context"
"io" "io"
"testing" "testing"
"github.com/docker/cli/cli/command"
"github.com/docker/cli/cli/trust" "github.com/docker/cli/cli/trust"
"github.com/docker/cli/internal/test" "github.com/docker/cli/internal/test"
"github.com/docker/cli/internal/test/notary" "github.com/docker/cli/internal/test/notary"
@ -12,6 +14,7 @@ import (
"github.com/theupdateframework/notary/trustpinning" "github.com/theupdateframework/notary/trustpinning"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp" is "gotest.tools/v3/assert/cmp"
"gotest.tools/v3/golden"
) )
func TestTrustRevokeCommandErrors(t *testing.T) { func TestTrustRevokeCommandErrors(t *testing.T) {
@ -148,3 +151,18 @@ func TestGetSignableRolesForTargetAndRemoveError(t *testing.T) {
err = getSignableRolesForTargetAndRemove(target, notaryRepo) err = getSignableRolesForTargetAndRemove(target, notaryRepo)
assert.Error(t, err, "client is offline") assert.Error(t, err, "client is offline")
} }
func TestRevokeTrustPromptTermination(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
cli := test.NewFakeCli(&fakeClient{})
cmd := newRevokeCommand(cli)
cmd.SetArgs([]string{"example/trust-demo"})
test.TerminatePrompt(ctx, t, cmd, cli, func(t *testing.T, err error) {
t.Helper()
assert.ErrorIs(t, err, command.ErrPromptTerminated)
})
assert.Equal(t, cli.ErrBuffer().String(), "")
golden.Assert(t, cli.OutBuffer().String(), "trust-revoke-prompt-termination.golden")
}

View File

@ -3,7 +3,6 @@ package trust
import ( import (
"context" "context"
"fmt" "fmt"
"os"
"strings" "strings"
"github.com/docker/cli/cli" "github.com/docker/cli/cli"
@ -76,6 +75,22 @@ func isLastSignerForReleases(roleWithSig data.Role, allRoles []client.RoleWithSi
return counter < releasesRoleWithSigs.Threshold, nil return counter < releasesRoleWithSigs.Threshold, nil
} }
func maybePromptForSignerRemoval(ctx context.Context, dockerCLI command.Cli, repoName, signerName string, isLastSigner, forceYes bool) (bool, error) {
if isLastSigner && !forceYes {
message := fmt.Sprintf("The signer \"%s\" signed the last released version of %s. "+
"Removing this signer will make %s unpullable. "+
"Are you sure you want to continue?",
signerName, repoName, repoName,
)
removeSigner, err := command.PromptForConfirmation(ctx, dockerCLI.In(), dockerCLI.Out(), message)
if err != nil {
return false, err
}
return removeSigner, nil
}
return false, nil
}
// removeSingleSigner attempts to remove a single signer and returns whether signer removal happened. // removeSingleSigner attempts to remove a single signer and returns whether signer removal happened.
// The signer not being removed doesn't necessarily raise an error e.g. user choosing "No" when prompted for confirmation. // The signer not being removed doesn't necessarily raise an error e.g. user choosing "No" when prompted for confirmation.
func removeSingleSigner(ctx context.Context, dockerCLI command.Cli, repoName, signerName string, forceYes bool) (bool, error) { func removeSingleSigner(ctx context.Context, dockerCLI command.Cli, repoName, signerName string, forceYes bool) (bool, error) {
@ -110,28 +125,26 @@ func removeSingleSigner(ctx context.Context, dockerCLI command.Cli, repoName, si
if err != nil { if err != nil {
return false, err return false, err
} }
if ok, err := isLastSignerForReleases(role, allRoles); ok && !forceYes {
removeSigner := command.PromptForConfirmation(os.Stdin, dockerCLI.Out(), fmt.Sprintf("The signer \"%s\" signed the last released version of %s. "+
"Removing this signer will make %s unpullable. "+
"Are you sure you want to continue?",
signerName, repoName, repoName,
))
if !removeSigner { isLastSigner, err := isLastSignerForReleases(role, allRoles)
fmt.Fprintf(dockerCLI.Out(), "\nAborting action.\n") if err != nil {
return false, nil
}
} else if err != nil {
return false, err
}
if err = notaryRepo.RemoveDelegationKeys(releasesRoleTUFName, role.KeyIDs); err != nil {
return false, err
}
if err = notaryRepo.RemoveDelegationRole(signerDelegation); err != nil {
return false, err return false, err
} }
if err = notaryRepo.Publish(); err != nil { ok, err := maybePromptForSignerRemoval(ctx, dockerCLI, repoName, signerName, isLastSigner, forceYes)
if err != nil || !ok {
fmt.Fprintf(dockerCLI.Out(), "\nAborting action.\n")
return false, err
}
if err := notaryRepo.RemoveDelegationKeys(releasesRoleTUFName, role.KeyIDs); err != nil {
return false, err
}
if err := notaryRepo.RemoveDelegationRole(signerDelegation); err != nil {
return false, err
}
if err := notaryRepo.Publish(); err != nil {
return false, err return false, err
} }

View File

@ -111,7 +111,8 @@ func TestIsLastSignerForReleases(t *testing.T) {
releaserole.Name = releasesRoleTUFName releaserole.Name = releasesRoleTUFName
releaserole.Threshold = 1 releaserole.Threshold = 1
allrole := []client.RoleWithSignatures{releaserole} allrole := []client.RoleWithSignatures{releaserole}
lastsigner, _ := isLastSignerForReleases(role, allrole) lastsigner, err := isLastSignerForReleases(role, allrole)
assert.Error(t, err, "all signed tags are currently revoked, use docker trust sign to fix")
assert.Check(t, is.Equal(false, lastsigner)) assert.Check(t, is.Equal(false, lastsigner))
role.KeyIDs = []string{"deadbeef"} role.KeyIDs = []string{"deadbeef"}
@ -120,13 +121,15 @@ func TestIsLastSignerForReleases(t *testing.T) {
releaserole.Signatures = []data.Signature{sig} releaserole.Signatures = []data.Signature{sig}
releaserole.Threshold = 1 releaserole.Threshold = 1
allrole = []client.RoleWithSignatures{releaserole} allrole = []client.RoleWithSignatures{releaserole}
lastsigner, _ = isLastSignerForReleases(role, allrole) lastsigner, err = isLastSignerForReleases(role, allrole)
assert.NilError(t, err)
assert.Check(t, is.Equal(true, lastsigner)) assert.Check(t, is.Equal(true, lastsigner))
sig.KeyID = "8badf00d" sig.KeyID = "8badf00d"
releaserole.Signatures = []data.Signature{sig} releaserole.Signatures = []data.Signature{sig}
releaserole.Threshold = 1 releaserole.Threshold = 1
allrole = []client.RoleWithSignatures{releaserole} allrole = []client.RoleWithSignatures{releaserole}
lastsigner, _ = isLastSignerForReleases(role, allrole) lastsigner, err = isLastSignerForReleases(role, allrole)
assert.NilError(t, err)
assert.Check(t, is.Equal(false, lastsigner)) assert.Check(t, is.Equal(false, lastsigner))
} }

View File

@ -0,0 +1,2 @@
Please confirm you would like to delete all signature data for example/trust-demo? [y/N]
Aborting action.

View File

@ -5,12 +5,15 @@ package command
import ( import (
"bufio" "bufio"
"context"
"fmt" "fmt"
"io" "io"
"os" "os"
"os/signal"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"syscall"
"github.com/docker/cli/cli/streams" "github.com/docker/cli/cli/streams"
"github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/filters"
@ -72,12 +75,21 @@ func PrettyPrint(i any) string {
} }
} }
// PromptForConfirmation requests and checks confirmation from user. type PromptError error
// This will display the provided message followed by ' [y/N] '. If
// the user input 'y' or 'Y' it returns true other false. If no var ErrPromptTerminated = PromptError(errors.New("prompt terminated"))
// message is provided "Are you sure you want to proceed? [y/N] "
// will be used instead. // PromptForConfirmation requests and checks confirmation from the user.
func PromptForConfirmation(ins io.Reader, outs io.Writer, message string) bool { // This will display the provided message followed by ' [y/N] '. If the user
// input 'y' or 'Y' it returns true otherwise false. If no message is provided,
// "Are you sure you want to proceed? [y/N] " will be used instead.
//
// If the user terminates the CLI with SIGINT or SIGTERM while the prompt is
// active, the prompt will return false with an ErrPromptTerminated error.
// When the prompt returns an error, the caller should propagate the error up
// the stack and close the io.Reader used for the prompt which will prevent the
// background goroutine from blocking indefinitely.
func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, message string) (bool, error) {
if message == "" { if message == "" {
message = "Are you sure you want to proceed?" message = "Are you sure you want to proceed?"
} }
@ -90,9 +102,31 @@ func PromptForConfirmation(ins io.Reader, outs io.Writer, message string) bool {
ins = streams.NewIn(os.Stdin) ins = streams.NewIn(os.Stdin)
} }
reader := bufio.NewReader(ins) result := make(chan bool)
answer, _, _ := reader.ReadLine()
return strings.ToLower(string(answer)) == "y" // Catch the termination signal and exit the prompt gracefully.
// The caller is responsible for properly handling the termination.
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer notifyCancel()
go func() {
var res bool
scanner := bufio.NewScanner(ins)
if scanner.Scan() {
answer := strings.TrimSpace(scanner.Text())
if strings.EqualFold(answer, "y") {
res = true
}
}
result <- res
}()
select {
case <-notifyCtx.Done():
return false, ErrPromptTerminated
case r := <-result:
return r, nil
}
} }
// PruneFilters returns consolidated prune filters obtained from config.json and cli // PruneFilters returns consolidated prune filters obtained from config.json and cli

View File

@ -1,36 +1,46 @@
package command package command_test
import ( import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"syscall"
"testing" "testing"
"time"
"github.com/docker/cli/cli/command"
"github.com/docker/cli/internal/test"
"github.com/pkg/errors" "github.com/pkg/errors"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
) )
func TestStringSliceReplaceAt(t *testing.T) { func TestStringSliceReplaceAt(t *testing.T) {
out, ok := StringSliceReplaceAt([]string{"abc", "foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, -1) out, ok := command.StringSliceReplaceAt([]string{"abc", "foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, -1)
assert.Assert(t, ok) assert.Assert(t, ok)
assert.DeepEqual(t, []string{"abc", "baz", "bax"}, out) assert.DeepEqual(t, []string{"abc", "baz", "bax"}, out)
out, ok = StringSliceReplaceAt([]string{"foo"}, []string{"foo", "bar"}, []string{"baz"}, -1) out, ok = command.StringSliceReplaceAt([]string{"foo"}, []string{"foo", "bar"}, []string{"baz"}, -1)
assert.Assert(t, !ok) assert.Assert(t, !ok)
assert.DeepEqual(t, []string{"foo"}, out) assert.DeepEqual(t, []string{"foo"}, out)
out, ok = StringSliceReplaceAt([]string{"abc", "foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, 0) out, ok = command.StringSliceReplaceAt([]string{"abc", "foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, 0)
assert.Assert(t, !ok) assert.Assert(t, !ok)
assert.DeepEqual(t, []string{"abc", "foo", "bar", "bax"}, out) assert.DeepEqual(t, []string{"abc", "foo", "bar", "bax"}, out)
out, ok = StringSliceReplaceAt([]string{"foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, 0) out, ok = command.StringSliceReplaceAt([]string{"foo", "bar", "bax"}, []string{"foo", "bar"}, []string{"baz"}, 0)
assert.Assert(t, ok) assert.Assert(t, ok)
assert.DeepEqual(t, []string{"baz", "bax"}, out) assert.DeepEqual(t, []string{"baz", "bax"}, out)
out, ok = StringSliceReplaceAt([]string{"abc", "foo", "bar", "baz"}, []string{"foo", "bar"}, nil, -1) out, ok = command.StringSliceReplaceAt([]string{"abc", "foo", "bar", "baz"}, []string{"foo", "bar"}, nil, -1)
assert.Assert(t, ok) assert.Assert(t, ok)
assert.DeepEqual(t, []string{"abc", "baz"}, out) assert.DeepEqual(t, []string{"abc", "baz"}, out)
out, ok = StringSliceReplaceAt([]string{"foo"}, nil, []string{"baz"}, -1) out, ok = command.StringSliceReplaceAt([]string{"foo"}, nil, []string{"baz"}, -1)
assert.Assert(t, !ok) assert.Assert(t, !ok)
assert.DeepEqual(t, []string{"foo"}, out) assert.DeepEqual(t, []string{"foo"}, out)
} }
@ -59,7 +69,7 @@ func TestValidateOutputPath(t *testing.T) {
for _, testcase := range testcases { for _, testcase := range testcases {
t.Run(testcase.path, func(t *testing.T) { t.Run(testcase.path, func(t *testing.T) {
err := ValidateOutputPath(testcase.path) err := command.ValidateOutputPath(testcase.path)
if testcase.err == nil { if testcase.err == nil {
assert.NilError(t, err) assert.NilError(t, err)
} else { } else {
@ -68,3 +78,177 @@ func TestValidateOutputPath(t *testing.T) {
}) })
} }
} }
func TestPromptForConfirmation(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
type promptResult struct {
result bool
err error
}
buf := new(bytes.Buffer)
bufioWriter := bufio.NewWriter(buf)
var (
promptWriter *io.PipeWriter
promptReader *io.PipeReader
)
defer func() {
if promptWriter != nil {
promptWriter.Close()
}
if promptReader != nil {
promptReader.Close()
}
}()
for _, tc := range []struct {
desc string
f func(*testing.T, context.Context, chan promptResult)
}{
{"SIGINT", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after SIGINT")
case r := <-c:
assert.Check(t, !r.result)
assert.ErrorContains(t, r.err, "prompt terminated")
}
}},
{"no", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()
_, err := fmt.Fprint(promptWriter, "n\n")
assert.NilError(t, err)
select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after user input `n`")
case r := <-c:
assert.Check(t, !r.result)
assert.NilError(t, r.err)
}
}},
{"yes", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()
_, err := fmt.Fprint(promptWriter, "y\n")
assert.NilError(t, err)
select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after user input `y`")
case r := <-c:
assert.Check(t, r.result)
assert.NilError(t, r.err)
}
}},
{"any", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()
_, err := fmt.Fprint(promptWriter, "a\n")
assert.NilError(t, err)
select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after user input `a`")
case r := <-c:
assert.Check(t, !r.result)
assert.NilError(t, r.err)
}
}},
{"with space", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()
_, err := fmt.Fprint(promptWriter, " y\n")
assert.NilError(t, err)
select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after user input ` y`")
case r := <-c:
assert.Check(t, r.result)
assert.NilError(t, r.err)
}
}},
{"reader closed", func(t *testing.T, ctx context.Context, c chan promptResult) {
t.Helper()
assert.NilError(t, promptReader.Close())
select {
case <-ctx.Done():
t.Fatal("PromptForConfirmation did not return after promptReader was closed")
case r := <-c:
assert.Check(t, !r.result)
assert.NilError(t, r.err)
}
}},
} {
t.Run("case="+tc.desc, func(t *testing.T) {
buf.Reset()
promptReader, promptWriter = io.Pipe()
wroteHook := make(chan struct{}, 1)
defer close(wroteHook)
promptOut := test.NewWriterWithHook(bufioWriter, func(p []byte) {
wroteHook <- struct{}{}
})
result := make(chan promptResult, 1)
defer close(result)
go func() {
r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "")
result <- promptResult{r, err}
}()
// wait for the Prompt to write to the buffer
pollForPromptOutput(ctx, t, wroteHook)
drainChannel(ctx, wroteHook)
assert.NilError(t, bufioWriter.Flush())
assert.Equal(t, strings.TrimSpace(buf.String()), "Are you sure you want to proceed? [y/N]")
resultCtx, resultCancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer resultCancel()
tc.f(t, resultCtx, result)
})
}
}
func drainChannel(ctx context.Context, ch <-chan struct{}) {
go func() {
for {
select {
case <-ctx.Done():
return
case <-ch:
}
}
}()
}
func pollForPromptOutput(ctx context.Context, t *testing.T, wroteHook <-chan struct{}) {
t.Helper()
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
for {
select {
case <-ctx.Done():
t.Fatal("Prompt output was not written to before ctx was cancelled")
return
case <-wroteHook:
return
}
}
}

View File

@ -80,8 +80,10 @@ func runPrune(ctx context.Context, dockerCli command.Cli, options pruneOptions)
// API < v1.42 removes all volumes (anonymous and named) by default. // API < v1.42 removes all volumes (anonymous and named) by default.
warning = allVolumesWarning warning = allVolumesWarning
} }
if !options.force && !command.PromptForConfirmation(dockerCli.In(), dockerCli.Out(), warning) { if !options.force {
return 0, "", errdefs.Cancelled(errors.New("user cancelled operation")) if r, err := command.PromptForConfirmation(ctx, dockerCli.In(), dockerCli.Out(), warning); !r || err != nil {
return 0, "", errdefs.Cancelled(errors.New("user cancelled operation"))
}
} }
report, err := dockerCli.Client().VolumesPrune(ctx, pruneFilters) report, err := dockerCli.Client().VolumesPrune(ctx, pruneFilters)

View File

@ -1,6 +1,7 @@
package volume package volume
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"runtime" "runtime"
@ -183,3 +184,19 @@ func simplePruneFunc(filters.Args) (types.VolumesPruneReport, error) {
SpaceReclaimed: 2000, SpaceReclaimed: 2000,
}, nil }, nil
} }
func TestVolumePrunePromptTerminate(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
cli := test.NewFakeCli(&fakeClient{
volumePruneFunc: func(filter filters.Args) (types.VolumesPruneReport, error) {
return types.VolumesPruneReport{}, errors.New("fakeClient volumePruneFunc should not be called")
},
})
cmd := NewPruneCommand(cli)
test.TerminatePrompt(ctx, t, cmd, cli, nil)
golden.Assert(t, cli.OutBuffer().String(), "volume-prune-terminate.golden")
}

View File

@ -0,0 +1,2 @@
WARNING! This will remove anonymous local volumes not used by at least one container.
Are you sure you want to continue? [y/N]

83
internal/test/cmd.go Normal file
View File

@ -0,0 +1,83 @@
package test
import (
"context"
"os"
"syscall"
"testing"
"time"
"github.com/docker/cli/cli/streams"
"github.com/spf13/cobra"
"gotest.tools/v3/assert"
)
func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli *FakeCli, assertFunc func(*testing.T, error)) {
t.Helper()
errChan := make(chan error)
defer close(errChan)
// wrap the out stream to detect when the prompt is ready
writerHookChan := make(chan struct{})
defer close(writerHookChan)
outStream := streams.NewOut(NewWriterWithHook(cli.OutBuffer(), func(p []byte) {
writerHookChan <- struct{}{}
}))
cli.SetOut(outStream)
r, _, err := os.Pipe()
assert.NilError(t, err)
cli.SetIn(streams.NewIn(r))
go func() {
errChan <- cmd.ExecuteContext(ctx)
}()
writeCtx, writeCancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer writeCancel()
// wait for the prompt to be ready
select {
case <-writeCtx.Done():
t.Fatalf("command %s did not write prompt to stdout", cmd.Name())
case <-writerHookChan:
// drain the channel for future buffer writes
go func() {
for {
select {
case <-ctx.Done():
return
case <-writerHookChan:
}
}
}()
}
assert.Check(t, cli.OutBuffer().Len() > 0)
// a small delay to ensure the plugin is prompting
time.Sleep(100 * time.Microsecond)
errCtx, errCancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer errCancel()
// sigint and sigterm are caught by the prompt
// this allows us to gracefully exit the prompt with a 0 exit code
syscall.Kill(syscall.Getpid(), syscall.SIGINT)
select {
case <-errCtx.Done():
t.Logf("command stdout:\n%s\n", cli.OutBuffer().String())
t.Logf("command stderr:\n%s\n", cli.ErrBuffer().String())
t.Fatalf("command %s did not return after SIGINT", cmd.Name())
case err := <-errChan:
if assertFunc != nil {
assertFunc(t, err)
} else {
assert.NilError(t, err)
assert.Equal(t, cli.ErrBuffer().String(), "")
}
}
}

30
internal/test/writer.go Normal file
View File

@ -0,0 +1,30 @@
package test
import (
"io"
)
// WriterWithHook is an io.Writer that calls a hook function
// after every write.
// This is useful in testing to wait for a write to complete,
// or to check what was written.
// To create a WriterWithHook use the NewWriterWithHook function.
type WriterWithHook struct {
actualWriter io.Writer
hook func([]byte)
}
// Write writes p to the actual writer and then calls the hook function.
func (w *WriterWithHook) Write(p []byte) (n int, err error) {
defer w.hook(p)
return w.actualWriter.Write(p)
}
var _ io.Writer = (*WriterWithHook)(nil)
// NewWriterWithHook returns a new WriterWithHook that still writes to the actualWriter
// but also calls the hook function after every write.
// The hook function is useful for testing, or waiting for a write to complete.
func NewWriterWithHook(actualWriter io.Writer, hook func([]byte)) *WriterWithHook {
return &WriterWithHook{actualWriter: actualWriter, hook: hook}
}