plugin: closer-based plugin notification socket

This changes things to rely on a plugin server that manages all
connections made to the server.

An optional handler can be passed into the server when the caller wants
to do extra things with the connection.

It is the caller's responsibility to close the server.
When the server is closed, first all existing connections are closed
(and new connections are prevented).

Now the signal loop only needs to close the server and not deal with
`net.Conn`'s directly (or double-indirects as the case was before this
change).

The socket, when present in the filesystem, is no longer unlinked
eagerly, as reconnections require it to be present for the lifecycle of
the plugin server.

Co-authored-by: Bjorn Neergaard <bjorn.neergaard@docker.com>
Signed-off-by: Brian Goff <cpuguy83@gmail.com>
Signed-off-by: Bjorn Neergaard <bjorn.neergaard@docker.com>
This commit is contained in:
Brian Goff 2024-02-29 21:33:03 +00:00 committed by Bjorn Neergaard
parent 4468148f37
commit d68cc0e8d0
No known key found for this signature in database
8 changed files with 244 additions and 139 deletions

View File

@ -7,24 +7,104 @@ import (
"io"
"net"
"os"
"runtime"
"sync"
)
// EnvKey represents the well-known environment variable used to pass the plugin being
// executed the socket name it should listen on to coordinate with the host CLI.
const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET"
// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections
// and update the conn pointer, and returns the listener for the socket (which the caller
// is responsible for closing when it's no longer needed).
func SetupConn(conn **net.UnixConn) (*net.UnixListener, error) {
listener, err := listen("docker_cli_" + randomID())
// NewPluginServer creates a plugin server that listens on a new Unix domain socket.
// `h` is called for each new connection to the socket in a goroutine.
func NewPluginServer(h func(net.Conn)) (*PluginServer, error) {
l, err := listen("docker_cli_" + randomID())
if err != nil {
return nil, err
}
accept(listener, conn)
if h == nil {
h = func(net.Conn) {}
}
return listener, nil
pl := &PluginServer{
l: l,
h: h,
}
go func() {
defer pl.Close()
for {
err := pl.accept()
if err != nil {
return
}
}
}()
return pl, nil
}
type PluginServer struct {
mu sync.Mutex
conns []net.Conn
l *net.UnixListener
h func(net.Conn)
closed bool
}
func (pl *PluginServer) accept() error {
conn, err := pl.l.Accept()
if err != nil {
return err
}
pl.mu.Lock()
defer pl.mu.Unlock()
if pl.closed {
// handle potential race condition between Close and Accept
conn.Close()
return errors.New("plugin server is closed")
}
pl.conns = append(pl.conns, conn)
go pl.h(conn)
return nil
}
func (pl *PluginServer) Addr() net.Addr {
return pl.l.Addr()
}
// Close ensures that the server is no longer accepting new connections and closes all existing connections.
// Existing connections will receive [io.EOF].
func (pl *PluginServer) Close() error {
// Remove the listener socket, if it exists on the filesystem.
unlink(pl.l)
// Close connections first to ensure the connections get io.EOF instead of a connection reset.
pl.closeAllConns()
// Try to ensure that any active connections have a chance to receive io.EOF
runtime.Gosched()
return pl.l.Close()
}
func (pl *PluginServer) closeAllConns() {
pl.mu.Lock()
defer pl.mu.Unlock()
// Prevent new connections from being accepted
pl.closed = true
for _, conn := range pl.conns {
conn.Close()
}
pl.conns = nil
}
func randomID() string {
@ -35,18 +115,6 @@ func randomID() string {
return hex.EncodeToString(b)
}
func accept(listener *net.UnixListener, conn **net.UnixConn) {
go func() {
for {
// ignore error here, if we failed to accept a connection,
// conn is nil and we fallback to previous behavior
*conn, _ = listener.AcceptUnix()
// perform any platform-specific actions on accept (e.g. unlink non-abstract sockets)
onAccept(*conn, listener)
}
}()
}
// ConnectAndWait connects to the socket passed via well-known env var,
// if present, and attempts to read from it until it receives an EOF, at which
// point cb is called.

View File

@ -0,0 +1,20 @@
//go:build windows || linux
package socket
import (
"net"
)
func listen(socketname string) (*net.UnixListener, error) {
// Create an abstract socket -- this socket can be opened by name, but is
// not present in the filesystem.
return net.ListenUnix("unix", &net.UnixAddr{
Name: "@" + socketname,
Net: "unix",
})
}
func unlink(listener *net.UnixListener) {
// Do nothing; the socket is not present in the filesystem.
}

View File

@ -1,19 +0,0 @@
package socket
import (
"net"
"os"
"path/filepath"
"syscall"
)
func listen(socketname string) (*net.UnixListener, error) {
return net.ListenUnix("unix", &net.UnixAddr{
Name: filepath.Join(os.TempDir(), socketname),
Net: "unix",
})
}
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
syscall.Unlink(listener.Addr().String())
}

View File

@ -0,0 +1,25 @@
//go:build !windows && !linux
package socket
import (
"net"
"os"
"path/filepath"
"syscall"
)
func listen(socketname string) (*net.UnixListener, error) {
// Because abstract sockets are unavailable, we create a socket in the
// system temporary directory instead.
return net.ListenUnix("unix", &net.UnixAddr{
Name: filepath.Join(os.TempDir(), socketname),
Net: "unix",
})
}
func unlink(listener *net.UnixListener) {
// unlink(2) is best effort here; if it fails, we may 'leak' a socket
// into the filesystem, but this is unlikely and overall harmless.
_ = syscall.Unlink(listener.Addr().String())
}

View File

@ -1,20 +0,0 @@
//go:build !darwin && !openbsd
package socket
import (
"net"
)
func listen(socketname string) (*net.UnixListener, error) {
return net.ListenUnix("unix", &net.UnixAddr{
Name: "@" + socketname,
Net: "unix",
})
}
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
// do nothing
// while on darwin and OpenBSD we would unlink here;
// on non-darwin the socket is abstract and not present on the filesystem
}

View File

@ -1,19 +0,0 @@
package socket
import (
"net"
"os"
"path/filepath"
"syscall"
)
func listen(socketname string) (*net.UnixListener, error) {
return net.ListenUnix("unix", &net.UnixAddr{
Name: filepath.Join(os.TempDir(), socketname),
Net: "unix",
})
}
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
syscall.Unlink(listener.Addr().String())
}

View File

@ -1,11 +1,14 @@
package socket
import (
"errors"
"io"
"io/fs"
"net"
"os"
"runtime"
"strings"
"sync/atomic"
"testing"
"time"
@ -13,54 +16,110 @@ import (
"gotest.tools/v3/poll"
)
func TestSetupConn(t *testing.T) {
t.Run("updates conn when connected", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
func TestPluginServer(t *testing.T) {
t.Run("connection closes with EOF when server closes", func(t *testing.T) {
called := make(chan struct{})
srv, err := NewPluginServer(func(_ net.Conn) { close(called) })
assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error")
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")
assert.Assert(t, srv != nil, "returned nil server but no error")
_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener")
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to resolve server address")
pollConnNotNil(t, &conn)
conn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned server")
defer conn.Close()
done := make(chan error, 1)
go func() {
_, err := conn.Read(make([]byte, 1))
done <- err
}()
select {
case <-called:
case <-time.After(10 * time.Millisecond):
t.Fatal("handler not called")
}
srv.Close()
select {
case err := <-done:
if !errors.Is(err, io.EOF) {
t.Fatalf("exepcted EOF error, got: %v", err)
}
case <-time.After(10 * time.Millisecond):
}
})
t.Run("allows reconnects", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
var calls int32
h := func(_ net.Conn) {
atomic.AddInt32(&calls, 1)
}
srv, err := NewPluginServer(h)
assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error")
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")
defer srv.Close()
assert.Check(t, srv.Addr() != nil, "returned nil addr but no error")
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to resolve server address")
waitForCalls := func(n int) {
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if atomic.LoadInt32(&calls) == int32(n) {
return poll.Success()
}
return poll.Continue("waiting for handler to be called")
})
}
otherConn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener")
assert.NilError(t, err, "failed to dial returned server")
otherConn.Close()
waitForCalls(1)
_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial listener")
conn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial server")
defer conn.Close()
waitForCalls(2)
// and again but don't close the existing connection
conn2, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial server")
defer conn2.Close()
waitForCalls(3)
srv.Close()
// now make sure we get EOF on the existing connections
buf := make([]byte, 1)
_, err = conn.Read(buf)
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
_, err = conn2.Read(buf)
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
})
t.Run("does not leak sockets to local directory", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
srv, err := NewPluginServer(nil)
assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error")
checkDirNoPluginSocket(t)
assert.Check(t, srv != nil, "returned nil server but no error")
checkDirNoNewPluginServer(t)
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to resolve server address")
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")
_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener")
checkDirNoPluginSocket(t)
assert.NilError(t, err, "failed to dial returned server")
checkDirNoNewPluginServer(t)
})
}
func checkDirNoPluginSocket(t *testing.T) {
func checkDirNoNewPluginServer(t *testing.T) {
t.Helper()
files, err := os.ReadDir(".")
@ -78,18 +137,24 @@ func checkDirNoPluginSocket(t *testing.T) {
func TestConnectAndWait(t *testing.T) {
t.Run("calls cancel func on EOF", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
assert.NilError(t, err, "failed to setup listener")
srv, err := NewPluginServer(nil)
assert.NilError(t, err, "failed to setup server")
defer srv.Close()
done := make(chan struct{})
t.Setenv(EnvKey, listener.Addr().String())
t.Setenv(EnvKey, srv.Addr().String())
cancelFunc := func() {
done <- struct{}{}
}
ConnectAndWait(cancelFunc)
pollConnNotNil(t, &conn)
conn.Close()
select {
case <-done:
t.Fatal("unexpectedly done")
default:
}
srv.Close()
select {
case <-done:
@ -101,17 +166,19 @@ func TestConnectAndWait(t *testing.T) {
// TODO: this test cannot be executed with `t.Parallel()`, due to
// relying on goroutine numbers to ensure correct behaviour
t.Run("connect goroutine exits after EOF", func(t *testing.T) {
var conn *net.UnixConn
listener, err := SetupConn(&conn)
assert.NilError(t, err, "failed to setup listener")
t.Setenv(EnvKey, listener.Addr().String())
srv, err := NewPluginServer(nil)
assert.NilError(t, err, "failed to setup server")
defer srv.Close()
t.Setenv(EnvKey, srv.Addr().String())
numGoroutines := runtime.NumGoroutine()
ConnectAndWait(func() {})
assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)
pollConnNotNil(t, &conn)
conn.Close()
srv.Close()
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if runtime.NumGoroutine() > numGoroutines+1 {
return poll.Continue("waiting for connect goroutine to exit")
@ -120,14 +187,3 @@ func TestConnectAndWait(t *testing.T) {
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
})
}
func pollConnNotNil(t *testing.T, conn **net.UnixConn) {
t.Helper()
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if *conn == nil {
return poll.Continue("waiting for conn to not be nil")
}
return poll.Success()
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
}

View File

@ -2,7 +2,6 @@ package main
import (
"fmt"
"net"
"os"
"os/exec"
"os/signal"
@ -222,11 +221,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
}
// Establish the plugin socket, adding it to the environment under a well-known key if successful.
var conn *net.UnixConn
listener, err := socket.SetupConn(&conn)
srv, err := socket.NewPluginServer(nil)
if err == nil {
envs = append(envs, socket.EnvKey+"="+listener.Addr().String())
defer listener.Close()
envs = append(envs, socket.EnvKey+"="+srv.Addr().String())
}
plugincmd.Env = append(envs, plugincmd.Env...)
@ -247,12 +244,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
// receive signals due to sharing a pgid with the parent CLI
continue
}
if conn != nil {
if err := conn.Close(); err != nil {
_, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err)
}
conn = nil
}
srv.Close()
retries++
if retries >= exitLimit {
_, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries)