2023-07-19 18:33:50 -04:00
|
|
|
//go:build windows
|
|
|
|
|
|
|
|
package socket
|
|
|
|
|
|
|
|
import (
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"net"
|
|
|
|
"sync"
|
|
|
|
"syscall"
|
|
|
|
"unsafe"
|
|
|
|
|
|
|
|
"github.com/Microsoft/go-winio/pkg/guid"
|
|
|
|
"golang.org/x/sys/windows"
|
|
|
|
)
|
|
|
|
|
|
|
|
//go:generate go run github.com/Microsoft/go-winio/tools/mkwinsyscall -output zsyscall_windows.go socket.go
|
|
|
|
|
|
|
|
//sys getsockname(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getsockname
|
|
|
|
//sys getpeername(s windows.Handle, name unsafe.Pointer, namelen *int32) (err error) [failretval==socketError] = ws2_32.getpeername
|
|
|
|
//sys bind(s windows.Handle, name unsafe.Pointer, namelen int32) (err error) [failretval==socketError] = ws2_32.bind
|
|
|
|
|
|
|
|
const socketError = uintptr(^uint32(0))
|
|
|
|
|
|
|
|
var (
|
|
|
|
// todo(helsaawy): create custom error types to store the desired vs actual size and addr family?
|
|
|
|
|
|
|
|
ErrBufferSize = errors.New("buffer size")
|
|
|
|
ErrAddrFamily = errors.New("address family")
|
|
|
|
ErrInvalidPointer = errors.New("invalid pointer")
|
|
|
|
ErrSocketClosed = fmt.Errorf("socket closed: %w", net.ErrClosed)
|
|
|
|
)
|
|
|
|
|
|
|
|
// todo(helsaawy): replace these with generics, ie: GetSockName[S RawSockaddr](s windows.Handle) (S, error)
|
|
|
|
|
|
|
|
// GetSockName writes the local address of socket s to the [RawSockaddr] rsa.
|
|
|
|
// If rsa is not large enough, the [windows.WSAEFAULT] is returned.
|
|
|
|
func GetSockName(s windows.Handle, rsa RawSockaddr) error {
|
|
|
|
ptr, l, err := rsa.Sockaddr()
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// although getsockname returns WSAEFAULT if the buffer is too small, it does not set
|
|
|
|
// &l to the correct size, so--apart from doubling the buffer repeatedly--there is no remedy
|
|
|
|
return getsockname(s, ptr, &l)
|
|
|
|
}
|
|
|
|
|
|
|
|
// GetPeerName returns the remote address the socket is connected to.
|
|
|
|
//
|
|
|
|
// See [GetSockName] for more information.
|
|
|
|
func GetPeerName(s windows.Handle, rsa RawSockaddr) error {
|
|
|
|
ptr, l, err := rsa.Sockaddr()
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return getpeername(s, ptr, &l)
|
|
|
|
}
|
|
|
|
|
|
|
|
func Bind(s windows.Handle, rsa RawSockaddr) (err error) {
|
|
|
|
ptr, l, err := rsa.Sockaddr()
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("could not retrieve socket pointer and size: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return bind(s, ptr, l)
|
|
|
|
}
|
|
|
|
|
|
|
|
// "golang.org/x/sys/windows".ConnectEx and .Bind only accept internal implementations of the
|
|
|
|
// their sockaddr interface, so they cannot be used with HvsockAddr
|
|
|
|
// Replicate functionality here from
|
|
|
|
// https://cs.opensource.google/go/x/sys/+/master:windows/syscall_windows.go
|
|
|
|
|
|
|
|
// The function pointers to `AcceptEx`, `ConnectEx` and `GetAcceptExSockaddrs` must be loaded at
|
|
|
|
// runtime via a WSAIoctl call:
|
|
|
|
// https://docs.microsoft.com/en-us/windows/win32/api/Mswsock/nc-mswsock-lpfn_connectex#remarks
|
|
|
|
|
|
|
|
type runtimeFunc struct {
|
|
|
|
id guid.GUID
|
|
|
|
once sync.Once
|
|
|
|
addr uintptr
|
|
|
|
err error
|
|
|
|
}
|
|
|
|
|
|
|
|
func (f *runtimeFunc) Load() error {
|
|
|
|
f.once.Do(func() {
|
|
|
|
var s windows.Handle
|
|
|
|
s, f.err = windows.Socket(windows.AF_INET, windows.SOCK_STREAM, windows.IPPROTO_TCP)
|
|
|
|
if f.err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer windows.CloseHandle(s) //nolint:errcheck
|
|
|
|
|
|
|
|
var n uint32
|
|
|
|
f.err = windows.WSAIoctl(s,
|
|
|
|
windows.SIO_GET_EXTENSION_FUNCTION_POINTER,
|
|
|
|
(*byte)(unsafe.Pointer(&f.id)),
|
|
|
|
uint32(unsafe.Sizeof(f.id)),
|
|
|
|
(*byte)(unsafe.Pointer(&f.addr)),
|
|
|
|
uint32(unsafe.Sizeof(f.addr)),
|
|
|
|
&n,
|
|
|
|
nil, // overlapped
|
|
|
|
0, // completionRoutine
|
|
|
|
)
|
|
|
|
})
|
|
|
|
return f.err
|
|
|
|
}
|
|
|
|
|
|
|
|
var (
|
|
|
|
// todo: add `AcceptEx` and `GetAcceptExSockaddrs`
|
|
|
|
WSAID_CONNECTEX = guid.GUID{ //revive:disable-line:var-naming ALL_CAPS
|
|
|
|
Data1: 0x25a207b9,
|
|
|
|
Data2: 0xddf3,
|
|
|
|
Data3: 0x4660,
|
|
|
|
Data4: [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
|
|
|
|
}
|
|
|
|
|
|
|
|
connectExFunc = runtimeFunc{id: WSAID_CONNECTEX}
|
|
|
|
)
|
|
|
|
|
|
|
|
func ConnectEx(
|
|
|
|
fd windows.Handle,
|
|
|
|
rsa RawSockaddr,
|
|
|
|
sendBuf *byte,
|
|
|
|
sendDataLen uint32,
|
|
|
|
bytesSent *uint32,
|
|
|
|
overlapped *windows.Overlapped,
|
|
|
|
) error {
|
|
|
|
if err := connectExFunc.Load(); err != nil {
|
|
|
|
return fmt.Errorf("failed to load ConnectEx function pointer: %w", err)
|
|
|
|
}
|
|
|
|
ptr, n, err := rsa.Sockaddr()
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
|
|
|
|
}
|
|
|
|
|
|
|
|
// BOOL LpfnConnectex(
|
|
|
|
// [in] SOCKET s,
|
|
|
|
// [in] const sockaddr *name,
|
|
|
|
// [in] int namelen,
|
|
|
|
// [in, optional] PVOID lpSendBuffer,
|
|
|
|
// [in] DWORD dwSendDataLength,
|
|
|
|
// [out] LPDWORD lpdwBytesSent,
|
|
|
|
// [in] LPOVERLAPPED lpOverlapped
|
|
|
|
// )
|
|
|
|
|
|
|
|
func connectEx(
|
|
|
|
s windows.Handle,
|
|
|
|
name unsafe.Pointer,
|
|
|
|
namelen int32,
|
|
|
|
sendBuf *byte,
|
|
|
|
sendDataLen uint32,
|
|
|
|
bytesSent *uint32,
|
|
|
|
overlapped *windows.Overlapped,
|
|
|
|
) (err error) {
|
2024-05-31 09:49:19 -04:00
|
|
|
r1, _, e1 := syscall.SyscallN(connectExFunc.addr,
|
2023-07-19 18:33:50 -04:00
|
|
|
uintptr(s),
|
|
|
|
uintptr(name),
|
|
|
|
uintptr(namelen),
|
|
|
|
uintptr(unsafe.Pointer(sendBuf)),
|
|
|
|
uintptr(sendDataLen),
|
|
|
|
uintptr(unsafe.Pointer(bytesSent)),
|
|
|
|
uintptr(unsafe.Pointer(overlapped)),
|
2024-05-31 09:49:19 -04:00
|
|
|
)
|
|
|
|
|
2023-07-19 18:33:50 -04:00
|
|
|
if r1 == 0 {
|
|
|
|
if e1 != 0 {
|
|
|
|
err = error(e1)
|
|
|
|
} else {
|
|
|
|
err = syscall.EINVAL
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return err
|
|
|
|
}
|