Skip to content

Commit de10969

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent 87fed0e commit de10969

File tree

2 files changed

+5
-12
lines changed

2 files changed

+5
-12
lines changed

torchrl/envs/transforms/llm.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,7 @@
1010
from typing import Any, Callable, Iterable, Literal
1111

1212
import torch
13-
from tensordict import (
14-
lazy_stack,
15-
NestedKey,
16-
TensorDict,
17-
TensorDictBase,
18-
unravel_key,
19-
)
13+
from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase, unravel_key
2014
from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams
2115
from tensordict.utils import _zip_strict, is_seq_of_nested_key
2216
from torch import nn

torchrl/modules/llm/vllm_policy.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,10 @@ def to_list(tokens, attention_mask):
244244
)
245245

246246
if generate_kwargs is None:
247-
generate_kwargs = {
248-
"detokenize": False,
249-
"prompt_logprobs": not generate,
250-
"logprobs": return_log_probs,
251-
}
247+
generate_kwargs = {}
248+
generate_kwargs.setdefault("detokenize", False)
249+
generate_kwargs.setdefault("prompt_logprobs", not generate)
250+
generate_kwargs.setdefault("logprobs", return_log_probs)
252251
if not generate:
253252
generate_kwargs["max_tokens"] = 1
254253
sampling_params = SamplingParams(**generate_kwargs)

0 commit comments

Comments
 (0)