mirror of https://github.com/docker/cli.git
130 lines
3.6 KiB
Go
130 lines
3.6 KiB
Go
|
package socket
|
||
|
|
||
|
import (
|
||
|
"io/fs"
|
||
|
"net"
|
||
|
"os"
|
||
|
"runtime"
|
||
|
"testing"
|
||
|
"time"
|
||
|
|
||
|
"gotest.tools/v3/assert"
|
||
|
"gotest.tools/v3/poll"
|
||
|
)
|
||
|
|
||
|
func TestSetupConn(t *testing.T) {
|
||
|
t.Run("updates conn when connected", func(t *testing.T) {
|
||
|
var conn *net.UnixConn
|
||
|
listener, err := SetupConn(&conn)
|
||
|
assert.NilError(t, err)
|
||
|
assert.Check(t, listener != nil, "returned nil listener but no error")
|
||
|
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
|
||
|
assert.NilError(t, err, "failed to resolve listener address")
|
||
|
|
||
|
_, err = net.DialUnix("unix", nil, addr)
|
||
|
assert.NilError(t, err, "failed to dial returned listener")
|
||
|
|
||
|
pollConnNotNil(t, &conn)
|
||
|
})
|
||
|
|
||
|
t.Run("allows reconnects", func(t *testing.T) {
|
||
|
var conn *net.UnixConn
|
||
|
listener, err := SetupConn(&conn)
|
||
|
assert.NilError(t, err)
|
||
|
assert.Check(t, listener != nil, "returned nil listener but no error")
|
||
|
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
|
||
|
assert.NilError(t, err, "failed to resolve listener address")
|
||
|
|
||
|
otherConn, err := net.DialUnix("unix", nil, addr)
|
||
|
assert.NilError(t, err, "failed to dial returned listener")
|
||
|
|
||
|
otherConn.Close()
|
||
|
|
||
|
_, err = net.DialUnix("unix", nil, addr)
|
||
|
assert.NilError(t, err, "failed to redial listener")
|
||
|
})
|
||
|
|
||
|
t.Run("does not leak sockets to local directory", func(t *testing.T) {
|
||
|
var conn *net.UnixConn
|
||
|
listener, err := SetupConn(&conn)
|
||
|
assert.NilError(t, err)
|
||
|
assert.Check(t, listener != nil, "returned nil listener but no error")
|
||
|
checkDirClean(t)
|
||
|
|
||
|
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
|
||
|
assert.NilError(t, err, "failed to resolve listener address")
|
||
|
_, err = net.DialUnix("unix", nil, addr)
|
||
|
assert.NilError(t, err, "failed to dial returned listener")
|
||
|
checkDirClean(t)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func checkDirClean(t *testing.T) {
|
||
|
t.Helper()
|
||
|
|
||
|
files, err := os.ReadDir(".")
|
||
|
assert.NilError(t, err, "failed to list files in dir to check for leaked sockets")
|
||
|
|
||
|
for _, f := range files {
|
||
|
info, err := f.Info()
|
||
|
assert.NilError(t, err, "failed to check file info")
|
||
|
if info.Mode().Type() == fs.ModeSocket {
|
||
|
t.Fatal("found socket in a local directory")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestConnectAndWait(t *testing.T) {
|
||
|
t.Run("calls cancel func on EOF", func(t *testing.T) {
|
||
|
var conn *net.UnixConn
|
||
|
listener, err := SetupConn(&conn)
|
||
|
assert.NilError(t, err, "failed to setup listener")
|
||
|
|
||
|
done := make(chan struct{})
|
||
|
t.Setenv(EnvKey, listener.Addr().String())
|
||
|
cancelFunc := func() {
|
||
|
done <- struct{}{}
|
||
|
}
|
||
|
ConnectAndWait(cancelFunc)
|
||
|
pollConnNotNil(t, &conn)
|
||
|
conn.Close()
|
||
|
|
||
|
select {
|
||
|
case <-done:
|
||
|
case <-time.After(10 * time.Millisecond):
|
||
|
t.Fatal("cancel function not closed after 10ms")
|
||
|
}
|
||
|
})
|
||
|
|
||
|
t.Run("connect goroutine exits after EOF", func(t *testing.T) {
|
||
|
var conn *net.UnixConn
|
||
|
listener, err := SetupConn(&conn)
|
||
|
assert.NilError(t, err, "failed to setup listener")
|
||
|
t.Setenv(EnvKey, listener.Addr().String())
|
||
|
numGoroutines := runtime.NumGoroutine()
|
||
|
|
||
|
ConnectAndWait(func() {})
|
||
|
assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)
|
||
|
|
||
|
pollConnNotNil(t, &conn)
|
||
|
conn.Close()
|
||
|
poll.WaitOn(t, func(t poll.LogT) poll.Result {
|
||
|
if runtime.NumGoroutine() > numGoroutines+1 {
|
||
|
return poll.Continue("waiting for connect goroutine to exit")
|
||
|
}
|
||
|
return poll.Success()
|
||
|
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func pollConnNotNil(t *testing.T, conn **net.UnixConn) {
|
||
|
t.Helper()
|
||
|
|
||
|
poll.WaitOn(t, func(t poll.LogT) poll.Result {
|
||
|
if *conn == nil {
|
||
|
return poll.Continue("waiting for conn to not be nil")
|
||
|
}
|
||
|
return poll.Success()
|
||
|
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
|
||
|
}
|