Skip to content

Commit fd9c437

Browse files
Caner Gocmenfacebook-github-bot
Caner Gocmen
authored andcommitted
Add hashing for Topology (#3045)
Summary: Pull Request resolved: #3045 Adding a hashing function for Topology. We're using the `hashlib` library to get consistent hashes. Differential Revision: D76004583
1 parent 71db31d commit fd9c437

File tree

2 files changed

+89
-0
lines changed

2 files changed

+89
-0
lines changed

torchrec/distributed/planner/tests/test_types.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ParameterConstraints,
1919
Shard,
2020
ShardingOption,
21+
Topology,
2122
)
2223
from torchrec.distributed.types import (
2324
BoundsCheckMode,
@@ -214,6 +215,54 @@ def test_module_pooled_mch_ec(self) -> None:
214215
self.assertEqual(sharding_option.is_pooled, False)
215216

216217

218+
class TestTopologyHash(unittest.TestCase):
219+
def test_hash_equality(self) -> None:
220+
# Create two identical Topology instances
221+
topology1 = Topology(
222+
world_size=2,
223+
compute_device="cuda",
224+
hbm_cap=1024 * 1024 * 2,
225+
local_world_size=2,
226+
)
227+
228+
topology2 = Topology(
229+
world_size=2,
230+
compute_device="cuda",
231+
hbm_cap=1024 * 1024 * 2,
232+
local_world_size=2,
233+
)
234+
235+
# Verify that the hash values are equal
236+
self.assertEqual(
237+
topology1._hash(),
238+
topology2._hash(),
239+
"Hashes should be equal for identical Topology instances",
240+
)
241+
242+
def test_hash_inequality(self) -> None:
243+
# Create two different Topology instances
244+
topology1 = Topology(
245+
world_size=2,
246+
compute_device="cuda",
247+
hbm_cap=1024 * 1024 * 2,
248+
local_world_size=2,
249+
)
250+
251+
topology2 = Topology(
252+
world_size=4, # Different world_size
253+
compute_device="cuda",
254+
hbm_cap=1024 * 1024 * 2,
255+
local_world_size=2,
256+
)
257+
258+
# Verify that the hash values are different
259+
self.assertNotEqual(
260+
topology1._hash(),
261+
topology2._hash(),
262+
"Hashes should be different for different Topology instances",
263+
)
264+
265+
217266
class TestParameterConstraintsHash(unittest.TestCase):
218267

219268
def test_hash_equality(self) -> None:

torchrec/distributed/planner/types.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# pyre-strict
99

1010
import abc
11+
import hashlib
12+
import pickle
1113
from copy import deepcopy
1214
from dataclasses import dataclass, field
1315
from enum import Enum
@@ -248,6 +250,10 @@ def get_bw(
248250

249251

250252
class Topology:
253+
"""
254+
Representation of a network of devices in a cluster.
255+
"""
256+
251257
def __init__(
252258
self,
253259
world_size: int,
@@ -396,6 +402,40 @@ def __repr__(self) -> str:
396402
topology_repr += str(self._comms_bandwidths) + "\n"
397403
return topology_repr
398404

405+
def _hash(self) -> str:
406+
"""
407+
Compute a consistent hash value for this Topology instance.
408+
409+
Returns:
410+
str: A hash value for this Topology instance.
411+
"""
412+
413+
# Compute hbms and ddrs from the decives
414+
hbms = [device.storage.hbm for device in self._devices]
415+
ddrs = [device.storage.ddr for device in self._devices]
416+
417+
# Combine all attributes into a hashable tuple
418+
hashable_list = [
419+
self._world_size,
420+
self._compute_device,
421+
hbms,
422+
ddrs,
423+
self._local_world_size,
424+
self._hbm_mem_bw,
425+
self._ddr_mem_bw,
426+
self._hbm_to_ddr_mem_bw,
427+
self._comms_bandwidths.intra_host_bw,
428+
self._comms_bandwidths.inter_host_bw,
429+
self._bwd_compute_multiplier,
430+
self._weighted_feature_bwd_compute_multiplier,
431+
self._uneven_sharding_perf_multiplier,
432+
]
433+
434+
serialized_list = str(hashable_list).encode("utf-8")
435+
hash_object = hashlib.sha256(serialized_list)
436+
hash_digest = hash_object.hexdigest()
437+
return hash_digest
438+
399439

400440
# ---- INPUT / OUTPUT ----- #
401441

0 commit comments

Comments
 (0)