-
Notifications
You must be signed in to change notification settings - Fork 414
[Feature] policy factory for collectors #2841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/2841
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 5 New Failures, 4 Unrelated FailuresAs of commit 6951230 with merge base 9cd95d5 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.6168s | 0.5364s | 1.8643 Ops/s | 1.8402 Ops/s | |
test_transformed | 1.1388s | 1.0652s | 0.9388 Ops/s | 0.9272 Ops/s | |
test_serial | 1.5549s | 1.5506s | 0.6449 Ops/s | 0.6196 Ops/s | |
test_parallel | 1.3039s | 1.2941s | 0.7727 Ops/s | 0.7675 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.2112ms | 30.9048μs | 32.3574 KOps/s | 33.0984 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 53.0590μs | 18.2166μs | 54.8950 KOps/s | 56.0533 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 54.3020μs | 17.6048μs | 56.8027 KOps/s | 59.0467 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 37.3700μs | 10.3838μs | 96.3034 KOps/s | 101.1763 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 68.6390μs | 32.9828μs | 30.3188 KOps/s | 31.0836 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 94.6500μs | 19.6530μs | 50.8828 KOps/s | 50.6151 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 59.5610μs | 19.3250μs | 51.7463 KOps/s | 52.1443 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 52.1870μs | 11.9653μs | 83.5748 KOps/s | 82.1184 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 77.3850μs | 34.7336μs | 28.7906 KOps/s | 29.6396 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 0.5809ms | 21.7425μs | 45.9929 KOps/s | 46.5983 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 72.0450μs | 19.1872μs | 52.1182 KOps/s | 52.6884 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 44.5430μs | 12.1198μs | 82.5096 KOps/s | 84.7906 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 81.3320μs | 36.2677μs | 27.5728 KOps/s | 28.4275 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 85.0290μs | 23.5005μs | 42.5523 KOps/s | 43.0688 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 65.1610μs | 20.9472μs | 47.7390 KOps/s | 48.5967 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 54.0510μs | 13.8511μs | 72.1966 KOps/s | 74.9867 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 83.1050μs | 34.4334μs | 29.0416 KOps/s | 29.5097 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 58.7990μs | 21.6838μs | 46.1174 KOps/s | 46.5345 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 2.2487ms | 21.9673μs | 45.5222 KOps/s | 46.5148 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 90.9170μs | 13.5355μs | 73.8797 KOps/s | 76.1682 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.1073ms | 35.9854μs | 27.7890 KOps/s | 28.0984 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 59.2710μs | 23.3225μs | 42.8771 KOps/s | 42.7980 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 72.3850μs | 23.5743μs | 42.4191 KOps/s | 42.5288 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 82.9790μs | 15.0230μs | 66.5647 KOps/s | 67.4905 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 86.0910μs | 37.8597μs | 26.4133 KOps/s | 26.9477 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 92.7930μs | 25.1859μs | 39.7047 KOps/s | 39.9755 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 0.1117ms | 23.3999μs | 42.7352 KOps/s | 43.9294 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 0.6188ms | 15.1590μs | 65.9675 KOps/s | 67.5333 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 83.1250μs | 39.3688μs | 25.4008 KOps/s | 26.0655 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 90.7030μs | 26.5320μs | 37.6904 KOps/s | 37.3539 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 70.6420μs | 24.9583μs | 40.0668 KOps/s | 40.7611 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 51.4560μs | 16.7528μs | 59.6916 KOps/s | 60.3846 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 10.2404ms | 9.8667ms | 101.3515 Ops/s | 99.8699 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 30.7996ms | 26.3815ms | 37.9053 Ops/s | 40.4089 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.2633ms | 0.1982ms | 5.0461 KOps/s | 5.5241 KOps/s | |
test_values[td1_return_estimate-False-False] | 27.0106ms | 24.4332ms | 40.9279 Ops/s | 40.6845 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 28.3246ms | 26.3249ms | 37.9869 Ops/s | 40.1344 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 37.7766ms | 35.5396ms | 28.1376 Ops/s | 28.1838 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 28.6761ms | 26.3916ms | 37.8908 Ops/s | 39.8019 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 12.4807ms | 8.7409ms | 114.4048 Ops/s | 116.6840 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 2.1924ms | 1.8689ms | 535.0707 Ops/s | 540.4710 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 0.5909ms | 0.3729ms | 2.6820 KOps/s | 2.6277 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 44.8035ms | 42.9561ms | 23.2796 Ops/s | 23.3060 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 4.6800ms | 3.4768ms | 287.6225 Ops/s | 286.0824 Ops/s | |
test_dqn_speed[False-None] | 5.4244ms | 1.4247ms | 701.8964 Ops/s | 676.8976 Ops/s | |
test_dqn_speed[False-backward] | 2.4914ms | 1.9464ms | 513.7782 Ops/s | 506.1853 Ops/s | |
test_dqn_speed[True-None] | 0.7091ms | 0.5591ms | 1.7887 KOps/s | 1.7348 KOps/s | |
test_dqn_speed[True-backward] | 1.0572ms | 0.9856ms | 1.0146 KOps/s | 984.5944 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.9278ms | 0.5660ms | 1.7669 KOps/s | 1.7328 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.0766ms | 0.9933ms | 1.0068 KOps/s | 984.1469 Ops/s | |
test_ddpg_speed[False-None] | 3.6522ms | 3.0001ms | 333.3257 Ops/s | 329.5932 Ops/s | |
test_ddpg_speed[False-backward] | 4.2034ms | 4.0878ms | 244.6282 Ops/s | 237.6997 Ops/s | |
test_ddpg_speed[True-None] | 1.6697ms | 1.4561ms | 686.7721 Ops/s | 677.4429 Ops/s | |
test_ddpg_speed[True-backward] | 2.8142ms | 2.6715ms | 374.3285 Ops/s | 411.5227 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 2.0924ms | 1.5229ms | 656.6218 Ops/s | 677.9380 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.6334ms | 2.4057ms | 415.6780 Ops/s | 426.7485 Ops/s | |
test_sac_speed[False-None] | 12.6826ms | 8.3527ms | 119.7212 Ops/s | 121.4429 Ops/s | |
test_sac_speed[False-backward] | 12.6386ms | 11.2302ms | 89.0456 Ops/s | 88.8924 Ops/s | |
test_sac_speed[True-None] | 4.7441ms | 2.8686ms | 348.6023 Ops/s | 386.6432 Ops/s | |
test_sac_speed[True-backward] | 5.4297ms | 5.2110ms | 191.9016 Ops/s | 229.6601 Ops/s | |
test_sac_speed[reduce-overhead-None] | 4.2439ms | 3.3342ms | 299.9232 Ops/s | 385.2004 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 6.8296ms | 4.9868ms | 200.5287 Ops/s | 218.0360 Ops/s | |
test_redq_speed[False-None] | 21.7096ms | 14.5451ms | 68.7515 Ops/s | 69.7009 Ops/s | |
test_redq_speed[False-backward] | 33.7493ms | 25.1084ms | 39.8274 Ops/s | 41.3938 Ops/s | |
test_redq_speed[True-None] | 8.8763ms | 8.2010ms | 121.9367 Ops/s | 143.1708 Ops/s | |
test_redq_speed[True-backward] | 16.0074ms | 15.2826ms | 65.4339 Ops/s | 64.4639 Ops/s | |
test_redq_speed[reduce-overhead-None] | 8.6644ms | 8.1491ms | 122.7126 Ops/s | 144.4884 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 16.3009ms | 15.9746ms | 62.5994 Ops/s | 66.5982 Ops/s | |
test_redq_deprec_speed[False-None] | 16.9903ms | 15.2886ms | 65.4083 Ops/s | 71.4860 Ops/s | |
test_redq_deprec_speed[False-backward] | 28.2176ms | 22.0343ms | 45.3839 Ops/s | 50.9140 Ops/s | |
test_redq_deprec_speed[True-None] | 7.3966ms | 5.7893ms | 172.7313 Ops/s | 176.3297 Ops/s | |
test_redq_deprec_speed[True-backward] | 11.5790ms | 10.9593ms | 91.2464 Ops/s | 88.8997 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 6.9307ms | 5.7256ms | 174.6553 Ops/s | 171.1011 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 11.1122ms | 10.8543ms | 92.1294 Ops/s | 89.5219 Ops/s | |
test_td3_speed[False-None] | 8.9276ms | 8.3480ms | 119.7885 Ops/s | 112.6410 Ops/s | |
test_td3_speed[False-backward] | 12.2937ms | 10.8795ms | 91.9157 Ops/s | 90.0249 Ops/s | |
test_td3_speed[True-None] | 2.8488ms | 2.4182ms | 413.5373 Ops/s | 427.9432 Ops/s | |
test_td3_speed[True-backward] | 4.1379ms | 4.0468ms | 247.1106 Ops/s | 247.5414 Ops/s | |
test_td3_speed[reduce-overhead-None] | 3.0576ms | 2.5197ms | 396.8724 Ops/s | 431.7185 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 4.7746ms | 4.6159ms | 216.6446 Ops/s | 229.1710 Ops/s | |
test_cql_speed[False-None] | 39.7393ms | 37.9241ms | 26.3685 Ops/s | 25.4123 Ops/s | |
test_cql_speed[False-backward] | 50.2972ms | 48.5276ms | 20.6068 Ops/s | 19.9436 Ops/s | |
test_cql_speed[True-None] | 24.0369ms | 22.8639ms | 43.7371 Ops/s | 43.4735 Ops/s | |
test_cql_speed[True-backward] | 31.0101ms | 29.6818ms | 33.6907 Ops/s | 33.0412 Ops/s | |
test_cql_speed[reduce-overhead-None] | 25.5136ms | 22.9608ms | 43.5524 Ops/s | 43.7527 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 31.5359ms | 30.2070ms | 33.1049 Ops/s | 33.5155 Ops/s | |
test_a2c_speed[False-None] | 10.3003ms | 7.6337ms | 130.9981 Ops/s | 134.1216 Ops/s | |
test_a2c_speed[False-backward] | 17.9571ms | 15.2254ms | 65.6799 Ops/s | 67.5035 Ops/s | |
test_a2c_speed[True-None] | 7.0290ms | 4.9535ms | 201.8770 Ops/s | 210.7018 Ops/s | |
test_a2c_speed[True-backward] | 12.9078ms | 11.6140ms | 86.1029 Ops/s | 89.2133 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 5.2857ms | 4.7235ms | 211.7067 Ops/s | 213.9131 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 12.9199ms | 11.5422ms | 86.6389 Ops/s | 89.5758 Ops/s | |
test_ppo_speed[False-None] | 9.2568ms | 7.7562ms | 128.9286 Ops/s | 129.5437 Ops/s | |
test_ppo_speed[False-backward] | 17.7039ms | 15.3529ms | 65.1343 Ops/s | 66.6908 Ops/s | |
test_ppo_speed[True-None] | 5.9585ms | 5.1290ms | 194.9697 Ops/s | 195.1346 Ops/s | |
test_ppo_speed[True-backward] | 13.4509ms | 11.5074ms | 86.9007 Ops/s | 89.5727 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 6.4508ms | 5.1925ms | 192.5854 Ops/s | 192.6685 Ops/s | |
test_ppo_speed[reduce-overhead-backward] | 11.7231ms | 11.4135ms | 87.6156 Ops/s | 89.4482 Ops/s | |
test_reinforce_speed[False-None] | 8.3868ms | 6.6684ms | 149.9605 Ops/s | 150.1868 Ops/s | |
test_reinforce_speed[False-backward] | 10.3774ms | 9.9409ms | 100.5942 Ops/s | 100.3231 Ops/s | |
test_reinforce_speed[True-None] | 4.4569ms | 4.1013ms | 243.8249 Ops/s | 244.7678 Ops/s | |
test_reinforce_speed[True-backward] | 11.9280ms | 10.6016ms | 94.3250 Ops/s | 99.0508 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 5.0543ms | 4.1901ms | 238.6574 Ops/s | 239.9774 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 11.9962ms | 10.3423ms | 96.6904 Ops/s | 95.1445 Ops/s | |
test_iql_speed[False-None] | 39.4692ms | 33.7048ms | 29.6694 Ops/s | 29.2201 Ops/s | |
test_iql_speed[False-backward] | 48.9738ms | 46.7393ms | 21.3953 Ops/s | 21.5388 Ops/s | |
test_iql_speed[True-None] | 19.8701ms | 16.4402ms | 60.8266 Ops/s | 62.8998 Ops/s | |
test_iql_speed[True-backward] | 29.4113ms | 27.7687ms | 36.0118 Ops/s | 36.1236 Ops/s | |
test_iql_speed[reduce-overhead-None] | 17.6198ms | 15.9609ms | 62.6530 Ops/s | 61.3443 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 29.1834ms | 27.1744ms | 36.7994 Ops/s | 36.7113 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.3554ms | 4.8553ms | 205.9615 Ops/s | 204.0116 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.8132ms | 0.5212ms | 1.9186 KOps/s | 1.8570 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.7398ms | 0.4998ms | 2.0009 KOps/s | 1.9272 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.4179ms | 4.6624ms | 214.4796 Ops/s | 205.0639 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.0411ms | 0.5133ms | 1.9482 KOps/s | 1.8781 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7344ms | 0.4925ms | 2.0305 KOps/s | 1.9555 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 3.6555ms | 1.6972ms | 589.2184 Ops/s | 570.6394 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 2.2142ms | 1.5814ms | 632.3435 Ops/s | 618.9509 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.1076ms | 4.7901ms | 208.7628 Ops/s | 207.2566 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.5473ms | 0.6562ms | 1.5240 KOps/s | 1.4668 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8351ms | 0.6279ms | 1.5927 KOps/s | 1.5400 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.4739ms | 4.6489ms | 215.1041 Ops/s | 215.5000 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.3540ms | 0.5225ms | 1.9140 KOps/s | 1.8438 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6670ms | 0.4965ms | 2.0140 KOps/s | 1.9418 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 7.3795ms | 4.5123ms | 221.6177 Ops/s | 208.5939 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.7813s | 1.6165ms | 618.6293 Ops/s | 1.9017 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.7785ms | 0.4922ms | 2.0315 KOps/s | 1.9660 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 5.0813ms | 4.7616ms | 210.0136 Ops/s | 202.3336 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 3.0145ms | 0.6570ms | 1.5221 KOps/s | 1.4518 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.8452ms | 0.6324ms | 1.5812 KOps/s | 1.4941 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 5.9896ms | 4.3449ms | 230.1534 Ops/s | 238.2516 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 8.5835ms | 2.3303ms | 429.1267 Ops/s | 422.7340 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 1.8658ms | 1.3351ms | 749.0165 Ops/s | 725.0440 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 5.6721ms | 4.2883ms | 233.1932 Ops/s | 23.1368 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 5.6119ms | 2.3345ms | 428.3539 Ops/s | 410.3633 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 6.6658ms | 1.4917ms | 670.3713 Ops/s | 735.5777 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.7479s | 19.5165ms | 51.2387 Ops/s | 213.7766 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 8.4342ms | 2.6552ms | 376.6191 Ops/s | 405.3066 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 6.4685ms | 1.6968ms | 589.3466 Ops/s | 619.7286 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 61.4778ms | 50.9773ms | 19.6166 Ops/s | 19.2576 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.9689ms | 14.9552ms | 66.8665 Ops/s | 67.6098 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 59.7307ms | 50.6432ms | 19.7460 Ops/s | 18.7342 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 15.9822ms | 14.7521ms | 67.7871 Ops/s | 67.0906 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 60.8427ms | 50.6363ms | 19.7487 Ops/s | 19.0349 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 16.9117ms | 16.0328ms | 62.3722 Ops/s | 61.8679 Ops/s |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_simple | 0.9319s | 0.8449s | 1.1836 Ops/s | 1.2236 Ops/s | |
test_transformed | 1.5728s | 1.4754s | 0.6778 Ops/s | 0.6792 Ops/s | |
test_serial | 2.5213s | 2.4028s | 0.4162 Ops/s | 0.4217 Ops/s | |
test_parallel | 2.2163s | 2.0564s | 0.4863 Ops/s | 0.5157 Ops/s | |
test_step_mdp_speed[True-True-True-True-True] | 0.4266ms | 38.9991μs | 25.6416 KOps/s | 24.9945 KOps/s | |
test_step_mdp_speed[True-True-True-True-False] | 50.1010μs | 23.3589μs | 42.8103 KOps/s | 43.1504 KOps/s | |
test_step_mdp_speed[True-True-True-False-True] | 0.4165ms | 22.0534μs | 45.3444 KOps/s | 45.5802 KOps/s | |
test_step_mdp_speed[True-True-True-False-False] | 0.4006ms | 12.9362μs | 77.3024 KOps/s | 76.3279 KOps/s | |
test_step_mdp_speed[True-True-False-True-True] | 68.8710μs | 42.8591μs | 23.3323 KOps/s | 23.2182 KOps/s | |
test_step_mdp_speed[True-True-False-True-False] | 0.4141ms | 25.6089μs | 39.0489 KOps/s | 38.7532 KOps/s | |
test_step_mdp_speed[True-True-False-False-True] | 0.4192ms | 24.8332μs | 40.2687 KOps/s | 40.6925 KOps/s | |
test_step_mdp_speed[True-True-False-False-False] | 0.4073ms | 15.3698μs | 65.0628 KOps/s | 65.3421 KOps/s | |
test_step_mdp_speed[True-False-True-True-True] | 76.4910μs | 45.2996μs | 22.0753 KOps/s | 22.1781 KOps/s | |
test_step_mdp_speed[True-False-True-True-False] | 0.4135ms | 28.0860μs | 35.6049 KOps/s | 35.8458 KOps/s | |
test_step_mdp_speed[True-False-True-False-True] | 0.4151ms | 24.4381μs | 40.9198 KOps/s | 40.6004 KOps/s | |
test_step_mdp_speed[True-False-True-False-False] | 44.8410μs | 15.1891μs | 65.8369 KOps/s | 64.4168 KOps/s | |
test_step_mdp_speed[True-False-False-True-True] | 74.5510μs | 46.7220μs | 21.4032 KOps/s | 21.9605 KOps/s | |
test_step_mdp_speed[True-False-False-True-False] | 0.4182ms | 29.6741μs | 33.6994 KOps/s | 32.4166 KOps/s | |
test_step_mdp_speed[True-False-False-False-True] | 0.4262ms | 27.2023μs | 36.7615 KOps/s | 36.9075 KOps/s | |
test_step_mdp_speed[True-False-False-False-False] | 0.4139ms | 17.6011μs | 56.8148 KOps/s | 56.1711 KOps/s | |
test_step_mdp_speed[False-True-True-True-True] | 87.2310μs | 44.9288μs | 22.2574 KOps/s | 22.2341 KOps/s | |
test_step_mdp_speed[False-True-True-True-False] | 0.4119ms | 28.5358μs | 35.0437 KOps/s | 34.8683 KOps/s | |
test_step_mdp_speed[False-True-True-False-True] | 2.6301ms | 28.9104μs | 34.5896 KOps/s | 34.8928 KOps/s | |
test_step_mdp_speed[False-True-True-False-False] | 0.4102ms | 17.2672μs | 57.9132 KOps/s | 58.3639 KOps/s | |
test_step_mdp_speed[False-True-False-True-True] | 0.4441ms | 47.2138μs | 21.1803 KOps/s | 20.7962 KOps/s | |
test_step_mdp_speed[False-True-False-True-False] | 63.0600μs | 30.7809μs | 32.4877 KOps/s | 31.9277 KOps/s | |
test_step_mdp_speed[False-True-False-False-True] | 0.4352ms | 30.4246μs | 32.8681 KOps/s | 32.2435 KOps/s | |
test_step_mdp_speed[False-True-False-False-False] | 0.4144ms | 19.0976μs | 52.3626 KOps/s | 50.6594 KOps/s | |
test_step_mdp_speed[False-False-True-True-True] | 0.4432ms | 50.0173μs | 19.9931 KOps/s | 19.9619 KOps/s | |
test_step_mdp_speed[False-False-True-True-False] | 71.7610μs | 33.0776μs | 30.2319 KOps/s | 29.6815 KOps/s | |
test_step_mdp_speed[False-False-True-False-True] | 0.4249ms | 30.5460μs | 32.7376 KOps/s | 32.1497 KOps/s | |
test_step_mdp_speed[False-False-True-False-False] | 55.0100μs | 19.2462μs | 51.9582 KOps/s | 49.9641 KOps/s | |
test_step_mdp_speed[False-False-False-True-True] | 0.4441ms | 50.9670μs | 19.6205 KOps/s | 19.2864 KOps/s | |
test_step_mdp_speed[False-False-False-True-False] | 0.4274ms | 34.9474μs | 28.6144 KOps/s | 28.5300 KOps/s | |
test_step_mdp_speed[False-False-False-False-True] | 63.2910μs | 32.5089μs | 30.7608 KOps/s | 30.6456 KOps/s | |
test_step_mdp_speed[False-False-False-False-False] | 0.4175ms | 21.7372μs | 46.0040 KOps/s | 46.6497 KOps/s | |
test_values[generalized_advantage_estimate-True-True] | 28.0619ms | 26.8934ms | 37.1839 Ops/s | 38.9250 Ops/s | |
test_values[vec_generalized_advantage_estimate-True-True] | 96.7784ms | 2.8525ms | 350.5648 Ops/s | 339.6788 Ops/s | |
test_values[td0_return_estimate-False-False] | 0.1063ms | 81.4144μs | 12.2828 KOps/s | 12.4390 KOps/s | |
test_values[td1_return_estimate-False-False] | 60.6445ms | 58.3765ms | 17.1302 Ops/s | 17.7987 Ops/s | |
test_values[vec_td1_return_estimate-False-False] | 1.3894ms | 1.1042ms | 905.6245 Ops/s | 913.3104 Ops/s | |
test_values[td_lambda_return_estimate-True-False] | 95.8431ms | 91.4364ms | 10.9366 Ops/s | 11.2526 Ops/s | |
test_values[vec_td_lambda_return_estimate-True-False] | 1.4183ms | 1.1092ms | 901.5208 Ops/s | 918.3534 Ops/s | |
test_gae_speed[generalized_advantage_estimate-False-1-512] | 27.8295ms | 26.8001ms | 37.3132 Ops/s | 39.3055 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-1-512] | 1.0403ms | 0.7740ms | 1.2919 KOps/s | 1.3102 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-1-512] | 1.0971ms | 0.6932ms | 1.4426 KOps/s | 1.4859 KOps/s | |
test_gae_speed[vec_generalized_advantage_estimate-True-32-512] | 1.9008ms | 1.5025ms | 665.5560 Ops/s | 661.0875 Ops/s | |
test_gae_speed[vec_generalized_advantage_estimate-False-32-512] | 0.7685ms | 0.7075ms | 1.4135 KOps/s | 1.3962 KOps/s | |
test_dqn_speed[False-None] | 1.9557ms | 1.5616ms | 640.3526 Ops/s | 648.1299 Ops/s | |
test_dqn_speed[False-backward] | 2.5355ms | 2.1647ms | 461.9605 Ops/s | 464.1822 Ops/s | |
test_dqn_speed[True-None] | 0.6559ms | 0.5553ms | 1.8010 KOps/s | 1.7912 KOps/s | |
test_dqn_speed[True-backward] | 1.2948ms | 1.2345ms | 810.0390 Ops/s | 883.6467 Ops/s | |
test_dqn_speed[reduce-overhead-None] | 0.7208ms | 0.5747ms | 1.7402 KOps/s | 1.7513 KOps/s | |
test_dqn_speed[reduce-overhead-backward] | 1.1258ms | 1.0773ms | 928.2238 Ops/s | 1.0262 KOps/s | |
test_ddpg_speed[False-None] | 3.1820ms | 2.8640ms | 349.1591 Ops/s | 351.1755 Ops/s | |
test_ddpg_speed[False-backward] | 4.7359ms | 4.2895ms | 233.1253 Ops/s | 241.7340 Ops/s | |
test_ddpg_speed[True-None] | 1.4699ms | 1.3486ms | 741.5240 Ops/s | 746.1879 Ops/s | |
test_ddpg_speed[True-backward] | 2.7063ms | 2.6139ms | 382.5756 Ops/s | 403.4276 Ops/s | |
test_ddpg_speed[reduce-overhead-None] | 1.4712ms | 1.3560ms | 737.4757 Ops/s | 738.9496 Ops/s | |
test_ddpg_speed[reduce-overhead-backward] | 2.1000ms | 2.0566ms | 486.2438 Ops/s | 521.6373 Ops/s | |
test_sac_speed[False-None] | 8.6040ms | 8.1389ms | 122.8660 Ops/s | 122.4865 Ops/s | |
test_sac_speed[False-backward] | 11.6949ms | 11.3561ms | 88.0582 Ops/s | 89.4200 Ops/s | |
test_sac_speed[True-None] | 1.9854ms | 1.8405ms | 543.3330 Ops/s | 541.6580 Ops/s | |
test_sac_speed[True-backward] | 4.2506ms | 3.8034ms | 262.9245 Ops/s | 274.5324 Ops/s | |
test_sac_speed[reduce-overhead-None] | 20.9945ms | 11.9836ms | 83.4476 Ops/s | 86.1558 Ops/s | |
test_sac_speed[reduce-overhead-backward] | 1.7485ms | 1.7044ms | 586.7193 Ops/s | 573.4926 Ops/s | |
test_redq_speed[False-None] | 8.1681ms | 7.6926ms | 129.9953 Ops/s | 128.2631 Ops/s | |
test_redq_speed[False-backward] | 12.2292ms | 11.9936ms | 83.3777 Ops/s | 83.0653 Ops/s | |
test_redq_speed[True-None] | 2.5551ms | 2.3650ms | 422.8290 Ops/s | 424.2939 Ops/s | |
test_redq_speed[True-backward] | 4.5527ms | 4.1138ms | 243.0826 Ops/s | 242.1178 Ops/s | |
test_redq_speed[reduce-overhead-None] | 2.4968ms | 2.3755ms | 420.9569 Ops/s | 420.1687 Ops/s | |
test_redq_speed[reduce-overhead-backward] | 4.6368ms | 4.3099ms | 232.0225 Ops/s | 232.0808 Ops/s | |
test_redq_deprec_speed[False-None] | 9.7894ms | 9.2377ms | 108.2519 Ops/s | 108.9767 Ops/s | |
test_redq_deprec_speed[False-backward] | 13.2560ms | 12.4242ms | 80.4880 Ops/s | 79.7549 Ops/s | |
test_redq_deprec_speed[True-None] | 2.7210ms | 2.6601ms | 375.9198 Ops/s | 371.2723 Ops/s | |
test_redq_deprec_speed[True-backward] | 5.0461ms | 4.5672ms | 218.9514 Ops/s | 211.2266 Ops/s | |
test_redq_deprec_speed[reduce-overhead-None] | 3.1861ms | 2.7045ms | 369.7497 Ops/s | 372.6811 Ops/s | |
test_redq_deprec_speed[reduce-overhead-backward] | 4.9563ms | 4.5411ms | 220.2122 Ops/s | 217.3233 Ops/s | |
test_td3_speed[False-None] | 8.5082ms | 8.1904ms | 122.0945 Ops/s | 124.9748 Ops/s | |
test_td3_speed[False-backward] | 11.4198ms | 10.7176ms | 93.3046 Ops/s | 92.9915 Ops/s | |
test_td3_speed[True-None] | 1.6971ms | 1.6557ms | 603.9624 Ops/s | 605.4262 Ops/s | |
test_td3_speed[True-backward] | 3.4420ms | 3.3919ms | 294.8181 Ops/s | 294.4295 Ops/s | |
test_td3_speed[reduce-overhead-None] | 77.5254ms | 26.5458ms | 37.6707 Ops/s | 37.8769 Ops/s | |
test_td3_speed[reduce-overhead-backward] | 1.5407ms | 1.4885ms | 671.8315 Ops/s | 666.5142 Ops/s | |
test_cql_speed[False-None] | 17.6115ms | 17.0992ms | 58.4823 Ops/s | 58.7421 Ops/s | |
test_cql_speed[False-backward] | 23.5664ms | 22.6307ms | 44.1878 Ops/s | 43.9023 Ops/s | |
test_cql_speed[True-None] | 3.5188ms | 3.2772ms | 305.1405 Ops/s | 303.7337 Ops/s | |
test_cql_speed[True-backward] | 6.1753ms | 5.7515ms | 173.8679 Ops/s | 171.6878 Ops/s | |
test_cql_speed[reduce-overhead-None] | 0.5911s | 16.3408ms | 61.1965 Ops/s | 75.4564 Ops/s | |
test_cql_speed[reduce-overhead-backward] | 2.0746ms | 2.0038ms | 499.0491 Ops/s | 501.1841 Ops/s | |
test_a2c_speed[False-None] | 3.4593ms | 3.2150ms | 311.0454 Ops/s | 307.8765 Ops/s | |
test_a2c_speed[False-backward] | 7.2667ms | 6.4120ms | 155.9577 Ops/s | 153.0857 Ops/s | |
test_a2c_speed[True-None] | 1.4397ms | 1.3482ms | 741.7510 Ops/s | 736.1607 Ops/s | |
test_a2c_speed[True-backward] | 3.2956ms | 3.0849ms | 324.1618 Ops/s | 333.9120 Ops/s | |
test_a2c_speed[reduce-overhead-None] | 16.1220ms | 9.1058ms | 109.8204 Ops/s | 109.5962 Ops/s | |
test_a2c_speed[reduce-overhead-backward] | 1.6784ms | 1.6126ms | 620.1306 Ops/s | 673.2005 Ops/s | |
test_ppo_speed[False-None] | 3.8427ms | 3.7298ms | 268.1113 Ops/s | 263.6238 Ops/s | |
test_ppo_speed[False-backward] | 7.5794ms | 7.1511ms | 139.8385 Ops/s | 142.0220 Ops/s | |
test_ppo_speed[True-None] | 1.5319ms | 1.4213ms | 703.5763 Ops/s | 695.2069 Ops/s | |
test_ppo_speed[True-backward] | 3.3765ms | 3.2590ms | 306.8450 Ops/s | 317.2860 Ops/s | |
test_ppo_speed[reduce-overhead-None] | 1.4082ms | 1.0031ms | 996.8827 Ops/s | 1.0258 KOps/s | |
test_ppo_speed[reduce-overhead-backward] | 1.6011ms | 1.5572ms | 642.1974 Ops/s | 683.8987 Ops/s | |
test_reinforce_speed[False-None] | 2.4093ms | 2.2772ms | 439.1388 Ops/s | 435.2413 Ops/s | |
test_reinforce_speed[False-backward] | 3.8110ms | 3.4087ms | 293.3671 Ops/s | 298.3421 Ops/s | |
test_reinforce_speed[True-None] | 1.3863ms | 1.2882ms | 776.2822 Ops/s | 767.0978 Ops/s | |
test_reinforce_speed[True-backward] | 3.2017ms | 3.1222ms | 320.2825 Ops/s | 338.1853 Ops/s | |
test_reinforce_speed[reduce-overhead-None] | 19.5414ms | 10.4833ms | 95.3897 Ops/s | 94.4040 Ops/s | |
test_reinforce_speed[reduce-overhead-backward] | 1.7207ms | 1.6379ms | 610.5198 Ops/s | 593.7734 Ops/s | |
test_iql_speed[False-None] | 9.9055ms | 9.2950ms | 107.5845 Ops/s | 105.7575 Ops/s | |
test_iql_speed[False-backward] | 13.7131ms | 13.2697ms | 75.3598 Ops/s | 73.6826 Ops/s | |
test_iql_speed[True-None] | 2.4358ms | 2.2219ms | 450.0603 Ops/s | 439.8459 Ops/s | |
test_iql_speed[True-backward] | 5.4091ms | 5.0188ms | 199.2494 Ops/s | 196.5463 Ops/s | |
test_iql_speed[reduce-overhead-None] | 0.5269s | 13.2313ms | 75.5783 Ops/s | 88.1345 Ops/s | |
test_iql_speed[reduce-overhead-backward] | 2.1035ms | 2.0561ms | 486.3565 Ops/s | 477.5068 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 7.6701ms | 6.2685ms | 159.5274 Ops/s | 158.1670 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 0.4999ms | 0.2731ms | 3.6616 KOps/s | 2.9449 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.4761ms | 0.2482ms | 4.0293 KOps/s | 2.9652 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 6.2248ms | 5.9196ms | 168.9296 Ops/s | 166.9355 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 2.0591ms | 0.3220ms | 3.1051 KOps/s | 3.6434 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.5386ms | 0.2865ms | 3.4906 KOps/s | 3.0479 KOps/s | |
test_rb_sample[TensorDictReplayBuffer-LazyMemmapStorage-sampler6-10000] | 1.5313ms | 1.3362ms | 748.3775 Ops/s | 774.2265 Ops/s | |
test_rb_sample[TensorDictReplayBuffer-LazyTensorStorage-sampler7-10000] | 1.4391ms | 1.2315ms | 811.9869 Ops/s | 830.4183 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.2569ms | 6.1183ms | 163.4444 Ops/s | 160.9955 Ops/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.1057ms | 0.4533ms | 2.2060 KOps/s | 2.3845 KOps/s | |
test_rb_sample[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.7352ms | 0.4443ms | 2.2506 KOps/s | 2.1668 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-RandomSampler-4000] | 6.0885ms | 5.9917ms | 166.8972 Ops/s | 164.0742 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-10000] | 1.6649ms | 0.3045ms | 3.2845 KOps/s | 3.6346 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-10000] | 0.6239ms | 0.3056ms | 3.2725 KOps/s | 4.0416 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-4000] | 10.1546ms | 6.0528ms | 165.2127 Ops/s | 165.8854 Ops/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-10000] | 0.6121ms | 0.2870ms | 3.4838 KOps/s | 3.3366 KOps/s | |
test_rb_iterate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-10000] | 0.4858ms | 0.2506ms | 3.9897 KOps/s | 3.5378 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-ListStorage-None-4000] | 6.3138ms | 6.1020ms | 163.8814 Ops/s | 161.9272 Ops/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-10000] | 1.9044ms | 0.4815ms | 2.0769 KOps/s | 2.4106 KOps/s | |
test_rb_iterate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-10000] | 0.6689ms | 0.4625ms | 2.1622 KOps/s | 2.5615 KOps/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-RandomSampler-400] | 7.0766ms | 5.4692ms | 182.8416 Ops/s | 177.9948 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-RandomSampler-400] | 10.1882ms | 2.0968ms | 476.9237 Ops/s | 445.3565 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-RandomSampler-400] | 7.0856ms | 1.2311ms | 812.3123 Ops/s | 846.7540 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-ListStorage-SamplerWithoutReplacement-400] | 7.2501ms | 5.6106ms | 178.2342 Ops/s | 176.2945 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyMemmapStorage-SamplerWithoutReplacement-400] | 10.6180ms | 2.1045ms | 475.1678 Ops/s | 435.4696 Ops/s | |
test_rb_populate[TensorDictReplayBuffer-LazyTensorStorage-SamplerWithoutReplacement-400] | 3.4910ms | 1.1777ms | 849.0857 Ops/s | 843.6090 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-ListStorage-None-400] | 0.5181s | 16.0395ms | 62.3462 Ops/s | 30.4778 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyMemmapStorage-None-400] | 8.3912ms | 2.1450ms | 466.2020 Ops/s | 442.8126 Ops/s | |
test_rb_populate[TensorDictPrioritizedReplayBuffer-LazyTensorStorage-None-400] | 8.8148ms | 1.3883ms | 720.2917 Ops/s | 779.5482 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-True] | 56.6271ms | 54.8310ms | 18.2379 Ops/s | 18.3542 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-10000-10000-100-False] | 19.1151ms | 17.2712ms | 57.8999 Ops/s | 59.2282 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-True] | 56.7215ms | 54.9943ms | 18.1837 Ops/s | 18.3002 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-100000-10000-100-False] | 19.1236ms | 17.3303ms | 57.7025 Ops/s | 58.3269 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-True] | 55.4740ms | 54.0691ms | 18.4949 Ops/s | 18.3844 Ops/s | |
test_rb_extend_sample[ReplayBuffer-LazyTensorStorage-RandomSampler-1000000-10000-100-False] | 22.2163ms | 20.5849ms | 48.5793 Ops/s | 53.6438 Ops/s |
torchrl/collectors/collectors.py
Outdated
raise RuntimeError("Non-terminating collectors do not have a length") | ||
|
||
@classmethod | ||
def from_policy_factory( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit convoluted.
We could also have one more kwarg in the constructors, but since there are many kwargs there already, and many collector subclasses, IDK if that's the most economical solution.
Happy to make this a kwargs that is exclusive with the policy
arg if you think it's more suited
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is ok
Just a question -- do you think that giving the CustomCollectorCls a __name__
that depends on its cls
would be helpful for debugging, or would that make it harder to find where the custom wrapper class was defined 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To double check my understanding -- the change is that policy is now instantiated at collector init time rather than before, which is helpful for distributed settings
if sampler is SamplerWithoutReplacement: | ||
assert sample["a"].unique().numel() == sample.numel() | ||
|
||
# class CustomCollectorCls(SyncDataCollector): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove comments
torchrl/collectors/collectors.py
Outdated
raise RuntimeError("Non-terminating collectors do not have a length") | ||
|
||
@classmethod | ||
def from_policy_factory( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is ok
Just a question -- do you think that giving the CustomCollectorCls a __name__
that depends on its cls
would be helpful for debugging, or would that make it harder to find where the custom wrapper class was defined 😅
Stack from ghstack (oldest at bottom):