Skip to content

Commit 428e789

Browse files
Caner Gocmenfacebook-github-bot
Caner Gocmen
authored andcommitted
Add hashing for Topology (#3045)
Summary: Adding a hashing function for Topology. We're using the `hashlib` library to get consistent hashes. Reviewed By: iamzainhuda Differential Revision: D76004583
1 parent c0670d4 commit 428e789

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

torchrec/distributed/planner/tests/test_types.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ParameterConstraints,
1919
Shard,
2020
ShardingOption,
21+
Topology,
2122
)
2223
from torchrec.distributed.types import (
2324
BoundsCheckMode,
@@ -214,6 +215,54 @@ def test_module_pooled_mch_ec(self) -> None:
214215
self.assertEqual(sharding_option.is_pooled, False)
215216

216217

218+
class TestTopologyHash(unittest.TestCase):
219+
def test_hash_equality(self) -> None:
220+
# Create two identical Topology instances
221+
topology1 = Topology(
222+
world_size=2,
223+
compute_device="cuda",
224+
hbm_cap=1024 * 1024 * 2,
225+
local_world_size=2,
226+
)
227+
228+
topology2 = Topology(
229+
world_size=2,
230+
compute_device="cuda",
231+
hbm_cap=1024 * 1024 * 2,
232+
local_world_size=2,
233+
)
234+
235+
# Verify that the hash values are equal
236+
self.assertEqual(
237+
topology1._hash(),
238+
topology2._hash(),
239+
"Hashes should be equal for identical Topology instances",
240+
)
241+
242+
def test_hash_inequality(self) -> None:
243+
# Create two different Topology instances
244+
topology1 = Topology(
245+
world_size=2,
246+
compute_device="cuda",
247+
hbm_cap=1024 * 1024 * 2,
248+
local_world_size=2,
249+
)
250+
251+
topology2 = Topology(
252+
world_size=4, # Different world_size
253+
compute_device="cuda",
254+
hbm_cap=1024 * 1024 * 2,
255+
local_world_size=2,
256+
)
257+
258+
# Verify that the hash values are different
259+
self.assertNotEqual(
260+
topology1._hash(),
261+
topology2._hash(),
262+
"Hashes should be different for different Topology instances",
263+
)
264+
265+
217266
class TestParameterConstraintsHash(unittest.TestCase):
218267

219268
def test_hash_equality(self) -> None:

torchrec/distributed/planner/types.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# pyre-strict
99

1010
import abc
11+
import hashlib
1112
from copy import deepcopy
1213
from dataclasses import dataclass, field
1314
from enum import Enum
@@ -248,6 +249,10 @@ def get_bw(
248249

249250

250251
class Topology:
252+
"""
253+
Representation of a network of devices in a cluster.
254+
"""
255+
251256
def __init__(
252257
self,
253258
world_size: int,
@@ -396,6 +401,40 @@ def __repr__(self) -> str:
396401
topology_repr += str(self._comms_bandwidths) + "\n"
397402
return topology_repr
398403

404+
def _hash(self) -> str:
405+
"""
406+
Compute a consistent hash value for this Topology instance.
407+
408+
Returns:
409+
str: A hash value for this Topology instance.
410+
"""
411+
412+
# Compute hbms and ddrs from the decives
413+
hbms = [device.storage.hbm for device in self._devices]
414+
ddrs = [device.storage.ddr for device in self._devices]
415+
416+
# Combine all attributes into a hashable tuple
417+
hashable_list = [
418+
self._world_size,
419+
self._compute_device,
420+
hbms,
421+
ddrs,
422+
self._local_world_size,
423+
self._hbm_mem_bw,
424+
self._ddr_mem_bw,
425+
self._hbm_to_ddr_mem_bw,
426+
self._comms_bandwidths.intra_host_bw,
427+
self._comms_bandwidths.inter_host_bw,
428+
self._bwd_compute_multiplier,
429+
self._weighted_feature_bwd_compute_multiplier,
430+
self._uneven_sharding_perf_multiplier,
431+
]
432+
433+
serialized_list = str(hashable_list).encode("utf-8")
434+
hash_object = hashlib.sha256(serialized_list)
435+
hash_digest = hash_object.hexdigest()
436+
return hash_digest
437+
399438

400439
# ---- INPUT / OUTPUT ----- #
401440

0 commit comments

Comments
 (0)