diff --git a/cli/command/container/run.go b/cli/command/container/run.go index 749071b173..562e802920 100644 --- a/cli/command/container/run.go +++ b/cli/command/container/run.go @@ -119,6 +119,8 @@ func runRun(ctx context.Context, dockerCli command.Cli, flags *pflag.FlagSet, ro //nolint:gocyclo func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOptions, copts *containerOptions, containerCfg *containerConfig) error { + ctx = context.WithoutCancel(ctx) + config := containerCfg.Config stdout, stderr := dockerCli.Out(), dockerCli.Err() apiClient := dockerCli.Client() @@ -178,6 +180,9 @@ func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOption detachKeys = runOpts.detachKeys } + // ctx should not be cancellable here, as this would kill the stream to the container + // and we want to keep the stream open until the process in the container exits or until + // the user forcefully terminates the CLI. closeFn, err := attachContainer(ctx, dockerCli, containerID, &errCh, config, container.AttachOptions{ Stream: true, Stdin: config.AttachStdin, diff --git a/e2e/container/run_test.go b/e2e/container/run_test.go index fa8ea72b03..d4fea81769 100644 --- a/e2e/container/run_test.go +++ b/e2e/container/run_test.go @@ -1,8 +1,10 @@ package container import ( + "bytes" "fmt" "strings" + "syscall" "testing" "time" @@ -13,6 +15,7 @@ import ( is "gotest.tools/v3/assert/cmp" "gotest.tools/v3/golden" "gotest.tools/v3/icmd" + "gotest.tools/v3/poll" "gotest.tools/v3/skip" ) @@ -221,3 +224,26 @@ func TestMountSubvolume(t *testing.T) { }) } } + +func TestProcessTermination(t *testing.T) { + var out bytes.Buffer + cmd := icmd.Command("docker", "run", "--rm", "-i", fixtures.AlpineImage, + "sh", "-c", "echo 'starting trap'; trap 'echo got signal; exit 0;' TERM; while true; do sleep 10; done") + cmd.Stdout = &out + cmd.Stderr = &out + + result := icmd.StartCmd(cmd).Assert(t, icmd.Success) + + poll.WaitOn(t, func(t poll.LogT) poll.Result { + if strings.Contains(result.Stdout(), "starting trap") { + return poll.Success() + } + return poll.Continue("waiting for process to trap signal") + }, poll.WithDelay(1*time.Second), poll.WithTimeout(5*time.Second)) + + assert.NilError(t, result.Cmd.Process.Signal(syscall.SIGTERM)) + + icmd.WaitOnCmd(time.Second*10, result).Assert(t, icmd.Expected{ + ExitCode: 0, + }) +}