Merge pull request #4905 from cpuguy83/plugin_notify_conn_cleanup

plugin: closer-based plugin notification socket
This commit is contained in:
Bjorn Neergaard 2024-03-21 21:52:30 -06:00 committed by GitHub
commit 318911b404
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 244 additions and 139 deletions

View File

@ -7,24 +7,104 @@ import (
"io" "io"
"net" "net"
"os" "os"
"runtime"
"sync"
) )
// EnvKey represents the well-known environment variable used to pass the plugin being // EnvKey represents the well-known environment variable used to pass the plugin being
// executed the socket name it should listen on to coordinate with the host CLI. // executed the socket name it should listen on to coordinate with the host CLI.
const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET" const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET"
// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections // NewPluginServer creates a plugin server that listens on a new Unix domain socket.
// and update the conn pointer, and returns the listener for the socket (which the caller // `h` is called for each new connection to the socket in a goroutine.
// is responsible for closing when it's no longer needed). func NewPluginServer(h func(net.Conn)) (*PluginServer, error) {
func SetupConn(conn **net.UnixConn) (*net.UnixListener, error) { l, err := listen("docker_cli_" + randomID())
listener, err := listen("docker_cli_" + randomID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
accept(listener, conn) if h == nil {
h = func(net.Conn) {}
}
return listener, nil pl := &PluginServer{
l: l,
h: h,
}
go func() {
defer pl.Close()
for {
err := pl.accept()
if err != nil {
return
}
}
}()
return pl, nil
}
type PluginServer struct {
mu sync.Mutex
conns []net.Conn
l *net.UnixListener
h func(net.Conn)
closed bool
}
func (pl *PluginServer) accept() error {
conn, err := pl.l.Accept()
if err != nil {
return err
}
pl.mu.Lock()
defer pl.mu.Unlock()
if pl.closed {
// handle potential race condition between Close and Accept
conn.Close()
return errors.New("plugin server is closed")
}
pl.conns = append(pl.conns, conn)
go pl.h(conn)
return nil
}
func (pl *PluginServer) Addr() net.Addr {
return pl.l.Addr()
}
// Close ensures that the server is no longer accepting new connections and closes all existing connections.
// Existing connections will receive [io.EOF].
func (pl *PluginServer) Close() error {
// Remove the listener socket, if it exists on the filesystem.
unlink(pl.l)
// Close connections first to ensure the connections get io.EOF instead of a connection reset.
pl.closeAllConns()
// Try to ensure that any active connections have a chance to receive io.EOF
runtime.Gosched()
return pl.l.Close()
}
func (pl *PluginServer) closeAllConns() {
pl.mu.Lock()
defer pl.mu.Unlock()
// Prevent new connections from being accepted
pl.closed = true
for _, conn := range pl.conns {
conn.Close()
}
pl.conns = nil
} }
func randomID() string { func randomID() string {
@ -35,18 +115,6 @@ func randomID() string {
return hex.EncodeToString(b) return hex.EncodeToString(b)
} }
func accept(listener *net.UnixListener, conn **net.UnixConn) {
go func() {
for {
// ignore error here, if we failed to accept a connection,
// conn is nil and we fallback to previous behavior
*conn, _ = listener.AcceptUnix()
// perform any platform-specific actions on accept (e.g. unlink non-abstract sockets)
onAccept(*conn, listener)
}
}()
}
// ConnectAndWait connects to the socket passed via well-known env var, // ConnectAndWait connects to the socket passed via well-known env var,
// if present, and attempts to read from it until it receives an EOF, at which // if present, and attempts to read from it until it receives an EOF, at which
// point cb is called. // point cb is called.

View File

@ -0,0 +1,20 @@
//go:build windows || linux
package socket
import (
"net"
)
func listen(socketname string) (*net.UnixListener, error) {
// Create an abstract socket -- this socket can be opened by name, but is
// not present in the filesystem.
return net.ListenUnix("unix", &net.UnixAddr{
Name: "@" + socketname,
Net: "unix",
})
}
func unlink(listener *net.UnixListener) {
// Do nothing; the socket is not present in the filesystem.
}

View File

@ -1,19 +0,0 @@
package socket
import (
"net"
"os"
"path/filepath"
"syscall"
)
func listen(socketname string) (*net.UnixListener, error) {
return net.ListenUnix("unix", &net.UnixAddr{
Name: filepath.Join(os.TempDir(), socketname),
Net: "unix",
})
}
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
syscall.Unlink(listener.Addr().String())
}

View File

@ -0,0 +1,25 @@
//go:build !windows && !linux
package socket
import (
"net"
"os"
"path/filepath"
"syscall"
)
func listen(socketname string) (*net.UnixListener, error) {
// Because abstract sockets are unavailable, we create a socket in the
// system temporary directory instead.
return net.ListenUnix("unix", &net.UnixAddr{
Name: filepath.Join(os.TempDir(), socketname),
Net: "unix",
})
}
func unlink(listener *net.UnixListener) {
// unlink(2) is best effort here; if it fails, we may 'leak' a socket
// into the filesystem, but this is unlikely and overall harmless.
_ = syscall.Unlink(listener.Addr().String())
}

View File

@ -1,20 +0,0 @@
//go:build !darwin && !openbsd
package socket
import (
"net"
)
func listen(socketname string) (*net.UnixListener, error) {
return net.ListenUnix("unix", &net.UnixAddr{
Name: "@" + socketname,
Net: "unix",
})
}
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
// do nothing
// while on darwin and OpenBSD we would unlink here;
// on non-darwin the socket is abstract and not present on the filesystem
}

View File

@ -1,19 +0,0 @@
package socket
import (
"net"
"os"
"path/filepath"
"syscall"
)
func listen(socketname string) (*net.UnixListener, error) {
return net.ListenUnix("unix", &net.UnixAddr{
Name: filepath.Join(os.TempDir(), socketname),
Net: "unix",
})
}
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
syscall.Unlink(listener.Addr().String())
}

View File

@ -1,11 +1,14 @@
package socket package socket
import ( import (
"errors"
"io"
"io/fs" "io/fs"
"net" "net"
"os" "os"
"runtime" "runtime"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -13,54 +16,110 @@ import (
"gotest.tools/v3/poll" "gotest.tools/v3/poll"
) )
func TestSetupConn(t *testing.T) { func TestPluginServer(t *testing.T) {
t.Run("updates conn when connected", func(t *testing.T) { t.Run("connection closes with EOF when server closes", func(t *testing.T) {
var conn *net.UnixConn called := make(chan struct{})
listener, err := SetupConn(&conn) srv, err := NewPluginServer(func(_ net.Conn) { close(called) })
assert.NilError(t, err) assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error") assert.Assert(t, srv != nil, "returned nil server but no error")
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")
_, err = net.DialUnix("unix", nil, addr) addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to dial returned listener") assert.NilError(t, err, "failed to resolve server address")
pollConnNotNil(t, &conn) conn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned server")
defer conn.Close()
done := make(chan error, 1)
go func() {
_, err := conn.Read(make([]byte, 1))
done <- err
}()
select {
case <-called:
case <-time.After(10 * time.Millisecond):
t.Fatal("handler not called")
}
srv.Close()
select {
case err := <-done:
if !errors.Is(err, io.EOF) {
t.Fatalf("exepcted EOF error, got: %v", err)
}
case <-time.After(10 * time.Millisecond):
}
}) })
t.Run("allows reconnects", func(t *testing.T) { t.Run("allows reconnects", func(t *testing.T) {
var conn *net.UnixConn var calls int32
listener, err := SetupConn(&conn) h := func(_ net.Conn) {
atomic.AddInt32(&calls, 1)
}
srv, err := NewPluginServer(h)
assert.NilError(t, err) assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error") defer srv.Close()
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address") assert.Check(t, srv.Addr() != nil, "returned nil addr but no error")
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to resolve server address")
waitForCalls := func(n int) {
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if atomic.LoadInt32(&calls) == int32(n) {
return poll.Success()
}
return poll.Continue("waiting for handler to be called")
})
}
otherConn, err := net.DialUnix("unix", nil, addr) otherConn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener") assert.NilError(t, err, "failed to dial returned server")
otherConn.Close() otherConn.Close()
waitForCalls(1)
_, err = net.DialUnix("unix", nil, addr) conn, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial listener") assert.NilError(t, err, "failed to redial server")
defer conn.Close()
waitForCalls(2)
// and again but don't close the existing connection
conn2, err := net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to redial server")
defer conn2.Close()
waitForCalls(3)
srv.Close()
// now make sure we get EOF on the existing connections
buf := make([]byte, 1)
_, err = conn.Read(buf)
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
_, err = conn2.Read(buf)
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
}) })
t.Run("does not leak sockets to local directory", func(t *testing.T) { t.Run("does not leak sockets to local directory", func(t *testing.T) {
var conn *net.UnixConn srv, err := NewPluginServer(nil)
listener, err := SetupConn(&conn)
assert.NilError(t, err) assert.NilError(t, err)
assert.Check(t, listener != nil, "returned nil listener but no error") assert.Check(t, srv != nil, "returned nil server but no error")
checkDirNoPluginSocket(t) checkDirNoNewPluginServer(t)
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
assert.NilError(t, err, "failed to resolve server address")
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
assert.NilError(t, err, "failed to resolve listener address")
_, err = net.DialUnix("unix", nil, addr) _, err = net.DialUnix("unix", nil, addr)
assert.NilError(t, err, "failed to dial returned listener") assert.NilError(t, err, "failed to dial returned server")
checkDirNoPluginSocket(t) checkDirNoNewPluginServer(t)
}) })
} }
func checkDirNoPluginSocket(t *testing.T) { func checkDirNoNewPluginServer(t *testing.T) {
t.Helper() t.Helper()
files, err := os.ReadDir(".") files, err := os.ReadDir(".")
@ -78,18 +137,24 @@ func checkDirNoPluginSocket(t *testing.T) {
func TestConnectAndWait(t *testing.T) { func TestConnectAndWait(t *testing.T) {
t.Run("calls cancel func on EOF", func(t *testing.T) { t.Run("calls cancel func on EOF", func(t *testing.T) {
var conn *net.UnixConn srv, err := NewPluginServer(nil)
listener, err := SetupConn(&conn) assert.NilError(t, err, "failed to setup server")
assert.NilError(t, err, "failed to setup listener") defer srv.Close()
done := make(chan struct{}) done := make(chan struct{})
t.Setenv(EnvKey, listener.Addr().String()) t.Setenv(EnvKey, srv.Addr().String())
cancelFunc := func() { cancelFunc := func() {
done <- struct{}{} done <- struct{}{}
} }
ConnectAndWait(cancelFunc) ConnectAndWait(cancelFunc)
pollConnNotNil(t, &conn)
conn.Close() select {
case <-done:
t.Fatal("unexpectedly done")
default:
}
srv.Close()
select { select {
case <-done: case <-done:
@ -101,17 +166,19 @@ func TestConnectAndWait(t *testing.T) {
// TODO: this test cannot be executed with `t.Parallel()`, due to // TODO: this test cannot be executed with `t.Parallel()`, due to
// relying on goroutine numbers to ensure correct behaviour // relying on goroutine numbers to ensure correct behaviour
t.Run("connect goroutine exits after EOF", func(t *testing.T) { t.Run("connect goroutine exits after EOF", func(t *testing.T) {
var conn *net.UnixConn srv, err := NewPluginServer(nil)
listener, err := SetupConn(&conn) assert.NilError(t, err, "failed to setup server")
assert.NilError(t, err, "failed to setup listener")
t.Setenv(EnvKey, listener.Addr().String()) defer srv.Close()
t.Setenv(EnvKey, srv.Addr().String())
numGoroutines := runtime.NumGoroutine() numGoroutines := runtime.NumGoroutine()
ConnectAndWait(func() {}) ConnectAndWait(func() {})
assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1) assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)
pollConnNotNil(t, &conn) srv.Close()
conn.Close()
poll.WaitOn(t, func(t poll.LogT) poll.Result { poll.WaitOn(t, func(t poll.LogT) poll.Result {
if runtime.NumGoroutine() > numGoroutines+1 { if runtime.NumGoroutine() > numGoroutines+1 {
return poll.Continue("waiting for connect goroutine to exit") return poll.Continue("waiting for connect goroutine to exit")
@ -120,14 +187,3 @@ func TestConnectAndWait(t *testing.T) {
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond)) }, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
}) })
} }
func pollConnNotNil(t *testing.T, conn **net.UnixConn) {
t.Helper()
poll.WaitOn(t, func(t poll.LogT) poll.Result {
if *conn == nil {
return poll.Continue("waiting for conn to not be nil")
}
return poll.Success()
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
}

View File

@ -2,7 +2,6 @@ package main
import ( import (
"fmt" "fmt"
"net"
"os" "os"
"os/exec" "os/exec"
"os/signal" "os/signal"
@ -222,11 +221,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
} }
// Establish the plugin socket, adding it to the environment under a well-known key if successful. // Establish the plugin socket, adding it to the environment under a well-known key if successful.
var conn *net.UnixConn srv, err := socket.NewPluginServer(nil)
listener, err := socket.SetupConn(&conn)
if err == nil { if err == nil {
envs = append(envs, socket.EnvKey+"="+listener.Addr().String()) envs = append(envs, socket.EnvKey+"="+srv.Addr().String())
defer listener.Close()
} }
plugincmd.Env = append(envs, plugincmd.Env...) plugincmd.Env = append(envs, plugincmd.Env...)
@ -247,12 +244,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
// receive signals due to sharing a pgid with the parent CLI // receive signals due to sharing a pgid with the parent CLI
continue continue
} }
if conn != nil {
if err := conn.Close(); err != nil { srv.Close()
_, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err)
}
conn = nil
}
retries++ retries++
if retries >= exitLimit { if retries >= exitLimit {
_, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries) _, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries)