test: cli force exit signal handler

Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com>
This commit is contained in:
Alano Terblanche 2024-06-20 14:32:38 +02:00
parent a4bfd8c744
commit 1322f585fe
No known key found for this signature in database
GPG Key ID: 0E8FACD1BA98DE27
2 changed files with 118 additions and 18 deletions

View File

@ -326,23 +326,23 @@ func tryPluginRun(ctx context.Context, dockerCli command.Cli, cmd *cobra.Command
return nil return nil
} }
// registerForceExitGoroutine registers a goroutine that will force exit the // forceExitAfter3TerminationSignals waits for the first termination signal
// process after 3 SIGTERM/SIGINT signals. // to be caught and the context to be marked as done, then registers a new
func registerForceExitGoroutine(ctx context.Context, w io.Writer) { // signal handler for subsequent signals. It forces the process to exit
// setup a signal handler to force exit after 3 SIGTERM/SIGINT // after 3 SIGTERM/SIGINT signals.
go func() { func forceExitAfter3TerminationSignals(ctx context.Context, w io.Writer) {
<-ctx.Done() // wait for the first signal to be caught and the context to be marked as done
sig := make(chan os.Signal, 2) <-ctx.Done()
signal.Notify(sig, platformsignals.TerminationSignals...) // register a new signal handler for subsequent signals
count := 0 sig := make(chan os.Signal, 2)
for range sig { signal.Notify(sig, platformsignals.TerminationSignals...)
count++
if count >= 2 { // once we have received a total of 3 signals we force exit the cli
_, _ = fmt.Fprint(w, "\ngot 3 SIGTERM/SIGINTs, forcefully exiting\n") for i := 0; i < 2; i++ {
os.Exit(1) <-sig
} }
} _, _ = fmt.Fprint(w, "\ngot 3 SIGTERM/SIGINTs, forcefully exiting\n")
}() os.Exit(1)
} }
//nolint:gocyclo //nolint:gocyclo
@ -406,7 +406,7 @@ func runDocker(ctx context.Context, dockerCli *command.DockerCli) error {
// This is a fallback for the case where the command does not exit // This is a fallback for the case where the command does not exit
// based on context cancellation. // based on context cancellation.
registerForceExitGoroutine(ctx, dockerCli.Err()) go forceExitAfter3TerminationSignals(ctx, dockerCli.Err())
// We've parsed global args already, so reset args to those // We've parsed global args already, so reset args to those
// which remain. // which remain.

View File

@ -1,17 +1,28 @@
package main package main
import ( import (
"bufio"
"bytes" "bytes"
"context" "context"
"fmt"
"io" "io"
"os" "os"
"os/exec"
"os/signal"
"strings"
"syscall"
"testing" "testing"
"time"
"github.com/docker/cli/cli/command" "github.com/docker/cli/cli/command"
"github.com/docker/cli/cli/debug" "github.com/docker/cli/cli/debug"
"github.com/docker/cli/cli/streams"
"github.com/docker/cli/cmd/docker/internal/signals"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp" is "gotest.tools/v3/assert/cmp"
"gotest.tools/v3/poll"
) )
func TestClientDebugEnabled(t *testing.T) { func TestClientDebugEnabled(t *testing.T) {
@ -75,3 +86,92 @@ func TestVersion(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
assert.Check(t, is.Contains(b.String(), "Docker version")) assert.Check(t, is.Contains(b.String(), "Docker version"))
} }
func TestFallbackForceExit(t *testing.T) {
longRunningCommand := cobra.Command{
RunE: func(cmd *cobra.Command, args []string) error {
read, _, err := os.Pipe()
if err != nil {
return err
}
// wait until the parent process sends a signal to exit
_, _, err = bufio.NewReader(read).ReadLine()
return err
},
}
// This is the child process that will run the long running command
if os.Getenv("TEST_FALLBACK_FORCE_EXIT") == "1" {
fmt.Println("running long command")
ctx, cancel := signal.NotifyContext(context.Background(), signals.TerminationSignals...)
t.Cleanup(cancel)
longRunningCommand.SetErr(streams.NewOut(os.Stderr))
longRunningCommand.SetOut(streams.NewOut(os.Stdout))
go forceExitAfter3TerminationSignals(ctx, streams.NewOut(os.Stderr))
err := longRunningCommand.ExecuteContext(ctx)
if err != nil {
os.Exit(0)
}
return
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
// spawn the child process
cmd := exec.CommandContext(ctx, os.Args[0], "-test.run=TestFallbackForceExit")
cmd.Env = append(os.Environ(), "TEST_FALLBACK_FORCE_EXIT=1")
var buf strings.Builder
cmd.Stderr = &buf
cmd.Stdout = &buf
t.Cleanup(func() {
_ = cmd.Process.Kill()
})
assert.NilError(t, cmd.Start())
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if strings.Contains(buf.String(), "running long command") {
return poll.Success()
}
return poll.Continue("waiting for child process to start")
}, poll.WithTimeout(1*time.Second), poll.WithDelay(100*time.Millisecond))
for i := 0; i < 3; i++ {
cmd.Process.Signal(syscall.SIGINT)
time.Sleep(100 * time.Millisecond)
}
cmdErr := make(chan error, 1)
go func() {
cmdErr <- cmd.Wait()
}()
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if strings.Contains(buf.String(), "got 3 SIGTERM/SIGINTs, forcefully exiting") {
return poll.Success()
}
return poll.Continue("waiting for child process to exit")
},
poll.WithTimeout(1*time.Second), poll.WithDelay(100*time.Millisecond))
select {
case cmdErr := <-cmdErr:
assert.Error(t, cmdErr, "exit status 1")
exitErr, ok := cmdErr.(*exec.ExitError)
if !ok {
t.Fatalf("unexpected error type: %T", cmdErr)
}
if exitErr.Success() {
t.Fatalf("unexpected exit status: %v", exitErr)
}
case <-time.After(1 * time.Second):
t.Fatal("timed out waiting for child process to exit")
}
}