From 13bdb230ec694c9ab47cc282506e2a0078e82009 Mon Sep 17 00:00:00 2001 From: Sander Bruens Date: Wed, 29 Jan 2025 19:27:43 -0500 Subject: [PATCH] feat: add WebSocket support to the existing `outline-ss-server` (#225) This allows clients to connect to Shadowsocks over WebSockets. --- cmd/outline-ss-server/config.go | 195 +++++++++++-- cmd/outline-ss-server/config_example.yml | 12 + cmd/outline-ss-server/config_test.go | 339 +++++++++++++++-------- cmd/outline-ss-server/main.go | 128 ++++++++- go.mod | 1 + go.sum | 2 + 6 files changed, 531 insertions(+), 146 deletions(-) diff --git a/cmd/outline-ss-server/config.go b/cmd/outline-ss-server/config.go index 0ed7295d..120046ac 100644 --- a/cmd/outline-ss-server/config.go +++ b/cmd/outline-ss-server/config.go @@ -15,12 +15,21 @@ package main import ( + "errors" "fmt" "net" + "reflect" + "strings" + "github.com/go-viper/mapstructure/v2" "gopkg.in/yaml.v3" ) +type Validator interface { + // Validate checks that the type is valid. + validate() error +} + type ServiceConfig struct { Listeners []ListenerConfig Keys []KeyConfig @@ -29,15 +38,143 @@ type ServiceConfig struct { type ListenerType string -const listenerTypeTCP ListenerType = "tcp" +const ( + TCPListenerType = ListenerType("tcp") + UDPListenerType = ListenerType("udp") + WebsocketStreamListenerType = ListenerType("websocket-stream") + WebsocketPacketListenerType = ListenerType("websocket-packet") +) + +type WebServerConfig struct { + // Unique identifier of the web server to be referenced in Websocket connections. + ID string -const listenerTypeUDP ListenerType = "udp" + // List of listener addresses (e.g., ":8080", "localhost:8081"). Should be localhost for HTTP. + Listeners []string `yaml:"listen"` +} +// ListenerConfig holds the configuration for a listener. It supports different +// listener types, configured via the embedded type and unmarshalled based on +// the "type" field in the YAML/JSON configuration. Only one of the fields will +// be set, corresponding to the listener type. type ListenerConfig struct { - Type ListenerType + // TCP configuration for the listener. + TCP *TCPUDPConfig + // UDP configuration for the listener. + UDP *TCPUDPConfig + // Websocket stream configuration for the listener. + WebsocketStream *WebsocketConfig + // Websocket packet configuration for the listener. + WebsocketPacket *WebsocketConfig +} + +var _ Validator = (*ListenerConfig)(nil) +var _ yaml.Unmarshaler = (*ListenerConfig)(nil) + +// Define a map to associate listener types with [ListenerConfig] field names. +var listenerTypeMap = map[ListenerType]string{ + TCPListenerType: "TCP", + UDPListenerType: "UDP", + WebsocketStreamListenerType: "WebsocketStream", + WebsocketPacketListenerType: "WebsocketPacket", +} + +func (c *ListenerConfig) UnmarshalYAML(value *yaml.Node) error { + var raw map[string]interface{} + if err := value.Decode(&raw); err != nil { + return err + } + + // Remove the "type" field so we can decode directly into the target struct. + rawType, ok := raw["type"] + if !ok { + return errors.New("`type` field required") + } + lnTypeStr, ok := rawType.(string) + if !ok { + return fmt.Errorf("`type` is not a string, but %T", rawType) + } + lnType := ListenerType(lnTypeStr) + delete(raw, "type") + + fieldName, ok := listenerTypeMap[lnType] + if !ok { + return fmt.Errorf("invalid listener type: %v", lnType) + } + v := reflect.ValueOf(c).Elem() + field := v.FieldByName(fieldName) + if !field.IsValid() { + return fmt.Errorf("invalid field name: %s for type: %s", fieldName, lnType) + } + fieldType := field.Type() + if fieldType.Kind() != reflect.Ptr || fieldType.Elem().Kind() != reflect.Struct { + return fmt.Errorf("field %s is not a pointer to a struct", fieldName) + } + + configValue := reflect.New(fieldType.Elem()) + field.Set(configValue) + if err := mapstructure.Decode(raw, configValue.Interface()); err != nil { + return fmt.Errorf("failed to decode map: %w", err) + } + return nil +} + +func (c *ListenerConfig) validate() error { + v := reflect.ValueOf(c).Elem() + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.Kind() == reflect.Ptr && field.IsNil() { + continue + } + if validator, ok := field.Interface().(Validator); ok { + if err := validator.validate(); err != nil { + return fmt.Errorf("invalid config: %v", err) + } + } + } + return nil +} + +type TCPUDPConfig struct { + // Address for the TCP or UDP listener. Should be in the format host:port. Address string } +var _ Validator = (*TCPUDPConfig)(nil) + +func (c *TCPUDPConfig) validate() error { + if c.Address == "" { + return errors.New("`address` must be specified") + } + if err := validateAddress(c.Address); err != nil { + return fmt.Errorf("invalid address: %v", err) + } + return nil +} + +type WebsocketConfig struct { + // Web server unique identifier to use for the websocket connection. + WebServer string `mapstructure:"web_server"` + // Path for the websocket connection. + Path string +} + +var _ Validator = (*WebsocketConfig)(nil) + +func (c *WebsocketConfig) validate() error { + if c.WebServer == "" { + return errors.New("`web_server` must be specified") + } + if c.Path == "" { + return errors.New("`path` must be specified") + } + if !strings.HasPrefix(c.Path, "/") { + return errors.New("`path` must start with `/`") + } + return nil +} + type DialerConfig struct { Fwmark uint } @@ -53,7 +190,12 @@ type LegacyKeyServiceConfig struct { Port int } +type WebConfig struct { + Servers []WebServerConfig `yaml:"servers"` +} + type Config struct { + Web WebConfig Services []ServiceConfig // Deprecated: `keys` exists for backward compatibility. Prefer to configure @@ -61,32 +203,41 @@ type Config struct { Keys []LegacyKeyServiceConfig } -// Validate checks that the config is valid. -func (c *Config) Validate() error { - existingListeners := make(map[string]bool) - for _, serviceConfig := range c.Services { - for _, lnConfig := range serviceConfig.Listeners { - // TODO: Support more listener types. - if lnConfig.Type != listenerTypeTCP && lnConfig.Type != listenerTypeUDP { - return fmt.Errorf("unsupported listener type: %s", lnConfig.Type) - } - host, _, err := net.SplitHostPort(lnConfig.Address) - if err != nil { - return fmt.Errorf("invalid listener address `%s`: %v", lnConfig.Address, err) - } - if ip := net.ParseIP(host); ip == nil { - return fmt.Errorf("address must be IP, found: %s", host) +var _ Validator = (*Config)(nil) + +func (c *Config) validate() error { + for _, srv := range c.Web.Servers { + if srv.ID == "" { + return fmt.Errorf("web server must have an ID") + } + for _, addr := range srv.Listeners { + if err := validateAddress(addr); err != nil { + return fmt.Errorf("invalid listener for web server `%s`: %w", srv.ID, err) } - key := string(lnConfig.Type) + "/" + lnConfig.Address - if _, exists := existingListeners[key]; exists { - return fmt.Errorf("listener of type %s with address %s already exists.", lnConfig.Type, lnConfig.Address) + } + } + + for _, service := range c.Services { + for _, ln := range service.Listeners { + if err := ln.validate(); err != nil { + return fmt.Errorf("invalid listener: %v", err) } - existingListeners[key] = true } } return nil } +func validateAddress(addr string) error { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return err + } + if ip := net.ParseIP(host); ip == nil { + return fmt.Errorf("address must be IP, found: %s", host) + } + return nil +} + // readConfig attempts to read a config from a filename and parses it as a [Config]. func readConfig(configData []byte) (*Config, error) { config := Config{} diff --git a/cmd/outline-ss-server/config_example.yml b/cmd/outline-ss-server/config_example.yml index 8dddbb6f..d2edc75d 100644 --- a/cmd/outline-ss-server/config_example.yml +++ b/cmd/outline-ss-server/config_example.yml @@ -12,6 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +web: + servers: + - id: my_web_server + listen: + - "127.0.0.1:8000" + services: - listeners: # TODO(sbruens): Allow a string-based listener config, as a convenient short-form @@ -20,6 +26,12 @@ services: address: "[::]:9000" - type: udp address: "[::]:9000" + - type: websocket-stream + web_server: my_web_server + path: "/SECRET/tcp" # Prevent probing by serving under a secret path. + - type: websocket-packet + web_server: my_web_server + path: "/SECRET/udp" # Prevent probing by serving under a secret path. keys: - id: user-0 cipher: chacha20-ietf-poly1305 diff --git a/cmd/outline-ss-server/config_test.go b/cmd/outline-ss-server/config_test.go index f183ff5a..d5dec092 100644 --- a/cmd/outline-ss-server/config_test.go +++ b/cmd/outline-ss-server/config_test.go @@ -16,152 +16,257 @@ package main import ( "os" + "strings" "testing" "github.com/stretchr/testify/require" ) -func TestValidateConfigFails(t *testing.T) { - tests := []struct { - name string - cfg *Config - }{ - { - name: "WithUnknownListenerType", - cfg: &Config{ - Services: []ServiceConfig{ - ServiceConfig{ - Listeners: []ListenerConfig{ - ListenerConfig{Type: "foo", Address: "[::]:9000"}, - }, - }, - }, +func TestConfigValidate(t *testing.T) { + t.Run("InvalidConfig/InvalidListenerType", func(t *testing.T) { + yaml := ` +services: + - listeners: + - type: + - tcp + - udp + address: "[::]:9000" +` + + _, err := readConfig([]byte(yaml)) + + require.Error(t, err) + }) + + t.Run("InvalidConfig/UnknownListenerType", func(t *testing.T) { + yaml := ` +services: + - listeners: + - type: foo + address: "[::]:9000" +` + + _, err := readConfig([]byte(yaml)) + + require.Error(t, err) + }) + + t.Run("InvalidConfig", func(t *testing.T) { + tests := []struct { + name string + yaml string + errStr string + }{ + { + name: "MissingAddress", + yaml: ` +services: + - listeners: + - type: tcp +`, + errStr: "`address` must be specified", }, - }, - { - name: "WithInvalidListenerAddress", - cfg: &Config{ - Services: []ServiceConfig{ - ServiceConfig{ - Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeTCP, Address: "tcp/[::]:9000"}, - }, - }, - }, + { + name: "InvalidAddress", + yaml: ` +services: + - listeners: + - type: tcp + address: "tcp/[::]:9000" +`, + errStr: "invalid address", }, - }, - { - name: "WithHostnameAddress", - cfg: &Config{ - Services: []ServiceConfig{ - ServiceConfig{ - Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeTCP, Address: "example.com:9000"}, - }, - }, - }, + { + name: "HostnameAddress", + yaml: ` +services: + - listeners: + - type: tcp + address: "example.com:9000" +`, + errStr: "address must be IP", }, - }, - { - name: "WithDuplicateListeners", - cfg: &Config{ - Services: []ServiceConfig{ - ServiceConfig{ - Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9000"}, - }, - }, - ServiceConfig{ - Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9000"}, - }, - }, - }, + { + name: "WebServerMissingID", + yaml: ` +web: + servers: + - listen: + - "127.0.0.1:8000" +`, + errStr: "web server must have an ID", }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - err := tc.cfg.Validate() - require.Error(t, err) - }) - } + { + name: "WebServerInvalidAddress", + yaml: ` +web: + servers: + - id: foo + listen: + - ":invalid" +`, + errStr: "invalid listener for web server `foo`", + }, + { + name: "WebsocketMissingWebServer", + yaml: ` +services: + - listeners: + - type: websocket-stream + path: "/tcp" +`, + errStr: "`web_server` must be specified", + }, + { + name: "WebsocketMissingPath", + yaml: ` +services: + - listeners: + - type: websocket-stream + web_server: my_web_server +`, + errStr: "`path` must be specified", + }, + { + name: "WebsocketInvalidPath", + yaml: ` +services: + - listeners: + - type: websocket-stream + web_server: my_web_server + path: "tcp" +`, + errStr: "`path` must start with `/`", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cfg, err := readConfig([]byte(tc.yaml)) + require.NoError(t, err) + err = cfg.validate() + require.Error(t, err) + if !isStrInError(err, tc.errStr) { + t.Errorf("config validation error=`%v`, expected=`%v`", err, tc.errStr) + } + }) + } + }) + + t.Run("ValidConfig", func(t *testing.T) { + yaml := ` +web: + servers: + - id: my_web_server + listen: + - "127.0.0.1:8000" + +services: + - listeners: + - type: tcp + address: "[::]:9000" + - type: websocket-stream + web_server: my_web_server + path: "/tcp" + keys: + - id: user-0 + cipher: chacha20-ietf-poly1305 + secret: Secret0 +` + cfg, err := readConfig([]byte(yaml)) + require.NoError(t, err) + err = cfg.validate() + require.NoError(t, err) + }) } func TestReadConfig(t *testing.T) { - config, err := readConfigFile("./config_example.yml") - - require.NoError(t, err) - expected := Config{ - Services: []ServiceConfig{ - ServiceConfig{ - Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9000"}, - ListenerConfig{Type: listenerTypeUDP, Address: "[::]:9000"}, - }, - Keys: []KeyConfig{ - KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, - KeyConfig{"user-1", "chacha20-ietf-poly1305", "Secret1"}, + + t.Run("ExampleFile", func(t *testing.T) { + config, err := readConfigFile("./config_example.yml") + + require.NoError(t, err) + expected := Config{ + Web: WebConfig{ + Servers: []WebServerConfig{ + WebServerConfig{ID: "my_web_server", Listeners: []string{"127.0.0.1:8000"}}, }, }, - ServiceConfig{ - Listeners: []ListenerConfig{ - ListenerConfig{Type: listenerTypeTCP, Address: "[::]:9001"}, - ListenerConfig{Type: listenerTypeUDP, Address: "[::]:9001"}, + Services: []ServiceConfig{ + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{TCP: &TCPUDPConfig{Address: "[::]:9000"}}, + ListenerConfig{UDP: &TCPUDPConfig{Address: "[::]:9000"}}, + ListenerConfig{WebsocketStream: &WebsocketConfig{WebServer: "my_web_server", Path: "/SECRET/tcp"}}, + ListenerConfig{WebsocketPacket: &WebsocketConfig{WebServer: "my_web_server", Path: "/SECRET/udp"}}, + }, + Keys: []KeyConfig{ + KeyConfig{"user-0", "chacha20-ietf-poly1305", "Secret0"}, + KeyConfig{"user-1", "chacha20-ietf-poly1305", "Secret1"}, + }, }, - Keys: []KeyConfig{ - KeyConfig{"user-2", "chacha20-ietf-poly1305", "Secret2"}, + ServiceConfig{ + Listeners: []ListenerConfig{ + ListenerConfig{TCP: &TCPUDPConfig{Address: "[::]:9001"}}, + ListenerConfig{UDP: &TCPUDPConfig{Address: "[::]:9001"}}, + }, + Keys: []KeyConfig{ + KeyConfig{"user-2", "chacha20-ietf-poly1305", "Secret2"}, + }, }, }, - }, - } - require.Equal(t, expected, *config) -} + } + require.Equal(t, expected, *config) + }) -func TestReadConfigParsesDeprecatedFormat(t *testing.T) { - config, err := readConfigFile("./config_example.deprecated.yml") + t.Run("ParsesDeprecatedFormat", func(t *testing.T) { + config, err := readConfigFile("./config_example.deprecated.yml") - require.NoError(t, err) - expected := Config{ - Keys: []LegacyKeyServiceConfig{ - LegacyKeyServiceConfig{ - KeyConfig: KeyConfig{ID: "user-0", Cipher: "chacha20-ietf-poly1305", Secret: "Secret0"}, - Port: 9000, - }, - LegacyKeyServiceConfig{ - KeyConfig: KeyConfig{ID: "user-1", Cipher: "chacha20-ietf-poly1305", Secret: "Secret1"}, - Port: 9000, - }, - LegacyKeyServiceConfig{ - KeyConfig: KeyConfig{ID: "user-2", Cipher: "chacha20-ietf-poly1305", Secret: "Secret2"}, - Port: 9001, + require.NoError(t, err) + expected := Config{ + Keys: []LegacyKeyServiceConfig{ + LegacyKeyServiceConfig{ + KeyConfig: KeyConfig{ID: "user-0", Cipher: "chacha20-ietf-poly1305", Secret: "Secret0"}, + Port: 9000, + }, + LegacyKeyServiceConfig{ + KeyConfig: KeyConfig{ID: "user-1", Cipher: "chacha20-ietf-poly1305", Secret: "Secret1"}, + Port: 9000, + }, + LegacyKeyServiceConfig{ + KeyConfig: KeyConfig{ID: "user-2", Cipher: "chacha20-ietf-poly1305", Secret: "Secret2"}, + Port: 9001, + }, }, - }, - } - require.Equal(t, expected, *config) -} + } + require.Equal(t, expected, *config) + }) -func TestReadConfigFromEmptyFile(t *testing.T) { - file, _ := os.CreateTemp("", "empty.yaml") + t.Run("FromEmptyFile", func(t *testing.T) { + file, _ := os.CreateTemp("", "empty.yaml") - config, err := readConfigFile(file.Name()) + config, err := readConfigFile(file.Name()) - require.NoError(t, err) - require.ElementsMatch(t, Config{}, config) -} + require.NoError(t, err) + require.ElementsMatch(t, Config{}, config) + }) -func TestReadConfigFromIncorrectFormatFails(t *testing.T) { - file, _ := os.CreateTemp("", "empty.yaml") - file.WriteString("foo") + t.Run("FromIncorrectFormatFails", func(t *testing.T) { + file, _ := os.CreateTemp("", "empty.yaml") + file.WriteString("foo") - config, err := readConfigFile(file.Name()) + config, err := readConfigFile(file.Name()) - require.Error(t, err) - require.ElementsMatch(t, Config{}, config) + require.Error(t, err) + require.ElementsMatch(t, Config{}, config) + }) } func readConfigFile(filename string) (*Config, error) { configData, _ := os.ReadFile(filename) return readConfig(configData) } + +func isStrInError(err error, str string) bool { + return err != nil && strings.Contains(err.Error(), str) +} diff --git a/cmd/outline-ss-server/main.go b/cmd/outline-ss-server/main.go index e8203d42..143d0f2e 100644 --- a/cmd/outline-ss-server/main.go +++ b/cmd/outline-ss-server/main.go @@ -17,6 +17,7 @@ package main import ( "container/list" "context" + "errors" "flag" "fmt" "log/slog" @@ -33,6 +34,7 @@ import ( "github.com/lmittmann/tint" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" + "golang.org/x/net/websocket" "golang.org/x/term" "github.com/Jigsaw-Code/outline-ss-server/ipinfo" @@ -62,6 +64,16 @@ func init() { ) } +type HTTPStreamListener struct { + service.StreamListener +} + +var _ net.Listener = (*HTTPStreamListener)(nil) + +func (t *HTTPStreamListener) Accept() (net.Conn, error) { + return t.StreamListener.AcceptStream() +} + type OutlineServer struct { stopConfig func() error lnManager service.ListenerManager @@ -80,7 +92,7 @@ func (s *OutlineServer) loadConfig(filename string) error { if err != nil { return fmt.Errorf("failed to load config (%v): %w", filename, err) } - if err := config.Validate(); err != nil { + if err := config.validate(); err != nil { return fmt.Errorf("failed to validate config: %w", err) } @@ -201,6 +213,32 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { }() startErrCh <- func() error { + // Start configured web servers. + webServers := make(map[string]*http.ServeMux) + for _, srvConfig := range config.Web.Servers { + if _, exists := webServers[srvConfig.ID]; exists { + return fmt.Errorf("web server with ID `%s` already exists", srvConfig.ID) + } + mux := http.NewServeMux() + for _, addr := range srvConfig.Listeners { + server := &http.Server{Addr: addr, Handler: mux} + ln, err := lnSet.ListenStream(addr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", addr, err) + } + go func() { + defer server.Shutdown(context.Background()) + err := server.Serve(&HTTPStreamListener{ln}) + if err != nil && !errors.Is(http.ErrServerClosed, err) && !errors.Is(net.ErrClosed, err) { + slog.Error("Failed to run web server.", "err", err, "ID", srvConfig.ID) + } + }() + slog.Info("Web server started.", "ID", srvConfig.ID, "address", addr) + } + webServers[srvConfig.ID] = mux + } + + // Start legacy services. totalCipherCount := len(config.Keys) portCiphers := make(map[int]*list.List) // Values are *List of *CipherEntry. for _, keyConfig := range config.Keys { @@ -251,6 +289,7 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { }, s.serverMetrics) } + // Start services with listeners. for _, serviceConfig := range config.Services { ciphers, err := newCipherListFromConfig(serviceConfig) if err != nil { @@ -267,10 +306,9 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { if err != nil { return err } - for _, lnConfig := range serviceConfig.Listeners { - switch lnConfig.Type { - case listenerTypeTCP: - ln, err := lnSet.ListenStream(lnConfig.Address) + for _, cfg := range serviceConfig.Listeners { + if cfg.TCP != nil { + ln, err := lnSet.ListenStream(cfg.TCP.Address) if err != nil { return err } @@ -283,8 +321,8 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { go service.StreamServe(ln.AcceptStream, func(ctx context.Context, conn transport.StreamConn) { streamHandler.HandleStream(ctx, conn, s.serviceMetrics.AddOpenTCPConnection(conn)) }) - case listenerTypeUDP: - pc, err := lnSet.ListenPacket(lnConfig.Address) + } else if cfg.UDP != nil { + pc, err := lnSet.ListenPacket(cfg.UDP.Address) if err != nil { return err } @@ -297,6 +335,56 @@ func (s *OutlineServer) runConfig(config Config) (func() error, error) { go service.PacketServe(pc, func(ctx context.Context, conn net.Conn) { associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn)) }, s.serverMetrics) + } else if cfg.WebsocketStream != nil { + if _, exists := webServers[cfg.WebsocketStream.WebServer]; !exists { + return fmt.Errorf("websocket-stream listener references unknown web server `%s`", cfg.WebsocketStream.WebServer) + } + mux := webServers[cfg.WebsocketStream.WebServer] + // TODO: Support a "half-closed" state for WebSockets. + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := func(wsConn *websocket.Conn) { + defer wsConn.Close() + ctx, contextCancel := context.WithCancel(context.Background()) + defer contextCancel() + // TODO: Get the forwarded client address. + raddr, err := transport.MakeNetAddr("tcp", r.RemoteAddr) + if err != nil { + slog.Error("failed to determine client address", "err", err) + w.WriteHeader(http.StatusBadGateway) + return + } + conn := &streamConn{&replaceAddrConn{Conn: wsConn, raddr: raddr}} + streamHandler.HandleStream(ctx, conn, s.serviceMetrics.AddOpenTCPConnection(conn)) + } + websocket.Handler(handler).ServeHTTP(w, r) + }) + mux.Handle(cfg.WebsocketStream.Path, http.StripPrefix(cfg.WebsocketStream.Path, handler)) + slog.Info("WebSocket stream service started.", "ID", cfg.WebsocketStream.WebServer, "path", cfg.WebsocketStream.Path) + } else if cfg.WebsocketPacket != nil { + if _, exists := webServers[cfg.WebsocketPacket.WebServer]; !exists { + return fmt.Errorf("websocket-packet listener references unknown web server `%s`", cfg.WebsocketPacket.WebServer) + } + mux := webServers[cfg.WebsocketPacket.WebServer] + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handler := func(wsConn *websocket.Conn) { + defer wsConn.Close() + ctx, contextCancel := context.WithCancel(context.Background()) + defer contextCancel() + raddr, err := transport.MakeNetAddr("udp", r.RemoteAddr) + if err != nil { + slog.Error("failed to determine client address", "err", err) + w.WriteHeader(http.StatusBadGateway) + return + } + conn := &replaceAddrConn{Conn: wsConn, raddr: raddr} + associationHandler.HandleAssociation(ctx, conn, s.serviceMetrics.AddOpenUDPAssociation(conn)) + } + websocket.Handler(handler).ServeHTTP(w, r) + }) + mux.Handle(cfg.WebsocketPacket.Path, http.StripPrefix(cfg.WebsocketPacket.Path, handler)) + slog.Info("WebSocket packet service started.", "ID", cfg.WebsocketPacket.WebServer, "path", cfg.WebsocketPacket.Path) + } else { + return fmt.Errorf("unknown listener configuration: %v", cfg) } } totalCipherCount += len(serviceConfig.Keys) @@ -364,6 +452,32 @@ func RunOutlineServer(filename string, natTimeout time.Duration, serverMetrics * return server, nil } +// TODO: Create a dedicated `ClientConn` struct with `ClientAddr` and `Conn`. +// replaceAddrConn overrides [websocket.Conn]'s remote address handling. +type replaceAddrConn struct { + *websocket.Conn + raddr net.Addr +} + +func (c replaceAddrConn) RemoteAddr() net.Addr { + return c.raddr +} + +type streamConn struct { + net.Conn +} + +var _ transport.StreamConn = (*streamConn)(nil) + +// TODO: Support a "half-closed" state. +func (c *streamConn) CloseRead() error { + return c.Close() +} + +func (c *streamConn) CloseWrite() error { + return c.Close() +} + func main() { slog.SetDefault(slog.New(logHandler)) diff --git a/go.mod b/go.mod index 0018d9d0..338951fc 100644 --- a/go.mod +++ b/go.mod @@ -133,6 +133,7 @@ require ( github.com/go-openapi/validate v0.22.1 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/go-telegram-bot-api/telegram-bot-api v4.6.4+incompatible // indirect + github.com/go-viper/mapstructure/v2 v2.2.1 // indirect github.com/gobwas/glob v0.2.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt/v4 v4.5.1 // indirect diff --git a/go.sum b/go.sum index 110613aa..780a2180 100644 --- a/go.sum +++ b/go.sum @@ -1190,6 +1190,8 @@ github.com/go-telegram-bot-api/telegram-bot-api v4.6.4+incompatible/go.mod h1:qf github.com/go-test/deep v1.0.4/go.mod h1:wGDj63lr65AM2AQyKZd/NYHGb0R+1RLqB8NKt3aSFNA= github.com/go-test/deep v1.1.0 h1:WOcxcdHcvdgThNXjw0t76K42FXTU7HpNQWHpA2HHNlg= github.com/go-test/deep v1.1.0/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= +github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss= +github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/go-zookeeper/zk v1.0.2/go.mod h1:nOB03cncLtlp4t+UAkGSV+9beXP/akpekBwL+UX1Qcw= github.com/go-zookeeper/zk v1.0.3/go.mod h1:nOB03cncLtlp4t+UAkGSV+9beXP/akpekBwL+UX1Qcw= github.com/gobuffalo/attrs v0.0.0-20190224210810-a9411de4debd/go.mod h1:4duuawTqi2wkkpB4ePgWMaai6/Kc6WEz83bhFwpHzj0=