Skip to content

Commit 903f940

Browse files
authored
Merge pull request #94 from lxzan/testing
Export GetSharding Method
2 parents 2042b4f + f322320 commit 903f940

3 files changed

Lines changed: 116 additions & 61 deletions

File tree

examples/chatroom/main.go

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ func main() {
2323
var upgrader = gws.NewUpgrader(handler, &gws.ServerOption{
2424
PermessageDeflate: gws.PermessageDeflate{
2525
Enabled: true,
26-
ServerContextTakeover: false,
27-
ClientContextTakeover: false,
26+
ServerContextTakeover: true,
27+
ClientContextTakeover: true,
2828
},
2929

3030
// 在querystring里面传入用户名
@@ -36,7 +36,7 @@ func main() {
3636
return false
3737
}
3838
session.Store("name", name)
39-
session.Store("key", r.Header.Get("Sec-WebSocket-Key"))
39+
session.Store("websocketKey", r.Header.Get("Sec-WebSocket-Key"))
4040
return true
4141
},
4242
})
@@ -59,42 +59,46 @@ func main() {
5959
}
6060
}
6161

62+
func MustLoad[T any](session gws.SessionStorage, key string) (v T) {
63+
if value, exist := session.Load(key); exist {
64+
v = value.(T)
65+
}
66+
return
67+
}
68+
6269
func NewWebSocket() *WebSocket {
63-
return &WebSocket{sessions: gws.NewConcurrentMap[string, *gws.Conn](16)}
70+
return &WebSocket{
71+
sessions: gws.NewConcurrentMap[string, *gws.Conn](16, 128),
72+
}
6473
}
6574

6675
type WebSocket struct {
6776
sessions *gws.ConcurrentMap[string, *gws.Conn] // 使用内置的ConcurrentMap存储连接, 可以减少锁冲突
6877
}
6978

70-
func (c *WebSocket) getName(socket *gws.Conn) string {
71-
name, _ := socket.Session().Load("name")
72-
return name.(string)
73-
}
74-
75-
func (c *WebSocket) getKey(socket *gws.Conn) string {
76-
name, _ := socket.Session().Load("key")
77-
return name.(string)
78-
}
79-
8079
func (c *WebSocket) OnOpen(socket *gws.Conn) {
81-
name := c.getName(socket)
80+
name := MustLoad[string](socket.Session(), "name")
8281
if conn, ok := c.sessions.Load(name); ok {
83-
conn.WriteClose(1000, []byte("connection replaced"))
82+
conn.WriteClose(1000, []byte("connection is replaced"))
8483
}
85-
socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
84+
_ = socket.SetDeadline(time.Now().Add(PingInterval + HeartbeatWaitTimeout))
8685
c.sessions.Store(name, socket)
8786
log.Printf("%s connected\n", name)
8887
}
8988

9089
func (c *WebSocket) OnClose(socket *gws.Conn, err error) {
91-
name := c.getName(socket)
92-
key := c.getKey(socket)
93-
if mSocket, ok := c.sessions.Load(name); ok {
94-
if mKey := c.getKey(mSocket); mKey == key {
95-
c.sessions.Delete(name)
90+
name := MustLoad[string](socket.Session(), "name")
91+
sharding := c.sessions.GetSharding(name)
92+
sharding.Lock()
93+
defer sharding.Unlock()
94+
95+
if conn, ok := sharding.Load(name); ok {
96+
key0 := MustLoad[string](socket.Session(), "websocketKey")
97+
if key1 := MustLoad[string](conn.Session(), "websocketKey"); key1 == key0 {
98+
sharding.Delete(name)
9699
}
97100
}
101+
98102
log.Printf("onerror, name=%s, msg=%s\n", name, err.Error())
99103
}
100104

@@ -114,14 +118,14 @@ func (c *WebSocket) OnMessage(socket *gws.Conn, message *gws.Message) {
114118
defer message.Close()
115119

116120
// chrome websocket不支持ping方法, 所以在text frame里面模拟ping
117-
if b := message.Data.Bytes(); len(b) == 4 && string(b) == "ping" {
121+
if b := message.Bytes(); len(b) == 4 && string(b) == "ping" {
118122
c.OnPing(socket, nil)
119123
return
120124
}
121125

122126
var input = &Input{}
123-
_ = json.Unmarshal(message.Data.Bytes(), input)
127+
_ = json.Unmarshal(message.Bytes(), input)
124128
if conn, ok := c.sessions.Load(input.To); ok {
125-
conn.WriteMessage(gws.OpcodeText, message.Data.Bytes())
129+
_ = conn.WriteMessage(gws.OpcodeText, message.Bytes())
126130
}
127131
}

session_storage.go

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -60,80 +60,121 @@ func (c *smap) Range(f func(key string, value any) bool) {
6060

6161
type (
6262
ConcurrentMap[K comparable, V any] struct {
63-
hasher maphash.Hasher[K]
64-
sharding uint64
65-
buckets []*bucket[K, V]
66-
}
67-
68-
bucket[K comparable, V any] struct {
69-
sync.Mutex
70-
m map[K]V
63+
hasher maphash.Hasher[K]
64+
num uint64
65+
shardings []*Map[K, V]
7166
}
7267
)
7368

74-
func NewConcurrentMap[K comparable, V any](sharding uint64) *ConcurrentMap[K, V] {
75-
sharding = internal.SelectValue(sharding == 0, 16, sharding)
76-
sharding = internal.ToBinaryNumber(sharding)
69+
// NewConcurrentMap create a new concurrency-safe map
70+
// arg0 represents the number of shardings; arg1 represents the initialized capacity of a sharding.
71+
func NewConcurrentMap[K comparable, V any](size ...uint64) *ConcurrentMap[K, V] {
72+
size = append(size, 0, 0)
73+
num, capacity := size[0], size[1]
74+
num = internal.ToBinaryNumber(internal.SelectValue(num <= 0, 16, num))
7775
var cm = &ConcurrentMap[K, V]{
78-
hasher: maphash.NewHasher[K](),
79-
sharding: sharding,
80-
buckets: make([]*bucket[K, V], sharding),
76+
hasher: maphash.NewHasher[K](),
77+
num: num,
78+
shardings: make([]*Map[K, V], num),
8179
}
82-
for i, _ := range cm.buckets {
83-
cm.buckets[i] = &bucket[K, V]{m: make(map[K]V)}
80+
for i, _ := range cm.shardings {
81+
cm.shardings[i] = NewMap[K, V](int(capacity))
8482
}
8583
return cm
8684
}
8785

88-
func (c *ConcurrentMap[K, V]) getBucket(key K) *bucket[K, V] {
86+
// GetSharding returns a map sharding for a key
87+
// the operations inside the sharding is lockless, and need to be locked manually.
88+
func (c *ConcurrentMap[K, V]) GetSharding(key K) *Map[K, V] {
8989
var hashCode = c.hasher.Hash(key)
90-
var index = hashCode & (c.sharding - 1)
91-
return c.buckets[index]
90+
var index = hashCode & (c.num - 1)
91+
return c.shardings[index]
9292
}
9393

94+
// Len returns the number of elements in the map
9495
func (c *ConcurrentMap[K, V]) Len() int {
9596
var length = 0
96-
for _, b := range c.buckets {
97+
for _, b := range c.shardings {
9798
b.Lock()
98-
length += len(b.m)
99+
length += b.Len()
99100
b.Unlock()
100101
}
101102
return length
102103
}
103104

104-
func (c *ConcurrentMap[K, V]) Load(key K) (value V, exist bool) {
105-
var b = c.getBucket(key)
105+
// Load returns the value stored in the map for a key, or nil if no
106+
// value is present.
107+
// The ok result indicates whether value was found in the map.
108+
func (c *ConcurrentMap[K, V]) Load(key K) (value V, ok bool) {
109+
var b = c.GetSharding(key)
106110
b.Lock()
107-
value, exist = b.m[key]
111+
value, ok = b.Load(key)
108112
b.Unlock()
109113
return
110114
}
111115

116+
// Delete deletes the value for a key.
112117
func (c *ConcurrentMap[K, V]) Delete(key K) {
113-
var b = c.getBucket(key)
118+
var b = c.GetSharding(key)
114119
b.Lock()
115-
delete(b.m, key)
120+
b.Delete(key)
116121
b.Unlock()
117122
}
118123

124+
// Store sets the value for a key.
119125
func (c *ConcurrentMap[K, V]) Store(key K, value V) {
120-
var b = c.getBucket(key)
126+
var b = c.GetSharding(key)
121127
b.Lock()
122-
b.m[key] = value
128+
b.Store(key, value)
123129
b.Unlock()
124130
}
125131

126132
// Range calls f sequentially for each key and value present in the map.
127133
// If f returns false, range stops the iteration.
128134
func (c *ConcurrentMap[K, V]) Range(f func(key K, value V) bool) {
129-
for _, b := range c.buckets {
135+
var next = true
136+
var cb = func(k K, v V) bool {
137+
next = f(k, v)
138+
return next
139+
}
140+
for i := uint64(0); i < c.num && next; i++ {
141+
var b = c.shardings[i]
130142
b.Lock()
131-
for k, v := range b.m {
132-
if !f(k, v) {
133-
b.Unlock()
134-
return
135-
}
136-
}
143+
b.Range(cb)
137144
b.Unlock()
138145
}
139146
}
147+
148+
type Map[K comparable, V any] struct {
149+
sync.Mutex
150+
m map[K]V
151+
}
152+
153+
func NewMap[K comparable, V any](size ...int) *Map[K, V] {
154+
var capacity = 0
155+
if len(size) > 0 {
156+
capacity = size[0]
157+
}
158+
c := new(Map[K, V])
159+
c.m = make(map[K]V, capacity)
160+
return c
161+
}
162+
163+
func (c *Map[K, V]) Len() int { return len(c.m) }
164+
165+
func (c *Map[K, V]) Load(key K) (value V, ok bool) {
166+
value, ok = c.m[key]
167+
return
168+
}
169+
170+
func (c *Map[K, V]) Delete(key K) { delete(c.m, key) }
171+
172+
func (c *Map[K, V]) Store(key K, value V) { c.m[key] = value }
173+
174+
func (c *Map[K, V]) Range(f func(K, V) bool) {
175+
for k, v := range c.m {
176+
if !f(k, v) {
177+
return
178+
}
179+
}
180+
}

session_storage_test.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ func TestMap_Range(t *testing.T) {
9999
func TestConcurrentMap(t *testing.T) {
100100
var as = assert.New(t)
101101
var m1 = make(map[string]any)
102-
var m2 = NewConcurrentMap[string, uint32](5)
102+
var m2 = NewConcurrentMap[string, uint32]()
103+
as.Equal(m2.num, uint64(16))
103104
var count = internal.AlphabetNumeric.Intn(1000)
104105
for i := 0; i < count; i++ {
105106
var key = string(internal.AlphabetNumeric.Generate(10))
@@ -123,6 +124,15 @@ func TestConcurrentMap(t *testing.T) {
123124
as.Equal(v, v1)
124125
}
125126
as.Equal(len(m1), m2.Len())
127+
128+
t.Run("", func(t *testing.T) {
129+
var sum = 0
130+
var cm = NewConcurrentMap[string, int](8, 8)
131+
for _, item := range cm.shardings {
132+
sum += len(item.m)
133+
}
134+
assert.Equal(t, sum, 0)
135+
})
126136
}
127137

128138
func TestConcurrentMap_Range(t *testing.T) {

0 commit comments

Comments
 (0)