diff --git a/esync/srvsync/esync_server.go b/esync/srvsync/esync_server.go index b4bfeb8..2602648 100644 --- a/esync/srvsync/esync_server.go +++ b/esync/srvsync/esync_server.go @@ -9,6 +9,7 @@ import ( "github.com/yohamta/donburi" "github.com/yohamta/donburi/component" "golang.org/x/sync/errgroup" + "nhooyr.io/websocket" "reflect" "slices" "sync" @@ -74,15 +75,17 @@ func NetworkSync(world donburi.World, entity *donburi.Entity, components ...donb // This is done by serializing all the components of the entity, and preparing a network bundle for the clients. func DoSync() error { errs, _ := errgroup.WithContext(context.Background()) - for _, client := range router.Peers() { - snapshot := buildSnapshot(client, world) - client := client + router.PeerMap().Range(func(key *websocket.Conn, client *router.NetworkClient) bool { + snapshot := buildSnapshot(client, world) errs.Go(func() error { err := client.SendMessage(snapshot) return err }) - } + + return true + }) + return errs.Wait() } diff --git a/internal/syncx/map.go b/internal/syncx/map.go new file mode 100644 index 0000000..7f21fa8 --- /dev/null +++ b/internal/syncx/map.go @@ -0,0 +1,38 @@ +package syncx + +import "sync" + +type Map[K comparable, V any] struct { + m sync.Map +} + +func (m *Map[K, V]) Delete(key K) { + m.m.Delete(key) +} + +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + v, ok := m.m.Load(key) + if !ok { + return value, ok + } + return v.(V), ok +} + +func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + v, loaded := m.m.LoadAndDelete(key) + if !loaded { + return value, loaded + } + return v.(V), loaded +} + +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + a, loaded := m.m.LoadOrStore(key, value) + return a.(V), loaded +} + +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.m.Range(func(key, value any) bool { return f(key.(K), value.(V)) }) +} + +func (m *Map[K, V]) Store(key K, value V) { m.m.Store(key, value) } diff --git a/internal/syncx/map_test.go b/internal/syncx/map_test.go new file mode 100644 index 0000000..8ae1d2d --- /dev/null +++ b/internal/syncx/map_test.go @@ -0,0 +1,46 @@ +package syncx_test + +import ( + "github.com/leap-fish/necs/internal/syncx" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestMap(t *testing.T) { + m := &syncx.Map[int, string]{} + // Test LoadAndStore + actual, loaded := m.LoadOrStore(1, "value1") + assert.False(t, loaded, "Expected loaded=false") + assert.Equal(t, "value1", actual, "Expected actual value 'value1'") + + // Test Load + actualValue, ok := m.Load(1) + assert.True(t, ok, "Expected ok=true") + assert.Equal(t, "value1", actualValue, "Expected actual value 'value1'") + + // Test Store + m.Store(2, "value2") + actualValue, ok = m.Load(2) + assert.True(t, ok, "Expected ok=true") + assert.Equal(t, "value2", actualValue, "Expected actual value 'value2'") + + // Test Delete + m.Delete(1) + _, ok = m.Load(1) + assert.False(t, ok, "Expected ok=false for key 1 after deletion") + + // Test LoadAndDelete + actualValue, loaded = m.LoadAndDelete(2) + assert.True(t, loaded, "Expected loaded=true") + assert.Equal(t, "value2", actualValue, "Expected actual value 'value2'") + + // Test Range + m.Store(3, "value3") + m.Store(4, "value4") + var count int + m.Range(func(key int, value string) bool { + count++ + return true + }) + assert.Equal(t, 2, count, "Expected 2 iterations") +} diff --git a/router/router.go b/router/router.go index e832a6f..f852ea4 100644 --- a/router/router.go +++ b/router/router.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "errors" "fmt" + "github.com/leap-fish/necs/internal/syncx" "github.com/leap-fish/necs/typeid" "github.com/leap-fish/necs/typemapper" "nhooyr.io/websocket" @@ -25,8 +26,8 @@ var ( callbacks = make(map[reflect.Type][]any) - connMap = map[*websocket.Conn]string{} - clientMap = map[*websocket.Conn]*NetworkClient{} + idMap = syncx.Map[*websocket.Conn, string]{} + clientMap = syncx.Map[*websocket.Conn, *NetworkClient]{} ) // On adds a callback to be called whenever the specified message type T is received. @@ -92,37 +93,49 @@ func ProcessMessage(sender *NetworkClient, msg []byte) error { } func Client(conn *websocket.Conn) *NetworkClient { - if _, ok := clientMap[conn]; ok { - return clientMap[conn] + client, ok := clientMap.Load(conn) + if ok { + return client } - clientMap[conn] = NewNetworkClient(context.Background(), conn) - return clientMap[conn] + clientMap.Store(conn, NewNetworkClient(context.Background(), conn)) + // Ignore because we know it's set + r, _ := clientMap.Load(conn) + return r } func Id(client *NetworkClient) string { - if _, ok := connMap[client.Conn]; ok { - return connMap[client.Conn] + id, ok := idMap.Load(client.Conn) + if ok { + return id } bytes := make([]byte, 16) _, _ = rand.Read(bytes) - id := fmt.Sprintf("%d:%x", len(connMap), bytes[:10]) + id = fmt.Sprintf("%x", bytes[:10]) - connMap[client.Conn] = id + idMap.Store(client.Conn, id) return id } +// Peers returns a new slice of NetworkClient pointers from the underlying map. +// Use PeerMap if you are able to as this avoids this kind of duplication. func Peers() []*NetworkClient { - peers := make([]*NetworkClient, 0, len(clientMap)) - for _, client := range clientMap { - peers = append(peers, client) - } + var peers []*NetworkClient + clientMap.Range(func(key *websocket.Conn, value *NetworkClient) bool { + peers = append(peers, value) + return true + }) return peers } +// PeerMap returns a pointer to the underlying peer map. +func PeerMap() *syncx.Map[*websocket.Conn, *NetworkClient] { + return &clientMap +} + func Broadcast(msg any) error { payload, err := Serialize(msg) if err != nil { @@ -164,8 +177,8 @@ func CallDisconnect(sender *websocket.Conn, err error) { go callback(client, err) } - delete(connMap, sender) - delete(clientMap, sender) + idMap.Delete(sender) + clientMap.Delete(sender) } func CallError(sender *websocket.Conn, err error) { @@ -181,4 +194,7 @@ func ResetRouter() { disconnectCallbacks = []func(sender *NetworkClient, err error){} errorCallbacks = []func(sender *NetworkClient, err error){} callbacks = make(map[reflect.Type][]any) + + idMap = syncx.Map[*websocket.Conn, string]{} + clientMap = syncx.Map[*websocket.Conn, *NetworkClient]{} }