From d68cc0e8d0164ab5b3f8cd47b333891604f4c4c9 Mon Sep 17 00:00:00 2001 From: Brian Goff Date: Thu, 29 Feb 2024 21:33:03 +0000 Subject: [PATCH] plugin: closer-based plugin notification socket This changes things to rely on a plugin server that manages all connections made to the server. An optional handler can be passed into the server when the caller wants to do extra things with the connection. It is the caller's responsibility to close the server. When the server is closed, first all existing connections are closed (and new connections are prevented). Now the signal loop only needs to close the server and not deal with `net.Conn`'s directly (or double-indirects as the case was before this change). The socket, when present in the filesystem, is no longer unlinked eagerly, as reconnections require it to be present for the lifecycle of the plugin server. Co-authored-by: Bjorn Neergaard Signed-off-by: Brian Goff Signed-off-by: Bjorn Neergaard --- cli-plugins/socket/socket.go | 106 +++++++++++++--- cli-plugins/socket/socket_abstract.go | 20 +++ cli-plugins/socket/socket_darwin.go | 19 --- cli-plugins/socket/socket_noabstract.go | 25 ++++ cli-plugins/socket/socket_nodarwin.go | 20 --- cli-plugins/socket/socket_openbsd.go | 19 --- cli-plugins/socket/socket_test.go | 158 ++++++++++++++++-------- cmd/docker/docker.go | 16 +-- 8 files changed, 244 insertions(+), 139 deletions(-) create mode 100644 cli-plugins/socket/socket_abstract.go delete mode 100644 cli-plugins/socket/socket_darwin.go create mode 100644 cli-plugins/socket/socket_noabstract.go delete mode 100644 cli-plugins/socket/socket_nodarwin.go delete mode 100644 cli-plugins/socket/socket_openbsd.go diff --git a/cli-plugins/socket/socket.go b/cli-plugins/socket/socket.go index 67ba11562e..2ad3a183d7 100644 --- a/cli-plugins/socket/socket.go +++ b/cli-plugins/socket/socket.go @@ -7,24 +7,104 @@ import ( "io" "net" "os" + "runtime" + "sync" ) // EnvKey represents the well-known environment variable used to pass the plugin being // executed the socket name it should listen on to coordinate with the host CLI. const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET" -// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections -// and update the conn pointer, and returns the listener for the socket (which the caller -// is responsible for closing when it's no longer needed). -func SetupConn(conn **net.UnixConn) (*net.UnixListener, error) { - listener, err := listen("docker_cli_" + randomID()) +// NewPluginServer creates a plugin server that listens on a new Unix domain socket. +// `h` is called for each new connection to the socket in a goroutine. +func NewPluginServer(h func(net.Conn)) (*PluginServer, error) { + l, err := listen("docker_cli_" + randomID()) if err != nil { return nil, err } - accept(listener, conn) + if h == nil { + h = func(net.Conn) {} + } - return listener, nil + pl := &PluginServer{ + l: l, + h: h, + } + + go func() { + defer pl.Close() + for { + err := pl.accept() + if err != nil { + return + } + } + }() + + return pl, nil +} + +type PluginServer struct { + mu sync.Mutex + conns []net.Conn + l *net.UnixListener + h func(net.Conn) + closed bool +} + +func (pl *PluginServer) accept() error { + conn, err := pl.l.Accept() + if err != nil { + return err + } + + pl.mu.Lock() + defer pl.mu.Unlock() + + if pl.closed { + // handle potential race condition between Close and Accept + conn.Close() + return errors.New("plugin server is closed") + } + + pl.conns = append(pl.conns, conn) + + go pl.h(conn) + return nil +} + +func (pl *PluginServer) Addr() net.Addr { + return pl.l.Addr() +} + +// Close ensures that the server is no longer accepting new connections and closes all existing connections. +// Existing connections will receive [io.EOF]. +func (pl *PluginServer) Close() error { + // Remove the listener socket, if it exists on the filesystem. + unlink(pl.l) + + // Close connections first to ensure the connections get io.EOF instead of a connection reset. + pl.closeAllConns() + + // Try to ensure that any active connections have a chance to receive io.EOF + runtime.Gosched() + + return pl.l.Close() +} + +func (pl *PluginServer) closeAllConns() { + pl.mu.Lock() + defer pl.mu.Unlock() + + // Prevent new connections from being accepted + pl.closed = true + + for _, conn := range pl.conns { + conn.Close() + } + + pl.conns = nil } func randomID() string { @@ -35,18 +115,6 @@ func randomID() string { return hex.EncodeToString(b) } -func accept(listener *net.UnixListener, conn **net.UnixConn) { - go func() { - for { - // ignore error here, if we failed to accept a connection, - // conn is nil and we fallback to previous behavior - *conn, _ = listener.AcceptUnix() - // perform any platform-specific actions on accept (e.g. unlink non-abstract sockets) - onAccept(*conn, listener) - } - }() -} - // ConnectAndWait connects to the socket passed via well-known env var, // if present, and attempts to read from it until it receives an EOF, at which // point cb is called. diff --git a/cli-plugins/socket/socket_abstract.go b/cli-plugins/socket/socket_abstract.go new file mode 100644 index 0000000000..ce7b429036 --- /dev/null +++ b/cli-plugins/socket/socket_abstract.go @@ -0,0 +1,20 @@ +//go:build windows || linux + +package socket + +import ( + "net" +) + +func listen(socketname string) (*net.UnixListener, error) { + // Create an abstract socket -- this socket can be opened by name, but is + // not present in the filesystem. + return net.ListenUnix("unix", &net.UnixAddr{ + Name: "@" + socketname, + Net: "unix", + }) +} + +func unlink(listener *net.UnixListener) { + // Do nothing; the socket is not present in the filesystem. +} diff --git a/cli-plugins/socket/socket_darwin.go b/cli-plugins/socket/socket_darwin.go deleted file mode 100644 index 17ab6aa69e..0000000000 --- a/cli-plugins/socket/socket_darwin.go +++ /dev/null @@ -1,19 +0,0 @@ -package socket - -import ( - "net" - "os" - "path/filepath" - "syscall" -) - -func listen(socketname string) (*net.UnixListener, error) { - return net.ListenUnix("unix", &net.UnixAddr{ - Name: filepath.Join(os.TempDir(), socketname), - Net: "unix", - }) -} - -func onAccept(conn *net.UnixConn, listener *net.UnixListener) { - syscall.Unlink(listener.Addr().String()) -} diff --git a/cli-plugins/socket/socket_noabstract.go b/cli-plugins/socket/socket_noabstract.go new file mode 100644 index 0000000000..fbc948e6dc --- /dev/null +++ b/cli-plugins/socket/socket_noabstract.go @@ -0,0 +1,25 @@ +//go:build !windows && !linux + +package socket + +import ( + "net" + "os" + "path/filepath" + "syscall" +) + +func listen(socketname string) (*net.UnixListener, error) { + // Because abstract sockets are unavailable, we create a socket in the + // system temporary directory instead. + return net.ListenUnix("unix", &net.UnixAddr{ + Name: filepath.Join(os.TempDir(), socketname), + Net: "unix", + }) +} + +func unlink(listener *net.UnixListener) { + // unlink(2) is best effort here; if it fails, we may 'leak' a socket + // into the filesystem, but this is unlikely and overall harmless. + _ = syscall.Unlink(listener.Addr().String()) +} diff --git a/cli-plugins/socket/socket_nodarwin.go b/cli-plugins/socket/socket_nodarwin.go deleted file mode 100644 index aa6065ecb4..0000000000 --- a/cli-plugins/socket/socket_nodarwin.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build !darwin && !openbsd - -package socket - -import ( - "net" -) - -func listen(socketname string) (*net.UnixListener, error) { - return net.ListenUnix("unix", &net.UnixAddr{ - Name: "@" + socketname, - Net: "unix", - }) -} - -func onAccept(conn *net.UnixConn, listener *net.UnixListener) { - // do nothing - // while on darwin and OpenBSD we would unlink here; - // on non-darwin the socket is abstract and not present on the filesystem -} diff --git a/cli-plugins/socket/socket_openbsd.go b/cli-plugins/socket/socket_openbsd.go deleted file mode 100644 index 17ab6aa69e..0000000000 --- a/cli-plugins/socket/socket_openbsd.go +++ /dev/null @@ -1,19 +0,0 @@ -package socket - -import ( - "net" - "os" - "path/filepath" - "syscall" -) - -func listen(socketname string) (*net.UnixListener, error) { - return net.ListenUnix("unix", &net.UnixAddr{ - Name: filepath.Join(os.TempDir(), socketname), - Net: "unix", - }) -} - -func onAccept(conn *net.UnixConn, listener *net.UnixListener) { - syscall.Unlink(listener.Addr().String()) -} diff --git a/cli-plugins/socket/socket_test.go b/cli-plugins/socket/socket_test.go index 409eb68948..df7b351181 100644 --- a/cli-plugins/socket/socket_test.go +++ b/cli-plugins/socket/socket_test.go @@ -1,11 +1,14 @@ package socket import ( + "errors" + "io" "io/fs" "net" "os" "runtime" "strings" + "sync/atomic" "testing" "time" @@ -13,54 +16,110 @@ import ( "gotest.tools/v3/poll" ) -func TestSetupConn(t *testing.T) { - t.Run("updates conn when connected", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) +func TestPluginServer(t *testing.T) { + t.Run("connection closes with EOF when server closes", func(t *testing.T) { + called := make(chan struct{}) + srv, err := NewPluginServer(func(_ net.Conn) { close(called) }) assert.NilError(t, err) - assert.Check(t, listener != nil, "returned nil listener but no error") - addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) - assert.NilError(t, err, "failed to resolve listener address") + assert.Assert(t, srv != nil, "returned nil server but no error") - _, err = net.DialUnix("unix", nil, addr) - assert.NilError(t, err, "failed to dial returned listener") + addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) + assert.NilError(t, err, "failed to resolve server address") - pollConnNotNil(t, &conn) + conn, err := net.DialUnix("unix", nil, addr) + assert.NilError(t, err, "failed to dial returned server") + defer conn.Close() + + done := make(chan error, 1) + go func() { + _, err := conn.Read(make([]byte, 1)) + done <- err + }() + + select { + case <-called: + case <-time.After(10 * time.Millisecond): + t.Fatal("handler not called") + } + + srv.Close() + + select { + case err := <-done: + if !errors.Is(err, io.EOF) { + t.Fatalf("exepcted EOF error, got: %v", err) + } + case <-time.After(10 * time.Millisecond): + } }) t.Run("allows reconnects", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) + var calls int32 + h := func(_ net.Conn) { + atomic.AddInt32(&calls, 1) + } + + srv, err := NewPluginServer(h) assert.NilError(t, err) - assert.Check(t, listener != nil, "returned nil listener but no error") - addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) - assert.NilError(t, err, "failed to resolve listener address") + defer srv.Close() + + assert.Check(t, srv.Addr() != nil, "returned nil addr but no error") + + addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) + assert.NilError(t, err, "failed to resolve server address") + + waitForCalls := func(n int) { + poll.WaitOn(t, func(t poll.LogT) poll.Result { + if atomic.LoadInt32(&calls) == int32(n) { + return poll.Success() + } + return poll.Continue("waiting for handler to be called") + }) + } otherConn, err := net.DialUnix("unix", nil, addr) - assert.NilError(t, err, "failed to dial returned listener") - + assert.NilError(t, err, "failed to dial returned server") otherConn.Close() + waitForCalls(1) - _, err = net.DialUnix("unix", nil, addr) - assert.NilError(t, err, "failed to redial listener") + conn, err := net.DialUnix("unix", nil, addr) + assert.NilError(t, err, "failed to redial server") + defer conn.Close() + waitForCalls(2) + + // and again but don't close the existing connection + conn2, err := net.DialUnix("unix", nil, addr) + assert.NilError(t, err, "failed to redial server") + defer conn2.Close() + waitForCalls(3) + + srv.Close() + + // now make sure we get EOF on the existing connections + buf := make([]byte, 1) + _, err = conn.Read(buf) + assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err) + + _, err = conn2.Read(buf) + assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err) }) t.Run("does not leak sockets to local directory", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) + srv, err := NewPluginServer(nil) assert.NilError(t, err) - assert.Check(t, listener != nil, "returned nil listener but no error") - checkDirNoPluginSocket(t) + assert.Check(t, srv != nil, "returned nil server but no error") + checkDirNoNewPluginServer(t) + + addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) + assert.NilError(t, err, "failed to resolve server address") - addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) - assert.NilError(t, err, "failed to resolve listener address") _, err = net.DialUnix("unix", nil, addr) - assert.NilError(t, err, "failed to dial returned listener") - checkDirNoPluginSocket(t) + assert.NilError(t, err, "failed to dial returned server") + checkDirNoNewPluginServer(t) }) } -func checkDirNoPluginSocket(t *testing.T) { +func checkDirNoNewPluginServer(t *testing.T) { t.Helper() files, err := os.ReadDir(".") @@ -78,18 +137,24 @@ func checkDirNoPluginSocket(t *testing.T) { func TestConnectAndWait(t *testing.T) { t.Run("calls cancel func on EOF", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) - assert.NilError(t, err, "failed to setup listener") + srv, err := NewPluginServer(nil) + assert.NilError(t, err, "failed to setup server") + defer srv.Close() done := make(chan struct{}) - t.Setenv(EnvKey, listener.Addr().String()) + t.Setenv(EnvKey, srv.Addr().String()) cancelFunc := func() { done <- struct{}{} } ConnectAndWait(cancelFunc) - pollConnNotNil(t, &conn) - conn.Close() + + select { + case <-done: + t.Fatal("unexpectedly done") + default: + } + + srv.Close() select { case <-done: @@ -101,17 +166,19 @@ func TestConnectAndWait(t *testing.T) { // TODO: this test cannot be executed with `t.Parallel()`, due to // relying on goroutine numbers to ensure correct behaviour t.Run("connect goroutine exits after EOF", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) - assert.NilError(t, err, "failed to setup listener") - t.Setenv(EnvKey, listener.Addr().String()) + srv, err := NewPluginServer(nil) + assert.NilError(t, err, "failed to setup server") + + defer srv.Close() + + t.Setenv(EnvKey, srv.Addr().String()) numGoroutines := runtime.NumGoroutine() ConnectAndWait(func() {}) assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1) - pollConnNotNil(t, &conn) - conn.Close() + srv.Close() + poll.WaitOn(t, func(t poll.LogT) poll.Result { if runtime.NumGoroutine() > numGoroutines+1 { return poll.Continue("waiting for connect goroutine to exit") @@ -120,14 +187,3 @@ func TestConnectAndWait(t *testing.T) { }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) }) } - -func pollConnNotNil(t *testing.T, conn **net.UnixConn) { - t.Helper() - - poll.WaitOn(t, func(t poll.LogT) poll.Result { - if *conn == nil { - return poll.Continue("waiting for conn to not be nil") - } - return poll.Success() - }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) -} diff --git a/cmd/docker/docker.go b/cmd/docker/docker.go index cfc53a6fa1..3a46aeb46d 100644 --- a/cmd/docker/docker.go +++ b/cmd/docker/docker.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "net" "os" "os/exec" "os/signal" @@ -222,11 +221,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, } // Establish the plugin socket, adding it to the environment under a well-known key if successful. - var conn *net.UnixConn - listener, err := socket.SetupConn(&conn) + srv, err := socket.NewPluginServer(nil) if err == nil { - envs = append(envs, socket.EnvKey+"="+listener.Addr().String()) - defer listener.Close() + envs = append(envs, socket.EnvKey+"="+srv.Addr().String()) } plugincmd.Env = append(envs, plugincmd.Env...) @@ -247,12 +244,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, // receive signals due to sharing a pgid with the parent CLI continue } - if conn != nil { - if err := conn.Close(); err != nil { - _, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err) - } - conn = nil - } + + srv.Close() + retries++ if retries >= exitLimit { _, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries)