242242 < div class ="pytorch-left-menu-search ">
243243
244244 < div class ="version ">
245- < a href ='https://pytorch.org/docs/versions.html '> main (2.2.0a0+git5a96a42 ) ▼</ a >
245+ < a href ='https://pytorch.org/docs/versions.html '> main (2.2.0a0+gite7f12b1 ) ▼</ a >
246246 </ div >
247247
248248
@@ -529,6 +529,7 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
529529 < span class ="s1 "> 'set_warn_always'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'is_warn_always_enabled'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'SymInt'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'SymFloat'</ span > < span class ="p "> ,</ span >
530530 < span class ="s1 "> 'SymBool'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'sym_not'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'unravel_index'</ span > < span class ="p "> ,</ span >
531531 < span class ="s1 "> 'sym_int'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'sym_float'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'sym_max'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'sym_min'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'sym_ite'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'compile'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'vmap'</ span > < span class ="p "> ,</ span >
532+ < span class ="s1 "> 'sym_sqrt'</ span > < span class ="p "> ,</ span >
532533 < span class ="s1 "> 'export'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'autocast'</ span > < span class ="p "> ,</ span > < span class ="s1 "> 'cond'</ span > < span class ="p "> ,</ span >
533534< span class ="p "> ]</ span >
534535
@@ -887,8 +888,15 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
887888< span class ="sd "> Args:</ span >
888889< span class ="sd "> a (SymBool or bool): Object to negate</ span >
889890< span class ="sd "> """</ span >
891+ < span class ="kn "> import</ span > < span class ="nn "> sympy</ span >
892+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
893+
894+ < span class ="k "> if</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
895+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_not</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,),</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
890896 < span class ="k "> if</ span > < span class ="nb "> hasattr</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="s1 "> '__sym_not__'</ span > < span class ="p "> ):</ span >
891897 < span class ="k "> return</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> __sym_not__</ span > < span class ="p "> ()</ span >
898+ < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> sympy</ span > < span class ="o "> .</ span > < span class ="n "> Basic</ span > < span class ="p "> ):</ span >
899+ < span class ="k "> return</ span > < span class ="o "> ~</ span > < span class ="n "> a</ span > < span class ="c1 "> # type: ignore[operator]</ span >
892900 < span class ="k "> return</ span > < span class ="ow "> not</ span > < span class ="n "> a</ span > </ div >
893901
894902< div class ="viewcode-block " id ="sym_float "> < a class ="viewcode-back " href ="../generated/torch.sym_float.html#torch.sym_float "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> sym_float</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
@@ -897,6 +905,10 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
897905< span class ="sd "> Args:</ span >
898906< span class ="sd "> a (SymInt, SymFloat, or object): Object to cast</ span >
899907< span class ="sd "> """</ span >
908+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
909+
910+ < span class ="k "> if</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
911+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_float</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,),</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
900912 < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> ):</ span >
901913 < span class ="k "> return</ span > < span class ="n "> a</ span >
902914 < span class ="k "> elif</ span > < span class ="nb "> hasattr</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="s1 "> '__sym_float__'</ span > < span class ="p "> ):</ span >
@@ -910,6 +922,10 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
910922< span class ="sd "> Args:</ span >
911923< span class ="sd "> a (SymInt, SymFloat, or object): Object to cast</ span >
912924< span class ="sd "> """</ span >
925+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
926+
927+ < span class ="k "> if</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
928+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_int</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,),</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
913929 < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> SymInt</ span > < span class ="p "> ):</ span >
914930 < span class ="k "> return</ span > < span class ="n "> a</ span >
915931 < span class ="k "> elif</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> ):</ span >
@@ -918,6 +934,10 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
918934
919935< div class ="viewcode-block " id ="sym_max "> < a class ="viewcode-back " href ="../generated/torch.sym_max.html#torch.sym_max "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> sym_max</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> ):</ span >
920936< span class ="w "> </ span > < span class ="sd "> """ SymInt-aware utility for max()."""</ span >
937+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
938+
939+ < span class ="k "> if</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ((</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> )):</ span >
940+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_max</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> ),</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> )</ span >
921941 < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> SymInt</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> )):</ span >
922942 < span class ="k "> return</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> __sym_max__</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> )</ span >
923943 < span class ="k "> elif</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> SymInt</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> )):</ span >
@@ -929,13 +949,31 @@ <h1>Source code for torch</h1><div class="highlight"><pre>
929949
930950< div class ="viewcode-block " id ="sym_min "> < a class ="viewcode-back " href ="../generated/torch.sym_min.html#torch.sym_min "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> sym_min</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> ):</ span >
931951< span class ="w "> </ span > < span class ="sd "> """ SymInt-aware utility for max()."""</ span >
952+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
953+
954+ < span class ="k "> if</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ((</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> )):</ span >
955+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_min</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> ),</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> )</ span >
932956 < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> SymInt</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> )):</ span >
933957 < span class ="k "> return</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> __sym_min__</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> )</ span >
934958 < span class ="k "> elif</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> SymInt</ span > < span class ="p "> ,</ span > < span class ="n "> SymFloat</ span > < span class ="p "> )):</ span >
935959 < span class ="k "> return</ span > < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> __sym_min__</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
936960 < span class ="k "> return</ span > < span class ="n "> builtins</ span > < span class ="o "> .</ span > < span class ="n "> min</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="n "> b</ span > < span class ="p "> )</ span > < span class ="c1 "> # type: ignore[operator]</ span > </ div >
937961
962+ < span class ="c1 "> # Drop in replacement for math.sqrt</ span >
963+ < span class ="k "> def</ span > < span class ="nf "> sym_sqrt</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
964+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
965+
966+ < span class ="k "> if</ span > < span class ="n "> has_torch_function_unary</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ):</ span >
967+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_sqrt</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,),</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
968+ < span class ="k "> if</ span > < span class ="nb "> hasattr</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> ,</ span > < span class ="s2 "> "__sym_sqrt__"</ span > < span class ="p "> ):</ span >
969+ < span class ="k "> return</ span > < span class ="n "> a</ span > < span class ="o "> .</ span > < span class ="n "> __sym_sqrt__</ span > < span class ="p "> ()</ span >
970+ < span class ="k "> return</ span > < span class ="n "> math</ span > < span class ="o "> .</ span > < span class ="n "> sqrt</ span > < span class ="p "> (</ span > < span class ="n "> a</ span > < span class ="p "> )</ span >
971+
938972< div class ="viewcode-block " id ="sym_ite "> < a class ="viewcode-back " href ="../generated/torch.sym_ite.html#torch.sym_ite "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> sym_ite</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> ):</ span >
973+ < span class ="kn "> from</ span > < span class ="nn "> .overrides</ span > < span class ="kn "> import</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ,</ span > < span class ="n "> handle_torch_function</ span >
974+
975+ < span class ="k "> if</ span > < span class ="n "> has_torch_function</ span > < span class ="p "> ((</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> )):</ span >
976+ < span class ="k "> return</ span > < span class ="n "> handle_torch_function</ span > < span class ="p "> (</ span > < span class ="n "> sym_ite</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> ),</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> )</ span >
939977 < span class ="k "> assert</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="p "> (</ span > < span class ="n "> SymBool</ span > < span class ="p "> ,</ span > < span class ="n "> builtins</ span > < span class ="o "> .</ span > < span class ="n "> bool</ span > < span class ="p "> ))</ span > < span class ="ow "> and</ span > < span class ="nb "> type</ span > < span class ="p "> (</ span > < span class ="n "> t</ span > < span class ="p "> )</ span > < span class ="o "> ==</ span > < span class ="nb "> type</ span > < span class ="p "> (</ span > < span class ="n "> f</ span > < span class ="p "> )</ span >
940978 < span class ="k "> if</ span > < span class ="nb "> isinstance</ span > < span class ="p "> (</ span > < span class ="n "> b</ span > < span class ="p "> ,</ span > < span class ="n "> SymBool</ span > < span class ="p "> ):</ span >
941979 < span class ="k "> return</ span > < span class ="n "> b</ span > < span class ="o "> .</ span > < span class ="n "> __sym_ite__</ span > < span class ="p "> (</ span > < span class ="n "> t</ span > < span class ="p "> ,</ span > < span class ="n "> f</ span > < span class ="p "> )</ span >
0 commit comments