From f49858d6192a77e3bcb6aabde27a06c225dadfcf Mon Sep 17 00:00:00 2001 From: Mark Puha Date: Mon, 24 Feb 2025 11:17:50 +0100 Subject: [PATCH] Solve race condition and reinit bug --- .gitignore | 2 +- Dockerfile | 2 +- Makefile | 10 +++--- README.md | 6 ++-- device/device.go | 20 +++++++---- device/internal/adapter/lua.go | 19 +++++++--- device/receive.go | 10 +++--- device/uapi.go | 65 ++++++++++++++++++---------------- ipc/uapi_unix.go | 2 +- ipc/uapi_windows.go | 2 +- 10 files changed, 78 insertions(+), 60 deletions(-) diff --git a/.gitignore b/.gitignore index cb8b7caa1..57fe9ed59 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -euphoria +amneziawg-go diff --git a/Dockerfile b/Dockerfile index de4696e25..9b764b9ab 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,4 +17,4 @@ RUN apk --no-cache add iproute2 iptables bash && \ chmod +x /usr/bin/awg /usr/bin/awg-quick && \ ln -s /usr/bin/awg /usr/bin/wg && \ ln -s /usr/bin/awg-quick /usr/bin/wg-quick -COPY --from=euphoria /usr/bin/euphoria /usr/bin/euphoria +COPY --from=euphoria /usr/bin/euphoria /usr/bin/amneziawg-go diff --git a/Makefile b/Makefile index 664af718b..091ac4d06 100644 --- a/Makefile +++ b/Makefile @@ -14,18 +14,18 @@ generate-version-and-build: [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ echo "$$ver" > version.go && \ git update-index --assume-unchanged version.go || true - @$(MAKE) euphoria + @$(MAKE) amneziawg-go -euphoria: $(wildcard *.go) $(wildcard */*.go) +amneziawg-go: $(wildcard *.go) $(wildcard */*.go) go build -tags luajit -ldflags="-w -s" -trimpath -v -o "$@" -install: euphoria - @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/euphoria" +install: amneziawg-go + @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go" test: go test ./... clean: - rm -f euphoria + rm -f amneziawg-go .PHONY: all clean test install generate-version-and-build diff --git a/README.md b/README.md index fb35bbdd0..dbe00d685 100644 --- a/README.md +++ b/README.md @@ -11,15 +11,15 @@ As a result, AmneziaWG maintains high performance while adding an extra layer of Simply run: ``` -$ euphoria wg0 +$ amneziawg-go wg1 ``` -This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/euphoria/wg0.sock`, which will result in euphoria shutting down. +This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg-go/wg0.sock`, which will result in euphoria shutting down. To run euphoria without forking to the background, pass `-f` or `--foreground`: ``` -$ euphoria -f wg0 +$ amneziawg-go -f wg0 ``` When an interface is running, you may use [`euphoria-tools`](https://github.com/amnezia-vpn/euphoria-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands. diff --git a/device/device.go b/device/device.go index 263248855..76c743782 100644 --- a/device/device.go +++ b/device/device.go @@ -102,7 +102,10 @@ type awgType struct { aSecCfg aSecCfgType junkCreator junkCreator - codec *adapter.Lua + codec struct { + adapter *adapter.Lua + isOn bool + } } type aSecCfgType struct { @@ -434,9 +437,9 @@ func (device *Device) Close() { device.resetProtocol() - if device.awg.codec != nil { - device.awg.codec.Close() - device.awg.codec = nil + if device.awg.codec.adapter != nil { + device.awg.codec.adapter.Close() + device.awg.codec.adapter = nil } device.log.Verbosef("Device closed") close(device.closed) @@ -596,7 +599,10 @@ func (device *Device) resetProtocol() { } func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { - device.awg.codec = tempAwgType.codec + if tempAwgType.codec.isOn { + device.awg.codec.adapter = tempAwgType.codec.adapter + device.awg.codec.isOn = tempAwgType.codec.adapter != nil + } if !tempAwgType.aSecCfg.isSet { return nil @@ -836,13 +842,13 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) { } func (device *Device) isCodecActive() bool { - return device.awg.codec != nil + return device.awg.codec.adapter != nil } func (device *Device) codecPacketIfActive(msgType uint32, packet []byte) ([]byte, error) { if device.isCodecActive() { var err error - packet, err = device.awg.codec.Generate(int64(msgType),packet) + packet, err = device.awg.codec.adapter.Generate(int64(msgType),packet) if err != nil { device.log.Errorf("%v - Failed to run codec generate: %v", device, err) return nil, err diff --git a/device/internal/adapter/lua.go b/device/internal/adapter/lua.go index 467734560..871528139 100644 --- a/device/internal/adapter/lua.go +++ b/device/internal/adapter/lua.go @@ -3,15 +3,16 @@ package adapter import ( "encoding/base64" "fmt" - "sync/atomic" + "sync" "github.com/aarzilli/golua/lua" ) type Lua struct { generateState *lua.State + mux sync.Mutex parseState *lua.State - packetCounter atomic.Int64 + packetCnt int64 base64LuaCode string } @@ -58,19 +59,26 @@ func (l *Lua) Close() { l.parseState.Close() } -// Only thread safe if used by wg packet creation which happens independably func (l *Lua) Generate( msgType int64, data []byte, ) ([]byte, error) { + l.mux.Lock() + defer l.mux.Unlock() + l.generateState.GetGlobal("d_gen") l.generateState.PushInteger(msgType) l.generateState.PushBytes(data) - l.generateState.PushInteger(l.packetCounter.Add(1)) + l.generateState.PushInteger(l.packetCnt) + l.packetCnt++ if err := l.generateState.Call(3, 1); err != nil { - return nil, fmt.Errorf("Error calling Lua function: %v\n", err) + return nil, fmt.Errorf( + "Error calling Lua function: %v\ntrace: %v", + err, + l.generateState.StackTrace(), + ) } result := l.generateState.ToBytes(-1) @@ -84,6 +92,7 @@ func (l *Lua) Parse(data []byte) ([]byte, error) { l.parseState.GetGlobal("d_parse") l.parseState.PushBytes(data) + if err := l.parseState.Call(1, 1); err != nil { return nil, fmt.Errorf("Error calling Lua function: %v\n", err) } diff --git a/device/receive.go b/device/receive.go index 2a3189d6d..c9699b49f 100644 --- a/device/receive.go +++ b/device/receive.go @@ -138,10 +138,7 @@ func (device *Device) RoutineReceiveIncoming( packet := bufsArrs[i][:size] if device.isCodecActive() { - realPacket, err := device.awg.codec.Parse(packet) - copy(packet, realPacket) - size = len(realPacket) - packet = bufsArrs[i][:size] + realPacket, err := device.awg.codec.adapter.Parse(packet) if err != nil { device.log.Verbosef( "Couldn't parse message; reason: %v", @@ -149,6 +146,9 @@ func (device *Device) RoutineReceiveIncoming( ) continue } + copy(packet, realPacket) + size = len(realPacket) + packet = bufsArrs[i][:size] } var msgType uint32 if device.isAdvancedSecurityOn() { @@ -166,7 +166,7 @@ func (device *Device) RoutineReceiveIncoming( } else { msgType = binary.LittleEndian.Uint32(packet[:4]) if msgType != MessageTransportType { - device.log.Verbosef("ASec: Received message with unknown type") + device.log.Verbosef("ASec: Received message with unknown type: %d", msgType) continue } } diff --git a/device/uapi.go b/device/uapi.go index cf377bae4..87a397639 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -98,8 +98,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error { sendf("fwmark=%d", device.net.fwmark) } - if device.awg.codec != nil { - sendf("lua_codec=%s", device.awg.codec.Base64LuaCode()) + if device.awg.codec.adapter != nil { + sendf("lua_codec=%s", device.awg.codec.adapter.Base64LuaCode()) } if device.isAdvancedSecurityOn() { if device.awg.aSecCfg.junkPacketCount != 0 { @@ -184,13 +184,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { peer := new(ipcSetPeer) deviceConfig := true - tempAwgTpe := awgType{} + tempAwgType := awgType{} scanner := bufio.NewScanner(r) for scanner.Scan() { line := scanner.Text() if line == "" { // Blank line means terminate operation. - err := device.handlePostConfig(&tempAwgTpe) + err := device.handlePostConfig(&tempAwgType) if err != nil { return err } @@ -221,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { var err error if deviceConfig { - err = device.handleDeviceLine(key, value, &tempAwgTpe) + err = device.handleDeviceLine(key, value, &tempAwgType) } else { err = device.handlePeerLine(peer, key, value) } @@ -229,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return err } } - err = device.handlePostConfig(&tempAwgTpe) + err = device.handlePostConfig(&tempAwgType) if err != nil { return err } @@ -241,7 +241,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) { return nil } -func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) error { +func (device *Device) handleDeviceLine(key, value string, tempAwgType *awgType) error { switch key { case "private_key": var sk NoisePrivateKey @@ -289,21 +289,24 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e case "lua_codec": device.log.Verbosef("UAPI: Updating lua_codec") - var err error - tempAwgTpe.codec, err = adapter.NewLua(adapter.LuaParams{ - Base64LuaCode: value, - }) - if err != nil { - return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err) + if len(value) != 0 { + var err error + tempAwgType.codec.adapter, err = adapter.NewLua(adapter.LuaParams{ + Base64LuaCode: value, + }) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err) + } } + tempAwgType.codec.isOn = true case "jc": junkPacketCount, err := strconv.Atoi(value) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_count") - tempAwgTpe.aSecCfg.junkPacketCount = junkPacketCount - tempAwgTpe.aSecCfg.isSet = true + tempAwgType.aSecCfg.junkPacketCount = junkPacketCount + tempAwgType.aSecCfg.isSet = true case "jmin": junkPacketMinSize, err := strconv.Atoi(value) @@ -311,8 +314,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_min_size") - tempAwgTpe.aSecCfg.junkPacketMinSize = junkPacketMinSize - tempAwgTpe.aSecCfg.isSet = true + tempAwgType.aSecCfg.junkPacketMinSize = junkPacketMinSize + tempAwgType.aSecCfg.isSet = true case "jmax": junkPacketMaxSize, err := strconv.Atoi(value) @@ -320,8 +323,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err) } device.log.Verbosef("UAPI: Updating junk_packet_max_size") - tempAwgTpe.aSecCfg.junkPacketMaxSize = junkPacketMaxSize - tempAwgTpe.aSecCfg.isSet = true + tempAwgType.aSecCfg.junkPacketMaxSize = junkPacketMaxSize + tempAwgType.aSecCfg.isSet = true case "s1": initPacketJunkSize, err := strconv.Atoi(value) @@ -329,8 +332,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating init_packet_junk_size") - tempAwgTpe.aSecCfg.initPacketJunkSize = initPacketJunkSize - tempAwgTpe.aSecCfg.isSet = true + tempAwgType.aSecCfg.initPacketJunkSize = initPacketJunkSize + tempAwgType.aSecCfg.isSet = true case "s2": responsePacketJunkSize, err := strconv.Atoi(value) @@ -338,40 +341,40 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err) } device.log.Verbosef("UAPI: Updating response_packet_junk_size") - tempAwgTpe.aSecCfg.responsePacketJunkSize = responsePacketJunkSize - tempAwgTpe.aSecCfg.isSet = true + tempAwgType.aSecCfg.responsePacketJunkSize = responsePacketJunkSize + tempAwgType.aSecCfg.isSet = true case "h1": initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err) } - tempAwgTpe.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) - tempAwgTpe.aSecCfg.isSet = true + tempAwgType.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader) + tempAwgType.aSecCfg.isSet = true case "h2": responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err) } - tempAwgTpe.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) - tempAwgTpe.aSecCfg.isSet = true + tempAwgType.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader) + tempAwgType.aSecCfg.isSet = true case "h3": underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err) } - tempAwgTpe.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) - tempAwgTpe.aSecCfg.isSet = true + tempAwgType.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader) + tempAwgType.aSecCfg.isSet = true case "h4": transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32) if err != nil { return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err) } - tempAwgTpe.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) - tempAwgTpe.aSecCfg.isSet = true + tempAwgType.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader) + tempAwgType.aSecCfg.isSet = true default: return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) diff --git a/ipc/uapi_unix.go b/ipc/uapi_unix.go index 4a5aff317..0da452a7e 100644 --- a/ipc/uapi_unix.go +++ b/ipc/uapi_unix.go @@ -26,7 +26,7 @@ const ( // socketDirectory is variable because it is modified by a linker // flag in wireguard-android. -var socketDirectory = "/var/run/euphoria" +var socketDirectory = "/var/run/amneziawg" func sockPath(iface string) string { return fmt.Sprintf("%s/%s.sock", socketDirectory, iface) diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index 829861642..42423c8c8 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -62,7 +62,7 @@ func init() { func UAPIListen(name string) (net.Listener, error) { listener, err := (&namedpipe.ListenConfig{ SecurityDescriptor: UAPISecurityDescriptor, - }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\Euphoria\` + name) + }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name) if err != nil { return nil, err }