plugin: closer-based plugin notification socket

This changes things to rely on a plugin server that manages all
connections made to the server.

An optional handler can be passed into the server when the caller wants
to do extra things with the connection.

It is the caller's responsibility to close the server.
When the server is closed, first all existing connections are closed
(and new connections are prevented).

Now the signal loop only needs to close the server and not deal with
`net.Conn`'s directly (or double-indirects as the case was before this
change).

The socket, when present in the filesystem, is no longer unlinked
eagerly, as reconnections require it to be present for the lifecycle of
the plugin server.

Co-authored-by: Bjorn Neergaard <bjorn.neergaard@docker.com>
Signed-off-by: Brian Goff <cpuguy83@gmail.com>
Signed-off-by: Bjorn Neergaard <bjorn.neergaard@docker.com>
This commit is contained in:
Brian Goff 2024-02-29 21:33:03 +00:00 committed by Bjorn Neergaard
parent 4468148f37
commit d68cc0e8d0
No known key found for this signature in database
8 changed files with 244 additions and 139 deletions

View File

@ -7,24 +7,104 @@ import (
"io" "io"
"net" "net"
"os" "os"
"runtime"
"sync"
) )
// EnvKey represents the well-known environment variable used to pass the plugin being // EnvKey represents the well-known environment variable used to pass the plugin being
// executed the socket name it should listen on to coordinate with the host CLI. // executed the socket name it should listen on to coordinate with the host CLI.
const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET" const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET"
// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections // NewPluginServer creates a plugin server that listens on a new Unix domain socket.
// and update the conn pointer, and returns the listener for the socket (which the caller // `h` is called for each new connection to the socket in a goroutine.
// is responsible for closing when it's no longer needed). func NewPluginServer(h func(net.Conn)) (*PluginServer, error) {
func SetupConn(conn **net.UnixConn) (*net.UnixListener, error) { l, err := listen("docker_cli_" + randomID())
listener, err := listen("docker_cli_" + randomID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
accept(listener, conn) if h == nil {
h = func(net.Conn) {}
}
return listener, nil pl := &PluginServer{
l: l,
h: h,
}
go func() {
defer pl.Close()
for {
err := pl.accept()
if err != nil {
return
}
}
}()
return pl, nil
}
type PluginServer struct {
mu sync.Mutex
conns []net.Conn
l *net.UnixListener
h func(net.Conn)
closed bool
}
func (pl *PluginServer) accept() error {
conn, err := pl.l.Accept()
if err != nil {
return err
}
pl.mu.Lock()
defer pl.mu.Unlock()
if pl.closed {
// handle potential race condition between Close and Accept
conn.Close()
return errors.New("plugin server is closed")
}
pl.conns = append(pl.conns, conn)
go pl.h(conn)
return nil
}
func (pl *PluginServer) Addr() net.Addr {
return pl.l.Addr()
}
// Close ensures that the server is no longer accepting new connections and closes all existing connections.
// Existing connections will receive [io.EOF].
func (pl *PluginServer) Close() error {
// Remove the listener socket, if it exists on the filesystem.
unlink(pl.l)
// Close connections first to ensure the connections get io.EOF instead of a connection reset.
pl.closeAllConns()
// Try to ensure that any active connections have a chance to receive io.EOF
runtime.Gosched()
return pl.l.Close()
}
func (pl *PluginServer) closeAllConns() {
pl.mu.Lock()
defer pl.mu.Unlock()
// Prevent new connections from being accepted
pl.closed = true
for _, conn := range pl.conns {
conn.Close()
}
pl.conns = nil
} }
func randomID() string { func randomID() string {
@ -35,18 +115,6 @@ func randomID() string {
return hex.EncodeToString(b) return hex.EncodeToString(b)
} }
func accept(listener *net.UnixListener, conn **net.UnixConn) {
go func() {
for {
// ignore error here, if we failed to accept a connection,
// conn is nil and we fallback to previous behavior
*conn, _ = listener.AcceptUnix()
// perform any platform-specific actions on accept (e.g. unlink non-abstract sockets)
onAccept(*conn, listener)
}
}()
}
// ConnectAndWait connects to the socket passed via well-known env var, // ConnectAndWait connects to the socket passed via well-known env var,
// if present, and attempts to read from it until it receives an EOF, at which // if present, and attempts to read from it until it receives an EOF, at which
// point cb is called. // point cb is called.

View File

@ -0,0 +1,20 @@
//go:build windows || linux
package socket
import (
"net"
)
func listen(socketname string) (*net.UnixListener, error) {
// Create an abstract socket -- this socket can be opened by name, but is
// not present in the filesystem.
return net.ListenUnix("unix", &net.UnixAddr{
Name: "@" + socketname,
Net: "unix",
})
}
func unlink(listener *net.UnixListener) {
// Do nothing; the socket is not present in the filesystem.
}

View File

@ -1,19 +0,0 @@
package socket
import (
"net"
"os"
"path/filepath"
"syscall"
)
func listen(socketname string) (*net.UnixListener, error) {
return net.ListenUnix("unix", &net.UnixAddr{
Name: filepath.Join(os.TempDir(), socketname),
Net: "unix",
})
}
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
syscall.Unlink(listener.Addr().String())
}

View File

@ -0,0 +1,25 @@
//go:build !windows && !linux
package socket
import (
"net"
"os"
"path/filepath"
"syscall"
)
func listen(socketname string) (*net.UnixListener, error) {
// Because abstract sockets are unavailable, we create a socket in the
// system temporary directory instead.
return net.ListenUnix("unix", &net.UnixAddr{
Name: filepath.Join(os.TempDir(), socketname),
Net: "unix",
})
}
func unlink(listener *net.UnixListener) {
// unlink(2) is best effort here; if it fails, we may 'leak' a socket
// into the filesystem, but this is unlikely and overall harmless.
_ = syscall.Unlink(listener.Addr().String())
}

View File

@ -1,20 +0,0 @@
//go:build !darwin && !openbsd
package socket
import (
"net"
)
func listen(socketname string) (*net.UnixListener, error) {
return net.ListenUnix("unix", &net.UnixAddr{
Name: "@" + socketname,
Net: "unix",
})
}
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
// do nothing
// while on darwin and OpenBSD we would unlink here;
// on non-darwin the socket is abstract and not present on the filesystem
}

View File

@ -1,19 +0,0 @@
package socket
import (
"net"
"os"
"path/filepath"
"syscall"
)
func listen(socketname string) (*net.UnixListener, error) {
return net.ListenUnix("unix", &net.UnixAddr{
Name: filepath.Join(os.TempDir(), socketname),
Net: "unix",
})
}
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
syscall.Unlink(listener.Addr().String())
}

View File

@ -1,11 +1,14 @@
package socket package socket
import ( import (
"errors"
"io"
"io/fs" "io/fs"
"net" "net"
"os" "os"
"runtime" "runtime"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -13,54 +16,110 @@ import (
"gotest.tools/v3/poll" "gotest.tools/v3/poll"
) )
func TestSetupConn(t *testing.T) { func TestPluginServer(t *testing.T) {
t.Run("updates conn when connected", func(t *testing.T) { t.Run("connection closes with EOF when server closes", func(t *testing.T) {
var conn *net.UnixConn called := make(chan struct{})
listener, err := SetupConn(&conn) srv, err := NewPluginServer(func(_ net.Conn) { close(called) })
assert.NilError(t, err) assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error") assert.Assert(t, srv != nil, "returned nil server but no error")
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")
_, err = net.DialUnix("unix", nil, addr) addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to dial returned listener") assert.NilError(t, err, "failed to resolve server address")
pollConnNotNil(t, &conn) conn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned server")
defer conn.Close()
done := make(chan error, 1)
go func() {
_, err := conn.Read(make([]byte, 1))
done <- err
}()
select {
case <-called:
case <-time.After(10 * time.Millisecond):
t.Fatal("handler not called")
}
srv.Close()
select {
case err := <-done:
if !errors.Is(err, io.EOF) {
t.Fatalf("exepcted EOF error, got: %v", err)
}
case <-time.After(10 * time.Millisecond):
}
}) })
t.Run("allows reconnects", func(t *testing.T) { t.Run("allows reconnects", func(t *testing.T) {
var conn *net.UnixConn var calls int32
listener, err := SetupConn(&conn) h := func(_ net.Conn) {
atomic.AddInt32(&calls, 1)
}
srv, err := NewPluginServer(h)
assert.NilError(t, err) assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error") defer srv.Close()
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")
otherConn, err := net.DialUnix("unix", nil, addr) assert.Check(t, srv.Addr() != nil, "returned nil addr but no error")
assert.NilError(t, err, "failed to dial returned listener")
otherConn.Close() addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to resolve server address")
_, err = net.DialUnix("unix", nil, addr) waitForCalls := func(n int) {
assert.NilError(t, err, "failed to redial listener") poll.WaitOn(t, func(t poll.LogT) poll.Result {
}) if atomic.LoadInt32(&calls) == int32(n) {
return poll.Success()
t.Run("does not leak sockets to local directory", func(t *testing.T) { }
var conn *net.UnixConn return poll.Continue("waiting for handler to be called")
listener, err := SetupConn(&conn)
assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error")
checkDirNoPluginSocket(t)
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")
_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener")
checkDirNoPluginSocket(t)
}) })
} }
func checkDirNoPluginSocket(t *testing.T) { otherConn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned server")
otherConn.Close()
waitForCalls(1)
conn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial server")
defer conn.Close()
waitForCalls(2)
// and again but don't close the existing connection
conn2, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial server")
defer conn2.Close()
waitForCalls(3)
srv.Close()
// now make sure we get EOF on the existing connections
buf := make([]byte, 1)
_, err = conn.Read(buf)
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
_, err = conn2.Read(buf)
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
})
t.Run("does not leak sockets to local directory", func(t *testing.T) {
srv, err := NewPluginServer(nil)
assert.NilError(t, err)
assert.Check(t, srv != nil, "returned nil server but no error")
checkDirNoNewPluginServer(t)
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to resolve server address")
_, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned server")
checkDirNoNewPluginServer(t)
})
}
func checkDirNoNewPluginServer(t *testing.T) {
t.Helper() t.Helper()
files, err := os.ReadDir(".") files, err := os.ReadDir(".")
@ -78,18 +137,24 @@ func checkDirNoPluginSocket(t *testing.T) {
func TestConnectAndWait(t *testing.T) { func TestConnectAndWait(t *testing.T) {
t.Run("calls cancel func on EOF", func(t *testing.T) { t.Run("calls cancel func on EOF", func(t *testing.T) {
var conn *net.UnixConn srv, err := NewPluginServer(nil)
listener, err := SetupConn(&conn) assert.NilError(t, err, "failed to setup server")
assert.NilError(t, err, "failed to setup listener") defer srv.Close()
done := make(chan struct{}) done := make(chan struct{})
t.Setenv(EnvKey, listener.Addr().String()) t.Setenv(EnvKey, srv.Addr().String())
cancelFunc := func() { cancelFunc := func() {
done <- struct{}{} done <- struct{}{}
} }
ConnectAndWait(cancelFunc) ConnectAndWait(cancelFunc)
pollConnNotNil(t, &conn)
conn.Close() select {
case <-done:
t.Fatal("unexpectedly done")
default:
}
srv.Close()
select { select {
case <-done: case <-done:
@ -101,17 +166,19 @@ func TestConnectAndWait(t *testing.T) {
// TODO: this test cannot be executed with `t.Parallel()`, due to // TODO: this test cannot be executed with `t.Parallel()`, due to
// relying on goroutine numbers to ensure correct behaviour // relying on goroutine numbers to ensure correct behaviour
t.Run("connect goroutine exits after EOF", func(t *testing.T) { t.Run("connect goroutine exits after EOF", func(t *testing.T) {
var conn *net.UnixConn srv, err := NewPluginServer(nil)
listener, err := SetupConn(&conn) assert.NilError(t, err, "failed to setup server")
assert.NilError(t, err, "failed to setup listener")
t.Setenv(EnvKey, listener.Addr().String()) defer srv.Close()
t.Setenv(EnvKey, srv.Addr().String())
numGoroutines := runtime.NumGoroutine() numGoroutines := runtime.NumGoroutine()
ConnectAndWait(func() {}) ConnectAndWait(func() {})
assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1) assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)
pollConnNotNil(t, &conn) srv.Close()
conn.Close()
poll.WaitOn(t, func(t poll.LogT) poll.Result { poll.WaitOn(t, func(t poll.LogT) poll.Result {
if runtime.NumGoroutine() > numGoroutines+1 { if runtime.NumGoroutine() > numGoroutines+1 {
return poll.Continue("waiting for connect goroutine to exit") return poll.Continue("waiting for connect goroutine to exit")
@ -120,14 +187,3 @@ func TestConnectAndWait(t *testing.T) {
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
}) })
} }
func pollConnNotNil(t *testing.T, conn **net.UnixConn) {
t.Helper()
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if *conn == nil {
return poll.Continue("waiting for conn to not be nil")
}
return poll.Success()
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
}

View File

@ -2,7 +2,6 @@ package main
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"os/exec" "os/exec"
"os/signal" "os/signal"
@ -222,11 +221,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
} }
// Establish the plugin socket, adding it to the environment under a well-known key if successful. // Establish the plugin socket, adding it to the environment under a well-known key if successful.
var conn *net.UnixConn srv, err := socket.NewPluginServer(nil)
listener, err := socket.SetupConn(&conn)
if err == nil { if err == nil {
envs = append(envs, socket.EnvKey+"="+listener.Addr().String()) envs = append(envs, socket.EnvKey+"="+srv.Addr().String())
defer listener.Close()
} }
plugincmd.Env = append(envs, plugincmd.Env...) plugincmd.Env = append(envs, plugincmd.Env...)
@ -247,12 +244,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
// receive signals due to sharing a pgid with the parent CLI // receive signals due to sharing a pgid with the parent CLI
continue continue
} }
if conn != nil {
if err := conn.Close(); err != nil { srv.Close()
_, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err)
}
conn = nil
}
retries++ retries++
if retries >= exitLimit { if retries >= exitLimit {
_, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries) _, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries)