@@ -280,6 +280,8 @@ class PrimIDs(Enum):
280
280
SINK = auto ()
281
281
# Tensor Subclasses methods
282
282
TENSOR_SUBCLASS_CTOR = auto ()
283
+ FLATTEN_TENSOR_SUBCLASS = auto ()
284
+ UNFLATTEN_TENSOR_SUBCLASS = auto ()
283
285
284
286
285
287
class OpTags (Enum ):
@@ -4098,7 +4100,7 @@ def check_types(coll):
4098
4100
return tuple (types_set )
4099
4101
4100
4102
4101
- def filter_types (types : tuple [Any , ...]) -> tuple [Any , ...]:
4103
+ def filter_types_for_tensor_wrapper_subclass (types : tuple [Any , ...]) -> tuple [Any , ...]:
4102
4104
return tuple (
4103
4105
filter (
4104
4106
lambda t : (
@@ -4170,7 +4172,7 @@ def printer_of_tensor_subclass_ctor(
4170
4172
filtered_types = (cls ,)
4171
4173
if non_tensors :
4172
4174
types = get_nested_types ([t .obj if isinstance (t , codeutils .ContextObject ) else t for t in non_tensors ])
4173
- filtered_types += filter_types (types )
4175
+ filtered_types += filter_types_for_tensor_wrapper_subclass (types )
4174
4176
new_imports = {t .__name__ : t for t in filtered_types }
4175
4177
bsym ._import_ctx .update (new_imports )
4176
4178
return s
@@ -4183,7 +4185,7 @@ def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
4183
4185
filtered_types : tuple [Any , ...] = (cls ,)
4184
4186
if non_tensors :
4185
4187
types = get_nested_types (non_tensors )
4186
- filtered_types += filter_types (types )
4188
+ filtered_types += filter_types_for_tensor_wrapper_subclass (types )
4187
4189
new_imports = {t .__name__ : t for t in filtered_types }
4188
4190
bsym ._import_ctx .update (new_imports )
4189
4191
@@ -4195,3 +4197,163 @@ def bind_postprocess_of_tensor_subclass_ctor(bsym: BoundSymbol) -> None:
4195
4197
python_printer = printer_of_tensor_subclass_ctor ,
4196
4198
_bind_postprocess = bind_postprocess_of_tensor_subclass_ctor ,
4197
4199
)
4200
+
4201
+
4202
+ def printer_of_tensor_subclass_flatten (
4203
+ bsym : BoundSymbol ,
4204
+ out_printables : Any ,
4205
+ arg_printables : Sequence [Printable ],
4206
+ kwarg_printables : dict [str , Printable ],
4207
+ ) -> str | Iterable [str ]:
4208
+ from itertools import chain
4209
+
4210
+ arg_str = (
4211
+ ""
4212
+ if (arg_printables is None or len (arg_printables ) == 0 )
4213
+ else ", " .join (codeutils .prettyprint (x ) for x in arg_printables )
4214
+ )
4215
+
4216
+ result_str : str
4217
+ if bsym .output is None or (baseutils .is_collection (bsym .output ) and len (bsym .output ) == 0 ):
4218
+ result_str = ""
4219
+ else :
4220
+ result_str = f"{ codeutils .prettyprint (out_printables , literals_as_underscores = True )} = "
4221
+
4222
+ # Creates a comment describing the output
4223
+ comment_str = ""
4224
+ if isinstance (bsym .output , Proxy ):
4225
+ comment_str = f" # { codeutils .prettyprint (out_printables , with_type = True )} "
4226
+
4227
+ s = f"{ result_str } { arg_str } .__tensor_flatten__(){ comment_str } "
4228
+
4229
+ if bsym .header :
4230
+ header_lines = (
4231
+ bsym .header
4232
+ if isinstance (bsym .header , Sequence ) and not isinstance (bsym .header , str )
4233
+ else bsym .header .splitlines ()
4234
+ )
4235
+ header_lines = (f"# { line } " for line in header_lines )
4236
+ return chain (header_lines , [s ])
4237
+
4238
+ return s
4239
+
4240
+
4241
+ # NOTE(crcrpar): The behavior is different from PyTorch `subclass_tensor.__tensor_flatten__()`
4242
+ # that returns a list of tensor attr names and a dict of const metadata. In Thunder traces,
4243
+ # const values could be obviated and actual tensor proxies would be more useful
4244
+ # than tensor attr names.
4245
+ def flatten_tensor_subclass_meta (t : SubclassTensorProxy ) -> tuple [TensorProxy , ...]:
4246
+ tensor_attr_names , metadata = t .__tensor_flatten__ ()
4247
+ tensors = tuple (getattr (t , name ) for name in tensor_attr_names )
4248
+ return tensors
4249
+
4250
+
4251
+ flatten_tensor_subclass = make_prim (
4252
+ PrimIDs .FLATTEN_TENSOR_SUBCLASS ,
4253
+ "flatten_tensor_subclass" ,
4254
+ meta = flatten_tensor_subclass_meta ,
4255
+ python_printer = printer_of_tensor_subclass_flatten ,
4256
+ )
4257
+
4258
+
4259
+ def printer_of_unflatten_tensor_subclass (
4260
+ bsym : BoundSymbol ,
4261
+ out_printables : Any ,
4262
+ arg_printables : Sequence [Printable ],
4263
+ kwarg_printables : dict [str , Printable ],
4264
+ ) -> str | Iterable [str ]:
4265
+ from itertools import chain
4266
+
4267
+ wrapped_cls : ContextObject | torch ._C ._TensorMeta = arg_printables [0 ]
4268
+ if isinstance (wrapped_cls , torch ._C ._TensorMeta ):
4269
+ cls = wrapped_cls
4270
+ else :
4271
+ cls : torch ._C ._TensorMeta = wrapped_cls .obj
4272
+
4273
+ arg_str = (
4274
+ ""
4275
+ if (arg_printables is None or len (arg_printables ) == 0 )
4276
+ else ", " .join (codeutils .prettyprint (x ) for x in arg_printables [1 :])
4277
+ )
4278
+ kwarg_str : str
4279
+
4280
+ if len (kwarg_printables ) == 0 :
4281
+ kwarg_str = ""
4282
+ else :
4283
+ kwarg_str = ", " .join (f"{ k } ={ codeutils .prettyprint (v )} " for k , v in kwarg_printables .items ())
4284
+
4285
+ result_str : str
4286
+ if bsym .output is None or (baseutils .is_collection (bsym .output ) and len (bsym .output ) == 0 ):
4287
+ result_str = ""
4288
+ else :
4289
+ result_str = f"{ codeutils .prettyprint (out_printables , literals_as_underscores = True )} = "
4290
+
4291
+ # Creates a comment describing the output
4292
+ comment_str = ""
4293
+ if isinstance (bsym .output , Proxy ):
4294
+ comment_str = f" # { codeutils .prettyprint (out_printables , with_type = True )} "
4295
+
4296
+ s = f"{ result_str } { cls .__name__ } .__tensor_unflatten__({ arg_str } { ', ' if (len (arg_str ) > 0 and len (kwarg_str ) > 0 ) else '' } { kwarg_str } ){ comment_str } "
4297
+
4298
+ if bsym .header :
4299
+ header_lines = (
4300
+ bsym .header
4301
+ if isinstance (bsym .header , Sequence ) and not isinstance (bsym .header , str )
4302
+ else bsym .header .splitlines ()
4303
+ )
4304
+ header_lines = (f"# { line } " for line in header_lines )
4305
+ return chain (header_lines , [s ])
4306
+
4307
+ return s
4308
+
4309
+
4310
+ def bind_postprocess_of_unflatten_tensor_subclass (bsym : BoundSymbol ) -> None :
4311
+ cls = bsym .args [0 ]
4312
+ inner_tensors = bsym .args [1 ]
4313
+ metadata = bsym .args [2 ]
4314
+
4315
+ filtered_types : tuple [Any , ...] = (cls ,)
4316
+ if metadata :
4317
+ types = get_nested_types (list (metadata .values ()))
4318
+ filtered_types += filter_types_for_tensor_wrapper_subclass (types )
4319
+ new_imports = {t .__name__ : t for t in filtered_types }
4320
+ bsym ._import_ctx .update (new_imports )
4321
+
4322
+
4323
+ def unflatten_tensor_subclass_meta (
4324
+ tensor_subclass_type ,
4325
+ inner_tensors : dict [str , TensorProxy ],
4326
+ metadata : dict [str , Any ],
4327
+ ) -> SubclassTensorProxy :
4328
+ first_tensor : TensorProxy = list (inner_tensors .values ())[0 ]
4329
+ a = SubclassTensorProxy (
4330
+ shape = first_tensor .shape ,
4331
+ device = first_tensor .device ,
4332
+ dtype = first_tensor .dtype ,
4333
+ requires_grad = first_tensor .requires_grad ,
4334
+ tensors = list (inner_tensors .values ()),
4335
+ non_tensors = list (metadata .values ()),
4336
+ subclass_type = tensor_subclass_type ,
4337
+ )
4338
+ for name , value in inner_tensors .items ():
4339
+ setattr (a , name , value )
4340
+ for name , value in metadata .items ():
4341
+ setattr (a , name , value )
4342
+ return a
4343
+
4344
+
4345
+ def unflatten_tensor_subclass_python_impl (
4346
+ tensor_subclass_type ,
4347
+ inner_tensors : dict [str , TensorProxy ],
4348
+ metadata : dict [str , Any ],
4349
+ ) -> torch .Tensor :
4350
+ return tensor_subclass_type .__tensor_unflatten__ (inner_tensors , metadata , - 1 , - 1 )
4351
+
4352
+
4353
+ unflatten_tensor_subclass = make_prim (
4354
+ PrimIDs .UNFLATTEN_TENSOR_SUBCLASS ,
4355
+ "unflatten_tensor_subclass" ,
4356
+ meta = unflatten_tensor_subclass_meta ,
4357
+ python_printer = printer_of_unflatten_tensor_subclass ,
4358
+ _bind_postprocess = bind_postprocess_of_unflatten_tensor_subclass ,
4359
+ )
0 commit comments