Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AF_UNIX sockets on Windows #98

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions sockets/sockets.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@
package sockets

import (
"context"
"errors"
"fmt"
"net"
"net/http"
"syscall"
"time"
)

// ErrProtocolNotAvailable is returned when a given transport protocol is not provided by the operating system.
Expand All @@ -24,3 +29,23 @@ func ConfigureTransport(tr *http.Transport, proto, addr string) error {
}
return nil
}

const (
defaultTimeout = 10 * time.Second
maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path)
)

func configureUnixTransport(tr *http.Transport, proto, addr string) error {
if len(addr) > maxUnixSocketPathSize {
return fmt.Errorf("Unix socket path %q is too long", addr)
}
// No need for compression in local communications.
tr.DisableCompression = true
dialer := &net.Dialer{
Timeout: defaultTimeout,
}
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, proto, addr)
}
return nil
}
23 changes: 1 addition & 22 deletions sockets/sockets_unix.go
Original file line number Diff line number Diff line change
@@ -1,36 +1,15 @@
//go:build !windows
// +build !windows

package sockets

import (
"context"
"fmt"
"net"
"net/http"
"syscall"
"time"
)

const (
defaultTimeout = 10 * time.Second
maxUnixSocketPathSize = len(syscall.RawSockaddrUnix{}.Path)
)

func configureUnixTransport(tr *http.Transport, proto, addr string) error {
if len(addr) > maxUnixSocketPathSize {
return fmt.Errorf("Unix socket path %q is too long", addr)
}
// No need for compression in local communications.
tr.DisableCompression = true
dialer := &net.Dialer{
Timeout: defaultTimeout,
}
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, proto, addr)
}
return nil
}

func configureNpipeTransport(tr *http.Transport, proto, addr string) error {
return ErrProtocolNotAvailable
}
Expand Down
4 changes: 0 additions & 4 deletions sockets/sockets_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ import (
"github.com/Microsoft/go-winio"
)

func configureUnixTransport(tr *http.Transport, proto, addr string) error {
return ErrProtocolNotAvailable
}

func configureNpipeTransport(tr *http.Transport, proto, addr string) error {
// No need for compression in local communications.
tr.DisableCompression = true
Expand Down
6 changes: 2 additions & 4 deletions sockets/unix_socket.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
// +build !windows

/*
Package sockets is a simple unix domain socket wrapper.
Expand Down Expand Up @@ -103,9 +101,9 @@ func NewUnixSocketWithOpts(path string, opts ...SockOption) (net.Listener, error
// We don't use "defer" here, to reset the umask to its original value as soon
// as possible. Ideally we'd be able to detect if WithChmod() was passed as
// an option, and skip changing umask if default permissions are used.
origUmask := syscall.Umask(0777)
origUmask := umask(0777)
l, err := net.Listen("unix", path)
syscall.Umask(origUmask)
umask(origUmask)
if err != nil {
return nil, err
}
Expand Down
29 changes: 8 additions & 21 deletions sockets/unix_socket_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
// +build !windows

package sockets

import (
"fmt"
"io/ioutil"
"net"
"os"
"syscall"
"testing"
)

Expand Down Expand Up @@ -52,26 +50,15 @@ func TestNewUnixSocket(t *testing.T) {
}

func TestUnixSocketWithOpts(t *testing.T) {
uid, gid := os.Getuid(), os.Getgid()
perms := os.FileMode(0660)
path := "/tmp/test.sock"
echoStr := "hello"
l, err := NewUnixSocketWithOpts(path, WithChown(uid, gid), WithChmod(perms))
socketFile, err := ioutil.TempFile("", "test*.sock")
if err != nil {
t.Fatal(err)
}
defer socketFile.Close()

l := createTestUnixSocket(t, socketFile.Name())
defer l.Close()
p, err := os.Stat(path)
if err != nil {
t.Fatal(err)
}
if p.Mode().Perm() != perms {
t.Fatalf("unexpected file permissions: expected: %#o, got: %#o", perms, p.Mode().Perm())
}
if stat, ok := p.Sys().(*syscall.Stat_t); ok {
if stat.Uid != uint32(uid) || stat.Gid != uint32(gid) {
t.Fatalf("unexpected file ownership: expected: %d:%d, got: %d:%d", uid, gid, stat.Uid, stat.Gid)
}
}
runTest(t, path, l, echoStr)

echoStr := "hello"
runTest(t, socketFile.Name(), l, echoStr)
}
33 changes: 33 additions & 0 deletions sockets/unix_socket_test_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//go:build !windows
// +build !windows

package sockets

import (
"net"
"os"
"syscall"
"testing"
)

func createTestUnixSocket(t *testing.T, path string) (listener net.Listener) {
uid, gid := os.Getuid(), os.Getgid()
perms := os.FileMode(0660)
l, err := NewUnixSocketWithOpts(path, WithChown(uid, gid), WithChmod(perms))
if err != nil {
t.Fatal(err)
}
p, err := os.Stat(path)
if err != nil {
t.Fatal(err)
}
if p.Mode().Perm() != perms {
t.Fatalf("unexpected file permissions: expected: %#o, got: %#o", perms, p.Mode().Perm())
}
if stat, ok := p.Sys().(*syscall.Stat_t); ok {
if stat.Uid != uint32(uid) || stat.Gid != uint32(gid) {
t.Fatalf("unexpected file ownership: expected: %d:%d, got: %d:%d", uid, gid, stat.Uid, stat.Gid)
}
}
return l
}
14 changes: 14 additions & 0 deletions sockets/unix_socket_test_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package sockets

import (
"net"
"testing"
)

func createTestUnixSocket(t *testing.T, path string) (listener net.Listener) {
l, err := NewUnixSocketWithOpts(path)
if err != nil {
t.Fatal(err)
}
return l
}
10 changes: 10 additions & 0 deletions sockets/unix_socket_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//go:build !windows
// +build !windows

package sockets

import "syscall"

func umask(newmask int) (oldmask int) {
return syscall.Umask(0777)
}
5 changes: 5 additions & 0 deletions sockets/unix_socket_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package sockets

func umask(newmask int) (oldmask int) {
return newmask
}