-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Transform for partial steps #2777
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2777
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 2 Pending, 1 Unrelated FailureAs of commit a2e1fc4 with merge base f1c42e0 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: cd6c967ac6d793e078cac90c340942f23ffb16f4 Pull Request resolved: #2777
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.6158s | 0.5164s | 1.9363 Ops/s | 1.9514 Ops/s | |
test_transformed | 1.1012s | 0.9969s | 1.0031 Ops/s | 1.0162 Ops/s | |
test_serial | 1.6071s | 1.5075s | 0.6634 Ops/s | 0.6526 Ops/s | |
test_parallel | 1.4286s | 1.3202s | 0.7575 Ops/s | 0.7574 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.5725ms | 29.9461μs | 33.3934 KOps/s | 32.8152 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 67.5960μs | 17.6433μs | 56.6786 KOps/s | 56.1144 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 75.5290μs | 16.6351μs | 60.1140 KOps/s | 58.9135 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 40.9470μs | 9.8898μs | 101.1142 KOps/s | 99.5451 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 98.0430μs | 31.4625μs | 31.7839 KOps/s | 31.3856 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 81.1720μs | 19.3818μs | 51.5949 KOps/s | 51.3335 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 53.0680μs | 18.6143μs | 53.7221 KOps/s | 52.5860 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 88.0550μs | 11.7318μs | 85.2387 KOps/s | 83.7487 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 79.6790μs | 33.3125μs | 30.0188 KOps/s | 29.6735 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 74.1380μs | 21.3545μs | 46.8286 KOps/s | 46.2652 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 0.1336ms | 18.6497μs | 53.6203 KOps/s | 53.2631 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 82.9750μs | 11.6135μs | 86.1070 KOps/s | 83.9357 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 93.3540μs | 34.8495μs | 28.6948 KOps/s | 28.1150 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 76.1120μs | 22.9917μs | 43.4939 KOps/s | 43.0291 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 85.0590μs | 20.2459μs | 49.3927 KOps/s | 48.6272 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 46.2260μs | 13.5178μs | 73.9767 KOps/s | 73.3348 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 0.1052ms | 33.4859μs | 29.8633 KOps/s | 29.7888 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 82.1730μs | 21.3047μs | 46.9380 KOps/s | 46.3011 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 0.5913ms | 21.3641μs | 46.8074 KOps/s | 46.2419 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 0.1061ms | 13.1813μs | 75.8652 KOps/s | 75.5135 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.1126ms | 34.8328μs | 28.7086 KOps/s | 28.1491 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 74.0580μs | 22.8856μs | 43.6957 KOps/s | 43.0543 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 2.7162ms | 22.9454μs | 43.5817 KOps/s | 43.1684 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 82.0730μs | 14.7357μs | 67.8625 KOps/s | 63.7157 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 95.2870μs | 36.8564μs | 27.1323 KOps/s | 26.5415 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 52.6780μs | 24.7906μs | 40.3378 KOps/s | 39.9591 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 82.0730μs | 22.9388μs | 43.5942 KOps/s | 42.7865 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 51.0650μs | 14.8101μs | 67.5214 KOps/s | 66.1021 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 98.4340μs | 38.4361μs | 26.0172 KOps/s | 25.6735 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 88.3440μs | 26.6103μs | 37.5794 KOps/s | 37.6760 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 64.7410μs | 24.3814μs | 41.0150 KOps/s | 40.6891 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 82.7140μs | 16.6236μs | 60.1556 KOps/s | 59.8695 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 12.9930ms | 10.0149ms | 99.8513 Ops/s | 101.7254 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 28.1301ms | 26.1884ms | 38.1849 Ops/s | 38.2905 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2628ms | 0.1967ms | 5.0831 KOps/s | 4.7735 KOps/s | |
test_values[td1_return_estimate-False-False] | 25.1824ms | 24.7653ms | 40.3791 Ops/s | 41.3686 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 27.8907ms | 26.4137ms | 37.8592 Ops/s | 38.0737 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 35.6425ms | 35.1729ms | 28.4310 Ops/s | 28.3544 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 27.5617ms | 26.3588ms | 37.9379 Ops/s | 38.1728 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 8.8930ms | 8.6477ms | 115.6382 Ops/s | 117.5123 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.8074ms | 1.9365ms | 516.3862 Ops/s | 512.5881 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.6770ms | 0.3695ms | 2.7066 KOps/s | 2.7029 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 42.7741ms | 41.8496ms | 23.8951 Ops/s | 23.4100 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 5.9577ms | 3.5851ms | 278.9329 Ops/s | 285.8542 Ops/s | |
test_dqn_speed[False-None] | 1.9173ms | 1.4128ms | 707.8076 Ops/s | 696.3408 Ops/s | |
test_dqn_speed[False-backward] | 1.9991ms | 1.9187ms | 521.1864 Ops/s | 517.6218 Ops/s | |
test_dqn_speed[True-None] | 0.7878ms | 0.5030ms | 1.9880 KOps/s | 2.0063 KOps/s | |
test_dqn_speed[True-backward] | 0.9765ms | 0.9312ms | 1.0739 KOps/s | 1.0678 KOps/s | |
test_dqn_speed[reduce-overhead-None] | 0.8229ms | 0.5029ms | 1.9885 KOps/s | 2.0198 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0061ms | 0.9335ms | 1.0713 KOps/s | 965.4549 Ops/s | |
test_ddpg_speed[False-None] | 4.3246ms | 3.0130ms | 331.8925 Ops/s | 339.4278 Ops/s | |
test_ddpg_speed[False-backward] | 4.4541ms | 4.1240ms | 242.4805 Ops/s | 240.6805 Ops/s | |
test_ddpg_speed[True-None] | 1.5274ms | 1.2593ms | 794.1036 Ops/s | 790.5024 Ops/s | |
test_ddpg_speed[True-backward] | 2.2746ms | 2.1636ms | 462.1905 Ops/s | 424.8496 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.5390ms | 1.2747ms | 784.5145 Ops/s | 782.5634 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.2481ms | 2.1498ms | 465.1677 Ops/s | 397.0072 Ops/s | |
test_sac_speed[False-None] | 9.4046ms | 8.6139ms | 116.0913 Ops/s | 112.8169 Ops/s | |
test_sac_speed[False-backward] | 13.2702ms | 11.6983ms | 85.4823 Ops/s | 85.6903 Ops/s | |
test_sac_speed[True-None] | 2.9698ms | 2.3202ms | 431.0065 Ops/s | 425.5733 Ops/s | |
test_sac_speed[True-backward] | 4.6754ms | 4.2780ms | 233.7565 Ops/s | 234.3388 Ops/s | |
test_sac_speed[reduce-overhead-None] | 3.0616ms | 2.3143ms | 432.1020 Ops/s | 439.4245 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 4.4695ms | 4.0749ms | 245.4023 Ops/s | 231.6226 Ops/s | |
test_redq_speed[False-None] | 21.7105ms | 14.1247ms | 70.7981 Ops/s | 44.1069 Ops/s | |
test_redq_speed[False-backward] | 28.5546ms | 23.4750ms | 42.5985 Ops/s | 38.8928 Ops/s | |
test_redq_speed[True-None] | 6.9007ms | 5.7849ms | 172.8635 Ops/s | 163.1112 Ops/s | |
test_redq_speed[True-backward] | 14.4091ms | 13.5611ms | 73.7402 Ops/s | 73.0197 Ops/s | |
test_redq_speed[reduce-overhead-None] | 7.7795ms | 6.1898ms | 161.5555 Ops/s | 160.6513 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 14.0332ms | 13.6136ms | 73.4562 Ops/s | 72.3556 Ops/s | |
test_redq_deprec_speed[False-None] | 15.7045ms | 14.0947ms | 70.9489 Ops/s | 69.8282 Ops/s | |
test_redq_deprec_speed[False-backward] | 21.2581ms | 20.0987ms | 49.7545 Ops/s | 48.7256 Ops/s | |
test_redq_deprec_speed[True-None] | 5.5154ms | 4.4693ms | 223.7480 Ops/s | 196.5869 Ops/s | |
test_redq_deprec_speed[True-backward] | 10.1029ms | 9.5273ms | 104.9617 Ops/s | 102.5082 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 5.3553ms | 4.6292ms | 216.0211 Ops/s | 201.5878 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 10.5307ms | 9.6938ms | 103.1586 Ops/s | 98.8246 Ops/s | |
test_td3_speed[False-None] | 9.1213ms | 8.5186ms | 117.3906 Ops/s | 113.3527 Ops/s | |
test_td3_speed[False-backward] | 12.3650ms | 11.6420ms | 85.8956 Ops/s | 86.5361 Ops/s | |
test_td3_speed[True-None] | 2.3487ms | 1.9316ms | 517.7163 Ops/s | 485.3659 Ops/s | |
test_td3_speed[True-backward] | 4.1774ms | 3.7144ms | 269.2204 Ops/s | 231.4543 Ops/s | |
test_td3_speed[reduce-overhead-None] | 2.3146ms | 1.9385ms | 515.8695 Ops/s | 458.1225 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 4.7259ms | 3.7052ms | 269.8896 Ops/s | 229.5230 Ops/s | |
test_cql_speed[False-None] | 40.2037ms | 37.4831ms | 26.6787 Ops/s | 26.0654 Ops/s | |
test_cql_speed[False-backward] | 52.1638ms | 49.1270ms | 20.3554 Ops/s | 20.2968 Ops/s | |
test_cql_speed[True-None] | 18.2500ms | 16.8993ms | 59.1742 Ops/s | 58.4800 Ops/s | |
test_cql_speed[True-backward] | 25.2514ms | 24.2121ms | 41.3016 Ops/s | 40.9804 Ops/s | |
test_cql_speed[reduce-overhead-None] | 17.7428ms | 17.0031ms | 58.8130 Ops/s | 58.4462 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 26.1318ms | 24.0266ms | 41.6206 Ops/s | 41.0456 Ops/s | |
test_a2c_speed[False-None] | 9.2314ms | 7.8685ms | 127.0883 Ops/s | 119.4653 Ops/s | |
test_a2c_speed[False-backward] | 16.3045ms | 15.3877ms | 64.9869 Ops/s | 63.0881 Ops/s | |
test_a2c_speed[True-None] | 4.5626ms | 4.0617ms | 246.1996 Ops/s | 235.6234 Ops/s | |
test_a2c_speed[True-backward] | 11.5150ms | 11.1903ms | 89.3631 Ops/s | 86.2838 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 4.7982ms | 4.1089ms | 243.3716 Ops/s | 242.1130 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 12.2835ms | 11.0832ms | 90.2266 Ops/s | 89.8618 Ops/s | |
test_ppo_speed[False-None] | 9.1831ms | 8.0722ms | 123.8824 Ops/s | 121.5473 Ops/s | |
test_ppo_speed[False-backward] | 16.2178ms | 15.6542ms | 63.8804 Ops/s | 61.7190 Ops/s | |
test_ppo_speed[True-None] | 4.9607ms | 4.4747ms | 223.4810 Ops/s | 211.5896 Ops/s | |
test_ppo_speed[True-backward] | 11.2125ms | 10.8270ms | 92.3615 Ops/s | 91.4193 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 5.3726ms | 4.5367ms | 220.4266 Ops/s | 216.6919 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 11.5022ms | 10.7840ms | 92.7301 Ops/s | 90.8362 Ops/s | |
test_reinforce_speed[False-None] | 7.9888ms | 6.8070ms | 146.9077 Ops/s | 139.1991 Ops/s | |
test_reinforce_speed[False-backward] | 11.4182ms | 10.5640ms | 94.6612 Ops/s | 93.4348 Ops/s | |
test_reinforce_speed[True-None] | 4.5961ms | 3.3494ms | 298.5604 Ops/s | 288.6687 Ops/s | |
test_reinforce_speed[True-backward] | 10.4033ms | 9.9213ms | 100.7933 Ops/s | 102.0729 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 4.2496ms | 3.5428ms | 282.2611 Ops/s | 271.7736 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 11.1994ms | 9.8103ms | 101.9334 Ops/s | 102.4306 Ops/s | |
test_iql_speed[False-None] | 38.6850ms | 34.2055ms | 29.2351 Ops/s | 29.2510 Ops/s | |
test_iql_speed[False-backward] | 51.4375ms | 47.2706ms | 21.1548 Ops/s | 20.5290 Ops/s | |
test_iql_speed[True-None] | 12.6373ms | 11.8898ms | 84.1055 Ops/s | 82.4669 Ops/s | |
test_iql_speed[True-backward] | 24.4971ms | 23.1719ms | 43.1557 Ops/s | 41.6544 Ops/s | |
test_iql_speed[reduce-overhead-None] | 13.1386ms | 11.8062ms | 84.7009 Ops/s | 82.0501 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 25.7915ms | 23.9010ms | 41.8393 Ops/s | 41.2613 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.0182ms | 5.4538ms | 183.3592 Ops/s | 177.1974 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.8188ms | 0.5404ms | 1.8503 KOps/s | 1.8281 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.8372ms | 0.5353ms | 1.8680 KOps/s | 1.9004 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.0062ms | 5.0540ms | 197.8612 Ops/s | 183.8867 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.9942ms | 0.5377ms | 1.8598 KOps/s | 1.8495 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.8075ms | 0.5116ms | 1.9545 KOps/s | 1.9197 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.4019ms | 1.6978ms | 589.0109 Ops/s | 585.8443 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.9480ms | 1.5993ms | 625.2852 Ops/s | 621.9271 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 8.1560ms | 5.2856ms | 189.1924 Ops/s | 181.5720 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.6445ms | 0.6805ms | 1.4695 KOps/s | 1.4545 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9914ms | 0.6508ms | 1.5366 KOps/s | 1.5024 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 5.3820ms | 4.9928ms | 200.2887 Ops/s | 189.6148 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 2.9609ms | 0.5372ms | 1.8614 KOps/s | 1.8205 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7866ms | 0.5176ms | 1.9320 KOps/s | 1.8148 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 5.3978ms | 4.8991ms | 204.1182 Ops/s | 189.6205 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.2251ms | 0.5330ms | 1.8762 KOps/s | 1.8604 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7370ms | 0.5058ms | 1.9769 KOps/s | 1.9500 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.8597ms | 5.0734ms | 197.1077 Ops/s | 184.1117 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 3.4478ms | 0.6762ms | 1.4788 KOps/s | 1.4637 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9312ms | 0.6670ms | 1.4993 KOps/s | 1.4337 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 6.3107ms | 4.5043ms | 222.0099 Ops/s | 192.3086 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 7.9988ms | 2.3873ms | 418.8869 Ops/s | 401.9326 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 6.1279ms | 1.5215ms | 657.2425 Ops/s | 641.7531 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.5156s | 14.6757ms | 68.1399 Ops/s | 209.5169 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 8.3660ms | 2.4096ms | 415.0136 Ops/s | 391.2203 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 5.0371ms | 1.4368ms | 695.9975 Ops/s | 679.7641 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 6.0241ms | 4.5288ms | 220.8074 Ops/s | 28.1240 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 8.7653ms | 2.5224ms | 396.4503 Ops/s | 347.6888 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 2.0631ms | 1.4950ms | 668.9152 Ops/s | 542.8235 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 14.3769ms | 12.0321ms | 83.1108 Ops/s | 78.3856 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 16.9266ms | 14.7644ms | 67.7303 Ops/s | 65.1503 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 21.5666ms | 20.9713ms | 47.6842 Ops/s | 45.6904 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 16.0822ms | 14.9043ms | 67.0947 Ops/s | 65.4968 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 21.8267ms | 20.8458ms | 47.9712 Ops/s | 46.6458 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 17.8570ms | 16.0227ms | 62.4114 Ops/s | 60.0350 Ops/s |
ghstack-source-id: 987728d0714cfd4947786b5c0fed396ee7cc6729 Pull Request resolved: #2777
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.8786s | 0.7956s | 1.2569 Ops/s | 1.2611 Ops/s | |
test_transformed | 1.4435s | 1.3576s | 0.7366 Ops/s | 0.7320 Ops/s | |
test_serial | 2.4104s | 2.3259s | 0.4299 Ops/s | 0.4408 Ops/s | |
test_parallel | 1.9286s | 1.8560s | 0.5388 Ops/s | 0.5380 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.1717ms | 40.7470μs | 24.5417 KOps/s | 24.8763 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 56.3310μs | 23.4802μs | 42.5891 KOps/s | 42.7677 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 49.6410μs | 22.4791μs | 44.4858 KOps/s | 44.6478 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 46.0510μs | 13.2520μs | 75.4606 KOps/s | 76.5435 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 79.0610μs | 42.7430μs | 23.3956 KOps/s | 22.7483 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 58.5910μs | 26.0256μs | 38.4237 KOps/s | 38.8652 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 54.3110μs | 24.9057μs | 40.1514 KOps/s | 40.2319 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 0.2137ms | 15.5685μs | 64.2321 KOps/s | 64.7202 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 80.7210μs | 44.8109μs | 22.3160 KOps/s | 21.5862 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 59.9410μs | 28.1089μs | 35.5759 KOps/s | 35.2387 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 55.3510μs | 24.5022μs | 40.8126 KOps/s | 39.1372 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 46.6810μs | 15.4697μs | 64.6424 KOps/s | 64.6161 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 84.3720μs | 47.1066μs | 21.2285 KOps/s | 21.0384 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 52.6100μs | 30.3699μs | 32.9273 KOps/s | 33.0631 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 59.0510μs | 27.0829μs | 36.9237 KOps/s | 36.6330 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 53.5710μs | 17.7425μs | 56.3620 KOps/s | 57.1185 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 78.3910μs | 44.8781μs | 22.2826 KOps/s | 21.8380 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 66.6310μs | 28.0238μs | 35.6840 KOps/s | 35.0886 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 62.4410μs | 28.1930μs | 35.4697 KOps/s | 35.1408 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 49.6300μs | 17.0215μs | 58.7493 KOps/s | 58.0958 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 89.8010μs | 46.7961μs | 21.3693 KOps/s | 21.0127 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 60.8110μs | 30.4100μs | 32.8839 KOps/s | 32.7891 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 3.2762ms | 31.1442μs | 32.1087 KOps/s | 31.7969 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 53.6510μs | 19.3140μs | 51.7758 KOps/s | 51.0578 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 85.5720μs | 49.7788μs | 20.0889 KOps/s | 20.0069 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 62.0410μs | 32.7174μs | 30.5648 KOps/s | 30.1534 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 62.1220μs | 30.7500μs | 32.5203 KOps/s | 32.2757 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 49.3410μs | 19.4076μs | 51.5262 KOps/s | 51.1985 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 87.1920μs | 51.6885μs | 19.3467 KOps/s | 19.3771 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 71.5010μs | 35.1028μs | 28.4878 KOps/s | 28.6307 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 65.9710μs | 32.6254μs | 30.6510 KOps/s | 30.8671 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 56.0910μs | 21.5171μs | 46.4747 KOps/s | 46.3575 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 25.5935ms | 24.4152ms | 40.9581 Ops/s | 41.0969 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 0.1005s | 2.9060ms | 344.1128 Ops/s | 347.5007 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1037ms | 78.5828μs | 12.7254 KOps/s | 12.7869 KOps/s | |
test_values[td1_return_estimate-False-False] | 57.3240ms | 56.2007ms | 17.7934 Ops/s | 18.5761 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3193ms | 1.0759ms | 929.4432 Ops/s | 930.2487 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 90.4161ms | 88.4887ms | 11.3009 Ops/s | 11.6552 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.2554ms | 1.0643ms | 939.6264 Ops/s | 932.3233 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 24.3847ms | 24.1652ms | 41.3819 Ops/s | 41.7860 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0312ms | 0.7461ms | 1.3403 KOps/s | 1.3502 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.8027ms | 0.6578ms | 1.5201 KOps/s | 1.5209 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.6255ms | 1.4761ms | 677.4494 Ops/s | 680.6258 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.7129ms | 0.6714ms | 1.4894 KOps/s | 1.4857 KOps/s | |
test_dqn_speed[False-None] | 7.1951ms | 1.5125ms | 661.1410 Ops/s | 658.8248 Ops/s | |
test_dqn_speed[False-backward] | 2.1797ms | 2.0963ms | 477.0310 Ops/s | 479.5241 Ops/s | |
test_dqn_speed[True-None] | 0.7363ms | 0.5782ms | 1.7296 KOps/s | 1.7266 KOps/s | |
test_dqn_speed[True-backward] | 1.2112ms | 1.1438ms | 874.3110 Ops/s | 851.9236 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.7466ms | 0.5966ms | 1.6761 KOps/s | 1.6529 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0345ms | 0.9867ms | 1.0134 KOps/s | 1.0101 KOps/s | |
test_ddpg_speed[False-None] | 3.1531ms | 2.8572ms | 349.9940 Ops/s | 349.5707 Ops/s | |
test_ddpg_speed[False-backward] | 4.4572ms | 4.0096ms | 249.4037 Ops/s | 245.6680 Ops/s | |
test_ddpg_speed[True-None] | 1.5387ms | 1.3823ms | 723.4245 Ops/s | 723.8293 Ops/s | |
test_ddpg_speed[True-backward] | 2.6547ms | 2.4598ms | 406.5379 Ops/s | 382.9922 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.5422ms | 1.3936ms | 717.5792 Ops/s | 718.8454 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 1.9687ms | 1.9122ms | 522.9517 Ops/s | 505.0686 Ops/s | |
test_sac_speed[False-None] | 8.2828ms | 7.8749ms | 126.9851 Ops/s | 124.5228 Ops/s | |
test_sac_speed[False-backward] | 11.3894ms | 10.6337ms | 94.0407 Ops/s | 91.9761 Ops/s | |
test_sac_speed[True-None] | 2.0964ms | 1.9064ms | 524.5466 Ops/s | 521.3288 Ops/s | |
test_sac_speed[True-backward] | 3.8161ms | 3.6237ms | 275.9580 Ops/s | 261.4267 Ops/s | |
test_sac_speed[reduce-overhead-None] | 17.9095ms | 10.9453ms | 91.3637 Ops/s | 91.0086 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.8246ms | 1.6547ms | 604.3317 Ops/s | 540.7872 Ops/s | |
test_redq_speed[False-None] | 8.0354ms | 7.4320ms | 134.5540 Ops/s | 133.5715 Ops/s | |
test_redq_speed[False-backward] | 11.4371ms | 10.9901ms | 90.9911 Ops/s | 87.0343 Ops/s | |
test_redq_speed[True-None] | 2.6744ms | 2.3804ms | 420.0915 Ops/s | 419.6582 Ops/s | |
test_redq_speed[True-backward] | 4.7103ms | 4.2522ms | 235.1713 Ops/s | 230.9268 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.5856ms | 2.4044ms | 415.9006 Ops/s | 414.1702 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 4.7164ms | 4.2999ms | 232.5653 Ops/s | 242.8797 Ops/s | |
test_redq_deprec_speed[False-None] | 9.4099ms | 8.9195ms | 112.1143 Ops/s | 112.7546 Ops/s | |
test_redq_deprec_speed[False-backward] | 12.6458ms | 12.0786ms | 82.7908 Ops/s | 85.4753 Ops/s | |
test_redq_deprec_speed[True-None] | 2.8574ms | 2.6987ms | 370.5467 Ops/s | 368.3570 Ops/s | |
test_redq_deprec_speed[True-backward] | 4.9892ms | 4.5516ms | 219.7015 Ops/s | 218.1203 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 2.9908ms | 2.7184ms | 367.8641 Ops/s | 367.9568 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.9209ms | 4.5369ms | 220.4141 Ops/s | 217.4761 Ops/s | |
test_td3_speed[False-None] | 8.0084ms | 7.8666ms | 127.1193 Ops/s | 127.1620 Ops/s | |
test_td3_speed[False-backward] | 10.9055ms | 10.3379ms | 96.7313 Ops/s | 97.1737 Ops/s | |
test_td3_speed[True-None] | 1.8555ms | 1.7474ms | 572.2841 Ops/s | 577.5517 Ops/s | |
test_td3_speed[True-backward] | 3.5264ms | 3.4342ms | 291.1868 Ops/s | 296.4637 Ops/s | |
test_td3_speed[reduce-overhead-None] | 71.1071ms | 27.5369ms | 36.3150 Ops/s | 36.5464 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.6810ms | 1.5488ms | 645.6573 Ops/s | 711.7019 Ops/s | |
test_cql_speed[False-None] | 16.9867ms | 16.5565ms | 60.3994 Ops/s | 59.1719 Ops/s | |
test_cql_speed[False-backward] | 22.3396ms | 21.8101ms | 45.8503 Ops/s | 45.9636 Ops/s | |
test_cql_speed[True-None] | 3.5338ms | 3.3647ms | 297.2055 Ops/s | 295.3280 Ops/s | |
test_cql_speed[True-backward] | 6.3735ms | 5.8032ms | 172.3182 Ops/s | 171.3373 Ops/s | |
test_cql_speed[reduce-overhead-None] | 18.9534ms | 13.1209ms | 76.2143 Ops/s | 75.9782 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 2.1990ms | 2.0520ms | 487.3279 Ops/s | 484.2529 Ops/s | |
test_a2c_speed[False-None] | 3.5929ms | 3.2231ms | 310.2562 Ops/s | 315.1967 Ops/s | |
test_a2c_speed[False-backward] | 6.8604ms | 6.2129ms | 160.9545 Ops/s | 159.6809 Ops/s | |
test_a2c_speed[True-None] | 1.5400ms | 1.3833ms | 722.8960 Ops/s | 717.2394 Ops/s | |
test_a2c_speed[True-backward] | 3.1485ms | 3.0820ms | 324.4597 Ops/s | 317.8666 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 14.7986ms | 8.6459ms | 115.6620 Ops/s | 117.2348 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.7406ms | 1.6225ms | 616.3181 Ops/s | 664.7709 Ops/s | |
test_ppo_speed[False-None] | 3.9404ms | 3.6673ms | 272.6767 Ops/s | 261.7582 Ops/s | |
test_ppo_speed[False-backward] | 7.3243ms | 6.8639ms | 145.6902 Ops/s | 147.9527 Ops/s | |
test_ppo_speed[True-None] | 1.6288ms | 1.4442ms | 692.4384 Ops/s | 671.9319 Ops/s | |
test_ppo_speed[True-backward] | 3.3998ms | 3.2581ms | 306.9252 Ops/s | 316.1991 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 1.1556ms | 0.9928ms | 1.0072 KOps/s | 1.0222 KOps/s | |
test_ppo_speed[reduce-overhead-backward] | 1.7949ms | 1.5849ms | 630.9453 Ops/s | 679.0158 Ops/s | |
test_reinforce_speed[False-None] | 2.3946ms | 2.2364ms | 447.1403 Ops/s | 440.3889 Ops/s | |
test_reinforce_speed[False-backward] | 3.7478ms | 3.3008ms | 302.9542 Ops/s | 307.1962 Ops/s | |
test_reinforce_speed[True-None] | 1.5124ms | 1.3362ms | 748.4079 Ops/s | 737.7495 Ops/s | |
test_reinforce_speed[True-backward] | 3.3241ms | 3.1225ms | 320.2519 Ops/s | 337.5356 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 17.1499ms | 9.3837ms | 106.5682 Ops/s | 106.0231 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.7549ms | 1.6672ms | 599.8191 Ops/s | 649.2735 Ops/s | |
test_iql_speed[False-None] | 9.5119ms | 9.0545ms | 110.4424 Ops/s | 108.8079 Ops/s | |
test_iql_speed[False-backward] | 13.4574ms | 12.8744ms | 77.6732 Ops/s | 78.2191 Ops/s | |
test_iql_speed[True-None] | 2.6281ms | 2.3234ms | 430.4037 Ops/s | 425.4990 Ops/s | |
test_iql_speed[True-backward] | 5.4419ms | 5.0405ms | 198.3936 Ops/s | 199.6933 Ops/s | |
test_iql_speed[reduce-overhead-None] | 0.4871s | 12.8873ms | 77.5960 Ops/s | 95.0508 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 2.2356ms | 2.1002ms | 476.1507 Ops/s | 462.9382 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.9727ms | 6.3813ms | 156.7085 Ops/s | 155.2952 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.5695ms | 0.3017ms | 3.3150 KOps/s | 3.0444 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6491ms | 0.2828ms | 3.5360 KOps/s | 3.4241 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.4525ms | 6.0635ms | 164.9220 Ops/s | 163.8898 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.6763ms | 0.2610ms | 3.8320 KOps/s | 2.8414 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.4581ms | 0.2393ms | 4.1789 KOps/s | 4.1942 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.4467ms | 1.2349ms | 809.8113 Ops/s | 818.8823 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.3673ms | 1.1724ms | 852.9598 Ops/s | 878.0794 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.4717ms | 6.2593ms | 159.7629 Ops/s | 158.9487 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.1412ms | 0.4691ms | 2.1319 KOps/s | 2.3784 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7355ms | 0.4515ms | 2.2150 KOps/s | 2.4432 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.3428ms | 6.1668ms | 162.1578 Ops/s | 161.7360 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.1466ms | 0.3180ms | 3.1450 KOps/s | 2.8127 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6292ms | 0.3131ms | 3.1939 KOps/s | 3.0488 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.4029ms | 6.1092ms | 163.6880 Ops/s | 162.3777 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.7247ms | 0.2831ms | 3.5322 KOps/s | 3.5260 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5883ms | 0.2633ms | 3.7983 KOps/s | 4.1269 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.4835ms | 6.2787ms | 159.2685 Ops/s | 158.9411 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 2.0382ms | 0.4603ms | 2.1725 KOps/s | 2.1948 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6277ms | 0.4273ms | 2.3401 KOps/s | 2.2457 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.1838ms | 5.5257ms | 180.9731 Ops/s | 180.0639 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 8.5506ms | 2.0437ms | 489.3119 Ops/s | 399.5050 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 7.1175ms | 1.2645ms | 790.8422 Ops/s | 842.2027 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 0.4592s | 14.7641ms | 67.7320 Ops/s | 182.1076 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 10.5196ms | 2.1050ms | 475.0706 Ops/s | 434.6802 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 6.0451ms | 1.1957ms | 836.3601 Ops/s | 738.7354 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 7.6051ms | 5.7491ms | 173.9405 Ops/s | 31.2084 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 5.8437ms | 2.1407ms | 467.1332 Ops/s | 431.1523 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 10.4025ms | 1.4345ms | 697.0908 Ops/s | 687.7299 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 13.4609ms | 13.1150ms | 76.2487 Ops/s | 72.4219 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 18.5415ms | 16.8237ms | 59.4400 Ops/s | 58.1836 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 18.6682ms | 18.1294ms | 55.1591 Ops/s | 54.2372 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 18.4210ms | 16.8962ms | 59.1849 Ops/s | 57.1666 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 18.0790ms | 17.6335ms | 56.7103 Ops/s | 55.0078 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 0.4046s | 25.9316ms | 38.5629 Ops/s | 54.7464 Ops/s |
tensordict_in = tensordict_in[partial_steps] | ||
else: | ||
if not partial_steps.any(): | ||
next_tensordict = self._skip_tensordict(tensordict_in) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we skip the step, would it also make sense to skip the transform on line 976?
next_tensordict = self.transform._step(tensordict, next_tensordict)
The setup I have is something like this -- td['no_step']
is the condition for ConditionalSkip
data_policy
generates single token fromtd['obs_tokens']
append_transform
after (before in.inv
but after in code)ConditionalSkip
that overrides_inv_step
to do the following- append new token to `td[obs_tokens]
- Check whether end token is reached, if so check if answer is a valid move
- If valid,td['no_step'] = False
- else resettd['obs_tokens'] and
td['no_step'] = True` - If max steps has been reached, also reset
- else
td['no_step'] = True
Thing is, I have a bunch of transforms that patch _step
that should only really be applied if the environment step actually happened (e.g. scoring the move made, appending to the san history etc.)
I was able to write some if conditions that would make sure these transforms would be appropriately skipped. but I think it would be convenient if either the transforms were skipped or there was a way to also wrap a transform with this sort of Conditional-ness
(e.g. provide a primitive such that one can wrap an existing torchrl Transform e.g. Tokenizer with some Conditional primitive)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thing is, I have a bunch of transforms that patch _step that should only really be applied if the environment step actually happened (e.g. scoring the move made, appending to the san history etc.)
Yes that makes sense!
We just need to be extra careful that skipping the transforms won't cause the output tensordict to be of a different format.
torchrl/envs/common.py
Outdated
next_tensordict.update(self.full_reward_spec.zero()) | ||
# Copy the data from tensordict in `next` | ||
next_tensordict.update( | ||
tensordict.select(*next_tensordict.keys(True, True), strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that this line will not select any key corresponding to a NonTensorData
in next_tensordict
, with the reason being that this line yields nothing I believe and then leaves_only=True
so the key itself will not be yielded
So the NonTensorData values will just have their example value, is this by design?
(Maybe passing leaves_only=False
will be sufficient for my purposes, lemme try)
ghstack-source-id: 587f91e33dfe1d59b73c4b2f2f1c21760ee79d2e Pull Request resolved: #2777
Stack from ghstack (oldest at bottom):