@@ -177,7 +177,7 @@ def from_tensordict_pair(
177
177
collate_fn : Callable [[Any ], Any ] | None = None ,
178
178
write_fn : Callable [[Any , Any ], Any ] | None = None ,
179
179
consolidated : bool | None = None ,
180
- ):
180
+ ) -> TensorDictMap :
181
181
"""Creates a new TensorDictStorage from a pair of tensordicts (source and dest) using pre-defined rules of thumb.
182
182
183
183
Args:
@@ -308,7 +308,23 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
308
308
if not self ._has_lazy_out_keys ():
309
309
# TODO: make this work with pytrees and avoid calling select if keys match
310
310
value = value .select (* self .out_keys , strict = False )
311
+ item , value = self ._maybe_add_batch (item , value )
312
+ index = self ._to_index (item , extend = True )
313
+ if index .unique ().numel () < index .numel ():
314
+ # If multiple values point to the same place in the storage, we cannot process them by batch
315
+ # There could be a better way to deal with this, using unique ids.
316
+ vals = []
317
+ for it , val in zip (item .split (1 ), value .split (1 )):
318
+ self [it ] = val
319
+ vals .append (val )
320
+ # __setitem__ may affect the content of the input data
321
+ value .update (TensorDictBase .lazy_stack (vals ))
322
+ return
311
323
if self .write_fn is not None :
324
+ # We use this block in the following context: the value written in the storage is already present,
325
+ # but it needs to be updated.
326
+ # We first check if the value is already there using `contains`. If so, we pass the new value and the
327
+ # previous one to write_fn. The values that are not present are passed alone.
312
328
if len (self ):
313
329
modifiable = self .contains (item )
314
330
if modifiable .any ():
@@ -322,8 +338,6 @@ def __setitem__(self, item: TensorDictBase, value: TensorDictBase):
322
338
value = self .write_fn (value )
323
339
else :
324
340
value = self .write_fn (value )
325
- item , value = self ._maybe_add_batch (item , value )
326
- index = self ._to_index (item , extend = True )
327
341
self .storage .set (index , value )
328
342
329
343
def __len__ (self ):
0 commit comments