diff --git a/cli-plugins/socket/socket.go b/cli-plugins/socket/socket.go index 67ba11562e..2ad3a183d7 100644 --- a/cli-plugins/socket/socket.go +++ b/cli-plugins/socket/socket.go @@ -7,24 +7,104 @@ import ( "io" "net" "os" + "runtime" + "sync" ) // EnvKey represents the well-known environment variable used to pass the plugin being // executed the socket name it should listen on to coordinate with the host CLI. const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET" -// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections -// and update the conn pointer, and returns the listener for the socket (which the caller -// is responsible for closing when it's no longer needed). -func SetupConn(conn **net.UnixConn) (*net.UnixListener, error) { - listener, err := listen("docker_cli_" + randomID()) +// NewPluginServer creates a plugin server that listens on a new Unix domain socket. +// `h` is called for each new connection to the socket in a goroutine. +func NewPluginServer(h func(net.Conn)) (*PluginServer, error) { + l, err := listen("docker_cli_" + randomID()) if err != nil { return nil, err } - accept(listener, conn) + if h == nil { + h = func(net.Conn) {} + } - return listener, nil + pl := &PluginServer{ + l: l, + h: h, + } + + go func() { + defer pl.Close() + for { + err := pl.accept() + if err != nil { + return + } + } + }() + + return pl, nil +} + +type PluginServer struct { + mu sync.Mutex + conns []net.Conn + l *net.UnixListener + h func(net.Conn) + closed bool +} + +func (pl *PluginServer) accept() error { + conn, err := pl.l.Accept() + if err != nil { + return err + } + + pl.mu.Lock() + defer pl.mu.Unlock() + + if pl.closed { + // handle potential race condition between Close and Accept + conn.Close() + return errors.New("plugin server is closed") + } + + pl.conns = append(pl.conns, conn) + + go pl.h(conn) + return nil +} + +func (pl *PluginServer) Addr() net.Addr { + return pl.l.Addr() +} + +// Close ensures that the server is no longer accepting new connections and closes all existing connections. +// Existing connections will receive [io.EOF]. +func (pl *PluginServer) Close() error { + // Remove the listener socket, if it exists on the filesystem. + unlink(pl.l) + + // Close connections first to ensure the connections get io.EOF instead of a connection reset. + pl.closeAllConns() + + // Try to ensure that any active connections have a chance to receive io.EOF + runtime.Gosched() + + return pl.l.Close() +} + +func (pl *PluginServer) closeAllConns() { + pl.mu.Lock() + defer pl.mu.Unlock() + + // Prevent new connections from being accepted + pl.closed = true + + for _, conn := range pl.conns { + conn.Close() + } + + pl.conns = nil } func randomID() string { @@ -35,18 +115,6 @@ func randomID() string { return hex.EncodeToString(b) } -func accept(listener *net.UnixListener, conn **net.UnixConn) { - go func() { - for { - // ignore error here, if we failed to accept a connection, - // conn is nil and we fallback to previous behavior - *conn, _ = listener.AcceptUnix() - // perform any platform-specific actions on accept (e.g. unlink non-abstract sockets) - onAccept(*conn, listener) - } - }() -} - // ConnectAndWait connects to the socket passed via well-known env var, // if present, and attempts to read from it until it receives an EOF, at which // point cb is called. diff --git a/cli-plugins/socket/socket_abstract.go b/cli-plugins/socket/socket_abstract.go new file mode 100644 index 0000000000..ce7b429036 --- /dev/null +++ b/cli-plugins/socket/socket_abstract.go @@ -0,0 +1,20 @@ +//go:build windows || linux + +package socket + +import ( + "net" +) + +func listen(socketname string) (*net.UnixListener, error) { + // Create an abstract socket -- this socket can be opened by name, but is + // not present in the filesystem. + return net.ListenUnix("unix", &net.UnixAddr{ + Name: "@" + socketname, + Net: "unix", + }) +} + +func unlink(listener *net.UnixListener) { + // Do nothing; the socket is not present in the filesystem. +} diff --git a/cli-plugins/socket/socket_darwin.go b/cli-plugins/socket/socket_darwin.go deleted file mode 100644 index 17ab6aa69e..0000000000 --- a/cli-plugins/socket/socket_darwin.go +++ /dev/null @@ -1,19 +0,0 @@ -package socket - -import ( - "net" - "os" - "path/filepath" - "syscall" -) - -func listen(socketname string) (*net.UnixListener, error) { - return net.ListenUnix("unix", &net.UnixAddr{ - Name: filepath.Join(os.TempDir(), socketname), - Net: "unix", - }) -} - -func onAccept(conn *net.UnixConn, listener *net.UnixListener) { - syscall.Unlink(listener.Addr().String()) -} diff --git a/cli-plugins/socket/socket_noabstract.go b/cli-plugins/socket/socket_noabstract.go new file mode 100644 index 0000000000..fbc948e6dc --- /dev/null +++ b/cli-plugins/socket/socket_noabstract.go @@ -0,0 +1,25 @@ +//go:build !windows && !linux + +package socket + +import ( + "net" + "os" + "path/filepath" + "syscall" +) + +func listen(socketname string) (*net.UnixListener, error) { + // Because abstract sockets are unavailable, we create a socket in the + // system temporary directory instead. + return net.ListenUnix("unix", &net.UnixAddr{ + Name: filepath.Join(os.TempDir(), socketname), + Net: "unix", + }) +} + +func unlink(listener *net.UnixListener) { + // unlink(2) is best effort here; if it fails, we may 'leak' a socket + // into the filesystem, but this is unlikely and overall harmless. + _ = syscall.Unlink(listener.Addr().String()) +} diff --git a/cli-plugins/socket/socket_nodarwin.go b/cli-plugins/socket/socket_nodarwin.go deleted file mode 100644 index aa6065ecb4..0000000000 --- a/cli-plugins/socket/socket_nodarwin.go +++ /dev/null @@ -1,20 +0,0 @@ -//go:build !darwin && !openbsd - -package socket - -import ( - "net" -) - -func listen(socketname string) (*net.UnixListener, error) { - return net.ListenUnix("unix", &net.UnixAddr{ - Name: "@" + socketname, - Net: "unix", - }) -} - -func onAccept(conn *net.UnixConn, listener *net.UnixListener) { - // do nothing - // while on darwin and OpenBSD we would unlink here; - // on non-darwin the socket is abstract and not present on the filesystem -} diff --git a/cli-plugins/socket/socket_openbsd.go b/cli-plugins/socket/socket_openbsd.go deleted file mode 100644 index 17ab6aa69e..0000000000 --- a/cli-plugins/socket/socket_openbsd.go +++ /dev/null @@ -1,19 +0,0 @@ -package socket - -import ( - "net" - "os" - "path/filepath" - "syscall" -) - -func listen(socketname string) (*net.UnixListener, error) { - return net.ListenUnix("unix", &net.UnixAddr{ - Name: filepath.Join(os.TempDir(), socketname), - Net: "unix", - }) -} - -func onAccept(conn *net.UnixConn, listener *net.UnixListener) { - syscall.Unlink(listener.Addr().String()) -} diff --git a/cli-plugins/socket/socket_test.go b/cli-plugins/socket/socket_test.go index 409eb68948..df7b351181 100644 --- a/cli-plugins/socket/socket_test.go +++ b/cli-plugins/socket/socket_test.go @@ -1,11 +1,14 @@ package socket import ( + "errors" + "io" "io/fs" "net" "os" "runtime" "strings" + "sync/atomic" "testing" "time" @@ -13,54 +16,110 @@ import ( "gotest.tools/v3/poll" ) -func TestSetupConn(t *testing.T) { - t.Run("updates conn when connected", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) +func TestPluginServer(t *testing.T) { + t.Run("connection closes with EOF when server closes", func(t *testing.T) { + called := make(chan struct{}) + srv, err := NewPluginServer(func(_ net.Conn) { close(called) }) assert.NilError(t, err) - assert.Check(t, listener != nil, "returned nil listener but no error") - addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) - assert.NilError(t, err, "failed to resolve listener address") + assert.Assert(t, srv != nil, "returned nil server but no error") - _, err = net.DialUnix("unix", nil, addr) - assert.NilError(t, err, "failed to dial returned listener") + addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) + assert.NilError(t, err, "failed to resolve server address") - pollConnNotNil(t, &conn) + conn, err := net.DialUnix("unix", nil, addr) + assert.NilError(t, err, "failed to dial returned server") + defer conn.Close() + + done := make(chan error, 1) + go func() { + _, err := conn.Read(make([]byte, 1)) + done <- err + }() + + select { + case <-called: + case <-time.After(10 * time.Millisecond): + t.Fatal("handler not called") + } + + srv.Close() + + select { + case err := <-done: + if !errors.Is(err, io.EOF) { + t.Fatalf("exepcted EOF error, got: %v", err) + } + case <-time.After(10 * time.Millisecond): + } }) t.Run("allows reconnects", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) + var calls int32 + h := func(_ net.Conn) { + atomic.AddInt32(&calls, 1) + } + + srv, err := NewPluginServer(h) assert.NilError(t, err) - assert.Check(t, listener != nil, "returned nil listener but no error") - addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) - assert.NilError(t, err, "failed to resolve listener address") + defer srv.Close() + + assert.Check(t, srv.Addr() != nil, "returned nil addr but no error") + + addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) + assert.NilError(t, err, "failed to resolve server address") + + waitForCalls := func(n int) { + poll.WaitOn(t, func(t poll.LogT) poll.Result { + if atomic.LoadInt32(&calls) == int32(n) { + return poll.Success() + } + return poll.Continue("waiting for handler to be called") + }) + } otherConn, err := net.DialUnix("unix", nil, addr) - assert.NilError(t, err, "failed to dial returned listener") - + assert.NilError(t, err, "failed to dial returned server") otherConn.Close() + waitForCalls(1) - _, err = net.DialUnix("unix", nil, addr) - assert.NilError(t, err, "failed to redial listener") + conn, err := net.DialUnix("unix", nil, addr) + assert.NilError(t, err, "failed to redial server") + defer conn.Close() + waitForCalls(2) + + // and again but don't close the existing connection + conn2, err := net.DialUnix("unix", nil, addr) + assert.NilError(t, err, "failed to redial server") + defer conn2.Close() + waitForCalls(3) + + srv.Close() + + // now make sure we get EOF on the existing connections + buf := make([]byte, 1) + _, err = conn.Read(buf) + assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err) + + _, err = conn2.Read(buf) + assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err) }) t.Run("does not leak sockets to local directory", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) + srv, err := NewPluginServer(nil) assert.NilError(t, err) - assert.Check(t, listener != nil, "returned nil listener but no error") - checkDirNoPluginSocket(t) + assert.Check(t, srv != nil, "returned nil server but no error") + checkDirNoNewPluginServer(t) + + addr, err := net.ResolveUnixAddr("unix", srv.Addr().String()) + assert.NilError(t, err, "failed to resolve server address") - addr, err := net.ResolveUnixAddr("unix", listener.Addr().String()) - assert.NilError(t, err, "failed to resolve listener address") _, err = net.DialUnix("unix", nil, addr) - assert.NilError(t, err, "failed to dial returned listener") - checkDirNoPluginSocket(t) + assert.NilError(t, err, "failed to dial returned server") + checkDirNoNewPluginServer(t) }) } -func checkDirNoPluginSocket(t *testing.T) { +func checkDirNoNewPluginServer(t *testing.T) { t.Helper() files, err := os.ReadDir(".") @@ -78,18 +137,24 @@ func checkDirNoPluginSocket(t *testing.T) { func TestConnectAndWait(t *testing.T) { t.Run("calls cancel func on EOF", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) - assert.NilError(t, err, "failed to setup listener") + srv, err := NewPluginServer(nil) + assert.NilError(t, err, "failed to setup server") + defer srv.Close() done := make(chan struct{}) - t.Setenv(EnvKey, listener.Addr().String()) + t.Setenv(EnvKey, srv.Addr().String()) cancelFunc := func() { done <- struct{}{} } ConnectAndWait(cancelFunc) - pollConnNotNil(t, &conn) - conn.Close() + + select { + case <-done: + t.Fatal("unexpectedly done") + default: + } + + srv.Close() select { case <-done: @@ -101,17 +166,19 @@ func TestConnectAndWait(t *testing.T) { // TODO: this test cannot be executed with `t.Parallel()`, due to // relying on goroutine numbers to ensure correct behaviour t.Run("connect goroutine exits after EOF", func(t *testing.T) { - var conn *net.UnixConn - listener, err := SetupConn(&conn) - assert.NilError(t, err, "failed to setup listener") - t.Setenv(EnvKey, listener.Addr().String()) + srv, err := NewPluginServer(nil) + assert.NilError(t, err, "failed to setup server") + + defer srv.Close() + + t.Setenv(EnvKey, srv.Addr().String()) numGoroutines := runtime.NumGoroutine() ConnectAndWait(func() {}) assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1) - pollConnNotNil(t, &conn) - conn.Close() + srv.Close() + poll.WaitOn(t, func(t poll.LogT) poll.Result { if runtime.NumGoroutine() > numGoroutines+1 { return poll.Continue("waiting for connect goroutine to exit") @@ -120,14 +187,3 @@ func TestConnectAndWait(t *testing.T) { }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) }) } - -func pollConnNotNil(t *testing.T, conn **net.UnixConn) { - t.Helper() - - poll.WaitOn(t, func(t poll.LogT) poll.Result { - if *conn == nil { - return poll.Continue("waiting for conn to not be nil") - } - return poll.Success() - }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) -} diff --git a/cmd/docker/docker.go b/cmd/docker/docker.go index cfc53a6fa1..3a46aeb46d 100644 --- a/cmd/docker/docker.go +++ b/cmd/docker/docker.go @@ -2,7 +2,6 @@ package main import ( "fmt" - "net" "os" "os/exec" "os/signal" @@ -222,11 +221,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, } // Establish the plugin socket, adding it to the environment under a well-known key if successful. - var conn *net.UnixConn - listener, err := socket.SetupConn(&conn) + srv, err := socket.NewPluginServer(nil) if err == nil { - envs = append(envs, socket.EnvKey+"="+listener.Addr().String()) - defer listener.Close() + envs = append(envs, socket.EnvKey+"="+srv.Addr().String()) } plugincmd.Env = append(envs, plugincmd.Env...) @@ -247,12 +244,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string, // receive signals due to sharing a pgid with the parent CLI continue } - if conn != nil { - if err := conn.Close(); err != nil { - _, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err) - } - conn = nil - } + + srv.Close() + retries++ if retries >= exitLimit { _, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries)