Skip to content

Commit 8a54c1b

Browse files
authored
Merge pull request #326 from wdvxdr1123/patch-simd-mask
use simd masking for amd64&arm64
2 parents 22c9092 + 2cd18b3 commit 8a54c1b

21 files changed

+621
-152
lines changed

.github/workflows/daily.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,5 @@ jobs:
5050
- run: AUTOBAHN=1 ./ci/test.sh
5151
- uses: actions/upload-artifact@v3
5252
with:
53-
name: coverage.html
53+
name: coverage-dev.html
5454
path: ./ci/out/coverage.html

README.md

+2-4
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ go get nhooyr.io/websocket
2626
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
2727
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections
2828
- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm)
29+
- WebSocket masking implemented in assembly for amd64 and arm64 [#326](https://github.com/nhooyr/websocket/issues/326)
2930

3031
## Roadmap
3132

@@ -36,8 +37,6 @@ See GitHub issues for minor issues but the major future enhancements are:
3637
- [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267)
3738
- [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246)
3839
- [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209)
39-
- [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16)
40-
- WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster
4140
- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4)
4241
- [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402)
4342

@@ -121,9 +120,8 @@ Advantages of nhooyr.io/websocket:
121120
- Gorilla requires registering a pong callback before sending a Ping
122121
- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
123122
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage
124-
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
123+
- [3.5x](https://github.com/nhooyr/websocket/pull/326#issuecomment-1959470758) faster WebSocket masking implementation in assembly for amd64 and arm64 and [2x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster implementation in pure Go
125124
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
126-
Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326)
127125
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
128126
- Gorilla only supports no context takeover mode
129127
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))

ci/bench.sh

+13-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,19 @@
22
set -eu
33
cd -- "$(dirname "$0")/.."
44

5-
go test --run=^$ --bench=. --benchmem --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test "$@" .
5+
go test --run=^$ --bench=. --benchmem "$@" ./...
6+
# For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test
67
(
78
cd ./internal/thirdparty
8-
go test --run=^$ --bench=. --benchmem --memprofile ../../ci/out/prof-thirdparty.mem --cpuprofile ../../ci/out/prof-thirdparty.cpu -o ../../ci/out/thirdparty.test "$@" .
9+
go test --run=^$ --bench=. --benchmem "$@" .
10+
11+
GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" .
12+
if [ "$#" -eq 0 ]; then
13+
if [ "${CI-}" ]; then
14+
sudo apt-get update
15+
sudo apt-get install -y qemu-user-static
16+
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
17+
fi
18+
qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem
19+
fi
920
)

ci/fmt.sh

+4
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,7 @@ npx [email protected] \
1818
$(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html")
1919

2020
go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go
21+
22+
if [ "${CI-}" ]; then
23+
git diff --exit-code
24+
fi

ci/test.sh

+13
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,19 @@ cd -- "$(dirname "$0")/.."
1111
go test "$@" ./...
1212
)
1313

14+
(
15+
GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" .
16+
if [ "$#" -eq 0 ]; then
17+
if [ "${CI-}" ]; then
18+
sudo apt-get update
19+
sudo apt-get install -y qemu-user-static
20+
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
21+
fi
22+
qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask
23+
fi
24+
)
25+
26+
1427
go install github.com/agnivade/wasmbrowsertest@latest
1528
go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./...
1629
sed -i.bak '/stringer\.go/d' ci/out/coverage.prof

frame.go

-123
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"fmt"
99
"io"
1010
"math"
11-
"math/bits"
1211

1312
"nhooyr.io/websocket/internal/errd"
1413
)
@@ -172,125 +171,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
172171

173172
return nil
174173
}
175-
176-
// mask applies the WebSocket masking algorithm to p
177-
// with the given key.
178-
// See https://tools.ietf.org/html/rfc6455#section-5.3
179-
//
180-
// The returned value is the correctly rotated key to
181-
// to continue to mask/unmask the message.
182-
//
183-
// It is optimized for LittleEndian and expects the key
184-
// to be in little endian.
185-
//
186-
// See https://github.com/golang/go/issues/31586
187-
func mask(key uint32, b []byte) uint32 {
188-
if len(b) >= 8 {
189-
key64 := uint64(key)<<32 | uint64(key)
190-
191-
// At some point in the future we can clean these unrolled loops up.
192-
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
193-
194-
// Then we xor until b is less than 128 bytes.
195-
for len(b) >= 128 {
196-
v := binary.LittleEndian.Uint64(b)
197-
binary.LittleEndian.PutUint64(b, v^key64)
198-
v = binary.LittleEndian.Uint64(b[8:16])
199-
binary.LittleEndian.PutUint64(b[8:16], v^key64)
200-
v = binary.LittleEndian.Uint64(b[16:24])
201-
binary.LittleEndian.PutUint64(b[16:24], v^key64)
202-
v = binary.LittleEndian.Uint64(b[24:32])
203-
binary.LittleEndian.PutUint64(b[24:32], v^key64)
204-
v = binary.LittleEndian.Uint64(b[32:40])
205-
binary.LittleEndian.PutUint64(b[32:40], v^key64)
206-
v = binary.LittleEndian.Uint64(b[40:48])
207-
binary.LittleEndian.PutUint64(b[40:48], v^key64)
208-
v = binary.LittleEndian.Uint64(b[48:56])
209-
binary.LittleEndian.PutUint64(b[48:56], v^key64)
210-
v = binary.LittleEndian.Uint64(b[56:64])
211-
binary.LittleEndian.PutUint64(b[56:64], v^key64)
212-
v = binary.LittleEndian.Uint64(b[64:72])
213-
binary.LittleEndian.PutUint64(b[64:72], v^key64)
214-
v = binary.LittleEndian.Uint64(b[72:80])
215-
binary.LittleEndian.PutUint64(b[72:80], v^key64)
216-
v = binary.LittleEndian.Uint64(b[80:88])
217-
binary.LittleEndian.PutUint64(b[80:88], v^key64)
218-
v = binary.LittleEndian.Uint64(b[88:96])
219-
binary.LittleEndian.PutUint64(b[88:96], v^key64)
220-
v = binary.LittleEndian.Uint64(b[96:104])
221-
binary.LittleEndian.PutUint64(b[96:104], v^key64)
222-
v = binary.LittleEndian.Uint64(b[104:112])
223-
binary.LittleEndian.PutUint64(b[104:112], v^key64)
224-
v = binary.LittleEndian.Uint64(b[112:120])
225-
binary.LittleEndian.PutUint64(b[112:120], v^key64)
226-
v = binary.LittleEndian.Uint64(b[120:128])
227-
binary.LittleEndian.PutUint64(b[120:128], v^key64)
228-
b = b[128:]
229-
}
230-
231-
// Then we xor until b is less than 64 bytes.
232-
for len(b) >= 64 {
233-
v := binary.LittleEndian.Uint64(b)
234-
binary.LittleEndian.PutUint64(b, v^key64)
235-
v = binary.LittleEndian.Uint64(b[8:16])
236-
binary.LittleEndian.PutUint64(b[8:16], v^key64)
237-
v = binary.LittleEndian.Uint64(b[16:24])
238-
binary.LittleEndian.PutUint64(b[16:24], v^key64)
239-
v = binary.LittleEndian.Uint64(b[24:32])
240-
binary.LittleEndian.PutUint64(b[24:32], v^key64)
241-
v = binary.LittleEndian.Uint64(b[32:40])
242-
binary.LittleEndian.PutUint64(b[32:40], v^key64)
243-
v = binary.LittleEndian.Uint64(b[40:48])
244-
binary.LittleEndian.PutUint64(b[40:48], v^key64)
245-
v = binary.LittleEndian.Uint64(b[48:56])
246-
binary.LittleEndian.PutUint64(b[48:56], v^key64)
247-
v = binary.LittleEndian.Uint64(b[56:64])
248-
binary.LittleEndian.PutUint64(b[56:64], v^key64)
249-
b = b[64:]
250-
}
251-
252-
// Then we xor until b is less than 32 bytes.
253-
for len(b) >= 32 {
254-
v := binary.LittleEndian.Uint64(b)
255-
binary.LittleEndian.PutUint64(b, v^key64)
256-
v = binary.LittleEndian.Uint64(b[8:16])
257-
binary.LittleEndian.PutUint64(b[8:16], v^key64)
258-
v = binary.LittleEndian.Uint64(b[16:24])
259-
binary.LittleEndian.PutUint64(b[16:24], v^key64)
260-
v = binary.LittleEndian.Uint64(b[24:32])
261-
binary.LittleEndian.PutUint64(b[24:32], v^key64)
262-
b = b[32:]
263-
}
264-
265-
// Then we xor until b is less than 16 bytes.
266-
for len(b) >= 16 {
267-
v := binary.LittleEndian.Uint64(b)
268-
binary.LittleEndian.PutUint64(b, v^key64)
269-
v = binary.LittleEndian.Uint64(b[8:16])
270-
binary.LittleEndian.PutUint64(b[8:16], v^key64)
271-
b = b[16:]
272-
}
273-
274-
// Then we xor until b is less than 8 bytes.
275-
for len(b) >= 8 {
276-
v := binary.LittleEndian.Uint64(b)
277-
binary.LittleEndian.PutUint64(b, v^key64)
278-
b = b[8:]
279-
}
280-
}
281-
282-
// Then we xor until b is less than 4 bytes.
283-
for len(b) >= 4 {
284-
v := binary.LittleEndian.Uint32(b)
285-
binary.LittleEndian.PutUint32(b, v^key)
286-
b = b[4:]
287-
}
288-
289-
// xor remaining bytes.
290-
for i := range b {
291-
b[i] ^= byte(key)
292-
key = bits.RotateLeft32(key, -8)
293-
}
294-
295-
return key
296-
}

frame_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func Test_mask(t *testing.T) {
9797
key := []byte{0xa, 0xb, 0xc, 0xff}
9898
key32 := binary.LittleEndian.Uint32(key)
9999
p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
100-
gotKey32 := mask(key32, p)
100+
gotKey32 := mask(p, key32)
101101

102102
expP := []byte{0, 0, 0, 0x0d, 0x6}
103103
assert.Equal(t, "p", expP, p)

go.sum

Whitespace-only changes.

internal/thirdparty/frame_test.go

+49-15
Original file line numberDiff line numberDiff line change
@@ -2,41 +2,54 @@ package thirdparty
22

33
import (
44
"encoding/binary"
5+
"runtime"
56
"strconv"
67
"testing"
78
_ "unsafe"
89

910
"github.com/gobwas/ws"
1011
_ "github.com/gorilla/websocket"
12+
_ "github.com/lesismal/nbio/nbhttp/websocket"
1113

1214
_ "nhooyr.io/websocket"
1315
)
1416

15-
func basicMask(maskKey [4]byte, pos int, b []byte) int {
17+
func basicMask(b []byte, maskKey [4]byte, pos int) int {
1618
for i := range b {
1719
b[i] ^= maskKey[pos&3]
1820
pos++
1921
}
2022
return pos & 3
2123
}
2224

25+
//go:linkname maskGo nhooyr.io/websocket.maskGo
26+
func maskGo(b []byte, key32 uint32) int
27+
28+
//go:linkname maskAsm nhooyr.io/websocket.maskAsm
29+
func maskAsm(b *byte, len int, key32 uint32) uint32
30+
31+
//go:linkname nbioMaskBytes github.com/lesismal/nbio/nbhttp/websocket.maskXOR
32+
func nbioMaskBytes(b, key []byte) int
33+
2334
//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes
2435
func gorillaMaskBytes(key [4]byte, pos int, b []byte) int
2536

26-
//go:linkname mask nhooyr.io/websocket.mask
27-
func mask(key32 uint32, b []byte) int
28-
2937
func Benchmark_mask(b *testing.B) {
38+
b.Run(runtime.GOARCH, benchmark_mask)
39+
}
40+
41+
func benchmark_mask(b *testing.B) {
3042
sizes := []int{
31-
2,
32-
3,
33-
4,
3443
8,
3544
16,
3645
32,
3746
128,
47+
256,
3848
512,
49+
1024,
50+
2048,
3951
4096,
52+
8192,
4053
16384,
4154
}
4255

@@ -48,22 +61,34 @@ func Benchmark_mask(b *testing.B) {
4861
name: "basic",
4962
fn: func(b *testing.B, key [4]byte, p []byte) {
5063
for i := 0; i < b.N; i++ {
51-
basicMask(key, 0, p)
64+
basicMask(p, key, 0)
5265
}
5366
},
5467
},
5568

5669
{
57-
name: "nhooyr",
70+
name: "nhooyr-go",
71+
fn: func(b *testing.B, key [4]byte, p []byte) {
72+
key32 := binary.LittleEndian.Uint32(key[:])
73+
b.ResetTimer()
74+
75+
for i := 0; i < b.N; i++ {
76+
maskGo(p, key32)
77+
}
78+
},
79+
},
80+
{
81+
name: "wdvxdr1123-asm",
5882
fn: func(b *testing.B, key [4]byte, p []byte) {
5983
key32 := binary.LittleEndian.Uint32(key[:])
6084
b.ResetTimer()
6185

6286
for i := 0; i < b.N; i++ {
63-
mask(key32, p)
87+
maskAsm(&p[0], len(p), key32)
6488
}
6589
},
6690
},
91+
6792
{
6893
name: "gorilla",
6994
fn: func(b *testing.B, key [4]byte, p []byte) {
@@ -80,16 +105,25 @@ func Benchmark_mask(b *testing.B) {
80105
}
81106
},
82107
},
108+
{
109+
name: "nbio",
110+
fn: func(b *testing.B, key [4]byte, p []byte) {
111+
keyb := key[:]
112+
for i := 0; i < b.N; i++ {
113+
nbioMaskBytes(p, keyb)
114+
}
115+
},
116+
},
83117
}
84118

85119
key := [4]byte{1, 2, 3, 4}
86120

87-
for _, size := range sizes {
88-
p := make([]byte, size)
121+
for _, fn := range fns {
122+
b.Run(fn.name, func(b *testing.B) {
123+
for _, size := range sizes {
124+
p := make([]byte, size)
89125

90-
b.Run(strconv.Itoa(size), func(b *testing.B) {
91-
for _, fn := range fns {
92-
b.Run(fn.name, func(b *testing.B) {
126+
b.Run(strconv.Itoa(size), func(b *testing.B) {
93127
b.SetBytes(int64(size))
94128

95129
fn.fn(b, key, p)

internal/thirdparty/go.mod

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ require (
88
github.com/gin-gonic/gin v1.9.1
99
github.com/gobwas/ws v1.3.0
1010
github.com/gorilla/websocket v1.5.0
11+
github.com/lesismal/nbio v1.3.18
1112
nhooyr.io/websocket v0.0.0-00010101000000-000000000000
1213
)
1314

@@ -25,6 +26,7 @@ require (
2526
github.com/json-iterator/go v1.1.12 // indirect
2627
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
2728
github.com/leodido/go-urn v1.2.4 // indirect
29+
github.com/lesismal/llib v1.1.12 // indirect
2830
github.com/mattn/go-isatty v0.0.19 // indirect
2931
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
3032
github.com/modern-go/reflect2 v1.0.2 // indirect
@@ -34,7 +36,7 @@ require (
3436
golang.org/x/arch v0.3.0 // indirect
3537
golang.org/x/crypto v0.9.0 // indirect
3638
golang.org/x/net v0.10.0 // indirect
37-
golang.org/x/sys v0.8.0 // indirect
39+
golang.org/x/sys v0.17.0 // indirect
3840
golang.org/x/text v0.9.0 // indirect
3941
google.golang.org/protobuf v1.30.0 // indirect
4042
gopkg.in/yaml.v3 v3.0.1 // indirect

0 commit comments

Comments
 (0)