2024-01-12 13:17:03 -05:00
|
|
|
package socket
|
|
|
|
|
|
|
|
import (
|
2024-02-12 07:26:54 -05:00
|
|
|
"crypto/rand"
|
|
|
|
"encoding/hex"
|
2024-01-12 13:17:03 -05:00
|
|
|
"errors"
|
|
|
|
"io"
|
|
|
|
"net"
|
|
|
|
"os"
|
2024-02-29 16:33:03 -05:00
|
|
|
"runtime"
|
|
|
|
"sync"
|
2024-01-12 13:17:03 -05:00
|
|
|
)
|
|
|
|
|
2024-03-22 00:34:39 -04:00
|
|
|
// EnvKey represents the well-known environment variable used to pass the
|
|
|
|
// plugin being executed the socket name it should listen on to coordinate with
|
|
|
|
// the host CLI.
|
2024-01-12 13:17:03 -05:00
|
|
|
const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET"
|
|
|
|
|
2024-03-22 00:34:39 -04:00
|
|
|
// NewPluginServer creates a plugin server that listens on a new Unix domain
|
|
|
|
// socket. h is called for each new connection to the socket in a goroutine.
|
2024-02-29 16:33:03 -05:00
|
|
|
func NewPluginServer(h func(net.Conn)) (*PluginServer, error) {
|
2024-03-22 10:55:18 -04:00
|
|
|
// Listen on a Unix socket, with the address being platform-dependent.
|
|
|
|
// When a non-abstract address is used, Go will unlink(2) the socket
|
|
|
|
// for us once the listener is closed, as documented in
|
|
|
|
// [net.UnixListener.SetUnlinkOnClose].
|
|
|
|
l, err := net.ListenUnix("unix", &net.UnixAddr{
|
|
|
|
Name: socketName("docker_cli_" + randomID()),
|
|
|
|
Net: "unix",
|
|
|
|
})
|
2024-01-12 13:17:03 -05:00
|
|
|
if err != nil {
|
2024-01-15 09:29:48 -05:00
|
|
|
return nil, err
|
2024-01-12 13:17:03 -05:00
|
|
|
}
|
|
|
|
|
2024-02-29 16:33:03 -05:00
|
|
|
if h == nil {
|
|
|
|
h = func(net.Conn) {}
|
|
|
|
}
|
|
|
|
|
|
|
|
pl := &PluginServer{
|
|
|
|
l: l,
|
|
|
|
h: h,
|
|
|
|
}
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
defer pl.Close()
|
|
|
|
for {
|
|
|
|
err := pl.accept()
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
|
|
|
|
return pl, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
type PluginServer struct {
|
|
|
|
mu sync.Mutex
|
|
|
|
conns []net.Conn
|
|
|
|
l *net.UnixListener
|
|
|
|
h func(net.Conn)
|
|
|
|
closed bool
|
|
|
|
}
|
|
|
|
|
|
|
|
func (pl *PluginServer) accept() error {
|
|
|
|
conn, err := pl.l.Accept()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
pl.mu.Lock()
|
|
|
|
defer pl.mu.Unlock()
|
|
|
|
|
|
|
|
if pl.closed {
|
2024-03-22 00:34:39 -04:00
|
|
|
// Handle potential race between Close and accept.
|
2024-02-29 16:33:03 -05:00
|
|
|
conn.Close()
|
|
|
|
return errors.New("plugin server is closed")
|
|
|
|
}
|
2024-01-12 13:17:03 -05:00
|
|
|
|
2024-02-29 16:33:03 -05:00
|
|
|
pl.conns = append(pl.conns, conn)
|
|
|
|
|
|
|
|
go pl.h(conn)
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2024-03-22 00:34:39 -04:00
|
|
|
// Addr returns the [net.Addr] of the underlying [net.Listener].
|
2024-02-29 16:33:03 -05:00
|
|
|
func (pl *PluginServer) Addr() net.Addr {
|
|
|
|
return pl.l.Addr()
|
|
|
|
}
|
|
|
|
|
2024-03-22 00:34:39 -04:00
|
|
|
// Close ensures that the server is no longer accepting new connections and
|
|
|
|
// closes all existing connections. Existing connections will receive [io.EOF].
|
|
|
|
//
|
|
|
|
// The error value is that of the underlying [net.Listner.Close] call.
|
2024-02-29 16:33:03 -05:00
|
|
|
func (pl *PluginServer) Close() error {
|
2024-03-22 00:34:39 -04:00
|
|
|
// Close connections first to ensure the connections get io.EOF instead
|
|
|
|
// of a connection reset.
|
2024-02-29 16:33:03 -05:00
|
|
|
pl.closeAllConns()
|
|
|
|
|
2024-03-22 00:34:39 -04:00
|
|
|
// Try to ensure that any active connections have a chance to receive
|
|
|
|
// io.EOF.
|
2024-02-29 16:33:03 -05:00
|
|
|
runtime.Gosched()
|
|
|
|
|
|
|
|
return pl.l.Close()
|
|
|
|
}
|
|
|
|
|
|
|
|
func (pl *PluginServer) closeAllConns() {
|
|
|
|
pl.mu.Lock()
|
|
|
|
defer pl.mu.Unlock()
|
|
|
|
|
2024-03-22 00:34:39 -04:00
|
|
|
// Prevent new connections from being accepted.
|
2024-02-29 16:33:03 -05:00
|
|
|
pl.closed = true
|
|
|
|
|
|
|
|
for _, conn := range pl.conns {
|
|
|
|
conn.Close()
|
|
|
|
}
|
|
|
|
|
|
|
|
pl.conns = nil
|
2024-01-12 13:17:03 -05:00
|
|
|
}
|
|
|
|
|
2024-02-12 07:26:54 -05:00
|
|
|
func randomID() string {
|
|
|
|
b := make([]byte, 16)
|
|
|
|
if _, err := rand.Read(b); err != nil {
|
|
|
|
panic(err) // This shouldn't happen
|
|
|
|
}
|
|
|
|
return hex.EncodeToString(b)
|
|
|
|
}
|
|
|
|
|
2024-01-12 13:17:03 -05:00
|
|
|
// ConnectAndWait connects to the socket passed via well-known env var,
|
|
|
|
// if present, and attempts to read from it until it receives an EOF, at which
|
|
|
|
// point cb is called.
|
|
|
|
func ConnectAndWait(cb func()) {
|
|
|
|
socketAddr, ok := os.LookupEnv(EnvKey)
|
|
|
|
if !ok {
|
|
|
|
// if a plugin compiled against a more recent version of docker/cli
|
|
|
|
// is executed by an older CLI binary, ignore missing environment
|
|
|
|
// variable and behave as usual
|
|
|
|
return
|
|
|
|
}
|
|
|
|
addr, err := net.ResolveUnixAddr("unix", socketAddr)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
conn, err := net.DialUnix("unix", nil, addr)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
b := make([]byte, 1)
|
|
|
|
for {
|
|
|
|
_, err := conn.Read(b)
|
|
|
|
if errors.Is(err, io.EOF) {
|
|
|
|
cb()
|
2024-01-19 20:06:43 -05:00
|
|
|
return
|
2024-01-12 13:17:03 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}()
|
|
|
|
}
|