feat: global signal handling with context cancellation

Signed-off-by: Alano Terblanche <18033717+Benehiko@users.noreply.github.com>
This commit is contained in:
Alano Terblanche 2024-04-08 10:11:09 +02:00
parent 8b924a5119
commit 3f0d90a2a9
No known key found for this signature in database
GPG Key ID: 0E8FACD1BA98DE27
9 changed files with 161 additions and 50 deletions

View File

@ -105,7 +105,12 @@ func RunAttach(ctx context.Context, dockerCLI command.Cli, containerID string, o
if opts.Proxy && !c.Config.Tty { if opts.Proxy && !c.Config.Tty {
sigc := notifyAllSignals() sigc := notifyAllSignals()
go ForwardAllSignals(ctx, apiClient, containerID, sigc) // since we're explicitly setting up signal handling here, and the daemon will
// get notified independently of the clients ctx cancellation, we use this context
// but without cancellation to avoid ForwardAllSignals from returning
// before all signals are forwarded.
bgCtx := context.WithoutCancel(ctx)
go ForwardAllSignals(bgCtx, apiClient, containerID, sigc)
defer signal.StopCatch(sigc) defer signal.StopCatch(sigc)
} }

View File

@ -37,6 +37,7 @@ type fakeClient struct {
containerRemoveFunc func(ctx context.Context, containerID string, options container.RemoveOptions) error containerRemoveFunc func(ctx context.Context, containerID string, options container.RemoveOptions) error
containerKillFunc func(ctx context.Context, containerID, signal string) error containerKillFunc func(ctx context.Context, containerID, signal string) error
containerPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error) containerPruneFunc func(ctx context.Context, pruneFilters filters.Args) (types.ContainersPruneReport, error)
containerAttachFunc func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error)
Version string Version string
} }
@ -173,3 +174,10 @@ func (f *fakeClient) ContainersPrune(ctx context.Context, pruneFilters filters.A
} }
return types.ContainersPruneReport{}, nil return types.ContainersPruneReport{}, nil
} }
func (f *fakeClient) ContainerAttach(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) {
if f.containerAttachFunc != nil {
return f.containerAttachFunc(ctx, containerID, options)
}
return types.HijackedResponse{}, nil
}

View File

@ -150,7 +150,12 @@ func runContainer(ctx context.Context, dockerCli command.Cli, runOpts *runOption
} }
if runOpts.sigProxy { if runOpts.sigProxy {
sigc := notifyAllSignals() sigc := notifyAllSignals()
go ForwardAllSignals(ctx, apiClient, containerID, sigc) // since we're explicitly setting up signal handling here, and the daemon will
// get notified independently of the clients ctx cancellation, we use this context
// but without cancellation to avoid ForwardAllSignals from returning
// before all signals are forwarded.
bgCtx := context.WithoutCancel(ctx)
go ForwardAllSignals(bgCtx, apiClient, containerID, sigc)
defer signal.StopCatch(sigc) defer signal.StopCatch(sigc)
} }

View File

@ -5,11 +5,18 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net"
"os/signal"
"syscall"
"testing" "testing"
"time"
"github.com/creack/pty"
"github.com/docker/cli/cli" "github.com/docker/cli/cli"
"github.com/docker/cli/cli/streams"
"github.com/docker/cli/internal/test" "github.com/docker/cli/internal/test"
"github.com/docker/cli/internal/test/notary" "github.com/docker/cli/internal/test/notary"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network" "github.com/docker/docker/api/types/network"
specs "github.com/opencontainers/image-spec/specs-go/v1" specs "github.com/opencontainers/image-spec/specs-go/v1"
@ -32,6 +39,68 @@ func TestRunLabel(t *testing.T) {
assert.NilError(t, cmd.Execute()) assert.NilError(t, cmd.Execute())
} }
func TestRunAttachTermination(t *testing.T) {
p, tty, err := pty.Open()
assert.NilError(t, err)
defer func() {
_ = tty.Close()
_ = p.Close()
}()
killCh := make(chan struct{})
attachCh := make(chan struct{})
fakeCLI := test.NewFakeCli(&fakeClient{
createContainerFunc: func(_ *container.Config, _ *container.HostConfig, _ *network.NetworkingConfig, _ *specs.Platform, _ string) (container.CreateResponse, error) {
return container.CreateResponse{
ID: "id",
}, nil
},
containerKillFunc: func(ctx context.Context, containerID, signal string) error {
killCh <- struct{}{}
return nil
},
containerAttachFunc: func(ctx context.Context, containerID string, options container.AttachOptions) (types.HijackedResponse, error) {
server, client := net.Pipe()
t.Cleanup(func() {
_ = server.Close()
})
attachCh <- struct{}{}
return types.NewHijackedResponse(client, types.MediaTypeRawStream), nil
},
Version: "1.36",
}, func(fc *test.FakeCli) {
fc.SetOut(streams.NewOut(tty))
fc.SetIn(streams.NewIn(tty))
})
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer cancel()
assert.Equal(t, fakeCLI.In().IsTerminal(), true)
assert.Equal(t, fakeCLI.Out().IsTerminal(), true)
cmd := NewRunCommand(fakeCLI)
cmd.SetArgs([]string{"-it", "busybox"})
cmd.SilenceUsage = true
go func() {
assert.ErrorIs(t, cmd.ExecuteContext(ctx), context.Canceled)
}()
select {
case <-time.After(5 * time.Second):
t.Fatal("containerAttachFunc was not called before the 5 second timeout")
case <-attachCh:
}
assert.NilError(t, syscall.Kill(syscall.Getpid(), syscall.SIGTERM))
select {
case <-time.After(5 * time.Second):
cancel()
t.Fatal("containerKillFunc was not called before the 5 second timeout")
case <-killCh:
}
}
func TestRunCommandWithContentTrustErrors(t *testing.T) { func TestRunCommandWithContentTrustErrors(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string

View File

@ -87,7 +87,8 @@ func RunStart(ctx context.Context, dockerCli command.Cli, opts *StartOptions) er
// We always use c.ID instead of container to maintain consistency during `docker start` // We always use c.ID instead of container to maintain consistency during `docker start`
if !c.Config.Tty { if !c.Config.Tty {
sigc := notifyAllSignals() sigc := notifyAllSignals()
go ForwardAllSignals(ctx, dockerCli.Client(), c.ID, sigc) bgCtx := context.WithoutCancel(ctx)
go ForwardAllSignals(bgCtx, dockerCli.Client(), c.ID, sigc)
defer signal.StopCatch(sigc) defer signal.StopCatch(sigc)
} }

View File

@ -9,11 +9,9 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"os/signal"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"syscall"
"github.com/docker/cli/cli/streams" "github.com/docker/cli/cli/streams"
"github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/filters"
@ -103,11 +101,6 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m
result := make(chan bool) result := make(chan bool)
// Catch the termination signal and exit the prompt gracefully.
// The caller is responsible for properly handling the termination.
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
defer notifyCancel()
go func() { go func() {
var res bool var res bool
scanner := bufio.NewScanner(ins) scanner := bufio.NewScanner(ins)
@ -121,8 +114,7 @@ func PromptForConfirmation(ctx context.Context, ins io.Reader, outs io.Writer, m
}() }()
select { select {
case <-notifyCtx.Done(): case <-ctx.Done():
// print a newline on termination
_, _ = fmt.Fprintln(outs, "") _, _ = fmt.Fprintln(outs, "")
return false, ErrPromptTerminated return false, ErrPromptTerminated
case r := <-result: case r := <-result:

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"os/signal"
"path/filepath" "path/filepath"
"strings" "strings"
"syscall" "syscall"
@ -135,6 +136,9 @@ func TestPromptForConfirmation(t *testing.T) {
}, promptResult{false, nil}}, }, promptResult{false, nil}},
} { } {
t.Run("case="+tc.desc, func(t *testing.T) { t.Run("case="+tc.desc, func(t *testing.T) {
notifyCtx, notifyCancel := signal.NotifyContext(ctx, syscall.SIGINT, syscall.SIGTERM)
t.Cleanup(notifyCancel)
buf.Reset() buf.Reset()
promptReader, promptWriter = io.Pipe() promptReader, promptWriter = io.Pipe()
@ -145,7 +149,7 @@ func TestPromptForConfirmation(t *testing.T) {
result := make(chan promptResult, 1) result := make(chan promptResult, 1)
go func() { go func() {
r, err := command.PromptForConfirmation(ctx, promptReader, promptOut, "") r, err := command.PromptForConfirmation(notifyCtx, promptReader, promptOut, "")
result <- promptResult{r, err} result <- promptResult{r, err}
}() }()

View File

@ -28,12 +28,20 @@ import (
) )
func main() { func main() {
ctx := context.Background() statusCode := dockerMain()
if statusCode != 0 {
os.Exit(statusCode)
}
}
func dockerMain() int {
ctx, cancelNotify := signal.NotifyContext(context.Background(), platformsignals.TerminationSignals...)
defer cancelNotify()
dockerCli, err := command.NewDockerCli(command.WithBaseContext(ctx)) dockerCli, err := command.NewDockerCli(command.WithBaseContext(ctx))
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
os.Exit(1) return 1
} }
logrus.SetOutput(dockerCli.Err()) logrus.SetOutput(dockerCli.Err())
otel.SetErrorHandler(debug.OTELErrorHandler) otel.SetErrorHandler(debug.OTELErrorHandler)
@ -46,16 +54,17 @@ func main() {
// StatusError should only be used for errors, and all errors should // StatusError should only be used for errors, and all errors should
// have a non-zero exit status, so never exit with 0 // have a non-zero exit status, so never exit with 0
if sterr.StatusCode == 0 { if sterr.StatusCode == 0 {
os.Exit(1) return 1
} }
os.Exit(sterr.StatusCode) return sterr.StatusCode
} }
if errdefs.IsCancelled(err) { if errdefs.IsCancelled(err) {
os.Exit(0) return 0
} }
fmt.Fprintln(dockerCli.Err(), err) fmt.Fprintln(dockerCli.Err(), err)
os.Exit(1) return 1
} }
return 0
} }
func newDockerCommand(dockerCli *command.DockerCli) *cli.TopLevelCommand { func newDockerCommand(dockerCli *command.DockerCli) *cli.TopLevelCommand {
@ -224,7 +233,7 @@ func setValidateArgs(dockerCli command.Cli, cmd *cobra.Command) {
}) })
} }
func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error { func tryPluginRun(ctx context.Context, dockerCli command.Cli, cmd *cobra.Command, subcommand string, envs []string) error {
plugincmd, err := pluginmanager.PluginRunCommand(dockerCli, subcommand, cmd) plugincmd, err := pluginmanager.PluginRunCommand(dockerCli, subcommand, cmd)
if err != nil { if err != nil {
return err return err
@ -242,40 +251,56 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
// Background signal handling logic: block on the signals channel, and // Background signal handling logic: block on the signals channel, and
// notify the plugin via the PluginServer (or signal) as appropriate. // notify the plugin via the PluginServer (or signal) as appropriate.
const exitLimit = 3 const exitLimit = 2
signals := make(chan os.Signal, exitLimit)
signal.Notify(signals, platformsignals.TerminationSignals...) tryTerminatePlugin := func(force bool) {
// If stdin is a TTY, the kernel will forward
// signals to the subprocess because the shared
// pgid makes the TTY a controlling terminal.
//
// The plugin should have it's own copy of this
// termination logic, and exit after 3 retries
// on it's own.
if dockerCli.Out().IsTerminal() {
return
}
// Terminate the plugin server, which will
// close all connections with plugin
// subprocesses, and signal them to exit.
//
// Repeated invocations will result in EINVAL,
// or EBADF; but that is fine for our purposes.
_ = srv.Close()
// force the process to terminate if it hasn't already
if force {
_ = plugincmd.Process.Kill()
_, _ = fmt.Fprint(dockerCli.Err(), "got 3 SIGTERM/SIGINTs, forcefully exiting\n")
os.Exit(1)
}
}
go func() { go func() {
retries := 0 retries := 0
force := false
// catch the first signal through context cancellation
<-ctx.Done()
tryTerminatePlugin(force)
// register subsequent signals
signals := make(chan os.Signal, exitLimit)
signal.Notify(signals, platformsignals.TerminationSignals...)
for range signals { for range signals {
// If stdin is a TTY, the kernel will forward retries++
// signals to the subprocess because the shared
// pgid makes the TTY a controlling terminal.
//
// The plugin should have it's own copy of this
// termination logic, and exit after 3 retries
// on it's own.
if dockerCli.Out().IsTerminal() {
continue
}
// Terminate the plugin server, which will
// close all connections with plugin
// subprocesses, and signal them to exit.
//
// Repeated invocations will result in EINVAL,
// or EBADF; but that is fine for our purposes.
_ = srv.Close()
// If we're still running after 3 interruptions // If we're still running after 3 interruptions
// (SIGINT/SIGTERM), send a SIGKILL to the plugin as a // (SIGINT/SIGTERM), send a SIGKILL to the plugin as a
// final attempt to terminate, and exit. // final attempt to terminate, and exit.
retries++
if retries >= exitLimit { if retries >= exitLimit {
_, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries) force = true
_ = plugincmd.Process.Kill()
os.Exit(1)
} }
tryTerminatePlugin(force)
} }
}() }()
@ -338,7 +363,7 @@ func runDocker(ctx context.Context, dockerCli *command.DockerCli) error {
ccmd, _, err := cmd.Find(args) ccmd, _, err := cmd.Find(args)
subCommand = ccmd subCommand = ccmd
if err != nil || pluginmanager.IsPluginCommand(ccmd) { if err != nil || pluginmanager.IsPluginCommand(ccmd) {
err := tryPluginRun(dockerCli, cmd, args[0], envs) err := tryPluginRun(ctx, dockerCli, cmd, args[0], envs)
if err == nil { if err == nil {
if dockerCli.HooksEnabled() && dockerCli.Out().IsTerminal() && ccmd != nil { if dockerCli.HooksEnabled() && dockerCli.Out().IsTerminal() && ccmd != nil {
pluginmanager.RunPluginHooks(ctx, dockerCli, cmd, ccmd, args) pluginmanager.RunPluginHooks(ctx, dockerCli, cmd, ccmd, args)

View File

@ -3,7 +3,6 @@ package test
import ( import (
"context" "context"
"os" "os"
"syscall"
"testing" "testing"
"time" "time"
@ -32,8 +31,11 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli
assert.NilError(t, err) assert.NilError(t, err)
cli.SetIn(streams.NewIn(r)) cli.SetIn(streams.NewIn(r))
notifyCtx, notifyCancel := context.WithCancel(ctx)
t.Cleanup(notifyCancel)
go func() { go func() {
errChan <- cmd.ExecuteContext(ctx) errChan <- cmd.ExecuteContext(notifyCtx)
}() }()
writeCtx, writeCancel := context.WithTimeout(ctx, 100*time.Millisecond) writeCtx, writeCancel := context.WithTimeout(ctx, 100*time.Millisecond)
@ -66,7 +68,7 @@ func TerminatePrompt(ctx context.Context, t *testing.T, cmd *cobra.Command, cli
// sigint and sigterm are caught by the prompt // sigint and sigterm are caught by the prompt
// this allows us to gracefully exit the prompt with a 0 exit code // this allows us to gracefully exit the prompt with a 0 exit code
syscall.Kill(syscall.Getpid(), syscall.SIGINT) notifyCancel()
select { select {
case <-errCtx.Done(): case <-errCtx.Done():