@@ -2457,6 +2457,7 @@ def __init__(
24572457 shape : Union [torch .Size , int ] = _DEFAULT_SHAPE ,
24582458 device : Optional [DEVICE_TYPING ] = None ,
24592459 dtype : torch .dtype | None = None ,
2460+ example_data : Any = None ,
24602461 ** kwargs ,
24612462 ):
24622463 if isinstance (shape , int ):
@@ -2467,6 +2468,7 @@ def __init__(
24672468 super ().__init__ (
24682469 shape = shape , space = None , device = device , dtype = dtype , domain = domain , ** kwargs
24692470 )
2471+ self .example_data = example_data
24702472
24712473 def cardinality (self ) -> Any :
24722474 raise RuntimeError ("Cannot enumerate a NonTensorSpec." )
@@ -2485,30 +2487,46 @@ def to(self, dest: Union[torch.dtype, DEVICE_TYPING]) -> NonTensor:
24852487 dest_device = torch .device (dest )
24862488 if dest_device == self .device and dest_dtype == self .dtype :
24872489 return self
2488- return self .__class__ (shape = self .shape , device = dest_device , dtype = None )
2490+ return self .__class__ (
2491+ shape = self .shape ,
2492+ device = dest_device ,
2493+ dtype = None ,
2494+ example_data = self .example_data ,
2495+ )
24892496
24902497 def clone (self ) -> NonTensor :
2491- return self .__class__ (shape = self .shape , device = self .device , dtype = self .dtype )
2498+ return self .__class__ (
2499+ shape = self .shape ,
2500+ device = self .device ,
2501+ dtype = self .dtype ,
2502+ example_data = self .example_data ,
2503+ )
24922504
24932505 def rand (self , shape = None ):
24942506 if shape is None :
24952507 shape = ()
24962508 return NonTensorData (
2497- data = None , batch_size = (* shape , * self ._safe_shape ), device = self .device
2509+ data = self .example_data ,
2510+ batch_size = (* shape , * self ._safe_shape ),
2511+ device = self .device ,
24982512 )
24992513
25002514 def zero (self , shape = None ):
25012515 if shape is None :
25022516 shape = ()
25032517 return NonTensorData (
2504- data = None , batch_size = (* shape , * self ._safe_shape ), device = self .device
2518+ data = self .example_data ,
2519+ batch_size = (* shape , * self ._safe_shape ),
2520+ device = self .device ,
25052521 )
25062522
25072523 def one (self , shape = None ):
25082524 if shape is None :
25092525 shape = ()
25102526 return NonTensorData (
2511- data = None , batch_size = (* shape , * self ._safe_shape ), device = self .device
2527+ data = self .example_data ,
2528+ batch_size = (* shape , * self ._safe_shape ),
2529+ device = self .device ,
25122530 )
25132531
25142532 def is_in (self , val : Any ) -> bool :
@@ -2533,23 +2551,36 @@ def expand(self, *shape):
25332551 raise ValueError (
25342552 f"The last elements of the expanded shape must match the current one. Got shape={ shape } while self.shape={ self .shape } ."
25352553 )
2536- return self .__class__ (shape = shape , device = self .device , dtype = None )
2554+ return self .__class__ (
2555+ shape = shape , device = self .device , dtype = None , example_data = self .example_data
2556+ )
25372557
25382558 def _reshape (self , shape ):
2539- return self .__class__ (shape = shape , device = self .device , dtype = self .dtype )
2559+ return self .__class__ (
2560+ shape = shape ,
2561+ device = self .device ,
2562+ dtype = self .dtype ,
2563+ example_data = self .example_data ,
2564+ )
25402565
25412566 def _unflatten (self , dim , sizes ):
25422567 shape = torch .zeros (self .shape , device = "meta" ).unflatten (dim , sizes ).shape
25432568 return self .__class__ (
25442569 shape = shape ,
25452570 device = self .device ,
25462571 dtype = self .dtype ,
2572+ example_data = self .example_data ,
25472573 )
25482574
25492575 def __getitem__ (self , idx : SHAPE_INDEX_TYPING ):
25502576 """Indexes the current TensorSpec based on the provided index."""
25512577 indexed_shape = _size (_shape_indexing (self .shape , idx ))
2552- return self .__class__ (shape = indexed_shape , device = self .device , dtype = self .dtype )
2578+ return self .__class__ (
2579+ shape = indexed_shape ,
2580+ device = self .device ,
2581+ dtype = self .dtype ,
2582+ example_data = self .example_data ,
2583+ )
25532584
25542585 def unbind (self , dim : int = 0 ):
25552586 orig_dim = dim
@@ -2565,6 +2596,7 @@ def unbind(self, dim: int = 0):
25652596 shape = shape ,
25662597 device = self .device ,
25672598 dtype = self .dtype ,
2599+ example_data = self .example_data ,
25682600 )
25692601 for i in range (self .shape [dim ])
25702602 )
0 commit comments