diff --git a/sockets/sockets.go b/sockets/sockets.go index 2e9e9006..5bd3ca4f 100644 --- a/sockets/sockets.go +++ b/sockets/sockets.go @@ -2,8 +2,13 @@ package sockets import ( + "context" "errors" + "fmt" + "net" "net/http" + "syscall" + "time" ) // ErrProtocolNotAvailable is returned when a given transport protocol is not provided by the operating system. @@ -24,3 +29,23 @@ func ConfigureTransport(tr *http.Transport, proto, addr string) error { } return nil } + +const ( + defaultTimeout = 10 * time.Second + maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path) +) + +func configureUnixTransport(tr *http.Transport, proto, addr string) error { + if len(addr) > maxUnixSocketPathSize { + return fmt.Errorf("Unix socket path %q is too long", addr) + } + // No need for compression in local communications. + tr.DisableCompression = true + dialer := &net.Dialer{ + Timeout: defaultTimeout, + } + tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + return dialer.DialContext(ctx, proto, addr) + } + return nil +} diff --git a/sockets/sockets_unix.go b/sockets/sockets_unix.go index 10d76342..f5c7f059 100644 --- a/sockets/sockets_unix.go +++ b/sockets/sockets_unix.go @@ -1,36 +1,15 @@ +//go:build !windows // +build !windows package sockets import ( - "context" - "fmt" "net" "net/http" "syscall" "time" ) -const ( - defaultTimeout = 10 * time.Second - maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path) -) - -func configureUnixTransport(tr *http.Transport, proto, addr string) error { - if len(addr) > maxUnixSocketPathSize { - return fmt.Errorf("Unix socket path %q is too long", addr) - } - // No need for compression in local communications. - tr.DisableCompression = true - dialer := &net.Dialer{ - Timeout: defaultTimeout, - } - tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { - return dialer.DialContext(ctx, proto, addr) - } - return nil -} - func configureNpipeTransport(tr *http.Transport, proto, addr string) error { return ErrProtocolNotAvailable } diff --git a/sockets/sockets_windows.go b/sockets/sockets_windows.go index 7acafc5a..d4f2e788 100644 --- a/sockets/sockets_windows.go +++ b/sockets/sockets_windows.go @@ -9,10 +9,6 @@ import ( "github.com/Microsoft/go-winio" ) -func configureUnixTransport(tr *http.Transport, proto, addr string) error { - return ErrProtocolNotAvailable -} - func configureNpipeTransport(tr *http.Transport, proto, addr string) error { // No need for compression in local communications. tr.DisableCompression = true diff --git a/sockets/unix_socket.go b/sockets/unix_socket.go index e7591e6e..9799a26c 100644 --- a/sockets/unix_socket.go +++ b/sockets/unix_socket.go @@ -1,5 +1,3 @@ -// +build !windows - /* Package sockets is a simple unix domain socket wrapper. @@ -103,9 +101,9 @@ func NewUnixSocketWithOpts(path string, opts ...SockOption) (net.Listener, error // We don't use "defer" here, to reset the umask to its original value as soon // as possible. Ideally we'd be able to detect if WithChmod() was passed as // an option, and skip changing umask if default permissions are used. - origUmask := syscall.Umask(0777) + origUmask := umask(0777) l, err := net.Listen("unix", path) - syscall.Umask(origUmask) + umask(origUmask) if err != nil { return nil, err } diff --git a/sockets/unix_socket_test.go b/sockets/unix_socket_test.go index 8957efd3..cffb6b36 100644 --- a/sockets/unix_socket_test.go +++ b/sockets/unix_socket_test.go @@ -1,12 +1,10 @@ -// +build !windows - package sockets import ( "fmt" + "io/ioutil" "net" "os" - "syscall" "testing" ) @@ -52,26 +50,15 @@ func TestNewUnixSocket(t *testing.T) { } func TestUnixSocketWithOpts(t *testing.T) { - uid, gid := os.Getuid(), os.Getgid() - perms := os.FileMode(0660) - path := "/tmp/test.sock" - echoStr := "hello" - l, err := NewUnixSocketWithOpts(path, WithChown(uid, gid), WithChmod(perms)) + socketFile, err := ioutil.TempFile("", "test*.sock") if err != nil { t.Fatal(err) } + defer socketFile.Close() + + l := createTestUnixSocket(t, socketFile.Name()) defer l.Close() - p, err := os.Stat(path) - if err != nil { - t.Fatal(err) - } - if p.Mode().Perm() != perms { - t.Fatalf("unexpected file permissions: expected: %#o, got: %#o", perms, p.Mode().Perm()) - } - if stat, ok := p.Sys().(*syscall.Stat_t); ok { - if stat.Uid != uint32(uid) || stat.Gid != uint32(gid) { - t.Fatalf("unexpected file ownership: expected: %d:%d, got: %d:%d", uid, gid, stat.Uid, stat.Gid) - } - } - runTest(t, path, l, echoStr) + + echoStr := "hello" + runTest(t, socketFile.Name(), l, echoStr) } diff --git a/sockets/unix_socket_test_unix.go b/sockets/unix_socket_test_unix.go new file mode 100644 index 00000000..424b408b --- /dev/null +++ b/sockets/unix_socket_test_unix.go @@ -0,0 +1,33 @@ +//go:build !windows +// +build !windows + +package sockets + +import ( + "net" + "os" + "syscall" + "testing" +) + +func createTestUnixSocket(t *testing.T, path string) (listener net.Listener) { + uid, gid := os.Getuid(), os.Getgid() + perms := os.FileMode(0660) + l, err := NewUnixSocketWithOpts(path, WithChown(uid, gid), WithChmod(perms)) + if err != nil { + t.Fatal(err) + } + p, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if p.Mode().Perm() != perms { + t.Fatalf("unexpected file permissions: expected: %#o, got: %#o", perms, p.Mode().Perm()) + } + if stat, ok := p.Sys().(*syscall.Stat_t); ok { + if stat.Uid != uint32(uid) || stat.Gid != uint32(gid) { + t.Fatalf("unexpected file ownership: expected: %d:%d, got: %d:%d", uid, gid, stat.Uid, stat.Gid) + } + } + return l +} diff --git a/sockets/unix_socket_test_windows.go b/sockets/unix_socket_test_windows.go new file mode 100644 index 00000000..e68aca0b --- /dev/null +++ b/sockets/unix_socket_test_windows.go @@ -0,0 +1,14 @@ +package sockets + +import ( + "net" + "testing" +) + +func createTestUnixSocket(t *testing.T, path string) (listener net.Listener) { + l, err := NewUnixSocketWithOpts(path) + if err != nil { + t.Fatal(err) + } + return l +} diff --git a/sockets/unix_socket_unix.go b/sockets/unix_socket_unix.go new file mode 100644 index 00000000..0e5b668b --- /dev/null +++ b/sockets/unix_socket_unix.go @@ -0,0 +1,10 @@ +//go:build !windows +// +build !windows + +package sockets + +import "syscall" + +func umask(newmask int) (oldmask int) { + return syscall.Umask(0777) +} diff --git a/sockets/unix_socket_windows.go b/sockets/unix_socket_windows.go new file mode 100644 index 00000000..e8b295a6 --- /dev/null +++ b/sockets/unix_socket_windows.go @@ -0,0 +1,5 @@ +package sockets + +func umask(newmask int) (oldmask int) { + return newmask +}