mirror of https://github.com/docker/cli.git
190 lines
4.8 KiB
Go
190 lines
4.8 KiB
Go
package socket
|
|
|
|
import (
|
|
"errors"
|
|
"io"
|
|
"io/fs"
|
|
"net"
|
|
"os"
|
|
"runtime"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"gotest.tools/v3/assert"
|
|
"gotest.tools/v3/poll"
|
|
)
|
|
|
|
func TestPluginServer(t *testing.T) {
|
|
t.Run("connection closes with EOF when server closes", func(t *testing.T) {
|
|
called := make(chan struct{})
|
|
srv, err := NewPluginServer(func(_ net.Conn) { close(called) })
|
|
assert.NilError(t, err)
|
|
assert.Assert(t, srv != nil, "returned nil server but no error")
|
|
|
|
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
|
|
assert.NilError(t, err, "failed to resolve server address")
|
|
|
|
conn, err := net.DialUnix("unix", nil, addr)
|
|
assert.NilError(t, err, "failed to dial returned server")
|
|
defer conn.Close()
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, err := conn.Read(make([]byte, 1))
|
|
done <- err
|
|
}()
|
|
|
|
select {
|
|
case <-called:
|
|
case <-time.After(10 * time.Millisecond):
|
|
t.Fatal("handler not called")
|
|
}
|
|
|
|
srv.Close()
|
|
|
|
select {
|
|
case err := <-done:
|
|
if !errors.Is(err, io.EOF) {
|
|
t.Fatalf("exepcted EOF error, got: %v", err)
|
|
}
|
|
case <-time.After(10 * time.Millisecond):
|
|
}
|
|
})
|
|
|
|
t.Run("allows reconnects", func(t *testing.T) {
|
|
var calls int32
|
|
h := func(_ net.Conn) {
|
|
atomic.AddInt32(&calls, 1)
|
|
}
|
|
|
|
srv, err := NewPluginServer(h)
|
|
assert.NilError(t, err)
|
|
defer srv.Close()
|
|
|
|
assert.Check(t, srv.Addr() != nil, "returned nil addr but no error")
|
|
|
|
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
|
|
assert.NilError(t, err, "failed to resolve server address")
|
|
|
|
waitForCalls := func(n int) {
|
|
poll.WaitOn(t, func(t poll.LogT) poll.Result {
|
|
if atomic.LoadInt32(&calls) == int32(n) {
|
|
return poll.Success()
|
|
}
|
|
return poll.Continue("waiting for handler to be called")
|
|
})
|
|
}
|
|
|
|
otherConn, err := net.DialUnix("unix", nil, addr)
|
|
assert.NilError(t, err, "failed to dial returned server")
|
|
otherConn.Close()
|
|
waitForCalls(1)
|
|
|
|
conn, err := net.DialUnix("unix", nil, addr)
|
|
assert.NilError(t, err, "failed to redial server")
|
|
defer conn.Close()
|
|
waitForCalls(2)
|
|
|
|
// and again but don't close the existing connection
|
|
conn2, err := net.DialUnix("unix", nil, addr)
|
|
assert.NilError(t, err, "failed to redial server")
|
|
defer conn2.Close()
|
|
waitForCalls(3)
|
|
|
|
srv.Close()
|
|
|
|
// now make sure we get EOF on the existing connections
|
|
buf := make([]byte, 1)
|
|
_, err = conn.Read(buf)
|
|
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
|
|
|
|
_, err = conn2.Read(buf)
|
|
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
|
|
})
|
|
|
|
t.Run("does not leak sockets to local directory", func(t *testing.T) {
|
|
srv, err := NewPluginServer(nil)
|
|
assert.NilError(t, err)
|
|
assert.Check(t, srv != nil, "returned nil server but no error")
|
|
checkDirNoNewPluginServer(t)
|
|
|
|
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
|
|
assert.NilError(t, err, "failed to resolve server address")
|
|
|
|
_, err = net.DialUnix("unix", nil, addr)
|
|
assert.NilError(t, err, "failed to dial returned server")
|
|
checkDirNoNewPluginServer(t)
|
|
})
|
|
}
|
|
|
|
func checkDirNoNewPluginServer(t *testing.T) {
|
|
t.Helper()
|
|
|
|
files, err := os.ReadDir(".")
|
|
assert.NilError(t, err, "failed to list files in dir to check for leaked sockets")
|
|
|
|
for _, f := range files {
|
|
info, err := f.Info()
|
|
assert.NilError(t, err, "failed to check file info")
|
|
// check for a socket with `docker_cli_` in the name (from `SetupConn()`)
|
|
if strings.Contains(f.Name(), "docker_cli_") && info.Mode().Type() == fs.ModeSocket {
|
|
t.Fatal("found socket in a local directory")
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestConnectAndWait(t *testing.T) {
|
|
t.Run("calls cancel func on EOF", func(t *testing.T) {
|
|
srv, err := NewPluginServer(nil)
|
|
assert.NilError(t, err, "failed to setup server")
|
|
defer srv.Close()
|
|
|
|
done := make(chan struct{})
|
|
t.Setenv(EnvKey, srv.Addr().String())
|
|
cancelFunc := func() {
|
|
done <- struct{}{}
|
|
}
|
|
ConnectAndWait(cancelFunc)
|
|
|
|
select {
|
|
case <-done:
|
|
t.Fatal("unexpectedly done")
|
|
default:
|
|
}
|
|
|
|
srv.Close()
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(10 * time.Millisecond):
|
|
t.Fatal("cancel function not closed after 10ms")
|
|
}
|
|
})
|
|
|
|
// TODO: this test cannot be executed with `t.Parallel()`, due to
|
|
// relying on goroutine numbers to ensure correct behaviour
|
|
t.Run("connect goroutine exits after EOF", func(t *testing.T) {
|
|
srv, err := NewPluginServer(nil)
|
|
assert.NilError(t, err, "failed to setup server")
|
|
|
|
defer srv.Close()
|
|
|
|
t.Setenv(EnvKey, srv.Addr().String())
|
|
numGoroutines := runtime.NumGoroutine()
|
|
|
|
ConnectAndWait(func() {})
|
|
assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)
|
|
|
|
srv.Close()
|
|
|
|
poll.WaitOn(t, func(t poll.LogT) poll.Result {
|
|
if runtime.NumGoroutine() > numGoroutines+1 {
|
|
return poll.Continue("waiting for connect goroutine to exit")
|
|
}
|
|
return poll.Success()
|
|
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
|
|
})
|
|
}
|