|
2 | 2 | #
|
3 | 3 | # This source code is licensed under the MIT license found in the
|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 |
| - |
6 |
| -# TODO: lazy imports |
| 5 | +from __future__ import annotations |
7 | 6 |
|
8 | 7 | import torch
|
9 | 8 |
|
10 |
| -import transformers |
11 | 9 | from tensordict import NestedKey, TensorDictBase
|
12 | 10 | from tensordict.nn import (
|
13 | 11 | TensorDictModule as Mod,
|
|
17 | 15 | )
|
18 | 16 | from tensordict.tensorclass import NonTensorData, NonTensorStack
|
19 | 17 | from torchrl.data.llm import LLMData
|
20 |
| -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel |
21 | 18 |
|
22 | 19 |
|
23 | 20 | def _maybe_clear_device(td):
|
@@ -107,11 +104,12 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase:
|
107 | 104 |
|
108 | 105 |
|
109 | 106 | def from_hf_transformers(
|
110 |
| - model: transformers.modeling_utils.PreTrainedModel, |
| 107 | + model: transformers.modeling_utils.PreTrainedModel, # noqa |
111 | 108 | *,
|
112 | 109 | generate: bool = True,
|
113 | 110 | return_log_probs: bool = True,
|
114 |
| - tokenizer: transformers.tokenization_utils.PreTrainedTokenizer | None = None, |
| 111 | + tokenizer: transformers.tokenization_utils.PreTrainedTokenizer |
| 112 | + | None = None, # noqa |
115 | 113 | from_text: bool = False,
|
116 | 114 | device: torch.device | None = None,
|
117 | 115 | kwargs: dict | None = None,
|
@@ -404,6 +402,9 @@ def remove_input_seq(tokens_in, tokens_out):
|
404 | 402 |
|
405 | 403 |
|
406 | 404 | if __name__ == "__main__":
|
| 405 | + import transformers |
| 406 | + from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel |
| 407 | + |
407 | 408 | max_seq_length = 50000
|
408 | 409 |
|
409 | 410 | tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
0 commit comments