From f49858d6192a77e3bcb6aabde27a06c225dadfcf Mon Sep 17 00:00:00 2001
From: Mark Puha
Date: Mon, 24 Feb 2025 11:17:50 +0100
Subject: [PATCH] Solve race condition and reinit bug
---
.gitignore | 2 +-
Dockerfile | 2 +-
Makefile | 10 +++---
README.md | 6 ++--
device/device.go | 20 +++++++----
device/internal/adapter/lua.go | 19 +++++++---
device/receive.go | 10 +++---
device/uapi.go | 65 ++++++++++++++++++----------------
ipc/uapi_unix.go | 2 +-
ipc/uapi_windows.go | 2 +-
10 files changed, 78 insertions(+), 60 deletions(-)
diff --git a/.gitignore b/.gitignore
index cb8b7caa1..57fe9ed59 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1 @@
-euphoria
+amneziawg-go
diff --git a/Dockerfile b/Dockerfile
index de4696e25..9b764b9ab 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -17,4 +17,4 @@ RUN apk --no-cache add iproute2 iptables bash && \
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
ln -s /usr/bin/awg /usr/bin/wg && \
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
-COPY --from=euphoria /usr/bin/euphoria /usr/bin/euphoria
+COPY --from=euphoria /usr/bin/euphoria /usr/bin/amneziawg-go
diff --git a/Makefile b/Makefile
index 664af718b..091ac4d06 100644
--- a/Makefile
+++ b/Makefile
@@ -14,18 +14,18 @@ generate-version-and-build:
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > version.go && \
git update-index --assume-unchanged version.go || true
- @$(MAKE) euphoria
+ @$(MAKE) amneziawg-go
-euphoria: $(wildcard *.go) $(wildcard */*.go)
+amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
go build -tags luajit -ldflags="-w -s" -trimpath -v -o "$@"
-install: euphoria
- @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/euphoria"
+install: amneziawg-go
+ @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
test:
go test ./...
clean:
- rm -f euphoria
+ rm -f amneziawg-go
.PHONY: all clean test install generate-version-and-build
diff --git a/README.md b/README.md
index fb35bbdd0..dbe00d685 100644
--- a/README.md
+++ b/README.md
@@ -11,15 +11,15 @@ As a result, AmneziaWG maintains high performance while adding an extra layer of
Simply run:
```
-$ euphoria wg0
+$ amneziawg-go wg1
```
-This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/euphoria/wg0.sock`, which will result in euphoria shutting down.
+This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg-go/wg0.sock`, which will result in euphoria shutting down.
To run euphoria without forking to the background, pass `-f` or `--foreground`:
```
-$ euphoria -f wg0
+$ amneziawg-go -f wg0
```
When an interface is running, you may use [`euphoria-tools`](https://github.com/amnezia-vpn/euphoria-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
diff --git a/device/device.go b/device/device.go
index 263248855..76c743782 100644
--- a/device/device.go
+++ b/device/device.go
@@ -102,7 +102,10 @@ type awgType struct {
aSecCfg aSecCfgType
junkCreator junkCreator
- codec *adapter.Lua
+ codec struct {
+ adapter *adapter.Lua
+ isOn bool
+ }
}
type aSecCfgType struct {
@@ -434,9 +437,9 @@ func (device *Device) Close() {
device.resetProtocol()
- if device.awg.codec != nil {
- device.awg.codec.Close()
- device.awg.codec = nil
+ if device.awg.codec.adapter != nil {
+ device.awg.codec.adapter.Close()
+ device.awg.codec.adapter = nil
}
device.log.Verbosef("Device closed")
close(device.closed)
@@ -596,7 +599,10 @@ func (device *Device) resetProtocol() {
}
func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
- device.awg.codec = tempAwgType.codec
+ if tempAwgType.codec.isOn {
+ device.awg.codec.adapter = tempAwgType.codec.adapter
+ device.awg.codec.isOn = tempAwgType.codec.adapter != nil
+ }
if !tempAwgType.aSecCfg.isSet {
return nil
@@ -836,13 +842,13 @@ func (device *Device) handlePostConfig(tempAwgType *awgType) (err error) {
}
func (device *Device) isCodecActive() bool {
- return device.awg.codec != nil
+ return device.awg.codec.adapter != nil
}
func (device *Device) codecPacketIfActive(msgType uint32, packet []byte) ([]byte, error) {
if device.isCodecActive() {
var err error
- packet, err = device.awg.codec.Generate(int64(msgType),packet)
+ packet, err = device.awg.codec.adapter.Generate(int64(msgType),packet)
if err != nil {
device.log.Errorf("%v - Failed to run codec generate: %v", device, err)
return nil, err
diff --git a/device/internal/adapter/lua.go b/device/internal/adapter/lua.go
index 467734560..871528139 100644
--- a/device/internal/adapter/lua.go
+++ b/device/internal/adapter/lua.go
@@ -3,15 +3,16 @@ package adapter
import (
"encoding/base64"
"fmt"
- "sync/atomic"
+ "sync"
"github.com/aarzilli/golua/lua"
)
type Lua struct {
generateState *lua.State
+ mux sync.Mutex
parseState *lua.State
- packetCounter atomic.Int64
+ packetCnt int64
base64LuaCode string
}
@@ -58,19 +59,26 @@ func (l *Lua) Close() {
l.parseState.Close()
}
-// Only thread safe if used by wg packet creation which happens independably
func (l *Lua) Generate(
msgType int64,
data []byte,
) ([]byte, error) {
+ l.mux.Lock()
+ defer l.mux.Unlock()
+
l.generateState.GetGlobal("d_gen")
l.generateState.PushInteger(msgType)
l.generateState.PushBytes(data)
- l.generateState.PushInteger(l.packetCounter.Add(1))
+ l.generateState.PushInteger(l.packetCnt)
+ l.packetCnt++
if err := l.generateState.Call(3, 1); err != nil {
- return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
+ return nil, fmt.Errorf(
+ "Error calling Lua function: %v\ntrace: %v",
+ err,
+ l.generateState.StackTrace(),
+ )
}
result := l.generateState.ToBytes(-1)
@@ -84,6 +92,7 @@ func (l *Lua) Parse(data []byte) ([]byte, error) {
l.parseState.GetGlobal("d_parse")
l.parseState.PushBytes(data)
+
if err := l.parseState.Call(1, 1); err != nil {
return nil, fmt.Errorf("Error calling Lua function: %v\n", err)
}
diff --git a/device/receive.go b/device/receive.go
index 2a3189d6d..c9699b49f 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -138,10 +138,7 @@ func (device *Device) RoutineReceiveIncoming(
packet := bufsArrs[i][:size]
if device.isCodecActive() {
- realPacket, err := device.awg.codec.Parse(packet)
- copy(packet, realPacket)
- size = len(realPacket)
- packet = bufsArrs[i][:size]
+ realPacket, err := device.awg.codec.adapter.Parse(packet)
if err != nil {
device.log.Verbosef(
"Couldn't parse message; reason: %v",
@@ -149,6 +146,9 @@ func (device *Device) RoutineReceiveIncoming(
)
continue
}
+ copy(packet, realPacket)
+ size = len(realPacket)
+ packet = bufsArrs[i][:size]
}
var msgType uint32
if device.isAdvancedSecurityOn() {
@@ -166,7 +166,7 @@ func (device *Device) RoutineReceiveIncoming(
} else {
msgType = binary.LittleEndian.Uint32(packet[:4])
if msgType != MessageTransportType {
- device.log.Verbosef("ASec: Received message with unknown type")
+ device.log.Verbosef("ASec: Received message with unknown type: %d", msgType)
continue
}
}
diff --git a/device/uapi.go b/device/uapi.go
index cf377bae4..87a397639 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -98,8 +98,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("fwmark=%d", device.net.fwmark)
}
- if device.awg.codec != nil {
- sendf("lua_codec=%s", device.awg.codec.Base64LuaCode())
+ if device.awg.codec.adapter != nil {
+ sendf("lua_codec=%s", device.awg.codec.adapter.Base64LuaCode())
}
if device.isAdvancedSecurityOn() {
if device.awg.aSecCfg.junkPacketCount != 0 {
@@ -184,13 +184,13 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
peer := new(ipcSetPeer)
deviceConfig := true
- tempAwgTpe := awgType{}
+ tempAwgType := awgType{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line := scanner.Text()
if line == "" {
// Blank line means terminate operation.
- err := device.handlePostConfig(&tempAwgTpe)
+ err := device.handlePostConfig(&tempAwgType)
if err != nil {
return err
}
@@ -221,7 +221,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
var err error
if deviceConfig {
- err = device.handleDeviceLine(key, value, &tempAwgTpe)
+ err = device.handleDeviceLine(key, value, &tempAwgType)
} else {
err = device.handlePeerLine(peer, key, value)
}
@@ -229,7 +229,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return err
}
}
- err = device.handlePostConfig(&tempAwgTpe)
+ err = device.handlePostConfig(&tempAwgType)
if err != nil {
return err
}
@@ -241,7 +241,7 @@ func (device *Device) IpcSetOperation(r io.Reader) (err error) {
return nil
}
-func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) error {
+func (device *Device) handleDeviceLine(key, value string, tempAwgType *awgType) error {
switch key {
case "private_key":
var sk NoisePrivateKey
@@ -289,21 +289,24 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e
case "lua_codec":
device.log.Verbosef("UAPI: Updating lua_codec")
- var err error
- tempAwgTpe.codec, err = adapter.NewLua(adapter.LuaParams{
- Base64LuaCode: value,
- })
- if err != nil {
- return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err)
+ if len(value) != 0 {
+ var err error
+ tempAwgType.codec.adapter, err = adapter.NewLua(adapter.LuaParams{
+ Base64LuaCode: value,
+ })
+ if err != nil {
+ return ipcErrorf(ipc.IpcErrorInvalid, "invalid lua_codec: %w", err)
+ }
}
+ tempAwgType.codec.isOn = true
case "jc":
junkPacketCount, err := strconv.Atoi(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_count")
- tempAwgTpe.aSecCfg.junkPacketCount = junkPacketCount
- tempAwgTpe.aSecCfg.isSet = true
+ tempAwgType.aSecCfg.junkPacketCount = junkPacketCount
+ tempAwgType.aSecCfg.isSet = true
case "jmin":
junkPacketMinSize, err := strconv.Atoi(value)
@@ -311,8 +314,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
- tempAwgTpe.aSecCfg.junkPacketMinSize = junkPacketMinSize
- tempAwgTpe.aSecCfg.isSet = true
+ tempAwgType.aSecCfg.junkPacketMinSize = junkPacketMinSize
+ tempAwgType.aSecCfg.isSet = true
case "jmax":
junkPacketMaxSize, err := strconv.Atoi(value)
@@ -320,8 +323,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
}
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
- tempAwgTpe.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
- tempAwgTpe.aSecCfg.isSet = true
+ tempAwgType.aSecCfg.junkPacketMaxSize = junkPacketMaxSize
+ tempAwgType.aSecCfg.isSet = true
case "s1":
initPacketJunkSize, err := strconv.Atoi(value)
@@ -329,8 +332,8 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
- tempAwgTpe.aSecCfg.initPacketJunkSize = initPacketJunkSize
- tempAwgTpe.aSecCfg.isSet = true
+ tempAwgType.aSecCfg.initPacketJunkSize = initPacketJunkSize
+ tempAwgType.aSecCfg.isSet = true
case "s2":
responsePacketJunkSize, err := strconv.Atoi(value)
@@ -338,40 +341,40 @@ func (device *Device) handleDeviceLine(key, value string, tempAwgTpe *awgType) e
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
}
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
- tempAwgTpe.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
- tempAwgTpe.aSecCfg.isSet = true
+ tempAwgType.aSecCfg.responsePacketJunkSize = responsePacketJunkSize
+ tempAwgType.aSecCfg.isSet = true
case "h1":
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
}
- tempAwgTpe.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
- tempAwgTpe.aSecCfg.isSet = true
+ tempAwgType.aSecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
+ tempAwgType.aSecCfg.isSet = true
case "h2":
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
}
- tempAwgTpe.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
- tempAwgTpe.aSecCfg.isSet = true
+ tempAwgType.aSecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
+ tempAwgType.aSecCfg.isSet = true
case "h3":
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
}
- tempAwgTpe.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
- tempAwgTpe.aSecCfg.isSet = true
+ tempAwgType.aSecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
+ tempAwgType.aSecCfg.isSet = true
case "h4":
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
}
- tempAwgTpe.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
- tempAwgTpe.aSecCfg.isSet = true
+ tempAwgType.aSecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
+ tempAwgType.aSecCfg.isSet = true
default:
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
diff --git a/ipc/uapi_unix.go b/ipc/uapi_unix.go
index 4a5aff317..0da452a7e 100644
--- a/ipc/uapi_unix.go
+++ b/ipc/uapi_unix.go
@@ -26,7 +26,7 @@ const (
// socketDirectory is variable because it is modified by a linker
// flag in wireguard-android.
-var socketDirectory = "/var/run/euphoria"
+var socketDirectory = "/var/run/amneziawg"
func sockPath(iface string) string {
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go
index 829861642..42423c8c8 100644
--- a/ipc/uapi_windows.go
+++ b/ipc/uapi_windows.go
@@ -62,7 +62,7 @@ func init() {
func UAPIListen(name string) (net.Listener, error) {
listener, err := (&namedpipe.ListenConfig{
SecurityDescriptor: UAPISecurityDescriptor,
- }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\Euphoria\` + name)
+ }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
if err != nil {
return nil, err
}