mirror of https://github.com/docker/cli.git
Merge pull request #4905 from cpuguy83/plugin_notify_conn_cleanup
plugin: closer-based plugin notification socket
This commit is contained in:
commit
318911b404
|
@ -7,24 +7,104 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// EnvKey represents the well-known environment variable used to pass the plugin being
|
||||
// executed the socket name it should listen on to coordinate with the host CLI.
|
||||
const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET"
|
||||
|
||||
// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections
|
||||
// and update the conn pointer, and returns the listener for the socket (which the caller
|
||||
// is responsible for closing when it's no longer needed).
|
||||
func SetupConn(conn **net.UnixConn) (*net.UnixListener, error) {
|
||||
listener, err := listen("docker_cli_" + randomID())
|
||||
// NewPluginServer creates a plugin server that listens on a new Unix domain socket.
|
||||
// `h` is called for each new connection to the socket in a goroutine.
|
||||
func NewPluginServer(h func(net.Conn)) (*PluginServer, error) {
|
||||
l, err := listen("docker_cli_" + randomID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accept(listener, conn)
|
||||
if h == nil {
|
||||
h = func(net.Conn) {}
|
||||
}
|
||||
|
||||
return listener, nil
|
||||
pl := &PluginServer{
|
||||
l: l,
|
||||
h: h,
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer pl.Close()
|
||||
for {
|
||||
err := pl.accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return pl, nil
|
||||
}
|
||||
|
||||
type PluginServer struct {
|
||||
mu sync.Mutex
|
||||
conns []net.Conn
|
||||
l *net.UnixListener
|
||||
h func(net.Conn)
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (pl *PluginServer) accept() error {
|
||||
conn, err := pl.l.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pl.mu.Lock()
|
||||
defer pl.mu.Unlock()
|
||||
|
||||
if pl.closed {
|
||||
// handle potential race condition between Close and Accept
|
||||
conn.Close()
|
||||
return errors.New("plugin server is closed")
|
||||
}
|
||||
|
||||
pl.conns = append(pl.conns, conn)
|
||||
|
||||
go pl.h(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pl *PluginServer) Addr() net.Addr {
|
||||
return pl.l.Addr()
|
||||
}
|
||||
|
||||
// Close ensures that the server is no longer accepting new connections and closes all existing connections.
|
||||
// Existing connections will receive [io.EOF].
|
||||
func (pl *PluginServer) Close() error {
|
||||
// Remove the listener socket, if it exists on the filesystem.
|
||||
unlink(pl.l)
|
||||
|
||||
// Close connections first to ensure the connections get io.EOF instead of a connection reset.
|
||||
pl.closeAllConns()
|
||||
|
||||
// Try to ensure that any active connections have a chance to receive io.EOF
|
||||
runtime.Gosched()
|
||||
|
||||
return pl.l.Close()
|
||||
}
|
||||
|
||||
func (pl *PluginServer) closeAllConns() {
|
||||
pl.mu.Lock()
|
||||
defer pl.mu.Unlock()
|
||||
|
||||
// Prevent new connections from being accepted
|
||||
pl.closed = true
|
||||
|
||||
for _, conn := range pl.conns {
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
pl.conns = nil
|
||||
}
|
||||
|
||||
func randomID() string {
|
||||
|
@ -35,18 +115,6 @@ func randomID() string {
|
|||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
func accept(listener *net.UnixListener, conn **net.UnixConn) {
|
||||
go func() {
|
||||
for {
|
||||
// ignore error here, if we failed to accept a connection,
|
||||
// conn is nil and we fallback to previous behavior
|
||||
*conn, _ = listener.AcceptUnix()
|
||||
// perform any platform-specific actions on accept (e.g. unlink non-abstract sockets)
|
||||
onAccept(*conn, listener)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// ConnectAndWait connects to the socket passed via well-known env var,
|
||||
// if present, and attempts to read from it until it receives an EOF, at which
|
||||
// point cb is called.
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
//go:build windows || linux
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func listen(socketname string) (*net.UnixListener, error) {
|
||||
// Create an abstract socket -- this socket can be opened by name, but is
|
||||
// not present in the filesystem.
|
||||
return net.ListenUnix("unix", &net.UnixAddr{
|
||||
Name: "@" + socketname,
|
||||
Net: "unix",
|
||||
})
|
||||
}
|
||||
|
||||
func unlink(listener *net.UnixListener) {
|
||||
// Do nothing; the socket is not present in the filesystem.
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func listen(socketname string) (*net.UnixListener, error) {
|
||||
return net.ListenUnix("unix", &net.UnixAddr{
|
||||
Name: filepath.Join(os.TempDir(), socketname),
|
||||
Net: "unix",
|
||||
})
|
||||
}
|
||||
|
||||
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
|
||||
syscall.Unlink(listener.Addr().String())
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
//go:build !windows && !linux
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func listen(socketname string) (*net.UnixListener, error) {
|
||||
// Because abstract sockets are unavailable, we create a socket in the
|
||||
// system temporary directory instead.
|
||||
return net.ListenUnix("unix", &net.UnixAddr{
|
||||
Name: filepath.Join(os.TempDir(), socketname),
|
||||
Net: "unix",
|
||||
})
|
||||
}
|
||||
|
||||
func unlink(listener *net.UnixListener) {
|
||||
// unlink(2) is best effort here; if it fails, we may 'leak' a socket
|
||||
// into the filesystem, but this is unlikely and overall harmless.
|
||||
_ = syscall.Unlink(listener.Addr().String())
|
||||
}
|
|
@ -1,20 +0,0 @@
|
|||
//go:build !darwin && !openbsd
|
||||
|
||||
package socket
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
func listen(socketname string) (*net.UnixListener, error) {
|
||||
return net.ListenUnix("unix", &net.UnixAddr{
|
||||
Name: "@" + socketname,
|
||||
Net: "unix",
|
||||
})
|
||||
}
|
||||
|
||||
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
|
||||
// do nothing
|
||||
// while on darwin and OpenBSD we would unlink here;
|
||||
// on non-darwin the socket is abstract and not present on the filesystem
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func listen(socketname string) (*net.UnixListener, error) {
|
||||
return net.ListenUnix("unix", &net.UnixAddr{
|
||||
Name: filepath.Join(os.TempDir(), socketname),
|
||||
Net: "unix",
|
||||
})
|
||||
}
|
||||
|
||||
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
|
||||
syscall.Unlink(listener.Addr().String())
|
||||
}
|
|
@ -1,11 +1,14 @@
|
|||
package socket
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -13,54 +16,110 @@ import (
|
|||
"gotest.tools/v3/poll"
|
||||
)
|
||||
|
||||
func TestSetupConn(t *testing.T) {
|
||||
t.Run("updates conn when connected", func(t *testing.T) {
|
||||
var conn *net.UnixConn
|
||||
listener, err := SetupConn(&conn)
|
||||
func TestPluginServer(t *testing.T) {
|
||||
t.Run("connection closes with EOF when server closes", func(t *testing.T) {
|
||||
called := make(chan struct{})
|
||||
srv, err := NewPluginServer(func(_ net.Conn) { close(called) })
|
||||
assert.NilError(t, err)
|
||||
assert.Check(t, listener != nil, "returned nil listener but no error")
|
||||
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
|
||||
assert.NilError(t, err, "failed to resolve listener address")
|
||||
assert.Assert(t, srv != nil, "returned nil server but no error")
|
||||
|
||||
_, err = net.DialUnix("unix", nil, addr)
|
||||
assert.NilError(t, err, "failed to dial returned listener")
|
||||
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
|
||||
assert.NilError(t, err, "failed to resolve server address")
|
||||
|
||||
pollConnNotNil(t, &conn)
|
||||
conn, err := net.DialUnix("unix", nil, addr)
|
||||
assert.NilError(t, err, "failed to dial returned server")
|
||||
defer conn.Close()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := conn.Read(make([]byte, 1))
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-called:
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
t.Fatal("handler not called")
|
||||
}
|
||||
|
||||
srv.Close()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if !errors.Is(err, io.EOF) {
|
||||
t.Fatalf("exepcted EOF error, got: %v", err)
|
||||
}
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allows reconnects", func(t *testing.T) {
|
||||
var conn *net.UnixConn
|
||||
listener, err := SetupConn(&conn)
|
||||
var calls int32
|
||||
h := func(_ net.Conn) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
}
|
||||
|
||||
srv, err := NewPluginServer(h)
|
||||
assert.NilError(t, err)
|
||||
assert.Check(t, listener != nil, "returned nil listener but no error")
|
||||
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
|
||||
assert.NilError(t, err, "failed to resolve listener address")
|
||||
defer srv.Close()
|
||||
|
||||
assert.Check(t, srv.Addr() != nil, "returned nil addr but no error")
|
||||
|
||||
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
|
||||
assert.NilError(t, err, "failed to resolve server address")
|
||||
|
||||
waitForCalls := func(n int) {
|
||||
poll.WaitOn(t, func(t poll.LogT) poll.Result {
|
||||
if atomic.LoadInt32(&calls) == int32(n) {
|
||||
return poll.Success()
|
||||
}
|
||||
return poll.Continue("waiting for handler to be called")
|
||||
})
|
||||
}
|
||||
|
||||
otherConn, err := net.DialUnix("unix", nil, addr)
|
||||
assert.NilError(t, err, "failed to dial returned listener")
|
||||
|
||||
assert.NilError(t, err, "failed to dial returned server")
|
||||
otherConn.Close()
|
||||
waitForCalls(1)
|
||||
|
||||
_, err = net.DialUnix("unix", nil, addr)
|
||||
assert.NilError(t, err, "failed to redial listener")
|
||||
conn, err := net.DialUnix("unix", nil, addr)
|
||||
assert.NilError(t, err, "failed to redial server")
|
||||
defer conn.Close()
|
||||
waitForCalls(2)
|
||||
|
||||
// and again but don't close the existing connection
|
||||
conn2, err := net.DialUnix("unix", nil, addr)
|
||||
assert.NilError(t, err, "failed to redial server")
|
||||
defer conn2.Close()
|
||||
waitForCalls(3)
|
||||
|
||||
srv.Close()
|
||||
|
||||
// now make sure we get EOF on the existing connections
|
||||
buf := make([]byte, 1)
|
||||
_, err = conn.Read(buf)
|
||||
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
|
||||
|
||||
_, err = conn2.Read(buf)
|
||||
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
|
||||
})
|
||||
|
||||
t.Run("does not leak sockets to local directory", func(t *testing.T) {
|
||||
var conn *net.UnixConn
|
||||
listener, err := SetupConn(&conn)
|
||||
srv, err := NewPluginServer(nil)
|
||||
assert.NilError(t, err)
|
||||
assert.Check(t, listener != nil, "returned nil listener but no error")
|
||||
checkDirNoPluginSocket(t)
|
||||
assert.Check(t, srv != nil, "returned nil server but no error")
|
||||
checkDirNoNewPluginServer(t)
|
||||
|
||||
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
|
||||
assert.NilError(t, err, "failed to resolve server address")
|
||||
|
||||
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
|
||||
assert.NilError(t, err, "failed to resolve listener address")
|
||||
_, err = net.DialUnix("unix", nil, addr)
|
||||
assert.NilError(t, err, "failed to dial returned listener")
|
||||
checkDirNoPluginSocket(t)
|
||||
assert.NilError(t, err, "failed to dial returned server")
|
||||
checkDirNoNewPluginServer(t)
|
||||
})
|
||||
}
|
||||
|
||||
func checkDirNoPluginSocket(t *testing.T) {
|
||||
func checkDirNoNewPluginServer(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
files, err := os.ReadDir(".")
|
||||
|
@ -78,18 +137,24 @@ func checkDirNoPluginSocket(t *testing.T) {
|
|||
|
||||
func TestConnectAndWait(t *testing.T) {
|
||||
t.Run("calls cancel func on EOF", func(t *testing.T) {
|
||||
var conn *net.UnixConn
|
||||
listener, err := SetupConn(&conn)
|
||||
assert.NilError(t, err, "failed to setup listener")
|
||||
srv, err := NewPluginServer(nil)
|
||||
assert.NilError(t, err, "failed to setup server")
|
||||
defer srv.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
t.Setenv(EnvKey, listener.Addr().String())
|
||||
t.Setenv(EnvKey, srv.Addr().String())
|
||||
cancelFunc := func() {
|
||||
done <- struct{}{}
|
||||
}
|
||||
ConnectAndWait(cancelFunc)
|
||||
pollConnNotNil(t, &conn)
|
||||
conn.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
t.Fatal("unexpectedly done")
|
||||
default:
|
||||
}
|
||||
|
||||
srv.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
|
@ -101,17 +166,19 @@ func TestConnectAndWait(t *testing.T) {
|
|||
// TODO: this test cannot be executed with `t.Parallel()`, due to
|
||||
// relying on goroutine numbers to ensure correct behaviour
|
||||
t.Run("connect goroutine exits after EOF", func(t *testing.T) {
|
||||
var conn *net.UnixConn
|
||||
listener, err := SetupConn(&conn)
|
||||
assert.NilError(t, err, "failed to setup listener")
|
||||
t.Setenv(EnvKey, listener.Addr().String())
|
||||
srv, err := NewPluginServer(nil)
|
||||
assert.NilError(t, err, "failed to setup server")
|
||||
|
||||
defer srv.Close()
|
||||
|
||||
t.Setenv(EnvKey, srv.Addr().String())
|
||||
numGoroutines := runtime.NumGoroutine()
|
||||
|
||||
ConnectAndWait(func() {})
|
||||
assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)
|
||||
|
||||
pollConnNotNil(t, &conn)
|
||||
conn.Close()
|
||||
srv.Close()
|
||||
|
||||
poll.WaitOn(t, func(t poll.LogT) poll.Result {
|
||||
if runtime.NumGoroutine() > numGoroutines+1 {
|
||||
return poll.Continue("waiting for connect goroutine to exit")
|
||||
|
@ -120,14 +187,3 @@ func TestConnectAndWait(t *testing.T) {
|
|||
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
|
||||
})
|
||||
}
|
||||
|
||||
func pollConnNotNil(t *testing.T, conn **net.UnixConn) {
|
||||
t.Helper()
|
||||
|
||||
poll.WaitOn(t, func(t poll.LogT) poll.Result {
|
||||
if *conn == nil {
|
||||
return poll.Continue("waiting for conn to not be nil")
|
||||
}
|
||||
return poll.Success()
|
||||
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
|
||||
}
|
||||
|
|
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
|
@ -222,11 +221,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
|
|||
}
|
||||
|
||||
// Establish the plugin socket, adding it to the environment under a well-known key if successful.
|
||||
var conn *net.UnixConn
|
||||
listener, err := socket.SetupConn(&conn)
|
||||
srv, err := socket.NewPluginServer(nil)
|
||||
if err == nil {
|
||||
envs = append(envs, socket.EnvKey+"="+listener.Addr().String())
|
||||
defer listener.Close()
|
||||
envs = append(envs, socket.EnvKey+"="+srv.Addr().String())
|
||||
}
|
||||
|
||||
plugincmd.Env = append(envs, plugincmd.Env...)
|
||||
|
@ -247,12 +244,9 @@ func tryPluginRun(dockerCli command.Cli, cmd *cobra.Command, subcommand string,
|
|||
// receive signals due to sharing a pgid with the parent CLI
|
||||
continue
|
||||
}
|
||||
if conn != nil {
|
||||
if err := conn.Close(); err != nil {
|
||||
_, _ = fmt.Fprintf(dockerCli.Err(), "failed to signal plugin to close: %v\n", err)
|
||||
}
|
||||
conn = nil
|
||||
}
|
||||
|
||||
srv.Close()
|
||||
|
||||
retries++
|
||||
if retries >= exitLimit {
|
||||
_, _ = fmt.Fprintf(dockerCli.Err(), "got %d SIGTERM/SIGINTs, forcefully exiting\n", retries)
|
||||
|
|
Loading…
Reference in New Issue