Skip to content

Commit dfa6584

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Support _get_physical_tpu_mesh for TPU7X.
PiperOrigin-RevId: 786883468
1 parent 9c76aa5 commit dfa6584

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

jax/_src/mesh_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _create_device_mesh_for_nd_torus(
261261
list(enumerate(mesh_shape))
262262
):
263263
# Preferentially map to more physical axes first for higher bandwidth.
264-
for num_axes in range(3, 0, -1):
264+
for num_axes in range(len(physical_mesh.shape), 0, -1):
265265
# Try assign to any subset of size num_axes. Generate all candidates.
266266
indices_and_axes = itertools.combinations(
267267
enumerate(assignable_physical_mesh), num_axes
@@ -660,6 +660,16 @@ def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray:
660660
coords[1] - min_coords[1],
661661
d.core_on_chip - min_cores_per_chip,
662662
] = d
663+
elif device_kind in (_TPU_7X,):
664+
out = np.empty(dims + (cores_per_chip,), dtype=object)
665+
for d in jax_devices:
666+
coords = d.coords
667+
out[
668+
coords[0] - min_coords[0],
669+
coords[1] - min_coords[1],
670+
coords[2] - min_coords[2],
671+
d.core_on_chip - min_cores_per_chip,
672+
] = d
663673
else:
664674
out = np.empty(dims, dtype=object)
665675
for d in jax_devices:

0 commit comments

Comments
 (0)