Skip to content

Commit 1fca477

Browse files
authored
Merge branch 'main' into 2.7-RC-TEST
2 parents eff088b + 02d2519 commit 1fca477

File tree

4 files changed

+5
-2
lines changed

4 files changed

+5
-2
lines changed

.jenkins/validate_tutorials_built.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"intermediate_source/flask_rest_api_tutorial",
5151
"intermediate_source/text_to_speech_with_torchaudio",
5252
"intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release.
53+
"advanced_source/semi_structured_sparse" # reenable after 3303 is fixed.
5354
]
5455

5556
def tutorial_source_dirs() -> List[Path]:

advanced_source/coding_ddpg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,7 @@ def ceil_div(x, y):
10401040

10411041
###############################################################################
10421042
# let's use the TD(lambda) estimator!
1043-
loss_module.make_value_estimator(ValueEstimators.TDLambda, gamma=gamma, lmbda=lmbda)
1043+
loss_module.make_value_estimator(ValueEstimators.TDLambda, gamma=gamma, lmbda=lmbda, device=device)
10441044

10451045
###############################################################################
10461046
# .. note::

advanced_source/semi_structured_sparse.py

+2
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@
210210
SparseSemiStructuredTensor._FORCE_CUTLASS = True
211211
torch.manual_seed(100)
212212

213+
# Set default device to "cuda:0"
214+
torch.set_default_device(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
213215

214216
######################################################################
215217
# We’ll also need to define some helper functions that are specific to the

intermediate_source/reinforcement_ppo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@
551551
#
552552

553553
advantage_module = GAE(
554-
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
554+
gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True, device=device,
555555
)
556556

557557
loss_module = ClipPPOLoss(

0 commit comments

Comments
 (0)