Skip to content

Commit 404981f

Browse files
yashk2810froystig
authored andcommitted
Remove eval_shape as a method on Traced and rather use .out_info because .trace already evals.
Forward `jit(f).eval_shape(*args)` to `jit(f).trace(*args).out_info` Co-authored-by: Roy Frostig <[email protected]> PiperOrigin-RevId: 786896497
1 parent dfa6584 commit 404981f

File tree

3 files changed

+15
-18
lines changed

3 files changed

+15
-18
lines changed

jax/_src/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2842,10 +2842,10 @@ def eval_shape(fun, *args, **kwargs):
28422842
float32
28432843
"""
28442844
if type(fun) is xc._xla.PjitFunction:
2845-
return fun.trace(*args, **kwargs).eval_shape() # type: ignore
2845+
return fun.trace(*args, **kwargs).out_info # type: ignore
28462846
try: hash(fun)
28472847
except TypeError: fun = partial(fun)
2848-
return jit(fun).trace(*args, **kwargs).eval_shape()
2848+
return jit(fun).trace(*args, **kwargs).out_info
28492849

28502850

28512851
def named_call(

jax/_src/pjit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def jit_lower(jit_func, *args, **kwargs):
313313

314314
@api_boundary
315315
def jit_eval_shape(jit_func, *args, **kwargs):
316-
return jit_trace(jit_func, *args, **kwargs).eval_shape()
316+
return jit_trace(jit_func, *args, **kwargs).out_info
317317

318318
def jit_evict_fn(self):
319319
self._clear_cache()

jax/_src/stages.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,18 @@ def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
709709

710710
@property
711711
def out_info(self):
712-
return self.eval_shape()
712+
out_shardings = [None if isinstance(s, UnspecifiedValue) else s
713+
for s in self._params_out_shardings]
714+
out = []
715+
for a, out_s in zip(self.jaxpr.out_avals, out_shardings):
716+
s = (a.sharding if a.sharding.mesh._are_all_axes_explicit else out_s
717+
if out_s is None else out_s)
718+
# TODO(yashkatariya): Add `Layout` to SDS.
719+
out.append(
720+
core.ShapeDtypeStruct(
721+
a.shape, a.dtype, sharding=s, weak_type=a.weak_type,
722+
vma=(a.vma if config._check_vma.value else None)))
723+
return tree_util.tree_unflatten(self._out_tree, out)
713724

714725
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
715726
_private_parameters: mlir.LoweringParameters | None = None):
@@ -727,20 +738,6 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
727738
raise ValueError(msg) from None
728739
return Lowered(lowering, self.args_info, self._out_tree)
729740

730-
def eval_shape(self):
731-
out_shardings = [None if isinstance(s, UnspecifiedValue) else s
732-
for s in self._params_out_shardings]
733-
out = []
734-
for a, out_s in zip(self.jaxpr.out_avals, out_shardings):
735-
s = (a.sharding if a.sharding.mesh._are_all_axes_explicit else out_s
736-
if out_s is None else out_s)
737-
# TODO(yashkatariya): Add `Layout` to SDS.
738-
out.append(
739-
core.ShapeDtypeStruct(
740-
a.shape, a.dtype, sharding=s, weak_type=a.weak_type,
741-
vma=(a.vma if config._check_vma.value else None)))
742-
return tree_util.tree_unflatten(self._out_tree, out)
743-
744741

745742
@runtime_checkable
746743
class Wrapped(Protocol):

0 commit comments

Comments
 (0)