-
Notifications
You must be signed in to change notification settings - Fork 365
[Feature] Set padded token log-prob to 0.0 #2856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2856
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.6396s | 0.5396s | 1.8531 Ops/s | 1.8217 Ops/s | |
test_transformed | 1.1475s | 1.0539s | 0.9488 Ops/s | 0.9469 Ops/s | |
test_serial | 1.6144s | 1.5334s | 0.6521 Ops/s | 0.6313 Ops/s | |
test_parallel | 1.3998s | 1.3050s | 0.7663 Ops/s | 0.7308 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.6384ms | 30.8999μs | 32.3626 KOps/s | 32.8228 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 47.6090μs | 17.6924μs | 56.5214 KOps/s | 56.8585 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 71.0130μs | 17.0609μs | 58.6136 KOps/s | 58.0981 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 42.2190μs | 9.9564μs | 100.4382 KOps/s | 101.1605 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 0.1081ms | 32.0936μs | 31.1589 KOps/s | 30.2442 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 0.1128ms | 19.4855μs | 51.3203 KOps/s | 50.6915 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 72.6260μs | 18.8793μs | 52.9680 KOps/s | 51.9875 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 82.3840μs | 11.8694μs | 84.2506 KOps/s | 85.4716 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 89.4760μs | 33.8452μs | 29.5463 KOps/s | 28.6813 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 77.8750μs | 21.5751μs | 46.3496 KOps/s | 46.4426 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 57.7780μs | 19.1793μs | 52.1395 KOps/s | 51.9869 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 58.7490μs | 11.9239μs | 83.8648 KOps/s | 84.9303 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 96.9210μs | 36.0785μs | 27.7173 KOps/s | 27.3632 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 72.1050μs | 23.3566μs | 42.8144 KOps/s | 43.4046 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 76.2730μs | 20.8539μs | 47.9526 KOps/s | 47.1709 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 43.0600μs | 13.7275μs | 72.8465 KOps/s | 74.4400 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 94.1050μs | 34.0977μs | 29.3275 KOps/s | 28.6962 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 61.4250μs | 21.5644μs | 46.3727 KOps/s | 45.9075 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 0.6245ms | 21.8820μs | 45.6998 KOps/s | 45.3283 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 67.4970μs | 13.2876μs | 75.2579 KOps/s | 75.6258 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 2.9815ms | 35.6174μs | 28.0761 KOps/s | 26.8581 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 45.0240μs | 23.1844μs | 43.1325 KOps/s | 42.5265 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 70.2910μs | 23.3893μs | 42.7545 KOps/s | 41.5201 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 66.9650μs | 14.9950μs | 66.6889 KOps/s | 66.1609 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 0.1092ms | 37.4681μs | 26.6894 KOps/s | 25.7689 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 67.4680μs | 25.2203μs | 39.6507 KOps/s | 39.6751 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 69.9010μs | 24.0064μs | 41.6556 KOps/s | 41.8379 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 73.3170μs | 15.0276μs | 66.5442 KOps/s | 67.0357 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 84.2280μs | 38.9785μs | 25.6552 KOps/s | 25.0054 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 88.6450μs | 26.4994μs | 37.7367 KOps/s | 37.2416 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 90.6490μs | 25.0003μs | 39.9994 KOps/s | 39.2191 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 93.1030μs | 16.5919μs | 60.2703 KOps/s | 59.9750 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 12.6139ms | 9.9639ms | 100.3619 Ops/s | 98.8159 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 28.7121ms | 26.6455ms | 37.5297 Ops/s | 36.4322 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2723ms | 0.2182ms | 4.5832 KOps/s | 4.9846 KOps/s | |
test_values[td1_return_estimate-False-False] | 27.8808ms | 24.6078ms | 40.6376 Ops/s | 41.0399 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 28.9743ms | 26.6104ms | 37.5793 Ops/s | 37.2887 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 38.3114ms | 35.1169ms | 28.4763 Ops/s | 28.3160 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 28.8595ms | 26.6138ms | 37.5745 Ops/s | 37.4501 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 13.6271ms | 8.6437ms | 115.6911 Ops/s | 116.5174 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 3.5061ms | 1.9409ms | 515.2139 Ops/s | 501.7411 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.5368ms | 0.3735ms | 2.6774 KOps/s | 2.6057 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 49.1064ms | 44.3859ms | 22.5297 Ops/s | 22.0475 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.5476ms | 3.6251ms | 275.8544 Ops/s | 272.9738 Ops/s | |
test_dqn_speed[False-None] | 2.0839ms | 1.4482ms | 690.5013 Ops/s | 677.3085 Ops/s | |
test_dqn_speed[False-backward] | 2.0482ms | 1.9539ms | 511.7979 Ops/s | 506.3893 Ops/s | |
test_dqn_speed[True-None] | 0.7513ms | 0.5702ms | 1.7538 KOps/s | 1.7284 KOps/s | |
test_dqn_speed[True-backward] | 1.1563ms | 1.0271ms | 973.6220 Ops/s | 951.2423 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.8342ms | 0.5743ms | 1.7413 KOps/s | 1.7249 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.1388ms | 1.0376ms | 963.7672 Ops/s | 929.7910 Ops/s | |
test_ddpg_speed[False-None] | 4.5428ms | 3.0983ms | 322.7605 Ops/s | 328.8437 Ops/s | |
test_ddpg_speed[False-backward] | 5.2820ms | 4.3354ms | 230.6580 Ops/s | 235.8324 Ops/s | |
test_ddpg_speed[True-None] | 2.1061ms | 1.4835ms | 674.0916 Ops/s | 661.8541 Ops/s | |
test_ddpg_speed[True-backward] | 3.2537ms | 2.9183ms | 342.6646 Ops/s | 394.9085 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 2.1702ms | 1.4621ms | 683.9649 Ops/s | 654.7552 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 3.3119ms | 2.6045ms | 383.9551 Ops/s | 407.0529 Ops/s | |
test_sac_speed[False-None] | 10.1616ms | 8.8696ms | 112.7447 Ops/s | 118.3627 Ops/s | |
test_sac_speed[False-backward] | 13.0036ms | 11.6314ms | 85.9743 Ops/s | 86.0017 Ops/s | |
test_sac_speed[True-None] | 4.4455ms | 2.7828ms | 359.3547 Ops/s | 331.8597 Ops/s | |
test_sac_speed[True-backward] | 5.2254ms | 4.3573ms | 229.5004 Ops/s | 229.2183 Ops/s | |
test_sac_speed[reduce-overhead-None] | 3.5202ms | 2.6618ms | 375.6867 Ops/s | 377.3870 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 4.9063ms | 4.6259ms | 216.1749 Ops/s | 227.7042 Ops/s | |
test_redq_speed[False-None] | 20.0727ms | 13.8158ms | 72.3810 Ops/s | 74.2658 Ops/s | |
test_redq_speed[False-backward] | 28.3719ms | 23.9404ms | 41.7705 Ops/s | 43.5306 Ops/s | |
test_redq_speed[True-None] | 7.9284ms | 7.2626ms | 137.6908 Ops/s | 138.9004 Ops/s | |
test_redq_speed[True-backward] | 17.0921ms | 15.3929ms | 64.9651 Ops/s | 62.1706 Ops/s | |
test_redq_speed[reduce-overhead-None] | 8.3532ms | 7.5659ms | 132.1723 Ops/s | 134.4694 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 16.1303ms | 15.4557ms | 64.7010 Ops/s | 63.6621 Ops/s | |
test_redq_deprec_speed[False-None] | 14.8064ms | 13.8996ms | 71.9447 Ops/s | 70.5992 Ops/s | |
test_redq_deprec_speed[False-backward] | 21.4012ms | 19.7281ms | 50.6892 Ops/s | 49.7091 Ops/s | |
test_redq_deprec_speed[True-None] | 7.4574ms | 6.4210ms | 155.7400 Ops/s | 155.4640 Ops/s | |
test_redq_deprec_speed[True-backward] | 12.1567ms | 11.2375ms | 88.9880 Ops/s | 90.9988 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 6.8972ms | 6.3181ms | 158.2758 Ops/s | 192.9707 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 12.3227ms | 11.7484ms | 85.1179 Ops/s | 101.0037 Ops/s | |
test_td3_speed[False-None] | 9.5729ms | 9.1031ms | 109.8525 Ops/s | 122.3196 Ops/s | |
test_td3_speed[False-backward] | 13.3362ms | 12.4113ms | 80.5720 Ops/s | 94.6116 Ops/s | |
test_td3_speed[True-None] | 3.3749ms | 3.0591ms | 326.8923 Ops/s | 434.7736 Ops/s | |
test_td3_speed[True-backward] | 5.7932ms | 5.3713ms | 186.1738 Ops/s | 247.2456 Ops/s | |
test_td3_speed[reduce-overhead-None] | 3.3625ms | 2.9552ms | 338.3866 Ops/s | 423.5566 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 5.0570ms | 4.6162ms | 216.6287 Ops/s | 239.2127 Ops/s | |
test_cql_speed[False-None] | 41.4512ms | 38.9019ms | 25.7057 Ops/s | 26.4504 Ops/s | |
test_cql_speed[False-backward] | 58.0712ms | 50.5153ms | 19.7960 Ops/s | 20.6553 Ops/s | |
test_cql_speed[True-None] | 24.6280ms | 22.9956ms | 43.4866 Ops/s | 44.6014 Ops/s | |
test_cql_speed[True-backward] | 32.3346ms | 30.4260ms | 32.8666 Ops/s | 33.3788 Ops/s | |
test_cql_speed[reduce-overhead-None] | 24.1577ms | 22.9207ms | 43.6287 Ops/s | 43.4242 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 31.9104ms | 29.5290ms | 33.8651 Ops/s | 32.6192 Ops/s | |
test_a2c_speed[False-None] | 8.3630ms | 7.3484ms | 136.0843 Ops/s | 130.0680 Ops/s | |
test_a2c_speed[False-backward] | 18.1305ms | 14.8752ms | 67.2261 Ops/s | 64.8883 Ops/s | |
test_a2c_speed[True-None] | 5.3689ms | 4.8925ms | 204.3959 Ops/s | 208.8834 Ops/s | |
test_a2c_speed[True-backward] | 13.5770ms | 11.6590ms | 85.7709 Ops/s | 87.2403 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 6.4750ms | 4.9461ms | 202.1793 Ops/s | 214.4145 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 12.6908ms | 11.9906ms | 83.3990 Ops/s | 90.4536 Ops/s | |
test_ppo_speed[False-None] | 8.9192ms | 8.0434ms | 124.3252 Ops/s | 131.8977 Ops/s | |
test_ppo_speed[False-backward] | 17.4710ms | 15.6697ms | 63.8174 Ops/s | 65.0016 Ops/s | |
test_ppo_speed[True-None] | 5.4619ms | 5.0690ms | 197.2776 Ops/s | 198.9756 Ops/s | |
test_ppo_speed[True-backward] | 12.9264ms | 11.7748ms | 84.9270 Ops/s | 90.1325 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 6.2975ms | 5.1777ms | 193.1359 Ops/s | 197.7797 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 12.2503ms | 11.4628ms | 87.2384 Ops/s | 86.9105 Ops/s | |
test_reinforce_speed[False-None] | 8.2056ms | 6.8040ms | 146.9714 Ops/s | 149.9828 Ops/s | |
test_reinforce_speed[False-backward] | 11.4936ms | 10.3173ms | 96.9246 Ops/s | 98.9830 Ops/s | |
test_reinforce_speed[True-None] | 4.8246ms | 4.2371ms | 236.0083 Ops/s | 236.9776 Ops/s | |
test_reinforce_speed[True-backward] | 11.4313ms | 10.4785ms | 95.4333 Ops/s | 99.4020 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 4.9261ms | 4.3294ms | 230.9773 Ops/s | 243.6787 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 11.5664ms | 10.7132ms | 93.3431 Ops/s | 95.7398 Ops/s | |
test_iql_speed[False-None] | 39.3405ms | 33.6797ms | 29.6915 Ops/s | 29.7977 Ops/s | |
test_iql_speed[False-backward] | 52.0234ms | 46.9800ms | 21.2856 Ops/s | 21.4298 Ops/s | |
test_iql_speed[True-None] | 18.9072ms | 16.8004ms | 59.5223 Ops/s | 59.9102 Ops/s | |
test_iql_speed[True-backward] | 29.6624ms | 28.6957ms | 34.8484 Ops/s | 35.6765 Ops/s | |
test_iql_speed[reduce-overhead-None] | 18.0376ms | 16.8045ms | 59.5079 Ops/s | 60.5897 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 31.2731ms | 29.0536ms | 34.4191 Ops/s | 35.7479 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 8.7441ms | 5.6636ms | 176.5672 Ops/s | 195.2484 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.9465ms | 0.5934ms | 1.6852 KOps/s | 1.8208 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.8859ms | 0.5615ms | 1.7809 KOps/s | 1.8525 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.9372ms | 5.4748ms | 182.6564 Ops/s | 199.8850 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 1.1713ms | 0.5782ms | 1.7295 KOps/s | 1.7990 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.8712ms | 0.5419ms | 1.8452 KOps/s | 1.9209 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 2.6514ms | 1.8635ms | 536.6295 Ops/s | 574.1030 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.7729ms | 1.7394ms | 574.9191 Ops/s | 603.2377 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 9.5291ms | 6.1031ms | 163.8513 Ops/s | 196.9971 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 3.8566ms | 0.7448ms | 1.3427 KOps/s | 1.4415 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 1.2373ms | 0.6953ms | 1.4383 KOps/s | 1.5148 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 8.2402ms | 5.6539ms | 176.8703 Ops/s | 202.1862 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 4.0150ms | 0.5974ms | 1.6739 KOps/s | 1.7795 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.9007ms | 0.5733ms | 1.7443 KOps/s | 1.9000 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.8801ms | 5.3277ms | 187.6999 Ops/s | 205.5357 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 3.0938ms | 0.6028ms | 1.6590 KOps/s | 1.7911 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 1.0062ms | 0.5656ms | 1.7680 KOps/s | 1.9016 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 7.4523ms | 5.9063ms | 169.3110 Ops/s | 199.1045 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.3595ms | 0.7643ms | 1.3083 KOps/s | 1.4373 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.9960ms | 0.7177ms | 1.3933 KOps/s | 1.4691 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 12.2319ms | 5.5502ms | 180.1725 Ops/s | 219.3705 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 9.6959ms | 2.8043ms | 356.5994 Ops/s | 448.4662 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 2.3456ms | 1.4943ms | 669.2033 Ops/s | 700.9793 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 10.4393ms | 5.5467ms | 180.2883 Ops/s | 227.2149 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 9.0326ms | 2.7717ms | 360.7880 Ops/s | 419.0649 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 6.6280ms | 1.5800ms | 632.9234 Ops/s | 684.6056 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.9763s | 24.9228ms | 40.1238 Ops/s | 231.9270 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 11.5722ms | 2.9464ms | 339.3950 Ops/s | 382.3630 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 5.8538ms | 1.6439ms | 608.3050 Ops/s | 617.1232 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 65.0975ms | 51.5330ms | 19.4050 Ops/s | 20.1999 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 16.4064ms | 14.9211ms | 67.0192 Ops/s | 67.5640 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 67.7353ms | 52.0536ms | 19.2110 Ops/s | 19.4649 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 18.1680ms | 15.3029ms | 65.3470 Ops/s | 65.8082 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 69.1049ms | 52.7258ms | 18.9661 Ops/s | 19.8737 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 18.3987ms | 16.5552ms | 60.4039 Ops/s | 61.4564 Ops/s |
torch.where(padded_values), | ||
) | ||
lps = tokens_response_td["log_probs"] | ||
lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me 1 makes more sense, it's the log prob of prob = 1 (you don't do the ratio of log-probs but their difference).
Any value would do, but I don't see why choosing p=2.71 is superior to 1
Stack from ghstack (oldest at bottom):