@@ -4426,8 +4426,8 @@ class UnaryTransform(Transform):
4426
4426
Args:
4427
4427
in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
4428
4428
out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
4429
- in_keys_inv (sequence of NestedKey): the keys of inputs to the unary operation during inverse call.
4430
- out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call.
4429
+ in_keys_inv (sequence of NestedKey, optional ): the keys of inputs to the unary operation during inverse call.
4430
+ out_keys_inv (sequence of NestedKey, optional ): the keys of the outputs of the unary operation durin inverse call.
4431
4431
4432
4432
Keyword Args:
4433
4433
fn (Callable): the function to use as the unary operation. If it accepts
@@ -4569,7 +4569,6 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
4569
4569
input_spec ["full_state_spec" ],
4570
4570
test_input_spec ,
4571
4571
)
4572
- print (input_spec )
4573
4572
return input_spec
4574
4573
4575
4574
def transform_output_spec (self , output_spec : Composite ) -> Composite :
@@ -4649,8 +4648,8 @@ class Hash(UnaryTransform):
4649
4648
Args:
4650
4649
in_keys (sequence of NestedKey): the keys of the values to hash.
4651
4650
out_keys (sequence of NestedKey): the keys of the resulting hashes.
4652
- in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call.
4653
- out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call.
4651
+ in_keys_inv (sequence of NestedKey, optional ): the keys of the values to hash during inv call.
4652
+ out_keys_inv (sequence of NestedKey, optional ): the keys of the resulting hashes during inv call.
4654
4653
4655
4654
Keyword Args:
4656
4655
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
@@ -4801,6 +4800,77 @@ def reproducible_hash(cls, string, seed=None):
4801
4800
return torch .frombuffer (hash_bytes , dtype = torch .uint8 )
4802
4801
4803
4802
4803
+ class Tokenizer (UnaryTransform ):
4804
+ r"""Applies a tokenization operation on the specified inputs.
4805
+
4806
+ Args:
4807
+ in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation.
4808
+ out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation.
4809
+ in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call.
4810
+ out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call.
4811
+
4812
+ Keyword Args:
4813
+ tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
4814
+ "bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
4815
+ pre-trained tokenizer.
4816
+ use_raw_nontensor (bool, optional): if ``False``, data is extracted from
4817
+ :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization
4818
+ function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
4819
+ inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``.
4820
+ additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary.
4821
+ """
4822
+
4823
+ def __init__ (
4824
+ self ,
4825
+ in_keys : Sequence [NestedKey ],
4826
+ out_keys : Sequence [NestedKey ],
4827
+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4828
+ out_keys_inv : Sequence [NestedKey ] | None = None ,
4829
+ * ,
4830
+ tokenizer : "transformers.PretrainedTokenizerBase" = None , # noqa: F821
4831
+ use_raw_nontensor : bool = False ,
4832
+ additional_tokens : List [str ] | None = None ,
4833
+ ):
4834
+ if tokenizer is None :
4835
+ from transformers import AutoTokenizer
4836
+
4837
+ tokenizer = AutoTokenizer .from_pretrained ("bert-base-uncased" )
4838
+ elif isinstance (tokenizer , str ):
4839
+ from transformers import AutoTokenizer
4840
+
4841
+ tokenizer = AutoTokenizer .from_pretrained (tokenizer )
4842
+
4843
+ self .tokenizer = tokenizer
4844
+ if additional_tokens :
4845
+ self .tokenizer .add_tokens (additional_tokens )
4846
+ super ().__init__ (
4847
+ in_keys = in_keys ,
4848
+ out_keys = out_keys ,
4849
+ in_keys_inv = in_keys_inv ,
4850
+ out_keys_inv = out_keys_inv ,
4851
+ fn = self .call_tokenizer_fn ,
4852
+ use_raw_nontensor = use_raw_nontensor ,
4853
+ )
4854
+
4855
+ @property
4856
+ def device (self ):
4857
+ if "_device" in self .__dict__ :
4858
+ return self ._device
4859
+ parent = self .parent
4860
+ if parent is None :
4861
+ return None
4862
+ device = parent .device
4863
+ self ._device = device
4864
+ return device
4865
+
4866
+ def call_tokenizer_fn (self , value : str | List [str ]):
4867
+ device = self .device
4868
+ out = self .tokenizer .encode (value , return_tensors = "pt" )
4869
+ if device is not None and out .device != device :
4870
+ out = out .to (device )
4871
+ return out
4872
+
4873
+
4804
4874
class Stack (Transform ):
4805
4875
"""Stacks tensors and tensordicts.
4806
4876
0 commit comments