@@ -261,7 +261,7 @@ def _create_device_mesh_for_nd_torus(
261
261
list (enumerate (mesh_shape ))
262
262
):
263
263
# Preferentially map to more physical axes first for higher bandwidth.
264
- for num_axes in range (3 , 0 , - 1 ):
264
+ for num_axes in range (len ( physical_mesh . shape ) , 0 , - 1 ):
265
265
# Try assign to any subset of size num_axes. Generate all candidates.
266
266
indices_and_axes = itertools .combinations (
267
267
enumerate (assignable_physical_mesh ), num_axes
@@ -660,6 +660,16 @@ def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray:
660
660
coords [1 ] - min_coords [1 ],
661
661
d .core_on_chip - min_cores_per_chip ,
662
662
] = d
663
+ elif device_kind in (_TPU_7X ,):
664
+ out = np .empty (dims + (cores_per_chip ,), dtype = object )
665
+ for d in jax_devices :
666
+ coords = d .coords
667
+ out [
668
+ coords [0 ] - min_coords [0 ],
669
+ coords [1 ] - min_coords [1 ],
670
+ coords [2 ] - min_coords [2 ],
671
+ d .core_on_chip - min_cores_per_chip ,
672
+ ] = d
663
673
else :
664
674
out = np .empty (dims , dtype = object )
665
675
for d in jax_devices :
0 commit comments