@@ -709,7 +709,18 @@ def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
709
709
710
710
@property
711
711
def out_info (self ):
712
- return self .eval_shape ()
712
+ out_shardings = [None if isinstance (s , UnspecifiedValue ) else s
713
+ for s in self ._params_out_shardings ]
714
+ out = []
715
+ for a , out_s in zip (self .jaxpr .out_avals , out_shardings ):
716
+ s = (a .sharding if a .sharding .mesh ._are_all_axes_explicit else out_s
717
+ if out_s is None else out_s )
718
+ # TODO(yashkatariya): Add `Layout` to SDS.
719
+ out .append (
720
+ core .ShapeDtypeStruct (
721
+ a .shape , a .dtype , sharding = s , weak_type = a .weak_type ,
722
+ vma = (a .vma if config ._check_vma .value else None )))
723
+ return tree_util .tree_unflatten (self ._out_tree , out )
713
724
714
725
def lower (self , * , lowering_platforms : tuple [str , ...] | None = None ,
715
726
_private_parameters : mlir .LoweringParameters | None = None ):
@@ -727,20 +738,6 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
727
738
raise ValueError (msg ) from None
728
739
return Lowered (lowering , self .args_info , self ._out_tree )
729
740
730
- def eval_shape (self ):
731
- out_shardings = [None if isinstance (s , UnspecifiedValue ) else s
732
- for s in self ._params_out_shardings ]
733
- out = []
734
- for a , out_s in zip (self .jaxpr .out_avals , out_shardings ):
735
- s = (a .sharding if a .sharding .mesh ._are_all_axes_explicit else out_s
736
- if out_s is None else out_s )
737
- # TODO(yashkatariya): Add `Layout` to SDS.
738
- out .append (
739
- core .ShapeDtypeStruct (
740
- a .shape , a .dtype , sharding = s , weak_type = a .weak_type ,
741
- vma = (a .vma if config ._check_vma .value else None )))
742
- return tree_util .tree_unflatten (self ._out_tree , out )
743
-
744
741
745
742
@runtime_checkable
746
743
class Wrapped (Protocol ):
0 commit comments