From 2c082d04f6d5eb0c8c0d9e8b59d287e88afa9798 Mon Sep 17 00:00:00 2001 From: Harshal Date: Sat, 15 Nov 2025 19:29:27 -0500 Subject: [PATCH 1/3] Add support for output dicts --- CMakeLists.txt | 1 - model_repository/dict_model/1/model.pt | Bin 0 -> 14949 bytes model_repository/dict_model/config.pbtxt | 24 +++++++++++++++++ src/model_instance_state.cc | 26 ++++++++++++++++-- src/model_instance_state.hh | 4 +++ test_client.py | 25 +++++++++++++++++ test_model.py | 33 +++++++++++++++++++++++ 7 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 model_repository/dict_model/1/model.pt create mode 100644 model_repository/dict_model/config.pbtxt create mode 100644 test_client.py create mode 100644 test_model.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b0e399..fac70e7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -166,7 +166,6 @@ set(PT_LIBS if (${TRITON_PYTORCH_NVSHMEM}) set(PT_LIBS ${PT_LIBS} - "libtorch_nvshmem.so" ) endif() # TRITON_PYTORCH_NVSHMEM diff --git a/model_repository/dict_model/1/model.pt b/model_repository/dict_model/1/model.pt new file mode 100644 index 0000000000000000000000000000000000000000..e01de04248728871644e12fba5b390820acdd50c GIT binary patch literal 14949 zcmdVB2{@JA_cwmb9CJd0q`6dPA@|x!r9p!Us7L38On>gjpZGko9g@AqE+>-z8O+Rky#z1CiP?fqG6?azH53v(f14o6&^ z^Oq*YQR1xhadGw1ba4uF($q3D?ZdJ8OUvJ_Ldsu@(L+oo?P@oqa7|m=bDqtuIe}=#D%<$u7#mr+3Qdatx*AEYIKT#OI(nvJL{2Kz zNDL+o?iWCEejF`-X-5-l!pMReC30s_GCB2GoSw7Lh8^9N#-~ftsex4%y^)Ygiv_~4 zV2vs@9nb-{^5Q@zB#tOON+oj)yh#doG0iRTqmtI~^pSKN*g8+9OHG`qg@ixZ7RJMO z*`C-vHv-0xximLtD_rc}L~z_kDD-r|nzc=k^CX5$?RNy7Z-x^UaSv)WZY|l+v4pg$ z3X!P==Ef%*Cea&gBNIR33;L;qkepX8WJ}{-`q5w+CSKB~q0={#YWqQ~&&p^pPQJ-| zx$PKiJ(EaHEwsq|YcRxP8*Bx^O?RI!*A>My3V5+Z2JniL`#w2IWJ6~Jz~l8o|*qK3gKbV{iOJ)UPw zX~$;h=lOzts1Qz`HFIgcqcCW`*g#rSr0A`P1S%F54hzl&5=#;k8bhzNZz8u9MbXEun-!ed4b>LmRMKz}jSdwhOO>;zt=)CJ zzI!S4|B{E1`gWvM!;91p;gTU=#(|D$7}?ykj;z`+4SY1-!ibl)Sm0I-K7)hEO4o&O z|D_>m8I?>QI#3t zILyPDnmDA9^3&;5uF8R4bdV%3$o0~Pkv?W`2`E$YPOmbTzGA&mHZiXfi)D|nVg^&ob`h>m~u2>Q+Prqu_va6^qD z(LAd{KIarc@V*;txL`89C^L{DBW;&VfY}`^v&m%|sUco? zuyZEyPKqP?UW3Vs*E{IDWHi!|Oi$z#AQ4xjMC-6oflMb?|W`l)i(>V!YB*rd_sHl~L(lRbt z7!m`+Dq_eby&cr7G?boR=!a(wchV`R^{Mvqbhuu*9U7W;k`;Wc>$cp2ZKt-9%O<0t zV{$D$Y`dEdHPhIetc{| zwis_A9nYrF`e!HD?V0Ih-@<8-By<9O%u;B?hs6-3u!l}EOrZBR`qNDwmwD%>E&-#4 z82Td5gkGLL92Dc!=)n@8r!`JtzPAzC)SXPs+@_Ir>vzz#!3n^tNG4+HyD3VHCT9h+ zX#VMVdQ&hPW^T*J$4ARSJ2sg-36h}CbOM+)i+9p0;V5$MdOR5{IGDOe&8IpajM zjUeeW`cRLXmwwbBU-_!?xDHwkF zpiNVk406qa;MOG6o;rsb=~PTvBqND?i+Vb%L9y5vP;&9CsRa1gWC$n~ktc zd^FD7l|l>Bcao|H&Sc8=O{DS8JK&vv!_Mde^pv>)Wv(i8oM#@)k%*;ZYtKMzWj*^= zX9v|!Fd|3$D??3GB7Kx7P41|L`8_o1zF#>fSXThbm*|1`;Gz|=% z!U%bA>AW|d#wYjh!fcNae0-z*DpcY<1w09KL>%2 zh||aSfw@(Fxazn9lv`YZ=nHXhuYVIP-f;`}2)=;Pb$W1h!ccT-kR=ukd+=hRCsawV z#j6GqG$L#{M1_sOrDv)zL$45bPS!)If^1x>FbH*w8gc*56i7HIOJ;4J51sG2VUn0W z+%l-Z>Xr>S>5xA}3Qb}(>bCIG^xW`D=^Gg4pN~3c^3k#I1ZFJvW&AHSK#0FJT(#W- z4bxwu^t!voUq+WeZv1VgJohQH&+imJpCT*U_Z`lZO=J?RBjN3fqZs;|4y$N+5mMdM zm?yr+0Ijt_eEfd=aMhFX*{Oy?GM8`~_XC@~WHS3!;T>EZd>dERMq%-jHK?IA1sn?v zu(NnP=GBPN^7mVr{M%h{bj@wBN;{4;I}ruCi}`Wz3*K(YGCt&R5szONr?0in;<*EN z(f!m@$k)w)19|&E{?;rk*185k!ar<04rn%(ztr+w#_9MCud>U-E>RbYvm=S#?Cj}$ZOVsQ~N#hVWLkb z%nF!+at4o#wcc)rDOUx_*^d{nc(fcD-#-Jr*F`}t=OWInmmvAPd+=N9KpZqm4ZA|G zq3^5dIF{Fr0WE`YR8Kx$bGO2lU~ z&l~cw;BX!4XWPP+H`|%mnI)L4CQUCs(gKsBBe14%04vyZ3%{srVxEsphQ3m1a7bYy zWVhYK?b}}A-n4_L87V}U@0$&8lH_1%TrOi>R{+bn12IvJixUl_LE*hVI0-iJ&PSYq z6K|y0_tX`+HkC|)ts}#Zsbf#6zCph;#kjQV0yGcl&t^|~4;ootnNgRLFfFSV;to#- z**j9Sea1EPpYDT^IxjG5dMe9SmV<8RYHU(JirWi{CO(Mm4~ee6*ihPyJ2ee?tH-o} zohFBV2{;1z>$>sex+=ElVg>FB9Dv&z$G}C_g^3P~gf)}>SV@83@Vv}%_@Li{B{OX? z;gu`nGVu*cJ~aTnl&e@jA_48yGuaaLGZ=Bsfzi-kk4x?ffV`_PdHVSVYuI0uo*b-% zEr~oNX9`)p@={prE=NRO=OIt_6>r;!co@?B2~MSiGbQU6qO+v{hN~<Qh4L`=cz&Sei@&3(D&{#Bu`C8u1n9h(g7VCcv z+>YJ>&&a_rG2|*t+r1NBjoQVS>1lwR%_aU`m7)=imH2RwA5WbP#g~s&VQB1N@Wqm`C=eQl zwT@Ec&9rVLN58t)r?bYj%^-v zn=fKUI^X4m`;KS&%-RdrR0WAvz)2{sj58j_9m?vKi4oUNi*dV)H9Wkwl)aP2V=QM3 zMsE3HX5}COR(j9Q6U(5+Zo(1Qvlpy!pg&o>Rax7yzB<-*HPBMUE;{k4GLJ&coJm#aL9<~q2N{Q$QH=V!$!?a zusgdLoUemNrU1+6N?dJtof#ez&G>a&U}tj#ZuVUTeHHfLD}?}b(|ZguTZHJAJ_gJysa#{aCZ?O@}Wj8;wOS?`Dp@EiUT9hORx(bsG7w9o_4NaGipI+T4Laf96wFG*%KN|OmC z5_no=Z`tKLf^2@ZD1EzUB`zr1jpE+vs3^`MS{0YrYBmOP`?!GTkyp@_R9vBLAQA+knwD_m%~iG!aM;TBF6PU8KBJS$71nUg=T>IDV(azQrie;`dNFNead zTjy|isT4W0{Sa7e>;yI{2Asa+0CDVtrt(i?xNZ^@JeZ&(R0%=hXP!)r+vj zF&aPF8$)oAFiwzsgVndvLC*Xbj9NXNRcda6UAu)yIW{8u{0l2IEDZI!Wr(slzkYe) z4uY4T;gXq&Jd2^XvCZ%_+PAjgO7{!swpY^lSsM$-7!HZOzaM0RB(O+;;Qg#wU=M9j zMV;7-eK+C5q6hfmOC_pE9At`0XP~CMIfSmHm=u}}4mpje!Im=8M+A` zhWh;N$g2@2ZZ#Js&Tn1CXbQ{H>J(AZRmFkYtSro3EkHM&^D`XR?ucGvb6JD;@3605 zH0sxIXmxoDnmX@5sWo_boIzzUs zHR=+6X)I<9XUJifT|4aQe-Ko44uQkbAQYD$4~J*P~G1IUeeHyhG%^OOv zao!7-qo2e!Uwi<46c&JA`fI%RNeI{F?tqQg1&CegF7|<01>Un2C0cNhF;kul_u5{g zslqFW+i3;1>IH1s0t+~BbJ4_2HlkGT0Y4_yf59h$0yNY5BdDY_V0gVJI#*@kXrz#< z*cbc)A}4ZBrD38;8XBy3M9WGMdTeYOGCDzUzcC4tb7YClg*a5-KZUt4^d@^E?KAf2 z%42uiKf>pWglJ-ADK1(Zi}%HZNQ?LvsJtqHEe$MGrfoqeJH*^jyJFlwdj*<4{|pt( zb5vie2wm^Cp-N^41b51k!(%$}aeW80biTnHK5iFm7e|?NRe0Oy6f#|NaAmU-x+}iK z3y-?5mDg}xZ#d8eZO`zZ)Rf>&bGP@DIwyxv2Givcidfryu$rlNPm0p{aYE?vgw{HjBI zpd+~oQ-+;o?u_UJ^RQd&)^Itxu%n%)5-3D9=~0|}_A{i^3DYMPJK@;A2wb=FE3OEa zB3ByyA?Q}WvK=}1*%a}+5P0D=w#0CVL-&20x%npCIk^?j1gFD|8|NXaP>?i?kT6bO zxq(d&+m7vzk*Tsh1Of7bV75Y)7rbZ-DBDCM+xi?g-Vz{TifeeQCL^rw7lt2HW7)7U zDQfxN0rnq{1!MPxyv^nkFynSQ3P2;iTDAjEIP#362Jv~xU>3LPb>pV3qV(*;G>kbs z1^cahh%4`=F$tDNxOn+6Sgj~PR}{^InL_oyfOpI~%giu`Y zz82ritpqk-52LqUh4Up=#;)$RU=Vp1Rxe(Sup<(;wzr{7%* zAhzDw$*lc&0*8`7_JY$jqsYE#_)u{HYCIK(m>f}h(5)W05em>qWXT?hBe-nf3*NKx ztJr=b3ib(iK<3?8m|y;dr*ppm96!tOY(Gd4oAX&vw@#nUYT5CD6H^`IXDaC+4%d@B~o*vPMB3s;O`9Vae^j~eG;>DQ0E z=E}D?DK{FMIKyGt#LJMb*#fm!yo`C5W$ConniExbhhn;C8Msc8qQ~>A*~=lNIBB5_ zIVZ0S)BT2F!pSoB_CsUPG)TaW8&9Ls`_fbHokw8kdUZ9zy&t%GSI6eK}c!80!HB)1m(t zyny_1e$ICiC*t|TVC|qr6i5>wT7z?NL&0pY$hgAV^Lbw3{Ta}qH3LpP6#+*_Rd!Zg zJM6tD$Qan?!G1;?HeC~=R+rxa9TSZ`JjPpxid_~x9L9g(F@EBI^!{i3&yvh8czk&^ z>f64DNl%L*BRm)HURcB|FzLfec#L9$Th;MUOd3qrR)oep4LFd^&r>lEAm1bpwmKp1 zR4>QJHi5i`4R6tBVkjQ6-UIVmKf@^hp|IcJ8Mp~dgn$#ZFege3?p7M$ljw3hygZi~ zF!(4)y*mV!5n_ZHwH2PeO=X+g?m*1yw{YM80Dj!@7H(XO#oscu(5|TrV%GfzH?Q3W z!*&VO*m(>GoD-*=_v?^ne-Jz(@8HxfDWcEK#@bsUyn;x@i3`nTsjt*UlwR$EhnESH z#r#@eg!KjHT)r?(zamQ1mJ~wZob?zrPY5=z2*iZR3|Q0(;_5CNEMekNo42QIYFaGb zyL=wDKNca1T{(~vxeI&;Jcn^ovP4h680?PU0n622VRPJJ$Qn?Op7*$9TCxZYPvj7% zt`!iF7z}T|zQf`+cizPFSCH_}6wQT_#9(wPta)|_YL6A;m<>WCOt}N+9Q}$%EFLls zR``R1x(4i?6NFnWc_0vJhYpTQ7^8cF)Ya%1obr{TS2Q1iT1=I3xmy^>oL0nvyYFM9 zSPmYl$OfVm0DDV>X-nHHC}}*7>mP~H#nJ8fcAF$ETJr#{L%H-w(@Ul`(*;g(Z87zc zADFGW1*R>#*+a=MjcwxR!|kA##tvgfLP7)!H)_Y@RXrYbxJ5wh`3h#KlnDJG%B4dh z*I{^wJ+yO2Kmxgh%f{#8`J>HfF<=|Y7(QZ8=G?&<4>HiCq7-Iih(dA7VYqgWLoQ8{ zp#pYS;n{uuJ(Ab`;F=48l5}Zm*_6o?>>OMs{Dn(FFBB~|L}IL(HB5OZM%6#;fuYqi zuwKgybI;l0+0?zfXX&ZXmtO-IhjHjZKIXdTIkS7RMc}nWA}SoLVLT^kfY$g}43*A9 zugMS4BiM;eI9~=z+&0jiwH!E;3&DCwGH;`pzOm$?w>WFc6|}A3Lal`u<;b@}P7aGf ztHlU;oCME4a``z~9W92BgJ8YS_%Vk==UN@ZBsQCwA%7h0zkY^`Z3!?}a4zVG-iJq~ z$!yF8WRfD5@I+t6;|b<7@A&9EW=6^#s5Yr$Jf6hBnBfsnb4G&Pv(MtaIUbHigDFFc zrRf`qqo@@kuKunrJ^tYO zNsX}#wDgm<;y-7=S?1vs;3wnGAKP8b{kv{$KUsIKxei~Y+&7hcg?x4H-=*Z;1!CO= z=dN1j*XMV~==dr8XUAwb{8*HgPTp=_t`1s%xTn@9Ys8J$1nI4o%Hlf*ejy@4--C z{&OOKeNaT1GwgQ^IrBHY<`2g^X!r)3nGXI>VA`{Tx1wCDZ`3=c@@%4X;gLO;YNcXIpHfA%e+U#dNd+hrmb9qJajbz^qcIma7 zFKTQ)nk2jGOw8Q0dvP#o_Kv8mX4;PK`4N4WM9{nOS>*mh50iR;(Q^#{h4ryHnG zeNF{>TrZkF8ZLaqyE%M|X#W;JamO6<@=0=021nlvd$QC= z^?p|4iEtbFtnh|oW?4!#(~|V8Yi?hz8d*Qd=kBY4pA%JKGcWI2#F~;LT2h_t)k`%r z^{%R^yRV7=C+ClbDlN6qFWuLyo>TSZQq$O2uX)uo2B_wSXX`CDeD!+hE0LgMM&`5Z zjLiFf4+tTNXDXUnGeI1eawdQ(Ys+jzSd!b|6Hk|2u z`FY^Hxw2Pfo@AW&FjS~XIG9|d_VH9sqB1=28#%7@VP*cU(mT()pLtbB8NA=s|5!q- z-WF@Y0f+ks?u#EM@pOI2n2~zRxrAyB%&$f0NLxZ`(-&*QIM$G@(fCOFzj<_XjIe2)Fs>R91*Jh=zXhMYsmD&PO>=vF9Fpznp)!>_&%uj|*)dof>+LpDyhAZv5xSW5f~rZ(Z0M z7D7GF#GU5z?rZgpyp#0m0+U>v?&?;Ta!@z0-`3@v*yP!h)pxxRwf0^XbM~ln)yhQ@ z!>u0p+#enqZ#8u2v@-v~02=Uom;cR*77l0KsO)=hQ_=@*lfCeGP2d`K-uWiu>AV4f zM?8h@oLR8Mv7nEr^NM3B-1Z#Fk#4@{XP@6ylj*v0Uj!X4{6`SpVJkz02?`-G}B zr(EysjJnyCP}+2HllIO&Bh^#6?@WEGbN%-1v48QjYHz3V#hScW=1zq{=jFl`)vCe+ zW#;N+J#1?(N<2?))2O@`U43X#VvFWk<9O?G0m&5J$ z41RQSbK*puq9U<*y8)x#-91w5SU={=VbOP2>s4uPhe@W1)ff4&BkVn144E0crbEea zx4*o+#^Ki8Z@O}>1V=Z@3yQaW46tt2ZC+6^O-$cH;dpJI6(%=aOTrh|4G^p}P%XRs zMp-!OYm?E62j#ELJ{QD4Eh_!W-vDCbdJ2VOv-z80sK~!v-zf7TXLx?IuWfiAe6m@36sN*iVx@|;;845P+sv(To?dI+&87^g>^!%XcjL50 zlDe4C`lFKe2?-y*uIj6E<~|B2v<2RHkKd+uM|HSv;oTO?udh|Nu68NvU(@7v zBNJT?N_)MDw~w(us5pODZu$aS?LIrl)4pWK*5+mXuO1rkyq<6pTw88UTOUGioxRkw zLq=qbs78EeUb^zj7r_^ft4lMZj&@rPcKVT#nYLB`(sv((P2 zZyY#EU{ggO`7zga1?+jfC!lEOr;*(PoLk2IWll_Q`|KMy>55r-`TY=`+jos_H)KXE z9oR0KaZ1p6msY<8iIoc{ne5Ruid?Wbgl?837Elpv&?Sx$f0(9mjBoU`P`UE?4nuZagXHneuHyIr1q{EJxEsC<)Qa!Cg3mFq^0 zk@aa7tWwYzdhUQCQ=HL#t+ZXscq%8%ZsYhZ>qK+5o@&n&U9nH(x@kvFSxi!`j?LuE z_vfdqF71vwC{|oqP{_l2<_BZ&B{kQ${AMl%H z{r?aAGA0hsS|4=Ghz{>7qFv zKWZNPpfGkLhZD42?91UHM#lY~C)VF;L47#UV?N=Q^uGTTdIms`d^Wj@yeBf(C>jb9b-j>_lLLU{C=^d7>tt@}<8aKmw-bRnTCwC803!PMWI;j%AB%0X;S_tXasiaOnTvZmTRM*GG ztOvd)`S6{zv3OVZ;=89kE3STnuWS!|^GxDSX7z&a-~V7hq1S1|J0Q@>J22ooqw;U( zk5Iue-_DvVH6@nx5ClE_J72^3>rs=V$T9joLh$(%XIB>o4{tu%643MG1pWU8^?ycz zf98giIg)=j*7qri{TPkU|9od!R(J$zIQV)vdwRKY{at%V6<3GAU|;?pCpSJ3#HWFJ z$d=_k{yk)mpMX1O>FEE=(Fp#+(Fpx5Md^&ImB{-9`L|Cyo@ z|3=YBxLQf@ZJg~95NJEuJ9w!dH_CFARS$#KL*&>71p1Ft`{T={ev(m^ep25S&#zgv zT!6nndip;*RGPoI-%mV2uKodhp2*CU!_npTY_Pfvw?+KZ5~rvCvjSPpEW2nOyS2${JEz8%{woTz(6n89vdRze*Yc@t*8I94Jv;yh8o{~78Bs=@8RU-vDT@F z_;T=Y`EU6{EnR~NIuo?@we+-fjf@Np^tB8pXib=)t3QGNM&CeJ$G}iWPus|_$HIjR zSKb)Sx2~sKm<#kc>yKaV^nPYns^^QJ+IP2q{~gb-lD(f={aI3tkHB7%+P_NnzOVJO zWY0e8C8_hPWbZpnKTG=d|4CB!SIOQS^3RfL1Ada!`&F_xBmA@E8r7d9^?#M@&4bDO zw_JkMDc$XTSc_9RBtY`0K2C z7qI3(4QHtS|6sU(o@eix?dN&=kNaN^+*7ChI?vvj{^xnVAOHVgxL@b_XZ-4&{}q06 zu(sBp1NSN`zFW?}k{_ggZHhmAu=5WEzZ4BWRQw>dd#ACVvC7f@VZ%=<{sjD=6X~7G zq<^wwzRnLNzvsVS$M2ni_>T(oY6*Xc{*n2974DrJRCIEn2UV7N|9fx^<4k`+W!Oa CGV!JW literal 0 HcmV?d00001 diff --git a/model_repository/dict_model/config.pbtxt b/model_repository/dict_model/config.pbtxt new file mode 100644 index 0000000..c39e708 --- /dev/null +++ b/model_repository/dict_model/config.pbtxt @@ -0,0 +1,24 @@ +name: "dict_model" +platform: "pytorch_libtorch" +max_batch_size: 8 + +input [ + { + name: "INPUT__0" + data_type: TYPE_FP32 + dims: [ 10 ] + } +] + +output [ + { + name: "logits" + data_type: TYPE_FP32 + dims: [ 20 ] + }, + { + name: "embeddings" + data_type: TYPE_FP32 + dims: [ 5 ] + } +] diff --git a/src/model_instance_state.cc b/src/model_instance_state.cc index d634f3b..de2603c 100644 --- a/src/model_instance_state.cc +++ b/src/model_instance_state.cc @@ -51,7 +51,7 @@ namespace triton::backend::pytorch { ModelInstanceState::ModelInstanceState( ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) : BackendModelInstance(model_state, triton_model_instance), - model_state_(model_state), device_(torch::kCPU), is_dict_input_(false), + model_state_(model_state), device_(torch::kCPU), is_dict_input_(false), is_dict_output_(false), device_cnt_(0) { if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { @@ -345,6 +345,18 @@ ModelInstanceState::Execute( list_output.elementType()->str() + "]"); } output_tensors->push_back(model_outputs_); + } else if (model_outputs_.isGenericDict()) { + is_dict_output_ = true; + auto dict_output = model_outputs_.toGenericDict(); + output_dict_key_to_index_.clear(); + + int index = 0; + for (auto it = dict_output.begin(); it != dict_output.end(); ++it) { + std::string key = it->key().toStringRef(); + output_tensors->push_back(it->value()); + output_dict_key_to_index_[key] = index; + index++; + } } else { throw std::invalid_argument( "output must be of type Tensor, List[str] or Tuple containing one of " @@ -872,7 +884,17 @@ ModelInstanceState::ReadOutputTensors( // The serialized string buffer must be valid until output copies are done std::vector> string_buffer; for (auto& output : model_state_->ModelOutputs()) { - int op_index = output_index_map_[output.first]; + // Use dict key mapping if available + int op_index; + if (is_dict_output_) { + auto it = output_dict_key_to_index_.find(output.first); + if (it == output_dict_key_to_index_.end()) { + continue; // Skip outputs not in dict + } + op_index = it->second; + } else { + op_index = output_index_map_[output.first]; + } auto name = output.first; auto output_tensor_pair = output.second; diff --git a/src/model_instance_state.hh b/src/model_instance_state.hh index b495510..143e4d5 100644 --- a/src/model_instance_state.hh +++ b/src/model_instance_state.hh @@ -73,6 +73,10 @@ class ModelInstanceState : public BackendModelInstance { // Map from configuration name for an output to the index of // that output in the model. std::unordered_map output_index_map_; + + // If the output is a dictionary of tensors. + std::unordered_map output_dict_key_to_index_; + bool is_dict_output_; std::unordered_map output_dtype_map_; // If the input to the tensor is a dictionary of tensors. diff --git a/test_client.py b/test_client.py new file mode 100644 index 0000000..f11cb15 --- /dev/null +++ b/test_client.py @@ -0,0 +1,25 @@ +# test_client.py +import tritonclient.http as httpclient +import numpy as np + +# Create client +client = httpclient.InferenceServerClient(url="localhost:8000") + +# Prepare input +input_data = np.random.randn(5, 10).astype(np.float32) +inputs = [httpclient.InferInput("INPUT__0", input_data.shape, "FP32")] +inputs[0].set_data_from_numpy(input_data) + +# Request outputs by dict key names +outputs = [ + httpclient.InferRequestedOutput("logits"), + httpclient.InferRequestedOutput("embeddings") +] + +# Infer +results = client.infer("dict_model", inputs, outputs=outputs) + +# Check output names +print("Output names:", results.get_response()) +print("Logits shape:", results.as_numpy("logits").shape) +print("Embeddings shape:", results.as_numpy("embeddings").shape) diff --git a/test_model.py b/test_model.py new file mode 100644 index 0000000..e2f5b68 --- /dev/null +++ b/test_model.py @@ -0,0 +1,33 @@ +# test_model.py +import torch +import torch.nn as nn + +class DictOutputModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 50) + self.fc2 = nn.Linear(50, 20) + self.fc3 = nn.Linear(50, 5) + + def forward(self, x): + features = self.fc1(x) + logits = self.fc2(features) + embeddings = self.fc3(features) + + # Return dictionary + return { + "logits": logits, + "embeddings": embeddings + } + +# Create and save model +model = DictOutputModel() +model.eval() + +# Trace with example input +example_input = torch.randn(1, 10) +traced_model = torch.jit.trace(model, example_input, strict=False) + +# Save +torch.jit.save(traced_model, "model.pt") +print("Model saved!") From e8cf1e5f37577d43caf9514af56615c9ec09321a Mon Sep 17 00:00:00 2001 From: Harshal Date: Sun, 16 Nov 2025 13:42:18 -0500 Subject: [PATCH 2/3] Add cached output validation --- src/model_instance_state.cc | 73 ++++++++++++++++++++++++++++--------- src/model_instance_state.hh | 13 ++++++- 2 files changed, 67 insertions(+), 19 deletions(-) diff --git a/src/model_instance_state.cc b/src/model_instance_state.cc index de2603c..effffe9 100644 --- a/src/model_instance_state.cc +++ b/src/model_instance_state.cc @@ -51,7 +51,8 @@ namespace triton::backend::pytorch { ModelInstanceState::ModelInstanceState( ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) : BackendModelInstance(model_state, triton_model_instance), - model_state_(model_state), device_(torch::kCPU), is_dict_input_(false), is_dict_output_(false), + model_state_(model_state), device_(torch::kCPU), is_dict_input_(false) + dict_output_validated_(false), device_cnt_(0) { if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) { @@ -149,6 +150,47 @@ ModelInstanceState::ModelInstanceState( THROW_IF_BACKEND_INSTANCE_ERROR(ValidateOutputs()); } +TRITONSERVER_Error* +ModelInstanceState::ValidateAndCacheDictOutput( + const c10::Dict& dict_output) +{ + if (dict_output_validated_.load(std::memory_order_acquire)) { + return nullptr; + } + std::lock_guard lock(dict_validation_mutex_); + if (dict_output_validated_.load(std::memory_order_acquire)) { + return nullptr; + } + if (dict_output.size() == 0) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, "Empty dict"); + } + std::vector temp_keys; + std::unordered_map temp_index; + size_t idx = 0; + for (auto it = dict_output.begin(); it != dict_output.end(); ++it) { + std::string key = it->key().toStringRef(); + if (!it->value().isTensor()) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, "Not tensor"); + } + temp_keys.push_back(key); + temp_index[key] = idx++; + } + std::vector missing; + for (auto& output : model_state_->ModelOutputs()) { + if (temp_index.find(output.first) == temp_index.end()) { + missing.push_back(output.first); + } + } + if (!missing.empty()) { + return TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INVALID_ARG, "Missing keys"); + } + output_dict_keys_ = std::move(temp_keys); + output_dict_key_to_index_ = std::move(temp_index); + dict_output_validated_.store(true, std::memory_order_release); + return nullptr; +} + + ModelInstanceState::~ModelInstanceState() { torch_model_.reset(); @@ -346,16 +388,16 @@ ModelInstanceState::Execute( } output_tensors->push_back(model_outputs_); } else if (model_outputs_.isGenericDict()) { - is_dict_output_ = true; auto dict_output = model_outputs_.toGenericDict(); - output_dict_key_to_index_.clear(); - - int index = 0; - for (auto it = dict_output.begin(); it != dict_output.end(); ++it) { - std::string key = it->key().toStringRef(); - output_tensors->push_back(it->value()); - output_dict_key_to_index_[key] = index; - index++; + if (!dict_output_validated_.load(std::memory_order_acquire)) { + TRITONSERVER_Error* err = ValidateAndCacheDictOutput(dict_output); + if (err != nullptr) { + SendErrorForResponses(responses, request_count, err); + return; + } + } + for (const auto& key : output_dict_keys_) { + output_tensors->push_back(dict_output.at(key)); } } else { throw std::invalid_argument( @@ -885,15 +927,12 @@ ModelInstanceState::ReadOutputTensors( std::vector> string_buffer; for (auto& output : model_state_->ModelOutputs()) { // Use dict key mapping if available - int op_index; - if (is_dict_output_) { + int op_index = output_index_map_[output.first]; + if (dict_output_validated_.load(std::memory_order_acquire)) { auto it = output_dict_key_to_index_.find(output.first); - if (it == output_dict_key_to_index_.end()) { - continue; // Skip outputs not in dict + if (it != output_dict_key_to_index_.end()) { + op_index = it->second; } - op_index = it->second; - } else { - op_index = output_index_map_[output.first]; } auto name = output.first; auto output_tensor_pair = output.second; diff --git a/src/model_instance_state.hh b/src/model_instance_state.hh index 143e4d5..092fa46 100644 --- a/src/model_instance_state.hh +++ b/src/model_instance_state.hh @@ -26,6 +26,9 @@ #pragma once +#include +#include + #include #include @@ -75,8 +78,11 @@ class ModelInstanceState : public BackendModelInstance { std::unordered_map output_index_map_; // If the output is a dictionary of tensors. - std::unordered_map output_dict_key_to_index_; - bool is_dict_output_; + std::atomic dict_output_validated_; + std::mutex dict_validation_mutex_; + std::vector output_dict_keys_; + std::unordered_map output_dict_key_to_index_; + std::unordered_map output_dtype_map_; // If the input to the tensor is a dictionary of tensors. @@ -96,6 +102,9 @@ class ModelInstanceState : public BackendModelInstance { int device_cnt_; public: + TRITONSERVER_Error* ValidateAndCacheDictOutput( + const c10::Dict& dict_output); + virtual ~ModelInstanceState(); // Clear CUDA cache From c527344722878d83c6d91aa9a22ebc1b78221c7e Mon Sep 17 00:00:00 2001 From: Harshal Chaudhari Date: Sun, 16 Nov 2025 17:00:21 -0500 Subject: [PATCH 3/3] Fix formatting --- src/model_instance_state.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model_instance_state.cc b/src/model_instance_state.cc index effffe9..f738980 100644 --- a/src/model_instance_state.cc +++ b/src/model_instance_state.cc @@ -51,7 +51,7 @@ namespace triton::backend::pytorch { ModelInstanceState::ModelInstanceState( ModelState* model_state, TRITONBACKEND_ModelInstance* triton_model_instance) : BackendModelInstance(model_state, triton_model_instance), - model_state_(model_state), device_(torch::kCPU), is_dict_input_(false) + model_state_(model_state), device_(torch::kCPU), is_dict_input_(false), dict_output_validated_(false), device_cnt_(0) {