184
184
< div class ="pytorch-left-menu-search ">
185
185
186
186
< div class ="version ">
187
- < a href ='https://pytorch.org/docs/versions.html '> master (1.9.0a0+git33aed5f ) ▼</ a >
187
+ < a href ='https://pytorch.org/docs/versions.html '> master (1.9.0a0+git435d257 ) ▼</ a >
188
188
</ div >
189
189
190
190
377
377
378
378
< h1 > Source code for torch._utils</ h1 > < div class ="highlight "> < pre >
379
379
< span > </ span > < span class ="kn "> import</ span > < span class ="nn "> torch</ span >
380
- < span class ="kn "> from</ span > < span class ="nn "> typing</ span > < span class ="kn "> import</ span > < span class ="n "> Optional</ span > < span class ="p "> ,</ span > < span class ="n "> List</ span > < span class ="p "> ,</ span > < span class ="n "> DefaultDict</ span >
380
+ < span class ="kn "> from</ span > < span class ="nn "> typing</ span > < span class ="kn "> import</ span > < span class ="n "> Optional</ span > < span class ="p "> ,</ span > < span class ="n "> List</ span > < span class ="p "> ,</ span > < span class ="n "> DefaultDict</ span > < span class =" p " > , </ span > < span class =" n " > Any </ span >
381
381
< span class ="kn "> import</ span > < span class ="nn "> warnings</ span >
382
382
< span class ="kn "> from</ span > < span class ="nn "> collections</ span > < span class ="kn "> import</ span > < span class ="n "> defaultdict</ span >
383
383
< span class ="kn "> import</ span > < span class ="nn "> sys</ span >
@@ -841,8 +841,17 @@ <h1>Source code for torch._utils</h1><div class="highlight"><pre>
841
841
< span class ="c1 "> # all device properties</ span >
842
842
< span class ="k "> return</ span > < span class ="p "> [</ span > < span class ="n "> _get_device_attr</ span > < span class ="p "> (</ span > < span class ="k "> lambda</ span > < span class ="n "> m</ span > < span class ="p "> :</ span > < span class ="n "> m</ span > < span class ="o "> .</ span > < span class ="n "> get_device_properties</ span > < span class ="p "> (</ span > < span class ="n "> i</ span > < span class ="p "> ))</ span > < span class ="k "> for</ span > < span class ="n "> i</ span > < span class ="ow "> in</ span > < span class ="n "> device_ids</ span > < span class ="p "> ]</ span >
843
843
844
+ < span class ="k "> def</ span > < span class ="nf "> get_current_device_index</ span > < span class ="p "> ()</ span > < span class ="o "> -></ span > < span class ="nb "> int</ span > < span class ="p "> :</ span >
845
+ < span class ="sa "> r</ span > < span class ="sd "> """Checks if there are CUDA devices available and</ span >
846
+ < span class ="sd "> returns the device index of the current default CUDA device.</ span >
847
+ < span class ="sd "> Returns -1 in case there are no CUDA devices available.</ span >
848
+ < span class ="sd "> Arguments: ``None``</ span >
849
+ < span class ="sd "> """</ span >
850
+ < span class ="k "> if</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> cuda</ span > < span class ="o "> .</ span > < span class ="n "> device_count</ span > < span class ="p "> ()</ span > < span class ="o "> ></ span > < span class ="mi "> 0</ span > < span class ="p "> :</ span >
851
+ < span class ="k "> return</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> cuda</ span > < span class ="o "> .</ span > < span class ="n "> current_device</ span > < span class ="p "> ()</ span >
852
+ < span class ="k "> return</ span > < span class ="o "> -</ span > < span class ="mi "> 1</ span >
844
853
845
- < span class ="k "> def</ span > < span class ="nf "> _get_device_index</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> ,</ span > < span class ="n "> optional</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="n "> allow_cpu</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> int</ span > < span class ="p "> :</ span >
854
+ < span class ="k "> def</ span > < span class ="nf "> _get_device_index</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> : </ span > < span class =" n " > Any </ span > < span class =" p " > ,</ span > < span class ="n "> optional</ span > < span class ="p " > : </ span > < span class =" nb " > bool </ span > < span class =" o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="n "> allow_cpu</ span > < span class ="p " > : </ span > < span class =" nb " > bool </ span > < span class =" o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> int</ span > < span class ="p "> :</ span >
846
855
< span class ="sa "> r</ span > < span class ="sd "> """Gets the device index from :attr:`device`, which can be a torch.device</ span >
847
856
< span class ="sd "> object, a Python integer, or ``None``.</ span >
848
857
@@ -860,8 +869,7 @@ <h1>Source code for torch._utils</h1><div class="highlight"><pre>
860
869
< span class ="sd "> """</ span >
861
870
< span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> ,</ span > < span class ="nb "> str</ span > < span class ="p "> ):</ span >
862
871
< span class ="n "> device</ span > < span class ="o "> =</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> )</ span >
863
- < span class ="n "> device_idx</ span > < span class ="p "> :</ span > < span class ="n "> Optional</ span > < span class ="p "> [</ span > < span class ="nb "> int</ span > < span class ="p "> ]</ span >
864
- < span class ="n "> device_idx</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
872
+ < span class ="n "> device_idx</ span > < span class ="p "> :</ span > < span class ="n "> Optional</ span > < span class ="p "> [</ span > < span class ="nb "> int</ span > < span class ="p "> ]</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
865
873
< span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> device</ span > < span class ="p "> ):</ span >
866
874
< span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> allow_cpu</ span > < span class ="ow "> and</ span > < span class ="n "> device</ span > < span class ="o "> .</ span > < span class ="n "> type</ span > < span class ="o "> ==</ span > < span class ="s1 "> 'cpu'</ span > < span class ="p "> :</ span >
867
875
< span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="s1 "> 'Expected a non cpu device, but got: </ span > < span class ="si "> {}</ span > < span class ="s1 "> '</ span > < span class ="o "> .</ span > < span class ="n "> format</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> ))</ span >
@@ -870,7 +878,15 @@ <h1>Source code for torch._utils</h1><div class="highlight"><pre>
870
878
< span class ="n "> device_idx</ span > < span class ="o "> =</ span > < span class ="n "> device</ span >
871
879
< span class ="k "> if</ span > < span class ="n "> device_idx</ span > < span class ="ow "> is</ span > < span class ="kc "> None</ span > < span class ="p "> :</ span >
872
880
< span class ="k "> if</ span > < span class ="n "> optional</ span > < span class ="p "> :</ span >
873
- < span class ="n "> device_idx</ span > < span class ="o "> =</ span > < span class ="n "> _get_current_device_index</ span > < span class ="p "> ()</ span >
881
+ < span class ="c1 "> # The eager API _get_current_device_index uses `lambda` functions which are</ span >
882
+ < span class ="c1 "> # not supported in JIT and hence not scriptable. The JIT equivalent API to get</ span >
883
+ < span class ="c1 "> # the current device index is `get_current_device_index()` which can</ span >
884
+ < span class ="c1 "> # be scripted. We use is_scripting to check the mode we are in and call the</ span >
885
+ < span class ="c1 "> # appropriate API.</ span >
886
+ < span class ="k "> if</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> jit</ span > < span class ="o "> .</ span > < span class ="n "> is_scripting</ span > < span class ="p "> ():</ span >
887
+ < span class ="n "> device_idx</ span > < span class ="o "> =</ span > < span class ="n "> get_current_device_index</ span > < span class ="p "> ()</ span >
888
+ < span class ="k "> else</ span > < span class ="p "> :</ span >
889
+ < span class ="n "> device_idx</ span > < span class ="o "> =</ span > < span class ="n "> _get_current_device_index</ span > < span class ="p "> ()</ span >
874
890
< span class ="k "> else</ span > < span class ="p "> :</ span >
875
891
< span class ="k "> raise</ span > < span class ="ne "> ValueError</ span > < span class ="p "> (</ span > < span class ="s1 "> 'Expected a torch.device with a specified index '</ span >
876
892
< span class ="s1 "> 'or an integer, but got:</ span > < span class ="si "> {}</ span > < span class ="s1 "> '</ span > < span class ="o "> .</ span > < span class ="n "> format</ span > < span class ="p "> (</ span > < span class ="n "> device</ span > < span class ="p "> ))</ span >
0 commit comments