Skip to content

Commit ad8e654

Browse files
committed
[Feature] Tokenizer transform
ghstack-source-id: ea80c1e Pull Request resolved: #2701
1 parent 9e51c8b commit ad8e654

File tree

3 files changed

+77
-5
lines changed

3 files changed

+77
-5
lines changed

torchrl/envs/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
TargetReturn,
9595
TensorDictPrimer,
9696
TimeMaxPool,
97+
Tokenizer,
9798
ToTensorImage,
9899
TrajCounter,
99100
Transform,

torchrl/envs/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
TargetReturn,
5656
TensorDictPrimer,
5757
TimeMaxPool,
58+
Tokenizer,
5859
ToTensorImage,
5960
TrajCounter,
6061
Transform,

torchrl/envs/transforms/transforms.py

+75-5
Original file line numberDiff line numberDiff line change
@@ -4426,8 +4426,8 @@ class UnaryTransform(Transform):
44264426
Args:
44274427
in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
44284428
out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
4429-
in_keys_inv (sequence of NestedKey): the keys of inputs to the unary operation during inverse call.
4430-
out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call.
4429+
in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the unary operation during inverse call.
4430+
out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the unary operation durin inverse call.
44314431
44324432
Keyword Args:
44334433
fn (Callable): the function to use as the unary operation. If it accepts
@@ -4569,7 +4569,6 @@ def transform_input_spec(self, input_spec: Composite) -> Composite:
45694569
input_spec["full_state_spec"],
45704570
test_input_spec,
45714571
)
4572-
print(input_spec)
45734572
return input_spec
45744573

45754574
def transform_output_spec(self, output_spec: Composite) -> Composite:
@@ -4649,8 +4648,8 @@ class Hash(UnaryTransform):
46494648
Args:
46504649
in_keys (sequence of NestedKey): the keys of the values to hash.
46514650
out_keys (sequence of NestedKey): the keys of the resulting hashes.
4652-
in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call.
4653-
out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call.
4651+
in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call.
4652+
out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call.
46544653
46554654
Keyword Args:
46564655
hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
@@ -4801,6 +4800,77 @@ def reproducible_hash(cls, string, seed=None):
48014800
return torch.frombuffer(hash_bytes, dtype=torch.uint8)
48024801

48034802

4803+
class Tokenizer(UnaryTransform):
4804+
r"""Applies a tokenization operation on the specified inputs.
4805+
4806+
Args:
4807+
in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation.
4808+
out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation.
4809+
in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call.
4810+
out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call.
4811+
4812+
Keyword Args:
4813+
tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
4814+
"bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
4815+
pre-trained tokenizer.
4816+
use_raw_nontensor (bool, optional): if ``False``, data is extracted from
4817+
:class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization
4818+
function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
4819+
inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``.
4820+
additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary.
4821+
"""
4822+
4823+
def __init__(
4824+
self,
4825+
in_keys: Sequence[NestedKey],
4826+
out_keys: Sequence[NestedKey],
4827+
in_keys_inv: Sequence[NestedKey] | None = None,
4828+
out_keys_inv: Sequence[NestedKey] | None = None,
4829+
*,
4830+
tokenizer: "transformers.PretrainedTokenizerBase" = None, # noqa: F821
4831+
use_raw_nontensor: bool = False,
4832+
additional_tokens: List[str] | None = None,
4833+
):
4834+
if tokenizer is None:
4835+
from transformers import AutoTokenizer
4836+
4837+
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
4838+
elif isinstance(tokenizer, str):
4839+
from transformers import AutoTokenizer
4840+
4841+
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
4842+
4843+
self.tokenizer = tokenizer
4844+
if additional_tokens:
4845+
self.tokenizer.add_tokens(additional_tokens)
4846+
super().__init__(
4847+
in_keys=in_keys,
4848+
out_keys=out_keys,
4849+
in_keys_inv=in_keys_inv,
4850+
out_keys_inv=out_keys_inv,
4851+
fn=self.call_tokenizer_fn,
4852+
use_raw_nontensor=use_raw_nontensor,
4853+
)
4854+
4855+
@property
4856+
def device(self):
4857+
if "_device" in self.__dict__:
4858+
return self._device
4859+
parent = self.parent
4860+
if parent is None:
4861+
return None
4862+
device = parent.device
4863+
self._device = device
4864+
return device
4865+
4866+
def call_tokenizer_fn(self, value: str | List[str]):
4867+
device = self.device
4868+
out = self.tokenizer.encode(value, return_tensors="pt")
4869+
if device is not None and out.device != device:
4870+
out = out.to(device)
4871+
return out
4872+
4873+
48044874
class Stack(Transform):
48054875
"""Stacks tensors and tensordicts.
48064876

0 commit comments

Comments
 (0)