Commit f03bd79
Explore new pipeline that overlaps optimizer with emb_lookup (#2916)
Summary:
Pull Request resolved: #2916
# context
* this workstream started from an training QPS optimization initiated from the PG side (see doc in the reference section), observing the embedding lookup can overlap with the optimizer.
* Embedding table weights are updated in the fused backward (fused-TBE), so the embedding lookup can start immediately after backward is completed without dependency on the optiimzer.
* we use a separate stream to run embedding lookup so that it can overlap with the previous optimizer (changed, see below)
* there is also an option of using data_dist stream for this embedding lookup, the output_dist won't be block but the start_sparse_data_dist would, which results a smaller mem footprint.
WARNING: This pipeline **DOES NOT** work for EBC/EC with feature processors because the embedding lookup is started immediately after TBE backward (where the embedding tables' weights have been updated)
# benchmark readings
* runtime: SemiSync < FusedSparseDist (lookup after opt) < FusedSparseDist (lookup before opt) < SparseDist
```
TrainPipelineSemiSync | Runtime (P90): 5447.42 ms | Peak Memory alloc (P90): 61.63 GB | Peak Memory reserved (P90): 64.31 GB
TrainPipelineFusedSparseDist | Runtime (P90): 5605.63 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.61 GB
TrainPipelineFusedSparseDist* | Runtime (P90): 5661.92 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.67 GB
TrainPipelineSparseDist | Runtime (P90): 6034.46 ms | Peak Memory alloc (P90): 51.80 GB | Peak Memory reserved (P90): 62.25 GB
* embedding_lookup_after_opt = False
```
* traces show that:
(1) the emb_lookup is right behind the TBE-bwd (on the same cuda stream)
(2) the output_dist is invoked right after each emb_lookup (there are two, one for unweighted ebc, one for weighted)
(3) the optimizer seems **NOT** overlap with emb_lookup kernel when `embedding_lookup_after_opt = False`
{F1977309185}
(4) the optimizer still does **NOT** overlap with emb_lookup kernel, but it fills in the gap between the `KJTTensorAwaitable.wait()` and the embedding lookup kernel when `embedding_lookup_after_opt = True`
{F1977309202}
(5) if use a separate stream for embedding lookup, so that the following `start_sparse_data_dist` can start immediately. however this causes extra memory consumption.
{F1977366363}
(6) if re-use the data_dist stream for embedding lookup, the following up `start_sparse_data_dist` will wait for embedding lookup to complete, the measured memory footprint is smaller
{F1977366349}
NOTE: Based on (5) and (6) we set `use_emb_lookup_stream = False` is the default behavior
# conclusions
* Based on a simple model (SparseNN), both "Fused Sparse Dist" pipeline and the "Semi Sync" pipeline are faster than the current default (commonly used) "Sparse Dist" pipeline, respectively -7% (fused sparse dist) and -10% (semi sync) in runtime.
* In a more realistic scenario, the optimizer step has a longer runtime footprint, which can amplify this optimization.
* The "Semi Sync" pipeline has a larger QPS win but it produces slightly different numerical training results, while the "Fused Sparse Dist" pipeline with a slight few QPS win should be numerically the same as the default pipeline.
* It would be the user's choice for which one to use.
# reference
* https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486
Reviewed By: dstaay-fb
Differential Revision: D64479105
fbshipit-source-id: c6bd1306823d02afafd7e2f1d9f95e1d068d6f611 parent f35befa commit f03bd79
File tree
4 files changed
+190
-2
lines changed- torchrec/distributed
- benchmark
- train_pipeline
4 files changed
+190
-2
lines changedLines changed: 2 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
| 45 | + | |
45 | 46 | | |
46 | 47 | | |
47 | 48 | | |
| |||
106 | 107 | | |
107 | 108 | | |
108 | 109 | | |
| 110 | + | |
109 | 111 | | |
110 | 112 | | |
111 | 113 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
15 | 15 | | |
16 | 16 | | |
17 | 17 | | |
| 18 | + | |
18 | 19 | | |
19 | 20 | | |
20 | 21 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| 50 | + | |
50 | 51 | | |
51 | 52 | | |
52 | 53 | | |
| |||
540 | 541 | | |
541 | 542 | | |
542 | 543 | | |
543 | | - | |
544 | | - | |
| 544 | + | |
545 | 545 | | |
546 | 546 | | |
547 | 547 | | |
| |||
803 | 803 | | |
804 | 804 | | |
805 | 805 | | |
| 806 | + | |
| 807 | + | |
| 808 | + | |
| 809 | + | |
| 810 | + | |
| 811 | + | |
| 812 | + | |
| 813 | + | |
| 814 | + | |
| 815 | + | |
| 816 | + | |
| 817 | + | |
| 818 | + | |
| 819 | + | |
| 820 | + | |
| 821 | + | |
| 822 | + | |
| 823 | + | |
| 824 | + | |
| 825 | + | |
| 826 | + | |
| 827 | + | |
| 828 | + | |
| 829 | + | |
| 830 | + | |
| 831 | + | |
| 832 | + | |
| 833 | + | |
| 834 | + | |
| 835 | + | |
| 836 | + | |
| 837 | + | |
| 838 | + | |
| 839 | + | |
| 840 | + | |
| 841 | + | |
| 842 | + | |
| 843 | + | |
| 844 | + | |
| 845 | + | |
| 846 | + | |
| 847 | + | |
| 848 | + | |
| 849 | + | |
| 850 | + | |
| 851 | + | |
| 852 | + | |
| 853 | + | |
| 854 | + | |
| 855 | + | |
| 856 | + | |
| 857 | + | |
| 858 | + | |
| 859 | + | |
| 860 | + | |
| 861 | + | |
| 862 | + | |
| 863 | + | |
| 864 | + | |
| 865 | + | |
| 866 | + | |
| 867 | + | |
| 868 | + | |
| 869 | + | |
| 870 | + | |
| 871 | + | |
| 872 | + | |
| 873 | + | |
| 874 | + | |
| 875 | + | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
| 883 | + | |
| 884 | + | |
| 885 | + | |
| 886 | + | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
| 892 | + | |
| 893 | + | |
| 894 | + | |
| 895 | + | |
| 896 | + | |
| 897 | + | |
| 898 | + | |
| 899 | + | |
| 900 | + | |
| 901 | + | |
| 902 | + | |
| 903 | + | |
| 904 | + | |
| 905 | + | |
| 906 | + | |
| 907 | + | |
| 908 | + | |
| 909 | + | |
| 910 | + | |
| 911 | + | |
| 912 | + | |
| 913 | + | |
| 914 | + | |
| 915 | + | |
| 916 | + | |
| 917 | + | |
| 918 | + | |
| 919 | + | |
| 920 | + | |
| 921 | + | |
| 922 | + | |
| 923 | + | |
| 924 | + | |
| 925 | + | |
| 926 | + | |
| 927 | + | |
| 928 | + | |
| 929 | + | |
| 930 | + | |
| 931 | + | |
| 932 | + | |
| 933 | + | |
| 934 | + | |
| 935 | + | |
| 936 | + | |
| 937 | + | |
| 938 | + | |
| 939 | + | |
| 940 | + | |
| 941 | + | |
| 942 | + | |
| 943 | + | |
| 944 | + | |
| 945 | + | |
| 946 | + | |
| 947 | + | |
| 948 | + | |
| 949 | + | |
| 950 | + | |
| 951 | + | |
| 952 | + | |
| 953 | + | |
| 954 | + | |
| 955 | + | |
| 956 | + | |
| 957 | + | |
| 958 | + | |
| 959 | + | |
| 960 | + | |
| 961 | + | |
| 962 | + | |
| 963 | + | |
| 964 | + | |
| 965 | + | |
| 966 | + | |
| 967 | + | |
| 968 | + | |
| 969 | + | |
| 970 | + | |
| 971 | + | |
| 972 | + | |
| 973 | + | |
| 974 | + | |
| 975 | + | |
806 | 976 | | |
807 | 977 | | |
808 | 978 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
670 | 670 | | |
671 | 671 | | |
672 | 672 | | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
| 686 | + | |
673 | 687 | | |
674 | 688 | | |
675 | 689 | | |
| |||
853 | 867 | | |
854 | 868 | | |
855 | 869 | | |
| 870 | + | |
856 | 871 | | |
857 | 872 | | |
858 | 873 | | |
| |||
0 commit comments