From 41b26b7fd9cde28890501eb272609e7325fb7cd0 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Wed, 22 Oct 2025 19:38:07 -0400 Subject: [PATCH 01/19] latest local updates --- checkpoints/epoch=0-step=2-v1.ckpt | Bin 0 -> 4573 bytes checkpoints/epoch=0-step=2-v2.ckpt | Bin 0 -> 4573 bytes checkpoints/epoch=0-step=2-v3.ckpt | Bin 0 -> 4573 bytes checkpoints/epoch=0-step=2.ckpt | Bin 0 -> 4573 bytes contextualized/regression/__init__.py | 43 +++ contextualized/regression/datamodules.py | 225 ++++++++++++++++ contextualized/regression/datasets.py | 7 +- .../regression/lightning_modules.py | 3 + contextualized_sanity_run.py | 175 ++++++++++++ scripts/test_contextualized_dm.py | 252 ++++++++++++++++++ 10 files changed, 702 insertions(+), 3 deletions(-) create mode 100644 checkpoints/epoch=0-step=2-v1.ckpt create mode 100644 checkpoints/epoch=0-step=2-v2.ckpt create mode 100644 checkpoints/epoch=0-step=2-v3.ckpt create mode 100644 checkpoints/epoch=0-step=2.ckpt create mode 100644 contextualized/regression/datamodules.py create mode 100644 contextualized_sanity_run.py create mode 100644 scripts/test_contextualized_dm.py diff --git a/checkpoints/epoch=0-step=2-v1.ckpt b/checkpoints/epoch=0-step=2-v1.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..7979456c38ce54914d6553ee49444f02016fc7ee GIT binary patch literal 4573 zcmcInTXP)M5uTN?EEeMnmpC@qF%Y9wY^_(j*3Dqck;R&1K^DS>iL=V&?99>5pxK%2 znKQB_alpiwpj^1+&h>$}JfwIdZ>dUt0aX+aRH~>{9;iI!HI=9I>DkN7T8e{J8P(Kk zdQN{|pYA^0J#&?9w`tnIfOf;RN4ra7foVI-{J6yg%Qu(Y^7Jn4QtGFFH7Vgk92wQ-i`ZoMEn-Sbu z$q3l>R7D#aTDKMTZUKAdVJ~gi=FG}p=4xwjmyHy5){Y@Y;BK2#??Kf{rjmID?u~!G z7Q^fl@b!7vuiuV51XFWZ2nSLiMYt~pxt~C;!@)}CIvk1tz9Hc7JUkFdc#uGh1{X{P z4IUypznRk!T7x4J9wukC{XSU9gtWaz$RQ_T$R-J<1Q`O~l5mtnQCOo8xxU{F;h2Qs zoUW2)(GiA{L0-b~oV{H_LBAK}lnnGHpNQm$gd)8*YD}0m4-H{6K}pi&7L+P|Atg#e zn-&_)z^?~94B>=?(Ky|NFPIBs62{}_0R6uL6PRksji%pdx?CXWBuZ3%pc#RBKG|A< zkR~FcorOYTGzCp26Zir9WQ{c5Bi?Aa}D%Zgkc zbZwrJa3*J?rQcSd8iHz(jU3b6$4EbxhGCKL8%#K+;WvfTa8~hu-%I-Ptb`f8@IPr0 zMAx3jB$UZXEodIVISG&FbQN|wu-}pJU9z~+aRXWWME5ZwgjF8+Mm=B_JSkz8RL*zS zD9@4FTv`-Gp+1y0yurK(tHow}eM-W-Uc9+Vyiv7aG& zjJA%Jg8Jc&;waKps^TUbtN}bepG-oH&ZkK%JqBA6II%1$IC2uQXlQVps`U-Tr`M54 zoi)MnKu{5|B{)R#oQ;S}gy?Rdipk&DKum}`5wVE{9_xDy>Phekw5dRWK!Xjm5j4C} z9JP!znd{b=xfFts5b0~k@}zHZ_pHs$rKazAVr{VDdk$9r!IOt)eUBeIi~@_}(_q8# z2Jw+&*<3L^Lj=t6sA`4ywvc#3g-Z!XP#7=6)5{Hn=K%(!1ubD&!sWOLwMgO0tiA`0 zswrnGUhJKwA8=Ti)pw`mO(twu)f*^1d$M{l#2#7=StH`r^B$iGuwhqcd^2ibDGRH9 zlcTAA6qvkf`e?=#JZR1;x2k99NwgoaZnYc{R%sij3d8CoohlA5kL2-{fHlgiR+Vg- z_xvoCPRj|%EU7>9e=`eSAb#@DQtYD;UX*ZEzZ*H1gl#08N7lS#AE69Y@6h`ZhaaVy zY6-5zP4(qOQ@<|Ym3jCvd7Wx~91Q||)jop7OU0>EJ24ONkcF_m06&rN)3_Ot41Olz z=Q`yx3s?i|q#t4Pd~E?SAmNf@D@%|>UdeK zfic8s34VbR1s6{lhQ&>`0>4C|>ck!<4p!r3_|*{l!Qt2I6rn{-CWZsrZ)T|g zSkB@i&QYFlSk2||+jVO1i;-%Gznj(j5wWSP4&m(ucqd1LOCOI|#3%@Efbnm2x)9!7 zfcK*OTxG{V<*_dcJ8^zf_uQtlYj3VdAjsfc-wZ2olpi$U50Ti`Ekm0#1Y zzNq;BeN=vabN=0l3~YS@Y%%{zS;hXz`}K=A&;LVOu@n6~B_FmfM|C3_ZVksOPv^g4 zHjIo-j*X9v701U*r4uK{CnhFFM+=k1i7^yRj82x6c(O1$p{zXUAN+F;^Hbe(m5egD z^W~eYTdKULuio$AUhQN7+LuzTVjXO2_f)!GkYo`#kOKK6335Z_>cvSGg1sr6dHgVM z8`CL3z3R!6tme_~^;tqat{uI)$y1oty?}kIO-;I3xYH3xo}jeur@gwxUfpE&r*$vj zB%IbR}EU->(hj8>$LPLCsRs&Lv*isw4+VDak9@<`umTy6b@}0?x1zL8 hqeRuTf3Qoy)mh!uO-RX-awo2;wp-C`)#t1ZU2>Ol0A=+R?jqulZB!wy|{ZJX&Fclzn`p20k3Oyid{bPvwK&Wy4X6;#d9P1l1S zos4AFfLpeoRy|e;EZa~yb9~oRm-1=2BNM6kxKpb#NJlE^lI?2z(MXtvJCzi|=3tky z4gc!(0CvkM4!w`a(%|5_DNOI-&^HHr>A)t_jO=A5js|y`Xu?dJ5yUXuZSI!)IP61J zIVJa9hkN8cIdvUA5W(!{@WDAapxl9W2&Q6bJ{(MdKn9P_6Wumf)3MOsYjVP9{+M3TJ z{xJ1SpxoDST0E8!*vxBlJ{`iDjEPJ?E!rO6x+){Ihusg0`Ar{Zt1F9=T^;H z#s9lb+?QuVm{xNClNLd=`8*UtiJa7cW*(di;S(7}gq<$zCqwuYF_v3yAjY3=F9SkY zWuB{6JsVO* z2TmNCaUFK_5DE;2tAU#34B(Ps*<4W_m3x}yP}TDBYcBSM3YQQLqcGZrx0k61&pZrB z13HAI5H3ees9_YAXOuo9RT0h7w;GP{%XEy%+>p3Q zllh^gd=@`(SfjjZRLKUtFV0ZuG%Sy3h00_9H(Kx|vQHjbjC|z76Cqqt?nax7+*D)E zBQ{T(hbcqFH}rnM;HgAYEx^^NsXiTR>eo12pMx)x*QwS=(ICKQ%)?l`M4Vc+6Y~Hs znGfsp@N5WQiJB3~;JFZA$&EXbRmrG!Pm~6etPw-b3fm= zMjf^0iogXm_&QEf;(8fc4P%JY5_|(C0xp^|RDc?ZbEH;d>bx zT)J4G1sVmx_Avg9PUpk-=ivuIc2?fmbNGdyy}AqMH}TGHIa}_Rj63!ccJPChnI?l_hPrnWAwNnaQZU* zSwvI3TNBtToo*4RtZ?RvNw!ox`F`v+?b{aTHWM7KP5$YCh?;n}=FlrmwK#|Y@mQ`@ zJaHQY&=^|V*>mrk)Ax(u~{Y=vOaH)9eHmqN|S$rPK=I?j^@Y4 z3Wejx$HvFUM@EK6bK@hqiNeIl$mm4AFgl(a9~Uh9|6P1JgE1EGteg^rTfcncbxV}@ z7B;O aP*41xmDU=;jk?|rQTL@Z{uBG7z5fM^XYsxO literal 0 HcmV?d00001 diff --git a/checkpoints/epoch=0-step=2-v3.ckpt b/checkpoints/epoch=0-step=2-v3.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..3bf8a039fa18ac809e90ce1c3082d67478a0f52a GIT binary patch literal 4573 zcmcIn-ESOM6`ytD#M$P(x#xF&_ndRjIXkUx+N>x&J<2UluX49Se8aLA=`oWqlC3W|rQ%NILgKmH#X-qv z-Kb9Mo?$^sRkXWsRdu`yarA)EI&`UB+CG%kmzal$1CCv_m}|RLeUbWs?YXc?BIL8f z*#d0V(zrKZgi+nJ4F=tY;W-X9814+xXML0U)SSdGY1q;~16woNR#ec5YgnESJ%%)r z)k9`GL0b1|CA1w=XVeWmUti3pVOu6r32-N=Qb5}6S_G~0f!(CbmVFzKmwh8|l z^$>QbDF!>wsY-wUswGVCV$eGSyXAo`O3dscTMAI zlcY{fD5zb5Oq7_m+^E-mujl>-q?YV1|p*5akVRyTa3{6smtIA4&xaMp)6FkP{ER0b|8u!=ne^3#2+SJ z4z&k5PK(DfLx=ip&L=pW%vi|u^8!>sP$jg`MzOsNwWCQG5;Ctwm~H4@o!K>e8UOFO zabKR|FsTjxPg)Y9&F2doO43OUXzIgh4o_w@5q5cD&v1B3GFDq|kc^*hFGGp2OnpzU z`ox51I7~^Ev#mYCb5drWD|5jVtYvC&!2ahTEaw^xY{Dvr=T?Xw(aC9(5@ zMcez&wU?oGFp0g0TtrsK{ryD_n)dJp`#Q4pOB^mpo6id)DU5guosq&cN=b*5*Aa_c z+S!6{CHx$RFH0!BsiP#IzH+NL3^kFecnBM70FCY^;ZTwH(~wp=2%8+Jv@$Q?q?6E! zf`p?~t*#@USVcqPt}%)iLKXpwgDs7`U?JjyM0D0s#pJK8BbJEU5wVU19_xDq>T>WT zXkCDU1ohX^M$q6^ao8Yf5XY$yV<7;>A=FmT$}^rxol_Py7V4huvX#D?=h|5P`;H%) z@?3iK5DH9+tG=4;_TiFZ*<8|Ho%zIeWz`Ds>tO7QEL=jELt(THZ!gsmp86P&1~i97 z4i}>))G!K{rnFupRT0h=^o+2Er7S3Wb&6EI&^Kt=@Q}tOylD0^HOr^uiZz#5w;GNJ%JLZFg9G9wjpqjz z3t9ZcU`6Itqe?dDeRWEfPSf@!Ev`NHf1?FollG;D79t-7@FIuH+8(sIz$`uHJjv$k z)?t~U;v0HDr0|VIQ!T)isHwgbYwA}ST$_P!O0UaWA4Nj~zHA-F;w9qLs-2hzc*z1- zor706T#uTOl)?{Ii6qje#SWriB#6MAXE)F^8SXM2mwM5RVO(ipTGO z02)JUTYC=tZSo;8K8W|@te;F;&y|W#+=2BA7vBHi%#C-3w%vH;!{7Ys`PtXT-~Sy< z-1p#_wca;>zIN__e&d~Yf4}z8o?on)|M>Oz)b;)~^7bFso}Bpm+JAzlCwO)H4Z;1r z!}sqviM+-8aojg2x#vqoMd9x3gur<9xBvX$Q=GqWr{`Cc%P$K1A06R;xxM{eu?%ed z1h~WeE2M?>n?I|bzkU86RPrqkW@YkW>#{{RlEbZTo26p*6J|qhdP;M|cHY`~Fx%-9d8H}-br`41o-1_AkuUn$LC!fyW z!JXRi0<<@wTE;rq)b6NsogncdvOfWGGY)b~|)b>lZFsXL2(t7%&5V&P6kAbx|Ay1h8fCUodtoR7CLp~xhn>( z_Vb6BZsWFeD#uew{Dx>hGsvS!di`XdR<~?9+E6&SX`rQw!In?+?-iM-it_Kq$xu)H Xo>f*F!Hv3Jx2XF{8vlv?(cb?85FYM# literal 0 HcmV?d00001 diff --git a/checkpoints/epoch=0-step=2.ckpt b/checkpoints/epoch=0-step=2.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..9a5597bbfecfe1e930b52d73d96f11cc40a26cdf GIT binary patch literal 4573 zcmcInU2Ggz6`pneoNbfXO-q`PnzRtSrpd1qdJ#)`@zI)EO=bW99H*J=r?r!O}r$>4~(mdU?7TFO)<65@9V3&$JrSpmBQdbC4 zM%jYutnTV2q-06C4_8&&t!TFDb5@5gxl7rLvicHt5%I8XRZZ?#PE}oGo^QDhY$AkQ zb|jmJ%}N^g`ds5oH7uRO7F~C3o9P^P`svf2!8~S6;g>XQ9hik}8D$$PsG6gjt_R&Z z8Of>vw`@PHdaM#ywxM$7_^zie=F+e|6RG&PQ>!vaM=I%(?P~nFNSKEEloY}a!u`r7 z{Hxak*deDl>^vh&0|Tq3FujXI&n)!Pfla0v*$Ye@4IVJjgqb!Yh#}Z*?vQ&pJcz1t zO76J^56L}p>Kc41g4x62;aS+LY)3l;Q?WE3_9Z}a@M#1=`}+yxDm)^mu0me~@EHyV zX5rBwgvSU(tud~NpuuN}=jSpCLQ8NkgwK<+nttz>Q$8K9<}=8IFkq4dQ-X-V7eaWP zM1fzS5ZSI<_u)_ohck*unt6+>LI&9oj%3U&Aq*-Hp`4O|zGP#O910;vpVf-Sb(8rj zH#JU4QfCI_lEiD86JWL2=&WZ#m{oFjR*4QOj?g~sb0M51 zV&?>lw)dZJF9W47iM@ziL{`WB{pAo8<*^O+b!6!aA)F_hUlB%981WQ3BZX;{B8QaM z5sO^f*@AB+{9Fhx5|rB1krGfZ-7XFSMWiYo!onKBqWh_ZP@(&&la&s_MhJ|o%nLYj z5?YavaFnXmb;OgaXh_^O&hSD|5imor$jHkkA}$c3y^bm-e{CHxA>NOObu92$-y={b z1eZYT0u%_;TSpr~{oBPsgQTw6c16<{eBdDj$_iR})-{-Y(q#HV-L)LP(qD5O3#)(s z_6r&$hFEg!!Q#=cPD62c)AM%(cAG8N&OhXH9o zhp-sJg{TQNjKalfr3Xn>gfoR*?47#nF<6>bb|vL?jhnEnyo}P5aQ6d}MzB)~H)37|^6e`dE-#Ee7$UeDeA##upuY_<(*^M?AxT(fGM{K@s z9;EaX-_Z2|gUgAInujYjFdtUu z;F}@59(5v;!M8$qL!o5WJgtUB(hab4etQnS6T){hN*BV|9(?cA@mH7seCp>9u24Iz zxgu&o4c^2tO580&t6>0fRD!oqBH*GCLp7MLEy4FuDB7@tS%al`5w5RBINC+gG#I?S zN)ejJOww>VduN(%zhTYK;}GR=ORLxney~cteLfJ~@Q2gNRz$1|);_#D2S3Wt)Y8QQ zEzl?kwukX=v^gJsJO@7svNQ6w?%i+RdVf0(Z{nTVaIoAX63!ccJPCpO=dH|IswDDGUn&$1lK#BB0{k zm_V;|yhWh0#F;B5*;4U|yRqAJg>7+eGr{5N#9#J`z=?Nb4m+jE76&mP9vLhZ$M1mv z8bfQ_y7&KS>QOO1i1(AMpGaCCDix32gZ1<0&5uUb4jlRGjitAqTpPdn+hdvQH*dVZ zzjouPU!Gg*y*@wr-l>JP-k-5acc?gd=BF>MT`qnwF>!Nb?N=YaKJmMUUz`-&Km7RD z@0>v1;{7DpOOz-mnFIp4Y#Ug+{OJb94(9%3c1nIe13R%bZl&FWMpWt zP#8R#8$CLP!a{B=H=NH4mT&yy@6#EKv3O_Xlpx&tg4#m`B^sA7biJ?da5v z-=w7OS!`WR)1-@qI~jra4NB_v;53`i&2{R=vp=bOwo?}~A!#?ByJFC4KfjOZHf~F& zay+HPZ;19Yi#(d7v+E}NjJ$R0p@u@=ro$~&3^qN@zgH+xCF$RdlcApYJu9s=f*W torch.Tensor: + if isinstance(x, torch.Tensor): + return x.to(dtype=dtype, copy=False) + if isinstance(x, (pd.DataFrame, pd.Series)): + x = x.to_numpy(copy=False) + # x is now np.ndarray or array-like + return torch.tensor(x, dtype=dtype) + + +def _maybe_index(x: torch.Tensor, idx: IndexLike) -> torch.Tensor: + if idx is None: + return x + if isinstance(idx, torch.Tensor): + return x[idx] + if isinstance(idx, np.ndarray): + idx = torch.as_tensor(idx, dtype=torch.long) + return x[idx] + # assume Sequence[int] + return x[torch.as_tensor(idx, dtype=torch.long)] + + +class ContextualizedRegressionDataModule(pl.LightningDataModule): + """ + DataModule that returns map-style datasets for contextualized regression, + allowing Lightning's Trainer (DDP) to auto-attach DistributedSampler and shard data. + + give ∈ { + "singletask_multivariate", + "singletask_univariate", + "multitask_multivariate", + "multitask_univariate", + } + """ + + def __init__( + self, + C: TensorLike, + X: TensorLike, + Y: Optional[TensorLike], + *, + task_type: str, + # splits: pass explicit index arrays OR a splitter callable + train_idx: IndexLike = None, + val_idx: IndexLike = None, + test_idx: IndexLike = None, + predict_idx: IndexLike = None, + splitter: Optional[ + Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], + Tuple[IndexLike, IndexLike, IndexLike]] + ] = None, + # dataloader config + batch_size: int = 32, + num_workers: int = 0, + pin_memory: bool = True, + persistent_workers: bool = False, + drop_last: bool = False, + shuffle_train: bool = True, + shuffle_eval: bool = False, + dtype: torch.dtype = torch.float, + ): + super().__init__() + if task_type not in TASK_TO_DATASET: + raise ValueError( + f"Unknown task_type={task_type!r}. " + f"Expected one of {list(TASK_TO_DATASET)}." + ) + self.task_type = task_type + + # raw inputs (convert in setup) + self._C_raw = C + self._X_raw = X + self._Y_raw = Y + + # split config + self.train_idx = train_idx + self.val_idx = val_idx + self.test_idx = test_idx + self.predict_idx = predict_idx + self.splitter = splitter + + # dl config + self.batch_size = batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = bool(persistent_workers and num_workers > 0) + self.drop_last = drop_last + self.shuffle_train = shuffle_train + self.shuffle_eval = shuffle_eval + self.dtype = dtype + + # will be set in setup() + self.C: Optional[torch.Tensor] = None + self.X: Optional[torch.Tensor] = None + self.Y: Optional[torch.Tensor] = None + + self.ds_train = None + self.ds_val = None + self.ds_test = None + self.ds_predict = None + + # One-time downloads or heavy ops would go here; we have none. + def prepare_data(self) -> None: + pass + + def setup(self, stage: Optional[str] = None) -> None: + # Convert inputs to tensors + C = _to_tensor(self._C_raw, self.dtype) + X = _to_tensor(self._X_raw, self.dtype) + Y = None if self._Y_raw is None else _to_tensor(self._Y_raw, self.dtype) + + # Basic shape sanity could be added here if desired. + + # If no explicit indices were given, allow a splitter to define them. + if self.train_idx is None and self.val_idx is None and self.test_idx is None: + if self.splitter is not None: + tr, va, te = self.splitter(C, X, Y) + self.train_idx, self.val_idx, self.test_idx = tr, va, te + + # If predict_idx not given, default to test indices (or full range if all None) + if self.predict_idx is None: + if self.test_idx is not None: + self.predict_idx = self.test_idx + else: + self.predict_idx = torch.arange(C.shape[0], dtype=torch.long) + + # Slice tensors per split (map-style datasets rely on correct len() for sharding) + def _mk_dataset(idx: IndexLike): + if idx is None: + return None + C_s = _maybe_index(C, idx) + X_s = _maybe_index(X, idx) + Y_s = None if (Y is None) else _maybe_index(Y, idx) + ds_cls = TASK_TO_DATASET[self.task_type] + # Y can be optional for some tasks; the dataset constructors you showed + # expect Y. If a task doesn't use Y, pass a placeholder or ensure callers pass X as Y when needed. + if Y_s is None: + # If Y is truly not used for this task_type, construct a compatible placeholder. + # Here we create zeros with appropriate last dim to match dataset expectations. + # For singletask_univariate/multivariate we assume Y has shape (n, y_dim). + # Override as needed if your upstream code guarantees a Y. + Y_s = torch.zeros((C_s.shape[0], X_s.shape[-1]), dtype=self.dtype) + return ds_cls(C_s, X_s, Y_s, dtype=self.dtype) + + self.ds_train = _mk_dataset(self.train_idx) + self.ds_val = _mk_dataset(self.val_idx) + self.ds_test = _mk_dataset(self.test_idx) + self.ds_predict = _mk_dataset(self.predict_idx) + + # Keep tensors for potential later use + self.C, self.X, self.Y = C, X, Y + + # ---- Dataloaders ---- + def _common_dl_kwargs(self) -> Dict: + return { + "batch_size": self.batch_size, + "num_workers": self.num_workers, + "pin_memory": self.pin_memory, + "persistent_workers": self.persistent_workers, + "drop_last": self.drop_last, + } + + def train_dataloader(self) -> DataLoader: + if self.ds_train is None: + raise RuntimeError("train dataset is not set; provide train_idx or splitter.") + return DataLoader( + dataset=self.ds_train, + shuffle=self.shuffle_train, # True only for train + **self._common_dl_kwargs(), + ) + + def val_dataloader(self) -> DataLoader: + if self.ds_val is None: + raise RuntimeError("val dataset is not set; provide val_idx or splitter.") + return DataLoader( + dataset=self.ds_val, + shuffle=self.shuffle_eval, # False by default + **self._common_dl_kwargs(), + ) + + def test_dataloader(self) -> DataLoader: + if self.ds_test is None: + raise RuntimeError("test dataset is not set; provide test_idx or splitter.") + return DataLoader( + dataset=self.ds_test, + shuffle=self.shuffle_eval, # False by default + **self._common_dl_kwargs(), + ) + + def predict_dataloader(self) -> DataLoader: + if self.ds_predict is None: + raise RuntimeError("predict dataset is not set; provide predict_idx/test_idx.") + # IMPORTANT: keep shuffle=False for stable ordering per-rank + return DataLoader( + dataset=self.ds_predict, + shuffle=False, + **self._common_dl_kwargs(), + ) diff --git a/contextualized/regression/datasets.py b/contextualized/regression/datasets.py index 911997cc..d0b85259 100644 --- a/contextualized/regression/datasets.py +++ b/contextualized/regression/datasets.py @@ -62,9 +62,10 @@ class MultitaskMultivariateDataset(Dataset): Multi-task Multivariate Dataset. """ def __init__(self, C, X, Y, dtype=torch.float): - self.C = torch.tensor(C, dtype=dtype) - self.X = torch.tensor(X, dtype=dtype) - self.Y = torch.tensor(Y, dtype=dtype) + self.C = C.to(dtype) if isinstance(C, torch.Tensor) else torch.as_tensor(C, dtype=dtype) + self.X = X.to(dtype) if isinstance(X, torch.Tensor) else torch.as_tensor(X, dtype=dtype) + self.Y = Y.to(dtype) if isinstance(Y, torch.Tensor) else torch.as_tensor(Y, dtype=dtype) + self.c_dim = C.shape[-1] self.x_dim = X.shape[-1] self.y_dim = Y.shape[-1] diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index c31398c9..14345c34 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -10,6 +10,9 @@ Implemented with PyTorch Lightning """ +# For distributed runs, use the ContextualizedRegressionDataModule which returns +# map-style datasets and allows Lightning's Trainer to auto-shard with DDP. +from .datamodules import ContextualizedRegressionDataModule # noqa: F401 from abc import abstractmethod import numpy as np diff --git a/contextualized_sanity_run.py b/contextualized_sanity_run.py new file mode 100644 index 00000000..4c364ee7 --- /dev/null +++ b/contextualized_sanity_run.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python +import argparse, sys, time, warnings, socket, os +from pathlib import Path +import numpy as np + +warnings.filterwarnings("ignore", message="To copy construct from a tensor", category=UserWarning) + +try: + import torch + import lightning as pl + from lightning.pytorch.strategies import DDPStrategy +except Exception as e: + print(f"[FATAL] torch/lightning import failed: {e}"); sys.exit(1) + +try: + from contextualized.regression.datamodules import ContextualizedRegressionDataModule + from contextualized.regression.datasets import ( + MultivariateDataset, UnivariateDataset, MultitaskMultivariateDataset, MultitaskUnivariateDataset + ) + _ctx_ok = True +except Exception as e: + print(f"[FATAL] Could not import contextualized modules: {e}"); _ctx_ok = False + +def _free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + +def make_tensors(n=32, c_dim=3, x_dim=5, y_dim=4, dtype=torch.float32, seed=7): + rng = np.random.default_rng(seed) + C = torch.tensor(rng.normal(size=(n, c_dim)), dtype=dtype) + X = torch.tensor(rng.normal(size=(n, x_dim)), dtype=dtype) + Y = torch.tensor(rng.normal(size=(n, y_dim)), dtype=dtype) + return C, X, Y + +def simple_splitter(C, X, Y): + n = C.shape[0]; idx = torch.arange(n, dtype=torch.long) + n_tr = int(0.6*n); n_va = int(0.2*n) + return idx[:n_tr], idx[n_tr:n_tr+n_va], idx[n_tr+n_va:] + +class TinyLightning(pl.LightningModule): + def __init__(self, in_dim=5, out_dim=4, lr=1e-3): + super().__init__(); self.save_hyperparameters() + self.head = torch.nn.Linear(in_dim, out_dim, bias=False) + torch.manual_seed(0) + with torch.no_grad(): + w = torch.arange(in_dim*out_dim).float().reshape(out_dim, in_dim)/100.0 + self.head.weight.copy_(w) + self.mu = torch.nn.Parameter(torch.zeros(out_dim, 1)) + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=self.hparams.lr) + def training_step(self, batch, batch_idx): + betas = self.head(batch["predictors"]); loss = (betas**2).mean() + self.log("train_loss", loss, on_epoch=True, prog_bar=False, logger=False); return loss + @torch.no_grad() + def predict_step(self, batch, batch_idx, dataloader_idx=0): + betas = self.head(batch["predictors"]); mus = self.mu.view(1,-1).repeat(betas.shape[0],1) + return {"idx": batch["idx"].detach().clone().cpu(), + "betas": betas.detach().cpu(), + "mus": mus.detach().cpu()} + +def check_dataset_shapes(C, X, Y): + print("\n[CHECK] Dataset constructors & shapes") + mv = MultivariateDataset(C, X, Y); uv = UnivariateDataset(C, X, Y) + mtmv = MultitaskMultivariateDataset(C, X, Y); mtuv = MultitaskUnivariateDataset(C, X, Y) + print(f" MultivariateDataset: len={len(mv)} sample keys={list(mv[0].keys())}") + print(f" UnivariateDataset: len={len(uv)} sample keys={list(uv[0].keys())}") + print(f" MultitaskMultivariateDataset: len={len(mtmv)}") + print(f" MultitaskUnivariateDataset: len={len(mtuv)}") + for name, ds in [("MultivariateDataset", mv), ("UnivariateDataset", uv), + ("MultitaskMultivariateDataset", mtmv), ("MultitaskUnivariateDataset", mtuv)]: + s = ds[0] + for k in ("idx","contexts","predictors","outcomes"): + assert k in s, f"{name} sample missing '{k}'" + print(" ✔ Map-style and key shape checks passed.") + +def run_single_process(dm, x_dim, y_dim, max_epochs=1): + print("\n[RUN] Single-process (CPU) trainer...") + model = TinyLightning(in_dim=x_dim, out_dim=y_dim) + trainer = pl.Trainer(accelerator="cpu", devices=1, max_epochs=max_epochs, + logger=False, enable_progress_bar=False, + default_root_dir=str(Path("./_tmp_sanity").resolve()), + enable_checkpointing=False) + tic = time.time(); trainer.fit(model, datamodule=dm) + outs = trainer.predict(model, datamodule=dm); sec = time.time() - tic + idx = torch.cat([o["idx"] for o in outs]).numpy() + betas = torch.cat([o["betas"] for o in outs]); mus = torch.cat([o["mus"] for o in outs]) + print(f" Predict returned {len(idx)} rows in {sec:.2f}s") + print(f" idx head: {idx[:10]}") + print(f" betas shape: {tuple(betas.shape)}, device={betas.device.type}") + print(f" mus shape: {tuple(mus.shape)}, device={mus.device.type}") + assert betas.device.type == "cpu" and mus.device.type == "cpu" + assert (idx == np.sort(idx)).all() + assert len(np.unique(idx)) == len(idx) + print(" ✔ Single-process checks passed.") + return idx, betas, mus + +def run_ddp(dm, x_dim, y_dim, world_size=2): + print(f"\n[RUN] DDP spawn (CPU, world_size={world_size})...") + # Force local master & explicit init_method to ignore any stale env vars + addr = "127.0.0.1"; port = _free_port() + os.environ["MASTER_ADDR"] = addr + os.environ["MASTER_PORT"] = str(port) + strategy = DDPStrategy(process_group_backend="gloo", + init_method=f"tcp://{addr}:{port}") + model = TinyLightning(in_dim=x_dim, out_dim=y_dim) + trainer = pl.Trainer(accelerator="cpu", devices=world_size, strategy=strategy, + max_epochs=0, logger=False, enable_progress_bar=False, + default_root_dir=str(Path("./_tmp_sanity_ddp").resolve()), + enable_checkpointing=False) + outs = trainer.predict(model, datamodule=dm) + idx = torch.cat([o["idx"] for o in outs]).numpy() + betas = torch.cat([o["betas"] for o in outs]); mus = torch.cat([o["mus"] for o in outs]) + print(f" Gathered rows: {len(idx)} (unique={len(np.unique(idx))})") + print(f" idx head: {idx[:10]}") + print(f" betas shape: {tuple(betas.shape)}, device={betas.device.type}") + print(f" mus shape: {tuple(mus.shape)}, device={mus.device.type}") + assert betas.device.type == "cpu" and mus.device.type == "cpu" + assert len(np.unique(idx)) == len(idx) + print(" ✔ DDP checks passed.") + return idx, betas, mus + +def maybe_try_wrapper(X): + try: + from contextualized.easy.wrappers.SKLearnWrapper import SKLearnWrapper # type: ignore + except Exception as e: + print(f"[INFO] SKLearnWrapper not available ({e}); skipping wrapper test."); return + print("\n[TRY] SKLearnWrapper in-memory vs memory-bounded (if supported)...") + class DummyEstimator: + def fit(self, C, X, Y): return self + def predict(self, X): + if isinstance(X, torch.Tensor): return X.sum(-1, keepdim=True).numpy() + return X.sum(-1, keepdims=True) + try: + wrapper = SKLearnWrapper(estimator=DummyEstimator()) + p1 = np.asarray(wrapper.predict(X, memory_bounded=False)) + p2 = np.asarray(wrapper.predict(X, memory_bounded=True)) + print(f" wrapper outputs shapes: {p1.shape} vs {p2.shape}") + print(" ✔ Wrapper paths match on toy data." if (p1.shape==p2.shape and np.allclose(p1,p2,1e-6,1e-6)) + else " ⚠ Wrapper paths differ on toy data.") + except TypeError as e: + print(f" ⚠ Wrapper signature mismatch: {e} — skipping.") + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--n", type=int, default=32) + ap.add_argument("--c-dim", type=int, default=3) + ap.add_argument("--x-dim", type=int, default=5) + ap.add_argument("--y-dim", type=int, default=4) + ap.add_argument("--batch-size", type=int, default=8) + ap.add_argument("--num-workers", type=int, default=0) + ap.add_argument("--ddp", type=int, default=0) + ap.add_argument("--try-wrapper", action="store_true") + args = ap.parse_args() + if not _ctx_ok: sys.exit(2) + C, X, Y = make_tensors(n=args.n, c_dim=args.c_dim, x_dim=args.x_dim, y_dim=args.y_dim) + check_dataset_shapes(C, X, Y) + print("\n[BUILD] ContextualizedRegressionDataModule") + dm = ContextualizedRegressionDataModule( + C=C, X=X, Y=Y, task_type="singletask_multivariate", + batch_size=args.batch_size, num_workers=args.num_workers, + shuffle_eval=False, shuffle_train=True, pin_memory=False, persistent_workers=False, + splitter=simple_splitter, + ) + dm.setup("fit") + idx1, betas1, mus1 = run_single_process(dm, x_dim=args.x_dim, y_dim=args.y_dim) + if args.ddp and args.ddp > 1: + idx2, betas2, mus2 = run_ddp(dm, x_dim=args.x_dim, y_dim=args.y_dim, world_size=args.ddp) + assert set(idx1.tolist()) == set(idx2.tolist()), "DDP vs single-process index coverage mismatch" + print(" ✔ DDP vs single-process index coverage matches.") + if args.try_wrapper: maybe_try_wrapper(X) + print("\n✅ ALL SANITY CHECKS COMPLETED SUCCESSFULLY") + +if __name__ == "__main__": + main() diff --git a/scripts/test_contextualized_dm.py b/scripts/test_contextualized_dm.py new file mode 100644 index 00000000..7a90514e --- /dev/null +++ b/scripts/test_contextualized_dm.py @@ -0,0 +1,252 @@ +# scripts/test_contextualized_dm.py +""" +Smoke-test your ContextualizedRegressionDataModule with synthetic data. + +Examples: + # Single-process sanity check + python scripts/test_contextualized_dm.py --task-type singletask_multivariate --peek + + # CPU DDP on Windows (Git Bash or PowerShell) + python scripts/test_contextualized_dm.py --task-type singletask_multivariate --devices 2 --peek +""" + +from __future__ import annotations +import argparse +import os +import sys +import tempfile +from pathlib import Path +from typing import Dict, Optional, Tuple + +import torch +from torch import nn +import lightning as pl +from lightning.pytorch.strategies import DDPStrategy + +# --- Make repo root importable if running from source tree --- +REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if REPO_ROOT not in sys.path: + sys.path.insert(0, REPO_ROOT) + +from contextualized.regression.datamodules import ContextualizedRegressionDataModule + +# ---- Candidate key names in your batch dict ---- +CTX_CANDIDATES = ("contexts", "context", "ctx", "C", "c") +X_CANDIDATES = ("predictors", "X", "features", "x", "inputs", "data") +Y_CANDIDATES = ("outcomes", "Y", "targets", "y", "labels") + + +def pick_first_key(d: Dict[str, torch.Tensor], candidates) -> Optional[str]: + for k in candidates: + if k in d: + return k + return None + + +# --------------------------- +# Synthetic (C, X, Y) +# --------------------------- +def make_synthetic( + n: int, + c_dim: int, + x_dim: int, + y_dim: int, + seed: int = 1234, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + g = torch.Generator().manual_seed(seed) + C = torch.randn(n, c_dim, generator=g) + X = torch.randn(n, x_dim, generator=g) + W = torch.randn(x_dim, y_dim, generator=g) / (x_dim ** 0.5) + Y = X @ W + 0.05 * torch.randn(n, y_dim, generator=g) + return C, X, Y + + +def make_indices(n: int, train_frac=0.7, val_frac=0.15) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + idx = torch.randperm(n) + n_train = int(n * train_frac) + n_val = int(n * val_frac) + train_idx = idx[:n_train] + val_idx = idx[n_train:n_train + n_val] + test_idx = idx[n_train + n_val:] + return train_idx, val_idx, test_idx + + +# --------------------------- +# Tiny adaptive model +# --------------------------- +class AdaptiveTinyModel(pl.LightningModule): + """ + - If batch has (features, targets): Linear -> MSE + - Else if batch has "contexts": mean(contexts**2) + - Else: mean of first float tensor + Holds an anchor param so the optimizer is never empty. + """ + def __init__(self, x_dim: Optional[int] = None, y_dim: Optional[int] = None, lr: float = 1e-2): + super().__init__() + self.lr = lr + self.mse = nn.MSELoss() + self._anchor = nn.Parameter(torch.tensor(0.0)) + self.head = nn.Linear(x_dim, y_dim) if (x_dim is not None and y_dim is not None) else None + + def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + x_key = pick_first_key(batch, X_CANDIDATES) + y_key = pick_first_key(batch, Y_CANDIDATES) + + if x_key and y_key: + x = batch[x_key].float() + y = batch[y_key].float() + if x.ndim == 3: # (B, T, D) + B, T, D = x.shape + x = x.view(B * T, D) + y = y.view(B * T, -1) + if self.head is None: + self.head = nn.Linear(x.shape[-1], y.shape[-1]).to(self.device) + preds = self.head(x) + return self.mse(preds, y) + + c_key = pick_first_key(batch, CTX_CANDIDATES) + if c_key: + c = batch[c_key].float() + return (c ** 2).mean() + + for k, v in batch.items(): + if torch.is_tensor(v) and v.dtype.is_floating_point: + return (v.float() ** 2).mean() + + raise RuntimeError("No usable tensor found in batch to compute a loss.") + + def training_step(self, batch, batch_idx): + loss = self._compute_loss(batch) + self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + loss = self._compute_loss(batch) + self.log("val_loss", loss, on_epoch=True, prog_bar=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=self.lr) + + +# --------------------------- +# CLI / Trainer +# --------------------------- +def parse_args(): + p = argparse.ArgumentParser(description="Test ContextualizedRegressionDataModule") + p.add_argument("--task-type", + choices=[ + "singletask_multivariate", + "singletask_univariate", + "multitask_multivariate", + "multitask_univariate", + ], + required=True) + p.add_argument("--n", type=int, default=256, help="Total samples") + p.add_argument("--c-dim", type=int, default=8, help="Context dim") + p.add_argument("--x-dim", type=int, default=16, help="Feature dim") + p.add_argument("--y-dim", type=int, default=4, help="Target dim") + p.add_argument("--batch-size", type=int, default=32) + p.add_argument("--num-workers", type=int, default=0) + p.add_argument("--devices", type=int, default=1) + p.add_argument("--max-epochs", type=int, default=1) + p.add_argument("--limit-train-batches", type=float, default=2) + p.add_argument("--limit-val-batches", type=float, default=1) + p.add_argument("--peek", action="store_true", help="Print first batch keys/shapes") + return p.parse_args() + + +def _unset_dist_env(): + # Ensure env:// rendezvous is NOT selected + for k in ("MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK", "LOCAL_RANK", "INIT_METHOD"): + if k in os.environ: + os.environ.pop(k) + + +def build_trainer(args) -> pl.Trainer: + if args.devices > 1: + # Force local file-store DDP init (no sockets/ports) + init_path = Path(tempfile.gettempdir()) / f"pl_init_{os.getpid()}.pt" + init_uri = init_path.as_uri() # proper file:///C:/... on Windows + strategy = DDPStrategy( + process_group_backend="gloo", + init_method=init_uri, + ) + else: + strategy = "auto" + + return pl.Trainer( + accelerator="cpu", + devices=args.devices, + strategy=strategy, + max_epochs=args.max_epochs, + limit_train_batches=args.limit_train_batches, + limit_val_batches=args.limit_val_batches, + enable_progress_bar=True, + logger=False, + ) + + +def main(): + # Windows-safe start method for spawn/DDP + try: + import torch.multiprocessing as mp + mp.set_start_method("spawn", force=True) + except RuntimeError: + pass + + _unset_dist_env() + args = parse_args() + + # --- synthetic data + splits --- + C, X, Y = make_synthetic(n=args.n, c_dim=args.c_dim, x_dim=args.x_dim, y_dim=args.y_dim) + train_idx, val_idx, test_idx = make_indices(args.n) + + # --- your datamodule --- + dm = ContextualizedRegressionDataModule( + C=C, X=X, Y=Y, + task_type=args.task_type, + train_idx=train_idx, + val_idx=val_idx, + test_idx=test_idx, + predict_idx=None, + batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=False, # CPU run + persistent_workers=False, # safe with num_workers=0 + drop_last=False, + shuffle_train=True, + shuffle_eval=False, + dtype=torch.float, + ) + + # Setup and peek one batch to infer dims so model has parameters before optimizer init + dm.setup("fit") + sample = next(iter(dm.train_dataloader())) + if args.peek: + print("[peek] batch keys:", list(sample.keys())) + for k, v in sample.items(): + if torch.is_tensor(v): + print(f"[peek] {k}: shape={tuple(v.shape)} dtype={v.dtype}") + print() + + # Infer x_dim/y_dim from batch (handles (B,T,D)) + x_key = pick_first_key(sample, X_CANDIDATES) + y_key = pick_first_key(sample, Y_CANDIDATES) + x_dim = y_dim = None + if x_key and y_key: + x = sample[x_key] + y = sample[y_key] + x_dim = x.shape[-1] + y_dim = y.shape[-1] + + # Build model (now has params) + model = AdaptiveTinyModel(x_dim=x_dim, y_dim=y_dim) + + # Trainer + trainer = build_trainer(args) + trainer.fit(model, dm) + print("✅ Test completed successfully.") + + +if __name__ == "__main__": + main() From 40146130ef37dad4d13669984d3ce3b78288da3d Mon Sep 17 00:00:00 2001 From: Samuel Wales-McGrath Date: Wed, 5 Nov 2025 13:39:24 -0500 Subject: [PATCH 02/19] Accumulated changes from osc server repo updates --- contextualized/__init__.py | 17 + contextualized/easy/ContextualGAM.py | 33 +- .../easy/ContextualizedClassifier.py | 32 +- contextualized/easy/ContextualizedNetworks.py | 289 ++----- .../easy/ContextualizedRegressor.py | 8 +- .../easy/wrappers/SKLearnWrapper.py | 802 +++++++++--------- contextualized/modules.py | 41 +- contextualized/regression/datamodules.py | 23 +- contextualized/regression/datasets.py | 18 +- .../regression/lightning_modules.py | 199 ++--- contextualized/regression/trainers.py | 180 +++- contextualized/utils/__init__.py | 0 contextualized/utils/engine.py | 40 + contextualized_sanity_run.py | 175 ---- 14 files changed, 899 insertions(+), 958 deletions(-) create mode 100644 contextualized/utils/__init__.py create mode 100644 contextualized/utils/engine.py delete mode 100644 contextualized_sanity_run.py diff --git a/contextualized/__init__.py b/contextualized/__init__.py index 2f24b87b..5b46cfe0 100644 --- a/contextualized/__init__.py +++ b/contextualized/__init__.py @@ -2,7 +2,13 @@ models, distributions, and functions with context-specific parameters. For more details, please refer to contextualized.ml. """ +import torch +if torch.cuda.is_available(): + try: + torch.set_float32_matmul_precision("high") # use TF32 kernels + except Exception: + pass from contextualized import analysis from contextualized import dags from contextualized import easy @@ -10,3 +16,14 @@ from contextualized import baselines from contextualized import utils from contextualized.utils import * + + +import os +os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") +os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "1") +os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") +# single-node default (disable IB unless you know you need it) +os.environ.setdefault("NCCL_IB_DISABLE", "1") +os.environ.setdefault("NCCL_P2P_DISABLE", "0") +from .utils.engine import pick_engine # optional re-export +__all__ = ["pick_engine"] diff --git a/contextualized/easy/ContextualGAM.py b/contextualized/easy/ContextualGAM.py index 5ea6cda5..f21777d0 100644 --- a/contextualized/easy/ContextualGAM.py +++ b/contextualized/easy/ContextualGAM.py @@ -4,46 +4,29 @@ for more details. """ -from contextualized.easy import ContextualizedClassifier, ContextualizedRegressor +from contextualized.easy import ContextualizedClassifier +from contextualized.easy import ContextualizedRegressor class ContextualGAMClassifier(ContextualizedClassifier): """ - The Contextual GAM Classifier separates and interprets the effect of context in context-varying decisions and classifiers, such as heterogeneous disease diagnoses. - Implemented as a Contextual Generalized Additive Model with a classifier on top. - Always uses a Neural Additive Model ("ngam") encoder for interpretability. - See `this paper `__ - for more details. - - Args: - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. - alpha (float, optional): Regularization strength. Defaults to 0.0. - mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. - l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. + Contextual GAM Classifier with a Neural Additive Model ("ngam") encoder. + Inherits the sklearn-like API from ContextualizedClassifier. """ def __init__(self, **kwargs): + # Force interpretability via NAM encoder kwargs["encoder_type"] = "ngam" super().__init__(**kwargs) class ContextualGAMRegressor(ContextualizedRegressor): """ - The Contextual GAM Regressor separates and interprets the effect of context in context-varying relationships, such as heterogeneous treatment effects. - Implemented as a Contextual Generalized Additive Model with a linear regressor on top. - Always uses a Neural Additive Model ("ngam") encoder for interpretability. - See `this paper `__ - for more details. - - Args: - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. - alpha (float, optional): Regularization strength. Defaults to 0.0. - mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. - l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. + Contextual GAM Regressor with a Neural Additive Model ("ngam") encoder. + Inherits the sklearn-like API from ContextualizedRegressor. """ def __init__(self, **kwargs): + # Force interpretability via NAM encoder kwargs["encoder_type"] = "ngam" super().__init__(**kwargs) diff --git a/contextualized/easy/ContextualizedClassifier.py b/contextualized/easy/ContextualizedClassifier.py index 30a9d980..0f057e6b 100644 --- a/contextualized/easy/ContextualizedClassifier.py +++ b/contextualized/easy/ContextualizedClassifier.py @@ -13,14 +13,6 @@ class ContextualizedClassifier(ContextualizedRegressor): """ Contextualized Logistic Regression reveals context-dependent decisions and decision boundaries. Implemented as a ContextualizedRegressor with logistic link function and binary cross-entropy loss. - - Args: - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. - encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". - alpha (float, optional): Regularization strength. Defaults to 0.0. - mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. - l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. """ def __init__(self, **kwargs): @@ -29,30 +21,16 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def predict(self, C, X, individual_preds=False, **kwargs): - """Predict binary outcomes from context C and predictors X. - - Args: - C (np.ndarray): Context array of shape (n_samples, n_context_features) - X (np.ndarray): Predictor array of shape (N, n_features) - individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The binary outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True. - """ + """Predict binary outcomes from context C and predictors X.""" return np.round(super().predict(C, X, individual_preds, **kwargs)) def predict_proba(self, C, X, **kwargs): """ Predict probabilities of outcomes from context C and predictors X. - Args: - C (np.ndarray): Context array of shape (n_samples, n_context_features) - X (np.ndarray): Predictor array of shape (N, n_features) - individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The outcome probabilities predicted by the context-specific models (n_samples, y_dim, 2). Returned as lists of individual bootstraps if individual_preds is True. + Returns + ------- + np.ndarray of shape (n_samples, y_dim, 2) """ - # Returns a np array of shape N samples, K outcomes, 2. - probs = super().predict(C, X, **kwargs) + probs = super().predict(C, X, **kwargs) # (n, y_dim[, 1]) return np.array([1 - probs, probs]).T.swapaxes(0, 1) diff --git a/contextualized/easy/ContextualizedNetworks.py b/contextualized/easy/ContextualizedNetworks.py index 1c4a8f26..0fd505a3 100644 --- a/contextualized/easy/ContextualizedNetworks.py +++ b/contextualized/easy/ContextualizedNetworks.py @@ -2,7 +2,7 @@ sklearn-like interface to Contextualized Networks. """ -from typing import * +from typing import List, Tuple, Union import numpy as np @@ -29,15 +29,7 @@ class ContextualizedNetworks(SKLearnWrapper): def _split_train_data( self, C: np.ndarray, X: np.ndarray, **kwargs ) -> Tuple[List[np.ndarray], List[np.ndarray]]: - """Splits data into train and test sets. - - Args: - C (np.ndarray): Contextual features for each sample. - X (np.ndarray): The data matrix. - - Returns: - Tuple[List[np.ndarray], List[np.ndarray]]: The train and test sets for C and X as ([C_train, X_train], [C_test, X_test]). - """ + """Splits data into train and val sets (no Y for networks).""" return super()._split_train_data(C, X, Y_required=False, **kwargs) def predict_networks( @@ -52,53 +44,27 @@ def predict_networks( Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]], ]: - """Predicts context-specific networks given contextual features. - - Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - with_offsets (bool, optional): If True, returns both the network parameters and offsets. Defaults to False. - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. - - Returns: - Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]]: The predicted network parameters (and offsets if with_offsets is True). Returned as lists of individual bootstraps if individual_preds is True. + """ + Predicts context-specific network parameters (and offsets if available). """ betas, mus = self.predict_params( C, individual_preds=individual_preds, uses_y=False, **kwargs ) - if with_offsets: - return betas, mus - return betas + return (betas, mus) if with_offsets else betas def predict_X( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs ) -> Union[np.ndarray, List[np.ndarray]]: - """Reconstructs the data matrix based on predicted contextualized networks and the true data matrix. - Useful for measuring reconstruction error or for imputation. - - Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - X (np.ndarray): The data matrix (n_samples, n_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. - **kwargs: Keyword arguments for the Lightning trainer's predict_y method. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The predicted data matrix, or matrices for each bootstrap if individual_preds is True (n_samples, n_features). + """ + Reconstructs X via predicted networks using the base wrapper predict(). """ return self.predict(C, X, individual_preds=individual_preds, **kwargs) class ContextualizedCorrelationNetworks(ContextualizedNetworks): """ - Contextualized Correlation Networks reveal context-varying feature correlations, interaction strengths, dependencies in feature groups. - Uses the Contextualized Networks model, see the `paper `__ for detailed estimation procedures. - - Args: - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - num_archetypes (int, optional): Number of archetypes to use. Defaults to 10. Always uses archetypes in the ContextualizedMetaModel. - encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". - alpha (float, optional): Regularization strength. Defaults to 0.0. - mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. - l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. + Contextualized Correlation Networks reveal context-varying feature correlations. + Uses the Contextualized Networks model. """ def __init__(self, **kwargs): @@ -109,73 +75,53 @@ def __init__(self, **kwargs): def predict_correlation( self, C: np.ndarray, individual_preds: bool = True, squared: bool = True ) -> Union[np.ndarray, List[np.ndarray]]: - """Predicts context-specific correlations between features. - - Args: - C (Numpy ndarray): Contextual features for each sample (n_samples, n_context_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to True. - squared (bool, optional): If True, returns the squared correlations. Defaults to True. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The predicted context-specific correlation matrices, or matrices for each bootstrap if individual_preds is True (n_samples, n_features, n_features). - """ - get_dataloader = lambda i: self.models[i].dataloader( - C, np.zeros((len(C), self.x_dim)) - ) - rhos = np.array( - [ - self.trainers[i].predict_params(self.models[i], get_dataloader(i))[0] - for i in range(len(self.models)) - ] + C_scaled = self._maybe_scale_C(C) + Y_zero = np.zeros((len(C_scaled), self.x_dim), dtype=np.float32) + dm = self._build_datamodule( + C=C_scaled, + X=np.zeros((len(C_scaled), self.x_dim), dtype=np.float32), + Y=Y_zero, + predict_idx=np.arange(len(C_scaled)), + data_kwargs=dict( + batch_size=self._init_kwargs["data"].get("val_batch_size", 16), + num_workers=self._init_kwargs["data"].get("num_workers", 0), + pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator == "gpu")), + persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), + shuffle_train=False, shuffle_eval=False, + dtype=self._init_kwargs["data"].get("dtype", torch.float), + ), + task_type="singletask_univariate", # correlation uses univariate convention ) + rhos = np.array([ + self.trainers[i].predict_correlation(self.models[i], dm.predict_dataloader()) + for i in range(len(self.models)) + ]) if individual_preds: - if squared: - return np.square(rhos) - return rhos - else: - if squared: - return np.square(np.mean(rhos, axis=0)) - return np.mean(rhos, axis=0) + return np.square(rhos) if squared else rhos + mean_rhos = np.mean(rhos, axis=0) + return np.square(mean_rhos) if squared else mean_rhos def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False ) -> Union[np.ndarray, List[np.ndarray]]: - """Measures mean-squared errors. - - Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - X (np.ndarray): The data matrix (n_samples, n_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples). + """ + Measures mean-squared reconstruction errors using (betas, mus). """ betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True) mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples - for i in range(X.shape[-1]): - for j in range(X.shape[-1]): + F = X.shape[-1] + for i in range(F): + for j in range(F): tiled_xi = np.array([X[:, i] for _ in range(len(betas))]) tiled_xj = np.array([X[:, j] for _ in range(len(betas))]) residuals = tiled_xi - betas[:, :, i, j] * tiled_xj - mus[:, :, i, j] - mses += residuals**2 / (X.shape[-1] ** 2) - if not individual_preds: - mses = np.mean(mses, axis=0) - return mses + mses += residuals**2 / (F**2) + return mses if individual_preds else np.mean(mses, axis=0) class ContextualizedMarkovNetworks(ContextualizedNetworks): """ - Contextualized Markov Networks reveal context-varying feature dependencies, cliques, and modules. - Implemented as Contextualized Gaussian Precision Matrices, directly interpretable as Markov Networks. - Uses the Contextualized Networks model, see the `paper `__ for detailed estimation procedures. - - Args: - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - num_archetypes (int, optional): Number of archetypes to use. Defaults to 10. Always uses archetypes in the ContextualizedMetaModel. - encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". - alpha (float, optional): Regularization strength. Defaults to 0.0. - mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. - l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. + Contextualized Markov Networks (Gaussian precision matrices). """ def __init__(self, **kwargs): @@ -184,97 +130,62 @@ def __init__(self, **kwargs): def predict_precisions( self, C: np.ndarray, individual_preds: bool = True ) -> Union[np.ndarray, List[np.ndarray]]: - """Predicts context-specific precision matrices. - Can be converted to context-specific Markov networks by binarizing the networks and setting all non-zero entries to 1. - Can be converted to context-specific covariance matrices by taking the inverse. - - Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to True. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The predicted context-specific Markov networks as precision matrices, or matrices for each bootstrap if individual_preds is True (n_samples, n_features, n_features). """ - get_dataloader = lambda i: self.models[i].dataloader( - C, np.zeros((len(C), self.x_dim)) - ) - precisions = np.array( - [ - self.trainers[i].predict_precision(self.models[i], get_dataloader(i)) - for i in range(len(self.models)) - ] + Predicts context-specific precision matrices. + """ + C_scaled = self._maybe_scale_C(C) + Y_zero = np.zeros((len(C_scaled), self.x_dim), dtype=np.float32) + dm = self._build_datamodule( + C=C_scaled, + X=np.zeros((len(C_scaled), self.x_dim), dtype=np.float32), + Y=Y_zero, + predict_idx=np.arange(len(C_scaled)), + data_kwargs=dict( + batch_size=self._init_kwargs["data"].get("val_batch_size", 16), + num_workers=self._init_kwargs["data"].get("num_workers", 0), + pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator == "gpu")), + persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), + shuffle_train=False, shuffle_eval=False, + dtype=self._init_kwargs["data"].get("dtype", torch.float), + ), + task_type="singletask_univariate", ) - if individual_preds: - return precisions - return np.mean(precisions, axis=0) + precisions = np.array([ + self.trainers[i].predict_precision(self.models[i], dm.predict_dataloader()) + for i in range(len(self.models)) + ]) + return precisions if individual_preds else np.mean(precisions, axis=0) def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False ) -> Union[np.ndarray, List[np.ndarray]]: - """Measures mean-squared errors. - - Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - X (np.ndarray): The data matrix (n_samples, n_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples). + """ + Measures mean-squared reconstruction errors using precision-implied betas/mus. """ betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True) mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples - for bootstrap in range(len(betas)): - for i in range(X.shape[-1]): - # betas are n_boostraps x n_samples x n_features x n_features - # preds[bootstrap, sample, i] = X[sample, :].dot(betas[bootstrap, sample, i, :]) + F = X.shape[-1] + for b in range(len(betas)): + for i in range(F): preds = np.array( [ - X[j].dot(betas[bootstrap, j, i, :]) + mus[bootstrap, j, i] + X[j].dot(betas[b, j, i, :]) + mus[b, j, i] for j in range(len(X)) ] ) residuals = X[:, i] - preds - mses[bootstrap, :] += residuals**2 / (X.shape[-1]) - if not individual_preds: - mses = np.mean(mses, axis=0) - return mses + mses[b, :] += residuals**2 / F + return mses if individual_preds else np.mean(mses, axis=0) class ContextualizedBayesianNetworks(ContextualizedNetworks): """ - Contextualized Bayesian Networks and Directed Acyclic Graphs (DAGs) reveal context-dependent causal relationships, effect sizes, and variable ordering. - Uses the NOTMAD model, see the `paper `__ for detailed estimation procedures. - - Args: - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - num_archetypes (int, optional): Number of archetypes to use. Defaults to 16. Always uses archetypes in the ContextualizedMetaModel. - encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". - archetype_dag_loss_type (str, optional): The type of loss to use for the archetype loss. Defaults to "l1". - archetype_l1 (float, optional): The strength of the l1 regularization for the archetype loss. Defaults to 0.0. - archetype_dag_params (dict, optional): Parameters for the archetype loss. Defaults to {"loss_type": "l1", "params": {"alpha": 0.0, "rho": 0.0, "s": 0.0, "tol": 1e-4}}. - archetype_dag_loss_params (dict, optional): Parameters for the archetype loss. Defaults to {"alpha": 0.0, "rho": 0.0, "s": 0.0, "tol": 1e-4}. - archetype_alpha (float, optional): The strength of the alpha regularization for the archetype loss. Defaults to 0.0. - archetype_rho (float, optional): The strength of the rho regularization for the archetype loss. Defaults to 0.0. - archetype_s (float, optional): The strength of the s regularization for the archetype loss. Defaults to 0.0. - archetype_tol (float, optional): The tolerance for the archetype loss. Defaults to 1e-4. - archetype_use_dynamic_alpha_rho (bool, optional): Whether to use dynamic alpha and rho for the archetype loss. Defaults to False. - init_mat (np.ndarray, optional): The initial adjacency matrix for the archetype loss. Defaults to None. - num_factors (int, optional): The number of factors for the archetype loss. Defaults to 0. - factor_mat_l1 (float, optional): The strength of the l1 regularization for the factor matrix for the archetype loss. Defaults to 0. - sample_specific_dag_loss_type (str, optional): The type of loss to use for the sample-specific loss. Defaults to "l1". - sample_specific_alpha (float, optional): The strength of the alpha regularization for the sample-specific loss. Defaults to 0.0. - sample_specific_rho (float, optional): The strength of the rho regularization for the sample-specific loss. Defaults to 0.0. - sample_specific_s (float, optional): The strength of the s regularization for the sample-specific loss. Defaults to 0.0. - sample_specific_tol (float, optional): The tolerance for the sample-specific loss. Defaults to 1e-4. - sample_specific_use_dynamic_alpha_rho (bool, optional): Whether to use dynamic alpha and rho for the sample-specific loss. Defaults to False. + Contextualized Bayesian Networks (NOTMAD): context-dependent DAGs. """ def _parse_private_init_kwargs(self, **kwargs): """ - Parses the kwargs for the NOTMAD model. - - Args: - **kwargs: Keyword arguments for the NOTMAD model, including the encoder, archetype loss, sample-specific loss, and optimization parameters. + Parse NOTMAD kwargs into model init dicts. """ # Encoder Parameters self._init_kwargs["model"]["encoder_kwargs"] = { @@ -288,7 +199,7 @@ def _parse_private_init_kwargs(self, **kwargs): }, } - # Archetype-specific parameters + # Archetype parameters archetype_dag_loss_type = kwargs.pop( "archetype_dag_loss_type", DEFAULT_DAG_LOSS_TYPE ) @@ -309,25 +220,24 @@ def _parse_private_init_kwargs(self, **kwargs): "factor_mat_l1": kwargs.pop("factor_mat_l1", 0), "num_archetypes": kwargs.pop("num_archetypes", 16), } - if self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] <= 0: print( "WARNING: num_archetypes is 0. NOTMAD requires archetypes. Setting num_archetypes to 16." ) self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] = 16 - # Possibly update values with convenience parameters + # Allow convenience overrides for archetype DAG params for param, value in self._init_kwargs["model"]["archetype_loss_params"]["dag"][ "params" ].items(): self._init_kwargs["model"]["archetype_loss_params"]["dag"]["params"][ param ] = kwargs.pop(f"archetype_{param}", value) + + # Sample-specific parameters sample_specific_dag_loss_type = kwargs.pop( "sample_specific_dag_loss_type", DEFAULT_DAG_LOSS_TYPE ) - - # Sample-specific parameters self._init_kwargs["model"]["sample_specific_loss_params"] = { "l1": kwargs.pop("sample_specific_l1", 0.0), "dag": kwargs.pop( @@ -341,8 +251,6 @@ def _parse_private_init_kwargs(self, **kwargs): }, ), } - - # Possibly update values with convenience parameters for param, value in self._init_kwargs["model"]["sample_specific_loss_params"][ "dag" ]["params"].items(): @@ -402,14 +310,8 @@ def __init__(self, **kwargs): def predict_params( self, C: np.ndarray, **kwargs ) -> Union[np.ndarray, List[np.ndarray]]: - """Predicts context-specific Bayesian network parameters as linear coefficients in a linear structural equation model (SEM). - - Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The linear coefficients of the predicted context-specific Bayesian network parameters (n_samples, n_features, n_features). Returned as lists of individual bootstraps if individual_preds is True. + """ + Predicts context-specific Bayesian network parameters (SEM coefficients). """ # No mus for NOTMAD at present. return super().predict_params(C, model_includes_mus=False, **kwargs) @@ -417,15 +319,8 @@ def predict_params( def predict_networks( self, C: np.ndarray, project_to_dag: bool = True, **kwargs ) -> Union[np.ndarray, List[np.ndarray]]: - """Predicts context-specific Bayesian networks. - - Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - project_to_dag (bool, optional): If True, guarantees returned graphs are DAGs by trimming edges until acyclicity is satisified. Defaults to True. - **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The linear coefficients of the predicted context-specific Bayesian network parameters (n_samples, n_features, n_features). Returned as lists of individual bootstraps if individual_preds is True. + """ + Predicts context-specific Bayesian networks (optionally projected to DAG). """ if kwargs.pop("with_offsets", False): print("No offsets can be returned by NOTMAD.") @@ -437,22 +332,12 @@ def predict_networks( def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs ) -> Union[np.ndarray, List[np.ndarray]]: - """Measures mean-squared errors. - - Args: - C (np.ndarray): Contextual features for each sample (n_samples, n_context_features) - X (np.ndarray): The data matrix (n_samples, n_features) - individual_preds (bool, optional): If True, returns the predictions for each bootstrap. Defaults to False. - **kwargs: Keyword arguments for the contextualized.dags.GraphTrainer's predict_params method. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The mean-squared errors for each sample, or for each bootstrap if individual_preds is True (n_samples). + """ + Measures mean-squared errors of DAG-based reconstruction. """ betas = self.predict_networks(C, individual_preds=True, **kwargs) mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples - for bootstrap in range(len(betas)): - X_pred = dag_pred_np(X, betas[bootstrap]) - mses[bootstrap, :] = np.mean((X - X_pred) ** 2, axis=1) - if not individual_preds: - mses = np.mean(mses, axis=0) - return mses + for b in range(len(betas)): + X_pred = dag_pred_np(X, betas[b]) + mses[b, :] = np.mean((X - X_pred) ** 2, axis=1) + return mses if individual_preds else np.mean(mses, axis=0) diff --git a/contextualized/easy/ContextualizedRegressor.py b/contextualized/easy/ContextualizedRegressor.py index 275f2ee9..2ac98b3c 100644 --- a/contextualized/easy/ContextualizedRegressor.py +++ b/contextualized/easy/ContextualizedRegressor.py @@ -7,10 +7,7 @@ ContextualizedRegression, ) from contextualized.easy.wrappers import SKLearnWrapper -from contextualized.regression import RegressionTrainer - -# TODO: Multitask metamodels -# TODO: Task-specific link functions. +from contextualized.regression.trainers import RegressionTrainer # <-- updated import class ContextualizedRegressor(SKLearnWrapper): @@ -41,9 +38,11 @@ def __init__(self, **kwargs): archetypes, but this should be a non-negative integer.""" ) + # Wrapper will accept these; no need to expose DataModule specifics here. extra_model_kwargs = ["base_param_predictor", "base_y_predictor", "y_dim"] extra_data_kwargs = ["Y_val"] trainer_constructor = RegressionTrainer + super().__init__( constructor, extra_model_kwargs, @@ -52,5 +51,6 @@ def __init__(self, **kwargs): **kwargs, ) + # Preserve legacy behavior that Y is expected/required for regression fits def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs): return super()._split_train_data(C, X, Y, Y_required=True, **kwargs) diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index 101966e8..2e6e1759 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -1,20 +1,19 @@ -""" -An sklearn-like wrapper for Contextualized models. -""" - +# --- imports you need above the class --- import copy import os from typing import * - import numpy as np -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.callbacks import ModelCheckpoint +import torch from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler -import torch +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.plugins.environments import LightningEnvironment +from pytorch_lightning.strategies import DDPStrategy # PL v1 Strategy API from contextualized.functions import LINK_FUNCTIONS from contextualized.regression import REGULARIZERS, LOSSES +from contextualized.regression.datamodules import ContextualizedRegressionDataModule DEFAULT_LEARNING_RATE = 1e-3 DEFAULT_N_BOOTSTRAPS = 1 @@ -32,23 +31,10 @@ class SKLearnWrapper: """ - An sklearn-like wrapper for Contextualized models. - - Args: - base_constructor (class): The base class to construct the model. - extra_model_kwargs (dict): Extra kwargs to pass to the model constructor. - extra_data_kwargs (dict): Extra kwargs to pass to the dataloader constructor. - trainer_constructor (class): The trainer class to use. - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". - loss_fn (torch.nn.Module, optional): Loss function. Defaults to LOSSES["mse"]. - link_fn (torch.nn.Module, optional): Link function. Defaults to LINK_FUNCTIONS["identity"]. - alpha (float, optional): Regularization strength. Defaults to 0.0. - mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. - l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. - normalize (bool, optional): If True, automatically standardize inputs during training and inverse-transform predictions. Defaults to False. + An sklearn-like wrapper for Contextualized models, optimized for multi-GPU (DDP) scaling. """ + # ---------- defaults ---------- def _set_defaults(self): self.default_learning_rate = DEFAULT_LEARNING_RATE self.default_n_bootstraps = DEFAULT_N_BOOTSTRAPS @@ -73,17 +59,20 @@ def __init__( ): self._set_defaults() self.base_constructor = base_constructor + self.trainer_constructor = trainer_constructor + self.n_bootstraps = 1 self.models = None self.trainers = None - self.dataloaders = None + self.normalize = kwargs.pop("normalize", self.default_normalize) self.scalers = {"C": None, "X": None, "Y": None} self.context_dim = None self.x_dim = None self.y_dim = None - self.trainer_constructor = trainer_constructor self.accelerator = "gpu" if torch.cuda.is_available() else "cpu" + + # Acceptable kwargs routing self.acceptable_kwargs = { "data": [ "train_batch_size", @@ -92,6 +81,13 @@ def __init__( "C_val", "X_val", "val_split", + "num_workers", + "pin_memory", + "persistent_workers", + "drop_last", + "shuffle_train", + "shuffle_eval", + "dtype", ], "model": [ "loss_fn", @@ -104,6 +100,7 @@ def __init__( "learning_rate", "context_dim", "x_dim", + "y_dim", ], "trainer": [ "max_epochs", @@ -112,6 +109,17 @@ def __init__( "callbacks", "callback_constructors", "accelerator", + "devices", + "strategy", + "plugins", + "logger", + "enable_checkpointing", + "num_sanity_val_steps", + "default_root_dir", + "log_every_n_steps", + "precision", # allow explicit precision override if desired + "enable_progress_bar", + "limit_val_batches", ], "fit": [], "wrapper": [ @@ -132,6 +140,7 @@ def __init__( self._update_acceptable_kwargs( "data", kwargs.pop("remove_data_kwargs", []), acceptable=False ) + self.convenience_kwargs = [ "alpha", "l1_ratio", @@ -141,6 +150,8 @@ def __init__( "layers", "encoder_link_fn", ] + + # Model constructor kwargs self.constructor_kwargs = self._organize_constructor_kwargs(**kwargs) self.constructor_kwargs["encoder_kwargs"]["width"] = kwargs.pop( "width", self.constructor_kwargs["encoder_kwargs"]["width"] @@ -159,106 +170,18 @@ def __init__( for k, v in kwargs.items() if k not in self.constructor_kwargs and k not in self.convenience_kwargs } - # Some args will not be ignored by wrapper because sub-class will handle them. - # self.private_kwargs = kwargs.pop("private_kwargs", []) - # self.private_kwargs.append("private_kwargs") - # Add Predictor-Specific kwargs for parsing. - self._init_kwargs, unrecognized_general_kwargs = self._organize_kwargs( - **self.not_constructor_kwargs - ) - for key, value in self.constructor_kwargs.items(): - self._init_kwargs["model"][key] = value - recognized_private_init_kwargs = self._parse_private_init_kwargs(**kwargs) - for kwarg in set(unrecognized_general_kwargs) - set( - recognized_private_init_kwargs - ): - print(f"Received unknown keyword argument {kwarg}, probably ignoring.") - - def _organize_and_expand_fit_kwargs(self, **kwargs): - """ - Private function to organize kwargs passed to constructor or - fit function. - """ - organized_kwargs, unrecognized_general_kwargs = self._organize_kwargs(**kwargs) - recognized_private_kwargs = self._parse_private_fit_kwargs(**kwargs) - for kwarg in set(unrecognized_general_kwargs) - set(recognized_private_kwargs): - print(f"Received unknown keyword argument {kwarg}, probably ignoring.") - # Add kwargs from __init__ to organized_kwargs, keeping more recent kwargs. - for category, category_kwargs in self._init_kwargs.items(): - for key, value in category_kwargs.items(): - if key not in organized_kwargs[category]: - organized_kwargs[category][key] = value - - # Add necessary kwargs. - def maybe_add_kwarg(category, kwarg, default_val): - if kwarg in self.acceptable_kwargs[category]: - organized_kwargs[category][kwarg] = organized_kwargs[category].get( - kwarg, default_val - ) - # Model - maybe_add_kwarg("model", "learning_rate", self.default_learning_rate) - maybe_add_kwarg("model", "context_dim", self.context_dim) - maybe_add_kwarg("model", "x_dim", self.x_dim) - maybe_add_kwarg("model", "y_dim", self.y_dim) - if ( - "num_archetypes" in organized_kwargs["model"] - and organized_kwargs["model"]["num_archetypes"] == 0 - ): - del organized_kwargs["model"]["num_archetypes"] - - # Data - maybe_add_kwarg("data", "train_batch_size", self.default_train_batch_size) - maybe_add_kwarg("data", "val_batch_size", self.default_val_batch_size) - maybe_add_kwarg("data", "test_batch_size", self.default_test_batch_size) - - # Wrapper - maybe_add_kwarg("wrapper", "n_bootstraps", self.default_n_bootstraps) - - # Trainer - maybe_add_kwarg( - "trainer", - "callback_constructors", - [ - lambda i: EarlyStopping( - monitor=kwargs.get("es_monitor", "val_loss"), - mode=kwargs.get("es_mode", "min"), - patience=kwargs.get("es_patience", self.default_es_patience), - verbose=kwargs.get("es_verbose", False), - min_delta=kwargs.get("es_min_delta", 0.00), - ) - ], - ) - organized_kwargs["trainer"]["callback_constructors"].append( - lambda i: ModelCheckpoint( - monitor=kwargs.get("es_monitor", "val_loss"), - dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints", - filename="{epoch}-{val_loss:.2f}", - ) + self._init_kwargs, unrecognized = self._organize_kwargs( + **self.not_constructor_kwargs ) - maybe_add_kwarg("trainer", "accelerator", self.accelerator) - return organized_kwargs - - def _parse_private_fit_kwargs(self, **kwargs): - """ - Parse private (model-specific) kwargs passed to fit function. - Return the list of parsed kwargs. - """ - return [] - - def _parse_private_init_kwargs(self, **kwargs): - """ - Parse private (model-specific) kwargs passed to constructor. - Return the list of parsed kwargs. - """ - return [] + for k, v in self.constructor_kwargs.items(): + self._init_kwargs["model"][k] = v + if unrecognized: + for kw in unrecognized: + print(f"Received unknown keyword argument {kw}, probably ignoring.") + # ---------- helpers ---------- def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True): - """ - Helper function to update the acceptable kwargs. - If acceptable=True, the new kwargs will be added to the list of acceptable kwargs. - If acceptable=False, the new kwargs will be removed from the list of acceptable kwargs. - """ if acceptable: self.acceptable_kwargs[category] = list( set(self.acceptable_kwargs[category]).union(set(new_kwargs)) @@ -269,139 +192,256 @@ def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True): ) def _organize_kwargs(self, **kwargs): - """ - Private helper function to organize kwargs passed to constructor or - fit function. - Organizes kwargs into data, model, trainer, fit, and wrapper categories. - """ - - # Combine default allowed keywords with subclass-specfic - organized_kwargs = {category: {} for category in self.acceptable_kwargs} - unrecognized_kwargs = [] - for kwarg, value in kwargs.items(): - # if kwarg in self.private_kwargs: - # continue - not_found = True - for category, category_kwargs in self.acceptable_kwargs.items(): - if kwarg in category_kwargs: - organized_kwargs[category][kwarg] = value - not_found = False + out = {cat: {} for cat in self.acceptable_kwargs} + unknown = [] + for k, v in kwargs.items(): + placed = False + for cat, allowed in self.acceptable_kwargs.items(): + if k in allowed: + out[cat][k] = v + placed = True break - if not_found: - unrecognized_kwargs.append(kwarg) - - return organized_kwargs, unrecognized_kwargs + if not placed: + unknown.append(k) + return out, unknown def _organize_constructor_kwargs(self, **kwargs): - """ - Helper function to set all the default constructor or changes allowed. - """ - constructor_kwargs = {} + model = {} - def maybe_add_constructor_kwarg(kwarg, default_val): - if kwarg in self.acceptable_kwargs["model"]: - constructor_kwargs[kwarg] = kwargs.get(kwarg, default_val) + def maybe_add(kw, default_val): + if kw in self.acceptable_kwargs["model"]: + model[kw] = kwargs.get(kw, default_val) - maybe_add_constructor_kwarg("link_fn", LINK_FUNCTIONS["identity"]) - maybe_add_constructor_kwarg("univariate", False) - maybe_add_constructor_kwarg("encoder_type", self.default_encoder_type) - maybe_add_constructor_kwarg("loss_fn", LOSSES["mse"]) - maybe_add_constructor_kwarg( + maybe_add("link_fn", LINK_FUNCTIONS["identity"]) + maybe_add("univariate", False) + maybe_add("encoder_type", DEFAULT_ENCODER_TYPE) + maybe_add("loss_fn", LOSSES["mse"]) + maybe_add( "encoder_kwargs", { - "width": kwargs.get("encoder_width", self.default_encoder_width), - "layers": kwargs.get("encoder_layers", self.default_encoder_layers), - "link_fn": kwargs.get("encoder_link_fn", self.default_encoder_link_fn), + "width": kwargs.get("encoder_width", DEFAULT_ENCODER_WIDTH), + "layers": kwargs.get("encoder_layers", DEFAULT_ENCODER_LAYERS), + "link_fn": kwargs.get("encoder_link_fn", DEFAULT_ENCODER_LINK_FN), }, ) if kwargs.get("subtype_probabilities", False): - constructor_kwargs["encoder_kwargs"]["link_fn"] = LINK_FUNCTIONS["softmax"] + model["encoder_kwargs"]["link_fn"] = LINK_FUNCTIONS["softmax"] - # Make regularizer + # Regularizer if "model_regularizer" in self.acceptable_kwargs["model"]: - if "alpha" in kwargs and kwargs["alpha"] > 0: - constructor_kwargs["model_regularizer"] = REGULARIZERS["l1_l2"]( + if kwargs.get("alpha", 0) > 0: + model["model_regularizer"] = REGULARIZERS["l1_l2"]( kwargs["alpha"], kwargs.get("l1_ratio", 1.0), kwargs.get("mu_ratio", 0.5), ) else: - constructor_kwargs["model_regularizer"] = kwargs.get( + model["model_regularizer"] = kwargs.get( "model_regularizer", REGULARIZERS["none"] ) - return constructor_kwargs - - def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs): - if "C_val" in kwargs: - if "X_val" in kwargs: - if Y_required and "Y_val" in kwargs: - train_data = [C, X, Y] - val_data = [kwargs["C_val"], X, kwargs["X_val"], Y, kwargs["Y_val"]] - return train_data, val_data - print("Y_val not provided, not using the provided C_val or X_val.") - else: - print("X_val not provided, not using the provided C_val.") - if "val_split" in kwargs: - if 0 <= kwargs["val_split"] < 1: - val_split = kwargs["val_split"] - else: - print( - """val_split={kwargs['val_split']} provided but should be between 0 - and 1 to indicate proportion of data to use as validation.""" - ) - raise ValueError + return model + + # ---------- internal: sanitize callbacks when no val loop ---------- + @staticmethod + def _retarget_or_strip_early_stopping(cb, use_val: bool, train_monitor="train_loss"): + try: + from pytorch_lightning.callbacks.early_stopping import EarlyStopping + except Exception: + return cb + if not isinstance(cb, EarlyStopping): + return cb + if use_val: + return cb + # No val loop -> if monitoring val_* (or nothing), rebuild to watch train_loss + monitor = getattr(cb, "monitor", None) + if (monitor is None) or (isinstance(monitor, str) and monitor.startswith("val_")): + return EarlyStopping( + monitor=train_monitor, + mode=getattr(cb, "mode", "min"), + patience=getattr(cb, "patience", 1), + verbose=getattr(cb, "verbose", False), + min_delta=getattr(cb, "min_delta", 0.0), + ) + return cb + + # ---------- fit kwarg expansion (with DDP + ES logic) ---------- + def _organize_and_expand_fit_kwargs(self, **kwargs): + organized, unrecognized = self._organize_kwargs(**kwargs) + # --- FORCE max_epochs to be set (avoid PL default=1000) --- + max_epochs_cli = kwargs.get("max_epochs", None) + epochs_cli = kwargs.get("epochs", None) + if max_epochs_cli is not None: + organized["trainer"]["max_epochs"] = int(max_epochs_cli) + elif epochs_cli is not None: + organized["trainer"]["max_epochs"] = int(epochs_cli) else: - val_split = self.default_val_split - if Y is None: - if val_split > 0: - C_train, C_val, X_train, X_val = train_test_split( - C, X, test_size=val_split, shuffle=True + organized["trainer"]["max_epochs"] = 3 + + world_size = int(os.getenv("WORLD_SIZE", "1")) + use_val = organized["data"].get("val_split", DEFAULT_VAL_SPLIT) > 0.0 + + # Trainer base + organized["trainer"].setdefault("accelerator", self.accelerator) + organized["trainer"].setdefault("enable_progress_bar", False) + organized["trainer"].setdefault("logger", False) + organized["trainer"].setdefault("enable_checkpointing", False) + organized["trainer"].setdefault("num_sanity_val_steps", 0) + # conservative precision by default (TF32 is controlled globally) + organized["trainer"].setdefault("precision", 32) + if not use_val: + organized["trainer"].setdefault("limit_val_batches", 0) + + if world_size > 1: + organized["trainer"].setdefault("devices", world_size) + strat = organized["trainer"].get("strategy", "auto") + if strat == "auto" or isinstance(strat, str): + organized["trainer"]["strategy"] = DDPStrategy( + find_unused_parameters=False, + static_graph=True, + gradient_as_bucket_view=True, ) - else: - C_train, X_train = C, X - C_val, X_val = C, X - train_data = [C_train, X_train] - val_data = [C_val, X_val] else: - if val_split > 0: - C_train, C_val, X_train, X_val, Y_train, Y_val = train_test_split( - C, X, Y, test_size=val_split, shuffle=True - ) - else: - C_train, X_train, Y_train = C, X, Y - C_val, X_val, Y_val = C, X, Y - train_data = [C_train, X_train, Y_train] - val_data = [C_val, X_val, Y_val] - return train_data, val_data + organized["trainer"]["devices"] = 1 + organized["trainer"].setdefault("strategy", "auto") + organized["trainer"].setdefault("plugins", [LightningEnvironment()]) - def _build_dataloader(self, model, batch_size, *data): - """ - Helper function to build a single dataloder. - Expects *args to contain whatever data (C,X,Y) is necessary for this model. - """ - return model.dataloader(*data, batch_size=batch_size) + # Defaults: model/data + def maybe_add(cat, k, default): + if k in self.acceptable_kwargs[cat]: + organized[cat][k] = organized[cat].get(k, default) - def _build_dataloaders(self, model, train_data, val_data, **kwargs): - """ - :param model: - :param **kwargs: - """ - train_dataloader = self._build_dataloader( - model, - kwargs.get("train_batch_size", self.default_train_batch_size), - *train_data, - ) - if val_data is None: - val_dataloader = None - else: - val_dataloader = self._build_dataloader( - model, - kwargs.get("val_batch_size", self.default_val_batch_size), - *val_data, - ) - - return train_dataloader, val_dataloader + # Model + maybe_add("model", "learning_rate", DEFAULT_LEARNING_RATE) + maybe_add("model", "context_dim", self.context_dim) + maybe_add("model", "x_dim", self.x_dim) + maybe_add("model", "y_dim", self.y_dim) + if organized["model"].get("num_archetypes", 1) == 0: + organized["model"].pop("num_archetypes", None) + + # Data (GPU-friendly) + maybe_add("data", "train_batch_size", DEFAULT_TRAIN_BATCH_SIZE) + maybe_add("data", "val_batch_size", DEFAULT_VAL_BATCH_SIZE) + maybe_add("data", "test_batch_size", DEFAULT_TEST_BATCH_SIZE) + maybe_add("data", "num_workers", 0) + maybe_add("data", "pin_memory", (self.accelerator == "gpu")) + maybe_add("data", "persistent_workers", False) + maybe_add("data", "drop_last", False) + maybe_add("data", "shuffle_train", True) + maybe_add("data", "shuffle_eval", False) + maybe_add("data", "dtype", torch.float) + # Wrapper + maybe_add("wrapper", "n_bootstraps", DEFAULT_N_BOOTSTRAPS) + + # Callbacks (EarlyStopping only if we validate; else watch train_loss) + es_monitor = organized["wrapper"].get("es_monitor", "val_loss" if use_val else "train_loss") + es_mode = organized["wrapper"].get("es_mode", "min") + es_patience = organized["wrapper"].get("es_patience", DEFAULT_ES_PATIENCE) + es_verbose = organized["wrapper"].get("es_verbose", False) + es_min_delta = organized["wrapper"].get("es_min_delta", 0.0) + + callbacks_list = [] + if use_val: + callbacks_list.append( + lambda i: EarlyStopping( + monitor=es_monitor, mode=es_mode, patience=es_patience, + verbose=es_verbose, min_delta=es_min_delta + ) + ) + if organized["trainer"].get("enable_checkpointing", False): + callbacks_list.append( + lambda i: ModelCheckpoint( + monitor="val_loss" if use_val else None, + dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints", + filename="{epoch}-{val_loss:.4f}" if use_val else "{epoch}", + ) + ) + organized["trainer"].setdefault("callback_constructors", callbacks_list) + + if unrecognized: + for kw in unrecognized: + print(f"Received unknown keyword argument {kw}, probably ignoring.") + + # ---- merge constructor-time defaults as fallbacks ---- + for category, cat_kwargs in self._init_kwargs.items(): + for k, v in cat_kwargs.items(): + organized[category].setdefault(k, v) + + # ---- sanitize any pre-specified callbacks for no-val runs ---- + # (handles both direct 'callbacks' and deferred 'callback_constructors') + cb_list = organized["trainer"].get("callbacks", []) + cb_list = [self._retarget_or_strip_early_stopping(cb, use_val) for cb in cb_list] + organized["trainer"]["callbacks"] = cb_list + + ctor_list = organized["trainer"].get("callback_constructors", []) + def _wrap_ctor(ctor): + def _wrapped(i): + cb = ctor(i) + return self._retarget_or_strip_early_stopping(cb, use_val) + return _wrapped + ctor_list = [_wrap_ctor(c) for c in ctor_list] + organized["trainer"]["callback_constructors"] = ctor_list + + return organized + + # ---------- data module builder ---------- + def _build_datamodule( + self, + C: np.ndarray, + X: np.ndarray, + Y: Optional[np.ndarray], + *, + train_idx=None, + val_idx=None, + test_idx=None, + predict_idx=None, + data_kwargs: Optional[dict] = None, + task_type: str = "singletask_multivariate", + ) -> ContextualizedRegressionDataModule: + dk = dict( + batch_size=self.default_train_batch_size, + num_workers=0, + pin_memory=(self.accelerator == "gpu"), + persistent_workers=False, # caller can override + drop_last=False, + shuffle_train=True, + shuffle_eval=False, + dtype=torch.float, + ) + if data_kwargs: + dk.update(data_kwargs) + + dm = ContextualizedRegressionDataModule( + C=C, + X=X, + Y=Y, + task_type=task_type, + train_idx=train_idx, + val_idx=val_idx, + test_idx=test_idx, + predict_idx=predict_idx, + batch_size=dk["batch_size"], + num_workers=dk["num_workers"], + pin_memory=dk["pin_memory"], + persistent_workers=dk["persistent_workers"], + drop_last=dk["drop_last"], + shuffle_train=dk["shuffle_train"], + shuffle_eval=dk["shuffle_eval"], + dtype=dk["dtype"], + ) + dm.prepare_data() + dm.setup() + return dm + + # ---------- split helper ---------- + def _split_indices(self, n: int, val_split: float): + if val_split <= 0.0: + idx = np.arange(n) + return idx, None + tr_idx, va_idx = train_test_split(np.arange(n), test_size=val_split, shuffle=True) + return tr_idx, va_idx + + # ---------- optional scaling ---------- def _maybe_scale_C(self, C: np.ndarray) -> np.ndarray: if self.normalize and self.scalers["C"] is not None: return self.scalers["C"].transform(C) @@ -412,47 +452,44 @@ def _maybe_scale_X(self, X: np.ndarray) -> np.ndarray: return self.scalers["X"].transform(X) return X - def predict( - self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs - ) -> Union[np.ndarray, List[np.ndarray]]: - """Predict outcomes from context C and predictors X. - - Args: - C (np.ndarray): Context array of shape (n_samples, n_context_features) - X (np.ndarray): Predictor array of shape (N, n_features) - individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. - - Returns: - Union[np.ndarray, List[np.ndarray]]: The outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True. - """ + # ---------- public API ---------- + def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs): if not hasattr(self, "models") or self.models is None: - raise ValueError( - "Trying to predict with a model that hasn't been trained yet." + raise ValueError("Trying to predict with a model that hasn't been trained yet.") + + Cq = self._maybe_scale_C(C) + Xq = self._maybe_scale_X(X) + Yq = np.zeros((len(Cq), self.y_dim), dtype=np.float32) + + preds = [] + for i in range(len(self.models)): + dm = self._build_datamodule( + C=Cq, X=Xq, Y=Yq, + predict_idx=np.arange(len(Cq)), + data_kwargs=dict( + batch_size=self._init_kwargs["data"].get("val_batch_size", DEFAULT_VAL_BATCH_SIZE), + num_workers=self._init_kwargs["data"].get("num_workers", 0), + pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator == "gpu")), + persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), + shuffle_train=False, + shuffle_eval=False, + dtype=self._init_kwargs["data"].get("dtype", torch.float), + ), + task_type="singletask_univariate" if self._init_kwargs["model"].get("univariate", False) + else "singletask_multivariate", ) - predictions = np.array( - [ - self.trainers[i].predict_y( - self.models[i], - self.models[i].dataloader( - self._maybe_scale_C(C), - self._maybe_scale_X(X), - np.zeros((len(C), self.y_dim)), - ), - **kwargs, - ) - for i in range(len(self.models)) - ] - ) - if individual_preds: - preds = predictions - else: - preds = np.mean(predictions, axis=0) + yhat = self.trainers[i].predict_y(self.models[i], dm.predict_dataloader(), **kwargs) + preds.append(yhat) + + predictions = np.array(preds) + if not individual_preds: + predictions = np.mean(predictions, axis=0) if self.normalize and self.scalers["Y"] is not None: if individual_preds: - preds = np.array([self.scalers["Y"].inverse_transform(p) for p in preds]) + predictions = np.array([self.scalers["Y"].inverse_transform(p) for p in predictions]) else: - preds = self.scalers["Y"].inverse_transform(preds) - return preds + predictions = self.scalers["Y"].inverse_transform(predictions) + return predictions def predict_params( self, @@ -460,143 +497,152 @@ def predict_params( individual_preds: bool = False, model_includes_mus: bool = True, **kwargs, - ) -> Union[ - np.ndarray, - List[np.ndarray], - Tuple[np.ndarray, np.ndarray], - Tuple[List[np.ndarray], List[np.ndarray]], - ]: - """ - Predict context-specific model parameters from context C. - - Args: - C (np.ndarray): Context array of shape (n_samples, n_context_features) - individual_preds (bool, optional): Whether to return individual model predictions for each bootstrap. Defaults to False, averaging across bootstraps. - model_includes_mus (bool, optional): Whether the model includes context-specific offsets (mu). Defaults to True. - - Returns: - Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]]: The parameters of the predicted context-specific models. - Returned as lists of individual bootstraps if individual_preds is True, otherwise averages the bootstraps for a better estimate. - If model_includes_mus is True, returns both coefficients and offsets as a tuple of (betas, mus). Otherwise, returns coefficients (betas) only. - For model_includes_mus=True, ([betas], [mus]) if individual_preds is True, otherwise (betas, mus). - For model_includes_mus=False, [betas] if individual_preds is True, otherwise betas. - betas is shape (n_samples, x_dim, y_dim) or (n_samples, x_dim) if y_dim = 1. - mus is shape (n_samples, y_dim) or (n_samples,) if y_dim = 1. - """ - # Returns betas, mus - if kwargs.pop("uses_y", True): - get_dataloader = lambda i: self.models[i].dataloader( - self._maybe_scale_C(C), - np.zeros((len(C), self.x_dim)), - np.zeros((len(C), self.y_dim)) - ) - else: - get_dataloader = lambda i: self.models[i].dataloader( - self._maybe_scale_C(C), - np.zeros((len(C), self.x_dim)) + ): + if not hasattr(self, "models") or self.models is None: + raise ValueError("Trying to predict with a model that hasn't been trained yet.") + + Cq = self._maybe_scale_C(C) + X_zero = np.zeros((len(Cq), self.x_dim), dtype=np.float32) + Y_zero = np.zeros((len(Cq), self.y_dim), dtype=np.float32) + + out_betas, out_mus = [], [] + for i in range(len(self.models)): + dm = self._build_datamodule( + C=Cq, + X=X_zero, + Y=Y_zero if kwargs.pop("uses_y", True) else None, + predict_idx=np.arange(len(Cq)), + data_kwargs=dict( + batch_size=self._init_kwargs["data"].get("val_batch_size", DEFAULT_VAL_BATCH_SIZE), + num_workers=self._init_kwargs["data"].get("num_workers", 0), + pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator == "gpu")), + persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), + shuffle_train=False, + shuffle_eval=False, + dtype=self._init_kwargs["data"].get("dtype", torch.float), + ), + task_type="singletask_univariate" if self._init_kwargs["model"].get("univariate", False) + else "singletask_multivariate", ) - predictions = [ - self.trainers[i].predict_params(self.models[i], get_dataloader(i), **kwargs) - for i in range(len(self.models)) - ] - if model_includes_mus: - betas = np.array([p[0] for p in predictions]) - mus = np.array([p[1] for p in predictions]) - if individual_preds: - return betas, mus + pred = self.trainers[i].predict_params(self.models[i], dm.predict_dataloader(), **kwargs) + if model_includes_mus: + out_betas.append(pred[0]); out_mus.append(pred[1]) else: - return np.mean(betas, axis=0), np.mean(mus, axis=0) - betas = np.array(predictions) - if not individual_preds: - return np.mean(betas, axis=0) - return betas + out_betas.append(pred) + + if model_includes_mus: + betas = np.array(out_betas); mus = np.array(out_mus) + return (betas, mus) if individual_preds else (np.mean(betas, axis=0), np.mean(mus, axis=0)) + else: + betas = np.array(out_betas) + return betas if individual_preds else np.mean(betas, axis=0) def fit(self, *args, **kwargs) -> None: """ Fit contextualized model to data. - Args: - C (np.ndarray): Context array of shape (n_samples, n_context_features) - X (np.ndarray): Predictor array of shape (N, n_features) - Y (np.ndarray, optional): Target array of shape (N, n_targets). Defaults to None, where X will be used as targets such as in Contextualized Networks. - max_epochs (int, optional): Maximum number of epochs to train for. Defaults to 1. - learning_rate (float, optional): Learning rate for optimizer. Defaults to 1e-3. - val_split (float, optional): Proportion of data to use for validation and early stopping. Defaults to 0.2. - n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. - train_batch_size (int, optional): Batch size for training. Defaults to 1. - val_batch_size (int, optional): Batch size for validation. Defaults to 16. - test_batch_size (int, optional): Batch size for testing. Defaults to 16. - es_patience (int, optional): Number of epochs to wait before early stopping. Defaults to 1. - es_monitor (str, optional): Metric to monitor for early stopping. Defaults to "val_loss". - es_mode (str, optional): Mode for early stopping. Defaults to "min". - es_verbose (bool, optional): Whether to print early stopping updates. Defaults to False. + C (np.ndarray): (n, c_dim) + X (np.ndarray): (n, x_dim) + Y (np.ndarray, optional): (n, y_dim) """ - self.models = [] - self.trainers = [] - self.dataloaders = {"train": [], "val": [], "test": []} + self.models, self.trainers = [], [] + C, X = args[0], args[1] if self.normalize: - if self.scalers["C"] is None: - self.scalers["C"] = StandardScaler().fit(C) + if self.scalers["C"] is None: self.scalers["C"] = StandardScaler().fit(C) C = self.scalers["C"].transform(C) - if self.scalers["X"] is None: - self.scalers["X"] = StandardScaler().fit(X) + if self.scalers["X"] is None: self.scalers["X"] = StandardScaler().fit(X) X = self.scalers["X"].transform(X) - self.context_dim = C.shape[-1] - self.x_dim = X.shape[-1] + self.context_dim = C.shape[-1]; self.x_dim = X.shape[-1] + if len(args) == 3: Y = args[2] - if kwargs.get("Y", None) is not None: - Y = kwargs.get("Y") - if len(Y.shape) == 1: # add feature dimension to Y if not given. - Y = np.expand_dims(Y, 1) + if kwargs.get("Y", None) is not None: Y = kwargs.get("Y") + if len(Y.shape) == 1: Y = np.expand_dims(Y, 1) if self.normalize and not np.array_equal(np.unique(Y), np.array([0, 1])): - if self.scalers["Y"] is None: - self.scalers["Y"] = StandardScaler().fit(Y) + if self.scalers["Y"] is None: self.scalers["Y"] = StandardScaler().fit(Y) Y = self.scalers["Y"].transform(Y) self.y_dim = Y.shape[-1] args = (C, X, Y) else: self.y_dim = self.x_dim args = (C, X) - organized_kwargs = self._organize_and_expand_fit_kwargs(**kwargs) - self.n_bootstraps = organized_kwargs["wrapper"].get( - "n_bootstraps", self.n_bootstraps - ) - for bootstrap in range(self.n_bootstraps): - model = self.base_constructor(**organized_kwargs["model"]) - train_data, val_data = self._split_train_data( - *args, **organized_kwargs["data"] - ) - train_dataloader, val_dataloader = self._build_dataloaders( - model, - train_data, - val_data, - **organized_kwargs["data"], + + organized = self._organize_and_expand_fit_kwargs(**kwargs) + self.n_bootstraps = organized["wrapper"].get("n_bootstraps", self.n_bootstraps) + + n = C.shape[0] + val_split = organized["data"].get("val_split", DEFAULT_VAL_SPLIT) + use_val = val_split > 0.0 + + for b in range(self.n_bootstraps): + # Build model (LightningModule) + _model_kwargs = dict(organized["model"]) + _model_kwargs.pop("univariate", None) + model = self.base_constructor(**_model_kwargs) + self.model_ = model + + # Indices + train_idx, val_idx = self._split_indices(n, val_split) + test_idx = None + + # DataModule + task_type = "singletask_univariate" if organized["model"].get("univariate", False) else "singletask_multivariate" + dm = self._build_datamodule( + C=args[0], X=args[1], Y=(args[2] if len(args) == 3 else None), + train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, + data_kwargs=dict( + batch_size=organized["data"].get("train_batch_size", DEFAULT_TRAIN_BATCH_SIZE), + num_workers=organized["data"].get("num_workers", 0), + pin_memory=organized["data"].get("pin_memory", (self.accelerator == "gpu")), + persistent_workers=organized["data"].get("persistent_workers", False), + drop_last=organized["data"].get("drop_last", False), + shuffle_train=organized["data"].get("shuffle_train", True), + shuffle_eval=organized["data"].get("shuffle_eval", False), + dtype=organized["data"].get("dtype", torch.float), + ), + task_type=task_type, ) - # Makes a new trainer for each bootstrap fit - bad practice, but necessary here. - my_trainer_kwargs = copy.deepcopy(organized_kwargs["trainer"]) - # Must reconstruct the callbacks because they save state from fitting trajectories. - my_trainer_kwargs["callbacks"] = [ - f(bootstrap) - for f in organized_kwargs["trainer"]["callback_constructors"] - ] - del my_trainer_kwargs["callback_constructors"] - trainer = self.trainer_constructor( - **my_trainer_kwargs, enable_progress_bar=False + + # Trainer (fresh callbacks) + trainer_kwargs = copy.deepcopy(organized["trainer"]) + trainer_kwargs["callbacks"] = [f(b) for f in trainer_kwargs.get("callback_constructors", [])] + trainer_kwargs.pop("callback_constructors", None) + + # Build via factory (handles env quirks) + from contextualized.regression.trainers import make_trainer_with_env + trainer = make_trainer_with_env( + self.trainer_constructor, + **trainer_kwargs, ) - checkpoint_callback = my_trainer_kwargs["callbacks"][1] - os.makedirs(checkpoint_callback.dirpath, exist_ok=True) - try: + + # Ensure checkpoint dir if used + for cb in trainer_kwargs.get("callbacks", []): + if isinstance(cb, ModelCheckpoint): + os.makedirs(cb.dirpath, exist_ok=True) + + # Fit (don’t pass val loader if no val split) + if use_val and dm.val_dataloader() is not None: trainer.fit( - model, train_dataloader, val_dataloader, **organized_kwargs["fit"] + model, + train_dataloaders=dm.train_dataloader(), + val_dataloaders=dm.val_dataloader(), + **organized["fit"], ) - except: - trainer.fit(model, train_dataloader, **organized_kwargs["fit"]) - if kwargs.get("max_epochs", 1) > 0: - best_checkpoint = torch.load(checkpoint_callback.best_model_path) - model.load_state_dict(best_checkpoint["state_dict"]) - self.dataloaders["train"].append(train_dataloader) - self.dataloaders["val"].append(val_dataloader) + else: + trainer.fit( + model, + train_dataloaders=dm.train_dataloader(), + **organized["fit"], + ) + + # (Optional) load best ckpt if checkpointing enabled + max_epochs = trainer_kwargs.get("max_epochs", 1) + if max_epochs and trainer_kwargs.get("enable_checkpointing", False): + ckpt_cb = next((cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)), None) + if ckpt_cb and ckpt_cb.best_model_path and os.path.exists(ckpt_cb.best_model_path): + best = torch.load(ckpt_cb.best_model_path, map_location="cpu") + model.load_state_dict(best["state_dict"]) + self.models.append(model) self.trainers.append(trainer) diff --git a/contextualized/modules.py b/contextualized/modules.py index 96d48b07..e3e87ba3 100644 --- a/contextualized/modules.py +++ b/contextualized/modules.py @@ -67,6 +67,25 @@ class Explainer(SoftSelect): def __init__(self, k, out_shape): super().__init__((k,), out_shape) +def _resolve_link_fn(maybe_link): + """ + Accepts either: + - a string key (looked up in LINK_FUNCTIONS), or + - a callable (returned as-is, including functools.partial) + """ + if isinstance(maybe_link, str): + try: + return LINK_FUNCTIONS[maybe_link] + except KeyError as e: + raise KeyError( + f"Unknown link_fn '{maybe_link}'. " + f"Valid options: {list(LINK_FUNCTIONS.keys())}" + ) from e + if callable(maybe_link): + return maybe_link + raise TypeError(f"link_fn must be str or callable, got {type(maybe_link).__name__}") + + class MLP(nn.Module): """ @@ -91,7 +110,8 @@ def __init__( else: # Linear encoder mlp_layers = [nn.Linear(input_dim, output_dim)] self.mlp = nn.Sequential(*mlp_layers) - self.link_fn = LINK_FUNCTIONS[link_fn] + self.link_fn = _resolve_link_fn(link_fn) + def forward(self, X): """Torch Forward pass.""" @@ -101,7 +121,9 @@ def forward(self, X): class NGAM(nn.Module): """ - Neural generalized additive model + Neural generalized additive model: sum_i f_i(x_i). + Each f_i is an MLP that outputs (B, output_dim). + The final link function is applied once to the summed output. """ def __init__( @@ -114,8 +136,12 @@ def __init__( link_fn="identity", ): super().__init__() - self.intput_dim = input_dim + self.input_dim = input_dim self.output_dim = output_dim + + # Internal NAM pieces should be identity-linked; the global link is applied after summation. + per_feat_link = "identity" + self.nams = nn.ModuleList( [ MLP( @@ -124,21 +150,22 @@ def __init__( width, layers, activation=activation, - link_fn=identity_link, + link_fn=per_feat_link, ) for _ in range(input_dim) ] ) - self.link_fn = LINK_FUNCTIONS[link_fn] + self.link_fn = _resolve_link_fn(link_fn) def forward(self, X): - """Torch Forward pass.""" + """X: (B, input_dim)""" ret = self.nams[0](X[:, 0].unsqueeze(-1)) - for i, nam in enumerate(self.nams[1:]): + for i, nam in enumerate(self.nams[1:], start=1): ret += nam(X[:, i].unsqueeze(-1)) return self.link_fn(ret) + class Linear(nn.Module): """ Linear encoder diff --git a/contextualized/regression/datamodules.py b/contextualized/regression/datamodules.py index 2f2586e3..04134f8c 100644 --- a/contextualized/regression/datamodules.py +++ b/contextualized/regression/datamodules.py @@ -6,7 +6,7 @@ import pandas as pd import torch from torch.utils.data import DataLoader -import lightning as pl +import pytorch_lightning as pl from .datasets import ( MultivariateDataset, @@ -159,16 +159,16 @@ def _mk_dataset(idx: IndexLike): X_s = _maybe_index(X, idx) Y_s = None if (Y is None) else _maybe_index(Y, idx) ds_cls = TASK_TO_DATASET[self.task_type] - # Y can be optional for some tasks; the dataset constructors you showed - # expect Y. If a task doesn't use Y, pass a placeholder or ensure callers pass X as Y when needed. + if Y_s is None: - # If Y is truly not used for this task_type, construct a compatible placeholder. - # Here we create zeros with appropriate last dim to match dataset expectations. - # For singletask_univariate/multivariate we assume Y has shape (n, y_dim). - # Override as needed if your upstream code guarantees a Y. - Y_s = torch.zeros((C_s.shape[0], X_s.shape[-1]), dtype=self.dtype) + raise ValueError( + f"Y is required for regression task_type='{self.task_type}'. " + "Pass a real Y array matching your task." + ) + return ds_cls(C_s, X_s, Y_s, dtype=self.dtype) + self.ds_train = _mk_dataset(self.train_idx) self.ds_val = _mk_dataset(self.val_idx) self.ds_test = _mk_dataset(self.test_idx) @@ -196,15 +196,16 @@ def train_dataloader(self) -> DataLoader: **self._common_dl_kwargs(), ) - def val_dataloader(self) -> DataLoader: + def val_dataloader(self): if self.ds_val is None: - raise RuntimeError("val dataset is not set; provide val_idx or splitter.") + return None return DataLoader( dataset=self.ds_val, - shuffle=self.shuffle_eval, # False by default + shuffle=self.shuffle_eval, **self._common_dl_kwargs(), ) + def test_dataloader(self) -> DataLoader: if self.ds_test is None: raise RuntimeError("test dataset is not set; provide test_idx or splitter.") diff --git a/contextualized/regression/datasets.py b/contextualized/regression/datasets.py index d0b85259..ce93708f 100644 --- a/contextualized/regression/datasets.py +++ b/contextualized/regression/datasets.py @@ -12,9 +12,9 @@ class MultivariateDataset(Dataset): Simple multivariate dataset with context, predictors, and outcomes. """ def __init__(self, C, X, Y, dtype=torch.float): - self.C = torch.tensor(C, dtype=dtype) - self.X = torch.tensor(X, dtype=dtype) - self.Y = torch.tensor(Y, dtype=dtype) + self.C = torch.as_tensor(C, dtype=dtype) + self.X = torch.as_tensor(X, dtype=dtype) + self.Y = torch.as_tensor(Y, dtype=dtype) self.c_dim = C.shape[-1] self.x_dim = X.shape[-1] self.y_dim = Y.shape[-1] @@ -37,9 +37,9 @@ class UnivariateDataset(Dataset): Simple univariate dataset with context, predictors, and one outcome. """ def __init__(self, C, X, Y, dtype=torch.float): - self.C = torch.tensor(C, dtype=dtype) - self.X = torch.tensor(X, dtype=dtype) - self.Y = torch.tensor(Y, dtype=dtype) + self.C = torch.as_tensor(C, dtype=dtype) + self.X = torch.as_tensor(X, dtype=dtype) + self.Y = torch.as_tensor(Y, dtype=dtype) self.c_dim = C.shape[-1] self.x_dim = X.shape[-1] self.y_dim = Y.shape[-1] @@ -118,9 +118,9 @@ class MultitaskUnivariateDataset(Dataset): Splits each sample into univariate X and Y feature pairs for univariate regression tasks. """ def __init__(self, C, X, Y, dtype=torch.float): - self.C = torch.tensor(C, dtype=dtype) - self.X = torch.tensor(X, dtype=dtype) - self.Y = torch.tensor(Y, dtype=dtype) + self.C = torch.as_tensor(C, dtype=dtype) + self.X = torch.as_tensor(X, dtype=dtype) + self.Y = torch.as_tensor(Y, dtype=dtype) self.c_dim = C.shape[-1] self.x_dim = X.shape[-1] self.y_dim = Y.shape[-1] diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index 14345c34..5a2af1db 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -18,12 +18,12 @@ import numpy as np import torch from torch.utils.data import DataLoader -import lightning as pl - +import pytorch_lightning as pl from contextualized.regression.regularizers import REGULARIZERS -from contextualized.regression.losses import MSE +from contextualized.regression.losses import MSE from contextualized.functions import LINK_FUNCTIONS + from contextualized.regression.metamodels import ( NaiveMetamodel, SubtypeMetamodel, @@ -33,6 +33,57 @@ MULTITASK_METAMODELS, ) +# --- Accept both string registry keys and callables for link_fn / loss_fn --- +def _resolve_registry_or_callable(maybe_obj, registry, name: str): + """Return a function from a registry by key, or the callable directly.""" + if isinstance(maybe_obj, str): + try: + return registry[maybe_obj] + except KeyError as e: + raise KeyError( + f"Unknown {name} '{maybe_obj}'. Valid keys: {list(registry.keys())}" + ) from e + if callable(maybe_obj): + return maybe_obj + raise TypeError(f"{name} must be a string key or a callable, got {type(maybe_obj).__name__}") + + +def _resolve_loss(maybe_loss): + """ + Allow: + * 'mse' string (maps to local MSE), + * any callable (already constructed loss), + and reject unknown strings to avoid circular imports with package-level registries. + """ + if isinstance(maybe_loss, str): + if maybe_loss.lower() == "mse": + return MSE + raise KeyError( + f"Unknown loss_fn '{maybe_loss}'. " + "Pass a callable loss or the string 'mse'." + ) + if callable(maybe_loss): + return maybe_loss + raise TypeError(f"loss_fn must be a string key or a callable, got {type(maybe_loss).__name__}") +# --------------------------------------------------------------------------- +def _resolve_regularizer(maybe_reg): + """ + Allow: + * string key -> lookup in REGULARIZERS + * callable -> pass through directly + """ + if isinstance(maybe_reg, str): + try: + return REGULARIZERS[maybe_reg] + except KeyError as e: + raise KeyError( + f"Unknown model_regularizer '{maybe_reg}'. " + f"Valid keys: {list(REGULARIZERS.keys())}" + ) from e + if callable(maybe_reg): + return maybe_reg + raise TypeError(f"model_regularizer must be a string key or a callable, got {type(maybe_reg).__name__}") + class ContextualizedRegressionBase(pl.LightningModule): """ @@ -369,12 +420,11 @@ def __init__( super().__init__() self.learning_rate = learning_rate self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.base_y_predictor = base_y_predictor self.base_param_predictor = base_param_predictor if metamodel_type == "subtype": @@ -533,12 +583,11 @@ def __init__( super().__init__() self.learning_rate = learning_rate self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.metamodel = MultitaskMetamodel( context_dim=context_dim, x_dim=x_dim, @@ -677,12 +726,11 @@ def __init__( self.learning_rate = learning_rate self.metamodel_type = metamodel_type self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.metamodel = TasksplitMetamodel( context_dim=context_dim, x_dim=x_dim, @@ -842,12 +890,11 @@ def __init__( super().__init__() self.learning_rate = learning_rate self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.base_y_predictor = base_y_predictor self.base_param_predictor = base_param_predictor if metamodel_type == "subtype": @@ -978,12 +1025,11 @@ def __init__( super().__init__() self.learning_rate = learning_rate self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.metamodel = MultitaskMetamodel( context_dim=context_dim, x_dim=x_dim, @@ -1076,12 +1122,11 @@ def __init__( super().__init__() self.learning_rate = learning_rate self.fit_intercept = fit_intercept - self.link_fn = LINK_FUNCTIONS[link_fn] - if loss_fn == "mse": - self.loss_fn = MSE - else: - raise ValueError("Supported loss_fn's: mse") - self.model_regularizer = REGULARIZERS[model_regularizer] + self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") + self.loss_fn = _resolve_loss(loss_fn) + + self.model_regularizer = _resolve_regularizer(model_regularizer) + self.metamodel = TasksplitMetamodel( context_dim=context_dim, x_dim=x_dim, @@ -1203,26 +1248,21 @@ def __init__(self, context_dim, x_dim, **kwargs): super().__init__(context_dim, x_dim, x_dim, **kwargs) def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) - beta_hat = beta_hat.squeeze(-1) + beta_hat = beta_hat.squeeze(-1) # (B, y, x) beta_hat_T = beta_hat.transpose(1, 2) signs = torch.sign(beta_hat) signs[signs != signs.transpose(1, 2)] = 0 correlations = signs * torch.sqrt(torch.abs(beta_hat * beta_hat_T)) batch.update({ - "betas": beta_hat.squeeze(-1), + "betas": beta_hat, # already squeezed "mus": mu_hat.squeeze(-1), "correlations": correlations, }) return batch + class MultitaskContextualizedCorrelation(MultitaskContextualizedUnivariateRegression): """Using multitask univariate contextualized regression to estimate Pearson's correlation See TasksplitMetamodel for assumptions and full docstring @@ -1272,32 +1312,16 @@ def __init__( self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ - C, _, _, _ = batch - beta_hat, mu_hat = self(C) + beta_hat, mu_hat = self(batch) # self.forward expects dict batch + # Zero diagonal (mask pre-registered in __init__) beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) - return beta_hat, mu_hat - - def dataloader(self, C, X, Y=None, **kwargs): - """ - - :param C: - :param X: - :param Y: - :param **kwargs: + batch.update({ + "betas": beta_hat, + "mus": mu_hat, + }) + return batch - """ - if Y is not None: - print( - "Passed a Y, but this is a Markov Graph between X featuers. Ignoring Y." - ) - return super().dataloader(C, X, X, **kwargs) class ContextualizedMarkovGraph(ContextualizedRegression): @@ -1315,32 +1339,13 @@ def __init__(self, context_dim, x_dim, **kwargs): self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ - C, _, _, _ = batch - beta_hat, mu_hat = self(C) - beta_hat = beta_hat + torch.transpose( - beta_hat, 1, 2 - ) # hotfix to enforce symmetry + beta_hat, mu_hat = self(batch) # dict batch + # Enforce symmetry (hotfix) and zero diagonal + beta_hat = beta_hat + beta_hat.transpose(1, 2) beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) - return beta_hat, mu_hat - - def dataloader(self, C, X, Y=None, **kwargs): - """ - - :param C: - :param X: - :param Y: - :param **kwargs: - - """ + batch.update({ + "betas": beta_hat, + "mus": mu_hat, + }) + return batch - if Y is not None: - print( - "Passed a Y, but this is a Markov Graph between X featuers. Ignoring Y." - ) - return super().dataloader(C, X, X, **kwargs) diff --git a/contextualized/regression/trainers.py b/contextualized/regression/trainers.py index 759e09ea..c33c66e4 100644 --- a/contextualized/regression/trainers.py +++ b/contextualized/regression/trainers.py @@ -2,48 +2,131 @@ PyTorch-Lightning trainers used for Contextualized regression. """ +from typing import Any, Tuple, List import numpy as np +import torch import pytorch_lightning as pl +from pytorch_lightning.plugins.environments import LightningEnvironment +import os +from pytorch_lightning.strategies import DDPStrategy + + +def _stack_from_preds(preds: List[dict], key: str) -> torch.Tensor: + """Concatenate a tensor field from the list of batch dicts returned by predict().""" + parts = [] + for p in preds: + val = p[key] + # ensure tensor on cpu + if isinstance(val, np.ndarray): + val = torch.from_numpy(val) + parts.append(val.detach().cpu()) + return torch.cat(parts, dim=0) class RegressionTrainer(pl.Trainer): """ Trains the contextualized.regression lightning_modules + and provides convenience prediction helpers that reshape + batched outputs into expected numpy arrays without relying + on model-private _*reshape helpers. """ - def predict_params(self, model, dataloader): + @torch.no_grad() + def predict_params(self, model: pl.LightningModule, dataloader) -> Tuple[np.ndarray, np.ndarray]: """ - Returns context-specific regression models - - beta (numpy.ndarray): (n, y_dim, x_dim) - - mu (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate]) + Returns context-specific regression parameters. + + Returns + ------- + (betas, mus) + betas: (n, y_dim, x_dim) + mus: (n, y_dim) or (n, y_dim, 1) depending on the model """ - preds = super().predict(model, dataloader) - return model._params_reshape(preds, dataloader) + preds = super().predict(model, dataloader) # list of batch dicts + betas = _stack_from_preds(preds, "betas") + mus = _stack_from_preds(preds, "mus") + return betas.numpy(), mus.numpy() - def predict_y(self, model, dataloader): + @torch.no_grad() + def predict_y(self, model: pl.LightningModule, dataloader) -> np.ndarray: """ - Returns context-specific predictions of the response Y - - y_hat (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate]) + Returns context-specific predictions of the response Y. + + Returns + ------- + y_hat : (n, y_dim, 1) for multivariate, or (n, y_dim, x_dim) for univariate """ - preds = super().predict(model, dataloader) - return model._y_reshape(preds, dataloader) + preds = super().predict(model, dataloader) # list of batch dicts + + y_parts = [] + for p in preds: + # Required keys were added by model.predict_step(...) + C = p["contexts"] + X = p["predictors"] + betas = p["betas"] + mus = p["mus"] + + # Ensure tensors on CPU first; model will move as needed inside helpers + if not torch.is_tensor(C): C = torch.as_tensor(C) + if not torch.is_tensor(X): X = torch.as_tensor(X) + if not torch.is_tensor(betas): betas = torch.as_tensor(betas) + if not torch.is_tensor(mus): mus = torch.as_tensor(mus) + + # --- FIX: make shapes broadcastable --- + # --- FIX: make shapes broadcastable for both multivariate (3D) and univariate (4D) --- + # Multivariate convention: X (B, y, x), betas (B, y, x), mus (B, y, 1) + # Univariate convention: X (B, y, x, 1), betas (B, y, x, 1), mus (B, y, x, 1) + if X.dim() == 2 and betas.dim() == 3 and betas.size(-1) == X.size(-1): + # allow X provided as (B, x) for multivariate -> expand to (B,1,x) + X = X.unsqueeze(1) # (B,1,x) + + if betas.dim() == 3 and X.dim() == 4: + # univariate predict_step may have squeezed betas -> add singleton to match (B,y,x,1) + betas = betas.unsqueeze(-1) # (B,y,x,1) + + if mus.dim() == 2: + # multivariate: ensure (B,y,1) + mus = mus.unsqueeze(-1) # (B,y,1) + elif mus.dim() == 3 and X.dim() == 4 and mus.size(-1) != 1: + # univariate: ensure trailing singleton (B,y,x,1) if it was (B,y,x) + mus = mus.unsqueeze(-1) + # --- end FIX --- + + # --- end FIX --- + + yhat = model._predict_y(C, X, betas, mus) # uses model's link + y_parts.append(yhat.detach().cpu()) + + y = torch.cat(y_parts, dim=0) + return y.numpy() + class CorrelationTrainer(RegressionTrainer): """ Trains the contextualized.regression correlation lightning_modules + and exposes a helper to compute context-specific correlation matrices. """ - def predict_correlation(self, model, dataloader): + @torch.no_grad() + def predict_correlation(self, model: pl.LightningModule, dataloader) -> np.ndarray: """ - Returns context-specific correlation networks containing Pearson's correlation coefficient - - correlation (numpy.ndarray): (n, x_dim, x_dim) + Returns context-specific correlation networks containing Pearson's correlation coefficient. + + Returns + ------- + correlations : (n, x_dim, x_dim) """ - betas, _ = super().predict_params(model, dataloader) + # If the model already returns 'correlations' in predict_step, prefer that. + preds = super().predict(model, dataloader) + if "correlations" in preds[0]: + cors = torch.cat([p["correlations"].detach().cpu() for p in preds], dim=0) + return cors.numpy() + + # Fallback: derive from betas like before + betas, _ = self.predict_params(model, dataloader) signs = np.sign(betas) - signs[signs != np.transpose(signs, (0, 2, 1))] = ( - 0 # remove asymmetric estimations - ) + signs[signs != np.transpose(signs, (0, 2, 1))] = 0 correlations = signs * np.sqrt(np.abs(betas * np.transpose(betas, (0, 2, 1)))) return correlations @@ -51,15 +134,66 @@ def predict_correlation(self, model, dataloader): class MarkovTrainer(CorrelationTrainer): """ Trains the contextualized.regression markov graph lightning_modules + and exposes a helper to compute context-specific precision matrices. """ - def predict_precision(self, model, dataloader): + @torch.no_grad() + def predict_precision(self, model: pl.LightningModule, dataloader) -> np.ndarray: """ - Returns context-specific precision matrix under a Gaussian graphical model + Returns context-specific precision matrix under a Gaussian graphical model. + Assuming all diagonal precisions are equal and constant over context, this is equivalent to the negative of the multivariate regression coefficient. - - precision (numpy.ndarray): (n, x_dim, x_dim) + + Returns + ------- + precision : (n, x_dim, x_dim) """ - # A trick in the markov lightning_module predict_step makes makes the predict_correlation - # output equivalent to negative precision values here. + # A trick in the markov lightning_module predict_step ensures the + # correlation output corresponds (up to sign) to precision entries. return -super().predict_correlation(model, dataloader) + + + +# ADD THIS FACTORY (end of file) + +# at top of file you already have: +# from pytorch_lightning.plugins.environments import LightningEnvironment + +from contextualized.utils.engine import pick_engine + + +def make_trainer_with_env(trainer_cls=RegressionTrainer, **kwargs) -> pl.Trainer: + # Respect explicit user settings; otherwise auto-pick + accelerator = kwargs.pop("accelerator", None) + devices = kwargs.pop("devices", None) + strategy = kwargs.pop("strategy", None) + plugins = kwargs.pop("plugins", None) + + accelerator, devices, strategy = pick_engine( + accelerator=accelerator, + devices=devices, + strategy=strategy, + prefer_spawn=True, # allows plain `python script.py` to use all GPUs + ) + + # If using classic ddp, upgrade string->Strategy with tuned flags + if strategy == "ddp": + strategy = DDPStrategy( + find_unused_parameters=False, + static_graph=True, + gradient_as_bucket_view=True, + ) + + if plugins is None and accelerator == "cpu": + from pytorch_lightning.plugins.environments import LightningEnvironment + plugins = [LightningEnvironment()] + + return trainer_cls( + accelerator=accelerator, + devices=devices, + strategy=strategy, + plugins=plugins, + **kwargs, + ) + diff --git a/contextualized/utils/__init__.py b/contextualized/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contextualized/utils/engine.py b/contextualized/utils/engine.py new file mode 100644 index 00000000..225de18e --- /dev/null +++ b/contextualized/utils/engine.py @@ -0,0 +1,40 @@ +# contextualized/utils/engine.py +import os, torch +from typing import Tuple, Union + +def _under_torchrun() -> bool: + e = os.environ + return any(k in e for k in ("LOCAL_RANK", "RANK", "WORLD_SIZE")) + +def _visible_gpus() -> int: + return torch.cuda.device_count() if torch.cuda.is_available() else 0 + +def pick_engine( + accelerator: str | None = None, + devices: Union[int, str, list[int]] | None = None, + strategy: str | None = None, + prefer_spawn: bool = True, +) -> Tuple[str, Union[int, str, list[int]], Union[str, object]]: + """ + CPU / 1-GPU / multi-GPU auto-selection WITHOUT requiring torchrun. + - If user passes any of (accelerator/devices/strategy), we respect them. + - Else: + GPUs == 0 => cpu, devices='auto' + GPUs == 1 => gpu, devices=1 + GPUs > 1 => + - if launched with torchrun => gpu, devices=1, strategy='ddp' + - else => gpu, devices=, strategy='ddp_spawn' + """ + if accelerator is not None or devices is not None or strategy is not None: + return accelerator or "auto", devices or "auto", strategy or "auto" + + ngpu = _visible_gpus() + if ngpu == 0: + return "cpu", "auto", "auto" + + if ngpu == 1: + return "gpu", 1, "auto" + + if _under_torchrun(): + return "gpu", 1, "ddp" # one proc per GPU (torchrun sets ranks) + return "gpu", ngpu, ("ddp_spawn" if prefer_spawn else "auto") diff --git a/contextualized_sanity_run.py b/contextualized_sanity_run.py deleted file mode 100644 index 4c364ee7..00000000 --- a/contextualized_sanity_run.py +++ /dev/null @@ -1,175 +0,0 @@ -#!/usr/bin/env python -import argparse, sys, time, warnings, socket, os -from pathlib import Path -import numpy as np - -warnings.filterwarnings("ignore", message="To copy construct from a tensor", category=UserWarning) - -try: - import torch - import lightning as pl - from lightning.pytorch.strategies import DDPStrategy -except Exception as e: - print(f"[FATAL] torch/lightning import failed: {e}"); sys.exit(1) - -try: - from contextualized.regression.datamodules import ContextualizedRegressionDataModule - from contextualized.regression.datasets import ( - MultivariateDataset, UnivariateDataset, MultitaskMultivariateDataset, MultitaskUnivariateDataset - ) - _ctx_ok = True -except Exception as e: - print(f"[FATAL] Could not import contextualized modules: {e}"); _ctx_ok = False - -def _free_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - -def make_tensors(n=32, c_dim=3, x_dim=5, y_dim=4, dtype=torch.float32, seed=7): - rng = np.random.default_rng(seed) - C = torch.tensor(rng.normal(size=(n, c_dim)), dtype=dtype) - X = torch.tensor(rng.normal(size=(n, x_dim)), dtype=dtype) - Y = torch.tensor(rng.normal(size=(n, y_dim)), dtype=dtype) - return C, X, Y - -def simple_splitter(C, X, Y): - n = C.shape[0]; idx = torch.arange(n, dtype=torch.long) - n_tr = int(0.6*n); n_va = int(0.2*n) - return idx[:n_tr], idx[n_tr:n_tr+n_va], idx[n_tr+n_va:] - -class TinyLightning(pl.LightningModule): - def __init__(self, in_dim=5, out_dim=4, lr=1e-3): - super().__init__(); self.save_hyperparameters() - self.head = torch.nn.Linear(in_dim, out_dim, bias=False) - torch.manual_seed(0) - with torch.no_grad(): - w = torch.arange(in_dim*out_dim).float().reshape(out_dim, in_dim)/100.0 - self.head.weight.copy_(w) - self.mu = torch.nn.Parameter(torch.zeros(out_dim, 1)) - def configure_optimizers(self): - return torch.optim.SGD(self.parameters(), lr=self.hparams.lr) - def training_step(self, batch, batch_idx): - betas = self.head(batch["predictors"]); loss = (betas**2).mean() - self.log("train_loss", loss, on_epoch=True, prog_bar=False, logger=False); return loss - @torch.no_grad() - def predict_step(self, batch, batch_idx, dataloader_idx=0): - betas = self.head(batch["predictors"]); mus = self.mu.view(1,-1).repeat(betas.shape[0],1) - return {"idx": batch["idx"].detach().clone().cpu(), - "betas": betas.detach().cpu(), - "mus": mus.detach().cpu()} - -def check_dataset_shapes(C, X, Y): - print("\n[CHECK] Dataset constructors & shapes") - mv = MultivariateDataset(C, X, Y); uv = UnivariateDataset(C, X, Y) - mtmv = MultitaskMultivariateDataset(C, X, Y); mtuv = MultitaskUnivariateDataset(C, X, Y) - print(f" MultivariateDataset: len={len(mv)} sample keys={list(mv[0].keys())}") - print(f" UnivariateDataset: len={len(uv)} sample keys={list(uv[0].keys())}") - print(f" MultitaskMultivariateDataset: len={len(mtmv)}") - print(f" MultitaskUnivariateDataset: len={len(mtuv)}") - for name, ds in [("MultivariateDataset", mv), ("UnivariateDataset", uv), - ("MultitaskMultivariateDataset", mtmv), ("MultitaskUnivariateDataset", mtuv)]: - s = ds[0] - for k in ("idx","contexts","predictors","outcomes"): - assert k in s, f"{name} sample missing '{k}'" - print(" ✔ Map-style and key shape checks passed.") - -def run_single_process(dm, x_dim, y_dim, max_epochs=1): - print("\n[RUN] Single-process (CPU) trainer...") - model = TinyLightning(in_dim=x_dim, out_dim=y_dim) - trainer = pl.Trainer(accelerator="cpu", devices=1, max_epochs=max_epochs, - logger=False, enable_progress_bar=False, - default_root_dir=str(Path("./_tmp_sanity").resolve()), - enable_checkpointing=False) - tic = time.time(); trainer.fit(model, datamodule=dm) - outs = trainer.predict(model, datamodule=dm); sec = time.time() - tic - idx = torch.cat([o["idx"] for o in outs]).numpy() - betas = torch.cat([o["betas"] for o in outs]); mus = torch.cat([o["mus"] for o in outs]) - print(f" Predict returned {len(idx)} rows in {sec:.2f}s") - print(f" idx head: {idx[:10]}") - print(f" betas shape: {tuple(betas.shape)}, device={betas.device.type}") - print(f" mus shape: {tuple(mus.shape)}, device={mus.device.type}") - assert betas.device.type == "cpu" and mus.device.type == "cpu" - assert (idx == np.sort(idx)).all() - assert len(np.unique(idx)) == len(idx) - print(" ✔ Single-process checks passed.") - return idx, betas, mus - -def run_ddp(dm, x_dim, y_dim, world_size=2): - print(f"\n[RUN] DDP spawn (CPU, world_size={world_size})...") - # Force local master & explicit init_method to ignore any stale env vars - addr = "127.0.0.1"; port = _free_port() - os.environ["MASTER_ADDR"] = addr - os.environ["MASTER_PORT"] = str(port) - strategy = DDPStrategy(process_group_backend="gloo", - init_method=f"tcp://{addr}:{port}") - model = TinyLightning(in_dim=x_dim, out_dim=y_dim) - trainer = pl.Trainer(accelerator="cpu", devices=world_size, strategy=strategy, - max_epochs=0, logger=False, enable_progress_bar=False, - default_root_dir=str(Path("./_tmp_sanity_ddp").resolve()), - enable_checkpointing=False) - outs = trainer.predict(model, datamodule=dm) - idx = torch.cat([o["idx"] for o in outs]).numpy() - betas = torch.cat([o["betas"] for o in outs]); mus = torch.cat([o["mus"] for o in outs]) - print(f" Gathered rows: {len(idx)} (unique={len(np.unique(idx))})") - print(f" idx head: {idx[:10]}") - print(f" betas shape: {tuple(betas.shape)}, device={betas.device.type}") - print(f" mus shape: {tuple(mus.shape)}, device={mus.device.type}") - assert betas.device.type == "cpu" and mus.device.type == "cpu" - assert len(np.unique(idx)) == len(idx) - print(" ✔ DDP checks passed.") - return idx, betas, mus - -def maybe_try_wrapper(X): - try: - from contextualized.easy.wrappers.SKLearnWrapper import SKLearnWrapper # type: ignore - except Exception as e: - print(f"[INFO] SKLearnWrapper not available ({e}); skipping wrapper test."); return - print("\n[TRY] SKLearnWrapper in-memory vs memory-bounded (if supported)...") - class DummyEstimator: - def fit(self, C, X, Y): return self - def predict(self, X): - if isinstance(X, torch.Tensor): return X.sum(-1, keepdim=True).numpy() - return X.sum(-1, keepdims=True) - try: - wrapper = SKLearnWrapper(estimator=DummyEstimator()) - p1 = np.asarray(wrapper.predict(X, memory_bounded=False)) - p2 = np.asarray(wrapper.predict(X, memory_bounded=True)) - print(f" wrapper outputs shapes: {p1.shape} vs {p2.shape}") - print(" ✔ Wrapper paths match on toy data." if (p1.shape==p2.shape and np.allclose(p1,p2,1e-6,1e-6)) - else " ⚠ Wrapper paths differ on toy data.") - except TypeError as e: - print(f" ⚠ Wrapper signature mismatch: {e} — skipping.") - -def main(): - ap = argparse.ArgumentParser() - ap.add_argument("--n", type=int, default=32) - ap.add_argument("--c-dim", type=int, default=3) - ap.add_argument("--x-dim", type=int, default=5) - ap.add_argument("--y-dim", type=int, default=4) - ap.add_argument("--batch-size", type=int, default=8) - ap.add_argument("--num-workers", type=int, default=0) - ap.add_argument("--ddp", type=int, default=0) - ap.add_argument("--try-wrapper", action="store_true") - args = ap.parse_args() - if not _ctx_ok: sys.exit(2) - C, X, Y = make_tensors(n=args.n, c_dim=args.c_dim, x_dim=args.x_dim, y_dim=args.y_dim) - check_dataset_shapes(C, X, Y) - print("\n[BUILD] ContextualizedRegressionDataModule") - dm = ContextualizedRegressionDataModule( - C=C, X=X, Y=Y, task_type="singletask_multivariate", - batch_size=args.batch_size, num_workers=args.num_workers, - shuffle_eval=False, shuffle_train=True, pin_memory=False, persistent_workers=False, - splitter=simple_splitter, - ) - dm.setup("fit") - idx1, betas1, mus1 = run_single_process(dm, x_dim=args.x_dim, y_dim=args.y_dim) - if args.ddp and args.ddp > 1: - idx2, betas2, mus2 = run_ddp(dm, x_dim=args.x_dim, y_dim=args.y_dim, world_size=args.ddp) - assert set(idx1.tolist()) == set(idx2.tolist()), "DDP vs single-process index coverage mismatch" - print(" ✔ DDP vs single-process index coverage matches.") - if args.try_wrapper: maybe_try_wrapper(X) - print("\n✅ ALL SANITY CHECKS COMPLETED SUCCESSFULLY") - -if __name__ == "__main__": - main() From 29501d4f4a20f55454f7a96e59b5f8403a360f08 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Wed, 5 Nov 2025 15:33:34 -0500 Subject: [PATCH 03/19] HPC scaling + SKLearnWrapper updates --- .../easy/ContextualizedClassifier.py | 18 +- .../easy/wrappers/SKLearnWrapper.py | 234 ++++++++++++------ contextualized/regression/datamodules.py | 32 ++- contextualized/regression/trainers.py | 57 +++-- 4 files changed, 222 insertions(+), 119 deletions(-) diff --git a/contextualized/easy/ContextualizedClassifier.py b/contextualized/easy/ContextualizedClassifier.py index 0f057e6b..75a3b5cf 100644 --- a/contextualized/easy/ContextualizedClassifier.py +++ b/contextualized/easy/ContextualizedClassifier.py @@ -22,7 +22,15 @@ def __init__(self, **kwargs): def predict(self, C, X, individual_preds=False, **kwargs): """Predict binary outcomes from context C and predictors X.""" - return np.round(super().predict(C, X, individual_preds, **kwargs)) + out = super().predict(C, X, individual_preds, **kwargs) + out = np.asarray(out) + if not individual_preds: + if out.ndim == 3 and out.shape[-1] == 1: + out = out[..., 0] + return np.round(out) + # individual_preds=True: list/array per-bootstrap -> squeeze each + return [np.round(p[..., 0] if (p.ndim == 3 and p.shape[-1] == 1) else p) for p in out] + def predict_proba(self, C, X, **kwargs): """ @@ -33,4 +41,10 @@ def predict_proba(self, C, X, **kwargs): np.ndarray of shape (n_samples, y_dim, 2) """ probs = super().predict(C, X, **kwargs) # (n, y_dim[, 1]) - return np.array([1 - probs, probs]).T.swapaxes(0, 1) + probs = np.asarray(probs) + if probs.ndim == 3 and probs.shape[-1] == 1: + probs = probs[..., 0] + p1 = probs + p0 = 1.0 - p1 + return np.stack([p0, p1], axis=-1) + diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index 2e6e1759..be428b56 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -31,10 +31,17 @@ class SKLearnWrapper: """ - An sklearn-like wrapper for Contextualized models, optimized for multi-GPU (DDP) scaling. + An sklearn-like wrapper for Contextualized models. + + Args: + base_constructor (class): Base LightningModule constructor. + extra_model_kwargs (Iterable[str]): Extra model kwargs to accept. + extra_data_kwargs (Iterable[str]): Extra data kwargs to accept. + trainer_constructor (class): Trainer class (usually RegressionTrainer). + normalize (bool): If True, standardize C/X (and Y if continuous). """ - # ---------- defaults ---------- + # -------------------- defaults -------------------- def _set_defaults(self): self.default_learning_rate = DEFAULT_LEARNING_RATE self.default_n_bootstraps = DEFAULT_N_BOOTSTRAPS @@ -72,14 +79,16 @@ def __init__( self.y_dim = None self.accelerator = "gpu" if torch.cuda.is_available() else "cpu" - # Acceptable kwargs routing + # Accepted kwarg routes self.acceptable_kwargs = { "data": [ "train_batch_size", "val_batch_size", "test_batch_size", + "predict_batch_size", "C_val", "X_val", + "Y_val", "val_split", "num_workers", "pin_memory", @@ -117,7 +126,7 @@ def __init__( "num_sanity_val_steps", "default_root_dir", "log_every_n_steps", - "precision", # allow explicit precision override if desired + "precision", "enable_progress_bar", "limit_val_batches", ], @@ -141,6 +150,7 @@ def __init__( "data", kwargs.pop("remove_data_kwargs", []), acceptable=False ) + # Convenience aliases handled at construction self.convenience_kwargs = [ "alpha", "l1_ratio", @@ -151,7 +161,7 @@ def __init__( "encoder_link_fn", ] - # Model constructor kwargs + # Model constructor kwargs (with convenience mapping) self.constructor_kwargs = self._organize_constructor_kwargs(**kwargs) self.constructor_kwargs["encoder_kwargs"]["width"] = kwargs.pop( "width", self.constructor_kwargs["encoder_kwargs"]["width"] @@ -165,6 +175,8 @@ def __init__( "link_fn", self.default_encoder_link_fn ), ) + + # Everything else self.not_constructor_kwargs = { k: v for k, v in kwargs.items() @@ -176,11 +188,10 @@ def __init__( ) for k, v in self.constructor_kwargs.items(): self._init_kwargs["model"][k] = v - if unrecognized: - for kw in unrecognized: - print(f"Received unknown keyword argument {kw}, probably ignoring.") + for kw in unrecognized: + print(f"Received unknown keyword argument {kw}, probably ignoring.") - # ---------- helpers ---------- + # -------------------- helpers -------------------- def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True): if acceptable: self.acceptable_kwargs[category] = list( @@ -241,21 +252,19 @@ def maybe_add(kw, default_val): ) return model - # ---------- internal: sanitize callbacks when no val loop ---------- @staticmethod def _retarget_or_strip_early_stopping(cb, use_val: bool, train_monitor="train_loss"): try: - from pytorch_lightning.callbacks.early_stopping import EarlyStopping + from pytorch_lightning.callbacks.early_stopping import EarlyStopping as _ES except Exception: return cb - if not isinstance(cb, EarlyStopping): + if not isinstance(cb, _ES): return cb if use_val: return cb - # No val loop -> if monitoring val_* (or nothing), rebuild to watch train_loss monitor = getattr(cb, "monitor", None) if (monitor is None) or (isinstance(monitor, str) and monitor.startswith("val_")): - return EarlyStopping( + return _ES( monitor=train_monitor, mode=getattr(cb, "mode", "min"), patience=getattr(cb, "patience", 1), @@ -264,12 +273,13 @@ def _retarget_or_strip_early_stopping(cb, use_val: bool, train_monitor="train_lo ) return cb - # ---------- fit kwarg expansion (with DDP + ES logic) ---------- + # -------------------- fit kwarg expansion -------------------- def _organize_and_expand_fit_kwargs(self, **kwargs): organized, unrecognized = self._organize_kwargs(**kwargs) - # --- FORCE max_epochs to be set (avoid PL default=1000) --- + + # Max epochs (avoid PL default 1000) max_epochs_cli = kwargs.get("max_epochs", None) - epochs_cli = kwargs.get("epochs", None) + epochs_cli = kwargs.get("epochs", None) if max_epochs_cli is not None: organized["trainer"]["max_epochs"] = int(max_epochs_cli) elif epochs_cli is not None: @@ -278,50 +288,44 @@ def _organize_and_expand_fit_kwargs(self, **kwargs): organized["trainer"]["max_epochs"] = 3 world_size = int(os.getenv("WORLD_SIZE", "1")) - use_val = organized["data"].get("val_split", DEFAULT_VAL_SPLIT) > 0.0 + use_val = organized["data"].get("val_split", self.default_val_split) > 0.0 - # Trainer base + # Trainer defaults organized["trainer"].setdefault("accelerator", self.accelerator) organized["trainer"].setdefault("enable_progress_bar", False) organized["trainer"].setdefault("logger", False) organized["trainer"].setdefault("enable_checkpointing", False) organized["trainer"].setdefault("num_sanity_val_steps", 0) - # conservative precision by default (TF32 is controlled globally) organized["trainer"].setdefault("precision", 32) if not use_val: organized["trainer"].setdefault("limit_val_batches", 0) if world_size > 1: organized["trainer"].setdefault("devices", world_size) - strat = organized["trainer"].get("strategy", "auto") - if strat == "auto" or isinstance(strat, str): - organized["trainer"]["strategy"] = DDPStrategy( - find_unused_parameters=False, - static_graph=True, - gradient_as_bucket_view=True, - ) + # Defer concrete object; prefer plain string for factory + organized["trainer"].setdefault("strategy", "ddp") else: organized["trainer"]["devices"] = 1 organized["trainer"].setdefault("strategy", "auto") organized["trainer"].setdefault("plugins", [LightningEnvironment()]) - # Defaults: model/data + # Model defaults def maybe_add(cat, k, default): if k in self.acceptable_kwargs[cat]: organized[cat][k] = organized[cat].get(k, default) - # Model - maybe_add("model", "learning_rate", DEFAULT_LEARNING_RATE) + maybe_add("model", "learning_rate", self.default_learning_rate) maybe_add("model", "context_dim", self.context_dim) maybe_add("model", "x_dim", self.x_dim) maybe_add("model", "y_dim", self.y_dim) if organized["model"].get("num_archetypes", 1) == 0: organized["model"].pop("num_archetypes", None) - # Data (GPU-friendly) - maybe_add("data", "train_batch_size", DEFAULT_TRAIN_BATCH_SIZE) - maybe_add("data", "val_batch_size", DEFAULT_VAL_BATCH_SIZE) - maybe_add("data", "test_batch_size", DEFAULT_TEST_BATCH_SIZE) + # Data defaults (per-loader sizes) + maybe_add("data", "train_batch_size", self.default_train_batch_size) + maybe_add("data", "val_batch_size", self.default_val_batch_size) + maybe_add("data", "test_batch_size", self.default_test_batch_size) + maybe_add("data", "predict_batch_size", self.default_val_batch_size) maybe_add("data", "num_workers", 0) maybe_add("data", "pin_memory", (self.accelerator == "gpu")) maybe_add("data", "persistent_workers", False) @@ -330,45 +334,43 @@ def maybe_add(cat, k, default): maybe_add("data", "shuffle_eval", False) maybe_add("data", "dtype", torch.float) - # Wrapper - maybe_add("wrapper", "n_bootstraps", DEFAULT_N_BOOTSTRAPS) + # Wrapper defaults + maybe_add("wrapper", "n_bootstraps", self.default_n_bootstraps) - # Callbacks (EarlyStopping only if we validate; else watch train_loss) + # EarlyStopping/Checkpoint constructors (sanitized later if no val) es_monitor = organized["wrapper"].get("es_monitor", "val_loss" if use_val else "train_loss") es_mode = organized["wrapper"].get("es_mode", "min") - es_patience = organized["wrapper"].get("es_patience", DEFAULT_ES_PATIENCE) + es_patience = organized["wrapper"].get("es_patience", self.default_es_patience) es_verbose = organized["wrapper"].get("es_verbose", False) es_min_delta = organized["wrapper"].get("es_min_delta", 0.0) - callbacks_list = [] + cb_ctors = organized["trainer"].get("callback_constructors", []) if use_val: - callbacks_list.append( + cb_ctors.append( lambda i: EarlyStopping( monitor=es_monitor, mode=es_mode, patience=es_patience, verbose=es_verbose, min_delta=es_min_delta ) ) if organized["trainer"].get("enable_checkpointing", False): - callbacks_list.append( + cb_ctors.append( lambda i: ModelCheckpoint( monitor="val_loss" if use_val else None, dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints", filename="{epoch}-{val_loss:.4f}" if use_val else "{epoch}", ) ) - organized["trainer"].setdefault("callback_constructors", callbacks_list) + organized["trainer"]["callback_constructors"] = cb_ctors - if unrecognized: - for kw in unrecognized: - print(f"Received unknown keyword argument {kw}, probably ignoring.") + for kw in unrecognized: + print(f"Received unknown keyword argument {kw}, probably ignoring.") - # ---- merge constructor-time defaults as fallbacks ---- + # Merge __init__ defaults as fallbacks for category, cat_kwargs in self._init_kwargs.items(): for k, v in cat_kwargs.items(): organized[category].setdefault(k, v) - # ---- sanitize any pre-specified callbacks for no-val runs ---- - # (handles both direct 'callbacks' and deferred 'callback_constructors') + # Sanitize any pre-specified callbacks for no-val runs cb_list = organized["trainer"].get("callbacks", []) cb_list = [self._retarget_or_strip_early_stopping(cb, use_val) for cb in cb_list] organized["trainer"]["callbacks"] = cb_list @@ -379,12 +381,11 @@ def _wrapped(i): cb = ctor(i) return self._retarget_or_strip_early_stopping(cb, use_val) return _wrapped - ctor_list = [_wrap_ctor(c) for c in ctor_list] - organized["trainer"]["callback_constructors"] = ctor_list + organized["trainer"]["callback_constructors"] = [_wrap_ctor(c) for c in ctor_list] return organized - # ---------- data module builder ---------- + # -------------------- data module builder -------------------- def _build_datamodule( self, C: np.ndarray, @@ -399,10 +400,13 @@ def _build_datamodule( task_type: str = "singletask_multivariate", ) -> ContextualizedRegressionDataModule: dk = dict( - batch_size=self.default_train_batch_size, + train_batch_size=self.default_train_batch_size, + val_batch_size=self.default_val_batch_size, + test_batch_size=self.default_test_batch_size, + predict_batch_size=self.default_val_batch_size, num_workers=0, pin_memory=(self.accelerator == "gpu"), - persistent_workers=False, # caller can override + persistent_workers=False, drop_last=False, shuffle_train=True, shuffle_eval=False, @@ -420,7 +424,10 @@ def _build_datamodule( val_idx=val_idx, test_idx=test_idx, predict_idx=predict_idx, - batch_size=dk["batch_size"], + train_batch_size=dk["train_batch_size"], + val_batch_size=dk["val_batch_size"], + test_batch_size=dk["test_batch_size"], + predict_batch_size=dk["predict_batch_size"], num_workers=dk["num_workers"], pin_memory=dk["pin_memory"], persistent_workers=dk["persistent_workers"], @@ -433,15 +440,38 @@ def _build_datamodule( dm.setup() return dm - # ---------- split helper ---------- - def _split_indices(self, n: int, val_split: float): - if val_split <= 0.0: + # -------------------- split helpers -------------------- + def _split_train_data( + self, + C: np.ndarray, + X: np.ndarray, + Y: Optional[np.ndarray] = None, + *, + Y_required: bool = True, + val_split: Optional[float] = None, + random_state: Optional[int] = None, + shuffle: bool = True, + **_, + ): + """ + Return (train_idx, val_idx) over rows; Lightning will attach DistributedSamplers. + """ + if Y_required and Y is None: + raise ValueError("Y is required but was not provided.") + n = C.shape[0] + vs = self.default_val_split if val_split is None else float(val_split) + if vs <= 0.0: idx = np.arange(n) return idx, None - tr_idx, va_idx = train_test_split(np.arange(n), test_size=val_split, shuffle=True) + tr_idx, va_idx = train_test_split( + np.arange(n), + test_size=vs, + shuffle=shuffle, + random_state=random_state, + ) return tr_idx, va_idx - # ---------- optional scaling ---------- + # -------------------- optional scaling -------------------- def _maybe_scale_C(self, C: np.ndarray) -> np.ndarray: if self.normalize and self.scalers["C"] is not None: return self.scalers["C"].transform(C) @@ -452,7 +482,7 @@ def _maybe_scale_X(self, X: np.ndarray) -> np.ndarray: return self.scalers["X"].transform(X) return X - # ---------- public API ---------- + # -------------------- public API -------------------- def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs): if not hasattr(self, "models") or self.models is None: raise ValueError("Trying to predict with a model that hasn't been trained yet.") @@ -467,7 +497,10 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, C=Cq, X=Xq, Y=Yq, predict_idx=np.arange(len(Cq)), data_kwargs=dict( - batch_size=self._init_kwargs["data"].get("val_batch_size", DEFAULT_VAL_BATCH_SIZE), + train_batch_size=self._init_kwargs["data"].get("train_batch_size", self.default_train_batch_size), + val_batch_size=self._init_kwargs["data"].get("val_batch_size", self.default_val_batch_size), + test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), + predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), num_workers=self._init_kwargs["data"].get("num_workers", 0), pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator == "gpu")), persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), @@ -513,7 +546,10 @@ def predict_params( Y=Y_zero if kwargs.pop("uses_y", True) else None, predict_idx=np.arange(len(Cq)), data_kwargs=dict( - batch_size=self._init_kwargs["data"].get("val_batch_size", DEFAULT_VAL_BATCH_SIZE), + train_batch_size=self._init_kwargs["data"].get("train_batch_size", self.default_train_batch_size), + val_batch_size=self._init_kwargs["data"].get("val_batch_size", self.default_val_batch_size), + test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), + predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), num_workers=self._init_kwargs["data"].get("num_workers", 0), pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator == "gpu")), persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), @@ -540,25 +576,53 @@ def predict_params( def fit(self, *args, **kwargs) -> None: """ Fit contextualized model to data. - Args: - C (np.ndarray): (n, c_dim) - X (np.ndarray): (n, x_dim) - Y (np.ndarray, optional): (n, y_dim) + + Accepts either: + - (C, X, Y) [canonical order], OR + - (X, Y, C) [README order], OR + - kw-only: C=..., X=..., (Y=...) """ self.models, self.trainers = [], [] - C, X = args[0], args[1] + # normalize argument order + C_in = kwargs.pop("C", None) + X_in = kwargs.pop("X", None) + Y_in = kwargs.pop("Y", None) + + if (C_in is not None) and (X_in is not None): + C, X, Y = C_in, X_in, Y_in + else: + if len(args) == 3: + A, B, Carg = args + if A.shape[0] == B.shape[0] == Carg.shape[0]: + if (B.ndim == 1) or (B.ndim == 2 and B.shape[1] <= 4): + X, Y, C = A, B, Carg + else: + C, X, Y = A, B, Carg + else: + raise ValueError("Mismatched sample counts among provided arrays.") + elif len(args) == 2: + A, B = args + if A.shape[0] != B.shape[0]: + raise ValueError("Mismatched sample counts for two-argument fit.") + # Assume (C, X) by default + C, X, Y = A, B, None + else: + raise ValueError("fit expects (C,X[,Y]) or (X,Y,C) or kw-only C=..., X=...") + + # Optional scaling if self.normalize: if self.scalers["C"] is None: self.scalers["C"] = StandardScaler().fit(C) C = self.scalers["C"].transform(C) if self.scalers["X"] is None: self.scalers["X"] = StandardScaler().fit(X) X = self.scalers["X"].transform(X) - self.context_dim = C.shape[-1]; self.x_dim = X.shape[-1] - if len(args) == 3: - Y = args[2] - if kwargs.get("Y", None) is not None: Y = kwargs.get("Y") - if len(Y.shape) == 1: Y = np.expand_dims(Y, 1) + self.context_dim = C.shape[-1] + self.x_dim = X.shape[-1] + + if Y is not None: + if len(Y.shape) == 1: + Y = np.expand_dims(Y, 1) if self.normalize and not np.array_equal(np.unique(Y), np.array([0, 1])): if self.scalers["Y"] is None: self.scalers["Y"] = StandardScaler().fit(Y) Y = self.scalers["Y"].transform(Y) @@ -572,18 +636,22 @@ def fit(self, *args, **kwargs) -> None: self.n_bootstraps = organized["wrapper"].get("n_bootstraps", self.n_bootstraps) n = C.shape[0] - val_split = organized["data"].get("val_split", DEFAULT_VAL_SPLIT) + val_split = organized["data"].get("val_split", self.default_val_split) use_val = val_split > 0.0 for b in range(self.n_bootstraps): - # Build model (LightningModule) + # Model (LightningModule) _model_kwargs = dict(organized["model"]) - _model_kwargs.pop("univariate", None) + _model_kwargs.pop("univariate", None) # handled via task_type below model = self.base_constructor(**_model_kwargs) self.model_ = model # Indices - train_idx, val_idx = self._split_indices(n, val_split) + train_idx, val_idx = self._split_train_data( + C, X, (args[2] if len(args) == 3 else None), + Y_required=(len(args) == 3), + val_split=val_split, + ) test_idx = None # DataModule @@ -592,7 +660,10 @@ def fit(self, *args, **kwargs) -> None: C=args[0], X=args[1], Y=(args[2] if len(args) == 3 else None), train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, data_kwargs=dict( - batch_size=organized["data"].get("train_batch_size", DEFAULT_TRAIN_BATCH_SIZE), + train_batch_size=organized["data"].get("train_batch_size", self.default_train_batch_size), + val_batch_size=organized["data"].get("val_batch_size", self.default_val_batch_size), + test_batch_size=organized["data"].get("test_batch_size", self.default_test_batch_size), + predict_batch_size=organized["data"].get("predict_batch_size", self.default_val_batch_size), num_workers=organized["data"].get("num_workers", 0), pin_memory=organized["data"].get("pin_memory", (self.accelerator == "gpu")), persistent_workers=organized["data"].get("persistent_workers", False), @@ -609,7 +680,7 @@ def fit(self, *args, **kwargs) -> None: trainer_kwargs["callbacks"] = [f(b) for f in trainer_kwargs.get("callback_constructors", [])] trainer_kwargs.pop("callback_constructors", None) - # Build via factory (handles env quirks) + # Build via factory (respects strategy strings and env) from contextualized.regression.trainers import make_trainer_with_env trainer = make_trainer_with_env( self.trainer_constructor, @@ -621,7 +692,7 @@ def fit(self, *args, **kwargs) -> None: if isinstance(cb, ModelCheckpoint): os.makedirs(cb.dirpath, exist_ok=True) - # Fit (don’t pass val loader if no val split) + # Fit (omit val loader if no val split) if use_val and dm.val_dataloader() is not None: trainer.fit( model, @@ -636,9 +707,8 @@ def fit(self, *args, **kwargs) -> None: **organized["fit"], ) - # (Optional) load best ckpt if checkpointing enabled - max_epochs = trainer_kwargs.get("max_epochs", 1) - if max_epochs and trainer_kwargs.get("enable_checkpointing", False): + # Load best checkpoint if enabled + if trainer_kwargs.get("enable_checkpointing", False): ckpt_cb = next((cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)), None) if ckpt_cb and ckpt_cb.best_model_path and os.path.exists(ckpt_cb.best_model_path): best = torch.load(ckpt_cb.best_model_path, map_location="cpu") diff --git a/contextualized/regression/datamodules.py b/contextualized/regression/datamodules.py index 04134f8c..281fb351 100644 --- a/contextualized/regression/datamodules.py +++ b/contextualized/regression/datamodules.py @@ -77,7 +77,10 @@ def __init__( Tuple[IndexLike, IndexLike, IndexLike]] ] = None, # dataloader config - batch_size: int = 32, + train_batch_size: int = 32, + val_batch_size: int = 32, + test_batch_size: int = 32, + predict_batch_size: int = 32, num_workers: int = 0, pin_memory: bool = True, persistent_workers: bool = False, @@ -85,6 +88,7 @@ def __init__( shuffle_train: bool = True, shuffle_eval: bool = False, dtype: torch.dtype = torch.float, + ): super().__init__() if task_type not in TASK_TO_DATASET: @@ -107,7 +111,10 @@ def __init__( self.splitter = splitter # dl config - self.batch_size = batch_size + self.train_batch_size = train_batch_size + self.val_batch_size = val_batch_size + self.test_batch_size = test_batch_size + self.predict_batch_size = predict_batch_size self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = bool(persistent_workers and num_workers > 0) @@ -116,6 +123,7 @@ def __init__( self.shuffle_eval = shuffle_eval self.dtype = dtype + # will be set in setup() self.C: Optional[torch.Tensor] = None self.X: Optional[torch.Tensor] = None @@ -178,22 +186,23 @@ def _mk_dataset(idx: IndexLike): self.C, self.X, self.Y = C, X, Y # ---- Dataloaders ---- - def _common_dl_kwargs(self) -> Dict: + def _common_dl_kwargs(self, batch_size: int) -> Dict: return { - "batch_size": self.batch_size, + "batch_size": batch_size, "num_workers": self.num_workers, "pin_memory": self.pin_memory, "persistent_workers": self.persistent_workers, "drop_last": self.drop_last, } + def train_dataloader(self) -> DataLoader: if self.ds_train is None: raise RuntimeError("train dataset is not set; provide train_idx or splitter.") return DataLoader( dataset=self.ds_train, - shuffle=self.shuffle_train, # True only for train - **self._common_dl_kwargs(), + shuffle=self.shuffle_train, + **self._common_dl_kwargs(self.train_batch_size), ) def val_dataloader(self): @@ -202,25 +211,24 @@ def val_dataloader(self): return DataLoader( dataset=self.ds_val, shuffle=self.shuffle_eval, - **self._common_dl_kwargs(), + **self._common_dl_kwargs(self.val_batch_size), ) - def test_dataloader(self) -> DataLoader: if self.ds_test is None: raise RuntimeError("test dataset is not set; provide test_idx or splitter.") return DataLoader( dataset=self.ds_test, - shuffle=self.shuffle_eval, # False by default - **self._common_dl_kwargs(), + shuffle=self.shuffle_eval, + **self._common_dl_kwargs(self.test_batch_size), ) def predict_dataloader(self) -> DataLoader: if self.ds_predict is None: raise RuntimeError("predict dataset is not set; provide predict_idx/test_idx.") - # IMPORTANT: keep shuffle=False for stable ordering per-rank return DataLoader( dataset=self.ds_predict, shuffle=False, - **self._common_dl_kwargs(), + **self._common_dl_kwargs(self.predict_batch_size), ) + diff --git a/contextualized/regression/trainers.py b/contextualized/regression/trainers.py index c33c66e4..8301ca8a 100644 --- a/contextualized/regression/trainers.py +++ b/contextualized/regression/trainers.py @@ -72,27 +72,25 @@ def predict_y(self, model: pl.LightningModule, dataloader) -> np.ndarray: if not torch.is_tensor(betas): betas = torch.as_tensor(betas) if not torch.is_tensor(mus): mus = torch.as_tensor(mus) - # --- FIX: make shapes broadcastable --- - # --- FIX: make shapes broadcastable for both multivariate (3D) and univariate (4D) --- + # --- shape fixes for multivariate (3D) and univariate (4D) --- # Multivariate convention: X (B, y, x), betas (B, y, x), mus (B, y, 1) # Univariate convention: X (B, y, x, 1), betas (B, y, x, 1), mus (B, y, x, 1) + + # If X is (B, x) and betas is (B, y, x), expand X -> (B, 1, x) if X.dim() == 2 and betas.dim() == 3 and betas.size(-1) == X.size(-1): - # allow X provided as (B, x) for multivariate -> expand to (B,1,x) - X = X.unsqueeze(1) # (B,1,x) + X = X.unsqueeze(1) - if betas.dim() == 3 and X.dim() == 4: - # univariate predict_step may have squeezed betas -> add singleton to match (B,y,x,1) - betas = betas.unsqueeze(-1) # (B,y,x,1) + # If betas is (B, y, x) but X is (B, y, x, 1), add trailing singleton to betas + if betas.dim() == 3 and X.dim() == 4 and betas.size(-1) == X.size(-2): + betas = betas.unsqueeze(-1) - if mus.dim() == 2: - # multivariate: ensure (B,y,1) - mus = mus.unsqueeze(-1) # (B,y,1) + # Ensure mus trailing dim is singleton + if mus.dim() == 2: # (B, y) + mus = mus.unsqueeze(-1) # (B, y, 1) elif mus.dim() == 3 and X.dim() == 4 and mus.size(-1) != 1: - # univariate: ensure trailing singleton (B,y,x,1) if it was (B,y,x) - mus = mus.unsqueeze(-1) - # --- end FIX --- + mus = mus.unsqueeze(-1) # (B, y, x, 1) + # --- end shape fixes --- - # --- end FIX --- yhat = model._predict_y(C, X, betas, mus) # uses model's link y_parts.append(yhat.detach().cpu()) @@ -163,6 +161,8 @@ def predict_precision(self, model: pl.LightningModule, dataloader) -> np.ndarray from contextualized.utils.engine import pick_engine +from pytorch_lightning.strategies import DDPStrategy, Strategy as PLStrategy + def make_trainer_with_env(trainer_cls=RegressionTrainer, **kwargs) -> pl.Trainer: # Respect explicit user settings; otherwise auto-pick accelerator = kwargs.pop("accelerator", None) @@ -170,30 +170,41 @@ def make_trainer_with_env(trainer_cls=RegressionTrainer, **kwargs) -> pl.Trainer strategy = kwargs.pop("strategy", None) plugins = kwargs.pop("plugins", None) - accelerator, devices, strategy = pick_engine( + # If caller provided a concrete Strategy instance, pass it through verbatim + if isinstance(strategy, PLStrategy): + return trainer_cls( + accelerator=("cpu" if accelerator is None else accelerator), + devices=(1 if devices is None else devices), + strategy=strategy, + plugins=plugins, + **kwargs, + ) + + # Otherwise, select engines automatically + accelerator, devices, strategy_name = pick_engine( accelerator=accelerator, devices=devices, - strategy=strategy, - prefer_spawn=True, # allows plain `python script.py` to use all GPUs + strategy=strategy, # may be "ddp" or "auto" + prefer_spawn=True, # allows plain `python script.py` to use all GPUs ) - # If using classic ddp, upgrade string->Strategy with tuned flags - if strategy == "ddp": - strategy = DDPStrategy( + # Upgrade "ddp" string to tuned DDPStrategy + if strategy_name == "ddp": + strategy_obj = DDPStrategy( find_unused_parameters=False, static_graph=True, gradient_as_bucket_view=True, ) + else: + strategy_obj = strategy_name # "auto" or other strings if plugins is None and accelerator == "cpu": - from pytorch_lightning.plugins.environments import LightningEnvironment plugins = [LightningEnvironment()] return trainer_cls( accelerator=accelerator, devices=devices, - strategy=strategy, + strategy=strategy_obj, plugins=plugins, **kwargs, ) - From a0c1f1d2bb8f53d00468c1c1d70d6fce275a5d27 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Sat, 8 Nov 2025 12:16:24 -0500 Subject: [PATCH 04/19] added file for testing GPU scalability --- smoke.py | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 smoke.py diff --git a/smoke.py b/smoke.py new file mode 100644 index 00000000..46b38dbc --- /dev/null +++ b/smoke.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +""" +CPU-only smoke test for Contextualized (no val loop, ~2–5s on a laptop CPU). + +- Forces CPU (disables CUDA) +- Generates a tiny synthetic regression set +- Fits a very small ContextualizedRegressor for 2 epochs +- Prints wall time and a few predictions +""" +import os, time +os.environ["CUDA_VISIBLE_DEVICES"] = "" # force CPU before importing torch/PL + +import numpy as np +from contextualized.easy.ContextualizedRegressor import ContextualizedRegressor + +def make_synth(n=2_000, c_dim=8, x_dim=16, y_dim=1, seed=123): + rng = np.random.default_rng(seed) + C = rng.normal(size=(n, c_dim)).astype(np.float32) + X = rng.normal(size=(n, x_dim)).astype(np.float32) + # context-conditioned linear truth + W = rng.normal(size=(c_dim, x_dim, y_dim)).astype(np.float32) + Y = (C @ W.reshape(c_dim, -1)).reshape(n, x_dim, y_dim) + Y = (X[..., None] * Y).sum(axis=1) + 0.1 * rng.normal(size=(n, y_dim)).astype(np.float32) + return C, X, Y + +def main(): + C, X, Y = make_synth() + + # Tiny model, no validation, CPU-only trainer settings live inside .fit kwargs + model = ContextualizedRegressor( + encoder_type="mlp", + width=16, + layers=2, + learning_rate=1e-3, + univariate=False, # multivariate target OK; here y_dim=1 anyway + ) + + t0 = time.time() + model.fit( + X, Y, C, # README order: (X, Y, C) + # ----- data ----- + train_batch_size=128, + num_workers=0, + val_split=0.0, # <— no val loop (avoids EarlyStopping/val_loss) + # ----- trainer ----- + accelerator="cpu", + devices=1, + strategy="auto", + max_epochs=2, + enable_progress_bar=False, + logger=False, + limit_val_batches=0, + # safety/consistency if callbacks sneak in + es_patience=0, + ) + dt = time.time() - t0 + + # quick predict + yhat = model.predict(C[:8], X[:8]) + print(f"\nDone. Wall time: {dt:.2f}s") + print("Pred sample (first 5 rows):\n", np.asarray(yhat)[:5].round(3)) + +if __name__ == "__main__": + main() From ceb7e9ce450b4d6a32abca1237f3adeeddca855b Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Sat, 8 Nov 2025 12:24:24 -0500 Subject: [PATCH 05/19] change to the constructor --- .../easy/wrappers/SKLearnWrapper.py | 59 ++- contextualized/regression/trainers.py | 1 + scale_bench.py | 430 ++++++++++++++++++ smoke.py | 64 --- 4 files changed, 470 insertions(+), 84 deletions(-) create mode 100644 scale_bench.py delete mode 100644 smoke.py diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index be428b56..a9d3a8e5 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -275,9 +275,16 @@ def _retarget_or_strip_early_stopping(cb, use_val: bool, train_monitor="train_lo # -------------------- fit kwarg expansion -------------------- def _organize_and_expand_fit_kwargs(self, **kwargs): + """ + Expand/normalize kwargs for data/model/trainer/wrapper/fit, and build a clean + configuration dict for downstream construction. Critically: + • Merge constructor-time defaults BEFORE computing use_val. + • Only add EarlyStopping if a val loop exists and patience > 0. + • Retarget or strip EarlyStopping if no val loop. + """ organized, unrecognized = self._organize_kwargs(**kwargs) - # Max epochs (avoid PL default 1000) + # -------- epochs (avoid PL default 1000) -------- max_epochs_cli = kwargs.get("max_epochs", None) epochs_cli = kwargs.get("epochs", None) if max_epochs_cli is not None: @@ -287,10 +294,18 @@ def _organize_and_expand_fit_kwargs(self, **kwargs): else: organized["trainer"]["max_epochs"] = 3 + # -------- merge constructor defaults BEFORE using them -------- + for category, cat_kwargs in self._init_kwargs.items(): + for k, v in cat_kwargs.items(): + organized[category].setdefault(k, v) + + # -------- world size / validation decision -------- world_size = int(os.getenv("WORLD_SIZE", "1")) - use_val = organized["data"].get("val_split", self.default_val_split) > 0.0 + current_val_split = organized["data"].get("val_split", self.default_val_split) + organized["data"]["val_split"] = current_val_split + use_val = float(current_val_split) > 0.0 - # Trainer defaults + # -------- trainer defaults -------- organized["trainer"].setdefault("accelerator", self.accelerator) organized["trainer"].setdefault("enable_progress_bar", False) organized["trainer"].setdefault("logger", False) @@ -302,18 +317,18 @@ def _organize_and_expand_fit_kwargs(self, **kwargs): if world_size > 1: organized["trainer"].setdefault("devices", world_size) - # Defer concrete object; prefer plain string for factory - organized["trainer"].setdefault("strategy", "ddp") + organized["trainer"].setdefault("strategy", "ddp") # string to allow factory else: organized["trainer"]["devices"] = 1 organized["trainer"].setdefault("strategy", "auto") organized["trainer"].setdefault("plugins", [LightningEnvironment()]) - # Model defaults + # Helper to safely set defaults if the key is permitted for that category def maybe_add(cat, k, default): if k in self.acceptable_kwargs[cat]: organized[cat][k] = organized[cat].get(k, default) + # -------- model defaults -------- maybe_add("model", "learning_rate", self.default_learning_rate) maybe_add("model", "context_dim", self.context_dim) maybe_add("model", "x_dim", self.x_dim) @@ -321,7 +336,7 @@ def maybe_add(cat, k, default): if organized["model"].get("num_archetypes", 1) == 0: organized["model"].pop("num_archetypes", None) - # Data defaults (per-loader sizes) + # -------- data defaults (per-loader sizes) -------- maybe_add("data", "train_batch_size", self.default_train_batch_size) maybe_add("data", "val_batch_size", self.default_val_batch_size) maybe_add("data", "test_batch_size", self.default_test_batch_size) @@ -334,10 +349,10 @@ def maybe_add(cat, k, default): maybe_add("data", "shuffle_eval", False) maybe_add("data", "dtype", torch.float) - # Wrapper defaults + # -------- wrapper defaults -------- maybe_add("wrapper", "n_bootstraps", self.default_n_bootstraps) - # EarlyStopping/Checkpoint constructors (sanitized later if no val) + # -------- EarlyStopping / Checkpoint constructors -------- es_monitor = organized["wrapper"].get("es_monitor", "val_loss" if use_val else "train_loss") es_mode = organized["wrapper"].get("es_mode", "min") es_patience = organized["wrapper"].get("es_patience", self.default_es_patience) @@ -345,36 +360,39 @@ def maybe_add(cat, k, default): es_min_delta = organized["wrapper"].get("es_min_delta", 0.0) cb_ctors = organized["trainer"].get("callback_constructors", []) - if use_val: + + # Only add EarlyStopping when there is a val loop AND patience > 0 + if use_val and (es_patience is not None and es_patience > 0): cb_ctors.append( lambda i: EarlyStopping( - monitor=es_monitor, mode=es_mode, patience=es_patience, - verbose=es_verbose, min_delta=es_min_delta + monitor=es_monitor, + mode=es_mode, + patience=es_patience, + verbose=es_verbose, + min_delta=es_min_delta, ) ) + if organized["trainer"].get("enable_checkpointing", False): cb_ctors.append( lambda i: ModelCheckpoint( - monitor="val_loss" if use_val else None, + monitor=("val_loss" if use_val else None), dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints", - filename="{epoch}-{val_loss:.4f}" if use_val else "{epoch}", + filename=("{epoch}-{val_loss:.4f}" if use_val else "{epoch}"), ) ) organized["trainer"]["callback_constructors"] = cb_ctors + # -------- unknown kw logging -------- for kw in unrecognized: print(f"Received unknown keyword argument {kw}, probably ignoring.") - # Merge __init__ defaults as fallbacks - for category, cat_kwargs in self._init_kwargs.items(): - for k, v in cat_kwargs.items(): - organized[category].setdefault(k, v) - - # Sanitize any pre-specified callbacks for no-val runs + # -------- sanitize any pre-specified callbacks for no-val runs -------- cb_list = organized["trainer"].get("callbacks", []) cb_list = [self._retarget_or_strip_early_stopping(cb, use_val) for cb in cb_list] organized["trainer"]["callbacks"] = cb_list + # Also sanitize dynamically constructed callbacks ctor_list = organized["trainer"].get("callback_constructors", []) def _wrap_ctor(ctor): def _wrapped(i): @@ -385,6 +403,7 @@ def _wrapped(i): return organized + # -------------------- data module builder -------------------- def _build_datamodule( self, diff --git a/contextualized/regression/trainers.py b/contextualized/regression/trainers.py index 8301ca8a..04197fae 100644 --- a/contextualized/regression/trainers.py +++ b/contextualized/regression/trainers.py @@ -83,6 +83,7 @@ def predict_y(self, model: pl.LightningModule, dataloader) -> np.ndarray: # If betas is (B, y, x) but X is (B, y, x, 1), add trailing singleton to betas if betas.dim() == 3 and X.dim() == 4 and betas.size(-1) == X.size(-2): betas = betas.unsqueeze(-1) + # Ensure mus trailing dim is singleton if mus.dim() == 2: # (B, y) diff --git a/scale_bench.py b/scale_bench.py new file mode 100644 index 00000000..06d9206d --- /dev/null +++ b/scale_bench.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 +# scale_bench.py +""" +Scalability benchmark for Contextualized-ML (single-node, Lambda GPU instance). + +Runs 4 configs with a FIXED GLOBAL BATCH SIZE: + - 1 CPU + - 1 GPU + - 2 GPUs (DDP) + - 4 GPUs (DDP) + +Outputs: + - results/bench_results.csv + - results/scaling_samples_per_sec.png + - results/scaling_wallclock.png + - results/scaling_epoch_time.png + - results/scaling_convergence_time.png + +Requirements: + - Your package importable on the instance. + - PyTorch + Lightning working w/ CUDA. +""" + +from __future__ import annotations +import argparse, json, math, os, sys, time, subprocess, shutil, uuid, signal +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Dict, Any, Optional, List + +import numpy as np +import torch + +# ---- Your package imports (as provided) ---- +from contextualized.regression.trainers import RegressionTrainer, make_trainer_with_env +from contextualized.regression.models import ContextualizedRegression # adjust path if different +from contextualized.regression.datamodules import ContextualizedRegressionDataModule + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback + +import matplotlib +matplotlib.use("Agg") # headless +import matplotlib.pyplot as plt + + +# ========================= +# Synthetic data generator +# ========================= +def make_synth(n=200_000, c_dim=16, x_dim=64, y_dim=1, noise=0.10, seed=123): + """ + Multivariate regression with context-conditioned parameters: + Y = g( beta(C)*X + mu(C) ) + noise + g = identity (MSE) + """ + rng = np.random.default_rng(seed) + C = rng.normal(size=(n, c_dim)).astype(np.float32) + X = rng.normal(size=(n, x_dim)).astype(np.float32) + + # Context-conditioned weights: low-rank projection from C -> (y_dim, x_dim) and mu + Wc = rng.normal(scale=0.5, size=(c_dim, y_dim * x_dim)).astype(np.float32) + Wm = rng.normal(scale=0.5, size=(c_dim, y_dim)).astype(np.float32) + + beta_flat = C @ Wc # (n, y_dim*x_dim) + beta = beta_flat.reshape(n, y_dim, x_dim) + mu = (C @ Wm).reshape(n, y_dim, 1) # (n, y_dim, 1) + + # Broadcast X to (n, y_dim, x_dim) for multivariate form + Xb = np.expand_dims(X, 1).repeat(y_dim, axis=1) + y_true = (beta * Xb).sum(axis=-1, keepdims=True) + mu # (n, y_dim, 1) + + Y = y_true + noise * rng.normal(size=y_true.shape).astype(np.float32) + Y = Y.squeeze(-1) # (n, y_dim) + return C, X, Y + + +# ========================= +# Metrics callback +# ========================= +class MetricsCallback(Callback): + """ + Collect: + - wall-clock time + - per-epoch time (avg) + - total epochs + - samples/sec (global) + - convergence time & steps (val_loss <= target; or train_loss if no val) + - max memory (GPU or CPU) + Only rank 0 logs final metrics in DDP. + """ + def __init__(self, global_batch_size: int, train_size: int, target_loss: float, use_val: bool): + super().__init__() + self.global_batch = global_batch_size + self.train_size = train_size + self.target_loss = target_loss + self.use_val = use_val + + self.t0 = None + self.epoch_starts = [] + self.epoch_durations = [] + self.total_epochs = 0 + + self.converged = False + self.convergence_epoch = None + self.convergence_time_s = None + self.gradient_steps = None + + self.max_gpu_mem = 0 + self.max_cpu_rss = 0 # placeholder (psutil optional) + + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + self.t0 = time.perf_counter() + + def on_train_epoch_start(self, trainer, pl_module): + self.epoch_starts.append(time.perf_counter()) + + def on_train_epoch_end(self, trainer, pl_module): + t1 = time.perf_counter() + if self.epoch_starts: + self.epoch_durations.append(t1 - self.epoch_starts[-1]) + self.total_epochs += 1 + + # Track GPU memory (max) on rank 0 if CUDA + if torch.cuda.is_available() and torch.cuda.current_device() == 0: + self.max_gpu_mem = max(self.max_gpu_mem, torch.cuda.max_memory_reserved(0)) + + # Convergence check at end of epoch (val preferred) + if not self.converged: + metrics = trainer.callback_metrics + key = "val_loss" if self.use_val and ("val_loss" in metrics) else "train_loss" + loss_val = float(metrics.get(key, float("inf"))) + if loss_val <= self.target_loss: + self.converged = True + self.convergence_epoch = self.total_epochs + self.convergence_time_s = time.perf_counter() - self.t0 + # gradient steps up to and including this epoch + steps_per_epoch = math.ceil(self.train_size / self.global_batch) + self.gradient_steps = steps_per_epoch * self.convergence_epoch + + def on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + pass + + def finalize(self, trainer: pl.Trainer) -> Dict[str, Any]: + wall = time.perf_counter() - self.t0 if self.t0 else None + avg_epoch = (sum(self.epoch_durations) / len(self.epoch_durations)) if self.epoch_durations else None + total_samples = self.train_size * (self.total_epochs if self.total_epochs else 0) + sps = (total_samples / wall) if wall and wall > 0 else None + + # Convert bytes -> GiB for readability + mem_gib = (self.max_gpu_mem / (1024**3)) if self.max_gpu_mem else 0.0 + + return dict( + wall_clock_s=wall, + epoch_time_s=avg_epoch, + total_epochs=self.total_epochs, + samples_per_sec=sps, + convergence_time_s=(self.convergence_time_s if self.converged else None), + gradient_steps=(self.gradient_steps if self.converged else None), + max_memory_gib=mem_gib, + ) + + +# ========================= +# Runner: one configuration +# ========================= +@dataclass +class RunConfig: + label: str # e.g., "cpu-1", "gpu-1", "gpu-2", "gpu-4" + accelerator: str # "cpu" or "gpu" + devices: int # 1,2,4 + strategy: str # "auto" or "ddp" + global_batch: int + +@dataclass +class RunResult: + hardware: str + wall_clock_s: Optional[float] + epoch_time_s: Optional[float] + total_epochs: int + samples_per_sec: Optional[float] + convergence_time_s: Optional[float] + gradient_steps: Optional[int] + max_memory_gib: Optional[float] + +def per_device_batch(global_batch: int, world_size: int) -> int: + if world_size < 1: + world_size = 1 + b = max(1, global_batch // world_size) + if b * world_size != global_batch: + print(f"[warn] global_batch={global_batch} not divisible by world_size={world_size}; " + f"using per_device_batch={b} (effective global={b*world_size}).") + return b + +def run_single_config(args, cfg: RunConfig) -> RunResult: + # Synthesize data + C, X, Y = make_synth( + n=args.n, c_dim=args.c_dim, x_dim=args.x_dim, y_dim=args.y_dim, + noise=args.noise, seed=args.seed + ) + + # Split indices (simple holdout) + n = C.shape[0] + n_val = int(n * args.val_split) + permutation = np.random.default_rng(args.seed).permutation(n) + val_idx = permutation[:n_val] + train_idx = permutation[n_val:] + + world_size = cfg.devices if cfg.accelerator == "gpu" and cfg.strategy == "ddp" else 1 + eff_per_device = per_device_batch(cfg.global_batch, world_size) + + # DataModule (map-style; Lightning will shard w/ DistributedSampler) + dm = ContextualizedRegressionDataModule( + C=C, X=X, Y=Y, + task_type="singletask_multivariate", + train_idx=train_idx, val_idx=val_idx, test_idx=None, + predict_idx=val_idx, + train_batch_size=eff_per_device, + val_batch_size=eff_per_device, + test_batch_size=eff_per_device, + predict_batch_size=eff_per_device, + num_workers=args.num_workers, + pin_memory=(cfg.accelerator == "gpu"), + persistent_workers=bool(args.num_workers > 0), + drop_last=False, + shuffle_train=True, + shuffle_eval=False, + dtype=torch.float, + ) + dm.prepare_data(); dm.setup() + + # Model + model = ContextualizedRegression( + context_dim=args.c_dim, + x_dim=args.x_dim, + y_dim=args.y_dim, + num_archetypes=args.archetypes, + encoder_type="mlp", + encoder_kwargs=dict(width=args.width, layers=args.layers, link_fn="identity"), + learning_rate=args.lr, + metamodel_type="subtype", + fit_intercept=True, + link_fn="identity", + loss_fn="mse", + model_regularizer="none", + ) + + # Metrics + use_val = args.val_split > 0.0 + mcb = MetricsCallback( + global_batch_size=(eff_per_device * world_size), + train_size=len(train_idx), + target_loss=args.target_loss, + use_val=use_val, + ) + + # Trainer (via your factory) + trainer = make_trainer_with_env( + RegressionTrainer, + max_epochs=args.max_epochs, + enable_progress_bar=False, + logger=False, + accelerator=cfg.accelerator, + devices=(cfg.devices if cfg.accelerator == "gpu" else 1), + strategy=cfg.strategy, # "ddp" or "auto" + precision=32, + callbacks=[mcb], + # sanity & val + num_sanity_val_steps=0, + limit_val_batches=(1.0 if use_val else 0.0), + ) + + # Fit + if use_val and dm.val_dataloader() is not None: + trainer.fit(model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader()) + else: + trainer.fit(model, train_dataloaders=dm.train_dataloader()) + + # Finalize metrics (rank-0 only meaningful; in non-ddp it's fine) + metrics = mcb.finalize(trainer) + + return RunResult( + hardware=cfg.label, + wall_clock_s=metrics["wall_clock_s"], + epoch_time_s=metrics["epoch_time_s"], + total_epochs=metrics["total_epochs"], + samples_per_sec=metrics["samples_per_sec"], + convergence_time_s=metrics["convergence_time_s"], + gradient_steps=(int(metrics["gradient_steps"]) if metrics["gradient_steps"] is not None else None), + max_memory_gib=metrics["max_memory_gib"], + ) + + +# ========================= +# Sweep driver (single node) +# ========================= +def run_sweep(args): + results_dir = Path(args.outdir) + results_dir.mkdir(parents=True, exist_ok=True) + table_csv = results_dir / "bench_results.csv" + + # Default sweep: 1 CPU, 1 GPU, 2 GPU, 4 GPU (skip GPU configs if no CUDA) + cuda_ok = torch.cuda.is_available() + + sweep: List[RunConfig] = [ + RunConfig("cpu-1", "cpu", 1, "auto", args.global_batch), + ] + if cuda_ok: + # respect available device count + ndev = torch.cuda.device_count() + if ndev >= 1: sweep.append(RunConfig("gpu-1", "gpu", 1, "auto", args.global_batch)) + if ndev >= 2: sweep.append(RunConfig("gpu-2", "gpu", 2, "ddp", args.global_batch)) + if ndev >= 4: sweep.append(RunConfig("gpu-4", "gpu", 4, "ddp", args.global_batch)) + + rows: List[RunResult] = [] + for cfg in sweep: + print(f"\n=== Running config: {cfg.label} (accelerator={cfg.accelerator}, devices={cfg.devices}, strategy={cfg.strategy}) ===") + rr = run_single_config(args, cfg) + rows.append(rr) + print(f" -> Done {cfg.label}: wall={rr.wall_clock_s:.2f}s, sps={rr.samples_per_sec:.1f}, epochs={rr.total_epochs}") + + # Write CSV + import csv + with table_csv.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(["hardware config", "wall-clock (s)", "epoch time (s)", "total epochs", + "samples/sec", "convergence time (s)", "gradient steps", "max memory (GiB)"]) + for r in rows: + w.writerow([ + r.hardware, + f"{r.wall_clock_s:.6f}" if r.wall_clock_s is not None else "", + f"{r.epoch_time_s:.6f}" if r.epoch_time_s is not None else "", + r.total_epochs, + f"{r.samples_per_sec:.2f}" if r.samples_per_sec is not None else "", + f"{r.convergence_time_s:.6f}" if r.convergence_time_s is not None else "", + r.gradient_steps if r.gradient_steps is not None else "", + f"{r.max_memory_gib:.3f}" if r.max_memory_gib is not None else "", + ]) + print(f"\nSaved table -> {table_csv}") + + # Plots vs #GPUs (CPU shown at 0 GPUs on x-axis) + def hw_to_ngpu(lbl: str) -> int: + if lbl.startswith("cpu"): return 0 + return int(lbl.split("-")[1]) + + rows_sorted = sorted(rows, key=lambda r: hw_to_ngpu(r.hardware)) + x = [hw_to_ngpu(r.hardware) for r in rows_sorted] + + def plot_metric(name: str, vals: List[Optional[float]], fname: str, ylab: str): + xs, ys = [], [] + for xi, v in zip(x, vals): + if v is not None: + xs.append(xi); ys.append(v) + if not xs: + return + plt.figure() + plt.plot(xs, ys, marker="o") + plt.xlabel("# GPUs (CPU plotted as 0)") + plt.ylabel(ylab) + plt.title(f"{name} vs #GPUs") + plt.grid(True) + outp = results_dir / fname + plt.savefig(outp, bbox_inches="tight") + print(f"Saved plot -> {outp}") + + plot_metric("Throughput (samples/sec)", + [r.samples_per_sec for r in rows_sorted], + "scaling_samples_per_sec.png", "samples/sec") + + plot_metric("Wall-clock", + [r.wall_clock_s for r in rows_sorted], + "scaling_wallclock.png", "seconds") + + plot_metric("Epoch time", + [r.epoch_time_s for r in rows_sorted], + "scaling_epoch_time.png", "seconds/epoch") + + plot_metric("Convergence time", + [r.convergence_time_s for r in rows_sorted], + "scaling_convergence_time.png", "seconds") + + +# ========================= +# CLI +# ========================= +def parse_args(): + p = argparse.ArgumentParser(description="Contextualized-ML scalability benchmark") + # Data + p.add_argument("--n", type=int, default=200_000, help="number of samples (synthetic)") + p.add_argument("--c-dim", type=int, default=16) + p.add_argument("--x-dim", type=int, default=64) + p.add_argument("--y-dim", type=int, default=1) + p.add_argument("--noise", type=float, default=0.10) + p.add_argument("--seed", type=int, default=123) + + # Training + p.add_argument("--global-batch", type=int, default=4096, help="fixed global batch across configs") + p.add_argument("--max-epochs", type=int, default=5) + p.add_argument("--val-split", type=float, default=0.1) + p.add_argument("--lr", type=float, default=1e-3) + p.add_argument("--archetypes", type=int, default=8) + p.add_argument("--width", type=int, default=64) + p.add_argument("--layers", type=int, default=3) + p.add_argument("--num-workers", type=int, default=8) + p.add_argument("--target-loss", type=float, default=0.02, + help="convergence threshold on (val_loss or train_loss)") + + # I/O + p.add_argument("--outdir", type=str, default="results") + p.add_argument("--mode", choices=["sweep", "single"], default="sweep", + help="sweep = run CPU+GPU configs; single = run one config given below") + # For --mode single (debugging) + p.add_argument("--single-accel", choices=["cpu", "gpu"], default="cpu") + p.add_argument("--single-devices", type=int, default=1) + p.add_argument("--single-strategy", choices=["auto", "ddp"], default="auto") + + return p.parse_args() + + +def main(): + args = parse_args() + if args.mode == "sweep": + run_sweep(args) + else: + label = f"{args.single_accel}-{args.single_devices}" + cfg = RunConfig(label, args.single_accel, args.single_devices, args.single_strategy, args.global_batch) + rr = run_single_config(args, cfg) + print(json.dumps(asdict(rr), indent=2)) + + +if __name__ == "__main__": + main() diff --git a/smoke.py b/smoke.py deleted file mode 100644 index 46b38dbc..00000000 --- a/smoke.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python3 -""" -CPU-only smoke test for Contextualized (no val loop, ~2–5s on a laptop CPU). - -- Forces CPU (disables CUDA) -- Generates a tiny synthetic regression set -- Fits a very small ContextualizedRegressor for 2 epochs -- Prints wall time and a few predictions -""" -import os, time -os.environ["CUDA_VISIBLE_DEVICES"] = "" # force CPU before importing torch/PL - -import numpy as np -from contextualized.easy.ContextualizedRegressor import ContextualizedRegressor - -def make_synth(n=2_000, c_dim=8, x_dim=16, y_dim=1, seed=123): - rng = np.random.default_rng(seed) - C = rng.normal(size=(n, c_dim)).astype(np.float32) - X = rng.normal(size=(n, x_dim)).astype(np.float32) - # context-conditioned linear truth - W = rng.normal(size=(c_dim, x_dim, y_dim)).astype(np.float32) - Y = (C @ W.reshape(c_dim, -1)).reshape(n, x_dim, y_dim) - Y = (X[..., None] * Y).sum(axis=1) + 0.1 * rng.normal(size=(n, y_dim)).astype(np.float32) - return C, X, Y - -def main(): - C, X, Y = make_synth() - - # Tiny model, no validation, CPU-only trainer settings live inside .fit kwargs - model = ContextualizedRegressor( - encoder_type="mlp", - width=16, - layers=2, - learning_rate=1e-3, - univariate=False, # multivariate target OK; here y_dim=1 anyway - ) - - t0 = time.time() - model.fit( - X, Y, C, # README order: (X, Y, C) - # ----- data ----- - train_batch_size=128, - num_workers=0, - val_split=0.0, # <— no val loop (avoids EarlyStopping/val_loss) - # ----- trainer ----- - accelerator="cpu", - devices=1, - strategy="auto", - max_epochs=2, - enable_progress_bar=False, - logger=False, - limit_val_batches=0, - # safety/consistency if callbacks sneak in - es_patience=0, - ) - dt = time.time() - t0 - - # quick predict - yhat = model.predict(C[:8], X[:8]) - print(f"\nDone. Wall time: {dt:.2f}s") - print("Pred sample (first 5 rows):\n", np.asarray(yhat)[:5].round(3)) - -if __name__ == "__main__": - main() From e81483d2109c1e655383c73b413da5dac55cad76 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Sun, 9 Nov 2025 18:23:35 -0500 Subject: [PATCH 06/19] change to test --- contextualized/data.py | 2 +- .../easy/wrappers/SKLearnWrapper.py | 13 +- contextualized/regression/datamodules.py | 8 +- .../regression/lightning_modules.py | 85 ++- scale_bench.py | 583 +++++++----------- 5 files changed, 262 insertions(+), 429 deletions(-) diff --git a/contextualized/data.py b/contextualized/data.py index 9c46291f..5c8fa627 100644 --- a/contextualized/data.py +++ b/contextualized/data.py @@ -1,5 +1,5 @@ import torch -from lightning import LightningDataModule +from pytorch_lightning import LightningDataModule from contextualized.regression.datasets import MultivariateDataset, UnivariateDataset, MultitaskMultivariateDataset, MultitaskUnivariateDataset from sklearn.model_selection import train_test_split diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index a9d3a8e5..4a4f4b29 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -77,7 +77,7 @@ def __init__( self.context_dim = None self.x_dim = None self.y_dim = None - self.accelerator = "gpu" if torch.cuda.is_available() else "cpu" + self.accelerator = "cuda" if torch.cuda.is_available() else "cpu" # Accepted kwarg routes self.acceptable_kwargs = { @@ -342,7 +342,7 @@ def maybe_add(cat, k, default): maybe_add("data", "test_batch_size", self.default_test_batch_size) maybe_add("data", "predict_batch_size", self.default_val_batch_size) maybe_add("data", "num_workers", 0) - maybe_add("data", "pin_memory", (self.accelerator == "gpu")) + maybe_add("data", "pin_memory", self.accelerator in ("cuda", "gpu")) maybe_add("data", "persistent_workers", False) maybe_add("data", "drop_last", False) maybe_add("data", "shuffle_train", True) @@ -403,7 +403,6 @@ def _wrapped(i): return organized - # -------------------- data module builder -------------------- def _build_datamodule( self, @@ -424,7 +423,7 @@ def _build_datamodule( test_batch_size=self.default_test_batch_size, predict_batch_size=self.default_val_batch_size, num_workers=0, - pin_memory=(self.accelerator == "gpu"), + pin_memory=(self.accelerator in ("cuda", "gpu")), persistent_workers=False, drop_last=False, shuffle_train=True, @@ -521,7 +520,7 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator == "gpu")), + pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator in ("cuda", "gpu"))), persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), shuffle_train=False, shuffle_eval=False, @@ -570,7 +569,7 @@ def predict_params( test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator == "gpu")), + pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator in ("cuda", "gpu"))), persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), shuffle_train=False, shuffle_eval=False, @@ -684,7 +683,7 @@ def fit(self, *args, **kwargs) -> None: test_batch_size=organized["data"].get("test_batch_size", self.default_test_batch_size), predict_batch_size=organized["data"].get("predict_batch_size", self.default_val_batch_size), num_workers=organized["data"].get("num_workers", 0), - pin_memory=organized["data"].get("pin_memory", (self.accelerator == "gpu")), + pin_memory=organized["data"].get("pin_memory", self.accelerator in ("cuda", "gpu")), persistent_workers=organized["data"].get("persistent_workers", False), drop_last=organized["data"].get("drop_last", False), shuffle_train=organized["data"].get("shuffle_train", True), diff --git a/contextualized/regression/datamodules.py b/contextualized/regression/datamodules.py index 281fb351..8aa43f5f 100644 --- a/contextualized/regression/datamodules.py +++ b/contextualized/regression/datamodules.py @@ -31,8 +31,9 @@ def _to_tensor(x: TensorLike, dtype: torch.dtype) -> torch.Tensor: return x.to(dtype=dtype, copy=False) if isinstance(x, (pd.DataFrame, pd.Series)): x = x.to_numpy(copy=False) - # x is now np.ndarray or array-like - return torch.tensor(x, dtype=dtype) + # np.ndarray -> avoid copy where possible + return torch.as_tensor(x, dtype=dtype) + def _maybe_index(x: torch.Tensor, idx: IndexLike) -> torch.Tensor: @@ -191,11 +192,12 @@ def _common_dl_kwargs(self, batch_size: int) -> Dict: "batch_size": batch_size, "num_workers": self.num_workers, "pin_memory": self.pin_memory, - "persistent_workers": self.persistent_workers, + "persistent_workers": bool(self.num_workers > 0 and self.persistent_workers), "drop_last": self.drop_last, } + def train_dataloader(self) -> DataLoader: if self.ds_train is None: raise RuntimeError("train dataset is not set; provide train_idx or splitter.") diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index 5a2af1db..fa7300d9 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -276,14 +276,9 @@ def test_step(self, batch, batch_idx): return loss def _predict_from_models(self, X, beta_hat, mu_hat): - """ - - :param X: - :param beta_hat: - :param mu_hat: + # fused reduction + keepdim avoids extra unsqueeze + return self.link_fn((beta_hat * X).sum(dim=-1, keepdim=True) + mu_hat) - """ - return self.link_fn((beta_hat * X).sum(axis=-1).unsqueeze(-1) + mu_hat) def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -418,6 +413,8 @@ def __init__( base_param_predictor=None, ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.fit_intercept = fit_intercept self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") @@ -473,7 +470,7 @@ def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(batch) batch.update({ "betas": beta_hat, - "mus": mu_hat.squeeze(-1), + "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), }) return batch @@ -558,6 +555,8 @@ def __init__( base_y_predictor=base_y_predictor, base_param_predictor=base_param_predictor ) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + class MultitaskContextualizedRegression(ContextualizedRegressionBase): @@ -581,6 +580,8 @@ def __init__( model_regularizer="none", ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.fit_intercept = fit_intercept self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") @@ -636,18 +637,11 @@ def _predict_y(self, C, X, beta_hat, mu_hat): return Y def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) batch.update({ "betas": beta_hat, - "mus": mu_hat.squeeze(-1), + "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), }) - # Return batch with predictions return batch @@ -723,6 +717,8 @@ def __init__( model_regularizer="none", ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.metamodel_type = metamodel_type self.fit_intercept = fit_intercept @@ -782,20 +778,14 @@ def _predict_y(self, C, X, beta_hat, mu_hat): return Y def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) batch.update({ "betas": beta_hat, - "mus": mu_hat.squeeze(-1), + "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), }) - # Return batch with predictions return batch + # def _batch_loss(self, batch, batch_idx): # """ @@ -888,6 +878,8 @@ def __init__( base_param_predictor=None, ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.fit_intercept = fit_intercept self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") @@ -954,8 +946,8 @@ def predict_step(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) batch.update({ - "betas": beta_hat.squeeze(-1), - "mus": mu_hat.squeeze(-1), + "betas": beta_hat, # keep last dim; downstream handles shape uniformly + "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), }) return batch @@ -1023,6 +1015,8 @@ def __init__( model_regularizer="none", ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.fit_intercept = fit_intercept self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") @@ -1078,19 +1072,14 @@ def _predict_y(self, C, X, beta_hat, mu_hat): return Y def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) batch.update({ - "betas": beta_hat.squeeze(-1), - "mus": mu_hat.squeeze(-1), + "betas": beta_hat, # keep last dim + "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), }) return batch + class TasksplitContextualizedUnivariateRegression(ContextualizedRegressionBase): """See TasksplitMetamodel""" @@ -1120,6 +1109,8 @@ def __init__( model_regularizer="none", ): super().__init__() + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.learning_rate = learning_rate self.fit_intercept = fit_intercept self.link_fn = _resolve_registry_or_callable(link_fn, LINK_FUNCTIONS, "link_fn") @@ -1178,19 +1169,14 @@ def _predict_y(self, C, X, beta_hat, mu_hat): return Y def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) batch.update({ - "betas": beta_hat.squeeze(-1), - "mus": mu_hat.squeeze(-1), + "betas": beta_hat, # keep last dim + "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), }) return batch + # def _params_reshape(self, preds, dataloader): # """ @@ -1246,6 +1232,8 @@ def __init__(self, context_dim, x_dim, **kwargs): if "y_dim" in kwargs: del kwargs["y_dim"] super().__init__(context_dim, x_dim, x_dim, **kwargs) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(batch) @@ -1255,14 +1243,15 @@ def predict_step(self, batch, batch_idx): signs[signs != signs.transpose(1, 2)] = 0 correlations = signs * torch.sqrt(torch.abs(beta_hat * beta_hat_T)) batch.update({ - "betas": beta_hat, # already squeezed - "mus": mu_hat.squeeze(-1), + "betas": beta_hat, + "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), "correlations": correlations, }) return batch + class MultitaskContextualizedCorrelation(MultitaskContextualizedUnivariateRegression): """Using multitask univariate contextualized regression to estimate Pearson's correlation See TasksplitMetamodel for assumptions and full docstring @@ -1274,6 +1263,8 @@ def __init__(self, context_dim, x_dim, **kwargs): if "y_dim" in kwargs: del kwargs["y_dim"] super().__init__(context_dim, x_dim, x_dim, **kwargs) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + class TasksplitContextualizedCorrelation(TasksplitContextualizedUnivariateRegression): @@ -1287,6 +1278,8 @@ def __init__(self, context_dim, x_dim, **kwargs): if "y_dim" in kwargs: del kwargs["y_dim"] super().__init__(context_dim, x_dim, x_dim, **kwargs) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + class ContextualizedNeighborhoodSelection(ContextualizedRegression): @@ -1309,6 +1302,8 @@ def __init__( super().__init__( context_dim, x_dim, x_dim, model_regularizer=model_regularizer, **kwargs ) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) def predict_step(self, batch, batch_idx): @@ -1336,6 +1331,8 @@ def __init__(self, context_dim, x_dim, **kwargs): if "y_dim" in kwargs: del kwargs["y_dim"] super().__init__(context_dim, x_dim, x_dim, **kwargs) + self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) + self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) def predict_step(self, batch, batch_idx): diff --git a/scale_bench.py b/scale_bench.py index 06d9206d..6465437a 100644 --- a/scale_bench.py +++ b/scale_bench.py @@ -1,240 +1,109 @@ #!/usr/bin/env python3 -# scale_bench.py """ -Scalability benchmark for Contextualized-ML (single-node, Lambda GPU instance). - -Runs 4 configs with a FIXED GLOBAL BATCH SIZE: - - 1 CPU - - 1 GPU - - 2 GPUs (DDP) - - 4 GPUs (DDP) +Sweep benchmark for Contextualized-ML scaling. +Runs: CPU, 1-GPU, 2-GPU, 3-GPU, 4-GPU (skips if not available). Outputs: - - results/bench_results.csv - - results/scaling_samples_per_sec.png - - results/scaling_wallclock.png - - results/scaling_epoch_time.png - - results/scaling_convergence_time.png - -Requirements: - - Your package importable on the instance. - - PyTorch + Lightning working w/ CUDA. + - bench_out/scale_results.csv + - bench_out/throughput_vs_devices.png + - bench_out/walltime_vs_devices.png + - bench_out/epoch_time_vs_devices.png """ -from __future__ import annotations -import argparse, json, math, os, sys, time, subprocess, shutil, uuid, signal -from dataclasses import dataclass, asdict +import os, time, argparse, csv, math from pathlib import Path -from typing import Dict, Any, Optional, List - import numpy as np import torch +import matplotlib.pyplot as plt -# ---- Your package imports (as provided) ---- -from contextualized.regression.trainers import RegressionTrainer, make_trainer_with_env -from contextualized.regression.models import ContextualizedRegression # adjust path if different from contextualized.regression.datamodules import ContextualizedRegressionDataModule - -import pytorch_lightning as pl +from contextualized.regression import ContextualizedRegression +from contextualized.regression.trainers import RegressionTrainer, make_trainer_with_env from pytorch_lightning.callbacks import Callback -import matplotlib -matplotlib.use("Agg") # headless -import matplotlib.pyplot as plt - - -# ========================= -# Synthetic data generator -# ========================= -def make_synth(n=200_000, c_dim=16, x_dim=64, y_dim=1, noise=0.10, seed=123): - """ - Multivariate regression with context-conditioned parameters: - Y = g( beta(C)*X + mu(C) ) + noise - g = identity (MSE) - """ - rng = np.random.default_rng(seed) - C = rng.normal(size=(n, c_dim)).astype(np.float32) - X = rng.normal(size=(n, x_dim)).astype(np.float32) - - # Context-conditioned weights: low-rank projection from C -> (y_dim, x_dim) and mu - Wc = rng.normal(scale=0.5, size=(c_dim, y_dim * x_dim)).astype(np.float32) - Wm = rng.normal(scale=0.5, size=(c_dim, y_dim)).astype(np.float32) - - beta_flat = C @ Wc # (n, y_dim*x_dim) - beta = beta_flat.reshape(n, y_dim, x_dim) - mu = (C @ Wm).reshape(n, y_dim, 1) # (n, y_dim, 1) - - # Broadcast X to (n, y_dim, x_dim) for multivariate form - Xb = np.expand_dims(X, 1).repeat(y_dim, axis=1) - y_true = (beta * Xb).sum(axis=-1, keepdims=True) + mu # (n, y_dim, 1) - - Y = y_true + noise * rng.normal(size=y_true.shape).astype(np.float32) - Y = Y.squeeze(-1) # (n, y_dim) +# ----------------- utils ----------------- +def env_defaults(): + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + os.environ.setdefault("OMP_NUM_THREADS", "1") + os.environ.setdefault("MKL_NUM_THREADS", "1") + torch.backends.cuda.matmul.allow_tf32 = True + if hasattr(torch, "set_float32_matmul_precision"): + torch.set_float32_matmul_precision("high") + +def make_synth(n, c_dim, x_dim, y_dim, seed=1337): + rng = np.random.RandomState(seed) + C = rng.randn(n, c_dim).astype("float32") + X = rng.randn(n, x_dim).astype("float32") + W = rng.randn(y_dim, x_dim).astype("float32") + b = rng.randn(y_dim, 1).astype("float32") + Y = (X @ W.T + b.squeeze(-1) + 0.05 * rng.randn(n, y_dim)).astype("float32") return C, X, Y +class TimingCallback(Callback): + """Collect per-epoch timings and global wall time.""" + def __init__(self): + self.epoch_times = [] + self._t_epoch = None + self._t0 = None + self._t1 = None -# ========================= -# Metrics callback -# ========================= -class MetricsCallback(Callback): - """ - Collect: - - wall-clock time - - per-epoch time (avg) - - total epochs - - samples/sec (global) - - convergence time & steps (val_loss <= target; or train_loss if no val) - - max memory (GPU or CPU) - Only rank 0 logs final metrics in DDP. - """ - def __init__(self, global_batch_size: int, train_size: int, target_loss: float, use_val: bool): - super().__init__() - self.global_batch = global_batch_size - self.train_size = train_size - self.target_loss = target_loss - self.use_val = use_val + def on_fit_start(self, trainer, pl_module): + self._t0 = time.perf_counter() - self.t0 = None - self.epoch_starts = [] - self.epoch_durations = [] - self.total_epochs = 0 - - self.converged = False - self.convergence_epoch = None - self.convergence_time_s = None - self.gradient_steps = None - - self.max_gpu_mem = 0 - self.max_cpu_rss = 0 # placeholder (psutil optional) - - def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - self.t0 = time.perf_counter() + def on_fit_end(self, trainer, pl_module): + self._t1 = time.perf_counter() def on_train_epoch_start(self, trainer, pl_module): - self.epoch_starts.append(time.perf_counter()) + self._t_epoch = time.perf_counter() def on_train_epoch_end(self, trainer, pl_module): - t1 = time.perf_counter() - if self.epoch_starts: - self.epoch_durations.append(t1 - self.epoch_starts[-1]) - self.total_epochs += 1 - - # Track GPU memory (max) on rank 0 if CUDA - if torch.cuda.is_available() and torch.cuda.current_device() == 0: - self.max_gpu_mem = max(self.max_gpu_mem, torch.cuda.max_memory_reserved(0)) - - # Convergence check at end of epoch (val preferred) - if not self.converged: - metrics = trainer.callback_metrics - key = "val_loss" if self.use_val and ("val_loss" in metrics) else "train_loss" - loss_val = float(metrics.get(key, float("inf"))) - if loss_val <= self.target_loss: - self.converged = True - self.convergence_epoch = self.total_epochs - self.convergence_time_s = time.perf_counter() - self.t0 - # gradient steps up to and including this epoch - steps_per_epoch = math.ceil(self.train_size / self.global_batch) - self.gradient_steps = steps_per_epoch * self.convergence_epoch - - def on_fit_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - pass - - def finalize(self, trainer: pl.Trainer) -> Dict[str, Any]: - wall = time.perf_counter() - self.t0 if self.t0 else None - avg_epoch = (sum(self.epoch_durations) / len(self.epoch_durations)) if self.epoch_durations else None - total_samples = self.train_size * (self.total_epochs if self.total_epochs else 0) - sps = (total_samples / wall) if wall and wall > 0 else None - - # Convert bytes -> GiB for readability - mem_gib = (self.max_gpu_mem / (1024**3)) if self.max_gpu_mem else 0.0 - - return dict( - wall_clock_s=wall, - epoch_time_s=avg_epoch, - total_epochs=self.total_epochs, - samples_per_sec=sps, - convergence_time_s=(self.convergence_time_s if self.converged else None), - gradient_steps=(self.gradient_steps if self.converged else None), - max_memory_gib=mem_gib, - ) - - -# ========================= -# Runner: one configuration -# ========================= -@dataclass -class RunConfig: - label: str # e.g., "cpu-1", "gpu-1", "gpu-2", "gpu-4" - accelerator: str # "cpu" or "gpu" - devices: int # 1,2,4 - strategy: str # "auto" or "ddp" - global_batch: int - -@dataclass -class RunResult: - hardware: str - wall_clock_s: Optional[float] - epoch_time_s: Optional[float] - total_epochs: int - samples_per_sec: Optional[float] - convergence_time_s: Optional[float] - gradient_steps: Optional[int] - max_memory_gib: Optional[float] - -def per_device_batch(global_batch: int, world_size: int) -> int: - if world_size < 1: - world_size = 1 - b = max(1, global_batch // world_size) - if b * world_size != global_batch: - print(f"[warn] global_batch={global_batch} not divisible by world_size={world_size}; " - f"using per_device_batch={b} (effective global={b*world_size}).") - return b - -def run_single_config(args, cfg: RunConfig) -> RunResult: - # Synthesize data - C, X, Y = make_synth( - n=args.n, c_dim=args.c_dim, x_dim=args.x_dim, y_dim=args.y_dim, - noise=args.noise, seed=args.seed - ) + t = time.perf_counter() + if self._t_epoch is not None: + self.epoch_times.append(t - self._t_epoch) + self._t_epoch = None - # Split indices (simple holdout) - n = C.shape[0] - n_val = int(n * args.val_split) - permutation = np.random.default_rng(args.seed).permutation(n) - val_idx = permutation[:n_val] - train_idx = permutation[n_val:] + @property + def wall_time(self): + if self._t0 is None or self._t1 is None: + return None + return self._t1 - self._t0 - world_size = cfg.devices if cfg.accelerator == "gpu" and cfg.strategy == "ddp" else 1 - eff_per_device = per_device_batch(cfg.global_batch, world_size) - # DataModule (map-style; Lightning will shard w/ DistributedSampler) +def run_one(cfg, data, args): + """ + cfg: dict with keys: + - label: str (e.g., "cpu", "gpu-1", "gpu-2", ...) + - accelerator: "cpu" or None + - devices: "auto" or int + - strategy: "auto" or "ddp" + """ + C, X, Y = data + # datamodule (map-style -> PL autoshard) + pin_mem = (cfg["accelerator"] != "cpu") dm = ContextualizedRegressionDataModule( C=C, X=X, Y=Y, task_type="singletask_multivariate", - train_idx=train_idx, val_idx=val_idx, test_idx=None, - predict_idx=val_idx, - train_batch_size=eff_per_device, - val_batch_size=eff_per_device, - test_batch_size=eff_per_device, - predict_batch_size=eff_per_device, + train_idx=None, val_idx=None, test_idx=None, predict_idx=None, + train_batch_size=args.batch_size, + val_batch_size=args.batch_size, + test_batch_size=args.batch_size, + predict_batch_size=args.batch_size, num_workers=args.num_workers, - pin_memory=(cfg.accelerator == "gpu"), - persistent_workers=bool(args.num_workers > 0), - drop_last=False, + pin_memory=pin_mem, + persistent_workers=bool(args.persistent_workers and args.num_workers > 0), + drop_last=True, shuffle_train=True, shuffle_eval=False, dtype=torch.float, ) - dm.prepare_data(); dm.setup() - # Model - model = ContextualizedRegression( - context_dim=args.c_dim, + model_kwargs = dict( + context_dim=args.context_dim, x_dim=args.x_dim, y_dim=args.y_dim, - num_archetypes=args.archetypes, - encoder_type="mlp", - encoder_kwargs=dict(width=args.width, layers=args.layers, link_fn="identity"), + num_archetypes=args.num_archetypes, + encoder_type=args.encoder_type, + encoder_kwargs={"width": args.width, "layers": args.layers, "link_fn": "identity"}, learning_rate=args.lr, metamodel_type="subtype", fit_intercept=True, @@ -242,189 +111,155 @@ def run_single_config(args, cfg: RunConfig) -> RunResult: loss_fn="mse", model_regularizer="none", ) + model = ContextualizedRegression(**model_kwargs) - # Metrics - use_val = args.val_split > 0.0 - mcb = MetricsCallback( - global_batch_size=(eff_per_device * world_size), - train_size=len(train_idx), - target_loss=args.target_loss, - use_val=use_val, - ) + # precision + prec_map = {"32":32, "64":64, "16":"16-mixed", "bf16":"bf16-mixed"} + precision = prec_map[args.precision] - # Trainer (via your factory) + timing_cb = TimingCallback() trainer = make_trainer_with_env( - RegressionTrainer, - max_epochs=args.max_epochs, - enable_progress_bar=False, + trainer_cls=RegressionTrainer, + max_epochs=args.epochs, + accelerator=cfg["accelerator"], + devices=cfg["devices"], + strategy=cfg["strategy"], logger=False, - accelerator=cfg.accelerator, - devices=(cfg.devices if cfg.accelerator == "gpu" else 1), - strategy=cfg.strategy, # "ddp" or "auto" - precision=32, - callbacks=[mcb], - # sanity & val + enable_progress_bar=False, + enable_checkpointing=False, num_sanity_val_steps=0, - limit_val_batches=(1.0 if use_val else 0.0), - ) - - # Fit - if use_val and dm.val_dataloader() is not None: - trainer.fit(model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader()) - else: - trainer.fit(model, train_dataloaders=dm.train_dataloader()) - - # Finalize metrics (rank-0 only meaningful; in non-ddp it's fine) - metrics = mcb.finalize(trainer) - - return RunResult( - hardware=cfg.label, - wall_clock_s=metrics["wall_clock_s"], - epoch_time_s=metrics["epoch_time_s"], - total_epochs=metrics["total_epochs"], - samples_per_sec=metrics["samples_per_sec"], - convergence_time_s=metrics["convergence_time_s"], - gradient_steps=(int(metrics["gradient_steps"]) if metrics["gradient_steps"] is not None else None), - max_memory_gib=metrics["max_memory_gib"], + precision=precision, + limit_val_batches=0, + limit_train_batches=1.0, + callbacks=[timing_cb], ) + # Warmup 1 epoch for stable timings (optional) + warm_model = ContextualizedRegression(**model_kwargs) + trainer.fit(warm_model, train_dataloaders=dm.train_dataloader()) + + # Timed run + if torch.cuda.is_available(): + torch.cuda.synchronize() + trainer.fit(model, train_dataloaders=dm.train_dataloader()) + if torch.cuda.is_available(): + torch.cuda.synchronize() + + # metrics + # Compute seen samples per epoch (drop_last=True) + steps_per_epoch = math.ceil(len(C) / args.batch_size) + seen_per_epoch = steps_per_epoch * args.batch_size + total_seen = seen_per_epoch * args.epochs + wall = timing_cb.wall_time + throughput = total_seen / wall if wall and wall > 0 else float("nan") + world_size = trainer.num_devices if hasattr(trainer, "num_devices") else 1 + + return { + "label": cfg["label"], + "world_size": world_size, + "batch_size": args.batch_size, + "num_workers": args.num_workers, + "precision": args.precision, + "epochs": args.epochs, + "steps_per_epoch": steps_per_epoch, + "samples_per_epoch": seen_per_epoch, + "wall_time_s": wall, + "throughput_samples_per_s": throughput, + "per_gpu_throughput": throughput / max(1, world_size), + "epoch_times_s": timing_cb.epoch_times, + } + +def plot_one(x, y, xlabel, ylabel, outpng): + plt.figure() + plt.plot(x, y, marker="o") + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.grid(True) + plt.tight_layout() + plt.savefig(outpng, dpi=150) + plt.close() -# ========================= -# Sweep driver (single node) -# ========================= -def run_sweep(args): - results_dir = Path(args.outdir) - results_dir.mkdir(parents=True, exist_ok=True) - table_csv = results_dir / "bench_results.csv" - - # Default sweep: 1 CPU, 1 GPU, 2 GPU, 4 GPU (skip GPU configs if no CUDA) - cuda_ok = torch.cuda.is_available() - - sweep: List[RunConfig] = [ - RunConfig("cpu-1", "cpu", 1, "auto", args.global_batch), - ] - if cuda_ok: - # respect available device count - ndev = torch.cuda.device_count() - if ndev >= 1: sweep.append(RunConfig("gpu-1", "gpu", 1, "auto", args.global_batch)) - if ndev >= 2: sweep.append(RunConfig("gpu-2", "gpu", 2, "ddp", args.global_batch)) - if ndev >= 4: sweep.append(RunConfig("gpu-4", "gpu", 4, "ddp", args.global_batch)) - - rows: List[RunResult] = [] - for cfg in sweep: - print(f"\n=== Running config: {cfg.label} (accelerator={cfg.accelerator}, devices={cfg.devices}, strategy={cfg.strategy}) ===") - rr = run_single_config(args, cfg) - rows.append(rr) - print(f" -> Done {cfg.label}: wall={rr.wall_clock_s:.2f}s, sps={rr.samples_per_sec:.1f}, epochs={rr.total_epochs}") - - # Write CSV - import csv - with table_csv.open("w", newline="") as f: +def main(): + env_defaults() + + ap = argparse.ArgumentParser() + # data/model + ap.add_argument("--n-samples", type=int, default=300_000) + ap.add_argument("--context-dim", type=int, default=32) + ap.add_argument("--x-dim", type=int, default=256) + ap.add_argument("--y-dim", type=int, default=64) + ap.add_argument("--encoder-type", type=str, default="mlp", choices=["mlp","ngam","linear"]) + ap.add_argument("--width", type=int, default=1024) + ap.add_argument("--layers", type=int, default=4) + ap.add_argument("--num-archetypes", type=int, default=10) + ap.add_argument("--lr", type=float, default=1e-3) + ap.add_argument("--epochs", type=int, default=3) + # dataloader + ap.add_argument("--batch-size", type=int, default=2048) + ap.add_argument("--num-workers", type=int, default=8) + ap.add_argument("--persistent-workers", action="store_true", default=True) + # precision + ap.add_argument("--precision", type=str, default="bf16", choices=["32","16","bf16","64"]) + # output + ap.add_argument("--outdir", type=str, default="bench_out") + args = ap.parse_args() + + outdir = Path(args.outdir) + outdir.mkdir(parents=True, exist_ok=True) + + # synth data + data = make_synth(args.n_samples, args.context_dim, args.x_dim, args.y_dim) + + # decide available gpu configs + n_gpus = torch.cuda.device_count() + configs = [{"label":"cpu", "accelerator":"cpu", "devices=ignored":1, "devices":1, "strategy":"auto"}] + for k in [1,2,3,4]: + if n_gpus >= k: + configs.append({"label":f"gpu-{k}", "accelerator":None, "devices":k, "strategy":"ddp"}) + + results = [] + for cfg in configs: + print(f"\n=== Running {cfg['label']} ===") + res = run_one(cfg, data, args) + for k,v in res.items(): + if k != "epoch_times_s": + print(f"{k}: {v}") + results.append(res) + + # write CSV + csv_path = outdir / "scale_results.csv" + with open(csv_path, "w", newline="") as f: w = csv.writer(f) - w.writerow(["hardware config", "wall-clock (s)", "epoch time (s)", "total epochs", - "samples/sec", "convergence time (s)", "gradient steps", "max memory (GiB)"]) - for r in rows: + w.writerow([ + "label","world_size","batch_size","num_workers","precision","epochs", + "steps_per_epoch","samples_per_epoch","wall_time_s","throughput_samples_per_s","per_gpu_throughput","epoch_times_s" + ]) + for r in results: w.writerow([ - r.hardware, - f"{r.wall_clock_s:.6f}" if r.wall_clock_s is not None else "", - f"{r.epoch_time_s:.6f}" if r.epoch_time_s is not None else "", - r.total_epochs, - f"{r.samples_per_sec:.2f}" if r.samples_per_sec is not None else "", - f"{r.convergence_time_s:.6f}" if r.convergence_time_s is not None else "", - r.gradient_steps if r.gradient_steps is not None else "", - f"{r.max_memory_gib:.3f}" if r.max_memory_gib is not None else "", + r["label"], r["world_size"], r["batch_size"], r["num_workers"], r["precision"], r["epochs"], + r["steps_per_epoch"], r["samples_per_epoch"], f"{r['wall_time_s']:.6f}", + f"{r['throughput_samples_per_s']:.3f}", f"{r['per_gpu_throughput']:.3f}", + ";".join(f"{et:.6f}" for et in r["epoch_times_s"]) ]) - print(f"\nSaved table -> {table_csv}") - - # Plots vs #GPUs (CPU shown at 0 GPUs on x-axis) - def hw_to_ngpu(lbl: str) -> int: - if lbl.startswith("cpu"): return 0 - return int(lbl.split("-")[1]) - - rows_sorted = sorted(rows, key=lambda r: hw_to_ngpu(r.hardware)) - x = [hw_to_ngpu(r.hardware) for r in rows_sorted] - - def plot_metric(name: str, vals: List[Optional[float]], fname: str, ylab: str): - xs, ys = [], [] - for xi, v in zip(x, vals): - if v is not None: - xs.append(xi); ys.append(v) - if not xs: - return - plt.figure() - plt.plot(xs, ys, marker="o") - plt.xlabel("# GPUs (CPU plotted as 0)") - plt.ylabel(ylab) - plt.title(f"{name} vs #GPUs") - plt.grid(True) - outp = results_dir / fname - plt.savefig(outp, bbox_inches="tight") - print(f"Saved plot -> {outp}") - - plot_metric("Throughput (samples/sec)", - [r.samples_per_sec for r in rows_sorted], - "scaling_samples_per_sec.png", "samples/sec") - - plot_metric("Wall-clock", - [r.wall_clock_s for r in rows_sorted], - "scaling_wallclock.png", "seconds") - - plot_metric("Epoch time", - [r.epoch_time_s for r in rows_sorted], - "scaling_epoch_time.png", "seconds/epoch") - - plot_metric("Convergence time", - [r.convergence_time_s for r in rows_sorted], - "scaling_convergence_time.png", "seconds") - - -# ========================= -# CLI -# ========================= -def parse_args(): - p = argparse.ArgumentParser(description="Contextualized-ML scalability benchmark") - # Data - p.add_argument("--n", type=int, default=200_000, help="number of samples (synthetic)") - p.add_argument("--c-dim", type=int, default=16) - p.add_argument("--x-dim", type=int, default=64) - p.add_argument("--y-dim", type=int, default=1) - p.add_argument("--noise", type=float, default=0.10) - p.add_argument("--seed", type=int, default=123) - - # Training - p.add_argument("--global-batch", type=int, default=4096, help="fixed global batch across configs") - p.add_argument("--max-epochs", type=int, default=5) - p.add_argument("--val-split", type=float, default=0.1) - p.add_argument("--lr", type=float, default=1e-3) - p.add_argument("--archetypes", type=int, default=8) - p.add_argument("--width", type=int, default=64) - p.add_argument("--layers", type=int, default=3) - p.add_argument("--num-workers", type=int, default=8) - p.add_argument("--target-loss", type=float, default=0.02, - help="convergence threshold on (val_loss or train_loss)") - - # I/O - p.add_argument("--outdir", type=str, default="results") - p.add_argument("--mode", choices=["sweep", "single"], default="sweep", - help="sweep = run CPU+GPU configs; single = run one config given below") - # For --mode single (debugging) - p.add_argument("--single-accel", choices=["cpu", "gpu"], default="cpu") - p.add_argument("--single-devices", type=int, default=1) - p.add_argument("--single-strategy", choices=["auto", "ddp"], default="auto") - - return p.parse_args() - - -def main(): - args = parse_args() - if args.mode == "sweep": - run_sweep(args) - else: - label = f"{args.single_accel}-{args.single_devices}" - cfg = RunConfig(label, args.single_accel, args.single_devices, args.single_strategy, args.global_batch) - rr = run_single_config(args, cfg) - print(json.dumps(asdict(rr), indent=2)) - + print(f"\n[Saved] {csv_path}") + + # plots + xs = [ (0 if r["label"]=="cpu" else r["world_size"]) for r in results ] + labels = [r["label"] for r in results] + + # Throughput vs devices (GPU counts; CPU plotted at 0) + plot_one(xs, [r["throughput_samples_per_s"] for r in results], + "Devices (0=CPU)", "Throughput (samples/s)", str(outdir / "throughput_vs_devices.png")) + # Wall time vs devices + plot_one(xs, [r["wall_time_s"] for r in results], + "Devices (0=CPU)", "Wall time (s)", str(outdir / "walltime_vs_devices.png")) + # Mean epoch time vs devices + plot_one(xs, [float(np.mean(r["epoch_times_s"])) for r in results], + "Devices (0=CPU)", "Mean epoch time (s)", str(outdir / "epoch_time_vs_devices.png")) + + print(f"[Saved] {outdir/'throughput_vs_devices.png'}") + print(f"[Saved] {outdir/'walltime_vs_devices.png'}") + print(f"[Saved] {outdir/'epoch_time_vs_devices.png'}") if __name__ == "__main__": main() From ae66e8062f93856b8efe75bd612b7f19ce66fbb7 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Mon, 10 Nov 2025 20:50:57 -0500 Subject: [PATCH 07/19] scale bench update --- scale_bench.py | 536 ++++++++++++++++++++++++++++++------------------- 1 file changed, 328 insertions(+), 208 deletions(-) diff --git a/scale_bench.py b/scale_bench.py index 6465437a..b2aa337f 100644 --- a/scale_bench.py +++ b/scale_bench.py @@ -1,265 +1,385 @@ #!/usr/bin/env python3 -""" -Sweep benchmark for Contextualized-ML scaling. - -Runs: CPU, 1-GPU, 2-GPU, 3-GPU, 4-GPU (skips if not available). -Outputs: - - bench_out/scale_results.csv - - bench_out/throughput_vs_devices.png - - bench_out/walltime_vs_devices.png - - bench_out/epoch_time_vs_devices.png -""" - -import os, time, argparse, csv, math -from pathlib import Path +import os, time, csv, argparse, math, json +from dataclasses import dataclass +from typing import List, Dict +from datetime import timedelta + import numpy as np import torch -import matplotlib.pyplot as plt +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.strategies import DDPStrategy -from contextualized.regression.datamodules import ContextualizedRegressionDataModule +# ---- your package pieces ---- from contextualized.regression import ContextualizedRegression -from contextualized.regression.trainers import RegressionTrainer, make_trainer_with_env -from pytorch_lightning.callbacks import Callback +from contextualized.regression.datamodules import ContextualizedRegressionDataModule -# ----------------- utils ----------------- -def env_defaults(): - os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + +# ---------------- utils ---------------- +def set_env_defaults(): + # Light, deterministic-friendly defaults os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") - torch.backends.cuda.matmul.allow_tf32 = True - if hasattr(torch, "set_float32_matmul_precision"): - torch.set_float32_matmul_precision("high") - -def make_synth(n, c_dim, x_dim, y_dim, seed=1337): - rng = np.random.RandomState(seed) - C = rng.randn(n, c_dim).astype("float32") - X = rng.randn(n, x_dim).astype("float32") - W = rng.randn(y_dim, x_dim).astype("float32") - b = rng.randn(y_dim, 1).astype("float32") - Y = (X @ W.T + b.squeeze(-1) + 0.05 * rng.randn(n, y_dim)).astype("float32") - return C, X, Y + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") -class TimingCallback(Callback): - """Collect per-epoch timings and global wall time.""" + # Prefer new PyTorch var (2.4+); avoid deprecated NCCL_ASYNC_ERROR_HANDLING + os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") + os.environ.setdefault("NCCL_DEBUG", "WARN") + os.environ.setdefault("NCCL_P2P_DISABLE", "0") + # Most cloud nodes lack IB; default it off for reliability + os.environ.setdefault("NCCL_IB_DISABLE", "1") + + # If user didn't set NCCL_SOCKET_IFNAME, auto-pick a sane one + if "NCCL_SOCKET_IFNAME" not in os.environ: + try: + ifaces = [d for d in os.listdir("/sys/class/net") if os.path.isdir(f"/sys/class/net/{d}")] + cand = next((i for i in ifaces if i not in ("lo", "docker0")), None) + os.environ["NCCL_SOCKET_IFNAME"] = cand or "lo" + except Exception: + os.environ["NCCL_SOCKET_IFNAME"] = "lo" + + # Unique rendezvous per run + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", str(12355 + (os.getpid() % 20000))) + + if int(os.environ.get("RANK", "0")) == 0: + keys = ["NCCL_DEBUG","NCCL_IB_DISABLE","NCCL_P2P_DISABLE","NCCL_SOCKET_IFNAME","MASTER_ADDR","MASTER_PORT"] + print("DDP/NCCL env:", {k: os.environ.get(k) for k in keys}) + + +def map_precision(p): + p = (p or "").lower() + if p in ("bf16", "bfloat16", "bf16-mixed"): + return "bf16-mixed" + if p in ("fp16", "16", "16-mixed"): + return "16-mixed" + return 32 # full precision + + +class EpochTimer(Callback): def __init__(self): + self._epoch_start = None self.epoch_times = [] - self._t_epoch = None - self._t0 = None - self._t1 = None - def on_fit_start(self, trainer, pl_module): - self._t0 = time.perf_counter() - - def on_fit_end(self, trainer, pl_module): - self._t1 = time.perf_counter() + @staticmethod + def _using_cuda(trainer) -> bool: + try: + return trainer.accelerator is not None and "cuda" in str(trainer.accelerator).lower() + except Exception: + return torch.cuda.is_available() def on_train_epoch_start(self, trainer, pl_module): - self._t_epoch = time.perf_counter() + if self._using_cuda(trainer): + torch.cuda.synchronize() + self._epoch_start = time.time() def on_train_epoch_end(self, trainer, pl_module): - t = time.perf_counter() - if self._t_epoch is not None: - self.epoch_times.append(t - self._t_epoch) - self._t_epoch = None - - @property - def wall_time(self): - if self._t0 is None or self._t1 is None: - return None - return self._t1 - self._t0 - - -def run_one(cfg, data, args): - """ - cfg: dict with keys: - - label: str (e.g., "cpu", "gpu-1", "gpu-2", ...) - - accelerator: "cpu" or None - - devices: "auto" or int - - strategy: "auto" or "ddp" - """ - C, X, Y = data - # datamodule (map-style -> PL autoshard) - pin_mem = (cfg["accelerator"] != "cpu") + if self._using_cuda(trainer): + torch.cuda.synchronize() + self.epoch_times.append(time.time() - self._epoch_start) + + +# ---------------- synthetic data ---------------- +def make_synthetic(n, c_dim, x_dim, y_dim, seed=42): + rng = np.random.default_rng(seed) + C = rng.standard_normal((n, c_dim)).astype(np.float32) + X = rng.standard_normal((n, x_dim)).astype(np.float32) + W = rng.standard_normal((y_dim, x_dim)).astype(np.float32) + MU = rng.standard_normal((y_dim, 1)).astype(np.float32) + Y = (X @ W.T) + MU.squeeze(-1) + 0.01 * rng.standard_normal((n, y_dim)).astype(np.float32) + return C, X, Y + + +# ---------------- model/trainer builders ---------------- +def build_model(c_dim, x_dim, y_dim, width, layers, lr): + model = ContextualizedRegression( + context_dim=c_dim, + x_dim=x_dim, + y_dim=y_dim, + num_archetypes=8, + encoder_type="mlp", + encoder_kwargs={"width": width, "layers": layers, "link_fn": "identity"}, + learning_rate=lr, + fit_intercept=True, + link_fn="identity", + loss_fn="mse", + model_regularizer="none", + ) + return model + + +def build_dm( + C, X, Y, + train_batch_size: int, + num_workers: int, + pin_memory: bool, +): + n = C.shape[0] + perm = np.random.permutation(n) + n_train = int(0.9 * n) + train_idx = perm[:n_train] + val_idx = perm[n_train:] + dm = ContextualizedRegressionDataModule( C=C, X=X, Y=Y, task_type="singletask_multivariate", - train_idx=None, val_idx=None, test_idx=None, predict_idx=None, - train_batch_size=args.batch_size, - val_batch_size=args.batch_size, - test_batch_size=args.batch_size, - predict_batch_size=args.batch_size, - num_workers=args.num_workers, - pin_memory=pin_mem, - persistent_workers=bool(args.persistent_workers and args.num_workers > 0), + train_idx=train_idx, + val_idx=val_idx, + test_idx=None, + predict_idx=None, + train_batch_size=train_batch_size, + val_batch_size=train_batch_size, + test_batch_size=train_batch_size, + predict_batch_size=train_batch_size, + num_workers=num_workers, + pin_memory=bool(pin_memory), + persistent_workers=bool(num_workers > 0), drop_last=True, shuffle_train=True, shuffle_eval=False, dtype=torch.float, ) - - model_kwargs = dict( - context_dim=args.context_dim, - x_dim=args.x_dim, - y_dim=args.y_dim, - num_archetypes=args.num_archetypes, - encoder_type=args.encoder_type, - encoder_kwargs={"width": args.width, "layers": args.layers, "link_fn": "identity"}, - learning_rate=args.lr, - metamodel_type="subtype", - fit_intercept=True, - link_fn="identity", - loss_fn="mse", - model_regularizer="none", - ) - model = ContextualizedRegression(**model_kwargs) - - # precision - prec_map = {"32":32, "64":64, "16":"16-mixed", "bf16":"bf16-mixed"} - precision = prec_map[args.precision] - - timing_cb = TimingCallback() - trainer = make_trainer_with_env( - trainer_cls=RegressionTrainer, - max_epochs=args.epochs, - accelerator=cfg["accelerator"], - devices=cfg["devices"], - strategy=cfg["strategy"], + dm.prepare_data(); dm.setup() + return dm + + +def build_trainer(devices, precision, epochs, ddp_timeout_s=120): + if devices == 0: + accelerator = "cpu" + devices = 1 + strategy = "auto" + elif devices == 1: + accelerator = "gpu" + strategy = "auto" + else: + accelerator = "gpu" + strategy = DDPStrategy( + start_method="spawn", + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + timeout=timedelta(seconds=ddp_timeout_s), + ) + timer = EpochTimer() + trainer = pl.Trainer( + accelerator=accelerator, + devices=devices, + strategy=strategy, + precision=precision, + max_epochs=epochs, logger=False, - enable_progress_bar=False, enable_checkpointing=False, num_sanity_val_steps=0, - precision=precision, - limit_val_batches=0, - limit_train_batches=1.0, - callbacks=[timing_cb], + enable_progress_bar=False, + log_every_n_steps=50, + callbacks=[timer], + inference_mode=False, ) + return trainer, timer - # Warmup 1 epoch for stable timings (optional) - warm_model = ContextualizedRegression(**model_kwargs) - trainer.fit(warm_model, train_dataloaders=dm.train_dataloader()) - # Timed run +# ---------------- benchmark runner ---------------- +@dataclass +class BenchCfg: + label: str + devices: int # 0=cpu, >=1 gpus + + +def run_once(cfg: BenchCfg, C, X, Y, args) -> Dict: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # datamodule + dm = build_dm( + C, X, Y, + train_batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=(cfg.devices >= 1), + ) + # model + model = build_model(args.context_dim, args.x_dim, args.y_dim, + args.width, args.layers, args.lr) + + # ---- warm-up on the SAME accelerator config ---- + tiny = max(1024, math.ceil(0.01 * C.shape[0])) + dm_warm = build_dm( + C[:tiny], X[:tiny], Y[:tiny], + train_batch_size=args.batch_size, + num_workers=0, + pin_memory=(cfg.devices >= 1), + ) + warm_trainer, _ = build_trainer( + devices=cfg.devices, # cpu: 0, 1-gpu: 1, multi: k + precision=map_precision(args.precision), + epochs=1, + ddp_timeout_s=args.ddp_timeout, + ) + warm_trainer.fit(model, train_dataloaders=dm_warm.train_dataloader()) + + # ---- main timed run ---- + trainer, timer = build_trainer( + devices=cfg.devices, + precision=map_precision(args.precision), + epochs=args.epochs, + ddp_timeout_s=args.ddp_timeout, + ) + if torch.cuda.is_available(): torch.cuda.synchronize() + t0 = time.time() trainer.fit(model, train_dataloaders=dm.train_dataloader()) if torch.cuda.is_available(): torch.cuda.synchronize() + wall = time.time() - t0 + + # metrics (use actual train size, not full N) + train_samples = len(dm.train_dataloader().dataset) + samples_total = train_samples * args.epochs + throughput = samples_total / max(wall, 1e-9) + per_device = (throughput / max(cfg.devices, 1)) if cfg.devices >= 1 else throughput + epoch_times = timer.epoch_times[:] # seconds per epoch + + res = dict( + label=cfg.label, + devices=cfg.devices, + wall_seconds=wall, + samples_total=int(samples_total), + throughput_samples_per_s=throughput, + per_device_throughput=per_device, + steps_per_epoch=math.ceil(train_samples / args.batch_size), + samples_per_epoch=int(train_samples), + epoch_times=epoch_times, + ) + if int(os.environ.get("RANK", "0")) == 0: + print(json.dumps({ + "label": res["label"], + "devices": res["devices"], + "wall_s": round(res["wall_seconds"], 3), + "throughput_sps": round(res["throughput_samples_per_s"], 2), + "per_device_sps": round(res["per_device_throughput"], 2), + "avg_epoch_s": round(float(np.mean(res["epoch_times"])) if res["epoch_times"] else float("nan"), 3) + }, indent=2)) + return res + + +def save_csv(rows: List[Dict], outdir: str): + os.makedirs(outdir, exist_ok=True) + path = os.path.join(outdir, "scale_results.csv") + fields = ["label","devices","wall_seconds","samples_total", + "throughput_samples_per_s","per_device_throughput", + "steps_per_epoch","samples_per_epoch","epoch_times"] + with open(path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=fields) + w.writeheader() + for r in rows: + r2 = r.copy() + r2["epoch_times"] = ";".join(f"{x:.6f}" for x in r["epoch_times"]) + w.writerow(r2) + return path + + +def plot_curves(rows: List[Dict], outdir: str): + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import numpy as np + + os.makedirs(outdir, exist_ok=True) + labels = [r["label"] for r in rows] + devs = [r["devices"] for r in rows] + thr = [r["throughput_samples_per_s"] for r in rows] + wall = [r["wall_seconds"] for r in rows] + avg_epoch = [np.mean(r["epoch_times"]) if r["epoch_times"] else float("nan") for r in rows] + + # Throughput + plt.figure() + plt.plot(devs, thr, marker="o") + plt.xticks(devs, labels, rotation=30, ha="right") + plt.xlabel("Configuration") + plt.ylabel("Throughput (samples/s)") + plt.title("Throughput vs Devices") + plt.tight_layout() + plt.savefig(os.path.join(outdir, "throughput_vs_devices.png")) + plt.close() - # metrics - # Compute seen samples per epoch (drop_last=True) - steps_per_epoch = math.ceil(len(C) / args.batch_size) - seen_per_epoch = steps_per_epoch * args.batch_size - total_seen = seen_per_epoch * args.epochs - wall = timing_cb.wall_time - throughput = total_seen / wall if wall and wall > 0 else float("nan") - world_size = trainer.num_devices if hasattr(trainer, "num_devices") else 1 - - return { - "label": cfg["label"], - "world_size": world_size, - "batch_size": args.batch_size, - "num_workers": args.num_workers, - "precision": args.precision, - "epochs": args.epochs, - "steps_per_epoch": steps_per_epoch, - "samples_per_epoch": seen_per_epoch, - "wall_time_s": wall, - "throughput_samples_per_s": throughput, - "per_gpu_throughput": throughput / max(1, world_size), - "epoch_times_s": timing_cb.epoch_times, - } - -def plot_one(x, y, xlabel, ylabel, outpng): + # Wall time plt.figure() - plt.plot(x, y, marker="o") - plt.xlabel(xlabel) - plt.ylabel(ylabel) - plt.grid(True) + plt.plot(devs, wall, marker="o") + plt.xticks(devs, labels, rotation=30, ha="right") + plt.xlabel("Configuration") + plt.ylabel("Total Wall Time (s)") + plt.title("Wall Time vs Devices") plt.tight_layout() - plt.savefig(outpng, dpi=150) + plt.savefig(os.path.join(outdir, "walltime_vs_devices.png")) plt.close() -def main(): - env_defaults() + # Avg epoch time + plt.figure() + plt.plot(devs, avg_epoch, marker="o") + plt.xticks(devs, labels, rotation=30, ha="right") + plt.xlabel("Configuration") + plt.ylabel("Avg Train Epoch Time (s)") + plt.title("Epoch Time vs Devices") + plt.tight_layout() + plt.savefig(os.path.join(outdir, "epoch_time_vs_devices.png")) + plt.close() + +def is_global_zero() -> bool: + return int(os.environ.get("RANK", "0")) == 0 + + +# ---------------- main ---------------- +def parse_args(): ap = argparse.ArgumentParser() - # data/model - ap.add_argument("--n-samples", type=int, default=300_000) - ap.add_argument("--context-dim", type=int, default=32) - ap.add_argument("--x-dim", type=int, default=256) + ap.add_argument("--epochs", type=int, default=5) + ap.add_argument("--batch-size", type=int, default=2048) + ap.add_argument("--num-workers", type=int, default=8) + ap.add_argument("--precision", type=str, default="bf16") + ap.add_argument("--n", type=int, default=2_000_000) + ap.add_argument("--context-dim", type=int, default=16) + ap.add_argument("--x-dim", type=int, default=512) ap.add_argument("--y-dim", type=int, default=64) - ap.add_argument("--encoder-type", type=str, default="mlp", choices=["mlp","ngam","linear"]) ap.add_argument("--width", type=int, default=1024) ap.add_argument("--layers", type=int, default=4) - ap.add_argument("--num-archetypes", type=int, default=10) ap.add_argument("--lr", type=float, default=1e-3) - ap.add_argument("--epochs", type=int, default=3) - # dataloader - ap.add_argument("--batch-size", type=int, default=2048) - ap.add_argument("--num-workers", type=int, default=8) - ap.add_argument("--persistent-workers", action="store_true", default=True) - # precision - ap.add_argument("--precision", type=str, default="bf16", choices=["32","16","bf16","64"]) - # output ap.add_argument("--outdir", type=str, default="bench_out") - args = ap.parse_args() + ap.add_argument("--ddp-timeout", type=int, default=120) + ap.add_argument("--max-gpus", type=int, default=4) + return ap.parse_args() + - outdir = Path(args.outdir) - outdir.mkdir(parents=True, exist_ok=True) +def main(): + set_env_defaults() + args = parse_args() + os.makedirs(args.outdir, exist_ok=True) + + if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True # optional micro-optim for fixed shapes - # synth data - data = make_synth(args.n_samples, args.context_dim, args.x_dim, args.y_dim) + # data once + C, X, Y = make_synthetic(args.n, args.context_dim, args.x_dim, args.y_dim) - # decide available gpu configs - n_gpus = torch.cuda.device_count() - configs = [{"label":"cpu", "accelerator":"cpu", "devices=ignored":1, "devices":1, "strategy":"auto"}] - for k in [1,2,3,4]: - if n_gpus >= k: - configs.append({"label":f"gpu-{k}", "accelerator":None, "devices":k, "strategy":"ddp"}) + # configs: CPU + 1..available GPUs (cap at --max-gpus) + gpus = torch.cuda.device_count() + dev_list = [BenchCfg("cpu", 0)] + for k in range(1, min(args.max_gpus, gpus) + 1): + dev_list.append(BenchCfg(f"gpu-{k}", k)) results = [] - for cfg in configs: - print(f"\n=== Running {cfg['label']} ===") - res = run_one(cfg, data, args) - for k,v in res.items(): - if k != "epoch_times_s": - print(f"{k}: {v}") + for cfg in dev_list: + if is_global_zero(): + print(f"\n=== Running {cfg.label} ===") + res = run_once(cfg, C, X, Y, args) results.append(res) - # write CSV - csv_path = outdir / "scale_results.csv" - with open(csv_path, "w", newline="") as f: - w = csv.writer(f) - w.writerow([ - "label","world_size","batch_size","num_workers","precision","epochs", - "steps_per_epoch","samples_per_epoch","wall_time_s","throughput_samples_per_s","per_gpu_throughput","epoch_times_s" - ]) - for r in results: - w.writerow([ - r["label"], r["world_size"], r["batch_size"], r["num_workers"], r["precision"], r["epochs"], - r["steps_per_epoch"], r["samples_per_epoch"], f"{r['wall_time_s']:.6f}", - f"{r['throughput_samples_per_s']:.3f}", f"{r['per_gpu_throughput']:.3f}", - ";".join(f"{et:.6f}" for et in r["epoch_times_s"]) - ]) - print(f"\n[Saved] {csv_path}") - - # plots - xs = [ (0 if r["label"]=="cpu" else r["world_size"]) for r in results ] - labels = [r["label"] for r in results] - - # Throughput vs devices (GPU counts; CPU plotted at 0) - plot_one(xs, [r["throughput_samples_per_s"] for r in results], - "Devices (0=CPU)", "Throughput (samples/s)", str(outdir / "throughput_vs_devices.png")) - # Wall time vs devices - plot_one(xs, [r["wall_time_s"] for r in results], - "Devices (0=CPU)", "Wall time (s)", str(outdir / "walltime_vs_devices.png")) - # Mean epoch time vs devices - plot_one(xs, [float(np.mean(r["epoch_times_s"])) for r in results], - "Devices (0=CPU)", "Mean epoch time (s)", str(outdir / "epoch_time_vs_devices.png")) - - print(f"[Saved] {outdir/'throughput_vs_devices.png'}") - print(f"[Saved] {outdir/'walltime_vs_devices.png'}") - print(f"[Saved] {outdir/'epoch_time_vs_devices.png'}") + if is_global_zero(): + csv_path = save_csv(results, args.outdir) + plot_curves(results, args.outdir) + print(f"\nSaved CSV → {csv_path}") + print(f"Saved plots → {args.outdir}/throughput_vs_devices.png, " + f"walltime_vs_devices.png, epoch_time_vs_devices.png") + if __name__ == "__main__": main() From fc35b6885b94ac4dd529bc884348587e0544fabb Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Tue, 11 Nov 2025 19:45:02 -0500 Subject: [PATCH 08/19] update to scale benchmarking --- scale_bench.py | 138 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 93 insertions(+), 45 deletions(-) diff --git a/scale_bench.py b/scale_bench.py index b2aa337f..7f0a86c9 100644 --- a/scale_bench.py +++ b/scale_bench.py @@ -15,21 +15,35 @@ from contextualized.regression.datamodules import ContextualizedRegressionDataModule -# ---------------- utils ---------------- +# ---------------- launcher/cluster helpers ---------------- +def under_torchrun() -> bool: + e = os.environ + return ("LOCAL_RANK" in e) or ("RANK" in e) or ("WORLD_SIZE" in e) + +def world_size() -> int: + try: + return int(os.environ.get("WORLD_SIZE", "1")) + except Exception: + return 1 + +def is_global_zero() -> bool: + # With torchrun Lightning sets ranks; outside it, fall back to env. + return int(os.environ.get("RANK", "0")) == 0 + + +# ---------------- env + perf ---------------- def set_env_defaults(): - # Light, deterministic-friendly defaults os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - # Prefer new PyTorch var (2.4+); avoid deprecated NCCL_ASYNC_ERROR_HANDLING + # Safe NCCL defaults for cloud os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") os.environ.setdefault("NCCL_DEBUG", "WARN") os.environ.setdefault("NCCL_P2P_DISABLE", "0") - # Most cloud nodes lack IB; default it off for reliability - os.environ.setdefault("NCCL_IB_DISABLE", "1") + os.environ.setdefault("NCCL_IB_DISABLE", "1") # IB typically unavailable on single Lambda node - # If user didn't set NCCL_SOCKET_IFNAME, auto-pick a sane one + # Pick an interface if not set if "NCCL_SOCKET_IFNAME" not in os.environ: try: ifaces = [d for d in os.listdir("/sys/class/net") if os.path.isdir(f"/sys/class/net/{d}")] @@ -38,14 +52,20 @@ def set_env_defaults(): except Exception: os.environ["NCCL_SOCKET_IFNAME"] = "lo" - # Unique rendezvous per run + # Rendezvous only matters for non-torchrun spawn os.environ.setdefault("MASTER_ADDR", "127.0.0.1") os.environ.setdefault("MASTER_PORT", str(12355 + (os.getpid() % 20000))) - if int(os.environ.get("RANK", "0")) == 0: + if is_global_zero(): keys = ["NCCL_DEBUG","NCCL_IB_DISABLE","NCCL_P2P_DISABLE","NCCL_SOCKET_IFNAME","MASTER_ADDR","MASTER_PORT"] print("DDP/NCCL env:", {k: os.environ.get(k) for k in keys}) + # Ampere+ matmul speedups + try: + torch.set_float32_matmul_precision("high") + except Exception: + pass + def map_precision(p): p = (p or "").lower() @@ -143,24 +163,44 @@ def build_dm( return dm -def build_trainer(devices, precision, epochs, ddp_timeout_s=120): +def build_trainer(devices, precision, epochs, ddp_timeout_s=120, torchrun_mode=False): + """ + devices: + - 0 => cpu + - 1 => single gpu + - >1 => multi-gpu (only when NOT torchrun) + torchrun_mode: + - True => use DDP (1 GPU per process), devices=1 + """ + timer = EpochTimer() + if devices == 0: accelerator = "cpu" devices = 1 strategy = "auto" - elif devices == 1: - accelerator = "gpu" - strategy = "auto" else: accelerator = "gpu" - strategy = DDPStrategy( - start_method="spawn", - find_unused_parameters=False, - gradient_as_bucket_view=True, - static_graph=True, - timeout=timedelta(seconds=ddp_timeout_s), - ) - timer = EpochTimer() + if torchrun_mode and world_size() > 1: + # true DDP under torchrun + strategy = DDPStrategy( + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + timeout=timedelta(seconds=ddp_timeout_s), + ) + devices = 1 # 1 GPU per process + elif devices == 1: + strategy = "auto" + else: + # multi-gpu without torchrun: fall back to spawn + strategy = DDPStrategy( + start_method="spawn", + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + timeout=timedelta(seconds=ddp_timeout_s), + ) + trainer = pl.Trainer( accelerator=accelerator, devices=devices, @@ -174,6 +214,7 @@ def build_trainer(devices, precision, epochs, ddp_timeout_s=120): log_every_n_steps=50, callbacks=[timer], inference_mode=False, + detect_anomaly=False, ) return trainer, timer @@ -185,7 +226,7 @@ class BenchCfg: devices: int # 0=cpu, >=1 gpus -def run_once(cfg: BenchCfg, C, X, Y, args) -> Dict: +def run_once(cfg: BenchCfg, C, X, Y, args, torchrun_mode: bool) -> Dict: if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -209,19 +250,21 @@ def run_once(cfg: BenchCfg, C, X, Y, args) -> Dict: pin_memory=(cfg.devices >= 1), ) warm_trainer, _ = build_trainer( - devices=cfg.devices, # cpu: 0, 1-gpu: 1, multi: k + devices=cfg.devices if not torchrun_mode else 1, precision=map_precision(args.precision), epochs=1, ddp_timeout_s=args.ddp_timeout, + torchrun_mode=torchrun_mode, ) warm_trainer.fit(model, train_dataloaders=dm_warm.train_dataloader()) # ---- main timed run ---- trainer, timer = build_trainer( - devices=cfg.devices, + devices=cfg.devices if not torchrun_mode else 1, precision=map_precision(args.precision), epochs=args.epochs, ddp_timeout_s=args.ddp_timeout, + torchrun_mode=torchrun_mode, ) if torch.cuda.is_available(): @@ -236,12 +279,12 @@ def run_once(cfg: BenchCfg, C, X, Y, args) -> Dict: train_samples = len(dm.train_dataloader().dataset) samples_total = train_samples * args.epochs throughput = samples_total / max(wall, 1e-9) - per_device = (throughput / max(cfg.devices, 1)) if cfg.devices >= 1 else throughput + per_device = (throughput / (world_size() if torchrun_mode and cfg.devices >= 1 else max(cfg.devices, 1))) epoch_times = timer.epoch_times[:] # seconds per epoch res = dict( label=cfg.label, - devices=cfg.devices, + devices=(world_size() if torchrun_mode else cfg.devices), wall_seconds=wall, samples_total=int(samples_total), throughput_samples_per_s=throughput, @@ -250,7 +293,7 @@ def run_once(cfg: BenchCfg, C, X, Y, args) -> Dict: samples_per_epoch=int(train_samples), epoch_times=epoch_times, ) - if int(os.environ.get("RANK", "0")) == 0: + if is_global_zero(): print(json.dumps({ "label": res["label"], "devices": res["devices"], @@ -295,7 +338,7 @@ def plot_curves(rows: List[Dict], outdir: str): plt.figure() plt.plot(devs, thr, marker="o") plt.xticks(devs, labels, rotation=30, ha="right") - plt.xlabel("Configuration") + plt.xlabel("Devices") plt.ylabel("Throughput (samples/s)") plt.title("Throughput vs Devices") plt.tight_layout() @@ -306,7 +349,7 @@ def plot_curves(rows: List[Dict], outdir: str): plt.figure() plt.plot(devs, wall, marker="o") plt.xticks(devs, labels, rotation=30, ha="right") - plt.xlabel("Configuration") + plt.xlabel("Devices") plt.ylabel("Total Wall Time (s)") plt.title("Wall Time vs Devices") plt.tight_layout() @@ -317,7 +360,7 @@ def plot_curves(rows: List[Dict], outdir: str): plt.figure() plt.plot(devs, avg_epoch, marker="o") plt.xticks(devs, labels, rotation=30, ha="right") - plt.xlabel("Configuration") + plt.xlabel("Devices") plt.ylabel("Avg Train Epoch Time (s)") plt.title("Epoch Time vs Devices") plt.tight_layout() @@ -325,15 +368,11 @@ def plot_curves(rows: List[Dict], outdir: str): plt.close() -def is_global_zero() -> bool: - return int(os.environ.get("RANK", "0")) == 0 - - # ---------------- main ---------------- def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--epochs", type=int, default=5) - ap.add_argument("--batch-size", type=int, default=2048) + ap.add_argument("--batch-size", type=int, default=2048) # PER GPU ap.add_argument("--num-workers", type=int, default=8) ap.add_argument("--precision", type=str, default="bf16") ap.add_argument("--n", type=int, default=2_000_000) @@ -344,7 +383,7 @@ def parse_args(): ap.add_argument("--layers", type=int, default=4) ap.add_argument("--lr", type=float, default=1e-3) ap.add_argument("--outdir", type=str, default="bench_out") - ap.add_argument("--ddp-timeout", type=int, default=120) + ap.add_argument("--ddp-timeout", type=int, default=180) ap.add_argument("--max-gpus", type=int, default=4) return ap.parse_args() @@ -355,23 +394,32 @@ def main(): os.makedirs(args.outdir, exist_ok=True) if torch.cuda.is_available(): - torch.backends.cudnn.benchmark = True # optional micro-optim for fixed shapes + torch.backends.cudnn.benchmark = True # opt for fixed shapes # data once C, X, Y = make_synthetic(args.n, args.context_dim, args.x_dim, args.y_dim) - # configs: CPU + 1..available GPUs (cap at --max-gpus) - gpus = torch.cuda.device_count() - dev_list = [BenchCfg("cpu", 0)] - for k in range(1, min(args.max_gpus, gpus) + 1): - dev_list.append(BenchCfg(f"gpu-{k}", k)) - results = [] - for cfg in dev_list: + torchrun_mode = under_torchrun() + + if torchrun_mode: + # Run exactly one multi-GPU config under torchrun: devices = WORLD_SIZE + cfg = BenchCfg(label=f"gpu-{world_size()}", devices=1) # 1 per proc; true DDP if is_global_zero(): - print(f"\n=== Running {cfg.label} ===") - res = run_once(cfg, C, X, Y, args) + print(f"\n=== Running {cfg.label} (torchrun, {world_size()} processes) ===") + res = run_once(cfg, C, X, Y, args, torchrun_mode=True) results.append(res) + else: + # Standalone: run cpu + 1..k GPUs (ddp-spawn for k>1) + gpus = torch.cuda.device_count() + dev_list = [BenchCfg("cpu", 0)] + for k in range(1, min(args.max_gpus, gpus) + 1): + dev_list.append(BenchCfg(f"gpu-{k}", k)) + for cfg in dev_list: + if is_global_zero(): + print(f"\n=== Running {cfg.label} ===") + res = run_once(cfg, C, X, Y, args, torchrun_mode=False) + results.append(res) if is_global_zero(): csv_path = save_csv(results, args.outdir) From 3fcb387e0f9f229e42a6d95dc42c885eba41a508 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Wed, 12 Nov 2025 11:50:25 -0500 Subject: [PATCH 09/19] cpu benchmark file added --- cpu_scale_bench.py | 448 +++++++++++++++++++++++++++++++++++++++++++++ scale_bench.py | 131 ++++++++----- 2 files changed, 537 insertions(+), 42 deletions(-) create mode 100644 cpu_scale_bench.py diff --git a/cpu_scale_bench.py b/cpu_scale_bench.py new file mode 100644 index 00000000..5d835f82 --- /dev/null +++ b/cpu_scale_bench.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +import os, time, csv, argparse, math, json +from dataclasses import dataclass +from typing import List, Dict +from datetime import timedelta + +import numpy as np +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.strategies import DDPStrategy + +# ---- your package pieces ---- +from contextualized.regression import ContextualizedRegression +from contextualized.regression.datamodules import ContextualizedRegressionDataModule + + +# ---------------- launcher/cluster helpers ---------------- +def under_torchrun() -> bool: + e = os.environ + return ("LOCAL_RANK" in e) or ("RANK" in e) or ("WORLD_SIZE" in e) + +def world_size() -> int: + try: + return int(os.environ.get("WORLD_SIZE", "1")) + except Exception: + return 1 + +def is_global_zero() -> bool: + return int(os.environ.get("RANK", "0")) == 0 + + +# ---------------- env + perf ---------------- +def set_env_defaults(): + os.environ.setdefault("OMP_NUM_THREADS", "1") + os.environ.setdefault("MKL_NUM_THREADS", "1") + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + + # Safer NCCL defaults on cloud single node + os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") + os.environ.setdefault("NCCL_DEBUG", "WARN") + os.environ.setdefault("NCCL_P2P_DISABLE", "0") + os.environ.setdefault("NCCL_IB_DISABLE", "1") # IB usually unavailable on single-node Lambda + + # Pick an interface if not set + if "NCCL_SOCKET_IFNAME" not in os.environ: + try: + ifaces = [d for d in os.listdir("/sys/class/net") if os.path.isdir(f"/sys/class/net/{d}")] + cand = next((i for i in ifaces if i not in ("lo", "docker0")), None) + os.environ["NCCL_SOCKET_IFNAME"] = cand or "lo" + except Exception: + os.environ["NCCL_SOCKET_IFNAME"] = "lo" + + # Rendezvous (used only by ddp_spawn mode) + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", str(12355 + (os.getpid() % 20000))) + + if is_global_zero(): + keys = ["NCCL_DEBUG","NCCL_IB_DISABLE","NCCL_P2P_DISABLE","NCCL_SOCKET_IFNAME","MASTER_ADDR","MASTER_PORT"] + print("DDP/NCCL env:", {k: os.environ.get(k) for k in keys}) + + # Ampere+ matmul speedups + try: + torch.set_float32_matmul_precision("high") + except Exception: + pass + + +def map_precision(p): + p = (p or "").lower() + if p in ("bf16", "bfloat16", "bf16-mixed"): + return "bf16-mixed" + if p in ("fp16", "16", "16-mixed"): + return "16-mixed" + return 32 # full precision + + +class EpochTimer(Callback): + def __init__(self): + self._epoch_start = None + self.epoch_times = [] + + @staticmethod + def _using_cuda(trainer) -> bool: + try: + return trainer.accelerator is not None and "cuda" in str(trainer.accelerator).lower() + except Exception: + return torch.cuda.is_available() + + def on_train_epoch_start(self, trainer, pl_module): + if self._using_cuda(trainer): + torch.cuda.synchronize() + self._epoch_start = time.time() + + def on_train_epoch_end(self, trainer, pl_module): + if self._using_cuda(trainer): + torch.cuda.synchronize() + self.epoch_times.append(time.time() - self._epoch_start) + + +# ---------------- synthetic data ---------------- +def make_synthetic(n, c_dim, x_dim, y_dim, seed=42): + rng = np.random.default_rng(seed) + C = rng.standard_normal((n, c_dim)).astype(np.float32) + X = rng.standard_normal((n, x_dim)).astype(np.float32) + W = rng.standard_normal((y_dim, x_dim)).astype(np.float32) + MU = rng.standard_normal((y_dim, 1)).astype(np.float32) + Y = (X @ W.T) + MU.squeeze(-1) + 0.01 * rng.standard_normal((n, y_dim)).astype(np.float32) + return C, X, Y + + +def load_or_make_dataset(path, n, c_dim, x_dim, y_dim, seed=42): + if path and os.path.exists(path): + npz = np.load(path) + C, X, Y = npz["C"], npz["X"], npz["Y"] + return C, X, Y + C, X, Y = make_synthetic(n, c_dim, x_dim, y_dim, seed=seed) + if path: + os.makedirs(os.path.dirname(path), exist_ok=True) + np.savez_compressed(path, C=C, X=X, Y=Y) + return C, X, Y + + +# ---------------- model/trainer builders ---------------- +def build_model(c_dim, x_dim, y_dim, width, layers, lr): + model = ContextualizedRegression( + context_dim=c_dim, + x_dim=x_dim, + y_dim=y_dim, + num_archetypes=8, + encoder_type="mlp", + encoder_kwargs={"width": width, "layers": layers, "link_fn": "identity"}, + learning_rate=lr, + fit_intercept=True, + link_fn="identity", + loss_fn="mse", + model_regularizer="none", + ) + return model + + +def build_dm( + C, X, Y, + train_batch_size: int, + num_workers: int, + pin_memory: bool, +): + n = C.shape[0] + perm = np.random.permutation(n) + n_train = int(0.9 * n) + train_idx = perm[:n_train] + val_idx = perm[n_train:] + + dm = ContextualizedRegressionDataModule( + C=C, X=X, Y=Y, + task_type="singletask_multivariate", + train_idx=train_idx, + val_idx=val_idx, + test_idx=None, + predict_idx=None, + train_batch_size=train_batch_size, + val_batch_size=train_batch_size, + test_batch_size=train_batch_size, + predict_batch_size=train_batch_size, + num_workers=num_workers, + pin_memory=bool(pin_memory), + persistent_workers=bool(num_workers > 0), + drop_last=True, + shuffle_train=True, + shuffle_eval=False, + dtype=torch.float, + ) + dm.prepare_data(); dm.setup() + return dm + + +def build_trainer(devices, precision, epochs, ddp_timeout_s=120, torchrun_mode=False): + """ + devices: + - 0 => cpu + - >=1 => number of devices this process should report to Lightning + + torchrun_mode: + - True => launched via torchrun; use DDP with devices = WORLD_SIZE, + no spawn. Satisfies Lightning's validation. + """ + timer = EpochTimer() + + if devices == 0: + accelerator = "cpu" + devices_arg = 1 + strategy = "auto" + else: + accelerator = "gpu" + if torchrun_mode: + ws = world_size() + devices_arg = ws # must equal WORLD_SIZE + strategy = DDPStrategy( + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + timeout=timedelta(seconds=ddp_timeout_s), + ) + else: + devices_arg = devices + strategy = "auto" if devices == 1 else DDPStrategy( + start_method="spawn", + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + timeout=timedelta(seconds=ddp_timeout_s), + ) + + trainer = pl.Trainer( + accelerator=accelerator, + devices=devices_arg, + strategy=strategy, + precision=precision, + max_epochs=epochs, + logger=False, + enable_checkpointing=False, + num_sanity_val_steps=0, + enable_progress_bar=False, + log_every_n_steps=50, + callbacks=[timer], + inference_mode=False, + detect_anomaly=False, + ) + return trainer, timer + + +# ---------------- benchmark runner ---------------- +@dataclass +class BenchCfg: + label: str + devices: int # 0 for cpu, >=1 for gpus + + +def run_once(cfg: BenchCfg, C, X, Y, args, torchrun_mode: bool) -> Dict: + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + pin = (cfg.devices >= 1) + dm = build_dm( + C, X, Y, + train_batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=pin, + ) + model = build_model(args.context_dim, args.x_dim, args.y_dim, + args.width, args.layers, args.lr) + + # Warm-up + tiny = max(1024, math.ceil(0.01 * C.shape[0])) + dm_warm = build_dm( + C[:tiny], X[:tiny], Y[:tiny], + train_batch_size=args.batch_size, + num_workers=0, + pin_memory=pin, + ) + warm_trainer, _ = build_trainer( + devices=(world_size() if torchrun_mode else cfg.devices), + precision=map_precision(args.precision), + epochs=1, + ddp_timeout_s=args.ddp_timeout, + torchrun_mode=torchrun_mode, + ) + warm_trainer.fit(model, train_dataloaders=dm_warm.train_dataloader()) + + # Timed run + trainer, timer = build_trainer( + devices=(world_size() if torchrun_mode else cfg.devices), + precision=map_precision(args.precision), + epochs=args.epochs, + ddp_timeout_s=args.ddp_timeout, + torchrun_mode=torchrun_mode, + ) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + t0 = time.time() + trainer.fit(model, train_dataloaders=dm.train_dataloader()) + if torch.cuda.is_available(): + torch.cuda.synchronize() + wall = time.time() - t0 + + train_samples = len(dm.train_dataloader().dataset) + samples_total = train_samples * args.epochs + throughput = samples_total / max(wall, 1e-9) + + # devices_for_metric: report 1 for CPU so it's easy to compare "per-device" + world = (world_size() if torchrun_mode else (cfg.devices if cfg.devices > 0 else 1)) + per_device = throughput / max(world, 1) + + res = dict( + label=cfg.label, + devices=(world_size() if torchrun_mode else (cfg.devices if cfg.devices > 0 else 1)), + wall_seconds=wall, + samples_total=int(samples_total), + throughput_samples_per_s=throughput, + per_device_throughput=per_device, + steps_per_epoch=math.ceil(train_samples / args.batch_size), + samples_per_epoch=int(train_samples), + epoch_times=timer.epoch_times[:], + ) + if is_global_zero(): + print(json.dumps({ + "label": res["label"], + "devices": res["devices"], + "wall_s": round(res["wall_seconds"], 3), + "throughput_sps": round(res["throughput_samples_per_s"], 2), + "per_device_sps": round(res["per_device_throughput"], 2), + "avg_epoch_s": round(float(np.mean(res["epoch_times"])) if res["epoch_times"] else float("nan"), 3) + }, indent=2)) + return res + + +def save_csv(rows: List[Dict], outdir: str): + os.makedirs(outdir, exist_ok=True) + path = os.path.join(outdir, "scale_results.csv") + fields = ["label","devices","wall_seconds","samples_total", + "throughput_samples_per_s","per_device_throughput", + "steps_per_epoch","samples_per_epoch","epoch_times"] + with open(path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=fields) + w.writeheader() + for r in rows: + r2 = r.copy() + r2["epoch_times"] = ";".join(f"{x:.6f}" for x in r["epoch_times"]) + w.writerow(r2) + return path + + +def plot_curves(rows: List[Dict], outdir: str): + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + import numpy as np + + os.makedirs(outdir, exist_ok=True) + labels = [r["label"] for r in rows] + devs = [r["devices"] for r in rows] + thr = [r["throughput_samples_per_s"] for r in rows] + wall = [r["wall_seconds"] for r in rows] + avg_epoch = [np.mean(r["epoch_times"]) if r["epoch_times"] else float("nan") for r in rows] + + plt.figure() + plt.plot(devs, thr, marker="o") + plt.xticks(devs, labels, rotation=30, ha="right") + plt.xlabel("Devices") + plt.ylabel("Throughput (samples/s)") + plt.title("Throughput vs Devices") + plt.tight_layout() + plt.savefig(os.path.join(outdir, "throughput_vs_devices.png")) + plt.close() + + plt.figure() + plt.plot(devs, wall, marker="o") + plt.xticks(devs, labels, rotation=30, ha="right") + plt.xlabel("Devices") + plt.ylabel("Total Wall Time (s)") + plt.title("Wall Time vs Devices") + plt.tight_layout() + plt.savefig(os.path.join(outdir, "walltime_vs_devices.png")) + plt.close() + + plt.figure() + plt.plot(devs, avg_epoch, marker="o") + plt.xticks(devs, labels, rotation=30, ha="right") + plt.xlabel("Devices") + plt.ylabel("Avg Train Epoch Time (s)") + plt.title("Epoch Time vs Devices") + plt.tight_layout() + plt.savefig(os.path.join(outdir, "epoch_time_vs_devices.png")) + plt.close() + + +# ---------------- main ---------------- +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--epochs", type=int, default=5) + ap.add_argument("--batch-size", type=int, default=2048) # PER GPU/CPU + ap.add_argument("--num-workers", type=int, default=8) + ap.add_argument("--precision", type=str, default="bf16") + ap.add_argument("--dataset-cache", type=str, default="bench_out/datasets/n{n}_seed42.npz", + help="Path to .npz to cache dataset. '{n}' will be replaced with num_samples.") + + # Accept BOTH forms; same dest + ap.add_argument("--num-samples", dest="num_samples", type=int, default=2_000_000) + ap.add_argument("--n", dest="num_samples", type=int) + + ap.add_argument("--context-dim", type=int, default=16) + ap.add_argument("--x-dim", type=int, default=512) + ap.add_argument("--y-dim", type=int, default=64) + ap.add_argument("--width", type=int, default=1024) + ap.add_argument("--layers", type=int, default=4) + ap.add_argument("--lr", type=float, default=1e-3) + ap.add_argument("--outdir", type=str, default="bench_out") + ap.add_argument("--ddp-timeout", type=int, default=180) + ap.add_argument("--max-gpus", type=int, default=4) + return ap.parse_args() + + +def main(): + set_env_defaults() + args = parse_args() + os.makedirs(args.outdir, exist_ok=True) + + if torch.cuda.is_available(): + torch.backends.cudnn.benchmark = True + os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + + # ensure dataset cache path is concrete + ds_path = args.dataset_cache.format(n=args.num_samples) if args.dataset_cache else None + C, X, Y = load_or_make_dataset(ds_path, args.num_samples, args.context_dim, args.x_dim, args.y_dim, seed=42) + + results = [] + torchrun_mode = under_torchrun() + + if torchrun_mode: + # Run exactly one config under torchrun (WORLD_SIZE GPUs, 1 per process) + cfg = BenchCfg(label=f"gpu-{world_size()}", devices=1) + if is_global_zero(): + print(f"\n=== Running {cfg.label} (torchrun, {world_size()} processes) ===") + res = run_once(cfg, C, X, Y, args, torchrun_mode=True) + results.append(res) + else: + # Standalone: run CPU + 1..k GPUs + gpus = torch.cuda.device_count() + dev_list = [BenchCfg("cpu", 0)] + for k in range(1, min(args.max_gpus, gpus) + 1): + dev_list.append(BenchCfg(f"gpu-{k}", k)) + for cfg in dev_list: + if is_global_zero(): + print(f"\n=== Running {cfg.label} ===") + res = run_once(cfg, C, X, Y, args, torchrun_mode=False) + results.append(res) + + if is_global_zero(): + csv_path = save_csv(results, args.outdir) + plot_curves(results, args.outdir) + print(f"\nSaved CSV → {csv_path}") + print(f"Saved plots → {args.outdir}/throughput_vs_devices.png, " + f"walltime_vs_devices.png, epoch_time_vs_devices.png") + + +if __name__ == "__main__": + main() diff --git a/scale_bench.py b/scale_bench.py index 7f0a86c9..f105fe90 100644 --- a/scale_bench.py +++ b/scale_bench.py @@ -1,4 +1,51 @@ #!/usr/bin/env python3 + +""" +# 0) See what NICs you actually have (optional, for sanity): +ls -1 /sys/class/net +ip -o link show | awk -F': ' '{print NR-1": "$2}' + +# 1) Minimal, safe NCCL/torch env (no hard-coded eth0): +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export TOKENIZERS_PARALLELISM=false +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export NCCL_DEBUG=WARN +export NCCL_P2P_DISABLE=0 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=$(ls /sys/class/net | grep -E '^(ens|enp|eno|eth|bond|ib)' | head -n1) +# If that prints nothing on your machine, fall back to auto-exclude: +[ -z "$NCCL_SOCKET_IFNAME" ] && export NCCL_SOCKET_IFNAME="^lo,docker0" + +# CUDA allocator tweak (fine to keep) +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# 2) Kill any stragglers (optional) +pkill -f scale_bench.py || true +pkill -f torchrun || true + +# 3a) Single-GPU run (torchrun, WORLD_SIZE=1) +torchrun --standalone --nproc_per_node=1 scale_bench.py \ + --epochs 3 --batch-size 2048 --num-workers 8 --precision bf16 \ + --num-samples 1800000 --outdir bench_out/gpu1 + +# 3b) Two GPUs +torchrun --standalone --nproc_per_node=2 scale_bench.py \ + --epochs 3 --batch-size 2048 --num-workers 8 --precision bf16 \ + --num-samples 1800000 --outdir bench_out/gpu2 + +# 3c) Three GPUs +torchrun --standalone --nproc_per_node=3 scale_bench.py \ + --epochs 3 --batch-size 2048 --num-workers 8 --precision bf16 \ + --num-samples 1800000 --outdir bench_out/gpu3 + +# 3d) Four GPUs +torchrun --standalone --nproc_per_node=4 scale_bench.py \ + --epochs 3 --batch-size 2048 --num-workers 8 --precision bf16 \ + --num-samples 1800000 --outdir bench_out/gpu4 + +""" import os, time, csv, argparse, math, json from dataclasses import dataclass from typing import List, Dict @@ -27,7 +74,6 @@ def world_size() -> int: return 1 def is_global_zero() -> bool: - # With torchrun Lightning sets ranks; outside it, fall back to env. return int(os.environ.get("RANK", "0")) == 0 @@ -37,11 +83,11 @@ def set_env_defaults(): os.environ.setdefault("MKL_NUM_THREADS", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - # Safe NCCL defaults for cloud + # Safer NCCL defaults on cloud single node os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") os.environ.setdefault("NCCL_DEBUG", "WARN") os.environ.setdefault("NCCL_P2P_DISABLE", "0") - os.environ.setdefault("NCCL_IB_DISABLE", "1") # IB typically unavailable on single Lambda node + os.environ.setdefault("NCCL_IB_DISABLE", "1") # IB usually unavailable on single-node Lambda # Pick an interface if not set if "NCCL_SOCKET_IFNAME" not in os.environ: @@ -52,7 +98,7 @@ def set_env_defaults(): except Exception: os.environ["NCCL_SOCKET_IFNAME"] = "lo" - # Rendezvous only matters for non-torchrun spawn + # Rendezvous (used only by ddp_spawn mode) os.environ.setdefault("MASTER_ADDR", "127.0.0.1") os.environ.setdefault("MASTER_PORT", str(12355 + (os.getpid() % 20000))) @@ -167,33 +213,32 @@ def build_trainer(devices, precision, epochs, ddp_timeout_s=120, torchrun_mode=F """ devices: - 0 => cpu - - 1 => single gpu - - >1 => multi-gpu (only when NOT torchrun) + - >=1 => number of devices this process should report to Lightning + torchrun_mode: - - True => use DDP (1 GPU per process), devices=1 + - True => launched via torchrun; use DDP with devices = WORLD_SIZE, + no spawn. Satisfies Lightning's validation. """ timer = EpochTimer() if devices == 0: accelerator = "cpu" - devices = 1 + devices_arg = 1 strategy = "auto" else: accelerator = "gpu" - if torchrun_mode and world_size() > 1: - # true DDP under torchrun + if torchrun_mode: + ws = world_size() + devices_arg = ws # <-- IMPORTANT: devices must equal WORLD_SIZE here strategy = DDPStrategy( find_unused_parameters=False, gradient_as_bucket_view=True, static_graph=True, timeout=timedelta(seconds=ddp_timeout_s), ) - devices = 1 # 1 GPU per process - elif devices == 1: - strategy = "auto" else: - # multi-gpu without torchrun: fall back to spawn - strategy = DDPStrategy( + devices_arg = devices + strategy = "auto" if devices == 1 else DDPStrategy( start_method="spawn", find_unused_parameters=False, gradient_as_bucket_view=True, @@ -203,7 +248,7 @@ def build_trainer(devices, precision, epochs, ddp_timeout_s=120, torchrun_mode=F trainer = pl.Trainer( accelerator=accelerator, - devices=devices, + devices=devices_arg, strategy=strategy, precision=precision, max_epochs=epochs, @@ -223,34 +268,32 @@ def build_trainer(devices, precision, epochs, ddp_timeout_s=120, torchrun_mode=F @dataclass class BenchCfg: label: str - devices: int # 0=cpu, >=1 gpus + devices: int # >=1 gpus def run_once(cfg: BenchCfg, C, X, Y, args, torchrun_mode: bool) -> Dict: if torch.cuda.is_available(): torch.cuda.empty_cache() - # datamodule dm = build_dm( C, X, Y, train_batch_size=args.batch_size, num_workers=args.num_workers, - pin_memory=(cfg.devices >= 1), + pin_memory=True, ) - # model model = build_model(args.context_dim, args.x_dim, args.y_dim, args.width, args.layers, args.lr) - # ---- warm-up on the SAME accelerator config ---- + # Warm-up (stabilize kernels/allocators) on same accelerator config tiny = max(1024, math.ceil(0.01 * C.shape[0])) dm_warm = build_dm( C[:tiny], X[:tiny], Y[:tiny], train_batch_size=args.batch_size, num_workers=0, - pin_memory=(cfg.devices >= 1), + pin_memory=True, ) warm_trainer, _ = build_trainer( - devices=cfg.devices if not torchrun_mode else 1, + devices=(world_size() if torchrun_mode else cfg.devices), # <-- fix precision=map_precision(args.precision), epochs=1, ddp_timeout_s=args.ddp_timeout, @@ -258,9 +301,9 @@ def run_once(cfg: BenchCfg, C, X, Y, args, torchrun_mode: bool) -> Dict: ) warm_trainer.fit(model, train_dataloaders=dm_warm.train_dataloader()) - # ---- main timed run ---- + # Timed run trainer, timer = build_trainer( - devices=cfg.devices if not torchrun_mode else 1, + devices=(world_size() if torchrun_mode else cfg.devices), # <-- fix precision=map_precision(args.precision), epochs=args.epochs, ddp_timeout_s=args.ddp_timeout, @@ -275,16 +318,18 @@ def run_once(cfg: BenchCfg, C, X, Y, args, torchrun_mode: bool) -> Dict: torch.cuda.synchronize() wall = time.time() - t0 - # metrics (use actual train size, not full N) train_samples = len(dm.train_dataloader().dataset) samples_total = train_samples * args.epochs throughput = samples_total / max(wall, 1e-9) - per_device = (throughput / (world_size() if torchrun_mode and cfg.devices >= 1 else max(cfg.devices, 1))) - epoch_times = timer.epoch_times[:] # seconds per epoch + + world = world_size() if torchrun_mode else cfg.devices + per_device = throughput / max(world, 1) + + epoch_times = timer.epoch_times[:] res = dict( label=cfg.label, - devices=(world_size() if torchrun_mode else cfg.devices), + devices=world, wall_seconds=wall, samples_total=int(samples_total), throughput_samples_per_s=throughput, @@ -334,7 +379,6 @@ def plot_curves(rows: List[Dict], outdir: str): wall = [r["wall_seconds"] for r in rows] avg_epoch = [np.mean(r["epoch_times"]) if r["epoch_times"] else float("nan") for r in rows] - # Throughput plt.figure() plt.plot(devs, thr, marker="o") plt.xticks(devs, labels, rotation=30, ha="right") @@ -345,7 +389,6 @@ def plot_curves(rows: List[Dict], outdir: str): plt.savefig(os.path.join(outdir, "throughput_vs_devices.png")) plt.close() - # Wall time plt.figure() plt.plot(devs, wall, marker="o") plt.xticks(devs, labels, rotation=30, ha="right") @@ -356,7 +399,6 @@ def plot_curves(rows: List[Dict], outdir: str): plt.savefig(os.path.join(outdir, "walltime_vs_devices.png")) plt.close() - # Avg epoch time plt.figure() plt.plot(devs, avg_epoch, marker="o") plt.xticks(devs, labels, rotation=30, ha="right") @@ -375,7 +417,11 @@ def parse_args(): ap.add_argument("--batch-size", type=int, default=2048) # PER GPU ap.add_argument("--num-workers", type=int, default=8) ap.add_argument("--precision", type=str, default="bf16") - ap.add_argument("--n", type=int, default=2_000_000) + + # Accept BOTH forms; they write to the same dest + ap.add_argument("--num-samples", dest="num_samples", type=int, default=2_000_000) + ap.add_argument("--n", dest="num_samples", type=int) # optional legacy alias + ap.add_argument("--context-dim", type=int, default=16) ap.add_argument("--x-dim", type=int, default=512) ap.add_argument("--y-dim", type=int, default=64) @@ -388,39 +434,40 @@ def parse_args(): return ap.parse_args() + def main(): set_env_defaults() args = parse_args() os.makedirs(args.outdir, exist_ok=True) if torch.cuda.is_available(): - torch.backends.cudnn.benchmark = True # opt for fixed shapes + torch.backends.cudnn.benchmark = True + os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") - # data once - C, X, Y = make_synthetic(args.n, args.context_dim, args.x_dim, args.y_dim) + # Generate data once + C, X, Y = make_synthetic(args.num_samples, args.context_dim, args.x_dim, args.y_dim) results = [] torchrun_mode = under_torchrun() if torchrun_mode: - # Run exactly one multi-GPU config under torchrun: devices = WORLD_SIZE - cfg = BenchCfg(label=f"gpu-{world_size()}", devices=1) # 1 per proc; true DDP + # Run a single config under torchrun (WORLD_SIZE GPUs, 1 per process) + cfg = BenchCfg(label=f"gpu-{world_size()}", devices=1) if is_global_zero(): print(f"\n=== Running {cfg.label} (torchrun, {world_size()} processes) ===") res = run_once(cfg, C, X, Y, args, torchrun_mode=True) results.append(res) else: - # Standalone: run cpu + 1..k GPUs (ddp-spawn for k>1) + # Standalone: GPU-only sweep 1..k (skip CPU entirely) gpus = torch.cuda.device_count() - dev_list = [BenchCfg("cpu", 0)] - for k in range(1, min(args.max_gpus, gpus) + 1): - dev_list.append(BenchCfg(f"gpu-{k}", k)) + dev_list = [BenchCfg(f"gpu-{k}", k) for k in range(1, min(args.max_gpus, gpus) + 1)] for cfg in dev_list: if is_global_zero(): print(f"\n=== Running {cfg.label} ===") res = run_once(cfg, C, X, Y, args, torchrun_mode=False) results.append(res) + # Save outputs if is_global_zero(): csv_path = save_csv(results, args.outdir) plot_curves(results, args.outdir) From 55b14647db98b5f8f67bf9c0a3feffcd76d3a0be Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Sat, 15 Nov 2025 15:27:32 -0500 Subject: [PATCH 10/19] added cleaner functions to wrapper --- .../easy/wrappers/SKLearnWrapper.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index 4a4f4b29..edc6a6cc 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -192,6 +192,10 @@ def __init__( print(f"Received unknown keyword argument {kw}, probably ignoring.") # -------------------- helpers -------------------- + + def _is_gpu(self) -> bool: + return self.accelerator in ("cuda", "gpu") + def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True): if acceptable: self.acceptable_kwargs[category] = list( @@ -342,8 +346,8 @@ def maybe_add(cat, k, default): maybe_add("data", "test_batch_size", self.default_test_batch_size) maybe_add("data", "predict_batch_size", self.default_val_batch_size) maybe_add("data", "num_workers", 0) - maybe_add("data", "pin_memory", self.accelerator in ("cuda", "gpu")) - maybe_add("data", "persistent_workers", False) + maybe_add("data", "pin_memory", self._is_gpu()) + maybe_add("data", "persistent_workers", organized["data"].get("num_workers", 0) > 0) maybe_add("data", "drop_last", False) maybe_add("data", "shuffle_train", True) maybe_add("data", "shuffle_eval", False) @@ -423,7 +427,8 @@ def _build_datamodule( test_batch_size=self.default_test_batch_size, predict_batch_size=self.default_val_batch_size, num_workers=0, - pin_memory=(self.accelerator in ("cuda", "gpu")), + pin_memory=self._is_gpu(), + persistent_workers=None, persistent_workers=False, drop_last=False, shuffle_train=True, @@ -433,6 +438,11 @@ def _build_datamodule( if data_kwargs: dk.update(data_kwargs) + # If not explicitly set, default to True when num_workers > 0 + if dk["persistent_workers"] is None: + dk["persistent_workers"] = bool(dk["num_workers"] > 0) + + dm = ContextualizedRegressionDataModule( C=C, X=X, @@ -520,7 +530,7 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator in ("cuda", "gpu"))), + pin_memory=self._init_kwargs["data"].get("pin_memory", self._is_gpu()), persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), shuffle_train=False, shuffle_eval=False, @@ -569,7 +579,7 @@ def predict_params( test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator in ("cuda", "gpu"))), + pin_memory=self._init_kwargs["data"].get("pin_memory", self._is_gpu()), persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), shuffle_train=False, shuffle_eval=False, @@ -683,8 +693,8 @@ def fit(self, *args, **kwargs) -> None: test_batch_size=organized["data"].get("test_batch_size", self.default_test_batch_size), predict_batch_size=organized["data"].get("predict_batch_size", self.default_val_batch_size), num_workers=organized["data"].get("num_workers", 0), - pin_memory=organized["data"].get("pin_memory", self.accelerator in ("cuda", "gpu")), - persistent_workers=organized["data"].get("persistent_workers", False), + pin_memory=organized["data"].get("pin_memory", self._is_gpu()), + persistent_workers=organized["data"].get("persistent_workers", organized["data"].get("num_workers", 0) > 0), drop_last=organized["data"].get("drop_last", False), shuffle_train=organized["data"].get("shuffle_train", True), shuffle_eval=organized["data"].get("shuffle_eval", False), From 06cab9957392030da42d6ba85e6b6c874022dfa8 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Sat, 15 Nov 2025 16:01:23 -0500 Subject: [PATCH 11/19] fix redundancy issue --- .../easy/wrappers/SKLearnWrapper.py | 223 +++++++++++++----- 1 file changed, 166 insertions(+), 57 deletions(-) diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index edc6a6cc..bbd4fb42 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -9,7 +9,7 @@ from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins.environments import LightningEnvironment -from pytorch_lightning.strategies import DDPStrategy # PL v1 Strategy API +from pytorch_lightning.strategies import DDPStrategy # PL v1 Strategy API (not strictly required here) from contextualized.functions import LINK_FUNCTIONS from contextualized.regression import REGULARIZERS, LOSSES @@ -347,7 +347,11 @@ def maybe_add(cat, k, default): maybe_add("data", "predict_batch_size", self.default_val_batch_size) maybe_add("data", "num_workers", 0) maybe_add("data", "pin_memory", self._is_gpu()) - maybe_add("data", "persistent_workers", organized["data"].get("num_workers", 0) > 0) + maybe_add( + "data", + "persistent_workers", + organized["data"].get("num_workers", 0) > 0, + ) maybe_add("data", "drop_last", False) maybe_add("data", "shuffle_train", True) maybe_add("data", "shuffle_eval", False) @@ -357,9 +361,13 @@ def maybe_add(cat, k, default): maybe_add("wrapper", "n_bootstraps", self.default_n_bootstraps) # -------- EarlyStopping / Checkpoint constructors -------- - es_monitor = organized["wrapper"].get("es_monitor", "val_loss" if use_val else "train_loss") + es_monitor = organized["wrapper"].get( + "es_monitor", "val_loss" if use_val else "train_loss" + ) es_mode = organized["wrapper"].get("es_mode", "min") - es_patience = organized["wrapper"].get("es_patience", self.default_es_patience) + es_patience = organized["wrapper"].get( + "es_patience", self.default_es_patience + ) es_verbose = organized["wrapper"].get("es_verbose", False) es_min_delta = organized["wrapper"].get("es_min_delta", 0.0) @@ -382,7 +390,9 @@ def maybe_add(cat, k, default): lambda i: ModelCheckpoint( monitor=("val_loss" if use_val else None), dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints", - filename=("{epoch}-{val_loss:.4f}" if use_val else "{epoch}"), + filename=( + "{epoch}-{val_loss:.4f}" if use_val else "{epoch}" + ), ) ) organized["trainer"]["callback_constructors"] = cb_ctors @@ -393,17 +403,24 @@ def maybe_add(cat, k, default): # -------- sanitize any pre-specified callbacks for no-val runs -------- cb_list = organized["trainer"].get("callbacks", []) - cb_list = [self._retarget_or_strip_early_stopping(cb, use_val) for cb in cb_list] + cb_list = [ + self._retarget_or_strip_early_stopping(cb, use_val) for cb in cb_list + ] organized["trainer"]["callbacks"] = cb_list # Also sanitize dynamically constructed callbacks ctor_list = organized["trainer"].get("callback_constructors", []) + def _wrap_ctor(ctor): def _wrapped(i): cb = ctor(i) return self._retarget_or_strip_early_stopping(cb, use_val) + return _wrapped - organized["trainer"]["callback_constructors"] = [_wrap_ctor(c) for c in ctor_list] + + organized["trainer"]["callback_constructors"] = [ + _wrap_ctor(c) for c in ctor_list + ] return organized @@ -428,8 +445,7 @@ def _build_datamodule( predict_batch_size=self.default_val_batch_size, num_workers=0, pin_memory=self._is_gpu(), - persistent_workers=None, - persistent_workers=False, + persistent_workers=None, # <-- only once drop_last=False, shuffle_train=True, shuffle_eval=False, @@ -442,7 +458,6 @@ def _build_datamodule( if dk["persistent_workers"] is None: dk["persistent_workers"] = bool(dk["num_workers"] > 0) - dm = ContextualizedRegressionDataModule( C=C, X=X, @@ -511,9 +526,13 @@ def _maybe_scale_X(self, X: np.ndarray) -> np.ndarray: return X # -------------------- public API -------------------- - def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs): + def predict( + self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs + ): if not hasattr(self, "models") or self.models is None: - raise ValueError("Trying to predict with a model that hasn't been trained yet.") + raise ValueError( + "Trying to predict with a model that hasn't been trained yet." + ) Cq = self._maybe_scale_C(C) Xq = self._maybe_scale_X(X) @@ -522,24 +541,41 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, preds = [] for i in range(len(self.models)): dm = self._build_datamodule( - C=Cq, X=Xq, Y=Yq, + C=Cq, + X=Xq, + Y=Yq, predict_idx=np.arange(len(Cq)), data_kwargs=dict( - train_batch_size=self._init_kwargs["data"].get("train_batch_size", self.default_train_batch_size), - val_batch_size=self._init_kwargs["data"].get("val_batch_size", self.default_val_batch_size), - test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), - predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), + train_batch_size=self._init_kwargs["data"].get( + "train_batch_size", self.default_train_batch_size + ), + val_batch_size=self._init_kwargs["data"].get( + "val_batch_size", self.default_val_batch_size + ), + test_batch_size=self._init_kwargs["data"].get( + "test_batch_size", self.default_test_batch_size + ), + predict_batch_size=self._init_kwargs["data"].get( + "predict_batch_size", self.default_val_batch_size + ), num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get("pin_memory", self._is_gpu()), - persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), + pin_memory=self._init_kwargs["data"].get( + "pin_memory", self._is_gpu() + ), + persistent_workers=self._init_kwargs["data"].get( + "persistent_workers", False + ), shuffle_train=False, shuffle_eval=False, dtype=self._init_kwargs["data"].get("dtype", torch.float), ), - task_type="singletask_univariate" if self._init_kwargs["model"].get("univariate", False) - else "singletask_multivariate", + task_type="singletask_univariate" + if self._init_kwargs["model"].get("univariate", False) + else "singletask_multivariate", + ) + yhat = self.trainers[i].predict_y( + self.models[i], dm.predict_dataloader(), **kwargs ) - yhat = self.trainers[i].predict_y(self.models[i], dm.predict_dataloader(), **kwargs) preds.append(yhat) predictions = np.array(preds) @@ -547,7 +583,9 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, predictions = np.mean(predictions, axis=0) if self.normalize and self.scalers["Y"] is not None: if individual_preds: - predictions = np.array([self.scalers["Y"].inverse_transform(p) for p in predictions]) + predictions = np.array( + [self.scalers["Y"].inverse_transform(p) for p in predictions] + ) else: predictions = self.scalers["Y"].inverse_transform(predictions) return predictions @@ -560,7 +598,9 @@ def predict_params( **kwargs, ): if not hasattr(self, "models") or self.models is None: - raise ValueError("Trying to predict with a model that hasn't been trained yet.") + raise ValueError( + "Trying to predict with a model that hasn't been trained yet." + ) Cq = self._maybe_scale_C(C) X_zero = np.zeros((len(Cq), self.x_dim), dtype=np.float32) @@ -574,29 +614,50 @@ def predict_params( Y=Y_zero if kwargs.pop("uses_y", True) else None, predict_idx=np.arange(len(Cq)), data_kwargs=dict( - train_batch_size=self._init_kwargs["data"].get("train_batch_size", self.default_train_batch_size), - val_batch_size=self._init_kwargs["data"].get("val_batch_size", self.default_val_batch_size), - test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), - predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), + train_batch_size=self._init_kwargs["data"].get( + "train_batch_size", self.default_train_batch_size + ), + val_batch_size=self._init_kwargs["data"].get( + "val_batch_size", self.default_val_batch_size + ), + test_batch_size=self._init_kwargs["data"].get( + "test_batch_size", self.default_test_batch_size + ), + predict_batch_size=self._init_kwargs["data"].get( + "predict_batch_size", self.default_val_batch_size + ), num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get("pin_memory", self._is_gpu()), - persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), + pin_memory=self._init_kwargs["data"].get( + "pin_memory", self._is_gpu() + ), + persistent_workers=self._init_kwargs["data"].get( + "persistent_workers", False + ), shuffle_train=False, shuffle_eval=False, dtype=self._init_kwargs["data"].get("dtype", torch.float), ), - task_type="singletask_univariate" if self._init_kwargs["model"].get("univariate", False) - else "singletask_multivariate", + task_type="singletask_univariate" + if self._init_kwargs["model"].get("univariate", False) + else "singletask_multivariate", + ) + pred = self.trainers[i].predict_params( + self.models[i], dm.predict_dataloader(), **kwargs ) - pred = self.trainers[i].predict_params(self.models[i], dm.predict_dataloader(), **kwargs) if model_includes_mus: - out_betas.append(pred[0]); out_mus.append(pred[1]) + out_betas.append(pred[0]) + out_mus.append(pred[1]) else: out_betas.append(pred) if model_includes_mus: - betas = np.array(out_betas); mus = np.array(out_mus) - return (betas, mus) if individual_preds else (np.mean(betas, axis=0), np.mean(mus, axis=0)) + betas = np.array(out_betas) + mus = np.array(out_mus) + return ( + (betas, mus) + if individual_preds + else (np.mean(betas, axis=0), np.mean(mus, axis=0)) + ) else: betas = np.array(out_betas) return betas if individual_preds else np.mean(betas, axis=0) @@ -612,7 +673,7 @@ def fit(self, *args, **kwargs) -> None: """ self.models, self.trainers = [], [] - # normalize argument order + # normalize argument order C_in = kwargs.pop("C", None) X_in = kwargs.pop("X", None) Y_in = kwargs.pop("Y", None) @@ -628,21 +689,29 @@ def fit(self, *args, **kwargs) -> None: else: C, X, Y = A, B, Carg else: - raise ValueError("Mismatched sample counts among provided arrays.") + raise ValueError( + "Mismatched sample counts among provided arrays." + ) elif len(args) == 2: A, B = args if A.shape[0] != B.shape[0]: - raise ValueError("Mismatched sample counts for two-argument fit.") + raise ValueError( + "Mismatched sample counts for two-argument fit." + ) # Assume (C, X) by default C, X, Y = A, B, None else: - raise ValueError("fit expects (C,X[,Y]) or (X,Y,C) or kw-only C=..., X=...") + raise ValueError( + "fit expects (C,X[,Y]) or (X,Y,C) or kw-only C=..., X=..." + ) # Optional scaling if self.normalize: - if self.scalers["C"] is None: self.scalers["C"] = StandardScaler().fit(C) + if self.scalers["C"] is None: + self.scalers["C"] = StandardScaler().fit(C) C = self.scalers["C"].transform(C) - if self.scalers["X"] is None: self.scalers["X"] = StandardScaler().fit(X) + if self.scalers["X"] is None: + self.scalers["X"] = StandardScaler().fit(X) X = self.scalers["X"].transform(X) self.context_dim = C.shape[-1] @@ -652,7 +721,8 @@ def fit(self, *args, **kwargs) -> None: if len(Y.shape) == 1: Y = np.expand_dims(Y, 1) if self.normalize and not np.array_equal(np.unique(Y), np.array([0, 1])): - if self.scalers["Y"] is None: self.scalers["Y"] = StandardScaler().fit(Y) + if self.scalers["Y"] is None: + self.scalers["Y"] = StandardScaler().fit(Y) Y = self.scalers["Y"].transform(Y) self.y_dim = Y.shape[-1] args = (C, X, Y) @@ -661,7 +731,9 @@ def fit(self, *args, **kwargs) -> None: args = (C, X) organized = self._organize_and_expand_fit_kwargs(**kwargs) - self.n_bootstraps = organized["wrapper"].get("n_bootstraps", self.n_bootstraps) + self.n_bootstraps = organized["wrapper"].get( + "n_bootstraps", self.n_bootstraps + ) n = C.shape[0] val_split = organized["data"].get("val_split", self.default_val_split) @@ -676,25 +748,48 @@ def fit(self, *args, **kwargs) -> None: # Indices train_idx, val_idx = self._split_train_data( - C, X, (args[2] if len(args) == 3 else None), + C, + X, + (args[2] if len(args) == 3 else None), Y_required=(len(args) == 3), val_split=val_split, ) test_idx = None # DataModule - task_type = "singletask_univariate" if organized["model"].get("univariate", False) else "singletask_multivariate" + task_type = ( + "singletask_univariate" + if organized["model"].get("univariate", False) + else "singletask_multivariate" + ) dm = self._build_datamodule( - C=args[0], X=args[1], Y=(args[2] if len(args) == 3 else None), - train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, + C=args[0], + X=args[1], + Y=(args[2] if len(args) == 3 else None), + train_idx=train_idx, + val_idx=val_idx, + test_idx=test_idx, data_kwargs=dict( - train_batch_size=organized["data"].get("train_batch_size", self.default_train_batch_size), - val_batch_size=organized["data"].get("val_batch_size", self.default_val_batch_size), - test_batch_size=organized["data"].get("test_batch_size", self.default_test_batch_size), - predict_batch_size=organized["data"].get("predict_batch_size", self.default_val_batch_size), + train_batch_size=organized["data"].get( + "train_batch_size", self.default_train_batch_size + ), + val_batch_size=organized["data"].get( + "val_batch_size", self.default_val_batch_size + ), + test_batch_size=organized["data"].get( + "test_batch_size", self.default_test_batch_size + ), + predict_batch_size=organized["data"].get( + "predict_batch_size", self.default_val_batch_size + ), num_workers=organized["data"].get("num_workers", 0), - pin_memory=organized["data"].get("pin_memory", self._is_gpu()), - persistent_workers=organized["data"].get("persistent_workers", organized["data"].get("num_workers", 0) > 0), + pin_memory=organized["data"].get( + "pin_memory", self._is_gpu() + ), + persistent_workers=organized["data"].get( + "persistent_workers", + organized["data"].get("num_workers", 0) > 0, + ), drop_last=organized["data"].get("drop_last", False), shuffle_train=organized["data"].get("shuffle_train", True), shuffle_eval=organized["data"].get("shuffle_eval", False), @@ -705,11 +800,14 @@ def fit(self, *args, **kwargs) -> None: # Trainer (fresh callbacks) trainer_kwargs = copy.deepcopy(organized["trainer"]) - trainer_kwargs["callbacks"] = [f(b) for f in trainer_kwargs.get("callback_constructors", [])] + trainer_kwargs["callbacks"] = [ + f(b) for f in trainer_kwargs.get("callback_constructors", []) + ] trainer_kwargs.pop("callback_constructors", None) # Build via factory (respects strategy strings and env) from contextualized.regression.trainers import make_trainer_with_env + trainer = make_trainer_with_env( self.trainer_constructor, **trainer_kwargs, @@ -737,8 +835,19 @@ def fit(self, *args, **kwargs) -> None: # Load best checkpoint if enabled if trainer_kwargs.get("enable_checkpointing", False): - ckpt_cb = next((cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)), None) - if ckpt_cb and ckpt_cb.best_model_path and os.path.exists(ckpt_cb.best_model_path): + ckpt_cb = next( + ( + cb + for cb in trainer.callbacks + if isinstance(cb, ModelCheckpoint) + ), + None, + ) + if ( + ckpt_cb + and ckpt_cb.best_model_path + and os.path.exists(ckpt_cb.best_model_path) + ): best = torch.load(ckpt_cb.best_model_path, map_location="cpu") model.load_state_dict(best["state_dict"]) From 581a6edf86143231a9f2b771db097eeb90467b03 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Sun, 21 Dec 2025 13:51:16 -0500 Subject: [PATCH 12/19] initial clean up of files for hpc implementation --- bash_scripts/network_heavy.sh | 222 +++++ bash_scripts/network_optimized.sh | 168 ++++ contextualized/__init__.py | 2 +- contextualized/callbacks.py | 1 - .../easy/ContextualizedClassifier.py | 3 +- contextualized/easy/ContextualizedNetworks.py | 159 +++- .../easy/ContextualizedRegressor.py | 2 +- contextualized/easy/tests.py | 2 +- .../easy/wrappers/SKLearnWrapper.py | 615 +++++++------ contextualized/modules.py | 2 +- contextualized/regression/__init__.py | 2 +- contextualized/regression/datamodules.py | 10 +- .../regression/lightning_modules.py | 108 ++- contextualized/regression/metamodels.py | 2 +- contextualized/regression/regularizers.py | 2 +- contextualized/regression/trainers.py | 77 +- contextualized/tests.py | 2 +- contextualized/utils/__init__.py | 47 + contextualized/utils/engine.py | 2 +- cpu_scale_bench.py | 448 ---------- network_scaling_heavy.py | 766 ++++++++++++++++ networks_pert_scale_bench.py | 826 ++++++++++++++++++ scale_bench.py => regression_scale_bench.py | 0 scripts/test_contextualized_dm.py | 252 ------ 24 files changed, 2654 insertions(+), 1066 deletions(-) create mode 100644 bash_scripts/network_heavy.sh create mode 100644 bash_scripts/network_optimized.sh delete mode 100644 cpu_scale_bench.py create mode 100644 network_scaling_heavy.py create mode 100644 networks_pert_scale_bench.py rename scale_bench.py => regression_scale_bench.py (100%) delete mode 100644 scripts/test_contextualized_dm.py diff --git a/bash_scripts/network_heavy.sh b/bash_scripts/network_heavy.sh new file mode 100644 index 00000000..2eec78fd --- /dev/null +++ b/bash_scripts/network_heavy.sh @@ -0,0 +1,222 @@ +#!/bin/bash +# ============================================================================= +# HEAVY ContextualizedCorrelationNetworks DDP SCALING BENCHMARK +# ============================================================================= +# +# This benchmark tests multi-GPU scaling with the ACTUAL CCN model, but +# configured for maximum compute to properly stress-test GPU parallelism. +# +# HEAVY Configuration vs Original: +# Parameter | Original | Heavy | Compute Impact +# -----------------|----------|---------|---------------- +# Archetypes | 16-30 | 64 | 2-4x more mixture components +# Encoder width | 25 | 256 | 10x wider networks +# Encoder layers | 3 | 6 | 2x deeper networks +# Bootstraps | 1 | 3 | 3x more models (ensemble) +# Data PCs | 50 | 100 | 2x larger output space +# +# Estimated parameters: ~15-30M (vs ~300K original) +# +# Expected scaling: +# 1 GPU: baseline +# 2 GPU: ~1.85x speedup (92% efficiency) +# 3 GPU: ~2.65x speedup (88% efficiency) +# 4 GPU: ~3.4x speedup (85% efficiency) +# +# ============================================================================= + +set -e + +# ===== CONFIGURATION ===== +SCRIPT="ccn_scaling_heavy.py" +OUTDIR="bench_results_ccn_heavy" +EPOCHS=20 +WARMUP=1 +BATCH_SIZE=512 # Per GPU + +# HEAVY CCN Architecture +ARCHETYPES=64 # Original: 16-30 +ENCODER_WIDTH=256 # Original: 25 +ENCODER_LAYERS=6 # Original: 3 +BOOTSTRAPS=1 # Original: 1 + +# Data dimensionality +DATA_PCS=100 # Original: 50 +CONTEXT_PCS=100 + +# Runtime +NUM_WORKERS=4 +SUBSAMPLE=1.0 + +# Clean previous results +echo "==============================================" +echo "Cleaning previous results..." +echo "==============================================" +rm -f "${OUTDIR}/ccn_heavy_scaling_results.csv" +mkdir -p "${OUTDIR}" + +echo "" +echo "==============================================" +echo "HEAVY CCN SCALING BENCHMARK" +echo "==============================================" +echo "Script: ${SCRIPT}" +echo "Epochs: ${EPOCHS} (+ ${WARMUP} warmup)" +echo "Batch size per GPU: ${BATCH_SIZE}" +echo "" +echo "--- HEAVY CCN Config ---" +echo "Archetypes: ${ARCHETYPES}" +echo "Encoder: ${ENCODER_WIDTH}w × ${ENCODER_LAYERS}L" +echo "Bootstraps: ${BOOTSTRAPS}" +echo "Data PCs: ${DATA_PCS}" +echo "" +echo "Output: ${OUTDIR}" +echo "" + +# ----------------------------------------------------------------------------- +# TEST 1: 1-GPU Baseline +# ----------------------------------------------------------------------------- +echo "==============================================" +echo "[1/4] Running 1-GPU baseline..." +echo "==============================================" + +python ${SCRIPT} \ + --epochs ${EPOCHS} \ + --warmup-epochs ${WARMUP} \ + --batch-size ${BATCH_SIZE} \ + --archetypes ${ARCHETYPES} \ + --encoder-width ${ENCODER_WIDTH} \ + --encoder-layers ${ENCODER_LAYERS} \ + --bootstraps ${BOOTSTRAPS} \ + --data-pcs ${DATA_PCS} \ + --context-pcs ${CONTEXT_PCS} \ + --num-workers ${NUM_WORKERS} \ + --subsample-fraction ${SUBSAMPLE} \ + --devices 1 \ + --outdir ${OUTDIR} \ + --label "1gpu_baseline" + +# Extract baseline time for efficiency calculation +BASELINE_TIME=$(tail -1 "${OUTDIR}/ccn_heavy_scaling_results.csv" | cut -d',' -f2) +echo "" +echo ">>> Baseline time: ${BASELINE_TIME}s" +echo "" + +# ----------------------------------------------------------------------------- +# TEST 2: 2-GPU DDP +# ----------------------------------------------------------------------------- +echo "==============================================" +echo "[2/4] Running 2-GPU DDP..." +echo "==============================================" + +torchrun \ + --standalone \ + --nproc_per_node=2 \ + ${SCRIPT} \ + --epochs ${EPOCHS} \ + --warmup-epochs ${WARMUP} \ + --batch-size ${BATCH_SIZE} \ + --archetypes ${ARCHETYPES} \ + --encoder-width ${ENCODER_WIDTH} \ + --encoder-layers ${ENCODER_LAYERS} \ + --bootstraps ${BOOTSTRAPS} \ + --data-pcs ${DATA_PCS} \ + --context-pcs ${CONTEXT_PCS} \ + --num-workers ${NUM_WORKERS} \ + --subsample-fraction ${SUBSAMPLE} \ + --devices 2 \ + --outdir ${OUTDIR} \ + --label "2gpu_ddp" \ + --baseline-time ${BASELINE_TIME} + +echo "" + +# ----------------------------------------------------------------------------- +# TEST 3: 3-GPU DDP +# ----------------------------------------------------------------------------- +echo "==============================================" +echo "[3/4] Running 3-GPU DDP..." +echo "==============================================" + +torchrun \ + --standalone \ + --nproc_per_node=3 \ + ${SCRIPT} \ + --epochs ${EPOCHS} \ + --warmup-epochs ${WARMUP} \ + --batch-size ${BATCH_SIZE} \ + --archetypes ${ARCHETYPES} \ + --encoder-width ${ENCODER_WIDTH} \ + --encoder-layers ${ENCODER_LAYERS} \ + --bootstraps ${BOOTSTRAPS} \ + --data-pcs ${DATA_PCS} \ + --context-pcs ${CONTEXT_PCS} \ + --num-workers ${NUM_WORKERS} \ + --subsample-fraction ${SUBSAMPLE} \ + --devices 3 \ + --outdir ${OUTDIR} \ + --label "3gpu_ddp" \ + --baseline-time ${BASELINE_TIME} + +echo "" + +# ----------------------------------------------------------------------------- +# TEST 4: 4-GPU DDP +# ----------------------------------------------------------------------------- +echo "==============================================" +echo "[4/4] Running 4-GPU DDP..." +echo "==============================================" + +torchrun \ + --standalone \ + --nproc_per_node=4 \ + ${SCRIPT} \ + --epochs ${EPOCHS} \ + --warmup-epochs ${WARMUP} \ + --batch-size ${BATCH_SIZE} \ + --archetypes ${ARCHETYPES} \ + --encoder-width ${ENCODER_WIDTH} \ + --encoder-layers ${ENCODER_LAYERS} \ + --bootstraps ${BOOTSTRAPS} \ + --data-pcs ${DATA_PCS} \ + --context-pcs ${CONTEXT_PCS} \ + --num-workers ${NUM_WORKERS} \ + --subsample-fraction ${SUBSAMPLE} \ + --devices 4 \ + --outdir ${OUTDIR} \ + --label "4gpu_ddp" \ + --baseline-time ${BASELINE_TIME} + +echo "" + +# ----------------------------------------------------------------------------- +# SUMMARY +# ----------------------------------------------------------------------------- +echo "==============================================" +echo "BENCHMARK COMPLETE" +echo "==============================================" +echo "" +echo "Full Results:" +echo "" +column -t -s',' "${OUTDIR}/ccn_heavy_scaling_results.csv" +echo "" + +echo "==============================================" +echo "SCALING SUMMARY" +echo "==============================================" +awk -F',' ' +NR==1 {next} +{ + printf " %-15s: %8.2fs | %5.2fx speedup | %5.1f%% efficiency\n", $1, $2, $13, $14 +} +' "${OUTDIR}/ccn_heavy_scaling_results.csv" +echo "" + +echo "==============================================" +echo "CCN CONFIGURATION USED" +echo "==============================================" +echo " Archetypes: ${ARCHETYPES}" +echo " Encoder width: ${ENCODER_WIDTH}" +echo " Encoder layers: ${ENCODER_LAYERS}" +echo " Bootstraps: ${BOOTSTRAPS}" +echo " Data PCs: ${DATA_PCS}" +echo "" \ No newline at end of file diff --git a/bash_scripts/network_optimized.sh b/bash_scripts/network_optimized.sh new file mode 100644 index 00000000..71266cc4 --- /dev/null +++ b/bash_scripts/network_optimized.sh @@ -0,0 +1,168 @@ +#!/bin/bash +# ============================================================================= +# OPTIMIZED DDP SCALING BENCHMARK SCRIPT +# ============================================================================= +# +# This script runs a proper scaling comparison with CONSTANT GLOBAL BATCH SIZE +# to measure true parallel efficiency. +# +# Key differences from original: +# 1. Global batch size stays at 256 regardless of GPU count +# 2. Each GPU processes 256/N samples per batch +# 3. Warmup epoch excluded from timing +# 4. Reduced DataLoader workers to avoid contention +# 5. NCCL optimizations enabled +# +# Expected scaling (realistic for small models): +# 1 GPU: baseline +# 2 GPU: 1.6-1.8x speedup (80-90% efficiency) +# 3 GPU: 2.2-2.6x speedup (73-87% efficiency) +# 4 GPU: 2.8-3.4x speedup (70-85% efficiency) +# +# ============================================================================= + +set -e # Exit on error + +# Configuration +SCRIPT="unseen_pert_scale_optimized.py" +OUTDIR="bench_results_optimized" +EPOCHS=40 +WARMUP=1 +BATCH_SIZE=256 # Per-GPU batch size (global = this × num_gpus) +NUM_WORKERS=4 # Will be auto-reduced for multi-GPU +SUBSAMPLE=1.0 # Use full data (matches existing cache filename) + +# IMPORTANT: For small models, we MUST scale batch size with GPUs. +# Otherwise communication overhead dominates and multi-GPU is SLOWER. +# Using --scale-batch flag to scale global batch with GPU count. + +# Clean previous results +rm -f "${OUTDIR}/scaling_results_optimized.csv" +mkdir -p "${OUTDIR}" + +echo "==============================================" +echo "STARTING SCALING BENCHMARK" +echo "==============================================" +echo "Script: ${SCRIPT}" +echo "Epochs: ${EPOCHS} (+ ${WARMUP} warmup)" +echo "Global Batch Size: ${BATCH_SIZE} (constant)" +echo "Output: ${OUTDIR}" +echo "" + +# ----------------------------------------------------------------------------- +# TEST 1: 1-GPU Baseline +# ----------------------------------------------------------------------------- +echo "==============================================" +echo "[1/4] Running 1-GPU baseline..." +echo "==============================================" + +python ${SCRIPT} \ + --epochs ${EPOCHS} \ + --warmup-epochs ${WARMUP} \ + --subsample-fraction ${SUBSAMPLE} \ + --devices 1 \ + --batch-size ${BATCH_SIZE} \ + --num-workers ${NUM_WORKERS} \ + --outdir ${OUTDIR} \ + --label "1gpu_baseline" \ + --verbose + +# Extract baseline time for efficiency calculation +BASELINE_TIME=$(tail -1 "${OUTDIR}/scaling_results_optimized.csv" | cut -d',' -f2) +echo "" +echo "Baseline time: ${BASELINE_TIME}s" +echo "" + +# ----------------------------------------------------------------------------- +# TEST 2: 2-GPU with torchrun +# ----------------------------------------------------------------------------- +echo "==============================================" +echo "[2/4] Running 2-GPU DDP with torchrun..." +echo "==============================================" + +torchrun \ + --standalone \ + --nproc_per_node=2 \ + ${SCRIPT} \ + --epochs ${EPOCHS} \ + --warmup-epochs ${WARMUP} \ + --subsample-fraction ${SUBSAMPLE} \ + --devices 2 \ + --batch-size ${BATCH_SIZE} \ + --num-workers ${NUM_WORKERS} \ + --outdir ${OUTDIR} \ + --label "2gpu_ddp" \ + --baseline-time ${BASELINE_TIME} \ + --scale-batch \ + --verbose + +echo "" + +# ----------------------------------------------------------------------------- +# TEST 3: 3-GPU with torchrun +# ----------------------------------------------------------------------------- +echo "==============================================" +echo "[3/4] Running 3-GPU DDP with torchrun..." +echo "==============================================" + +torchrun \ + --standalone \ + --nproc_per_node=3 \ + ${SCRIPT} \ + --epochs ${EPOCHS} \ + --warmup-epochs ${WARMUP} \ + --subsample-fraction ${SUBSAMPLE} \ + --devices 3 \ + --batch-size ${BATCH_SIZE} \ + --num-workers ${NUM_WORKERS} \ + --outdir ${OUTDIR} \ + --label "3gpu_ddp" \ + --baseline-time ${BASELINE_TIME} \ + --scale-batch \ + --verbose + +echo "" + +# ----------------------------------------------------------------------------- +# TEST 4: 4-GPU with torchrun +# ----------------------------------------------------------------------------- +echo "==============================================" +echo "[4/4] Running 4-GPU DDP with torchrun..." +echo "==============================================" + +torchrun \ + --standalone \ + --nproc_per_node=4 \ + ${SCRIPT} \ + --epochs ${EPOCHS} \ + --warmup-epochs ${WARMUP} \ + --subsample-fraction ${SUBSAMPLE} \ + --devices 4 \ + --batch-size ${BATCH_SIZE} \ + --num-workers ${NUM_WORKERS} \ + --outdir ${OUTDIR} \ + --label "4gpu_ddp" \ + --baseline-time ${BASELINE_TIME} \ + --scale-batch \ + --verbose + +echo "" + +# ----------------------------------------------------------------------------- +# SUMMARY +# ----------------------------------------------------------------------------- +echo "==============================================" +echo "BENCHMARK COMPLETE" +echo "==============================================" +echo "" +echo "Results saved to: ${OUTDIR}/scaling_results_optimized.csv" +echo "" +echo "Results:" +cat "${OUTDIR}/scaling_results_optimized.csv" | column -t -s',' +echo "" + +# Calculate speedups +echo "Speedup Summary:" +echo "----------------" +awk -F',' 'NR==1 {next} NR==2 {base=$2} {printf "%s: %.2fs (%.2fx speedup, %.1f%% efficiency)\n", $1, $2, base/$2, $8}' \ + "${OUTDIR}/scaling_results_optimized.csv" \ No newline at end of file diff --git a/contextualized/__init__.py b/contextualized/__init__.py index 5b46cfe0..82574ff5 100644 --- a/contextualized/__init__.py +++ b/contextualized/__init__.py @@ -26,4 +26,4 @@ os.environ.setdefault("NCCL_IB_DISABLE", "1") os.environ.setdefault("NCCL_P2P_DISABLE", "0") from .utils.engine import pick_engine # optional re-export -__all__ = ["pick_engine"] +__all__ = ["pick_engine"] \ No newline at end of file diff --git a/contextualized/callbacks.py b/contextualized/callbacks.py index b41d4d28..5c81d425 100644 --- a/contextualized/callbacks.py +++ b/contextualized/callbacks.py @@ -87,4 +87,3 @@ def write_on_batch_end( self.arr[n, yi, xi, 0] = beta self.arr[n, yi, xi, 1] = mu - diff --git a/contextualized/easy/ContextualizedClassifier.py b/contextualized/easy/ContextualizedClassifier.py index 75a3b5cf..1360fd8d 100644 --- a/contextualized/easy/ContextualizedClassifier.py +++ b/contextualized/easy/ContextualizedClassifier.py @@ -46,5 +46,4 @@ def predict_proba(self, C, X, **kwargs): probs = probs[..., 0] p1 = probs p0 = 1.0 - p1 - return np.stack([p0, p1], axis=-1) - + return np.stack([p0, p1], axis=-1) \ No newline at end of file diff --git a/contextualized/easy/ContextualizedNetworks.py b/contextualized/easy/ContextualizedNetworks.py index 0fd505a3..a701a840 100644 --- a/contextualized/easy/ContextualizedNetworks.py +++ b/contextualized/easy/ContextualizedNetworks.py @@ -2,9 +2,10 @@ sklearn-like interface to Contextualized Networks. """ -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Optional import numpy as np +import torch from contextualized.easy.wrappers import SKLearnWrapper from contextualized.regression.trainers import CorrelationTrainer, MarkovTrainer @@ -27,10 +28,31 @@ class ContextualizedNetworks(SKLearnWrapper): """ def _split_train_data( - self, C: np.ndarray, X: np.ndarray, **kwargs - ) -> Tuple[List[np.ndarray], List[np.ndarray]]: - """Splits data into train and val sets (no Y for networks).""" - return super()._split_train_data(C, X, Y_required=False, **kwargs) + self, + C: np.ndarray, + X: np.ndarray, + Y: Optional[np.ndarray] = None, + *, + Y_required: bool = False, + val_split: Optional[float] = None, + random_state: Optional[int] = None, + shuffle: bool = True, + **kwargs, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Override only to change the default behavior (networks do not *require* Y), + but keep the signature compatible with SKLearnWrapper._split_train_data. + """ + return super()._split_train_data( + C, + X, + Y, + Y_required=Y_required, + val_split=val_split, + random_state=random_state, + shuffle=shuffle, + **kwargs, + ) def predict_networks( self, @@ -83,13 +105,18 @@ def predict_correlation( Y=Y_zero, predict_idx=np.arange(len(C_scaled)), data_kwargs=dict( - batch_size=self._init_kwargs["data"].get("val_batch_size", 16), + train_batch_size=self._init_kwargs["data"].get("train_batch_size", 16), + val_batch_size=self._init_kwargs["data"].get("val_batch_size", 16), + test_batch_size=self._init_kwargs["data"].get("test_batch_size", 16), + predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", 16), num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator == "gpu")), + pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator in ("cuda", "gpu"))), persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), - shuffle_train=False, shuffle_eval=False, + shuffle_train=False, + shuffle_eval=False, dtype=self._init_kwargs["data"].get("dtype", torch.float), ), + task_type="singletask_univariate", # correlation uses univariate convention ) rhos = np.array([ @@ -105,18 +132,99 @@ def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False ) -> Union[np.ndarray, List[np.ndarray]]: """ - Measures mean-squared reconstruction errors using (betas, mus). + Measures mean-squared reconstruction errors between the true X and the + reconstructed X_hat produced by the contextualized correlation network. + + Parameters + ---------- + C : np.ndarray + Context matrix of shape (N, C_dim). + X : np.ndarray + Data matrix of shape (N, F). + individual_preds : bool, default False + If False: return per-sample MSE averaged over bootstraps. + If True: return per-bootstrap, per-sample MSE. + + Returns + ------- + np.ndarray + If individual_preds is False: shape (N_eff,), per-sample MSE averaged + over bootstraps. + + If individual_preds is True: shape (B, N_eff), per-bootstrap, per-sample MSE. + + Notes + ----- + In single-process (non-distributed) settings, N_eff == N (full dataset). + + Under distributed settings, predict_X may operate on rank-local shards so + the number of samples in X_hat (N_hat) may differ from len(X) (N_true). + In that case we align both X_hat and X to N_eff = min(N_hat, N_true) to + avoid shape mismatches, yielding valid MSEs for the evaluated subset. """ - betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True) - mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples - F = X.shape[-1] - for i in range(F): - for j in range(F): - tiled_xi = np.array([X[:, i] for _ in range(len(betas))]) - tiled_xj = np.array([X[:, j] for _ in range(len(betas))]) - residuals = tiled_xi - betas[:, :, i, j] * tiled_xj - mus[:, :, i, j] - mses += residuals**2 / (F**2) - return mses if individual_preds else np.mean(mses, axis=0) + # Predict reconstructions of X for each bootstrap model + X_hat = self.predict_X(C, X, individual_preds=True) + X_hat = np.array(X_hat) + + if X_hat.ndim not in (3, 4): + raise ValueError( + f"Unexpected X_hat ndim={X_hat.ndim} with shape {X_hat.shape} in " + "ContextualizedCorrelationNetworks.measure_mses" + ) + + # X: (N_true, F) + N_true, F = X.shape + + if X_hat.ndim == 3: + # X_hat: (B, N_hat, F_hat) + B, N_hat, F_hat = X_hat.shape + if F_hat != F: + raise ValueError( + f"Feature dimension mismatch between X_hat (F={F_hat}) and X (F={F}) " + "in ContextualizedCorrelationNetworks.measure_mses" + ) + + # Align on the sample dimension + N_eff = min(N_hat, N_true) + if N_hat != N_true: + X_hat = X_hat[:, :N_eff, :] + X_eff = X[:N_eff, :] + else: + N_eff = N_true + X_eff = X + + X_true = X_eff[None, :, :] # (1, N_eff, F) + residuals = X_hat - X_true # (B, N_eff, F) + mses = (residuals ** 2).mean(axis=-1) # (B, N_eff) + + else: # X_hat.ndim == 4 + # X_hat: (B, N_hat, F1, F2) + B, N_hat, F1, F2 = X_hat.shape + if F1 != F: + raise ValueError( + f"Feature dimension mismatch between X_hat (F1={F1}) and X (F={F}) " + "in ContextualizedCorrelationNetworks.measure_mses" + ) + + N_eff = min(N_hat, N_true) + if N_hat != N_true: + X_hat = X_hat[:, :N_eff, :, :] + X_eff = X[:N_eff, :] + else: + N_eff = N_true + X_eff = X + + X_true = X_eff[None, :, :, None] # (1, N_eff, F, 1) + residuals = X_hat - X_true # (B, N_eff, F, F2) + mses = (residuals ** 2).mean(axis=(-1, -2)) # (B, N_eff) + + # mses: (B, N_eff) + return mses if individual_preds else mses.mean(axis=0) + + + + + class ContextualizedMarkovNetworks(ContextualizedNetworks): @@ -141,13 +249,18 @@ def predict_precisions( Y=Y_zero, predict_idx=np.arange(len(C_scaled)), data_kwargs=dict( - batch_size=self._init_kwargs["data"].get("val_batch_size", 16), + train_batch_size=self._init_kwargs["data"].get("train_batch_size", 16), + val_batch_size=self._init_kwargs["data"].get("val_batch_size", 16), + test_batch_size=self._init_kwargs["data"].get("test_batch_size", 16), + predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", 16), num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator == "gpu")), + pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator in ("cuda", "gpu"))), persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), - shuffle_train=False, shuffle_eval=False, + shuffle_train=False, + shuffle_eval=False, dtype=self._init_kwargs["data"].get("dtype", torch.float), ), + task_type="singletask_univariate", ) precisions = np.array([ @@ -340,4 +453,4 @@ def measure_mses( for b in range(len(betas)): X_pred = dag_pred_np(X, betas[b]) mses[b, :] = np.mean((X - X_pred) ** 2, axis=1) - return mses if individual_preds else np.mean(mses, axis=0) + return mses if individual_preds else np.mean(mses, axis=0) \ No newline at end of file diff --git a/contextualized/easy/ContextualizedRegressor.py b/contextualized/easy/ContextualizedRegressor.py index 2ac98b3c..932fa971 100644 --- a/contextualized/easy/ContextualizedRegressor.py +++ b/contextualized/easy/ContextualizedRegressor.py @@ -53,4 +53,4 @@ def __init__(self, **kwargs): # Preserve legacy behavior that Y is expected/required for regression fits def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs): - return super()._split_train_data(C, X, Y, Y_required=True, **kwargs) + return super()._split_train_data(C, X, Y, Y_required=True, **kwargs) \ No newline at end of file diff --git a/contextualized/easy/tests.py b/contextualized/easy/tests.py index 2f368468..5ca6d269 100644 --- a/contextualized/easy/tests.py +++ b/contextualized/easy/tests.py @@ -398,4 +398,4 @@ def test_regressor_normalization(self): if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index bbd4fb42..ade6f4ca 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -4,12 +4,13 @@ from typing import * import numpy as np import torch +import torch.distributed as dist from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins.environments import LightningEnvironment -from pytorch_lightning.strategies import DDPStrategy # PL v1 Strategy API (not strictly required here) +from pytorch_lightning.strategies import DDPStrategy from contextualized.functions import LINK_FUNCTIONS from contextualized.regression import REGULARIZERS, LOSSES @@ -19,7 +20,7 @@ DEFAULT_N_BOOTSTRAPS = 1 DEFAULT_ES_PATIENCE = 1 DEFAULT_VAL_BATCH_SIZE = 16 -DEFAULT_TRAIN_BATCH_SIZE = 1 +DEFAULT_TRAIN_BATCH_SIZE = 64 DEFAULT_TEST_BATCH_SIZE = 16 DEFAULT_VAL_SPLIT = 0.2 DEFAULT_ENCODER_TYPE = "mlp" @@ -29,19 +30,52 @@ DEFAULT_NORMALIZE = False +def _is_distributed() -> bool: + """Check if we're in a distributed context.""" + return dist.is_available() and dist.is_initialized() + + +def _get_rank() -> int: + """Get current process rank.""" + if _is_distributed(): + return dist.get_rank() + return int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0"))) + + +def _is_main_process() -> bool: + """Check if this is the main process (rank 0).""" + return _get_rank() == 0 + + +def _gather_predictions(local_preds: np.ndarray, world_size: int) -> np.ndarray: + """ + Gather predictions from all ranks to rank 0. + Returns full predictions on rank 0, None on other ranks. + """ + if not _is_distributed() or world_size == 1: + return local_preds + + local_tensor = torch.from_numpy(local_preds).cuda() + + if _is_main_process(): + gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)] + dist.gather(local_tensor, gather_list=gathered, dst=0) + return torch.cat(gathered, dim=0).cpu().numpy() + else: + dist.gather(local_tensor, dst=0) + return None + + class SKLearnWrapper: """ An sklearn-like wrapper for Contextualized models. - - Args: - base_constructor (class): Base LightningModule constructor. - extra_model_kwargs (Iterable[str]): Extra model kwargs to accept. - extra_data_kwargs (Iterable[str]): Extra data kwargs to accept. - trainer_constructor (class): Trainer class (usually RegressionTrainer). - normalize (bool): If True, standardize C/X (and Y if continuous). + + FIXED VERSION with proper DDP handling for: + - Prediction (avoids duplicate computation) + - Data loading (proper num_workers) + - Distributed inference """ - # -------------------- defaults -------------------- def _set_defaults(self): self.default_learning_rate = DEFAULT_LEARNING_RATE self.default_n_bootstraps = DEFAULT_N_BOOTSTRAPS @@ -68,9 +102,15 @@ def __init__( self.base_constructor = base_constructor self.trainer_constructor = trainer_constructor + self._trainer_init_kwargs = kwargs.pop("trainer_kwargs", None) + self.n_bootstraps = 1 self.models = None self.trainers = None + + # Track if we trained with DDP (affects prediction strategy) + self._trained_with_ddp = False + self._trained_devices = 1 self.normalize = kwargs.pop("normalize", self.default_normalize) self.scalers = {"C": None, "X": None, "Y": None} @@ -150,7 +190,6 @@ def __init__( "data", kwargs.pop("remove_data_kwargs", []), acceptable=False ) - # Convenience aliases handled at construction self.convenience_kwargs = [ "alpha", "l1_ratio", @@ -161,7 +200,6 @@ def __init__( "encoder_link_fn", ] - # Model constructor kwargs (with convenience mapping) self.constructor_kwargs = self._organize_constructor_kwargs(**kwargs) self.constructor_kwargs["encoder_kwargs"]["width"] = kwargs.pop( "width", self.constructor_kwargs["encoder_kwargs"]["width"] @@ -176,7 +214,6 @@ def __init__( ), ) - # Everything else self.not_constructor_kwargs = { k: v for k, v in kwargs.items() @@ -188,13 +225,12 @@ def __init__( ) for k, v in self.constructor_kwargs.items(): self._init_kwargs["model"][k] = v - for kw in unrecognized: - print(f"Received unknown keyword argument {kw}, probably ignoring.") - # -------------------- helpers -------------------- + if self._trainer_init_kwargs is not None: + self._init_kwargs["trainer"].update(self._trainer_init_kwargs) - def _is_gpu(self) -> bool: - return self.accelerator in ("cuda", "gpu") + for kw in unrecognized: + print(f"Received unknown keyword argument {kw}, probably ignoring.") def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True): if acceptable: @@ -242,7 +278,6 @@ def maybe_add(kw, default_val): if kwargs.get("subtype_probabilities", False): model["encoder_kwargs"]["link_fn"] = LINK_FUNCTIONS["softmax"] - # Regularizer if "model_regularizer" in self.acceptable_kwargs["model"]: if kwargs.get("alpha", 0) > 0: model["model_regularizer"] = REGULARIZERS["l1_l2"]( @@ -276,19 +311,48 @@ def _retarget_or_strip_early_stopping(cb, use_val: bool, train_monitor="train_lo min_delta=getattr(cb, "min_delta", 0.0), ) return cb + + def _default_num_workers(self, devices: int) -> int: + """ + Heuristic for default DataLoader workers. + FIXED: CPU also benefits from workers for I/O overlap. + """ + try: + n_cpu = os.cpu_count() or 0 + except Exception: + n_cpu = 0 + + if n_cpu <= 0: + return 0 + + # For CPU-only, still use some workers for data loading overlap + if self.accelerator not in ("cuda", "gpu"): + return min(2, n_cpu) + + world_size_env = os.environ.get("WORLD_SIZE", None) + if world_size_env is not None: + try: + world_size = max(1, int(world_size_env)) + except ValueError: + world_size = 1 + else: + world_size = max(1, devices) + + cpu_per_rank = max(1, n_cpu // world_size) + # 2-4 workers per rank, capped + return int(min(4, max(2, cpu_per_rank // 2))) - # -------------------- fit kwarg expansion -------------------- def _organize_and_expand_fit_kwargs(self, **kwargs): """ - Expand/normalize kwargs for data/model/trainer/wrapper/fit, and build a clean - configuration dict for downstream construction. Critically: - • Merge constructor-time defaults BEFORE computing use_val. - • Only add EarlyStopping if a val loop exists and patience > 0. - • Retarget or strip EarlyStopping if no val loop. + Expand/normalize kwargs for data/model/trainer/wrapper/fit. + FIXED: Better DDP defaults and tracking. """ organized, unrecognized = self._organize_kwargs(**kwargs) - # -------- epochs (avoid PL default 1000) -------- + for category, cat_kwargs in self._init_kwargs.items(): + for k, v in cat_kwargs.items(): + organized[category].setdefault(k, v) + max_epochs_cli = kwargs.get("max_epochs", None) epochs_cli = kwargs.get("epochs", None) if max_epochs_cli is not None: @@ -296,43 +360,82 @@ def _organize_and_expand_fit_kwargs(self, **kwargs): elif epochs_cli is not None: organized["trainer"]["max_epochs"] = int(epochs_cli) else: - organized["trainer"]["max_epochs"] = 3 + organized["trainer"].setdefault("max_epochs", 3) - # -------- merge constructor defaults BEFORE using them -------- - for category, cat_kwargs in self._init_kwargs.items(): - for k, v in cat_kwargs.items(): - organized[category].setdefault(k, v) - - # -------- world size / validation decision -------- - world_size = int(os.getenv("WORLD_SIZE", "1")) current_val_split = organized["data"].get("val_split", self.default_val_split) organized["data"]["val_split"] = current_val_split use_val = float(current_val_split) > 0.0 - # -------- trainer defaults -------- organized["trainer"].setdefault("accelerator", self.accelerator) organized["trainer"].setdefault("enable_progress_bar", False) organized["trainer"].setdefault("logger", False) organized["trainer"].setdefault("enable_checkpointing", False) organized["trainer"].setdefault("num_sanity_val_steps", 0) - organized["trainer"].setdefault("precision", 32) + + # FIXED: Default to mixed precision on GPU + if self.accelerator in ("cuda", "gpu"): + organized["trainer"].setdefault("precision", "16-mixed") + else: + organized["trainer"].setdefault("precision", 32) + if not use_val: organized["trainer"].setdefault("limit_val_batches", 0) - if world_size > 1: - organized["trainer"].setdefault("devices", world_size) - organized["trainer"].setdefault("strategy", "ddp") # string to allow factory + world_size_env = int(os.environ.get("WORLD_SIZE", "1")) + if "devices" not in organized["trainer"]: + # When torchrun is active, devices must match world_size + organized["trainer"]["devices"] = world_size_env if world_size_env > 1 else 1 + + devices_cfg = organized["trainer"].get("devices", 1) + if isinstance(devices_cfg, int): + devices = devices_cfg + elif isinstance(devices_cfg, (list, tuple)): + devices = len(devices_cfg) else: - organized["trainer"]["devices"] = 1 - organized["trainer"].setdefault("strategy", "auto") - organized["trainer"].setdefault("plugins", [LightningEnvironment()]) + devices = 1 + + # Validate: if torchrun sets WORLD_SIZE > 1, devices must match + if world_size_env > 1 and devices != world_size_env: + if _is_main_process(): + print(f"[WARNING] torchrun WORLD_SIZE={world_size_env} but devices={devices}. " + f"Overriding devices to {world_size_env}.") + devices = world_size_env + organized["trainer"]["devices"] = devices + + # Track for prediction strategy + self._trained_devices = devices + self._trained_with_ddp = devices > 1 + + if "strategy" not in organized["trainer"]: + if devices > 1 or world_size_env > 1: + from datetime import timedelta + # Check if we're under torchrun (process group may already exist) + if world_size_env > 1: + # torchrun case: let Lightning use existing process group + organized["trainer"]["strategy"] = "ddp" + else: + # Lightning-spawned DDP case + organized["trainer"]["strategy"] = DDPStrategy( + process_group_backend="nccl" if torch.cuda.is_available() else "gloo", + find_unused_parameters=False, + broadcast_buffers=False, + timeout=timedelta(minutes=30), + ) + else: + organized["trainer"]["strategy"] = "auto" + + if ( + organized["trainer"].get("strategy") in ("auto", None) + and organized["trainer"].get("devices", 1) == 1 + and world_size_env == 1 # Not under torchrun + and "plugins" not in organized["trainer"] + ): + organized["trainer"]["plugins"] = [LightningEnvironment()] - # Helper to safely set defaults if the key is permitted for that category def maybe_add(cat, k, default): if k in self.acceptable_kwargs[cat]: organized[cat][k] = organized[cat].get(k, default) - # -------- model defaults -------- maybe_add("model", "learning_rate", self.default_learning_rate) maybe_add("model", "context_dim", self.context_dim) maybe_add("model", "x_dim", self.x_dim) @@ -340,40 +443,37 @@ def maybe_add(cat, k, default): if organized["model"].get("num_archetypes", 1) == 0: organized["model"].pop("num_archetypes", None) - # -------- data defaults (per-loader sizes) -------- maybe_add("data", "train_batch_size", self.default_train_batch_size) maybe_add("data", "val_batch_size", self.default_val_batch_size) maybe_add("data", "test_batch_size", self.default_test_batch_size) maybe_add("data", "predict_batch_size", self.default_val_batch_size) - maybe_add("data", "num_workers", 0) - maybe_add("data", "pin_memory", self._is_gpu()) - maybe_add( - "data", - "persistent_workers", - organized["data"].get("num_workers", 0) > 0, - ) - maybe_add("data", "drop_last", False) + + # FIXED: Better num_workers default + default_nw = self._default_num_workers(devices) + maybe_add("data", "num_workers", default_nw) + + maybe_add("data", "pin_memory", self.accelerator in ("cuda", "gpu")) + + persistent_default = organized["data"].get("num_workers", 0) > 0 + maybe_add("data", "persistent_workers", persistent_default) + + drop_last_default = devices > 1 + maybe_add("data", "drop_last", drop_last_default) + maybe_add("data", "shuffle_train", True) maybe_add("data", "shuffle_eval", False) maybe_add("data", "dtype", torch.float) - # -------- wrapper defaults -------- maybe_add("wrapper", "n_bootstraps", self.default_n_bootstraps) - # -------- EarlyStopping / Checkpoint constructors -------- - es_monitor = organized["wrapper"].get( - "es_monitor", "val_loss" if use_val else "train_loss" - ) + es_monitor = organized["wrapper"].get("es_monitor", "val_loss" if use_val else "train_loss") es_mode = organized["wrapper"].get("es_mode", "min") - es_patience = organized["wrapper"].get( - "es_patience", self.default_es_patience - ) + es_patience = organized["wrapper"].get("es_patience", self.default_es_patience) es_verbose = organized["wrapper"].get("es_verbose", False) es_min_delta = organized["wrapper"].get("es_min_delta", 0.0) cb_ctors = organized["trainer"].get("callback_constructors", []) - # Only add EarlyStopping when there is a val loop AND patience > 0 if use_val and (es_patience is not None and es_patience > 0): cb_ctors.append( lambda i: EarlyStopping( @@ -390,41 +490,30 @@ def maybe_add(cat, k, default): lambda i: ModelCheckpoint( monitor=("val_loss" if use_val else None), dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints", - filename=( - "{epoch}-{val_loss:.4f}" if use_val else "{epoch}" - ), + filename=("{epoch}-{val_loss:.4f}" if use_val else "{epoch}"), ) ) organized["trainer"]["callback_constructors"] = cb_ctors - # -------- unknown kw logging -------- for kw in unrecognized: print(f"Received unknown keyword argument {kw}, probably ignoring.") - # -------- sanitize any pre-specified callbacks for no-val runs -------- cb_list = organized["trainer"].get("callbacks", []) - cb_list = [ - self._retarget_or_strip_early_stopping(cb, use_val) for cb in cb_list - ] + cb_list = [self._retarget_or_strip_early_stopping(cb, use_val) for cb in cb_list] organized["trainer"]["callbacks"] = cb_list - # Also sanitize dynamically constructed callbacks ctor_list = organized["trainer"].get("callback_constructors", []) def _wrap_ctor(ctor): def _wrapped(i): cb = ctor(i) return self._retarget_or_strip_early_stopping(cb, use_val) - return _wrapped - organized["trainer"]["callback_constructors"] = [ - _wrap_ctor(c) for c in ctor_list - ] + organized["trainer"]["callback_constructors"] = [_wrap_ctor(c) for c in ctor_list] return organized - # -------------------- data module builder -------------------- def _build_datamodule( self, C: np.ndarray, @@ -444,8 +533,8 @@ def _build_datamodule( test_batch_size=self.default_test_batch_size, predict_batch_size=self.default_val_batch_size, num_workers=0, - pin_memory=self._is_gpu(), - persistent_workers=None, # <-- only once + pin_memory=(self.accelerator in ("cuda", "gpu")), + persistent_workers=False, drop_last=False, shuffle_train=True, shuffle_eval=False, @@ -454,10 +543,6 @@ def _build_datamodule( if data_kwargs: dk.update(data_kwargs) - # If not explicitly set, default to True when num_workers > 0 - if dk["persistent_workers"] is None: - dk["persistent_workers"] = bool(dk["num_workers"] > 0) - dm = ContextualizedRegressionDataModule( C=C, X=X, @@ -479,11 +564,8 @@ def _build_datamodule( shuffle_eval=dk["shuffle_eval"], dtype=dk["dtype"], ) - dm.prepare_data() - dm.setup() return dm - # -------------------- split helpers -------------------- def _split_train_data( self, C: np.ndarray, @@ -496,9 +578,7 @@ def _split_train_data( shuffle: bool = True, **_, ): - """ - Return (train_idx, val_idx) over rows; Lightning will attach DistributedSamplers. - """ + """Return (train_idx, val_idx) over rows.""" if Y_required and Y is None: raise ValueError("Y is required but was not provided.") n = C.shape[0] @@ -506,6 +586,19 @@ def _split_train_data( if vs <= 0.0: idx = np.arange(n) return idx, None + + # FIXED: Handle small datasets + min_val_samples = max(1, int(n * vs)) + if min_val_samples < 2: + # Too small for validation split + idx = np.arange(n) + return idx, None + + # CRITICAL FIX: Use deterministic random_state for DDP + # All ranks MUST get the same train/val split + if random_state is None: + random_state = 42 # Fixed seed for reproducibility across ranks + tr_idx, va_idx = train_test_split( np.arange(n), test_size=vs, @@ -514,7 +607,6 @@ def _split_train_data( ) return tr_idx, va_idx - # -------------------- optional scaling -------------------- def _maybe_scale_C(self, C: np.ndarray) -> np.ndarray: if self.normalize and self.scalers["C"] is not None: return self.scalers["C"].transform(C) @@ -525,62 +617,88 @@ def _maybe_scale_X(self, X: np.ndarray) -> np.ndarray: return self.scalers["X"].transform(X) return X - # -------------------- public API -------------------- - def predict( - self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs - ): + def _get_inference_device(self) -> torch.device: + """ + Get the device to use for inference. + FIXED: Always use single device for prediction to avoid DDP complexity. + """ + if self.accelerator in ("cuda", "gpu") and torch.cuda.is_available(): + return torch.device("cuda:0") + return torch.device("cpu") + + def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs): + """ + FIXED: Proper single-device inference that works after DDP training. + """ if not hasattr(self, "models") or self.models is None: - raise ValueError( - "Trying to predict with a model that hasn't been trained yet." - ) + raise ValueError("Trying to predict with a model that hasn't been trained yet.") Cq = self._maybe_scale_C(C) Xq = self._maybe_scale_X(X) Yq = np.zeros((len(Cq), self.y_dim), dtype=np.float32) + # FIXED: Use single device for inference + device = self._get_inference_device() + + # Build dataloader without distributed sampler + dm = self._build_datamodule( + C=Cq, + X=Xq, + Y=Yq, + predict_idx=np.arange(len(Cq)), + data_kwargs=dict( + train_batch_size=self._init_kwargs["data"].get("train_batch_size", self.default_train_batch_size), + val_batch_size=self._init_kwargs["data"].get("val_batch_size", self.default_val_batch_size), + test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), + predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), + num_workers=0, # Single-threaded for inference simplicity + pin_memory=False, + persistent_workers=False, + shuffle_train=False, + shuffle_eval=False, + dtype=self._init_kwargs["data"].get("dtype", torch.float), + ), + task_type="singletask_univariate" if self._init_kwargs["model"].get("univariate", False) + else "singletask_multivariate", + ) + + # Setup the datamodule + dm.setup(stage="predict") + pred_loader = dm.predict_dataloader() + preds = [] for i in range(len(self.models)): - dm = self._build_datamodule( - C=Cq, - X=Xq, - Y=Yq, - predict_idx=np.arange(len(Cq)), - data_kwargs=dict( - train_batch_size=self._init_kwargs["data"].get( - "train_batch_size", self.default_train_batch_size - ), - val_batch_size=self._init_kwargs["data"].get( - "val_batch_size", self.default_val_batch_size - ), - test_batch_size=self._init_kwargs["data"].get( - "test_batch_size", self.default_test_batch_size - ), - predict_batch_size=self._init_kwargs["data"].get( - "predict_batch_size", self.default_val_batch_size - ), - num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get( - "pin_memory", self._is_gpu() - ), - persistent_workers=self._init_kwargs["data"].get( - "persistent_workers", False - ), - shuffle_train=False, - shuffle_eval=False, - dtype=self._init_kwargs["data"].get("dtype", torch.float), - ), - task_type="singletask_univariate" - if self._init_kwargs["model"].get("univariate", False) - else "singletask_multivariate", - ) - yhat = self.trainers[i].predict_y( - self.models[i], dm.predict_dataloader(), **kwargs - ) + model = self.models[i] + model.eval() + model.to(device) + + out_batches = [] + with torch.no_grad(): + for b_idx, batch in enumerate(pred_loader): + # Move batch to device + batch = { + k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) + for k, v in batch.items() + } + + out = model.predict_step(batch, b_idx) + + Cb = out.get("contexts") + Xb = out.get("predictors") + betas = out["betas"] + mus = out["mus"] + + yb = model._predict_y(Cb, Xb, betas, mus) + out_batches.append(yb.detach().cpu()) + + yhat = torch.cat(out_batches, dim=0).numpy() preds.append(yhat) predictions = np.array(preds) + if not individual_preds: predictions = np.mean(predictions, axis=0) + if self.normalize and self.scalers["Y"] is not None: if individual_preds: predictions = np.array( @@ -588,6 +706,7 @@ def predict( ) else: predictions = self.scalers["Y"].inverse_transform(predictions) + return predictions def predict_params( @@ -597,67 +716,81 @@ def predict_params( model_includes_mus: bool = True, **kwargs, ): + """ + FIXED: Proper single-device inference for parameter prediction. + """ if not hasattr(self, "models") or self.models is None: - raise ValueError( - "Trying to predict with a model that hasn't been trained yet." - ) + raise ValueError("Trying to predict with a model that hasn't been trained yet.") Cq = self._maybe_scale_C(C) X_zero = np.zeros((len(Cq), self.x_dim), dtype=np.float32) Y_zero = np.zeros((len(Cq), self.y_dim), dtype=np.float32) + uses_y = kwargs.pop("uses_y", True) + device = self._get_inference_device() + + dm = self._build_datamodule( + C=Cq, + X=X_zero, + Y=Y_zero if uses_y else None, + predict_idx=np.arange(len(Cq)), + data_kwargs=dict( + train_batch_size=self._init_kwargs["data"].get("train_batch_size", self.default_train_batch_size), + val_batch_size=self._init_kwargs["data"].get("val_batch_size", self.default_val_batch_size), + test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), + predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), + num_workers=0, + pin_memory=False, + persistent_workers=False, + shuffle_train=False, + shuffle_eval=False, + dtype=self._init_kwargs["data"].get("dtype", torch.float), + ), + task_type="singletask_univariate" if self._init_kwargs["model"].get("univariate", False) + else "singletask_multivariate", + ) + + dm.setup(stage="predict") + pred_loader = dm.predict_dataloader() + out_betas, out_mus = [], [] + for i in range(len(self.models)): - dm = self._build_datamodule( - C=Cq, - X=X_zero, - Y=Y_zero if kwargs.pop("uses_y", True) else None, - predict_idx=np.arange(len(Cq)), - data_kwargs=dict( - train_batch_size=self._init_kwargs["data"].get( - "train_batch_size", self.default_train_batch_size - ), - val_batch_size=self._init_kwargs["data"].get( - "val_batch_size", self.default_val_batch_size - ), - test_batch_size=self._init_kwargs["data"].get( - "test_batch_size", self.default_test_batch_size - ), - predict_batch_size=self._init_kwargs["data"].get( - "predict_batch_size", self.default_val_batch_size - ), - num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get( - "pin_memory", self._is_gpu() - ), - persistent_workers=self._init_kwargs["data"].get( - "persistent_workers", False - ), - shuffle_train=False, - shuffle_eval=False, - dtype=self._init_kwargs["data"].get("dtype", torch.float), - ), - task_type="singletask_univariate" - if self._init_kwargs["model"].get("univariate", False) - else "singletask_multivariate", - ) - pred = self.trainers[i].predict_params( - self.models[i], dm.predict_dataloader(), **kwargs - ) + model = self.models[i] + model.eval() + model.to(device) + + beta_batches, mu_batches = [], [] + + with torch.no_grad(): + for b_idx, batch in enumerate(pred_loader): + batch = { + k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) + for k, v in batch.items() + } + out = model.predict_step(batch, b_idx) + + betas_b = out["betas"].detach().cpu() + beta_batches.append(betas_b) + + if model_includes_mus: + mus_b = out["mus"].detach().cpu() + mu_batches.append(mus_b) + + betas_i = torch.cat(beta_batches, dim=0).numpy() if model_includes_mus: - out_betas.append(pred[0]) - out_mus.append(pred[1]) + mus_i = torch.cat(mu_batches, dim=0).numpy() + out_betas.append(betas_i) + out_mus.append(mus_i) else: - out_betas.append(pred) + out_betas.append(betas_i) if model_includes_mus: betas = np.array(out_betas) mus = np.array(out_mus) - return ( - (betas, mus) - if individual_preds - else (np.mean(betas, axis=0), np.mean(mus, axis=0)) - ) + if individual_preds: + return betas, mus + return np.mean(betas, axis=0), np.mean(mus, axis=0) else: betas = np.array(out_betas) return betas if individual_preds else np.mean(betas, axis=0) @@ -665,15 +798,11 @@ def predict_params( def fit(self, *args, **kwargs) -> None: """ Fit contextualized model to data. - - Accepts either: - - (C, X, Y) [canonical order], OR - - (X, Y, C) [README order], OR - - kw-only: C=..., X=..., (Y=...) + FIXED: Proper DDP handling and device tracking. """ self.models, self.trainers = [], [] - # normalize argument order + # Normalize argument order C_in = kwargs.pop("C", None) X_in = kwargs.pop("X", None) Y_in = kwargs.pop("Y", None) @@ -689,21 +818,14 @@ def fit(self, *args, **kwargs) -> None: else: C, X, Y = A, B, Carg else: - raise ValueError( - "Mismatched sample counts among provided arrays." - ) + raise ValueError("Mismatched sample counts among provided arrays.") elif len(args) == 2: A, B = args if A.shape[0] != B.shape[0]: - raise ValueError( - "Mismatched sample counts for two-argument fit." - ) - # Assume (C, X) by default + raise ValueError("Mismatched sample counts for two-argument fit.") C, X, Y = A, B, None else: - raise ValueError( - "fit expects (C,X[,Y]) or (X,Y,C) or kw-only C=..., X=..." - ) + raise ValueError("fit expects (C,X[,Y]) or (X,Y,C) or kw-only C=..., X=...") # Optional scaling if self.normalize: @@ -731,65 +853,42 @@ def fit(self, *args, **kwargs) -> None: args = (C, X) organized = self._organize_and_expand_fit_kwargs(**kwargs) - self.n_bootstraps = organized["wrapper"].get( - "n_bootstraps", self.n_bootstraps - ) + self.n_bootstraps = organized["wrapper"].get("n_bootstraps", self.n_bootstraps) n = C.shape[0] val_split = organized["data"].get("val_split", self.default_val_split) use_val = val_split > 0.0 for b in range(self.n_bootstraps): - # Model (LightningModule) + # Model _model_kwargs = dict(organized["model"]) - _model_kwargs.pop("univariate", None) # handled via task_type below + _model_kwargs.pop("univariate", None) model = self.base_constructor(**_model_kwargs) self.model_ = model # Indices train_idx, val_idx = self._split_train_data( - C, - X, - (args[2] if len(args) == 3 else None), + C, X, (args[2] if len(args) == 3 else None), Y_required=(len(args) == 3), val_split=val_split, ) + print(f"[RANK {os.environ.get('RANK', 0)}] train_idx[:5]={train_idx[:5]}, val_idx[:5]={val_idx[:5] if val_idx is not None else None}") + test_idx = None # DataModule - task_type = ( - "singletask_univariate" - if organized["model"].get("univariate", False) - else "singletask_multivariate" - ) + task_type = "singletask_univariate" if organized["model"].get("univariate", False) else "singletask_multivariate" dm = self._build_datamodule( - C=args[0], - X=args[1], - Y=(args[2] if len(args) == 3 else None), - train_idx=train_idx, - val_idx=val_idx, - test_idx=test_idx, + C=args[0], X=args[1], Y=(args[2] if len(args) == 3 else None), + train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, data_kwargs=dict( - train_batch_size=organized["data"].get( - "train_batch_size", self.default_train_batch_size - ), - val_batch_size=organized["data"].get( - "val_batch_size", self.default_val_batch_size - ), - test_batch_size=organized["data"].get( - "test_batch_size", self.default_test_batch_size - ), - predict_batch_size=organized["data"].get( - "predict_batch_size", self.default_val_batch_size - ), + train_batch_size=organized["data"].get("train_batch_size", self.default_train_batch_size), + val_batch_size=organized["data"].get("val_batch_size", self.default_val_batch_size), + test_batch_size=organized["data"].get("test_batch_size", self.default_test_batch_size), + predict_batch_size=organized["data"].get("predict_batch_size", self.default_val_batch_size), num_workers=organized["data"].get("num_workers", 0), - pin_memory=organized["data"].get( - "pin_memory", self._is_gpu() - ), - persistent_workers=organized["data"].get( - "persistent_workers", - organized["data"].get("num_workers", 0) > 0, - ), + pin_memory=organized["data"].get("pin_memory", self.accelerator in ("cuda", "gpu")), + persistent_workers=organized["data"].get("persistent_workers", False), drop_last=organized["data"].get("drop_last", False), shuffle_train=organized["data"].get("shuffle_train", True), shuffle_eval=organized["data"].get("shuffle_eval", False), @@ -798,58 +897,38 @@ def fit(self, *args, **kwargs) -> None: task_type=task_type, ) - # Trainer (fresh callbacks) + # Trainer trainer_kwargs = copy.deepcopy(organized["trainer"]) - trainer_kwargs["callbacks"] = [ - f(b) for f in trainer_kwargs.get("callback_constructors", []) - ] + trainer_kwargs["callbacks"] = [f(b) for f in trainer_kwargs.get("callback_constructors", [])] trainer_kwargs.pop("callback_constructors", None) - # Build via factory (respects strategy strings and env) from contextualized.regression.trainers import make_trainer_with_env - trainer = make_trainer_with_env( self.trainer_constructor, **trainer_kwargs, ) - # Ensure checkpoint dir if used for cb in trainer_kwargs.get("callbacks", []): if isinstance(cb, ModelCheckpoint): os.makedirs(cb.dirpath, exist_ok=True) - # Fit (omit val loader if no val split) - if use_val and dm.val_dataloader() is not None: - trainer.fit( - model, - train_dataloaders=dm.train_dataloader(), - val_dataloaders=dm.val_dataloader(), - **organized["fit"], - ) - else: - trainer.fit( - model, - train_dataloaders=dm.train_dataloader(), - **organized["fit"], - ) + # Ensure all ranks have setup data before training + if torch.cuda.is_available(): + torch.cuda.synchronize() + + # Fit + trainer.fit( + model, + datamodule=dm, + **organized["fit"], + ) # Load best checkpoint if enabled if trainer_kwargs.get("enable_checkpointing", False): - ckpt_cb = next( - ( - cb - for cb in trainer.callbacks - if isinstance(cb, ModelCheckpoint) - ), - None, - ) - if ( - ckpt_cb - and ckpt_cb.best_model_path - and os.path.exists(ckpt_cb.best_model_path) - ): + ckpt_cb = next((cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)), None) + if ckpt_cb and ckpt_cb.best_model_path and os.path.exists(ckpt_cb.best_model_path): best = torch.load(ckpt_cb.best_model_path, map_location="cpu") model.load_state_dict(best["state_dict"]) self.models.append(model) - self.trainers.append(trainer) + self.trainers.append(trainer) \ No newline at end of file diff --git a/contextualized/modules.py b/contextualized/modules.py index e3e87ba3..596d033b 100644 --- a/contextualized/modules.py +++ b/contextualized/modules.py @@ -182,4 +182,4 @@ def forward(self, X): return self.linear(X) -ENCODERS = {"mlp": MLP, "ngam": NGAM, "linear": Linear} +ENCODERS = {"mlp": MLP, "ngam": NGAM, "linear": Linear} \ No newline at end of file diff --git a/contextualized/regression/__init__.py b/contextualized/regression/__init__.py index a2fb4e91..9e8fd308 100644 --- a/contextualized/regression/__init__.py +++ b/contextualized/regression/__init__.py @@ -72,4 +72,4 @@ # trainers "RegressionTrainer", "TRAINERS", -] +] \ No newline at end of file diff --git a/contextualized/regression/datamodules.py b/contextualized/regression/datamodules.py index 8aa43f5f..82539886 100644 --- a/contextualized/regression/datamodules.py +++ b/contextualized/regression/datamodules.py @@ -170,14 +170,15 @@ def _mk_dataset(idx: IndexLike): ds_cls = TASK_TO_DATASET[self.task_type] if Y_s is None: - raise ValueError( - f"Y is required for regression task_type='{self.task_type}'. " - "Pass a real Y array matching your task." - ) + # Allow unsupervised / network-style usage where Y is omitted. + # In that case, use X as a dummy target so shapes line up. + # This mirrors the old CorrelationDataModule behavior (Y = X). + Y_s = X_s return ds_cls(C_s, X_s, Y_s, dtype=self.dtype) + self.ds_train = _mk_dataset(self.train_idx) self.ds_val = _mk_dataset(self.val_idx) self.ds_test = _mk_dataset(self.test_idx) @@ -233,4 +234,3 @@ def predict_dataloader(self) -> DataLoader: shuffle=False, **self._common_dl_kwargs(self.predict_batch_size), ) - diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index fa7300d9..2f74eaea 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -276,8 +276,111 @@ def test_step(self, batch, batch_idx): return loss def _predict_from_models(self, X, beta_hat, mu_hat): - # fused reduction + keepdim avoids extra unsqueeze - return self.link_fn((beta_hat * X).sum(dim=-1, keepdim=True) + mu_hat) + """ + Make shapes consistent before computing: + y = g( (beta ⊙ X).sum(-1, keepdim=True) + mu ) + + Expected canonical shapes: + - beta_hat: (B, y_dim, x_dim) + - mu_hat: (B, y_dim, 1) or (B, y_dim) + - X: one of + * (B, x_dim) + * (B, 1, x_dim) + * (B, y_dim, x_dim) + + We also accept beta_hat/mu_hat with an extra trailing singleton dim: + * (B, y_dim, x_dim, 1) -> squeeze to (B, y_dim, x_dim) + """ + + # ---- Normalize beta_hat to (B, y_dim, x_dim) ---- + if not isinstance(beta_hat, torch.Tensor): + raise RuntimeError( + f"beta_hat must be a tensor, got {type(beta_hat)}" + ) + + # Handle univariate case where shape is (B, y, x, 1) + if beta_hat.dim() == 4 and beta_hat.shape[-1] == 1: + beta_hat = beta_hat.squeeze(-1) + + if beta_hat.dim() != 3: + raise RuntimeError( + f"_predict_from_models expects beta_hat with shape (B, y, x) " + f"or (B, y, x, 1); got {beta_hat.shape}" + ) + + B, y_dim, x_dim = beta_hat.shape + + # ---- Move and normalize X ---- + if not isinstance(X, torch.Tensor): + X = torch.as_tensor(X, device=beta_hat.device, dtype=beta_hat.dtype) + else: + X = X.to(device=beta_hat.device, dtype=beta_hat.dtype) + + if X.dim() == 2: + # (B, x_dim) -> broadcast over y_dim + if X.shape[0] != B: + raise RuntimeError( + f"X batch dim {X.shape[0]} != beta_hat batch dim {B}. " + f"X.shape={X.shape}, beta_hat.shape={beta_hat.shape}" + ) + if X.shape[1] != x_dim: + raise RuntimeError( + f"X feature dim {X.shape[1]} != x_dim {x_dim}. " + f"X.shape={X.shape}, beta_hat.shape={beta_hat.shape}" + ) + X = X.unsqueeze(1).expand(-1, y_dim, -1) + + elif X.dim() == 3: + if X.shape[0] != B: + raise RuntimeError( + f"X batch dim {X.shape[0]} != beta_hat batch dim {B}. " + f"X.shape={X.shape}, beta_hat.shape={beta_hat.shape}" + ) + + if X.shape[1] == y_dim and X.shape[2] == x_dim: + pass # already good + elif X.shape[1] == 1 and X.shape[2] == x_dim: + X = X.expand(-1, y_dim, -1) # (B,1,x) -> (B,y,x) + elif X.shape[1] == x_dim and X.shape[2] == y_dim and x_dim == y_dim: + X = X.permute(0, 2, 1) # (B,x,y) -> (B,y,x) + else: + raise RuntimeError( + f"Unexpected X shape {X.shape} for beta_hat {beta_hat.shape}. " + "Cannot safely align dimensions." + ) + else: + raise RuntimeError( + f"Unsupported X.ndim={X.dim()} for _predict_from_models; " + f"expected 2 or 3. X.shape={X.shape}, beta_hat.shape={beta_hat.shape}" + ) + + # ---- Normalize mu_hat to broadcast correctly ---- + if not isinstance(mu_hat, torch.Tensor): + mu_hat = torch.as_tensor(mu_hat, device=beta_hat.device, dtype=beta_hat.dtype) + else: + mu_hat = mu_hat.to(device=beta_hat.device, dtype=beta_hat.dtype) + + # Handle univariate case where mu_hat is (B, y, x, 1) + if mu_hat.dim() == 4 and mu_hat.shape[-1] == 1: + mu_hat = mu_hat.squeeze(-1) + + if mu_hat.dim() == 2: + # (B, y_dim) -> (B, y_dim, 1) + mu_hat = mu_hat.unsqueeze(-1) + elif mu_hat.dim() == 3: + # assume already (B, y_dim, 1) or (B, y_dim, x_dim) + pass + else: + raise RuntimeError( + f"Unsupported mu_hat.ndim={mu_hat.dim()} in _predict_from_models; " + f"mu_hat.shape={mu_hat.shape}" + ) + + out = (beta_hat * X).sum(dim=-1, keepdim=True) + mu_hat + return self.link_fn(out) + + + def _predict_y(self, C, X, beta_hat, mu_hat): @@ -1345,4 +1448,3 @@ def predict_step(self, batch, batch_idx): "mus": mu_hat, }) return batch - diff --git a/contextualized/regression/metamodels.py b/contextualized/regression/metamodels.py index feca94af..5df95120 100644 --- a/contextualized/regression/metamodels.py +++ b/contextualized/regression/metamodels.py @@ -274,4 +274,4 @@ def forward(self, C, T): MULTITASK_METAMODELS = { "multitask": MultitaskMetamodel, "tasksplit": TasksplitMetamodel, -} +} \ No newline at end of file diff --git a/contextualized/regression/regularizers.py b/contextualized/regression/regularizers.py index 5ba15648..eee1f904 100644 --- a/contextualized/regression/regularizers.py +++ b/contextualized/regression/regularizers.py @@ -78,4 +78,4 @@ def l1_l2_reg(alpha, l1_ratio=0.5, mu_ratio=0.5): return partial(l1_l2_reg_fn, alpha, l1_ratio, mu_ratio) -REGULARIZERS = {"none": no_reg(), "l1": l1_reg, "l2": l2_reg, "l1_l2": l1_l2_reg} +REGULARIZERS = {"none": no_reg(), "l1": l1_reg, "l2": l2_reg, "l1_l2": l1_l2_reg} \ No newline at end of file diff --git a/contextualized/regression/trainers.py b/contextualized/regression/trainers.py index 04197fae..da96e481 100644 --- a/contextualized/regression/trainers.py +++ b/contextualized/regression/trainers.py @@ -154,58 +154,25 @@ def predict_precision(self, model: pl.LightningModule, dataloader) -> np.ndarray -# ADD THIS FACTORY (end of file) - -# at top of file you already have: -# from pytorch_lightning.plugins.environments import LightningEnvironment - -from contextualized.utils.engine import pick_engine - - -from pytorch_lightning.strategies import DDPStrategy, Strategy as PLStrategy - -def make_trainer_with_env(trainer_cls=RegressionTrainer, **kwargs) -> pl.Trainer: - # Respect explicit user settings; otherwise auto-pick - accelerator = kwargs.pop("accelerator", None) - devices = kwargs.pop("devices", None) - strategy = kwargs.pop("strategy", None) - plugins = kwargs.pop("plugins", None) - - # If caller provided a concrete Strategy instance, pass it through verbatim - if isinstance(strategy, PLStrategy): - return trainer_cls( - accelerator=("cpu" if accelerator is None else accelerator), - devices=(1 if devices is None else devices), - strategy=strategy, - plugins=plugins, - **kwargs, - ) - - # Otherwise, select engines automatically - accelerator, devices, strategy_name = pick_engine( - accelerator=accelerator, - devices=devices, - strategy=strategy, # may be "ddp" or "auto" - prefer_spawn=True, # allows plain `python script.py` to use all GPUs - ) - - # Upgrade "ddp" string to tuned DDPStrategy - if strategy_name == "ddp": - strategy_obj = DDPStrategy( - find_unused_parameters=False, - static_graph=True, - gradient_as_bucket_view=True, - ) - else: - strategy_obj = strategy_name # "auto" or other strings - - if plugins is None and accelerator == "cpu": - plugins = [LightningEnvironment()] - - return trainer_cls( - accelerator=accelerator, - devices=devices, - strategy=strategy_obj, - plugins=plugins, - **kwargs, - ) +def choose_lightning_environment() -> LightningEnvironment: + # If you have a custom Environment subclass, wire it here. + # Otherwise, the default LightningEnvironment is fine. + return LightningEnvironment() + +def make_trainer_with_env(trainer_cls, **trainer_kwargs): + """ + Factory that respects caller-provided `devices` and `strategy`. + FIXED: Don't inject LightningEnvironment when torchrun is managing processes. + """ + import os + + # Check if we're under torchrun (WORLD_SIZE > 1 means torchrun is managing) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + + # Only inject LightningEnvironment for single-process runs + # When torchrun is active, Lightning will auto-detect TorchElasticEnvironment + if "plugins" not in trainer_kwargs and world_size == 1: + env = choose_lightning_environment() + trainer_kwargs["plugins"] = [env] + + return trainer_cls(**trainer_kwargs) \ No newline at end of file diff --git a/contextualized/tests.py b/contextualized/tests.py index 1e4e8718..0bff28d0 100644 --- a/contextualized/tests.py +++ b/contextualized/tests.py @@ -203,4 +203,4 @@ def test_save_load(self): if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file diff --git a/contextualized/utils/__init__.py b/contextualized/utils/__init__.py index e69de29b..00e7196b 100644 --- a/contextualized/utils/__init__.py +++ b/contextualized/utils/__init__.py @@ -0,0 +1,47 @@ +""" +Utility functions and simple helper predictors used across the library, +including saving/loading of contextualized models. +""" + +from __future__ import annotations + +import torch + + +def save(model, path: str) -> None: + """Save a model object to disk.""" + with open(path, "wb") as out_file: + torch.save(model, out_file) + + +def load(path: str): + """Load a model object from disk.""" + with open(path, "rb") as in_file: + # Newer torch supports weights_only; older versions do not. + try: + return torch.load(in_file, weights_only=False) + except TypeError: + return torch.load(in_file) + + +class DummyParamPredictor: + """Predicts parameters as all zeros (for unit tests / baselines).""" + + def __init__(self, beta_dim, mu_dim): + self.beta_dim = beta_dim + self.mu_dim = mu_dim + + def predict_params(self, *args): + n = len(args[0]) + return torch.zeros((n, *self.beta_dim)), torch.zeros((n, *self.mu_dim)) + + +class DummyYPredictor: + """Predicts Y values as all zeros (for unit tests / baselines).""" + + def __init__(self, y_dim): + self.y_dim = y_dim + + def predict_y(self, *args): + n = len(args[0]) + return torch.zeros((n, *self.y_dim)) diff --git a/contextualized/utils/engine.py b/contextualized/utils/engine.py index 225de18e..ebabb468 100644 --- a/contextualized/utils/engine.py +++ b/contextualized/utils/engine.py @@ -37,4 +37,4 @@ def pick_engine( if _under_torchrun(): return "gpu", 1, "ddp" # one proc per GPU (torchrun sets ranks) - return "gpu", ngpu, ("ddp_spawn" if prefer_spawn else "auto") + return "gpu", ngpu, ("ddp_spawn" if prefer_spawn else "auto") \ No newline at end of file diff --git a/cpu_scale_bench.py b/cpu_scale_bench.py deleted file mode 100644 index 5d835f82..00000000 --- a/cpu_scale_bench.py +++ /dev/null @@ -1,448 +0,0 @@ -#!/usr/bin/env python3 -import os, time, csv, argparse, math, json -from dataclasses import dataclass -from typing import List, Dict -from datetime import timedelta - -import numpy as np -import torch -import pytorch_lightning as pl -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.strategies import DDPStrategy - -# ---- your package pieces ---- -from contextualized.regression import ContextualizedRegression -from contextualized.regression.datamodules import ContextualizedRegressionDataModule - - -# ---------------- launcher/cluster helpers ---------------- -def under_torchrun() -> bool: - e = os.environ - return ("LOCAL_RANK" in e) or ("RANK" in e) or ("WORLD_SIZE" in e) - -def world_size() -> int: - try: - return int(os.environ.get("WORLD_SIZE", "1")) - except Exception: - return 1 - -def is_global_zero() -> bool: - return int(os.environ.get("RANK", "0")) == 0 - - -# ---------------- env + perf ---------------- -def set_env_defaults(): - os.environ.setdefault("OMP_NUM_THREADS", "1") - os.environ.setdefault("MKL_NUM_THREADS", "1") - os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - - # Safer NCCL defaults on cloud single node - os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") - os.environ.setdefault("NCCL_DEBUG", "WARN") - os.environ.setdefault("NCCL_P2P_DISABLE", "0") - os.environ.setdefault("NCCL_IB_DISABLE", "1") # IB usually unavailable on single-node Lambda - - # Pick an interface if not set - if "NCCL_SOCKET_IFNAME" not in os.environ: - try: - ifaces = [d for d in os.listdir("/sys/class/net") if os.path.isdir(f"/sys/class/net/{d}")] - cand = next((i for i in ifaces if i not in ("lo", "docker0")), None) - os.environ["NCCL_SOCKET_IFNAME"] = cand or "lo" - except Exception: - os.environ["NCCL_SOCKET_IFNAME"] = "lo" - - # Rendezvous (used only by ddp_spawn mode) - os.environ.setdefault("MASTER_ADDR", "127.0.0.1") - os.environ.setdefault("MASTER_PORT", str(12355 + (os.getpid() % 20000))) - - if is_global_zero(): - keys = ["NCCL_DEBUG","NCCL_IB_DISABLE","NCCL_P2P_DISABLE","NCCL_SOCKET_IFNAME","MASTER_ADDR","MASTER_PORT"] - print("DDP/NCCL env:", {k: os.environ.get(k) for k in keys}) - - # Ampere+ matmul speedups - try: - torch.set_float32_matmul_precision("high") - except Exception: - pass - - -def map_precision(p): - p = (p or "").lower() - if p in ("bf16", "bfloat16", "bf16-mixed"): - return "bf16-mixed" - if p in ("fp16", "16", "16-mixed"): - return "16-mixed" - return 32 # full precision - - -class EpochTimer(Callback): - def __init__(self): - self._epoch_start = None - self.epoch_times = [] - - @staticmethod - def _using_cuda(trainer) -> bool: - try: - return trainer.accelerator is not None and "cuda" in str(trainer.accelerator).lower() - except Exception: - return torch.cuda.is_available() - - def on_train_epoch_start(self, trainer, pl_module): - if self._using_cuda(trainer): - torch.cuda.synchronize() - self._epoch_start = time.time() - - def on_train_epoch_end(self, trainer, pl_module): - if self._using_cuda(trainer): - torch.cuda.synchronize() - self.epoch_times.append(time.time() - self._epoch_start) - - -# ---------------- synthetic data ---------------- -def make_synthetic(n, c_dim, x_dim, y_dim, seed=42): - rng = np.random.default_rng(seed) - C = rng.standard_normal((n, c_dim)).astype(np.float32) - X = rng.standard_normal((n, x_dim)).astype(np.float32) - W = rng.standard_normal((y_dim, x_dim)).astype(np.float32) - MU = rng.standard_normal((y_dim, 1)).astype(np.float32) - Y = (X @ W.T) + MU.squeeze(-1) + 0.01 * rng.standard_normal((n, y_dim)).astype(np.float32) - return C, X, Y - - -def load_or_make_dataset(path, n, c_dim, x_dim, y_dim, seed=42): - if path and os.path.exists(path): - npz = np.load(path) - C, X, Y = npz["C"], npz["X"], npz["Y"] - return C, X, Y - C, X, Y = make_synthetic(n, c_dim, x_dim, y_dim, seed=seed) - if path: - os.makedirs(os.path.dirname(path), exist_ok=True) - np.savez_compressed(path, C=C, X=X, Y=Y) - return C, X, Y - - -# ---------------- model/trainer builders ---------------- -def build_model(c_dim, x_dim, y_dim, width, layers, lr): - model = ContextualizedRegression( - context_dim=c_dim, - x_dim=x_dim, - y_dim=y_dim, - num_archetypes=8, - encoder_type="mlp", - encoder_kwargs={"width": width, "layers": layers, "link_fn": "identity"}, - learning_rate=lr, - fit_intercept=True, - link_fn="identity", - loss_fn="mse", - model_regularizer="none", - ) - return model - - -def build_dm( - C, X, Y, - train_batch_size: int, - num_workers: int, - pin_memory: bool, -): - n = C.shape[0] - perm = np.random.permutation(n) - n_train = int(0.9 * n) - train_idx = perm[:n_train] - val_idx = perm[n_train:] - - dm = ContextualizedRegressionDataModule( - C=C, X=X, Y=Y, - task_type="singletask_multivariate", - train_idx=train_idx, - val_idx=val_idx, - test_idx=None, - predict_idx=None, - train_batch_size=train_batch_size, - val_batch_size=train_batch_size, - test_batch_size=train_batch_size, - predict_batch_size=train_batch_size, - num_workers=num_workers, - pin_memory=bool(pin_memory), - persistent_workers=bool(num_workers > 0), - drop_last=True, - shuffle_train=True, - shuffle_eval=False, - dtype=torch.float, - ) - dm.prepare_data(); dm.setup() - return dm - - -def build_trainer(devices, precision, epochs, ddp_timeout_s=120, torchrun_mode=False): - """ - devices: - - 0 => cpu - - >=1 => number of devices this process should report to Lightning - - torchrun_mode: - - True => launched via torchrun; use DDP with devices = WORLD_SIZE, - no spawn. Satisfies Lightning's validation. - """ - timer = EpochTimer() - - if devices == 0: - accelerator = "cpu" - devices_arg = 1 - strategy = "auto" - else: - accelerator = "gpu" - if torchrun_mode: - ws = world_size() - devices_arg = ws # must equal WORLD_SIZE - strategy = DDPStrategy( - find_unused_parameters=False, - gradient_as_bucket_view=True, - static_graph=True, - timeout=timedelta(seconds=ddp_timeout_s), - ) - else: - devices_arg = devices - strategy = "auto" if devices == 1 else DDPStrategy( - start_method="spawn", - find_unused_parameters=False, - gradient_as_bucket_view=True, - static_graph=True, - timeout=timedelta(seconds=ddp_timeout_s), - ) - - trainer = pl.Trainer( - accelerator=accelerator, - devices=devices_arg, - strategy=strategy, - precision=precision, - max_epochs=epochs, - logger=False, - enable_checkpointing=False, - num_sanity_val_steps=0, - enable_progress_bar=False, - log_every_n_steps=50, - callbacks=[timer], - inference_mode=False, - detect_anomaly=False, - ) - return trainer, timer - - -# ---------------- benchmark runner ---------------- -@dataclass -class BenchCfg: - label: str - devices: int # 0 for cpu, >=1 for gpus - - -def run_once(cfg: BenchCfg, C, X, Y, args, torchrun_mode: bool) -> Dict: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - pin = (cfg.devices >= 1) - dm = build_dm( - C, X, Y, - train_batch_size=args.batch_size, - num_workers=args.num_workers, - pin_memory=pin, - ) - model = build_model(args.context_dim, args.x_dim, args.y_dim, - args.width, args.layers, args.lr) - - # Warm-up - tiny = max(1024, math.ceil(0.01 * C.shape[0])) - dm_warm = build_dm( - C[:tiny], X[:tiny], Y[:tiny], - train_batch_size=args.batch_size, - num_workers=0, - pin_memory=pin, - ) - warm_trainer, _ = build_trainer( - devices=(world_size() if torchrun_mode else cfg.devices), - precision=map_precision(args.precision), - epochs=1, - ddp_timeout_s=args.ddp_timeout, - torchrun_mode=torchrun_mode, - ) - warm_trainer.fit(model, train_dataloaders=dm_warm.train_dataloader()) - - # Timed run - trainer, timer = build_trainer( - devices=(world_size() if torchrun_mode else cfg.devices), - precision=map_precision(args.precision), - epochs=args.epochs, - ddp_timeout_s=args.ddp_timeout, - torchrun_mode=torchrun_mode, - ) - - if torch.cuda.is_available(): - torch.cuda.synchronize() - t0 = time.time() - trainer.fit(model, train_dataloaders=dm.train_dataloader()) - if torch.cuda.is_available(): - torch.cuda.synchronize() - wall = time.time() - t0 - - train_samples = len(dm.train_dataloader().dataset) - samples_total = train_samples * args.epochs - throughput = samples_total / max(wall, 1e-9) - - # devices_for_metric: report 1 for CPU so it's easy to compare "per-device" - world = (world_size() if torchrun_mode else (cfg.devices if cfg.devices > 0 else 1)) - per_device = throughput / max(world, 1) - - res = dict( - label=cfg.label, - devices=(world_size() if torchrun_mode else (cfg.devices if cfg.devices > 0 else 1)), - wall_seconds=wall, - samples_total=int(samples_total), - throughput_samples_per_s=throughput, - per_device_throughput=per_device, - steps_per_epoch=math.ceil(train_samples / args.batch_size), - samples_per_epoch=int(train_samples), - epoch_times=timer.epoch_times[:], - ) - if is_global_zero(): - print(json.dumps({ - "label": res["label"], - "devices": res["devices"], - "wall_s": round(res["wall_seconds"], 3), - "throughput_sps": round(res["throughput_samples_per_s"], 2), - "per_device_sps": round(res["per_device_throughput"], 2), - "avg_epoch_s": round(float(np.mean(res["epoch_times"])) if res["epoch_times"] else float("nan"), 3) - }, indent=2)) - return res - - -def save_csv(rows: List[Dict], outdir: str): - os.makedirs(outdir, exist_ok=True) - path = os.path.join(outdir, "scale_results.csv") - fields = ["label","devices","wall_seconds","samples_total", - "throughput_samples_per_s","per_device_throughput", - "steps_per_epoch","samples_per_epoch","epoch_times"] - with open(path, "w", newline="") as f: - w = csv.DictWriter(f, fieldnames=fields) - w.writeheader() - for r in rows: - r2 = r.copy() - r2["epoch_times"] = ";".join(f"{x:.6f}" for x in r["epoch_times"]) - w.writerow(r2) - return path - - -def plot_curves(rows: List[Dict], outdir: str): - import matplotlib - matplotlib.use("Agg") - import matplotlib.pyplot as plt - import numpy as np - - os.makedirs(outdir, exist_ok=True) - labels = [r["label"] for r in rows] - devs = [r["devices"] for r in rows] - thr = [r["throughput_samples_per_s"] for r in rows] - wall = [r["wall_seconds"] for r in rows] - avg_epoch = [np.mean(r["epoch_times"]) if r["epoch_times"] else float("nan") for r in rows] - - plt.figure() - plt.plot(devs, thr, marker="o") - plt.xticks(devs, labels, rotation=30, ha="right") - plt.xlabel("Devices") - plt.ylabel("Throughput (samples/s)") - plt.title("Throughput vs Devices") - plt.tight_layout() - plt.savefig(os.path.join(outdir, "throughput_vs_devices.png")) - plt.close() - - plt.figure() - plt.plot(devs, wall, marker="o") - plt.xticks(devs, labels, rotation=30, ha="right") - plt.xlabel("Devices") - plt.ylabel("Total Wall Time (s)") - plt.title("Wall Time vs Devices") - plt.tight_layout() - plt.savefig(os.path.join(outdir, "walltime_vs_devices.png")) - plt.close() - - plt.figure() - plt.plot(devs, avg_epoch, marker="o") - plt.xticks(devs, labels, rotation=30, ha="right") - plt.xlabel("Devices") - plt.ylabel("Avg Train Epoch Time (s)") - plt.title("Epoch Time vs Devices") - plt.tight_layout() - plt.savefig(os.path.join(outdir, "epoch_time_vs_devices.png")) - plt.close() - - -# ---------------- main ---------------- -def parse_args(): - ap = argparse.ArgumentParser() - ap.add_argument("--epochs", type=int, default=5) - ap.add_argument("--batch-size", type=int, default=2048) # PER GPU/CPU - ap.add_argument("--num-workers", type=int, default=8) - ap.add_argument("--precision", type=str, default="bf16") - ap.add_argument("--dataset-cache", type=str, default="bench_out/datasets/n{n}_seed42.npz", - help="Path to .npz to cache dataset. '{n}' will be replaced with num_samples.") - - # Accept BOTH forms; same dest - ap.add_argument("--num-samples", dest="num_samples", type=int, default=2_000_000) - ap.add_argument("--n", dest="num_samples", type=int) - - ap.add_argument("--context-dim", type=int, default=16) - ap.add_argument("--x-dim", type=int, default=512) - ap.add_argument("--y-dim", type=int, default=64) - ap.add_argument("--width", type=int, default=1024) - ap.add_argument("--layers", type=int, default=4) - ap.add_argument("--lr", type=float, default=1e-3) - ap.add_argument("--outdir", type=str, default="bench_out") - ap.add_argument("--ddp-timeout", type=int, default=180) - ap.add_argument("--max-gpus", type=int, default=4) - return ap.parse_args() - - -def main(): - set_env_defaults() - args = parse_args() - os.makedirs(args.outdir, exist_ok=True) - - if torch.cuda.is_available(): - torch.backends.cudnn.benchmark = True - os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") - - # ensure dataset cache path is concrete - ds_path = args.dataset_cache.format(n=args.num_samples) if args.dataset_cache else None - C, X, Y = load_or_make_dataset(ds_path, args.num_samples, args.context_dim, args.x_dim, args.y_dim, seed=42) - - results = [] - torchrun_mode = under_torchrun() - - if torchrun_mode: - # Run exactly one config under torchrun (WORLD_SIZE GPUs, 1 per process) - cfg = BenchCfg(label=f"gpu-{world_size()}", devices=1) - if is_global_zero(): - print(f"\n=== Running {cfg.label} (torchrun, {world_size()} processes) ===") - res = run_once(cfg, C, X, Y, args, torchrun_mode=True) - results.append(res) - else: - # Standalone: run CPU + 1..k GPUs - gpus = torch.cuda.device_count() - dev_list = [BenchCfg("cpu", 0)] - for k in range(1, min(args.max_gpus, gpus) + 1): - dev_list.append(BenchCfg(f"gpu-{k}", k)) - for cfg in dev_list: - if is_global_zero(): - print(f"\n=== Running {cfg.label} ===") - res = run_once(cfg, C, X, Y, args, torchrun_mode=False) - results.append(res) - - if is_global_zero(): - csv_path = save_csv(results, args.outdir) - plot_curves(results, args.outdir) - print(f"\nSaved CSV → {csv_path}") - print(f"Saved plots → {args.outdir}/throughput_vs_devices.png, " - f"walltime_vs_devices.png, epoch_time_vs_devices.png") - - -if __name__ == "__main__": - main() diff --git a/network_scaling_heavy.py b/network_scaling_heavy.py new file mode 100644 index 00000000..c1e04bff --- /dev/null +++ b/network_scaling_heavy.py @@ -0,0 +1,766 @@ +#!/usr/bin/env python3 +""" +HEAVY ContextualizedCorrelationNetworks DDP Scaling Benchmark + +This benchmark tests multi-GPU scaling with the actual ContextualizedCorrelationNetworks +model, but configured for maximum compute to properly stress-test GPU parallelism. + +Key optimizations for heavier compute: +1. Larger encoder networks (more layers, wider hidden dims) +2. More archetypes (more mixture components to learn) +3. Multiple bootstraps (ensemble of models) +4. Larger batch sizes to saturate GPU memory +5. More training epochs +6. Increased data dimensionality (more PCs) + +The goal is to make the model heavy enough that: +- Forward/backward pass takes significant time (50-200ms per batch) +- GPU compute dominates over NCCL sync overhead +- Multi-GPU scaling approaches theoretical limits (85-95% efficiency) + +Usage: + # 1-GPU baseline + python ccn_scaling_heavy.py --epochs 20 --devices 1 --label 1gpu_baseline + + # Multi-GPU with torchrun + torchrun --standalone --nproc_per_node=4 ccn_scaling_heavy.py --epochs 20 --label 4gpu_ddp +""" + +import os +import time +import csv +import warnings +import pickle +from dataclasses import dataclass +from typing import Tuple, Optional, List + +import numpy as np +import pandas as pd +from sklearn.decomposition import PCA +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler + +import torch +import torch.distributed as dist + +from rdkit import Chem +from rdkit.Chem import rdFingerprintGenerator + +from contextualized.easy import ContextualizedCorrelationNetworks + + +# ================= CONFIGURATION ================= + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +DATA_DIR = os.path.join(os.path.dirname(BASE_DIR), "data") + +PATH_L1000 = os.path.join(DATA_DIR, "trt_cp_smiles_qc.csv") +PATH_CTLS = os.path.join(DATA_DIR, "ctrls.csv") + +# INCREASED: More PCs = larger feature space = more compute +N_DATA_PCS = 100 # Was 50 +N_CONTEXT_PCS = 100 # Control profile PCs + +PERTURBATION_HOLDOUT_SIZE = 0.2 +RANDOM_STATE = 42 + +morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=4096) + + +# ================= DISTRIBUTED HELPERS ================= + +def is_global_zero() -> bool: + """Return True only on global rank 0.""" + if dist.is_available() and dist.is_initialized(): + try: + return dist.get_rank() == 0 + except Exception: + return True + return int(os.environ.get("GLOBAL_RANK", os.environ.get("RANK", "0"))) == 0 + + +def get_rank() -> int: + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0"))) + + +def get_world_size() -> int: + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + return int(os.environ.get("WORLD_SIZE", "1")) + + +def get_local_rank() -> int: + return int(os.environ.get("LOCAL_RANK", "0")) + + +def barrier(): + if dist.is_available() and dist.is_initialized(): + dist.barrier() + + +def print_rank0(msg: str): + if is_global_zero(): + print(msg, flush=True) + + +# ================= ENVIRONMENT SETUP ================= + +def set_env_defaults(): + """Optimized environment for heavy CCN training.""" + world_size = int(os.environ.get("WORLD_SIZE", "1")) + cpu_count = os.cpu_count() or 8 + threads = max(1, cpu_count // max(world_size, 1)) + + os.environ.setdefault("OMP_NUM_THREADS", str(min(threads, 4))) + os.environ.setdefault("MKL_NUM_THREADS", str(min(threads, 4))) + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + + # NCCL optimizations + os.environ.setdefault("NCCL_DEBUG", "WARN") + os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "1") + os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") + os.environ.setdefault("NCCL_ALGO", "Ring") + os.environ.setdefault("NCCL_NSOCKS_PERTHREAD", "4") + os.environ.setdefault("NCCL_SOCKET_NTHREADS", "2") + + # PyTorch optimizations + try: + torch.set_float32_matmul_precision("high") + except: + pass + + # Deterministic seeds + np.random.seed(RANDOM_STATE) + torch.manual_seed(RANDOM_STATE) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(RANDOM_STATE) + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + + +# ================= FINGERPRINT HELPER ================= + +def smiles_to_morgan_fp(smiles: str) -> np.ndarray: + """Convert SMILES to Morgan fingerprint.""" + try: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) + fp = morgan_gen.GetFingerprint(mol) + return np.array(fp, dtype=np.float32) + except: + return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) + + +# ================= DATA LOADING WITH CACHE ================= + +def get_cache_path(subsample_fraction: Optional[float], n_data_pcs: int) -> str: + """Generate cache path based on config.""" + suffix = f"_sub{subsample_fraction}" if subsample_fraction else "" + suffix += f"_pcs{n_data_pcs}" + return os.path.join(DATA_DIR, f"ccn_heavy_cache{suffix}.pkl") + + +def load_and_preprocess( + subsample_fraction: Optional[float] = None, + use_cache: bool = True, + n_data_pcs: int = N_DATA_PCS, + n_context_pcs: int = N_CONTEXT_PCS, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Load and preprocess data with configurable dimensionality. + Higher dimensions = more compute in the model. + """ + cache_path = get_cache_path(subsample_fraction, n_data_pcs) + + # Try cache + if use_cache and os.path.exists(cache_path): + print_rank0(f"[DATA] Loading from cache: {cache_path}") + with open(cache_path, 'rb') as f: + cached = pickle.load(f) + return ( + cached['C_train'], cached['X_train_norm'], + cached['C_test'], cached['X_test_norm'], + cached['cell_ids_train'], cached['cell_ids_test'] + ) + + # Wait for rank 0 to create cache + if not is_global_zero() and use_cache: + wait_count = 0 + while not os.path.exists(cache_path) and wait_count < 600: + time.sleep(1) + wait_count += 1 + if os.path.exists(cache_path): + with open(cache_path, 'rb') as f: + cached = pickle.load(f) + return ( + cached['C_train'], cached['X_train_norm'], + cached['C_test'], cached['X_test_norm'], + cached['cell_ids_train'], cached['cell_ids_test'] + ) + + print_rank0(f"[DATA] Loading L1000 from {PATH_L1000}") + df = pd.read_csv(PATH_L1000, engine="pyarrow") + + df = df[df["pert_type"].isin(["trt_cp"])] + + bad = ( + (df["distil_cc_q75"] < 0.2) | + (df["distil_cc_q75"] == -666) | + (df["distil_cc_q75"].isna()) | + (df["pct_self_rank_q25"] > 5) | + (df["pct_self_rank_q25"] == -666) | + (df["pct_self_rank_q25"].isna()) + ) + df = df[~bad] + df = df.dropna(subset=["canonical_smiles"]) + df = df[df["canonical_smiles"] != ""] + + print_rank0(f"[DATA] Samples after QC: {len(df)}") + + if subsample_fraction is not None: + df = df.sample(frac=subsample_fraction, random_state=RANDOM_STATE) + print_rank0(f"[DATA] Subsampled to {len(df)} ({subsample_fraction*100:.1f}%)") + + # Split by perturbation + unique_smiles = df["canonical_smiles"].unique() + print_rank0(f"[DATA] Unique perturbations: {len(unique_smiles)}") + + smiles_train, smiles_test = train_test_split( + unique_smiles, test_size=PERTURBATION_HOLDOUT_SIZE, random_state=RANDOM_STATE + ) + + df_train = df[df["canonical_smiles"].isin(smiles_train)].copy() + df_test = df[df["canonical_smiles"].isin(smiles_test)].copy() + + print_rank0(f"[DATA] Train: {len(df_train)}, Test: {len(df_test)}") + + # Handle missing values + pert_time_mean = df_train.loc[df_train["pert_time"] != -666, "pert_time"].mean() + pert_dose_mean = df_train.loc[df_train["pert_dose"] != -666, "pert_dose"].mean() + + for df_split in [df_train, df_test]: + df_split["ignore_flag_pert_time"] = (df_split["pert_time"] == -666).astype(int) + df_split["ignore_flag_pert_dose"] = (df_split["pert_dose"] == -666).astype(int) + df_split["pert_time"] = df_split["pert_time"].replace(-666, pert_time_mean) + df_split["pert_dose"] = df_split["pert_dose"].replace(-666, pert_dose_mean) + + def process_split(df_split, name): + numeric_cols = df_split.select_dtypes(include=[np.number]).columns + drop_cols = ["pert_dose", "pert_dose_unit", "pert_time", "distil_cc_q75", "pct_self_rank_q25"] + feature_cols = [c for c in numeric_cols if c not in drop_cols] + X_raw = df_split[feature_cols].values.astype(np.float32) + + print_rank0(f"[DATA] [{name}] Generating fingerprints...") + fps = np.stack([smiles_to_morgan_fp(s) for s in df_split["canonical_smiles"]]) + print_rank0(f"[DATA] [{name}] Fingerprint shape: {fps.shape}") + + pert_time = df_split["pert_time"].to_numpy().reshape(-1, 1).astype(np.float32) + pert_dose = df_split["pert_dose"].to_numpy().reshape(-1, 1).astype(np.float32) + ign_t = df_split["ignore_flag_pert_time"].to_numpy().reshape(-1, 1).astype(np.float32) + ign_d = df_split["ignore_flag_pert_dose"].to_numpy().reshape(-1, 1).astype(np.float32) + + return X_raw, fps, pert_time, pert_dose, ign_t, ign_d, df_split["cell_id"].to_numpy() + + X_train_raw, morgan_train, pt_train, pd_train, ign_t_train, ign_d_train, cells_train = process_split(df_train, "train") + X_test_raw, morgan_test, pt_test, pd_test, ign_t_test, ign_d_test, cells_test = process_split(df_test, "test") + + # Scale features + print_rank0("[DATA] Scaling gene expression...") + scaler_genes = StandardScaler() + X_train_scaled = scaler_genes.fit_transform(X_train_raw) + X_test_scaled = scaler_genes.transform(X_test_raw) + + # Load controls + print_rank0(f"[DATA] Loading controls from {PATH_CTLS}") + ctrls_df = pd.read_csv(PATH_CTLS, index_col=0) + + unique_cells = np.union1d(np.unique(cells_train), np.unique(cells_test)) + ctrls_df = ctrls_df.loc[ctrls_df.index.intersection(unique_cells)] + + scaler_ctrls = StandardScaler() + ctrls_scaled = scaler_ctrls.fit_transform(ctrls_df.values) + + # INCREASED: More control PCs + actual_n_ctrl_pcs = min(n_context_pcs, ctrls_scaled.shape[0], ctrls_scaled.shape[1]) + print_rank0(f"[DATA] Using {actual_n_ctrl_pcs} control PCs") + + pca_ctrls = PCA(n_components=actual_n_ctrl_pcs, random_state=RANDOM_STATE) + ctrls_pcs = pca_ctrls.fit_transform(ctrls_scaled) + cell2vec = dict(zip(ctrls_df.index, ctrls_pcs)) + + if not cell2vec: + raise ValueError("No overlapping cell IDs") + + print_rank0(f"[DATA] Control embeddings for {len(cell2vec)} cells") + + def build_context(df_split, X_scaled, morgan, pt, pd, ign_t, ign_d, name, scaler=None, fit=False): + cell_ids = df_split["cell_id"].to_numpy() + unique_cells_split = np.sort(df_split["cell_id"].unique()) + + all_cont = [] + valid_cells = [] + + for cell_id in unique_cells_split: + if cell_id not in cell2vec: + continue + mask = cell_ids == cell_id + if mask.sum() == 0: + continue + valid_cells.append(cell_id) + cont = np.hstack([ + np.tile(cell2vec[cell_id], (mask.sum(), 1)), + pt[mask], + pd[mask], + ]).astype(np.float32) + all_cont.append(cont) + + if fit: + all_cont_stacked = np.vstack(all_cont) + scaler = StandardScaler() + scaler.fit(all_cont_stacked) + + X_list, C_list, cid_list = [], [], [] + + for i, cell_id in enumerate(valid_cells): + mask = cell_ids == cell_id + X_cell = X_scaled[mask] + cont_scaled = scaler.transform(all_cont[i]) + C_cell = np.hstack([ + cont_scaled, + morgan[mask], + ign_t[mask], + ign_d[mask], + ]).astype(np.float32) + + X_list.append(X_cell) + C_list.append(C_cell) + cid_list.append(cell_ids[mask]) + + X_final = np.vstack(X_list) + C_final = np.vstack(C_list) + cell_ids_final = np.concatenate(cid_list) + + return X_final, C_final, cell_ids_final, scaler + + print_rank0("[DATA] Building context matrices...") + X_train, C_train, cell_ids_train, ctx_scaler = build_context( + df_train, X_train_scaled, morgan_train, pt_train, pd_train, ign_t_train, ign_d_train, "train", fit=True + ) + X_test, C_test, cell_ids_test, _ = build_context( + df_test, X_test_scaled, morgan_test, pt_test, pd_test, ign_t_test, ign_d_test, "test", scaler=ctx_scaler + ) + + print_rank0(f"[DATA] Context shapes: C_train={C_train.shape}, C_test={C_test.shape}") + + # INCREASED: More data PCs + actual_n_data_pcs = min(n_data_pcs, X_train.shape[1], X_train.shape[0]) + print_rank0(f"[DATA] Using {actual_n_data_pcs} data PCs") + + pca_data = PCA(n_components=actual_n_data_pcs, random_state=RANDOM_STATE) + X_train_pca = pca_data.fit_transform(X_train) + X_test_pca = pca_data.transform(X_test) + + pca_scaler = StandardScaler() + X_train_norm = pca_scaler.fit_transform(X_train_pca).astype(np.float32) + X_test_norm = pca_scaler.transform(X_test_pca).astype(np.float32) + + print_rank0(f"[DATA] Final: X_train={X_train_norm.shape}, X_test={X_test_norm.shape}") + print_rank0(f"[DATA] Final: C_train={C_train.shape}, C_test={C_test.shape}") + + # Save cache + if use_cache and is_global_zero(): + cache_data = { + 'C_train': C_train, 'X_train_norm': X_train_norm, + 'C_test': C_test, 'X_test_norm': X_test_norm, + 'cell_ids_train': cell_ids_train, 'cell_ids_test': cell_ids_test, + } + os.makedirs(os.path.dirname(cache_path), exist_ok=True) + with open(cache_path, 'wb') as f: + pickle.dump(cache_data, f) + print_rank0(f"[DATA] Saved cache: {cache_path}") + + return C_train, X_train_norm, C_test, X_test_norm, cell_ids_train, cell_ids_test + + +# ================= BENCHMARK RESULT ================= + +@dataclass +class BenchResult: + label: str + wall_seconds: float + train_mse_mean: float + test_mse_mean: float + num_gpus: int + batch_size_per_gpu: int + effective_batch_size: int + samples_per_second: float + num_archetypes: int + encoder_width: int + encoder_layers: int + n_bootstraps: int + speedup: float = 1.0 + efficiency: float = 100.0 + + +# ================= MAIN BENCHMARK ================= + +def run_ccn_benchmark( + label: str, + C_train: np.ndarray, + X_train_norm: np.ndarray, + C_test: np.ndarray, + X_test_norm: np.ndarray, + epochs: int, + devices: int, + batch_size_per_gpu: int = 512, + num_workers: int = 4, + # Heavy CCN parameters + num_archetypes: int = 64, + encoder_width: int = 256, + encoder_layers: int = 6, + n_bootstraps: int = 3, + warmup_epochs: int = 1, + baseline_time: Optional[float] = None, +) -> BenchResult: + """ + Run ContextualizedCorrelationNetworks benchmark with heavy configuration. + + Key parameters for increased compute: + - num_archetypes: More mixture components (64 vs default 16) + - encoder_width: Wider encoder networks (256 vs default 25) + - encoder_layers: Deeper encoders (6 vs default 3) + - n_bootstraps: Ensemble of models (3 vs default 1) + """ + + world_size = int(os.environ.get("WORLD_SIZE", "1")) + rank = get_rank() + local_rank = get_local_rank() + launched_with_torchrun = world_size > 1 + + # Device setup + if torch.cuda.is_available() and devices > 0: + accelerator = "gpu" + if launched_with_torchrun: + devices = world_size + else: + accelerator = "cpu" + devices = 1 + num_workers = 0 + + # Reduce workers for multi-GPU + if launched_with_torchrun and num_workers > 2: + num_workers = 2 + + # Batch size: scale with GPUs for proper throughput scaling + effective_batch = batch_size_per_gpu * max(world_size, 1) + + print_rank0(f"\n{'='*70}") + print_rank0(f"[{label}] HEAVY CCN BENCHMARK") + print_rank0(f"{'='*70}") + print_rank0(f" World size: {world_size}") + print_rank0(f" Accelerator: {accelerator}") + print_rank0(f" Devices: {devices}") + print_rank0(f" Batch size per GPU: {batch_size_per_gpu}") + print_rank0(f" Effective batch size: {effective_batch}") + print_rank0(f" Epochs: {epochs} (+ {warmup_epochs} warmup)") + print_rank0(f" Num workers: {num_workers}") + print_rank0(f" --- CCN Config (HEAVY) ---") + print_rank0(f" Archetypes: {num_archetypes}") + print_rank0(f" Encoder width: {encoder_width}") + print_rank0(f" Encoder layers: {encoder_layers}") + print_rank0(f" Bootstraps: {n_bootstraps}") + print_rank0(f" Data dims: C={C_train.shape[1]}, X={X_train_norm.shape[1]}") + + # Log per-process info + print( + f"[{label}] [RANK {rank} / LOCAL {local_rank}] " + f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}", + flush=True + ) + + # Strategy configuration + strategy_kwarg = "auto" + if accelerator == "gpu" and launched_with_torchrun and world_size > 1: + try: + from pytorch_lightning.strategies import DDPStrategy + strategy_kwarg = DDPStrategy( + process_group_backend="nccl", + find_unused_parameters=False, + gradient_as_bucket_view=True, + ) + print_rank0(f"[{label}] Using DDPStrategy with NCCL + gradient_as_bucket_view") + except Exception as e: + strategy_kwarg = "ddp" + print_rank0(f"[{label}] Falling back to strategy='ddp': {e}") + + # Trainer kwargs + trainer_kwargs = { + "max_epochs": epochs + warmup_epochs, + "accelerator": accelerator, + "devices": devices, + "enable_progress_bar": False, + "logger": False, + "enable_checkpointing": False, + "num_sanity_val_steps": 0, + "precision": "16-mixed" if accelerator == "gpu" else 32, + "strategy": strategy_kwarg, + } + + print_rank0(f"[{label}] Trainer kwargs: {trainer_kwargs}") + + # Construct HEAVY CCN model + print_rank0(f"[{label}] Constructing ContextualizedCorrelationNetworks...") + + ccn = ContextualizedCorrelationNetworks( + encoder_type="mlp", + num_archetypes=num_archetypes, + n_bootstraps=n_bootstraps, + encoder_kwargs={ + "width": encoder_width, + "layers": encoder_layers, + }, + trainer_kwargs=trainer_kwargs, + es_patience=0, # No early stopping for benchmark + ) + + # Estimate parameter count + # CCN params ≈ n_bootstraps × (encoder_params + archetype_params + correlation_params) + # encoder_params ≈ (context_dim × width + width × width × (layers-1) + width × archetypes) + # archetype_params ≈ archetypes × x_dim × x_dim (correlation matrices) + context_dim = C_train.shape[1] + x_dim = X_train_norm.shape[1] + encoder_params = context_dim * encoder_width + encoder_width * encoder_width * (encoder_layers - 1) + encoder_width * num_archetypes + archetype_params = num_archetypes * x_dim * x_dim + total_params = n_bootstraps * (encoder_params + archetype_params) + print_rank0(f"[{label}] Estimated parameters: ~{total_params:,} ({total_params/1e6:.2f}M)") + + # Synchronize before training + barrier() + if torch.cuda.is_available(): + torch.cuda.synchronize() + + print_rank0(f"[{label}] Starting training...") + t0 = time.time() + + ccn.fit( + C_train, + X_train_norm, + train_batch_size=batch_size_per_gpu, + val_batch_size=batch_size_per_gpu, + test_batch_size=batch_size_per_gpu, + num_workers=num_workers, + persistent_workers=(num_workers > 0), + pin_memory=(accelerator == "gpu"), + ) + + # Synchronize after training + barrier() + if torch.cuda.is_available(): + torch.cuda.synchronize() + + wall = time.time() - t0 + + # Adjust for warmup + if warmup_epochs > 0 and epochs > 0: + wall_per_epoch = wall / (epochs + warmup_epochs) + wall = wall_per_epoch * epochs + + print_rank0(f"[{label}] Training completed in {wall:.2f}s") + + # Metrics + n_samples = C_train.shape[0] + samples_per_sec = (n_samples * epochs) / max(wall, 1e-6) + + speedup = 1.0 + efficiency = 100.0 + if baseline_time is not None and baseline_time > 0: + speedup = baseline_time / wall + efficiency = (speedup / world_size) * 100 + + train_mse = float("nan") + test_mse = float("nan") + + if is_global_zero(): + try: + print_rank0(f"[{label}] Computing MSE...") + mse_train_vec = ccn.measure_mses(C_train, X_train_norm, individual_preds=False) + mse_test_vec = ccn.measure_mses(C_test, X_test_norm, individual_preds=False) + train_mse = float(np.mean(mse_train_vec)) + test_mse = float(np.mean(mse_test_vec)) + except Exception as e: + warnings.warn(f"[{label}] measure_mses failed: {e}") + + print_rank0(f"\n[{label}] RESULTS:") + print_rank0(f" Wall time: {wall:.2f}s") + print_rank0(f" Samples/sec: {samples_per_sec:.1f}") + print_rank0(f" Train MSE: {train_mse:.6f}") + print_rank0(f" Test MSE: {test_mse:.6f}") + if baseline_time: + print_rank0(f" Speedup: {speedup:.2f}x") + print_rank0(f" Efficiency: {efficiency:.1f}%") + + return BenchResult( + label=label, + wall_seconds=wall, + train_mse_mean=train_mse, + test_mse_mean=test_mse, + num_gpus=world_size, + batch_size_per_gpu=batch_size_per_gpu, + effective_batch_size=effective_batch, + samples_per_second=samples_per_sec, + num_archetypes=num_archetypes, + encoder_width=encoder_width, + encoder_layers=encoder_layers, + n_bootstraps=n_bootstraps, + speedup=speedup, + efficiency=efficiency, + ) + + +# ================= CSV OUTPUT ================= + +def save_results_csv(results: List[BenchResult], outdir: str): + if not is_global_zero(): + return + + os.makedirs(outdir, exist_ok=True) + path = os.path.join(outdir, "ccn_heavy_scaling_results.csv") + + write_header = not os.path.exists(path) + + with open(path, "a", newline="") as f: + writer = csv.writer(f) + if write_header: + writer.writerow([ + "label", "wall_seconds", "train_mse", "test_mse", + "num_gpus", "batch_per_gpu", "effective_batch", "samples_per_sec", + "archetypes", "encoder_width", "encoder_layers", "bootstraps", + "speedup", "efficiency" + ]) + for r in results: + writer.writerow([ + r.label, + f"{r.wall_seconds:.4f}", + f"{r.train_mse_mean:.6f}", + f"{r.test_mse_mean:.6f}", + r.num_gpus, + r.batch_size_per_gpu, + r.effective_batch_size, + f"{r.samples_per_second:.2f}", + r.num_archetypes, + r.encoder_width, + r.encoder_layers, + r.n_bootstraps, + f"{r.speedup:.2f}", + f"{r.efficiency:.1f}", + ]) + + print_rank0(f"\n[OUTPUT] Results appended to: {path}") + + +# ================= CLI ================= + +def parse_args(): + import argparse + + ap = argparse.ArgumentParser(description="Heavy ContextualizedCorrelationNetworks Scaling Benchmark") + + # Training config + ap.add_argument("--epochs", type=int, default=20) + ap.add_argument("--warmup-epochs", type=int, default=1) + ap.add_argument("--batch-size", type=int, default=512, + help="Batch size per GPU") + ap.add_argument("--num-workers", type=int, default=4) + + # CCN architecture (HEAVY defaults) + ap.add_argument("--archetypes", type=int, default=64, + help="Number of archetypes (default: 64, original: 16)") + ap.add_argument("--encoder-width", type=int, default=256, + help="Encoder hidden width (default: 256, original: 25)") + ap.add_argument("--encoder-layers", type=int, default=6, + help="Encoder depth (default: 6, original: 3)") + ap.add_argument("--bootstraps", type=int, default=3, + help="Number of bootstrap models (default: 3, original: 1)") + + # Data config + ap.add_argument("--data-pcs", type=int, default=100, + help="Number of data PCs (default: 100, original: 50)") + ap.add_argument("--context-pcs", type=int, default=100, + help="Number of context PCs (default: 100)") + ap.add_argument("--subsample-fraction", type=float, default=None) + + # Runtime config + ap.add_argument("--devices", type=int, default=1) + ap.add_argument("--outdir", type=str, default="bench_results_ccn_heavy") + ap.add_argument("--label", type=str, default=None) + ap.add_argument("--baseline-time", type=float, default=None) + ap.add_argument("--no-cache", action="store_true") + + return ap.parse_args() + + +# ================= MAIN ================= + +def main(): + args = parse_args() + set_env_defaults() + + world_size = get_world_size() + + # Auto-generate label if not provided + if args.label: + label = args.label + else: + label = f"{world_size}gpu_ccn_heavy" + + print_rank0("\n" + "="*70) + print_rank0("HEAVY ContextualizedCorrelationNetworks SCALING BENCHMARK") + print_rank0("="*70) + print_rank0(f" World size: {world_size}") + print_rank0(f" Epochs: {args.epochs}") + print_rank0(f" Batch size: {args.batch_size}") + print_rank0(f" Archetypes: {args.archetypes}") + print_rank0(f" Encoder: {args.encoder_width}w × {args.encoder_layers}L") + print_rank0(f" Bootstraps: {args.bootstraps}") + print_rank0(f" Data PCs: {args.data_pcs}") + + # Load data + C_train, X_train_norm, C_test, X_test_norm, _, _ = load_and_preprocess( + subsample_fraction=args.subsample_fraction, + use_cache=not args.no_cache, + n_data_pcs=args.data_pcs, + n_context_pcs=args.context_pcs, + ) + + barrier() + + # Run benchmark + result = run_ccn_benchmark( + label=label, + C_train=C_train, + X_train_norm=X_train_norm, + C_test=C_test, + X_test_norm=X_test_norm, + epochs=args.epochs, + devices=args.devices, + batch_size_per_gpu=args.batch_size, + num_workers=args.num_workers, + num_archetypes=args.archetypes, + encoder_width=args.encoder_width, + encoder_layers=args.encoder_layers, + n_bootstraps=args.bootstraps, + warmup_epochs=args.warmup_epochs, + baseline_time=args.baseline_time, + ) + + # Save results + if is_global_zero(): + save_results_csv([result], args.outdir) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/networks_pert_scale_bench.py b/networks_pert_scale_bench.py new file mode 100644 index 00000000..2525de0b --- /dev/null +++ b/networks_pert_scale_bench.py @@ -0,0 +1,826 @@ +#!/usr/bin/env python3 +""" +Baseline scaling benchmark for unseen_pert with true 1-GPU vs 2-GPU comparison (DDP). + +- Preprocesses L1000 + controls, building C (context) and X (features) +- Trains a simple MLP regressor C -> X + +It runs two modes in ONE command: + 1) 1 GPU -> single-process training on cuda:0 + 2) 2 GPUs -> DistributedDataParallel (DDP) with 2 processes (ranks 0 and 1), + each bound to one GPU. +D +For each mode it prints: + - wall time (seconds) + - throughput (samples / second) + - final train MSE + - final test MSE + +Outputs: + - CSV: bench_out_unseen/scale_results_unseen_ddp.csv (two rows: 1gpu, 2gpu) + +Typical usage inside a 2-GPU interactive job: + + cd /fs/scratch/PAS2942/samuel_wales_mcgrath/hpc/Contextualized + conda activate contextpert-hpc + + python unseen_pert_scale_ddp.py \ + --epochs 20 \ + --batch-size 512 \ + --num-workers 0 \ + --subsample-fraction 1.0 +""" + +import os +import time +import csv +import warnings +from dataclasses import dataclass +from typing import Tuple, Optional, List + +import numpy as np +import pandas as pd +from sklearn.decomposition import PCA +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data.distributed import DistributedSampler +import torch.multiprocessing as mp +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +from rdkit import Chem +from rdkit.Chem import rdFingerprintGenerator + + +# ------------------- paths & basic config ------------------- + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +DATA_DIR = os.path.join(os.path.dirname(BASE_DIR), "data") + +PATH_L1000 = os.path.join(DATA_DIR, "trt_cp_smiles_qc.csv") +PATH_CTLS = os.path.join(DATA_DIR, "ctrls.csv") + +N_DATA_PCS = 50 +PERTURBATION_HOLDOUT_SIZE = 0.2 +RANDOM_STATE = 42 + +morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=4096) + + +# ------------------- env + seeds ------------------- + +def set_env_defaults(): + """Safe CPU/GPU threading + seeds (for non-DDP parts).""" + os.environ.setdefault("OMP_NUM_THREADS", "1") + os.environ.setdefault("MKL_NUM_THREADS", "1") + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + + # Do NOT clear MASTER_ADDR/MASTER_PORT here; DDP will need them later. + try: + torch.set_float32_matmul_precision("high") + except Exception: + pass + + np.random.seed(RANDOM_STATE) + torch.manual_seed(RANDOM_STATE) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(RANDOM_STATE) + + +def set_seeds(rank: int): + """Per-process seeds for DDP workers.""" + np.random.seed(RANDOM_STATE + rank) + torch.manual_seed(RANDOM_STATE + rank) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(RANDOM_STATE + rank) + + +# ------------------- fingerprint helper ------------------- + +def smiles_to_morgan_fp(smiles: str) -> np.ndarray: + """Convert a SMILES string to a Morgan fingerprint (binary vector).""" + try: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + warnings.warn(f"Invalid SMILES: {smiles}") + return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) + fp = morgan_gen.GetFingerprint(mol) + arr = np.array(fp, dtype=np.float32) + return arr + except Exception as e: + warnings.warn(f"Error processing SMILES {smiles}: {e}") + return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) + + +# ------------------- data preprocessing (unseen_pert) ------------------- + +def load_and_preprocess( + subsample_fraction: Optional[float] = None, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Implements the unseen_pert preprocessing, returning: + C_train, X_train_norm, C_test, X_test_norm, cell_ids_train, cell_ids_test + + where: + - X_*_norm are PCA+standardized gene features + - C_* are context vectors (ctrl PCs + Morgan + time/dose flags) + """ + print(f"Reading L1000 data from {PATH_L1000}") + df = pd.read_csv(PATH_L1000, engine="pyarrow") + + # Only trt_cp perturbations + df = df[df["pert_type"].isin(["trt_cp"])] + + # Quality filters + bad = ( + (df["distil_cc_q75"] < 0.2) + | (df["distil_cc_q75"] == -666) + | (df["distil_cc_q75"].isna()) + | (df["pct_self_rank_q25"] > 5) + | (df["pct_self_rank_q25"] == -666) + | (df["pct_self_rank_q25"].isna()) + ) + df = df[~bad] + + # Valid SMILES only + df = df.dropna(subset=["canonical_smiles"]) + df = df[df["canonical_smiles"] != ""] + + print(f"Remaining samples after QC + SMILES filter: {len(df)}") + + if subsample_fraction is not None: + df = df.sample(frac=subsample_fraction, random_state=RANDOM_STATE) + print(f"Subsampled to {len(df)} samples ({subsample_fraction * 100:.1f}% of data)") + + # Perturbation holdout: split on unique SMILES + unique_smiles = df["canonical_smiles"].unique() + print(f"Found {len(unique_smiles)} unique perturbations (SMILES)") + smiles_train, smiles_test = train_test_split( + unique_smiles, + test_size=PERTURBATION_HOLDOUT_SIZE, + random_state=RANDOM_STATE, + ) + + df_train = df[df["canonical_smiles"].isin(smiles_train)].copy() + df_test = df[df["canonical_smiles"].isin(smiles_test)].copy() + + print(f"Perturbation split: {len(smiles_train)} train, {len(smiles_test)} test perturbations") + print(f"Sample split: {len(df_train)} train, {len(df_test)} test samples") + + # Handle pert_time / pert_dose missing values with -666 logic + pert_time_mean = None + pert_dose_mean = None + + for df_split, split_name in ((df_train, "train"), (df_test, "test")): + df_split["ignore_flag_pert_time"] = (df_split["pert_time"] == -666).astype(int) + df_split["ignore_flag_pert_dose"] = (df_split["pert_dose"] == -666).astype(int) + + for col in ["pert_time", "pert_dose"]: + if split_name == "train": + mean_val = df_split.loc[df_split[col] != -666, col].mean() + if col == "pert_time": + pert_time_mean = mean_val + else: + pert_dose_mean = mean_val + else: + mean_val = pert_time_mean if col == "pert_time" else pert_dose_mean + + df_split[col] = df_split[col].replace(-666, mean_val) + + def process_data_split(df_split, split_name): + numeric_cols = df_split.select_dtypes(include=[np.number]).columns + drop_cols = [ + "pert_dose", + "pert_dose_unit", + "pert_time", + "distil_cc_q75", + "pct_self_rank_q25", + ] + feature_cols = [c for c in numeric_cols if c not in drop_cols] + X_raw = df_split[feature_cols].values.astype(np.float32) + + print(f"[{split_name}] Generating Morgan fingerprints...") + fps = np.stack([smiles_to_morgan_fp(s) for s in df_split["canonical_smiles"]]) + print(f"[{split_name}] Morgan shape: {fps.shape}") + + pert_time = df_split["pert_time"].to_numpy().reshape(-1, 1).astype(np.float32) + pert_dose = df_split["pert_dose"].to_numpy().reshape(-1, 1).astype(np.float32) + ignore_time = df_split["ignore_flag_pert_time"].to_numpy().reshape(-1, 1).astype(np.float32) + ignore_dose = df_split["ignore_flag_pert_dose"].to_numpy().reshape(-1, 1).astype(np.float32) + + return X_raw, fps, pert_time, pert_dose, ignore_time, ignore_dose + + (X_raw_train, morgan_train, pt_train, pd_train, ign_t_train, ign_d_train) = process_data_split( + df_train, "train" + ) + (X_raw_test, morgan_test, pt_test, pd_test, ign_t_test, ign_d_test) = process_data_split( + df_test, "test" + ) + + # Scale gene expression + print("Scaling gene expression...") + scaler_genes = StandardScaler() + X_train_scaled = scaler_genes.fit_transform(X_raw_train) + X_test_scaled = scaler_genes.transform(X_raw_test) + + # Morgan fingerprints as float (already binary) + morgan_train_scaled = morgan_train.astype(np.float32) + morgan_test_scaled = morgan_test.astype(np.float32) + + # Load controls + print(f"Reading control profiles from {PATH_CTLS}") + ctrls_df = pd.read_csv(PATH_CTLS, index_col=0) + + unique_cells_train = np.sort(df_train["cell_id"].unique()) + unique_cells_test = np.sort(df_test["cell_id"].unique()) + unique_cells_all = np.sort(np.union1d(unique_cells_train, unique_cells_test)) + + ctrls_df = ctrls_df.loc[ctrls_df.index.intersection(unique_cells_all)] + scaler_ctrls = StandardScaler() + ctrls_scaled = scaler_ctrls.fit_transform(ctrls_df.values) + + n_cells = ctrls_scaled.shape[0] + n_ctrl_pcs = min(50, n_cells) + + pca_ctrls = PCA(n_components=n_ctrl_pcs, random_state=RANDOM_STATE) + ctrls_pcs = pca_ctrls.fit_transform(ctrls_scaled) + + cell2vec = dict(zip(ctrls_df.index, ctrls_pcs)) + if not cell2vec: + raise ValueError("No overlapping cell IDs between L1000 and ctrls.csv") + + print(f"Control embeddings for {len(cell2vec)} cells (PCs={n_ctrl_pcs})") + + def build_context_matrix( + df_split, + X_scaled, + morgan_scaled, + pt, + pd, + ign_t, + ign_d, + split_name, + scaler_context=None, + is_train=False, + ): + cell_ids = df_split["cell_id"].to_numpy() + unique_cells_split = np.sort(df_split["cell_id"].unique()) + + all_continuous_context = [] + valid_cells = [] + + for cell_id in unique_cells_split: + if cell_id not in cell2vec: + print(f"[{split_name}] Warning: cell {cell_id} not in control embeddings; skipping") + continue + mask = cell_ids == cell_id + if mask.sum() == 0: + continue + + valid_cells.append(cell_id) + cont = np.hstack( + [ + np.tile(cell2vec[cell_id], (mask.sum(), 1)), + pt[mask], + pd[mask], + ] + ).astype(np.float32) + all_continuous_context.append(cont) + + if is_train: + all_cont = np.vstack(all_continuous_context) + scaler_context = StandardScaler() + scaler_context.fit(all_cont) + print(f"[{split_name}] Context scaler fit on {all_cont.shape} continuous features") + + if scaler_context is None: + raise ValueError("scaler_context must be provided for non-training split") + + X_list, C_list, cid_list = [], [], [] + + for i, cell_id in enumerate(valid_cells): + mask = cell_ids == cell_id + X_cell = X_scaled[mask] + cont_scaled = scaler_context.transform(all_continuous_context[i]) + C_cell = np.hstack( + [ + cont_scaled, + morgan_scaled[mask], + ign_t[mask], + ign_d[mask], + ] + ).astype(np.float32) + + X_list.append(X_cell) + C_list.append(C_cell) + cid_list.append(cell_ids[mask]) + + if not X_list: + raise RuntimeError(f"No data for split {split_name}") + + X_final = np.vstack(X_list) + C_final = np.vstack(C_list) + cell_ids_final = np.concatenate(cid_list) + + return X_final, C_final, cell_ids_final, scaler_context + + print("Building context matrices...") + X_train, C_train, cell_ids_train, scaler_context = build_context_matrix( + df_train, + X_train_scaled, + morgan_train_scaled, + pt_train, + pd_train, + ign_t_train, + ign_d_train, + "train", + is_train=True, + ) + X_test, C_test, cell_ids_test, _ = build_context_matrix( + df_test, + X_test_scaled, + morgan_test_scaled, + pt_test, + pd_test, + ign_t_test, + ign_d_test, + "test", + scaler_context=scaler_context, + is_train=False, + ) + + print(f"C_train: {C_train.shape}, X_train: {X_train.shape}") + print(f"C_test: {C_test.shape}, X_test: {X_test.shape}") + + # PCA on X then scale + print("PCA + scaling on gene features...") + pca_data = PCA(n_components=N_DATA_PCS, random_state=RANDOM_STATE) + X_train_pca = pca_data.fit_transform(X_train) + X_test_pca = pca_data.transform(X_test) + + pca_scaler = StandardScaler() + X_train_norm = pca_scaler.fit_transform(X_train_pca) + X_test_norm = pca_scaler.transform(X_test_pca) + + print(f"Final X_train_norm: {X_train_norm.shape}, X_test_norm: {X_test_norm.shape}") + + return C_train, X_train_norm, C_test, X_test_norm, cell_ids_train, cell_ids_test + + + +@dataclass +class BenchResult: + label: str + wall_seconds: float + samples_total: int + throughput_sps: float + train_mse_mean: float + test_mse_mean: float + + +class SimpleRegressor(nn.Module): + def __init__(self, in_dim: int, out_dim: int): + super().__init__() + hidden = 512 + self.net = nn.Sequential( + nn.Linear(in_dim, hidden), + nn.ReLU(), + nn.Linear(hidden, hidden), + nn.ReLU(), + nn.Linear(hidden, out_dim), + ) + + def forward(self, x): + return self.net(x) + + +def run_single_gpu( + epochs: int, + batch_size: int, + num_workers: int, + subsample_fraction: Optional[float], +) -> BenchResult: + """Single-process, single-GPU training on cuda:0 (or CPU).""" + label = "1gpu_single" + print("\n================ 1-GPU baseline (single process) ================") + + C_train, X_train_norm, C_test, X_test_norm, _, _ = load_and_preprocess( + subsample_fraction=subsample_fraction + ) + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + print(f"[{label}] Using CUDA on device {device}") + else: + device = torch.device("cpu") + print(f"[{label}] CUDA not available, using CPU") + + C_train_t = torch.from_numpy(C_train).float() + X_train_t = torch.from_numpy(X_train_norm).float() + C_test_t = torch.from_numpy(C_test).float() + X_test_t = torch.from_numpy(X_test_norm).float() + + train_ds = TensorDataset(C_train_t, X_train_t) + test_ds = TensorDataset(C_test_t, X_test_t) + + train_loader = DataLoader( + train_ds, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers, + pin_memory=torch.cuda.is_available(), + ) + test_loader = DataLoader( + test_ds, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=torch.cuda.is_available(), + ) + + in_dim = C_train.shape[1] + out_dim = X_train_norm.shape[1] + + model = SimpleRegressor(in_dim, out_dim).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + criterion = nn.MSELoss() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.backends.cudnn.benchmark = True + torch.cuda.synchronize() + + n_samples = C_train.shape[0] + t0 = time.time() + + for epoch in range(epochs): + model.train() + epoch_loss = 0.0 + for batch_C, batch_X in train_loader: + batch_C = batch_C.to(device, non_blocking=True) + batch_X = batch_X.to(device, non_blocking=True) + + optimizer.zero_grad() + preds = model(batch_C) + loss = criterion(preds, batch_X) + loss.backward() + optimizer.step() + epoch_loss += loss.item() * batch_C.size(0) + + epoch_loss /= n_samples + print(f"[{label}] Epoch {epoch+1}/{epochs} - train MSE {epoch_loss:.6f}") + + if torch.cuda.is_available(): + torch.cuda.synchronize() + wall = time.time() - t0 + + samples_total = n_samples * epochs + throughput = samples_total / max(wall, 1e-9) + + # Evaluation + def eval_mse(loader, split_name: str) -> float: + model.eval() + total_loss = 0.0 + count = 0 + with torch.no_grad(): + for batch_C, batch_X in loader: + batch_C = batch_C.to(device, non_blocking=True) + batch_X = batch_X.to(device, non_blocking=True) + preds = model(batch_C) + loss = criterion(preds, batch_X) + bsz = batch_C.size(0) + total_loss += loss.item() * bsz + count += bsz + mse = total_loss / max(count, 1) + print(f"[{label}] {split_name} MSE {mse:.6f}") + return mse + + train_mse = eval_mse(train_loader, "train") + test_mse = eval_mse(test_loader, "test") + + print(f"\n[{label}] run complete") + print(f" wall time (s): {wall:.2f}") + print(f" total samples: {samples_total}") + print(f" throughput (samples/s): {throughput:.2f}") + print(f" final train MSE: {train_mse:.6f}") + print(f" final test MSE: {test_mse:.6f}") + + return BenchResult( + label=label, + wall_seconds=wall, + samples_total=samples_total, + throughput_sps=throughput, + train_mse_mean=train_mse, + test_mse_mean=test_mse, + ) + + +def ddp_worker( + rank: int, + world_size: int, + port: str, + epochs: int, + batch_size: int, + num_workers: int, + subsample_fraction: Optional[float], + result_dict, +): + """ + DDP worker function run by each spawned process (rank 0 and 1). + We only record metrics in rank 0 and put them in result_dict["2gpu_ddp"]. + """ + set_seeds(rank) + + # Device mapping: assume 2 GPUs visible, use local index = rank + if torch.cuda.is_available(): + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + else: + device = torch.device("cpu") + + init_method = f"tcp://127.0.0.1:{port}" + dist.init_process_group( + backend="gloo", + init_method=init_method, + world_size=world_size, + rank=rank, + ) + + label = "2gpu_ddp" + if rank == 0: + print("\n================ 2-GPU DDP baseline ================") + print(f"[{label}] world_size={world_size}, backend=gloo, init_method={init_method}") + if torch.cuda.is_available(): + print(f"[{label}] Using GPUs 0 and 1 with DDP") + + # IMPORTANT: we measure only training time; data loading can be duplicated. + C_train, X_train_norm, C_test, X_test_norm, _, _ = load_and_preprocess( + subsample_fraction=subsample_fraction + ) + + C_train_t = torch.from_numpy(C_train).float() + X_train_t = torch.from_numpy(X_train_norm).float() + C_test_t = torch.from_numpy(C_test).float() + X_test_t = torch.from_numpy(X_test_norm).float() + + train_ds = TensorDataset(C_train_t, X_train_t) + test_ds = TensorDataset(C_test_t, X_test_t) + + train_sampler = DistributedSampler( + train_ds, + num_replicas=world_size, + rank=rank, + shuffle=True, + drop_last=False, + ) + + train_loader = DataLoader( + train_ds, + batch_size=batch_size, + sampler=train_sampler, + num_workers=num_workers, + pin_memory=torch.cuda.is_available(), + ) + + # test_loader will only be used on rank 0 after training (non-distributed). + in_dim = C_train.shape[1] + out_dim = X_train_norm.shape[1] + + model = SimpleRegressor(in_dim, out_dim).to(device) + ddp_model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None) + + optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3) + criterion = nn.MSELoss() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.backends.cudnn.benchmark = True + torch.cuda.synchronize() + + n_samples = C_train.shape[0] + + # Synchronize before timing + dist.barrier() + if torch.cuda.is_available(): + torch.cuda.synchronize() + t0 = time.time() + + # Training loop + for epoch in range(epochs): + ddp_model.train() + train_sampler.set_epoch(epoch) + + running_loss = 0.0 + count_seen = 0 + + for batch_C, batch_X in train_loader: + batch_C = batch_C.to(device, non_blocking=True) + batch_X = batch_X.to(device, non_blocking=True) + + optimizer.zero_grad() + preds = ddp_model(batch_C) + loss = criterion(preds, batch_X) + loss.backward() + optimizer.step() + + bsz = batch_C.size(0) + running_loss += loss.item() * bsz + count_seen += bsz + + # Aggregate epoch loss to rank 0 (average over all samples) + loss_tensor = torch.tensor([running_loss, count_seen], dtype=torch.float64, device=device) + dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM) + if rank == 0: + total_loss, total_count = loss_tensor.tolist() + epoch_loss = total_loss / max(total_count, 1.0) + print(f"[{label}] Epoch {epoch+1}/{epochs} - train MSE {epoch_loss:.6f}") + + # Synchronize end of training + dist.barrier() + if torch.cuda.is_available(): + torch.cuda.synchronize() + wall = time.time() - t0 + + # Only rank 0 computes evaluation, using full dataset on its GPU + if rank == 0: + # For evaluation we use the underlying model (not wrapped in DDP) + eval_model = ddp_model.module + eval_model.eval() + + test_loader = DataLoader( + test_ds, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=torch.cuda.is_available(), + ) + train_loader_full = DataLoader( + train_ds, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + pin_memory=torch.cuda.is_available(), + ) + + def eval_mse(loader, split_name: str) -> float: + total_loss = 0.0 + count = 0 + with torch.no_grad(): + for batch_C, batch_X in loader: + batch_C = batch_C.to(device, non_blocking=True) + batch_X = batch_X.to(device, non_blocking=True) + preds = eval_model(batch_C) + loss = criterion(preds, batch_X) + bsz = batch_C.size(0) + total_loss += loss.item() * bsz + count += bsz + mse = total_loss / max(count, 1) + print(f"[{label}] {split_name} MSE {mse:.6f}") + return mse + + train_mse = eval_mse(train_loader_full, "train") + test_mse = eval_mse(test_loader, "test") + + samples_total = n_samples * epochs + throughput = samples_total / max(wall, 1e-9) + + print(f"\n[{label}] run complete") + print(f" wall time (s): {wall:.2f}") + print(f" total samples: {samples_total}") + print(f" throughput (samples/s): {throughput:.2f}") + print(f" final train MSE: {train_mse:.6f}") + print(f" final test MSE: {test_mse:.6f}") + + result_dict["2gpu_ddp"] = BenchResult( + label=label, + wall_seconds=wall, + samples_total=samples_total, + throughput_sps=throughput, + train_mse_mean=train_mse, + test_mse_mean=test_mse, + ) + + # Tear down process group + dist.destroy_process_group() + + +# ------------------- CSV writer ------------------- + +def save_results_csv(results: List[BenchResult], outdir: str): + os.makedirs(outdir, exist_ok=True) + path = os.path.join(outdir, "scale_results_unseen_ddp.csv") + with open(path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow( + [ + "label", + "wall_seconds", + "samples_total", + "throughput_samples_per_s", + "train_mse_mean", + "test_mse_mean", + ] + ) + for r in results: + writer.writerow( + [ + r.label, + f"{r.wall_seconds:.6f}", + r.samples_total, + f"{r.throughput_sps:.6f}", + f"{r.train_mse_mean:.6f}", + f"{r.test_mse_mean:.6f}", + ] + ) + print(f"\nSaved CSV → {path}") + + +# ------------------- CLI & main ------------------- + +def parse_args(): + import argparse + + ap = argparse.ArgumentParser() + ap.add_argument("--epochs", type=int, default=3) + ap.add_argument("--batch-size", type=int, default=256) + ap.add_argument( + "--num-workers", + type=int, + default=0, + help="DataLoader workers (0 is safest on HPC).", + ) + ap.add_argument( + "--subsample-fraction", + type=float, + default=None, + help="Optional fraction of rows to subsample for quick tests", + ) + ap.add_argument( + "--outdir", + type=str, + default="bench_out_unseen", + ) + ap.add_argument( + "--ddp-port", + type=str, + default="29611", + help="TCP port for DDP init_method (tcp://127.0.0.1:PORT).", + ) + return ap.parse_args() + + +def main(): + args = parse_args() + mp.set_start_method("spawn", force=True) + set_env_defaults() + + results: List[BenchResult] = [] + + # 1-GPU baseline + res_1gpu = run_single_gpu( + epochs=args.epochs, + batch_size=args.batch_size, + num_workers=args.num_workers, + subsample_fraction=args.subsample_fraction, + ) + results.append(res_1gpu) + + # 2-GPU DDP baseline + if torch.cuda.is_available() and torch.cuda.device_count() >= 2: + world_size = 2 + port = args.ddp_port # use TCP init on localhost + + manager = mp.Manager() + result_dict = manager.dict() + + mp.spawn( + ddp_worker, + args=( + world_size, + port, + args.epochs, + args.batch_size, + args.num_workers, + args.subsample_fraction, + result_dict, + ), + nprocs=world_size, + join=True, + ) + + if "2gpu_ddp" in result_dict: + results.append(result_dict["2gpu_ddp"]) + else: + print("\n[WARN] DDP finished but no result in result_dict['2gpu_ddp'].") + else: + print("\n[Info] < 2 GPUs visible; skipping 2-GPU DDP benchmark.") + + save_results_csv(results, args.outdir) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scale_bench.py b/regression_scale_bench.py similarity index 100% rename from scale_bench.py rename to regression_scale_bench.py diff --git a/scripts/test_contextualized_dm.py b/scripts/test_contextualized_dm.py deleted file mode 100644 index 7a90514e..00000000 --- a/scripts/test_contextualized_dm.py +++ /dev/null @@ -1,252 +0,0 @@ -# scripts/test_contextualized_dm.py -""" -Smoke-test your ContextualizedRegressionDataModule with synthetic data. - -Examples: - # Single-process sanity check - python scripts/test_contextualized_dm.py --task-type singletask_multivariate --peek - - # CPU DDP on Windows (Git Bash or PowerShell) - python scripts/test_contextualized_dm.py --task-type singletask_multivariate --devices 2 --peek -""" - -from __future__ import annotations -import argparse -import os -import sys -import tempfile -from pathlib import Path -from typing import Dict, Optional, Tuple - -import torch -from torch import nn -import lightning as pl -from lightning.pytorch.strategies import DDPStrategy - -# --- Make repo root importable if running from source tree --- -REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -if REPO_ROOT not in sys.path: - sys.path.insert(0, REPO_ROOT) - -from contextualized.regression.datamodules import ContextualizedRegressionDataModule - -# ---- Candidate key names in your batch dict ---- -CTX_CANDIDATES = ("contexts", "context", "ctx", "C", "c") -X_CANDIDATES = ("predictors", "X", "features", "x", "inputs", "data") -Y_CANDIDATES = ("outcomes", "Y", "targets", "y", "labels") - - -def pick_first_key(d: Dict[str, torch.Tensor], candidates) -> Optional[str]: - for k in candidates: - if k in d: - return k - return None - - -# --------------------------- -# Synthetic (C, X, Y) -# --------------------------- -def make_synthetic( - n: int, - c_dim: int, - x_dim: int, - y_dim: int, - seed: int = 1234, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - g = torch.Generator().manual_seed(seed) - C = torch.randn(n, c_dim, generator=g) - X = torch.randn(n, x_dim, generator=g) - W = torch.randn(x_dim, y_dim, generator=g) / (x_dim ** 0.5) - Y = X @ W + 0.05 * torch.randn(n, y_dim, generator=g) - return C, X, Y - - -def make_indices(n: int, train_frac=0.7, val_frac=0.15) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - idx = torch.randperm(n) - n_train = int(n * train_frac) - n_val = int(n * val_frac) - train_idx = idx[:n_train] - val_idx = idx[n_train:n_train + n_val] - test_idx = idx[n_train + n_val:] - return train_idx, val_idx, test_idx - - -# --------------------------- -# Tiny adaptive model -# --------------------------- -class AdaptiveTinyModel(pl.LightningModule): - """ - - If batch has (features, targets): Linear -> MSE - - Else if batch has "contexts": mean(contexts**2) - - Else: mean of first float tensor - Holds an anchor param so the optimizer is never empty. - """ - def __init__(self, x_dim: Optional[int] = None, y_dim: Optional[int] = None, lr: float = 1e-2): - super().__init__() - self.lr = lr - self.mse = nn.MSELoss() - self._anchor = nn.Parameter(torch.tensor(0.0)) - self.head = nn.Linear(x_dim, y_dim) if (x_dim is not None and y_dim is not None) else None - - def _compute_loss(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - x_key = pick_first_key(batch, X_CANDIDATES) - y_key = pick_first_key(batch, Y_CANDIDATES) - - if x_key and y_key: - x = batch[x_key].float() - y = batch[y_key].float() - if x.ndim == 3: # (B, T, D) - B, T, D = x.shape - x = x.view(B * T, D) - y = y.view(B * T, -1) - if self.head is None: - self.head = nn.Linear(x.shape[-1], y.shape[-1]).to(self.device) - preds = self.head(x) - return self.mse(preds, y) - - c_key = pick_first_key(batch, CTX_CANDIDATES) - if c_key: - c = batch[c_key].float() - return (c ** 2).mean() - - for k, v in batch.items(): - if torch.is_tensor(v) and v.dtype.is_floating_point: - return (v.float() ** 2).mean() - - raise RuntimeError("No usable tensor found in batch to compute a loss.") - - def training_step(self, batch, batch_idx): - loss = self._compute_loss(batch) - self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True) - return loss - - def validation_step(self, batch, batch_idx): - loss = self._compute_loss(batch) - self.log("val_loss", loss, on_epoch=True, prog_bar=True) - - def configure_optimizers(self): - return torch.optim.SGD(self.parameters(), lr=self.lr) - - -# --------------------------- -# CLI / Trainer -# --------------------------- -def parse_args(): - p = argparse.ArgumentParser(description="Test ContextualizedRegressionDataModule") - p.add_argument("--task-type", - choices=[ - "singletask_multivariate", - "singletask_univariate", - "multitask_multivariate", - "multitask_univariate", - ], - required=True) - p.add_argument("--n", type=int, default=256, help="Total samples") - p.add_argument("--c-dim", type=int, default=8, help="Context dim") - p.add_argument("--x-dim", type=int, default=16, help="Feature dim") - p.add_argument("--y-dim", type=int, default=4, help="Target dim") - p.add_argument("--batch-size", type=int, default=32) - p.add_argument("--num-workers", type=int, default=0) - p.add_argument("--devices", type=int, default=1) - p.add_argument("--max-epochs", type=int, default=1) - p.add_argument("--limit-train-batches", type=float, default=2) - p.add_argument("--limit-val-batches", type=float, default=1) - p.add_argument("--peek", action="store_true", help="Print first batch keys/shapes") - return p.parse_args() - - -def _unset_dist_env(): - # Ensure env:// rendezvous is NOT selected - for k in ("MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK", "LOCAL_RANK", "INIT_METHOD"): - if k in os.environ: - os.environ.pop(k) - - -def build_trainer(args) -> pl.Trainer: - if args.devices > 1: - # Force local file-store DDP init (no sockets/ports) - init_path = Path(tempfile.gettempdir()) / f"pl_init_{os.getpid()}.pt" - init_uri = init_path.as_uri() # proper file:///C:/... on Windows - strategy = DDPStrategy( - process_group_backend="gloo", - init_method=init_uri, - ) - else: - strategy = "auto" - - return pl.Trainer( - accelerator="cpu", - devices=args.devices, - strategy=strategy, - max_epochs=args.max_epochs, - limit_train_batches=args.limit_train_batches, - limit_val_batches=args.limit_val_batches, - enable_progress_bar=True, - logger=False, - ) - - -def main(): - # Windows-safe start method for spawn/DDP - try: - import torch.multiprocessing as mp - mp.set_start_method("spawn", force=True) - except RuntimeError: - pass - - _unset_dist_env() - args = parse_args() - - # --- synthetic data + splits --- - C, X, Y = make_synthetic(n=args.n, c_dim=args.c_dim, x_dim=args.x_dim, y_dim=args.y_dim) - train_idx, val_idx, test_idx = make_indices(args.n) - - # --- your datamodule --- - dm = ContextualizedRegressionDataModule( - C=C, X=X, Y=Y, - task_type=args.task_type, - train_idx=train_idx, - val_idx=val_idx, - test_idx=test_idx, - predict_idx=None, - batch_size=args.batch_size, - num_workers=args.num_workers, - pin_memory=False, # CPU run - persistent_workers=False, # safe with num_workers=0 - drop_last=False, - shuffle_train=True, - shuffle_eval=False, - dtype=torch.float, - ) - - # Setup and peek one batch to infer dims so model has parameters before optimizer init - dm.setup("fit") - sample = next(iter(dm.train_dataloader())) - if args.peek: - print("[peek] batch keys:", list(sample.keys())) - for k, v in sample.items(): - if torch.is_tensor(v): - print(f"[peek] {k}: shape={tuple(v.shape)} dtype={v.dtype}") - print() - - # Infer x_dim/y_dim from batch (handles (B,T,D)) - x_key = pick_first_key(sample, X_CANDIDATES) - y_key = pick_first_key(sample, Y_CANDIDATES) - x_dim = y_dim = None - if x_key and y_key: - x = sample[x_key] - y = sample[y_key] - x_dim = x.shape[-1] - y_dim = y.shape[-1] - - # Build model (now has params) - model = AdaptiveTinyModel(x_dim=x_dim, y_dim=y_dim) - - # Trainer - trainer = build_trainer(args) - trainer.fit(model, dm) - print("✅ Test completed successfully.") - - -if __name__ == "__main__": - main() From 661e41827d1de3c4ea6db2e1855acd12875d9ad5 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Mon, 22 Dec 2025 16:26:38 -0500 Subject: [PATCH 13/19] fixed errors with lightning_modules params --- 01_regressor_cpu_single.py | 92 +++ 02_networks_cpu_single.py | 49 ++ bench_scale_contextualized_regression.py | 766 ++++++++++++++++++ .../easy/ContextualizedClassifier.py | 32 +- contextualized/easy/ContextualizedNetworks.py | 197 +++-- .../easy/wrappers/SKLearnWrapper.py | 354 ++++++-- contextualized/regression/datamodules.py | 41 +- contextualized/regression/datasets.py | 132 +-- .../regression/lightning_modules.py | 220 ++--- contextualized/regression/trainers.py | 297 +++++-- 10 files changed, 1777 insertions(+), 403 deletions(-) create mode 100644 01_regressor_cpu_single.py create mode 100644 02_networks_cpu_single.py create mode 100644 bench_scale_contextualized_regression.py diff --git a/01_regressor_cpu_single.py b/01_regressor_cpu_single.py new file mode 100644 index 00000000..3b902840 --- /dev/null +++ b/01_regressor_cpu_single.py @@ -0,0 +1,92 @@ +import numpy as np + +from contextualized.easy import ContextualizedRegressor + +def main(): + np.random.seed(0) + + n = 96 + c_dim = 4 + x_dim = 6 + y_dim = 2 + + C = np.random.randn(n, c_dim).astype(np.float32) + X = np.random.randn(n, x_dim).astype(np.float32) + + # Construct a learnable signal: Y = X @ W + noise + W = np.array([[1.5, -0.5], + [0.7, 0.2], + [0.0, 0.0], + [0.3, -1.0], + [0.0, 0.0], + [0.2, 0.1]], dtype=np.float32) # (x_dim, y_dim) + + Y = (X @ W + 0.05 * np.random.randn(n, y_dim).astype(np.float32)).astype(np.float32) + + model = ContextualizedRegressor( + metamodel_type="subtype", + num_archetypes=4, + univariate=False, + ) + + # CPU-only fit + model.fit( + C=C, X=X, Y=Y, + accelerator="cpu", + devices=1, + strategy="auto", + max_epochs=3, + val_split=0.2, + num_workers=0, + enable_progress_bar=False, + logger=False, + ) + + yhat = model.predict(C, X) + betas, mus = model.predict_params(C) + + # --- shape sanity --- + yhat_arr = np.asarray(yhat) + betas_arr = np.asarray(betas) + mus_arr = np.asarray(mus) + + print("SINGLE CPU REGRESSOR") + print("yhat.shape:", yhat_arr.shape) + print("betas.shape:", betas_arr.shape) + print("mus.shape:", mus_arr.shape) + + # Expected conventions (based on your current implementation) + assert yhat_arr.shape[0] == n, "yhat first dim should be n" + assert betas_arr.shape[0] == n, "betas first dim should be n" + assert mus_arr.shape[0] == n, "mus first dim should be n" + + # yhat is typically (n, y_dim, 1) for multivariate in your code path + # betas is (n, y_dim, x_dim) + assert betas_arr.shape[1] == y_dim and betas_arr.shape[2] == x_dim, "betas expected (n, y_dim, x_dim)" + + # --- quick quality check: MSE vs baseline mean predictor --- + # squeeze last dim if present + yhat_s = yhat_arr[..., 0] if (yhat_arr.ndim == 3 and yhat_arr.shape[-1] == 1) else yhat_arr + y_true = Y + + mse = np.mean((yhat_s - y_true) ** 2) + baseline = np.mean((np.mean(y_true, axis=0, keepdims=True) - y_true) ** 2) + + print("MSE:", float(mse)) + print("Baseline MSE (mean predictor):", float(baseline)) + assert np.isfinite(mse), "MSE must be finite" + assert mse < baseline, "Model should beat baseline mean predictor on this synthetic signal" + + # --- ordering check (this is critical for your gather/sort design) --- + perm = np.random.permutation(n) + yhat_perm = np.asarray(model.predict(C[perm], X[perm])) + yhat_perm_s = yhat_perm[..., 0] if (yhat_perm.ndim == 3 and yhat_perm.shape[-1] == 1) else yhat_perm + + max_err = np.max(np.abs(yhat_perm_s - yhat_s[perm])) + print("Ordering check max_err:", float(max_err)) + assert max_err < 1e-5, "Prediction order is not stable under permutation" + + print("PASS: single-process CPU regressor tests") + +if __name__ == "__main__": + main() diff --git a/02_networks_cpu_single.py b/02_networks_cpu_single.py new file mode 100644 index 00000000..5999563c --- /dev/null +++ b/02_networks_cpu_single.py @@ -0,0 +1,49 @@ +import numpy as np + +from contextualized.easy import ContextualizedCorrelationNetworks # adjust if your import path differs + +def main(): + np.random.seed(0) + + n = 80 + c_dim = 3 + x_dim = 5 + + C = np.random.randn(n, c_dim).astype(np.float32) + X = np.random.randn(n, x_dim).astype(np.float32) + + net = ContextualizedCorrelationNetworks( + metamodel_type="subtype", + num_archetypes=4, + ) + + net.fit( + C=C, X=X, + accelerator="cpu", + devices=1, + strategy="auto", + max_epochs=2, + val_split=0.2, + num_workers=0, + enable_progress_bar=False, + logger=False, + ) + + rhos2 = net.predict_correlation(C, individual_preds=False, squared=True) + rhos2 = np.asarray(rhos2) + + print("SINGLE CPU CORRELATION NETWORKS") + print("rhos2.shape:", rhos2.shape) + + assert rhos2.shape[0] == n and rhos2.shape[1] == x_dim and rhos2.shape[2] == x_dim, \ + "Expected (n, x_dim, x_dim)" + assert np.all(np.isfinite(rhos2)), "Correlations must be finite" + + # Symmetry sanity (should be symmetric-ish) + sym_err = np.max(np.abs(rhos2 - np.transpose(rhos2, (0, 2, 1)))) + print("Symmetry max_err:", float(sym_err)) + + print("PASS: single-process CPU networks tests") + +if __name__ == "__main__": + main() diff --git a/bench_scale_contextualized_regression.py b/bench_scale_contextualized_regression.py new file mode 100644 index 00000000..65b8e6c1 --- /dev/null +++ b/bench_scale_contextualized_regression.py @@ -0,0 +1,766 @@ +#!/usr/bin/env python3 +""" +bench_scale_contextualized_regression.py + +Synthetic scaling benchmark for Contextualized regression workflow. + +Modes: + - run : run a single config (supports torch distributed launch) + - sweep : run CPU + GPU(1..K) sequentially (spawns torch distributed runs) and plot + +Examples (Lambda 4x GPU single node): + # Full sweep: CPU + 1/2/3/4 GPU and plots + python bench_scale_contextualized_regression.py sweep \ + --include_cpu \ + --max_gpus 4 \ + --n 200000 \ + --c_dim 16 --x_dim 64 --y_dim 8 \ + --epochs 5 \ + --train_batch_size 2048 --val_batch_size 2048 --test_batch_size 4096 \ + --num_workers 4 \ + --out_dir ./scale_runs/run1 + + # Single run on 1 GPU (no torchrun needed for 1 device) + python bench_scale_contextualized_regression.py run --accelerator gpu --devices 1 --out_dir ./one_gpu + + # Single run on 4 GPUs using torch distributed launcher + python -m torch.distributed.run --standalone --nproc_per_node=4 \ + bench_scale_contextualized_regression.py run --accelerator gpu --devices 4 --out_dir ./four_gpu +""" + +from __future__ import annotations + +import argparse +import csv +import json +import os +import platform +import re +import shutil +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import numpy as np + + +# ------------------------- +# Utilities +# ------------------------- + +def _now_ts() -> str: + return time.strftime("%Y%m%d_%H%M%S") + + +def _maybe_git_commit() -> Optional[str]: + try: + if not shutil.which("git"): + return None + out = subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL).decode().strip() + if re.fullmatch(r"[0-9a-f]{40}", out): + return out + except Exception: + pass + return None + + +def _rank_world() -> Tuple[int, int, int]: + """(rank, world_size, local_rank) for torchrun-style environments.""" + rank = int(os.environ.get("RANK", "0")) + world = int(os.environ.get("WORLD_SIZE", "1")) + local = int(os.environ.get("LOCAL_RANK", "0")) + return rank, world, local + + +def _set_seed(seed: int) -> None: + np.random.seed(seed) + try: + import torch + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + except Exception: + pass + + +def _cuda_sync_if_available() -> None: + try: + import torch + if torch.cuda.is_available(): + torch.cuda.synchronize() + except Exception: + pass + + +def _safe_float(x: Any) -> float: + try: + return float(x) + except Exception: + return float("nan") + + +def _ensure_dir(p: Path) -> None: + p.mkdir(parents=True, exist_ok=True) + + +# ------------------------- +# Synthetic data generator +# ------------------------- + +def make_synth_contextual_regression( + n: int, + c_dim: int, + x_dim: int, + y_dim: int, + noise: float, + seed: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict[str, np.ndarray]]: + """ + Create synthetic data where coefficients beta(C) vary with context C: + beta_flat = C @ W^T + b + beta = reshape(beta_flat, y_dim, x_dim) + mu = C @ V + y = sum_j beta[..., j]*x_j + mu + eps + """ + rng = np.random.default_rng(seed) + + C = rng.normal(size=(n, c_dim)).astype(np.float32) + X = rng.normal(size=(n, x_dim)).astype(np.float32) + + # Context -> beta mapping + W = (rng.normal(size=(y_dim * x_dim, c_dim)).astype(np.float32) / np.sqrt(c_dim)).astype(np.float32) + b = (0.1 * rng.normal(size=(y_dim * x_dim,))).astype(np.float32) + + beta_flat = C @ W.T + b[None, :] # (n, y_dim*x_dim) + beta = beta_flat.reshape(n, y_dim, x_dim) # (n, y_dim, x_dim) + + # Context -> intercept + V = (0.1 * rng.normal(size=(c_dim, y_dim))).astype(np.float32) + mu = (C @ V).astype(np.float32) # (n, y_dim) + + y = (beta * X[:, None, :]).sum(axis=-1) + mu + y = y + noise * rng.normal(size=(n, y_dim)).astype(np.float32) + Y = y.astype(np.float32) + + truth = {"beta": beta, "mu": mu} + return C, X, Y, truth + + +def make_splits(n: int, val_frac: float, test_frac: float, seed: int) -> Dict[str, np.ndarray]: + assert 0.0 <= val_frac < 1.0 + assert 0.0 <= test_frac < 1.0 + assert val_frac + test_frac < 1.0 + + rng = np.random.default_rng(seed) + idx = np.arange(n, dtype=np.int64) + rng.shuffle(idx) + + n_test = int(round(n * test_frac)) + n_val = int(round(n * val_frac)) + n_train = n - n_val - n_test + + train_idx = idx[:n_train] + val_idx = idx[n_train:n_train + n_val] + test_idx = idx[n_train + n_val:] + + return {"train_idx": train_idx, "val_idx": val_idx, "test_idx": test_idx} + + +# ------------------------- +# Result schema +# ------------------------- + +@dataclass +class BenchResult: + tag: str + accelerator: str + devices: int + backend: str + n: int + n_train: int + n_val: int + n_test: int + c_dim: int + x_dim: int + y_dim: int + epochs: int + train_batch_size: int + val_batch_size: int + test_batch_size: int + num_workers: int + seed: int + + fit_time_s: float + predict_time_s: float + total_time_s: float + + train_throughput_sps: float # samples/sec (global, unique samples per epoch) + predict_throughput_sps: float + + test_mse: float + + hostname: str + python: str + platform: str + torch: Optional[str] + lightning: Optional[str] + git_commit: Optional[str] + + +# ------------------------- +# Core runner +# ------------------------- + +def run_one(args: argparse.Namespace) -> Optional[BenchResult]: + rank, world, local_rank = _rank_world() + + # Make output directory (all ranks see it; only rank0 writes result). + out_dir = Path(args.out_dir).resolve() + _ensure_dir(out_dir) + + # Deferred imports so CPU-only environments don't choke on CUDA imports early. + try: + import torch + import pytorch_lightning as pl + except Exception as e: + if rank == 0: + raise RuntimeError( + "Failed to import torch / pytorch_lightning. Ensure your env has them installed." + ) from e + return None + + # Set device for GPU + if args.accelerator == "gpu": + if not torch.cuda.is_available(): + if rank == 0: + raise RuntimeError("Requested accelerator=gpu but torch.cuda.is_available() is False.") + return None + torch.cuda.set_device(local_rank) + + # Determinism (best-effort) + _set_seed(args.seed) + try: + pl.seed_everything(args.seed, workers=True) + except Exception: + pass + + # Synthesize data (replicated across ranks; ok for benchmarking) + C, X, Y, _truth = make_synth_contextual_regression( + n=args.n, + c_dim=args.c_dim, + x_dim=args.x_dim, + y_dim=args.y_dim, + noise=args.noise, + seed=args.seed, + ) + + splits = make_splits(args.n, args.val_frac, args.test_frac, args.seed) + train_idx = splits["train_idx"] + val_idx = splits["val_idx"] + test_idx = splits["test_idx"] + + # Build model using your regression workflow + # We prefer the easy wrapper if available, as it exercises your end-to-end stack. + try: + from contextualized.easy import ContextualizedRegressor + except Exception as e: + if rank == 0: + raise RuntimeError( + "Could not import contextualized.easy.ContextualizedRegressor. " + "Verify your package is importable from this environment." + ) from e + return None + + # Strategy configuration (DDP for multi-GPU) + strategy_obj = None + if args.devices > 1: + # Use Lightning DDPStrategy explicitly to control backend (nccl/gloo) + try: + from pytorch_lightning.strategies import DDPStrategy + strategy_obj = DDPStrategy( + process_group_backend=args.backend, + find_unused_parameters=False, + ) + except Exception: + strategy_obj = "ddp" # fallback: let Lightning decide + + # Create the regressor (robust to wrapper signature differences) + # We attempt common kwargs; if wrapper rejects, we show a clear error on rank0. + model_kwargs: Dict[str, Any] = dict( + num_archetypes=args.num_archetypes, + encoder_type=args.encoder_type, + max_epochs=args.epochs, + learning_rate=args.learning_rate, + # data / loader knobs (if wrapper exposes them) + train_batch_size=args.train_batch_size, + val_batch_size=args.val_batch_size, + test_batch_size=args.test_batch_size, + val_split=args.val_frac, + # trainer knobs + accelerator=("gpu" if args.accelerator == "gpu" else "cpu"), + devices=args.devices, + num_workers=args.num_workers, + deterministic=args.deterministic, + enable_checkpointing=False, + logger=False, + enable_progress_bar=False, + ) + + # Some wrappers may not accept the above keys; strip unsupported keys dynamically. + def instantiate_contextualized_regressor() -> Any: + import inspect + sig = inspect.signature(ContextualizedRegressor.__init__) + accepted = set(sig.parameters.keys()) + # Always remove 'self' + accepted.discard("self") + filt = {k: v for k, v in model_kwargs.items() if k in accepted} + # Strategy: some wrappers accept "strategy" directly + if "strategy" in accepted and strategy_obj is not None: + filt["strategy"] = strategy_obj + return ContextualizedRegressor(**filt) + + try: + reg = instantiate_contextualized_regressor() + except TypeError as e: + if rank == 0: + raise RuntimeError( + "Failed to instantiate ContextualizedRegressor with inferred kwargs.\n" + "This usually means the wrapper signature differs from what this benchmark expects.\n" + "Action: open contextualized/easy/wrappers.py and confirm which Trainer/loader args are supported.\n" + f"Original error: {e}" + ) + return None + + # Fit timing + _cuda_sync_if_available() + t0 = time.perf_counter() + + # Prefer explicit indices to exercise your stable-index paths if supported + fit_kwargs: Dict[str, Any] = dict(C=C, X=X, Y=Y) + try: + import inspect + fit_sig = inspect.signature(reg.fit) + if "train_idx" in fit_sig.parameters: + fit_kwargs["train_idx"] = train_idx + if "val_idx" in fit_sig.parameters: + fit_kwargs["val_idx"] = val_idx + if "test_idx" in fit_sig.parameters: + fit_kwargs["test_idx"] = test_idx + except Exception: + pass + + reg.fit(**fit_kwargs) + + _cuda_sync_if_available() + fit_time = time.perf_counter() - t0 + + # Predict timing (prefer predict_idx if supported) + _cuda_sync_if_available() + t1 = time.perf_counter() + + yhat = None + pred_kwargs_full: Dict[str, Any] = dict(C=C, X=X) + try: + import inspect + pred_sig = inspect.signature(reg.predict) + if "predict_idx" in pred_sig.parameters: + pred_kwargs_full["predict_idx"] = test_idx + yhat = reg.predict(**pred_kwargs_full) + else: + # fallback: feed the subset directly + yhat = reg.predict(C[test_idx], X[test_idx]) + except Exception: + # fallback: feed the subset directly + yhat = reg.predict(C[test_idx], X[test_idx]) + + _cuda_sync_if_available() + pred_time = time.perf_counter() - t1 + + # Only rank0 should compute/report metrics if wrapper returns None on non-rank0 + if yhat is None: + return None + + # Convert prediction to numpy + if hasattr(yhat, "detach"): + yhat_np = yhat.detach().cpu().numpy() + else: + yhat_np = np.asarray(yhat) + + y_true = Y[test_idx] + test_mse = float(np.mean((yhat_np - y_true) ** 2)) + + total_time = fit_time + pred_time + + # Throughput (global unique samples per epoch) + n_train = int(train_idx.shape[0]) + n_val = int(val_idx.shape[0]) + n_test = int(test_idx.shape[0]) + + train_throughput = (n_train * args.epochs) / max(fit_time, 1e-9) + pred_throughput = (n_test) / max(pred_time, 1e-9) + + # Version info + torch_ver = getattr(torch, "__version__", None) + lightning_ver = getattr(pl, "__version__", None) + + tag = args.tag or f"{args.accelerator}_{args.devices}dev" + result = BenchResult( + tag=tag, + accelerator=args.accelerator, + devices=args.devices, + backend=args.backend, + + n=args.n, + n_train=n_train, + n_val=n_val, + n_test=n_test, + c_dim=args.c_dim, + x_dim=args.x_dim, + y_dim=args.y_dim, + epochs=args.epochs, + train_batch_size=args.train_batch_size, + val_batch_size=args.val_batch_size, + test_batch_size=args.test_batch_size, + num_workers=args.num_workers, + seed=args.seed, + + fit_time_s=_safe_float(fit_time), + predict_time_s=_safe_float(pred_time), + total_time_s=_safe_float(total_time), + + train_throughput_sps=_safe_float(train_throughput), + predict_throughput_sps=_safe_float(pred_throughput), + + test_mse=_safe_float(test_mse), + + hostname=platform.node(), + python=sys.version.replace("\n", " "), + platform=f"{platform.system()} {platform.release()} ({platform.machine()})", + torch=torch_ver, + lightning=lightning_ver, + git_commit=_maybe_git_commit(), + ) + + # Write per-run JSON on rank0 only + if rank == 0: + out_json = out_dir / f"result_{tag}.json" + with out_json.open("w") as f: + json.dump(asdict(result), f, indent=2) + print(f"[rank0] Wrote: {out_json}") + print( + f"[rank0] fit={result.fit_time_s:.3f}s " + f"pred={result.predict_time_s:.3f}s " + f"total={result.total_time_s:.3f}s " + f"train_thr={result.train_throughput_sps:.1f} samp/s " + f"pred_thr={result.predict_throughput_sps:.1f} samp/s " + f"test_mse={result.test_mse:.6f}" + ) + + return result + + +# ------------------------- +# Sweep + plotting +# ------------------------- + +def _load_results(out_dir: Path) -> Dict[str, Dict[str, Any]]: + results: Dict[str, Dict[str, Any]] = {} + for p in sorted(out_dir.glob("result_*.json")): + with p.open("r") as f: + d = json.load(f) + results[d["tag"]] = d + return results + + +def _write_csv(out_dir: Path, rows: Dict[str, Dict[str, Any]]) -> Path: + out_csv = out_dir / "results.csv" + keys = sorted(next(iter(rows.values())).keys()) + with out_csv.open("w", newline="") as f: + w = csv.DictWriter(f, fieldnames=keys) + w.writeheader() + for _tag, d in sorted(rows.items(), key=lambda kv: (kv[1]["accelerator"], kv[1]["devices"])): + w.writerow({k: d.get(k, None) for k in keys}) + return out_csv + + +def plot_results(out_dir: Path, baseline_devices: int = 1, include_cpu_speedup: bool = True) -> None: + # Use non-interactive backend for headless servers + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + rows = _load_results(out_dir) + if not rows: + raise RuntimeError(f"No result_*.json found in {out_dir}") + + # Split CPU vs GPU + cpu = [d for d in rows.values() if d["accelerator"] == "cpu"] + gpu = [d for d in rows.values() if d["accelerator"] == "gpu"] + + gpu_sorted = sorted(gpu, key=lambda d: d["devices"]) + cpu_sorted = sorted(cpu, key=lambda d: d["devices"]) + + # Baseline for speedup/efficiency + base_gpu = next((d for d in gpu_sorted if d["devices"] == baseline_devices), None) + if base_gpu is None and gpu_sorted: + base_gpu = gpu_sorted[0] + + # Helper series + def series(ds, key): + return [float(d[key]) for d in ds] + + # Plot 1: wall time (fit/predict/total) vs devices (GPU) + if gpu_sorted: + x = [d["devices"] for d in gpu_sorted] + + fit_t = series(gpu_sorted, "fit_time_s") + pred_t = series(gpu_sorted, "predict_time_s") + tot_t = series(gpu_sorted, "total_time_s") + + plt.figure(figsize=(8, 5)) + plt.plot(x, fit_t, marker="o", label="fit_time_s") + plt.plot(x, pred_t, marker="o", label="predict_time_s") + plt.plot(x, tot_t, marker="o", label="total_time_s") + plt.xlabel("GPUs (devices)") + plt.ylabel("Seconds") + plt.title("Wall time vs GPUs") + plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.6) + plt.legend() + p = out_dir / "wall_time_vs_gpus.png" + plt.tight_layout() + plt.savefig(p, dpi=200) + plt.close() + + # Plot 2: throughput vs devices (GPU) + if gpu_sorted: + x = [d["devices"] for d in gpu_sorted] + thr = series(gpu_sorted, "train_throughput_sps") + + plt.figure(figsize=(8, 5)) + plt.plot(x, thr, marker="o") + plt.xlabel("GPUs (devices)") + plt.ylabel("Train throughput (samples/sec, global)") + plt.title("Train throughput vs GPUs") + plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.6) + p = out_dir / "throughput_vs_gpus.png" + plt.tight_layout() + plt.savefig(p, dpi=200) + plt.close() + + # Plot 3: speedup + efficiency vs devices (GPU) + if gpu_sorted and base_gpu is not None: + x = [d["devices"] for d in gpu_sorted] + base_thr = float(base_gpu["train_throughput_sps"]) + speedup = [float(d["train_throughput_sps"]) / max(base_thr, 1e-9) for d in gpu_sorted] + efficiency = [s / max(dev, 1e-9) for s, dev in zip(speedup, x)] + + plt.figure(figsize=(8, 5)) + plt.plot(x, speedup, marker="o", label=f"Speedup vs {base_gpu['devices']} GPU") + plt.plot(x, x, linestyle="--", label="Ideal linear speedup") + plt.xlabel("GPUs (devices)") + plt.ylabel("Speedup") + plt.title("Speedup vs GPUs") + plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.6) + plt.legend() + p = out_dir / "speedup_vs_gpus.png" + plt.tight_layout() + plt.savefig(p, dpi=200) + plt.close() + + plt.figure(figsize=(8, 5)) + plt.plot(x, efficiency, marker="o") + plt.xlabel("GPUs (devices)") + plt.ylabel("Scaling efficiency (speedup / GPUs)") + plt.title("Scaling efficiency vs GPUs") + plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.6) + p = out_dir / "efficiency_vs_gpus.png" + plt.tight_layout() + plt.savefig(p, dpi=200) + plt.close() + + # Optional: CPU vs best GPU throughput comparison + if include_cpu_speedup and cpu_sorted and gpu_sorted: + cpu_thr = float(cpu_sorted[0]["train_throughput_sps"]) + best_gpu = max(gpu_sorted, key=lambda d: float(d["train_throughput_sps"])) + best_thr = float(best_gpu["train_throughput_sps"]) + ratio = best_thr / max(cpu_thr, 1e-9) + + plt.figure(figsize=(8, 5)) + labels = ["CPU (1)", f"GPU ({best_gpu['devices']})"] + vals = [cpu_thr, best_thr] + plt.bar(labels, vals) + plt.ylabel("Train throughput (samples/sec, global)") + plt.title(f"CPU vs best GPU throughput (GPU/CPU = {ratio:.2f}x)") + plt.grid(True, axis="y", linestyle="--", linewidth=0.6, alpha=0.6) + p = out_dir / "cpu_vs_best_gpu_throughput.png" + plt.tight_layout() + plt.savefig(p, dpi=200) + plt.close() + + # Write CSV for convenience + out_csv = _write_csv(out_dir, rows) + print(f"Wrote plots + {out_csv}") + + +def sweep(args: argparse.Namespace) -> None: + out_dir = Path(args.out_dir).resolve() + _ensure_dir(out_dir) + + # Write run config + run_cfg = vars(args).copy() + run_cfg["timestamp"] = _now_ts() + run_cfg["git_commit"] = _maybe_git_commit() + with (out_dir / "run_config.json").open("w") as f: + json.dump(run_cfg, f, indent=2) + + # Build base run args (forwarded to run subcommand) + base_run = [ + sys.executable, + str(Path(__file__).resolve()), + "run", + "--n", str(args.n), + "--c_dim", str(args.c_dim), + "--x_dim", str(args.x_dim), + "--y_dim", str(args.y_dim), + "--noise", str(args.noise), + "--val_frac", str(args.val_frac), + "--test_frac", str(args.test_frac), + "--epochs", str(args.epochs), + "--train_batch_size", str(args.train_batch_size), + "--val_batch_size", str(args.val_batch_size), + "--test_batch_size", str(args.test_batch_size), + "--num_workers", str(args.num_workers), + "--learning_rate", str(args.learning_rate), + "--num_archetypes", str(args.num_archetypes), + "--encoder_type", str(args.encoder_type), + "--backend", str(args.backend), + "--seed", str(args.seed), + "--out_dir", str(out_dir), + ] + if args.deterministic: + base_run.append("--deterministic") + + # 1) CPU (single proc) + if args.include_cpu: + cmd = base_run + ["--accelerator", "cpu", "--devices", "1", "--tag", "cpu_1dev"] + print("\n=== Running CPU (1 device) ===") + subprocess.run(cmd, check=True) + + # 2) GPU sweeps + for k in range(1, args.max_gpus + 1): + tag = f"gpu_{k}dev" + print(f"\n=== Running GPU ({k} device{'s' if k > 1 else ''}) ===") + if k == 1 and not args.force_torchrun_for_1gpu: + # Single process GPU + cmd = base_run + ["--accelerator", "gpu", "--devices", "1", "--tag", tag] + subprocess.run(cmd, check=True) + else: + # Multi-process launch (also works for 1 GPU if forced) + cmd = [ + sys.executable, "-m", "torch.distributed.run", + "--standalone", + f"--nproc_per_node={k}", + str(Path(__file__).resolve()), + "run", + "--accelerator", "gpu", + "--devices", str(k), + "--tag", tag, + "--out_dir", str(out_dir), + "--n", str(args.n), + "--c_dim", str(args.c_dim), + "--x_dim", str(args.x_dim), + "--y_dim", str(args.y_dim), + "--noise", str(args.noise), + "--val_frac", str(args.val_frac), + "--test_frac", str(args.test_frac), + "--epochs", str(args.epochs), + "--train_batch_size", str(args.train_batch_size), + "--val_batch_size", str(args.val_batch_size), + "--test_batch_size", str(args.test_batch_size), + "--num_workers", str(args.num_workers), + "--learning_rate", str(args.learning_rate), + "--num_archetypes", str(args.num_archetypes), + "--encoder_type", str(args.encoder_type), + "--backend", str(args.backend), + "--seed", str(args.seed), + ] + if args.deterministic: + cmd.append("--deterministic") + subprocess.run(cmd, check=True) + + # Plot at end + plot_results(out_dir, baseline_devices=args.speedup_baseline_gpu) + + +# ------------------------- +# CLI +# ------------------------- + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Contextualized regression scaling benchmark (synthetic).") + sub = p.add_subparsers(dest="cmd", required=True) + + common = argparse.ArgumentParser(add_help=False) + common.add_argument("--out_dir", type=str, required=True, help="Output directory for results and plots.") + common.add_argument("--tag", type=str, default="", help="Tag for this run; used in filename result_.json") + + common.add_argument("--n", type=int, default=200_000) + common.add_argument("--c_dim", type=int, default=16) + common.add_argument("--x_dim", type=int, default=64) + common.add_argument("--y_dim", type=int, default=8) + common.add_argument("--noise", type=float, default=0.5) + + common.add_argument("--val_frac", type=float, default=0.2) + common.add_argument("--test_frac", type=float, default=0.1) + + common.add_argument("--epochs", type=int, default=5) + common.add_argument("--train_batch_size", type=int, default=2048) + common.add_argument("--val_batch_size", type=int, default=2048) + common.add_argument("--test_batch_size", type=int, default=4096) + common.add_argument("--num_workers", type=int, default=4) + + common.add_argument("--learning_rate", type=float, default=1e-3) + common.add_argument("--num_archetypes", type=int, default=0) + common.add_argument("--encoder_type", type=str, default="mlp") + + common.add_argument("--seed", type=int, default=123) + common.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo"]) + + common.add_argument("--deterministic", action="store_true", help="Best-effort deterministic training.") + + # run + pr = sub.add_parser("run", parents=[common], help="Run one benchmark configuration.") + pr.add_argument("--accelerator", type=str, required=True, choices=["cpu", "gpu"]) + pr.add_argument("--devices", type=int, required=True) + + # sweep + ps = sub.add_parser("sweep", parents=[common], help="Run CPU + GPU(1..K) sweep and plot.") + ps.add_argument("--include_cpu", action="store_true", help="Include a CPU baseline run.") + ps.add_argument("--max_gpus", type=int, default=4, help="Max GPUs to sweep up to (inclusive).") + ps.add_argument("--force_torchrun_for_1gpu", action="store_true", + help="Also launch 1-GPU via torch.distributed.run for consistency.") + ps.add_argument("--speedup_baseline_gpu", type=int, default=1, + help="Baseline GPU count for speedup/efficiency plots (default: 1).") + + return p + + +def main() -> None: + parser = build_parser() + args = parser.parse_args() + + if args.cmd == "run": + # In DDP, only rank0 writes; non-rank0 returns None + run_one(args) + + elif args.cmd == "sweep": + sweep(args) + + else: + raise RuntimeError(f"Unknown cmd: {args.cmd}") + + +if __name__ == "__main__": + main() diff --git a/contextualized/easy/ContextualizedClassifier.py b/contextualized/easy/ContextualizedClassifier.py index 1360fd8d..32e5eb5a 100644 --- a/contextualized/easy/ContextualizedClassifier.py +++ b/contextualized/easy/ContextualizedClassifier.py @@ -20,30 +20,34 @@ def __init__(self, **kwargs): kwargs["loss_fn"] = LOSSES["bceloss"] super().__init__(**kwargs) - def predict(self, C, X, individual_preds=False, **kwargs): - """Predict binary outcomes from context C and predictors X.""" - out = super().predict(C, X, individual_preds, **kwargs) + def predict(self, C, X, individual_preds: bool = False, **kwargs): + out = super().predict(C, X, individual_preds=individual_preds, **kwargs) + if out is None: + return None + out = np.asarray(out) + if not individual_preds: + # common binary case: (N, 1, 1) or (N, 1) if out.ndim == 3 and out.shape[-1] == 1: out = out[..., 0] return np.round(out) - # individual_preds=True: list/array per-bootstrap -> squeeze each - return [np.round(p[..., 0] if (p.ndim == 3 and p.shape[-1] == 1) else p) for p in out] + # individual_preds=True: list/array across bootstraps + return [ + np.round(p[..., 0] if (p.ndim == 3 and p.shape[-1] == 1) else p) + for p in out + ] def predict_proba(self, C, X, **kwargs): - """ - Predict probabilities of outcomes from context C and predictors X. - - Returns - ------- - np.ndarray of shape (n_samples, y_dim, 2) - """ - probs = super().predict(C, X, **kwargs) # (n, y_dim[, 1]) + probs = super().predict(C, X, **kwargs) + if probs is None: + return None + probs = np.asarray(probs) if probs.ndim == 3 and probs.shape[-1] == 1: probs = probs[..., 0] + p1 = probs p0 = 1.0 - p1 - return np.stack([p0, p1], axis=-1) \ No newline at end of file + return np.stack([p0, p1], axis=-1) diff --git a/contextualized/easy/ContextualizedNetworks.py b/contextualized/easy/ContextualizedNetworks.py index a701a840..9e18e7b3 100644 --- a/contextualized/easy/ContextualizedNetworks.py +++ b/contextualized/easy/ContextualizedNetworks.py @@ -1,11 +1,18 @@ """ sklearn-like interface to Contextualized Networks. + +CPU/DDP FIXES (drag-and-drop): +1) When using a LightningDataModule outside Trainer.fit/predict, you MUST call + dm.setup(stage="predict") before dm.predict_dataloader(). +2) Under DDP, prediction helpers are rank-0 only (by design in your trainers/wrapper). + We therefore early-return None on non-rank0 to avoid constructing np.array([None,...]). """ from typing import List, Tuple, Union, Optional import numpy as np import torch +import torch.distributed as dist from contextualized.easy.wrappers import SKLearnWrapper from contextualized.regression.trainers import CorrelationTrainer, MarkovTrainer @@ -22,6 +29,16 @@ from contextualized.dags.graph_utils import dag_pred_np +def _is_distributed() -> bool: + return dist.is_available() and dist.is_initialized() + + +def _rank() -> int: + if _is_distributed(): + return dist.get_rank() + return 0 + + class ContextualizedNetworks(SKLearnWrapper): """ sklearn-like interface to Contextualized Networks. @@ -65,15 +82,26 @@ def predict_networks( List[np.ndarray], Tuple[np.ndarray, np.ndarray], Tuple[List[np.ndarray], List[np.ndarray]], + None, ]: """ Predicts context-specific network parameters (and offsets if available). + + DDP behavior: + - rank0 returns arrays/tuples + - non-rank0 returns None """ - betas, mus = self.predict_params( - C, individual_preds=individual_preds, uses_y=False, **kwargs - ) + out = self.predict_params(C, individual_preds=individual_preds, uses_y=False, **kwargs) + if out is None: + return None + + betas, mus = out + if betas is None: + return None + return (betas, mus) if with_offsets else betas + def predict_X( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs ) -> Union[np.ndarray, List[np.ndarray]]: @@ -96,9 +124,18 @@ def __init__(self, **kwargs): def predict_correlation( self, C: np.ndarray, individual_preds: bool = True, squared: bool = True - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: + """ + Returns per-sample correlation matrices (or squared correlations). + + DDP behavior: + - All ranks must execute the predict loop to avoid collective mismatches. + - rank0 returns arrays + - non-rank0 returns None (propagated from trainer) + """ C_scaled = self._maybe_scale_C(C) Y_zero = np.zeros((len(C_scaled), self.x_dim), dtype=np.float32) + dm = self._build_datamodule( C=C_scaled, X=np.zeros((len(C_scaled), self.x_dim), dtype=np.float32), @@ -110,60 +147,50 @@ def predict_correlation( test_batch_size=self._init_kwargs["data"].get("test_batch_size", 16), predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", 16), num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator in ("cuda", "gpu"))), + pin_memory=self._init_kwargs["data"].get( + "pin_memory", (self.accelerator in ("cuda", "gpu")) + ), persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), + drop_last=False, shuffle_train=False, shuffle_eval=False, dtype=self._init_kwargs["data"].get("dtype", torch.float), ), - task_type="singletask_univariate", # correlation uses univariate convention ) - rhos = np.array([ - self.trainers[i].predict_correlation(self.models[i], dm.predict_dataloader()) - for i in range(len(self.models)) - ]) + + # CRITICAL FIX: setup before calling predict_dataloader() when not using Trainer.predict(datamodule=...) + dm.setup(stage="predict") + pred_loader = dm.predict_dataloader() + + rhos_list = [] + for i in range(len(self.models)): + rho_i = self.trainers[i].predict_correlation(self.models[i], pred_loader) + if rho_i is None: + # non-rank0 under DDP + return None + rhos_list.append(rho_i) + + rhos = np.array(rhos_list) + if individual_preds: return np.square(rhos) if squared else rhos + mean_rhos = np.mean(rhos, axis=0) return np.square(mean_rhos) if squared else mean_rhos + def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """ - Measures mean-squared reconstruction errors between the true X and the - reconstructed X_hat produced by the contextualized correlation network. - - Parameters - ---------- - C : np.ndarray - Context matrix of shape (N, C_dim). - X : np.ndarray - Data matrix of shape (N, F). - individual_preds : bool, default False - If False: return per-sample MSE averaged over bootstraps. - If True: return per-bootstrap, per-sample MSE. - - Returns - ------- - np.ndarray - If individual_preds is False: shape (N_eff,), per-sample MSE averaged - over bootstraps. - - If individual_preds is True: shape (B, N_eff), per-bootstrap, per-sample MSE. - - Notes - ----- - In single-process (non-distributed) settings, N_eff == N (full dataset). - - Under distributed settings, predict_X may operate on rank-local shards so - the number of samples in X_hat (N_hat) may differ from len(X) (N_true). - In that case we align both X_hat and X to N_eff = min(N_hat, N_true) to - avoid shape mismatches, yielding valid MSEs for the evaluated subset. + Measures mean-squared reconstruction errors between true X and reconstructed X_hat. + (Behavior unchanged; this already handles N_hat != N_true.) """ - # Predict reconstructions of X for each bootstrap model X_hat = self.predict_X(C, X, individual_preds=True) + if X_hat is None: + return None + X_hat = np.array(X_hat) if X_hat.ndim not in (3, 4): @@ -172,11 +199,9 @@ def measure_mses( "ContextualizedCorrelationNetworks.measure_mses" ) - # X: (N_true, F) N_true, F = X.shape if X_hat.ndim == 3: - # X_hat: (B, N_hat, F_hat) B, N_hat, F_hat = X_hat.shape if F_hat != F: raise ValueError( @@ -184,21 +209,18 @@ def measure_mses( "in ContextualizedCorrelationNetworks.measure_mses" ) - # Align on the sample dimension N_eff = min(N_hat, N_true) if N_hat != N_true: X_hat = X_hat[:, :N_eff, :] X_eff = X[:N_eff, :] else: - N_eff = N_true X_eff = X - X_true = X_eff[None, :, :] # (1, N_eff, F) - residuals = X_hat - X_true # (B, N_eff, F) - mses = (residuals ** 2).mean(axis=-1) # (B, N_eff) + X_true = X_eff[None, :, :] + residuals = X_hat - X_true + mses = (residuals ** 2).mean(axis=-1) - else: # X_hat.ndim == 4 - # X_hat: (B, N_hat, F1, F2) + else: B, N_hat, F1, F2 = X_hat.shape if F1 != F: raise ValueError( @@ -211,22 +233,15 @@ def measure_mses( X_hat = X_hat[:, :N_eff, :, :] X_eff = X[:N_eff, :] else: - N_eff = N_true X_eff = X - X_true = X_eff[None, :, :, None] # (1, N_eff, F, 1) - residuals = X_hat - X_true # (B, N_eff, F, F2) - mses = (residuals ** 2).mean(axis=(-1, -2)) # (B, N_eff) + X_true = X_eff[None, :, :, None] + residuals = X_hat - X_true + mses = (residuals ** 2).mean(axis=(-1, -2)) - # mses: (B, N_eff) return mses if individual_preds else mses.mean(axis=0) - - - - - class ContextualizedMarkovNetworks(ContextualizedNetworks): """ Contextualized Markov Networks (Gaussian precision matrices). @@ -237,12 +252,18 @@ def __init__(self, **kwargs): def predict_precisions( self, C: np.ndarray, individual_preds: bool = True - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """ Predicts context-specific precision matrices. + + DDP behavior: + - All ranks must execute the predict loop to avoid collective mismatches. + - rank0 returns arrays + - non-rank0 returns None (propagated from trainer) """ C_scaled = self._maybe_scale_C(C) Y_zero = np.zeros((len(C_scaled), self.x_dim), dtype=np.float32) + dm = self._build_datamodule( C=C_scaled, X=np.zeros((len(C_scaled), self.x_dim), dtype=np.float32), @@ -254,29 +275,46 @@ def predict_precisions( test_batch_size=self._init_kwargs["data"].get("test_batch_size", 16), predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", 16), num_workers=self._init_kwargs["data"].get("num_workers", 0), - pin_memory=self._init_kwargs["data"].get("pin_memory", (self.accelerator in ("cuda", "gpu"))), + pin_memory=self._init_kwargs["data"].get( + "pin_memory", (self.accelerator in ("cuda", "gpu")) + ), persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), + drop_last=False, shuffle_train=False, shuffle_eval=False, dtype=self._init_kwargs["data"].get("dtype", torch.float), ), - task_type="singletask_univariate", ) - precisions = np.array([ - self.trainers[i].predict_precision(self.models[i], dm.predict_dataloader()) - for i in range(len(self.models)) - ]) + + # CRITICAL FIX: setup before calling predict_dataloader() + dm.setup(stage="predict") + pred_loader = dm.predict_dataloader() + + prec_list = [] + for i in range(len(self.models)): + p_i = self.trainers[i].predict_precision(self.models[i], pred_loader) + if p_i is None: + # non-rank0 under DDP + return None + prec_list.append(p_i) + + precisions = np.array(prec_list) return precisions if individual_preds else np.mean(precisions, axis=0) + def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """ Measures mean-squared reconstruction errors using precision-implied betas/mus. """ - betas, mus = self.predict_networks(C, individual_preds=True, with_offsets=True) - mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples + out = self.predict_networks(C, individual_preds=True, with_offsets=True) + if out is None: + return None + betas, mus = out + + mses = np.zeros((len(betas), len(C))) F = X.shape[-1] for b in range(len(betas)): for i in range(F): @@ -300,7 +338,6 @@ def _parse_private_init_kwargs(self, **kwargs): """ Parse NOTMAD kwargs into model init dicts. """ - # Encoder Parameters self._init_kwargs["model"]["encoder_kwargs"] = { "type": kwargs.pop( "encoder_type", self._init_kwargs["model"]["encoder_type"] @@ -312,7 +349,6 @@ def _parse_private_init_kwargs(self, **kwargs): }, } - # Archetype parameters archetype_dag_loss_type = kwargs.pop( "archetype_dag_loss_type", DEFAULT_DAG_LOSS_TYPE ) @@ -339,7 +375,6 @@ def _parse_private_init_kwargs(self, **kwargs): ) self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] = 16 - # Allow convenience overrides for archetype DAG params for param, value in self._init_kwargs["model"]["archetype_loss_params"]["dag"][ "params" ].items(): @@ -347,7 +382,6 @@ def _parse_private_init_kwargs(self, **kwargs): param ] = kwargs.pop(f"archetype_{param}", value) - # Sample-specific parameters sample_specific_dag_loss_type = kwargs.pop( "sample_specific_dag_loss_type", DEFAULT_DAG_LOSS_TYPE ) @@ -371,7 +405,6 @@ def _parse_private_init_kwargs(self, **kwargs): param ] = kwargs.pop(f"sample_specific_{param}", value) - # Optimization parameters self._init_kwargs["model"]["opt_params"] = { "learning_rate": kwargs.pop("learning_rate", 1e-3), "step": kwargs.pop("step", 50), @@ -422,16 +455,15 @@ def __init__(self, **kwargs): def predict_params( self, C: np.ndarray, **kwargs - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """ Predicts context-specific Bayesian network parameters (SEM coefficients). """ - # No mus for NOTMAD at present. return super().predict_params(C, model_includes_mus=False, **kwargs) def predict_networks( self, C: np.ndarray, project_to_dag: bool = True, **kwargs - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """ Predicts context-specific Bayesian networks (optionally projected to DAG). """ @@ -444,13 +476,16 @@ def predict_networks( def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> Union[np.ndarray, List[np.ndarray], None]: """ Measures mean-squared errors of DAG-based reconstruction. """ betas = self.predict_networks(C, individual_preds=True, **kwargs) - mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples + if betas is None: + return None + + mses = np.zeros((len(betas), len(C))) for b in range(len(betas)): X_pred = dag_pred_np(X, betas[b]) mses[b, :] = np.mean((X - X_pred) ** 2, axis=1) - return mses if individual_preds else np.mean(mses, axis=0) \ No newline at end of file + return mses if individual_preds else np.mean(mses, axis=0) diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index ade6f4ca..f054836e 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -46,26 +46,141 @@ def _is_main_process() -> bool: """Check if this is the main process (rank 0).""" return _get_rank() == 0 +def _flatten_pl_predict_output(preds): + """ + Lightning can return: + - list[dict] (single dataloader) + - list[list[dict]] (multiple dataloaders) + Normalize to list[dict]. + """ + if preds is None: + return [] + if len(preds) > 0 and isinstance(preds[0], list): + out = [] + for sub in preds: + out.extend(sub) + return out + return preds + -def _gather_predictions(local_preds: np.ndarray, world_size: int) -> np.ndarray: +def _pack_local_pred_payload(pred_list: list) -> dict: """ - Gather predictions from all ranks to rank 0. - Returns full predictions on rank 0, None on other ranks. + Convert list[dict] -> dict[str, np.ndarray] by concatenating along axis 0. + Assumes each dict entry is either a torch.Tensor (CPU) or a Python scalar. """ - if not _is_distributed() or world_size == 1: - return local_preds - - local_tensor = torch.from_numpy(local_preds).cuda() - + pred_list = _flatten_pl_predict_output(pred_list) + if not pred_list: + return {} + + # Union of keys across batches (some models include extra keys) + keys = set() + for d in pred_list: + keys.update(d.keys()) + + packed = {} + for k in keys: + chunks = [] + for d in pred_list: + if k not in d: + continue + v = d[k] + if torch.is_tensor(v): + chunks.append(v.detach().cpu().numpy()) + else: + chunks.append(np.asarray(v)) + if not chunks: + continue + # Concatenate on first dim where possible; fallback to stack + try: + packed[k] = np.concatenate(chunks, axis=0) + except Exception: + packed[k] = np.stack(chunks, axis=0) + return packed + + +def _gather_object_to_rank0(obj): + """ + Gather arbitrary Python objects to rank 0. + Returns: list[obj] on rank 0, None on non-zero ranks. + """ + if not _is_distributed(): + return [obj] + + world_size = dist.get_world_size() + if world_size == 1: + return [obj] + if _is_main_process(): - gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)] - dist.gather(local_tensor, gather_list=gathered, dst=0) - return torch.cat(gathered, dim=0).cpu().numpy() + gathered = [None for _ in range(world_size)] + dist.gather_object(obj, object_gather_list=gathered, dst=0) + return gathered else: - dist.gather(local_tensor, dst=0) + dist.gather_object(obj, object_gather_list=None, dst=0) return None +def _merge_packed_payloads(payloads: list) -> dict: + """ + Merge list[dict[str, np.ndarray]] -> dict[str, np.ndarray] by concatenation axis 0. + """ + merged = {} + if not payloads: + return merged + + keys = set() + for p in payloads: + if p: + keys.update(p.keys()) + + for k in keys: + chunks = [p[k] for p in payloads if p and (k in p) and (p[k] is not None) and (len(p[k]) > 0)] + if not chunks: + continue + merged[k] = np.concatenate(chunks, axis=0) + return merged + + +def _stable_sort_and_dedupe_by_key(payload: dict, primary: str, secondary: tuple = ()) -> dict: + """ + Sort payload arrays by a composite key (primary + optional secondary indices), + then dedupe (needed because DistributedSampler may pad/duplicate). + """ + if (payload is None) or (primary not in payload) or (len(payload[primary]) == 0): + return payload + + primary_arr = payload[primary].astype(np.int64) + + # Build composite key + if secondary: + parts = [primary_arr] + for s in secondary: + if s in payload: + parts.append(payload[s].astype(np.int64)) + if len(parts) == 1: + key = primary_arr + else: + # lexsort uses last key as primary; reverse order + order = np.lexsort(tuple(reversed(parts))) + key_sorted = np.stack([p[order] for p in parts], axis=1) + # Dedup by full composite row + _, uniq_pos = np.unique(key_sorted, axis=0, return_index=True) + keep = order[np.sort(uniq_pos)] + else: + order = np.argsort(primary_arr, kind="mergesort") + key_sorted = primary_arr[order] + _, uniq_pos = np.unique(key_sorted, return_index=True) + keep = order[np.sort(uniq_pos)] + + out = {} + for k, v in payload.items(): + if isinstance(v, np.ndarray) and (v.shape[0] == primary_arr.shape[0]): + out[k] = v[keep] + else: + out[k] = v + return out + + + class SKLearnWrapper: """ An sklearn-like wrapper for Contextualized models. @@ -627,9 +742,6 @@ def _get_inference_device(self) -> torch.device: return torch.device("cpu") def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs): - """ - FIXED: Proper single-device inference that works after DDP training. - """ if not hasattr(self, "models") or self.models is None: raise ValueError("Trying to predict with a model that hasn't been trained yet.") @@ -637,10 +749,6 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, Xq = self._maybe_scale_X(X) Yq = np.zeros((len(Cq), self.y_dim), dtype=np.float32) - # FIXED: Use single device for inference - device = self._get_inference_device() - - # Build dataloader without distributed sampler dm = self._build_datamodule( C=Cq, X=Xq, @@ -651,7 +759,7 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, val_batch_size=self._init_kwargs["data"].get("val_batch_size", self.default_val_batch_size), test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), - num_workers=0, # Single-threaded for inference simplicity + num_workers=0, pin_memory=False, persistent_workers=False, shuffle_train=False, @@ -661,54 +769,102 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, task_type="singletask_univariate" if self._init_kwargs["model"].get("univariate", False) else "singletask_multivariate", ) - - # Setup the datamodule - dm.setup(stage="predict") - pred_loader = dm.predict_dataloader() + # Let Lightning handle sharding under DDP preds = [] + n_expected = len(Cq) + for i in range(len(self.models)): model = self.models[i] model.eval() - model.to(device) - - out_batches = [] - with torch.no_grad(): - for b_idx, batch in enumerate(pred_loader): - # Move batch to device - batch = { - k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) - for k, v in batch.items() - } - - out = model.predict_step(batch, b_idx) - - Cb = out.get("contexts") - Xb = out.get("predictors") - betas = out["betas"] - mus = out["mus"] - - yb = model._predict_y(Cb, Xb, betas, mus) - out_batches.append(yb.detach().cpu()) - - yhat = torch.cat(out_batches, dim=0).numpy() - preds.append(yhat) - predictions = np.array(preds) + # Prefer the trainer created during fit (keeps strategy/devices consistent) + trainer = None + if hasattr(self, "trainers") and self.trainers is not None and i < len(self.trainers): + trainer = self.trainers[i] + if _is_distributed() and trainer is not None: + # ---- DDP path: use trainer.predict + gather outputs to rank 0 ---- + local_pred = trainer.predict(model, datamodule=dm) + + local_packed = _pack_local_pred_payload(local_pred) + gathered = _gather_object_to_rank0(local_packed) + + if not _is_main_process(): + # Non-zero ranks return nothing; rank 0 will return the final answer. + return None + + merged = _merge_packed_payloads(gathered) + + # Sort/dedupe by orig_idx (DistributedSampler may pad) + merged = _stable_sort_and_dedupe_by_key(merged, primary="orig_idx") + + if "betas" not in merged or "mus" not in merged or "orig_idx" not in merged: + raise RuntimeError("predict: Missing required keys in gathered payload: need orig_idx, betas, mus.") + + orig_idx = merged["orig_idx"].astype(np.int64) + betas = torch.as_tensor(merged["betas"]) + mus = torch.as_tensor(merged["mus"]) + + # Ensure we are aligned to query order + # (orig_idx is row-id into the query arrays because predict_idx=np.arange(n)) + C_sorted = torch.as_tensor(Cq[orig_idx], dtype=betas.dtype) + X_sorted = torch.as_tensor(Xq[orig_idx], dtype=betas.dtype) + + # Compute yhat on rank 0 in correct global order + with torch.no_grad(): + yhat = model._predict_y(C_sorted, X_sorted, betas, mus).detach().cpu().numpy() + + # If DDP padded, we may have > n_expected; trim safely by orig_idx range + # (should not happen if orig_idx is in [0, n_expected)) + if yhat.shape[0] != n_expected: + # Build dense output in original query order + dense = np.zeros((n_expected,) + yhat.shape[1:], dtype=yhat.dtype) + dense[orig_idx] = yhat + yhat = dense + + preds.append(yhat) + + else: + # ---- Single-process fallback: iterate predict_dataloader directly ---- + dm.setup(stage="predict") + pred_loader = dm.predict_dataloader() + + out_batches = [] + device = self._get_inference_device() + model.to(device) + + with torch.no_grad(): + for b_idx, batch in enumerate(pred_loader): + batch = { + k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) + for k, v in batch.items() + } + + out = model.predict_step(batch, b_idx) + betas = out["betas"] + mus = out["mus"] + + # IMPORTANT: use the *batch* for C/X, not the output payload + yb = model._predict_y(batch["contexts"], batch["predictors"], betas, mus) + out_batches.append(yb.detach().cpu()) + + yhat = torch.cat(out_batches, dim=0).numpy() + preds.append(yhat) + + predictions = np.array(preds) if not individual_preds: predictions = np.mean(predictions, axis=0) if self.normalize and self.scalers["Y"] is not None: if individual_preds: - predictions = np.array( - [self.scalers["Y"].inverse_transform(p) for p in predictions] - ) + predictions = np.array([self.scalers["Y"].inverse_transform(p) for p in predictions]) else: predictions = self.scalers["Y"].inverse_transform(predictions) return predictions + def predict_params( self, C: np.ndarray, @@ -716,9 +872,6 @@ def predict_params( model_includes_mus: bool = True, **kwargs, ): - """ - FIXED: Proper single-device inference for parameter prediction. - """ if not hasattr(self, "models") or self.models is None: raise ValueError("Trying to predict with a model that hasn't been trained yet.") @@ -727,7 +880,6 @@ def predict_params( Y_zero = np.zeros((len(Cq), self.y_dim), dtype=np.float32) uses_y = kwargs.pop("uses_y", True) - device = self._get_inference_device() dm = self._build_datamodule( C=Cq, @@ -749,51 +901,87 @@ def predict_params( task_type="singletask_univariate" if self._init_kwargs["model"].get("univariate", False) else "singletask_multivariate", ) - - dm.setup(stage="predict") - pred_loader = dm.predict_dataloader() out_betas, out_mus = [], [] + n_expected = len(Cq) for i in range(len(self.models)): model = self.models[i] model.eval() - model.to(device) - beta_batches, mu_batches = [], [] + trainer = None + if hasattr(self, "trainers") and self.trainers is not None and i < len(self.trainers): + trainer = self.trainers[i] + + if _is_distributed() and trainer is not None: + local_pred = trainer.predict(model, datamodule=dm) + local_packed = _pack_local_pred_payload(local_pred) + gathered = _gather_object_to_rank0(local_packed) + + if not _is_main_process(): + return (None, None) if model_includes_mus else None + - with torch.no_grad(): - for b_idx, batch in enumerate(pred_loader): - batch = { - k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) - for k, v in batch.items() - } - out = model.predict_step(batch, b_idx) + merged = _merge_packed_payloads(gathered) + merged = _stable_sort_and_dedupe_by_key(merged, primary="orig_idx") - betas_b = out["betas"].detach().cpu() - beta_batches.append(betas_b) + if "betas" not in merged or "orig_idx" not in merged: + raise RuntimeError("predict_params: Missing required keys in gathered payload: need orig_idx, betas.") - if model_includes_mus: - mus_b = out["mus"].detach().cpu() - mu_batches.append(mus_b) + orig_idx = merged["orig_idx"].astype(np.int64) + + betas_i = merged["betas"] + if betas_i.shape[0] != n_expected: + dense_b = np.zeros((n_expected,) + betas_i.shape[1:], dtype=betas_i.dtype) + dense_b[orig_idx] = betas_i + betas_i = dense_b - betas_i = torch.cat(beta_batches, dim=0).numpy() - if model_includes_mus: - mus_i = torch.cat(mu_batches, dim=0).numpy() out_betas.append(betas_i) - out_mus.append(mus_i) + + if model_includes_mus: + if "mus" not in merged: + raise RuntimeError("predict_params: model_includes_mus=True but mus missing in payload.") + mus_i = merged["mus"] + if mus_i.shape[0] != n_expected: + dense_m = np.zeros((n_expected,) + mus_i.shape[1:], dtype=mus_i.dtype) + dense_m[orig_idx] = mus_i + mus_i = dense_m + out_mus.append(mus_i) + else: + # Single-process fallback (local ordered) + dm.setup(stage="predict") + pred_loader = dm.predict_dataloader() + + device = self._get_inference_device() + model.to(device) + + beta_batches, mu_batches = [], [] + with torch.no_grad(): + for b_idx, batch in enumerate(pred_loader): + batch = { + k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) + for k, v in batch.items() + } + out = model.predict_step(batch, b_idx) + beta_batches.append(out["betas"].detach().cpu()) + if model_includes_mus: + mu_batches.append(out["mus"].detach().cpu()) + + betas_i = torch.cat(beta_batches, dim=0).numpy() out_betas.append(betas_i) + if model_includes_mus: + mus_i = torch.cat(mu_batches, dim=0).numpy() + out_mus.append(mus_i) + + betas = np.array(out_betas) if model_includes_mus: - betas = np.array(out_betas) mus = np.array(out_mus) - if individual_preds: - return betas, mus - return np.mean(betas, axis=0), np.mean(mus, axis=0) - else: - betas = np.array(out_betas) - return betas if individual_preds else np.mean(betas, axis=0) + return (betas, mus) if individual_preds else (np.mean(betas, axis=0), np.mean(mus, axis=0)) + + return betas if individual_preds else np.mean(betas, axis=0) + def fit(self, *args, **kwargs) -> None: """ diff --git a/contextualized/regression/datamodules.py b/contextualized/regression/datamodules.py index 82539886..ff2df7c3 100644 --- a/contextualized/regression/datamodules.py +++ b/contextualized/regression/datamodules.py @@ -47,6 +47,18 @@ def _maybe_index(x: torch.Tensor, idx: IndexLike) -> torch.Tensor: # assume Sequence[int] return x[torch.as_tensor(idx, dtype=torch.long)] +def _to_index_tensor(idx: IndexLike) -> Optional[torch.Tensor]: + """Normalize an index-like into a 1D CPU LongTensor.""" + if idx is None: + return None + if isinstance(idx, torch.Tensor): + out = idx.to(dtype=torch.long, device="cpu") + elif isinstance(idx, np.ndarray): + out = torch.as_tensor(idx, dtype=torch.long, device="cpu") + else: + # assume Sequence[int] + out = torch.as_tensor(idx, dtype=torch.long, device="cpu") + return out.view(-1) # ensure 1D class ContextualizedRegressionDataModule(pl.LightningDataModule): """ @@ -164,18 +176,22 @@ def setup(self, stage: Optional[str] = None) -> None: def _mk_dataset(idx: IndexLike): if idx is None: return None - C_s = _maybe_index(C, idx) - X_s = _maybe_index(X, idx) - Y_s = None if (Y is None) else _maybe_index(Y, idx) + + idx_t = _to_index_tensor(idx) # <-- NEW: stable mapping to original rows + + C_s = _maybe_index(C, idx_t) + X_s = _maybe_index(X, idx_t) + Y_s = None if (Y is None) else _maybe_index(Y, idx_t) ds_cls = TASK_TO_DATASET[self.task_type] if Y_s is None: # Allow unsupervised / network-style usage where Y is omitted. # In that case, use X as a dummy target so shapes line up. - # This mirrors the old CorrelationDataModule behavior (Y = X). Y_s = X_s - return ds_cls(C_s, X_s, Y_s, dtype=self.dtype) + # IMPORTANT: pass orig_idx so every item can report its original row id + return ds_cls(C_s, X_s, Y_s, orig_idx=idx_t, dtype=self.dtype) + @@ -188,24 +204,25 @@ def _mk_dataset(idx: IndexLike): self.C, self.X, self.Y = C, X, Y # ---- Dataloaders ---- - def _common_dl_kwargs(self, batch_size: int) -> Dict: + def _common_dl_kwargs(self, batch_size: int, *, drop_last: Optional[bool] = None) -> Dict: return { "batch_size": batch_size, "num_workers": self.num_workers, "pin_memory": self.pin_memory, "persistent_workers": bool(self.num_workers > 0 and self.persistent_workers), - "drop_last": self.drop_last, + "drop_last": self.drop_last if drop_last is None else bool(drop_last), } + def train_dataloader(self) -> DataLoader: if self.ds_train is None: raise RuntimeError("train dataset is not set; provide train_idx or splitter.") return DataLoader( dataset=self.ds_train, shuffle=self.shuffle_train, - **self._common_dl_kwargs(self.train_batch_size), + **self._common_dl_kwargs(self.train_batch_size, drop_last=self.drop_last), ) def val_dataloader(self): @@ -214,7 +231,8 @@ def val_dataloader(self): return DataLoader( dataset=self.ds_val, shuffle=self.shuffle_eval, - **self._common_dl_kwargs(self.val_batch_size), + # NEVER drop samples for eval (avoids silent data loss / mis-ordering) + **self._common_dl_kwargs(self.val_batch_size, drop_last=False), ) def test_dataloader(self) -> DataLoader: @@ -223,7 +241,7 @@ def test_dataloader(self) -> DataLoader: return DataLoader( dataset=self.ds_test, shuffle=self.shuffle_eval, - **self._common_dl_kwargs(self.test_batch_size), + **self._common_dl_kwargs(self.test_batch_size, drop_last=False), ) def predict_dataloader(self) -> DataLoader: @@ -232,5 +250,6 @@ def predict_dataloader(self) -> DataLoader: return DataLoader( dataset=self.ds_predict, shuffle=False, - **self._common_dl_kwargs(self.predict_batch_size), + **self._common_dl_kwargs(self.predict_batch_size, drop_last=False), ) + diff --git a/contextualized/regression/datasets.py b/contextualized/regression/datasets.py index ce93708f..64791876 100644 --- a/contextualized/regression/datasets.py +++ b/contextualized/regression/datasets.py @@ -11,21 +11,31 @@ class MultivariateDataset(Dataset): """ Simple multivariate dataset with context, predictors, and outcomes. """ - def __init__(self, C, X, Y, dtype=torch.float): + def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): self.C = torch.as_tensor(C, dtype=dtype) self.X = torch.as_tensor(X, dtype=dtype) self.Y = torch.as_tensor(Y, dtype=dtype) - self.c_dim = C.shape[-1] - self.x_dim = X.shape[-1] - self.y_dim = Y.shape[-1] + + # NEW: stable original-row index for distributed ordered gather + # FIX: enforce 1D LongTensor when provided + if orig_idx is None: + self.orig_idx = torch.arange(len(self.C), dtype=torch.long) + else: + self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) + + # FIX: derive dims from converted tensors to prevent shape mismatches + self.c_dim = self.C.shape[-1] + self.x_dim = self.X.shape[-1] + self.y_dim = self.Y.shape[-1] self.dtype = dtype - + def __len__(self): return len(self.C) - + def __getitem__(self, idx): return { - "idx": idx, + "idx": idx, # dataset-local position + "orig_idx": self.orig_idx[idx], # NEW: original-row id "contexts": self.C[idx], "predictors": self.X[idx].expand(self.y_dim, -1), "outcomes": self.Y[idx].unsqueeze(-1), @@ -36,21 +46,31 @@ class UnivariateDataset(Dataset): """ Simple univariate dataset with context, predictors, and one outcome. """ - def __init__(self, C, X, Y, dtype=torch.float): + def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): self.C = torch.as_tensor(C, dtype=dtype) self.X = torch.as_tensor(X, dtype=dtype) self.Y = torch.as_tensor(Y, dtype=dtype) - self.c_dim = C.shape[-1] - self.x_dim = X.shape[-1] - self.y_dim = Y.shape[-1] + + # NEW: stable original-row index + # FIX: enforce 1D LongTensor when provided + if orig_idx is None: + self.orig_idx = torch.arange(len(self.C), dtype=torch.long) + else: + self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) + + # FIX: derive dims from converted tensors to prevent shape mismatches + self.c_dim = self.C.shape[-1] + self.x_dim = self.X.shape[-1] + self.y_dim = self.Y.shape[-1] self.dtype = dtype - + def __len__(self): return len(self.C) - + def __getitem__(self, idx): return { "idx": idx, + "orig_idx": self.orig_idx[idx], # NEW "contexts": self.C[idx], "predictors": self.X[idx].expand(self.y_dim, -1).unsqueeze(-1), "outcomes": self.Y[idx].expand(self.x_dim, -1).T.unsqueeze(-1), @@ -61,85 +81,85 @@ class MultitaskMultivariateDataset(Dataset): """ Multi-task Multivariate Dataset. """ - def __init__(self, C, X, Y, dtype=torch.float): + def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): self.C = C.to(dtype) if isinstance(C, torch.Tensor) else torch.as_tensor(C, dtype=dtype) self.X = X.to(dtype) if isinstance(X, torch.Tensor) else torch.as_tensor(X, dtype=dtype) self.Y = Y.to(dtype) if isinstance(Y, torch.Tensor) else torch.as_tensor(Y, dtype=dtype) - self.c_dim = C.shape[-1] - self.x_dim = X.shape[-1] - self.y_dim = Y.shape[-1] + # NEW: stable original-row index per sample + # FIX: enforce 1D LongTensor when provided + if orig_idx is None: + self.orig_idx = torch.arange(len(self.C), dtype=torch.long) + else: + self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) + + # FIX: derive dims from converted tensors to prevent shape mismatches + self.c_dim = self.C.shape[-1] + self.x_dim = self.X.shape[-1] + self.y_dim = self.Y.shape[-1] self.dtype = dtype - + def __len__(self): return len(self.C) * self.y_dim - + def __getitem__(self, idx): - # Get task-split sample indices n_i = idx // self.y_dim y_i = idx % self.y_dim - # Create a one-hot encoding for the task - t = torch.zeros(self.y_dim) + + # Minor improvement: task vector dtype matches dataset dtype + t = torch.zeros(self.y_dim, dtype=self.dtype) t[y_i] = 1 + return { - "idx": idx, + "idx": idx, # dataset-item index + "orig_idx": self.orig_idx[n_i], # NEW: original-row id of the sample "contexts": self.C[n_i], "task": t, "predictors": self.X[n_i], "outcomes": self.Y[n_i, y_i].unsqueeze(0), - "sample_idx": n_i, + "sample_idx": n_i, # local sample index within this dataset "outcome_idx": y_i, } - # def __next__(self): - # if self.y_i >= self.y_dim: - # self.n_i += 1 - # self.y_i = 0 - # if self.n_i >= self.n: - # self.n_i = 0 - # raise StopIteration - # t = torch.zeros(self.y_dim) - # t[self.y_i] = 1 - # ret = ( - # self.C[self.n_i], - # t, - # self.X[self.n_i], - # self.Y[self.n_i, self.y_i].unsqueeze(0), - # self.n_i, - # self.y_i, - # ) - # self.y_i += 1 - # return ret - class MultitaskUnivariateDataset(Dataset): """ Multitask Univariate Dataset. Splits each sample into univariate X and Y feature pairs for univariate regression tasks. - """ - def __init__(self, C, X, Y, dtype=torch.float): + """ + def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): self.C = torch.as_tensor(C, dtype=dtype) self.X = torch.as_tensor(X, dtype=dtype) self.Y = torch.as_tensor(Y, dtype=dtype) - self.c_dim = C.shape[-1] - self.x_dim = X.shape[-1] - self.y_dim = Y.shape[-1] + + # NEW: stable original-row index per sample + # FIX: enforce 1D LongTensor when provided + if orig_idx is None: + self.orig_idx = torch.arange(len(self.C), dtype=torch.long) + else: + self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) + + # FIX: derive dims from converted tensors to prevent shape mismatches + self.c_dim = self.C.shape[-1] + self.x_dim = self.X.shape[-1] + self.y_dim = self.Y.shape[-1] self.dtype = dtype - + def __len__(self): return len(self.C) * self.x_dim * self.y_dim - + def __getitem__(self, idx): - # Get task-split sample indices n_i = idx // (self.x_dim * self.y_dim) x_i = (idx // self.y_dim) % self.x_dim y_i = idx % self.y_dim - # Create a one-hot encoding for the task - t = torch.zeros(self.x_dim + self.y_dim) + + t = torch.zeros(self.x_dim + self.y_dim, dtype=self.dtype) t[x_i] = 1 t[self.x_dim + y_i] = 1 + return { - "idx": idx, + "idx": idx, # dataset-item index + "orig_idx": self.orig_idx[n_i], # NEW: original-row id of the sample "contexts": self.C[n_i], "task": t, "predictors": self.X[n_i, x_i].unsqueeze(0), @@ -147,4 +167,4 @@ def __getitem__(self, idx): "sample_idx": n_i, "predictor_idx": x_i, "outcome_idx": y_i, - } \ No newline at end of file + } diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index 2f74eaea..ad04dfc9 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -242,55 +242,119 @@ def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer - def training_step(self, batch, batch_idx): - """ + def _batch_size_from_batch(self, batch: dict) -> int: + # all your datasets provide "contexts" in the batch dict + if isinstance(batch, dict) and "contexts" in batch and isinstance(batch["contexts"], torch.Tensor): + return int(batch["contexts"].shape[0]) + return 1 - :param batch: - :param batch_idx: + def _predict_payload(self, batch: dict, **outputs) -> dict: """ - loss = self._batch_loss(batch, batch_idx) - self.log_dict({"train_loss": loss}) - return loss - - def validation_step(self, batch, batch_idx): + Return a minimal, DDP-safe payload for trainer.predict: + - indices needed to reorder across ranks + - model outputs + Everything is detached and moved to CPU to avoid GPU memory blow-ups. """ + out = {} + for k in ("idx", "orig_idx", "sample_idx", "outcome_idx", "predictor_idx"): + if isinstance(batch, dict) and k in batch: + out[k] = batch[k] - :param batch: - :param batch_idx: + out.update(outputs) - """ + # Detach + move tensors to CPU for cheap gather/reorder in wrapper code later + for k, v in list(out.items()): + if isinstance(v, torch.Tensor): + out[k] = v.detach().cpu() + return out + + + def training_step(self, batch, batch_idx): loss = self._batch_loss(batch, batch_idx) - self.log_dict({"val_loss": loss}) + bs = self._batch_size_from_batch(batch) + self.log( + "train_loss", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + sync_dist=True, + batch_size=bs, + ) return loss - def test_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - """ + def validation_step(self, batch, batch_idx): loss = self._batch_loss(batch, batch_idx) - self.log_dict({"test_loss": loss}) + bs = self._batch_size_from_batch(batch) + self.log( + "val_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + batch_size=bs, + ) return loss + + def test_step(self, batch, batch_idx): + loss = self._batch_loss(batch, batch_idx) + bs = self._batch_size_from_batch(batch) + self.log( + "test_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + sync_dist=True, + batch_size=bs, + ) + return loss + def _predict_from_models(self, X, beta_hat, mu_hat): """ Make shapes consistent before computing: y = g( (beta ⊙ X).sum(-1, keepdim=True) + mu ) + ... + """ - Expected canonical shapes: - - beta_hat: (B, y_dim, x_dim) - - mu_hat: (B, y_dim, 1) or (B, y_dim) - - X: one of - * (B, x_dim) - * (B, 1, x_dim) - * (B, y_dim, x_dim) + # ---- Univariate grid case: X is (B, y_dim, x_dim, 1) ---- + # singletask_univariate dataset convention produces predictors shaped (B, y, x, 1) + if isinstance(X, torch.Tensor) and X.dim() == 4 and X.shape[-1] == 1: + # move X to device/dtype + X = X.to(device=beta_hat.device, dtype=beta_hat.dtype) - We also accept beta_hat/mu_hat with an extra trailing singleton dim: - * (B, y_dim, x_dim, 1) -> squeeze to (B, y_dim, x_dim) - """ + # beta_hat should be (B, y, x, 1) in this regime + if beta_hat.dim() == 3: + beta_hat = beta_hat.unsqueeze(-1) + if beta_hat.dim() != 4 or beta_hat.shape[-1] != 1: + raise RuntimeError(f"Univariate expects beta_hat (B,y,x,1); got {beta_hat.shape}") + + # mu_hat should broadcast to (B, y, x, 1) + if not isinstance(mu_hat, torch.Tensor): + mu_hat = torch.as_tensor(mu_hat, device=beta_hat.device, dtype=beta_hat.dtype) + else: + mu_hat = mu_hat.to(device=beta_hat.device, dtype=beta_hat.dtype) + + if mu_hat.dim() == 2: + # (B, y) -> (B, y, 1, 1) -> expand across x + mu_hat = mu_hat.unsqueeze(-1).unsqueeze(-1).expand(-1, beta_hat.shape[1], beta_hat.shape[2], 1) + elif mu_hat.dim() == 3: + # (B, y, x) or (B, y, 1) -> (B, y, x, 1) + if mu_hat.shape[-1] == 1: + mu_hat = mu_hat.unsqueeze(-1).expand(-1, beta_hat.shape[1], beta_hat.shape[2], 1) + else: + mu_hat = mu_hat.unsqueeze(-1) + elif mu_hat.dim() == 4 and mu_hat.shape[-1] == 1: + pass + else: + raise RuntimeError(f"Unsupported mu_hat shape for univariate: {mu_hat.shape}") + + out = (beta_hat * X).sum(dim=-1, keepdim=True) + mu_hat + return self.link_fn(out) # ---- Normalize beta_hat to (B, y_dim, x_dim) ---- if not isinstance(beta_hat, torch.Tensor): @@ -564,18 +628,10 @@ def _batch_loss(self, batch, batch_idx): return pred_loss + reg_loss def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat, - "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), - }) - return batch + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) + # def _params_reshape(self, preds, dataloader): # """ @@ -741,11 +797,10 @@ def _predict_y(self, C, X, beta_hat, mu_hat): def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat, - "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), - }) - return batch + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) + + # def _params_reshape(self, preds, dataloader): @@ -882,11 +937,9 @@ def _predict_y(self, C, X, beta_hat, mu_hat): def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat, - "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), - }) - return batch + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) + # def _batch_loss(self, batch, batch_idx): @@ -1041,18 +1094,10 @@ def _batch_loss(self, batch, batch_idx): return pred_loss + reg_loss def predict_step(self, batch, batch_idx): - """ - - :param batch: - :param batch_idx: - - """ beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat, # keep last dim; downstream handles shape uniformly - "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), - }) - return batch + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) + # def _params_reshape(self, preds, dataloader): # """ @@ -1176,11 +1221,9 @@ def _predict_y(self, C, X, beta_hat, mu_hat): def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat, # keep last dim - "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), - }) - return batch + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) + class TasksplitContextualizedUnivariateRegression(ContextualizedRegressionBase): @@ -1273,11 +1316,9 @@ def _predict_y(self, C, X, beta_hat, mu_hat): def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(batch) - batch.update({ - "betas": beta_hat, # keep last dim - "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), - }) - return batch + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) + # def _params_reshape(self, preds, dataloader): @@ -1340,17 +1381,16 @@ def __init__(self, context_dim, x_dim, **kwargs): def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(batch) - beta_hat = beta_hat.squeeze(-1) # (B, y, x) + beta_hat = beta_hat.squeeze(-1) # (B, y, x) + beta_hat_T = beta_hat.transpose(1, 2) signs = torch.sign(beta_hat) signs[signs != signs.transpose(1, 2)] = 0 correlations = signs * torch.sqrt(torch.abs(beta_hat * beta_hat_T)) - batch.update({ - "betas": beta_hat, - "mus": mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1), - "correlations": correlations, - }) - return batch + + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat, correlations=correlations) + @@ -1410,14 +1450,12 @@ def __init__( self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) def predict_step(self, batch, batch_idx): - beta_hat, mu_hat = self(batch) # self.forward expects dict batch - # Zero diagonal (mask pre-registered in __init__) + beta_hat, mu_hat = self(batch) # dict batch beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) - batch.update({ - "betas": beta_hat, - "mus": mu_hat, - }) - return batch + + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) + @@ -1440,11 +1478,9 @@ def __init__(self, context_dim, x_dim, **kwargs): def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(batch) # dict batch - # Enforce symmetry (hotfix) and zero diagonal beta_hat = beta_hat + beta_hat.transpose(1, 2) beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) - batch.update({ - "betas": beta_hat, - "mus": mu_hat, - }) - return batch + + mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) + return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) + diff --git a/contextualized/regression/trainers.py b/contextualized/regression/trainers.py index da96e481..527c1164 100644 --- a/contextualized/regression/trainers.py +++ b/contextualized/regression/trainers.py @@ -2,27 +2,180 @@ PyTorch-Lightning trainers used for Contextualized regression. """ -from typing import Any, Tuple, List +from typing import Any, Tuple, List, Dict, Optional import numpy as np import torch +import torch.distributed as dist import pytorch_lightning as pl from pytorch_lightning.plugins.environments import LightningEnvironment import os from pytorch_lightning.strategies import DDPStrategy + def _stack_from_preds(preds: List[dict], key: str) -> torch.Tensor: """Concatenate a tensor field from the list of batch dicts returned by predict().""" + preds = _flatten_pl_predict_output(preds) parts = [] for p in preds: val = p[key] - # ensure tensor on cpu if isinstance(val, np.ndarray): val = torch.from_numpy(val) parts.append(val.detach().cpu()) return torch.cat(parts, dim=0) +def _is_distributed() -> bool: + return dist.is_available() and dist.is_initialized() + + +def _is_main_process() -> bool: + return (not _is_distributed()) or dist.get_rank() == 0 + + +def _flatten_pl_predict_output(preds): + """ + Lightning can return: + - list[dict] (single dataloader) + - list[list[dict]] (multiple dataloaders) + Normalize to list[dict]. + """ + if preds is None: + return [] + if len(preds) > 0 and isinstance(preds[0], list): + out = [] + for sub in preds: + out.extend(sub) + return out + return preds + + +def _to_numpy_cpu(x): + if x is None: + return None + if isinstance(x, np.ndarray): + return x + if torch.is_tensor(x): + return x.detach().cpu().numpy() + return np.asarray(x) + + +def _pack_keys_from_preds(preds: list, keys: Tuple[str, ...]) -> Dict[str, np.ndarray]: + """ + Pack only requested keys from list[dict] predictions into numpy arrays. + Concats on axis 0. + """ + preds = _flatten_pl_predict_output(preds) + if not preds: + return {} + + packed: Dict[str, List[np.ndarray]] = {k: [] for k in keys} + for p in preds: + for k in keys: + if k in p: + v = _to_numpy_cpu(p[k]) + if v is not None: + packed[k].append(v) + + out: Dict[str, np.ndarray] = {} + for k, parts in packed.items(): + if not parts: + continue + out[k] = np.concatenate(parts, axis=0) + return out + + +def _gather_object_to_rank0(obj): + """ + Gather arbitrary Python objects to rank 0. + Returns list[obj] on rank 0, None on other ranks. + """ + if not _is_distributed(): + return [obj] + + world_size = dist.get_world_size() + if world_size == 1: + return [obj] + + if _is_main_process(): + gathered = [None for _ in range(world_size)] + dist.gather_object(obj, object_gather_list=gathered, dst=0) + return gathered + else: + dist.gather_object(obj, object_gather_list=None, dst=0) + return None + + +def _merge_packed_payloads(payloads: List[Optional[Dict[str, np.ndarray]]]) -> Dict[str, np.ndarray]: + """ + Merge list[dict[str, np.ndarray]] -> dict[str, np.ndarray] by concatenation axis 0. + """ + merged: Dict[str, np.ndarray] = {} + payloads = [p for p in payloads if p] + if not payloads: + return merged + + keys = set() + for p in payloads: + keys.update(p.keys()) + + for k in keys: + chunks = [p[k] for p in payloads if (k in p) and (p[k] is not None) and (len(p[k]) > 0)] + if not chunks: + continue + merged[k] = np.concatenate(chunks, axis=0) + return merged + + +def _stable_sort_and_dedupe(payload: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Sort payload arrays by dataset-local 'idx' when present (correct for subsets), + else fall back to 'orig_idx'. Then dedupe (DistributedSampler may pad/duplicate). + """ + if not payload: + return payload + + key = "idx" if "idx" in payload else ("orig_idx" if "orig_idx" in payload else None) + if key is None: + return payload + + k = payload[key].astype(np.int64) + if k.size == 0: + return payload + + order = np.argsort(k, kind="mergesort") + k_sorted = k[order] + _, uniq_pos = np.unique(k_sorted, return_index=True) + keep = order[np.sort(uniq_pos)] + + out: Dict[str, np.ndarray] = {} + for name, v in payload.items(): + if isinstance(v, np.ndarray) and v.shape[0] == k.shape[0]: + out[name] = v[keep] + else: + out[name] = v + return out + + + +def _gather_predict_payload(preds, keys: Tuple[str, ...]) -> Optional[Dict[str, np.ndarray]]: + """ + Packs requested keys from local preds, gathers to rank0 under DDP, merges, and + stable-sorts/dedupes by orig_idx (if present). + Returns payload dict on rank0; returns None on non-rank0 in DDP. + """ + local = _pack_keys_from_preds(preds, keys) + + gathered = _gather_object_to_rank0(local) + if gathered is None: + return None # non-rank0 DDP + + merged = _merge_packed_payloads(gathered) + merged = _stable_sort_and_dedupe(merged) + return merged + + + class RegressionTrainer(pl.Trainer): """ Trains the contextualized.regression lightning_modules @@ -33,71 +186,80 @@ class RegressionTrainer(pl.Trainer): @torch.no_grad() def predict_params(self, model: pl.LightningModule, dataloader) -> Tuple[np.ndarray, np.ndarray]: - """ - Returns context-specific regression parameters. + preds = super().predict(model, dataloader) + + payload = _gather_predict_payload(preds, keys=("idx", "orig_idx", "betas", "mus")) + if payload is None: + # non-rank0 DDP: return nothing to avoid duplicated outputs + return None, None + + + if "betas" not in payload or "mus" not in payload: + raise RuntimeError("predict_params: predict_step must return 'betas' and 'mus' (and ideally 'orig_idx').") + + return payload["betas"], payload["mus"] - Returns - ------- - (betas, mus) - betas: (n, y_dim, x_dim) - mus: (n, y_dim) or (n, y_dim, 1) depending on the model - """ - preds = super().predict(model, dataloader) # list of batch dicts - betas = _stack_from_preds(preds, "betas") - mus = _stack_from_preds(preds, "mus") - return betas.numpy(), mus.numpy() @torch.no_grad() def predict_y(self, model: pl.LightningModule, dataloader) -> np.ndarray: - """ - Returns context-specific predictions of the response Y. + preds = super().predict(model, dataloader) - Returns - ------- - y_hat : (n, y_dim, 1) for multivariate, or (n, y_dim, x_dim) for univariate - """ - preds = super().predict(model, dataloader) # list of batch dicts + # Prefer lightweight gather, but allow legacy keys if present. + payload = _gather_predict_payload(preds, keys=("idx", "orig_idx", "betas", "mus")) + if payload is None: + return None # non-rank0 DDP + + if "betas" not in payload or "mus" not in payload: + raise RuntimeError("predict_y: predict_step must return 'betas' and 'mus'.") + + betas = torch.as_tensor(payload["betas"]) + mus = torch.as_tensor(payload["mus"]) - y_parts = [] - for p in preds: - # Required keys were added by model.predict_step(...) - C = p["contexts"] - X = p["predictors"] - betas = p["betas"] - mus = p["mus"] + # If legacy contexts/predictors were returned and gathered, use them. + if ("contexts" in payload) and ("predictors" in payload): + C = torch.as_tensor(payload["contexts"]) + X = torch.as_tensor(payload["predictors"]) + else: + # Option A path: reconstruct from dataset via dataset-local idx (NOT orig_idx) + ds = getattr(dataloader, "dataset", None) + if ds is None: + raise RuntimeError("predict_y: dataloader has no .dataset; cannot reconstruct C/X.") - # Ensure tensors on CPU first; model will move as needed inside helpers - if not torch.is_tensor(C): C = torch.as_tensor(C) - if not torch.is_tensor(X): X = torch.as_tensor(X) - if not torch.is_tensor(betas): betas = torch.as_tensor(betas) - if not torch.is_tensor(mus): mus = torch.as_tensor(mus) + idx_np = payload["idx"].astype(np.int64) + idx_t = torch.as_tensor(idx_np, dtype=torch.long) - # --- shape fixes for multivariate (3D) and univariate (4D) --- - # Multivariate convention: X (B, y, x), betas (B, y, x), mus (B, y, 1) - # Univariate convention: X (B, y, x, 1), betas (B, y, x, 1), mus (B, y, x, 1) + # Support Subset wrapper if user wrapped loaders externally + if hasattr(ds, "dataset") and hasattr(ds, "indices"): + base = ds.dataset + if not (hasattr(base, "C") and hasattr(base, "X")): + raise RuntimeError("predict_y: Subset base dataset must expose .C and .X.") + base_pos = np.asarray(ds.indices, dtype=np.int64)[idx_np] + base_pos_t = torch.as_tensor(base_pos, dtype=torch.long) + C = base.C[base_pos_t] + X = base.X[base_pos_t] + else: + if not (hasattr(ds, "C") and hasattr(ds, "X")): + raise RuntimeError("predict_y: dataset must expose .C and .X tensors for Option A prediction.") + C = ds.C[idx_t] + X = ds.X[idx_t] - # If X is (B, x) and betas is (B, y, x), expand X -> (B, 1, x) - if X.dim() == 2 and betas.dim() == 3 and betas.size(-1) == X.size(-1): - X = X.unsqueeze(1) + # dtype align + if torch.is_tensor(C): + C = C.to(dtype=betas.dtype) + else: + C = torch.as_tensor(C, dtype=betas.dtype) - # If betas is (B, y, x) but X is (B, y, x, 1), add trailing singleton to betas - if betas.dim() == 3 and X.dim() == 4 and betas.size(-1) == X.size(-2): - betas = betas.unsqueeze(-1) - + if torch.is_tensor(X): + X = X.to(dtype=betas.dtype) + else: + X = torch.as_tensor(X, dtype=betas.dtype) - # Ensure mus trailing dim is singleton - if mus.dim() == 2: # (B, y) - mus = mus.unsqueeze(-1) # (B, y, 1) - elif mus.dim() == 3 and X.dim() == 4 and mus.size(-1) != 1: - mus = mus.unsqueeze(-1) # (B, y, x, 1) - # --- end shape fixes --- + with torch.no_grad(): + yhat = model._predict_y(C, X, betas, mus).detach().cpu().numpy() - yhat = model._predict_y(C, X, betas, mus) # uses model's link - y_parts.append(yhat.detach().cpu()) + return yhat - y = torch.cat(y_parts, dim=0) - return y.numpy() @@ -109,27 +271,30 @@ class CorrelationTrainer(RegressionTrainer): @torch.no_grad() def predict_correlation(self, model: pl.LightningModule, dataloader) -> np.ndarray: - """ - Returns context-specific correlation networks containing Pearson's correlation coefficient. - - Returns - ------- - correlations : (n, x_dim, x_dim) - """ - # If the model already returns 'correlations' in predict_step, prefer that. preds = super().predict(model, dataloader) - if "correlations" in preds[0]: - cors = torch.cat([p["correlations"].detach().cpu() for p in preds], dim=0) - return cors.numpy() - - # Fallback: derive from betas like before + preds_flat = _flatten_pl_predict_output(preds) + + # If model returns correlations directly, gather and reorder them. + if preds_flat and ("correlations" in preds_flat[0]): + payload = _gather_predict_payload(preds, keys=("orig_idx", "correlations")) + if payload is None: + return None # non-rank0 DDP + if "correlations" not in payload: + raise RuntimeError("predict_correlation: predict_step returned no 'correlations'.") + return payload["correlations"] + + # Fallback: derive from betas betas, _ = self.predict_params(model, dataloader) + if betas is None: + return None # non-rank0 DDP + signs = np.sign(betas) signs[signs != np.transpose(signs, (0, 2, 1))] = 0 correlations = signs * np.sqrt(np.abs(betas * np.transpose(betas, (0, 2, 1)))) return correlations + class MarkovTrainer(CorrelationTrainer): """ Trains the contextualized.regression markov graph lightning_modules From 91f6562adadbf6cb18243ed606838bc906faa9e2 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Tue, 30 Dec 2025 22:25:15 -0500 Subject: [PATCH 14/19] Update lightning_modules error --- 01_regressor_cpu_single.py | 92 - 02_networks_cpu_single.py | 49 - bench_scale_contextualized_regression.py | 766 --------- contextualized/easy/ContextualizedNetworks.py | 82 +- .../easy/wrappers/SKLearnWrapper.py | 1504 ++++++++--------- .../regression/lightning_modules.py | 19 +- regression_scale_bench.py | 480 ------ scale_bench.py | 540 ++++++ scale_bench_networks.py | 623 +++++++ 9 files changed, 1910 insertions(+), 2245 deletions(-) delete mode 100644 01_regressor_cpu_single.py delete mode 100644 02_networks_cpu_single.py delete mode 100644 bench_scale_contextualized_regression.py delete mode 100644 regression_scale_bench.py create mode 100644 scale_bench.py create mode 100644 scale_bench_networks.py diff --git a/01_regressor_cpu_single.py b/01_regressor_cpu_single.py deleted file mode 100644 index 3b902840..00000000 --- a/01_regressor_cpu_single.py +++ /dev/null @@ -1,92 +0,0 @@ -import numpy as np - -from contextualized.easy import ContextualizedRegressor - -def main(): - np.random.seed(0) - - n = 96 - c_dim = 4 - x_dim = 6 - y_dim = 2 - - C = np.random.randn(n, c_dim).astype(np.float32) - X = np.random.randn(n, x_dim).astype(np.float32) - - # Construct a learnable signal: Y = X @ W + noise - W = np.array([[1.5, -0.5], - [0.7, 0.2], - [0.0, 0.0], - [0.3, -1.0], - [0.0, 0.0], - [0.2, 0.1]], dtype=np.float32) # (x_dim, y_dim) - - Y = (X @ W + 0.05 * np.random.randn(n, y_dim).astype(np.float32)).astype(np.float32) - - model = ContextualizedRegressor( - metamodel_type="subtype", - num_archetypes=4, - univariate=False, - ) - - # CPU-only fit - model.fit( - C=C, X=X, Y=Y, - accelerator="cpu", - devices=1, - strategy="auto", - max_epochs=3, - val_split=0.2, - num_workers=0, - enable_progress_bar=False, - logger=False, - ) - - yhat = model.predict(C, X) - betas, mus = model.predict_params(C) - - # --- shape sanity --- - yhat_arr = np.asarray(yhat) - betas_arr = np.asarray(betas) - mus_arr = np.asarray(mus) - - print("SINGLE CPU REGRESSOR") - print("yhat.shape:", yhat_arr.shape) - print("betas.shape:", betas_arr.shape) - print("mus.shape:", mus_arr.shape) - - # Expected conventions (based on your current implementation) - assert yhat_arr.shape[0] == n, "yhat first dim should be n" - assert betas_arr.shape[0] == n, "betas first dim should be n" - assert mus_arr.shape[0] == n, "mus first dim should be n" - - # yhat is typically (n, y_dim, 1) for multivariate in your code path - # betas is (n, y_dim, x_dim) - assert betas_arr.shape[1] == y_dim and betas_arr.shape[2] == x_dim, "betas expected (n, y_dim, x_dim)" - - # --- quick quality check: MSE vs baseline mean predictor --- - # squeeze last dim if present - yhat_s = yhat_arr[..., 0] if (yhat_arr.ndim == 3 and yhat_arr.shape[-1] == 1) else yhat_arr - y_true = Y - - mse = np.mean((yhat_s - y_true) ** 2) - baseline = np.mean((np.mean(y_true, axis=0, keepdims=True) - y_true) ** 2) - - print("MSE:", float(mse)) - print("Baseline MSE (mean predictor):", float(baseline)) - assert np.isfinite(mse), "MSE must be finite" - assert mse < baseline, "Model should beat baseline mean predictor on this synthetic signal" - - # --- ordering check (this is critical for your gather/sort design) --- - perm = np.random.permutation(n) - yhat_perm = np.asarray(model.predict(C[perm], X[perm])) - yhat_perm_s = yhat_perm[..., 0] if (yhat_perm.ndim == 3 and yhat_perm.shape[-1] == 1) else yhat_perm - - max_err = np.max(np.abs(yhat_perm_s - yhat_s[perm])) - print("Ordering check max_err:", float(max_err)) - assert max_err < 1e-5, "Prediction order is not stable under permutation" - - print("PASS: single-process CPU regressor tests") - -if __name__ == "__main__": - main() diff --git a/02_networks_cpu_single.py b/02_networks_cpu_single.py deleted file mode 100644 index 5999563c..00000000 --- a/02_networks_cpu_single.py +++ /dev/null @@ -1,49 +0,0 @@ -import numpy as np - -from contextualized.easy import ContextualizedCorrelationNetworks # adjust if your import path differs - -def main(): - np.random.seed(0) - - n = 80 - c_dim = 3 - x_dim = 5 - - C = np.random.randn(n, c_dim).astype(np.float32) - X = np.random.randn(n, x_dim).astype(np.float32) - - net = ContextualizedCorrelationNetworks( - metamodel_type="subtype", - num_archetypes=4, - ) - - net.fit( - C=C, X=X, - accelerator="cpu", - devices=1, - strategy="auto", - max_epochs=2, - val_split=0.2, - num_workers=0, - enable_progress_bar=False, - logger=False, - ) - - rhos2 = net.predict_correlation(C, individual_preds=False, squared=True) - rhos2 = np.asarray(rhos2) - - print("SINGLE CPU CORRELATION NETWORKS") - print("rhos2.shape:", rhos2.shape) - - assert rhos2.shape[0] == n and rhos2.shape[1] == x_dim and rhos2.shape[2] == x_dim, \ - "Expected (n, x_dim, x_dim)" - assert np.all(np.isfinite(rhos2)), "Correlations must be finite" - - # Symmetry sanity (should be symmetric-ish) - sym_err = np.max(np.abs(rhos2 - np.transpose(rhos2, (0, 2, 1)))) - print("Symmetry max_err:", float(sym_err)) - - print("PASS: single-process CPU networks tests") - -if __name__ == "__main__": - main() diff --git a/bench_scale_contextualized_regression.py b/bench_scale_contextualized_regression.py deleted file mode 100644 index 65b8e6c1..00000000 --- a/bench_scale_contextualized_regression.py +++ /dev/null @@ -1,766 +0,0 @@ -#!/usr/bin/env python3 -""" -bench_scale_contextualized_regression.py - -Synthetic scaling benchmark for Contextualized regression workflow. - -Modes: - - run : run a single config (supports torch distributed launch) - - sweep : run CPU + GPU(1..K) sequentially (spawns torch distributed runs) and plot - -Examples (Lambda 4x GPU single node): - # Full sweep: CPU + 1/2/3/4 GPU and plots - python bench_scale_contextualized_regression.py sweep \ - --include_cpu \ - --max_gpus 4 \ - --n 200000 \ - --c_dim 16 --x_dim 64 --y_dim 8 \ - --epochs 5 \ - --train_batch_size 2048 --val_batch_size 2048 --test_batch_size 4096 \ - --num_workers 4 \ - --out_dir ./scale_runs/run1 - - # Single run on 1 GPU (no torchrun needed for 1 device) - python bench_scale_contextualized_regression.py run --accelerator gpu --devices 1 --out_dir ./one_gpu - - # Single run on 4 GPUs using torch distributed launcher - python -m torch.distributed.run --standalone --nproc_per_node=4 \ - bench_scale_contextualized_regression.py run --accelerator gpu --devices 4 --out_dir ./four_gpu -""" - -from __future__ import annotations - -import argparse -import csv -import json -import os -import platform -import re -import shutil -import subprocess -import sys -import time -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any, Dict, Optional, Tuple - -import numpy as np - - -# ------------------------- -# Utilities -# ------------------------- - -def _now_ts() -> str: - return time.strftime("%Y%m%d_%H%M%S") - - -def _maybe_git_commit() -> Optional[str]: - try: - if not shutil.which("git"): - return None - out = subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL).decode().strip() - if re.fullmatch(r"[0-9a-f]{40}", out): - return out - except Exception: - pass - return None - - -def _rank_world() -> Tuple[int, int, int]: - """(rank, world_size, local_rank) for torchrun-style environments.""" - rank = int(os.environ.get("RANK", "0")) - world = int(os.environ.get("WORLD_SIZE", "1")) - local = int(os.environ.get("LOCAL_RANK", "0")) - return rank, world, local - - -def _set_seed(seed: int) -> None: - np.random.seed(seed) - try: - import torch - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - except Exception: - pass - - -def _cuda_sync_if_available() -> None: - try: - import torch - if torch.cuda.is_available(): - torch.cuda.synchronize() - except Exception: - pass - - -def _safe_float(x: Any) -> float: - try: - return float(x) - except Exception: - return float("nan") - - -def _ensure_dir(p: Path) -> None: - p.mkdir(parents=True, exist_ok=True) - - -# ------------------------- -# Synthetic data generator -# ------------------------- - -def make_synth_contextual_regression( - n: int, - c_dim: int, - x_dim: int, - y_dim: int, - noise: float, - seed: int, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict[str, np.ndarray]]: - """ - Create synthetic data where coefficients beta(C) vary with context C: - beta_flat = C @ W^T + b - beta = reshape(beta_flat, y_dim, x_dim) - mu = C @ V - y = sum_j beta[..., j]*x_j + mu + eps - """ - rng = np.random.default_rng(seed) - - C = rng.normal(size=(n, c_dim)).astype(np.float32) - X = rng.normal(size=(n, x_dim)).astype(np.float32) - - # Context -> beta mapping - W = (rng.normal(size=(y_dim * x_dim, c_dim)).astype(np.float32) / np.sqrt(c_dim)).astype(np.float32) - b = (0.1 * rng.normal(size=(y_dim * x_dim,))).astype(np.float32) - - beta_flat = C @ W.T + b[None, :] # (n, y_dim*x_dim) - beta = beta_flat.reshape(n, y_dim, x_dim) # (n, y_dim, x_dim) - - # Context -> intercept - V = (0.1 * rng.normal(size=(c_dim, y_dim))).astype(np.float32) - mu = (C @ V).astype(np.float32) # (n, y_dim) - - y = (beta * X[:, None, :]).sum(axis=-1) + mu - y = y + noise * rng.normal(size=(n, y_dim)).astype(np.float32) - Y = y.astype(np.float32) - - truth = {"beta": beta, "mu": mu} - return C, X, Y, truth - - -def make_splits(n: int, val_frac: float, test_frac: float, seed: int) -> Dict[str, np.ndarray]: - assert 0.0 <= val_frac < 1.0 - assert 0.0 <= test_frac < 1.0 - assert val_frac + test_frac < 1.0 - - rng = np.random.default_rng(seed) - idx = np.arange(n, dtype=np.int64) - rng.shuffle(idx) - - n_test = int(round(n * test_frac)) - n_val = int(round(n * val_frac)) - n_train = n - n_val - n_test - - train_idx = idx[:n_train] - val_idx = idx[n_train:n_train + n_val] - test_idx = idx[n_train + n_val:] - - return {"train_idx": train_idx, "val_idx": val_idx, "test_idx": test_idx} - - -# ------------------------- -# Result schema -# ------------------------- - -@dataclass -class BenchResult: - tag: str - accelerator: str - devices: int - backend: str - n: int - n_train: int - n_val: int - n_test: int - c_dim: int - x_dim: int - y_dim: int - epochs: int - train_batch_size: int - val_batch_size: int - test_batch_size: int - num_workers: int - seed: int - - fit_time_s: float - predict_time_s: float - total_time_s: float - - train_throughput_sps: float # samples/sec (global, unique samples per epoch) - predict_throughput_sps: float - - test_mse: float - - hostname: str - python: str - platform: str - torch: Optional[str] - lightning: Optional[str] - git_commit: Optional[str] - - -# ------------------------- -# Core runner -# ------------------------- - -def run_one(args: argparse.Namespace) -> Optional[BenchResult]: - rank, world, local_rank = _rank_world() - - # Make output directory (all ranks see it; only rank0 writes result). - out_dir = Path(args.out_dir).resolve() - _ensure_dir(out_dir) - - # Deferred imports so CPU-only environments don't choke on CUDA imports early. - try: - import torch - import pytorch_lightning as pl - except Exception as e: - if rank == 0: - raise RuntimeError( - "Failed to import torch / pytorch_lightning. Ensure your env has them installed." - ) from e - return None - - # Set device for GPU - if args.accelerator == "gpu": - if not torch.cuda.is_available(): - if rank == 0: - raise RuntimeError("Requested accelerator=gpu but torch.cuda.is_available() is False.") - return None - torch.cuda.set_device(local_rank) - - # Determinism (best-effort) - _set_seed(args.seed) - try: - pl.seed_everything(args.seed, workers=True) - except Exception: - pass - - # Synthesize data (replicated across ranks; ok for benchmarking) - C, X, Y, _truth = make_synth_contextual_regression( - n=args.n, - c_dim=args.c_dim, - x_dim=args.x_dim, - y_dim=args.y_dim, - noise=args.noise, - seed=args.seed, - ) - - splits = make_splits(args.n, args.val_frac, args.test_frac, args.seed) - train_idx = splits["train_idx"] - val_idx = splits["val_idx"] - test_idx = splits["test_idx"] - - # Build model using your regression workflow - # We prefer the easy wrapper if available, as it exercises your end-to-end stack. - try: - from contextualized.easy import ContextualizedRegressor - except Exception as e: - if rank == 0: - raise RuntimeError( - "Could not import contextualized.easy.ContextualizedRegressor. " - "Verify your package is importable from this environment." - ) from e - return None - - # Strategy configuration (DDP for multi-GPU) - strategy_obj = None - if args.devices > 1: - # Use Lightning DDPStrategy explicitly to control backend (nccl/gloo) - try: - from pytorch_lightning.strategies import DDPStrategy - strategy_obj = DDPStrategy( - process_group_backend=args.backend, - find_unused_parameters=False, - ) - except Exception: - strategy_obj = "ddp" # fallback: let Lightning decide - - # Create the regressor (robust to wrapper signature differences) - # We attempt common kwargs; if wrapper rejects, we show a clear error on rank0. - model_kwargs: Dict[str, Any] = dict( - num_archetypes=args.num_archetypes, - encoder_type=args.encoder_type, - max_epochs=args.epochs, - learning_rate=args.learning_rate, - # data / loader knobs (if wrapper exposes them) - train_batch_size=args.train_batch_size, - val_batch_size=args.val_batch_size, - test_batch_size=args.test_batch_size, - val_split=args.val_frac, - # trainer knobs - accelerator=("gpu" if args.accelerator == "gpu" else "cpu"), - devices=args.devices, - num_workers=args.num_workers, - deterministic=args.deterministic, - enable_checkpointing=False, - logger=False, - enable_progress_bar=False, - ) - - # Some wrappers may not accept the above keys; strip unsupported keys dynamically. - def instantiate_contextualized_regressor() -> Any: - import inspect - sig = inspect.signature(ContextualizedRegressor.__init__) - accepted = set(sig.parameters.keys()) - # Always remove 'self' - accepted.discard("self") - filt = {k: v for k, v in model_kwargs.items() if k in accepted} - # Strategy: some wrappers accept "strategy" directly - if "strategy" in accepted and strategy_obj is not None: - filt["strategy"] = strategy_obj - return ContextualizedRegressor(**filt) - - try: - reg = instantiate_contextualized_regressor() - except TypeError as e: - if rank == 0: - raise RuntimeError( - "Failed to instantiate ContextualizedRegressor with inferred kwargs.\n" - "This usually means the wrapper signature differs from what this benchmark expects.\n" - "Action: open contextualized/easy/wrappers.py and confirm which Trainer/loader args are supported.\n" - f"Original error: {e}" - ) - return None - - # Fit timing - _cuda_sync_if_available() - t0 = time.perf_counter() - - # Prefer explicit indices to exercise your stable-index paths if supported - fit_kwargs: Dict[str, Any] = dict(C=C, X=X, Y=Y) - try: - import inspect - fit_sig = inspect.signature(reg.fit) - if "train_idx" in fit_sig.parameters: - fit_kwargs["train_idx"] = train_idx - if "val_idx" in fit_sig.parameters: - fit_kwargs["val_idx"] = val_idx - if "test_idx" in fit_sig.parameters: - fit_kwargs["test_idx"] = test_idx - except Exception: - pass - - reg.fit(**fit_kwargs) - - _cuda_sync_if_available() - fit_time = time.perf_counter() - t0 - - # Predict timing (prefer predict_idx if supported) - _cuda_sync_if_available() - t1 = time.perf_counter() - - yhat = None - pred_kwargs_full: Dict[str, Any] = dict(C=C, X=X) - try: - import inspect - pred_sig = inspect.signature(reg.predict) - if "predict_idx" in pred_sig.parameters: - pred_kwargs_full["predict_idx"] = test_idx - yhat = reg.predict(**pred_kwargs_full) - else: - # fallback: feed the subset directly - yhat = reg.predict(C[test_idx], X[test_idx]) - except Exception: - # fallback: feed the subset directly - yhat = reg.predict(C[test_idx], X[test_idx]) - - _cuda_sync_if_available() - pred_time = time.perf_counter() - t1 - - # Only rank0 should compute/report metrics if wrapper returns None on non-rank0 - if yhat is None: - return None - - # Convert prediction to numpy - if hasattr(yhat, "detach"): - yhat_np = yhat.detach().cpu().numpy() - else: - yhat_np = np.asarray(yhat) - - y_true = Y[test_idx] - test_mse = float(np.mean((yhat_np - y_true) ** 2)) - - total_time = fit_time + pred_time - - # Throughput (global unique samples per epoch) - n_train = int(train_idx.shape[0]) - n_val = int(val_idx.shape[0]) - n_test = int(test_idx.shape[0]) - - train_throughput = (n_train * args.epochs) / max(fit_time, 1e-9) - pred_throughput = (n_test) / max(pred_time, 1e-9) - - # Version info - torch_ver = getattr(torch, "__version__", None) - lightning_ver = getattr(pl, "__version__", None) - - tag = args.tag or f"{args.accelerator}_{args.devices}dev" - result = BenchResult( - tag=tag, - accelerator=args.accelerator, - devices=args.devices, - backend=args.backend, - - n=args.n, - n_train=n_train, - n_val=n_val, - n_test=n_test, - c_dim=args.c_dim, - x_dim=args.x_dim, - y_dim=args.y_dim, - epochs=args.epochs, - train_batch_size=args.train_batch_size, - val_batch_size=args.val_batch_size, - test_batch_size=args.test_batch_size, - num_workers=args.num_workers, - seed=args.seed, - - fit_time_s=_safe_float(fit_time), - predict_time_s=_safe_float(pred_time), - total_time_s=_safe_float(total_time), - - train_throughput_sps=_safe_float(train_throughput), - predict_throughput_sps=_safe_float(pred_throughput), - - test_mse=_safe_float(test_mse), - - hostname=platform.node(), - python=sys.version.replace("\n", " "), - platform=f"{platform.system()} {platform.release()} ({platform.machine()})", - torch=torch_ver, - lightning=lightning_ver, - git_commit=_maybe_git_commit(), - ) - - # Write per-run JSON on rank0 only - if rank == 0: - out_json = out_dir / f"result_{tag}.json" - with out_json.open("w") as f: - json.dump(asdict(result), f, indent=2) - print(f"[rank0] Wrote: {out_json}") - print( - f"[rank0] fit={result.fit_time_s:.3f}s " - f"pred={result.predict_time_s:.3f}s " - f"total={result.total_time_s:.3f}s " - f"train_thr={result.train_throughput_sps:.1f} samp/s " - f"pred_thr={result.predict_throughput_sps:.1f} samp/s " - f"test_mse={result.test_mse:.6f}" - ) - - return result - - -# ------------------------- -# Sweep + plotting -# ------------------------- - -def _load_results(out_dir: Path) -> Dict[str, Dict[str, Any]]: - results: Dict[str, Dict[str, Any]] = {} - for p in sorted(out_dir.glob("result_*.json")): - with p.open("r") as f: - d = json.load(f) - results[d["tag"]] = d - return results - - -def _write_csv(out_dir: Path, rows: Dict[str, Dict[str, Any]]) -> Path: - out_csv = out_dir / "results.csv" - keys = sorted(next(iter(rows.values())).keys()) - with out_csv.open("w", newline="") as f: - w = csv.DictWriter(f, fieldnames=keys) - w.writeheader() - for _tag, d in sorted(rows.items(), key=lambda kv: (kv[1]["accelerator"], kv[1]["devices"])): - w.writerow({k: d.get(k, None) for k in keys}) - return out_csv - - -def plot_results(out_dir: Path, baseline_devices: int = 1, include_cpu_speedup: bool = True) -> None: - # Use non-interactive backend for headless servers - import matplotlib - matplotlib.use("Agg") - import matplotlib.pyplot as plt - - rows = _load_results(out_dir) - if not rows: - raise RuntimeError(f"No result_*.json found in {out_dir}") - - # Split CPU vs GPU - cpu = [d for d in rows.values() if d["accelerator"] == "cpu"] - gpu = [d for d in rows.values() if d["accelerator"] == "gpu"] - - gpu_sorted = sorted(gpu, key=lambda d: d["devices"]) - cpu_sorted = sorted(cpu, key=lambda d: d["devices"]) - - # Baseline for speedup/efficiency - base_gpu = next((d for d in gpu_sorted if d["devices"] == baseline_devices), None) - if base_gpu is None and gpu_sorted: - base_gpu = gpu_sorted[0] - - # Helper series - def series(ds, key): - return [float(d[key]) for d in ds] - - # Plot 1: wall time (fit/predict/total) vs devices (GPU) - if gpu_sorted: - x = [d["devices"] for d in gpu_sorted] - - fit_t = series(gpu_sorted, "fit_time_s") - pred_t = series(gpu_sorted, "predict_time_s") - tot_t = series(gpu_sorted, "total_time_s") - - plt.figure(figsize=(8, 5)) - plt.plot(x, fit_t, marker="o", label="fit_time_s") - plt.plot(x, pred_t, marker="o", label="predict_time_s") - plt.plot(x, tot_t, marker="o", label="total_time_s") - plt.xlabel("GPUs (devices)") - plt.ylabel("Seconds") - plt.title("Wall time vs GPUs") - plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.6) - plt.legend() - p = out_dir / "wall_time_vs_gpus.png" - plt.tight_layout() - plt.savefig(p, dpi=200) - plt.close() - - # Plot 2: throughput vs devices (GPU) - if gpu_sorted: - x = [d["devices"] for d in gpu_sorted] - thr = series(gpu_sorted, "train_throughput_sps") - - plt.figure(figsize=(8, 5)) - plt.plot(x, thr, marker="o") - plt.xlabel("GPUs (devices)") - plt.ylabel("Train throughput (samples/sec, global)") - plt.title("Train throughput vs GPUs") - plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.6) - p = out_dir / "throughput_vs_gpus.png" - plt.tight_layout() - plt.savefig(p, dpi=200) - plt.close() - - # Plot 3: speedup + efficiency vs devices (GPU) - if gpu_sorted and base_gpu is not None: - x = [d["devices"] for d in gpu_sorted] - base_thr = float(base_gpu["train_throughput_sps"]) - speedup = [float(d["train_throughput_sps"]) / max(base_thr, 1e-9) for d in gpu_sorted] - efficiency = [s / max(dev, 1e-9) for s, dev in zip(speedup, x)] - - plt.figure(figsize=(8, 5)) - plt.plot(x, speedup, marker="o", label=f"Speedup vs {base_gpu['devices']} GPU") - plt.plot(x, x, linestyle="--", label="Ideal linear speedup") - plt.xlabel("GPUs (devices)") - plt.ylabel("Speedup") - plt.title("Speedup vs GPUs") - plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.6) - plt.legend() - p = out_dir / "speedup_vs_gpus.png" - plt.tight_layout() - plt.savefig(p, dpi=200) - plt.close() - - plt.figure(figsize=(8, 5)) - plt.plot(x, efficiency, marker="o") - plt.xlabel("GPUs (devices)") - plt.ylabel("Scaling efficiency (speedup / GPUs)") - plt.title("Scaling efficiency vs GPUs") - plt.grid(True, linestyle="--", linewidth=0.6, alpha=0.6) - p = out_dir / "efficiency_vs_gpus.png" - plt.tight_layout() - plt.savefig(p, dpi=200) - plt.close() - - # Optional: CPU vs best GPU throughput comparison - if include_cpu_speedup and cpu_sorted and gpu_sorted: - cpu_thr = float(cpu_sorted[0]["train_throughput_sps"]) - best_gpu = max(gpu_sorted, key=lambda d: float(d["train_throughput_sps"])) - best_thr = float(best_gpu["train_throughput_sps"]) - ratio = best_thr / max(cpu_thr, 1e-9) - - plt.figure(figsize=(8, 5)) - labels = ["CPU (1)", f"GPU ({best_gpu['devices']})"] - vals = [cpu_thr, best_thr] - plt.bar(labels, vals) - plt.ylabel("Train throughput (samples/sec, global)") - plt.title(f"CPU vs best GPU throughput (GPU/CPU = {ratio:.2f}x)") - plt.grid(True, axis="y", linestyle="--", linewidth=0.6, alpha=0.6) - p = out_dir / "cpu_vs_best_gpu_throughput.png" - plt.tight_layout() - plt.savefig(p, dpi=200) - plt.close() - - # Write CSV for convenience - out_csv = _write_csv(out_dir, rows) - print(f"Wrote plots + {out_csv}") - - -def sweep(args: argparse.Namespace) -> None: - out_dir = Path(args.out_dir).resolve() - _ensure_dir(out_dir) - - # Write run config - run_cfg = vars(args).copy() - run_cfg["timestamp"] = _now_ts() - run_cfg["git_commit"] = _maybe_git_commit() - with (out_dir / "run_config.json").open("w") as f: - json.dump(run_cfg, f, indent=2) - - # Build base run args (forwarded to run subcommand) - base_run = [ - sys.executable, - str(Path(__file__).resolve()), - "run", - "--n", str(args.n), - "--c_dim", str(args.c_dim), - "--x_dim", str(args.x_dim), - "--y_dim", str(args.y_dim), - "--noise", str(args.noise), - "--val_frac", str(args.val_frac), - "--test_frac", str(args.test_frac), - "--epochs", str(args.epochs), - "--train_batch_size", str(args.train_batch_size), - "--val_batch_size", str(args.val_batch_size), - "--test_batch_size", str(args.test_batch_size), - "--num_workers", str(args.num_workers), - "--learning_rate", str(args.learning_rate), - "--num_archetypes", str(args.num_archetypes), - "--encoder_type", str(args.encoder_type), - "--backend", str(args.backend), - "--seed", str(args.seed), - "--out_dir", str(out_dir), - ] - if args.deterministic: - base_run.append("--deterministic") - - # 1) CPU (single proc) - if args.include_cpu: - cmd = base_run + ["--accelerator", "cpu", "--devices", "1", "--tag", "cpu_1dev"] - print("\n=== Running CPU (1 device) ===") - subprocess.run(cmd, check=True) - - # 2) GPU sweeps - for k in range(1, args.max_gpus + 1): - tag = f"gpu_{k}dev" - print(f"\n=== Running GPU ({k} device{'s' if k > 1 else ''}) ===") - if k == 1 and not args.force_torchrun_for_1gpu: - # Single process GPU - cmd = base_run + ["--accelerator", "gpu", "--devices", "1", "--tag", tag] - subprocess.run(cmd, check=True) - else: - # Multi-process launch (also works for 1 GPU if forced) - cmd = [ - sys.executable, "-m", "torch.distributed.run", - "--standalone", - f"--nproc_per_node={k}", - str(Path(__file__).resolve()), - "run", - "--accelerator", "gpu", - "--devices", str(k), - "--tag", tag, - "--out_dir", str(out_dir), - "--n", str(args.n), - "--c_dim", str(args.c_dim), - "--x_dim", str(args.x_dim), - "--y_dim", str(args.y_dim), - "--noise", str(args.noise), - "--val_frac", str(args.val_frac), - "--test_frac", str(args.test_frac), - "--epochs", str(args.epochs), - "--train_batch_size", str(args.train_batch_size), - "--val_batch_size", str(args.val_batch_size), - "--test_batch_size", str(args.test_batch_size), - "--num_workers", str(args.num_workers), - "--learning_rate", str(args.learning_rate), - "--num_archetypes", str(args.num_archetypes), - "--encoder_type", str(args.encoder_type), - "--backend", str(args.backend), - "--seed", str(args.seed), - ] - if args.deterministic: - cmd.append("--deterministic") - subprocess.run(cmd, check=True) - - # Plot at end - plot_results(out_dir, baseline_devices=args.speedup_baseline_gpu) - - -# ------------------------- -# CLI -# ------------------------- - -def build_parser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(description="Contextualized regression scaling benchmark (synthetic).") - sub = p.add_subparsers(dest="cmd", required=True) - - common = argparse.ArgumentParser(add_help=False) - common.add_argument("--out_dir", type=str, required=True, help="Output directory for results and plots.") - common.add_argument("--tag", type=str, default="", help="Tag for this run; used in filename result_.json") - - common.add_argument("--n", type=int, default=200_000) - common.add_argument("--c_dim", type=int, default=16) - common.add_argument("--x_dim", type=int, default=64) - common.add_argument("--y_dim", type=int, default=8) - common.add_argument("--noise", type=float, default=0.5) - - common.add_argument("--val_frac", type=float, default=0.2) - common.add_argument("--test_frac", type=float, default=0.1) - - common.add_argument("--epochs", type=int, default=5) - common.add_argument("--train_batch_size", type=int, default=2048) - common.add_argument("--val_batch_size", type=int, default=2048) - common.add_argument("--test_batch_size", type=int, default=4096) - common.add_argument("--num_workers", type=int, default=4) - - common.add_argument("--learning_rate", type=float, default=1e-3) - common.add_argument("--num_archetypes", type=int, default=0) - common.add_argument("--encoder_type", type=str, default="mlp") - - common.add_argument("--seed", type=int, default=123) - common.add_argument("--backend", type=str, default="nccl", choices=["nccl", "gloo"]) - - common.add_argument("--deterministic", action="store_true", help="Best-effort deterministic training.") - - # run - pr = sub.add_parser("run", parents=[common], help="Run one benchmark configuration.") - pr.add_argument("--accelerator", type=str, required=True, choices=["cpu", "gpu"]) - pr.add_argument("--devices", type=int, required=True) - - # sweep - ps = sub.add_parser("sweep", parents=[common], help="Run CPU + GPU(1..K) sweep and plot.") - ps.add_argument("--include_cpu", action="store_true", help="Include a CPU baseline run.") - ps.add_argument("--max_gpus", type=int, default=4, help="Max GPUs to sweep up to (inclusive).") - ps.add_argument("--force_torchrun_for_1gpu", action="store_true", - help="Also launch 1-GPU via torch.distributed.run for consistency.") - ps.add_argument("--speedup_baseline_gpu", type=int, default=1, - help="Baseline GPU count for speedup/efficiency plots (default: 1).") - - return p - - -def main() -> None: - parser = build_parser() - args = parser.parse_args() - - if args.cmd == "run": - # In DDP, only rank0 writes; non-rank0 returns None - run_one(args) - - elif args.cmd == "sweep": - sweep(args) - - else: - raise RuntimeError(f"Unknown cmd: {args.cmd}") - - -if __name__ == "__main__": - main() diff --git a/contextualized/easy/ContextualizedNetworks.py b/contextualized/easy/ContextualizedNetworks.py index 9e18e7b3..04e11aa2 100644 --- a/contextualized/easy/ContextualizedNetworks.py +++ b/contextualized/easy/ContextualizedNetworks.py @@ -5,7 +5,9 @@ 1) When using a LightningDataModule outside Trainer.fit/predict, you MUST call dm.setup(stage="predict") before dm.predict_dataloader(). 2) Under DDP, prediction helpers are rank-0 only (by design in your trainers/wrapper). - We therefore early-return None on non-rank0 to avoid constructing np.array([None,...]). + We therefore avoid constructing np.array([None,...]) and return None on non-rank0, + while still executing the full per-model predict loop on all ranks to prevent + collective mismatches/hangs. """ from typing import List, Tuple, Union, Optional @@ -101,7 +103,6 @@ def predict_networks( return (betas, mus) if with_offsets else betas - def predict_X( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs ) -> Union[np.ndarray, List[np.ndarray]]: @@ -118,9 +119,7 @@ class ContextualizedCorrelationNetworks(ContextualizedNetworks): """ def __init__(self, **kwargs): - super().__init__( - ContextualizedCorrelation, [], [], CorrelationTrainer, **kwargs - ) + super().__init__(ContextualizedCorrelation, [], [], CorrelationTrainer, **kwargs) def predict_correlation( self, C: np.ndarray, individual_preds: bool = True, squared: bool = True @@ -129,9 +128,9 @@ def predict_correlation( Returns per-sample correlation matrices (or squared correlations). DDP behavior: - - All ranks must execute the predict loop to avoid collective mismatches. + - All ranks must execute the full per-model predict loop to avoid collective mismatches. - rank0 returns arrays - - non-rank0 returns None (propagated from trainer) + - non-rank0 returns None (rank-0-only trainer outputs are propagated) """ C_scaled = self._maybe_scale_C(C) Y_zero = np.zeros((len(C_scaled), self.x_dim), dtype=np.float32) @@ -159,18 +158,24 @@ def predict_correlation( task_type="singletask_univariate", # correlation uses univariate convention ) - # CRITICAL FIX: setup before calling predict_dataloader() when not using Trainer.predict(datamodule=...) + # FIX (1): setup before calling predict_dataloader() when not using Trainer.predict(datamodule=...) dm.setup(stage="predict") pred_loader = dm.predict_dataloader() + saw_none = False rhos_list = [] + + # FIX (2): call predict for all models on all ranks; only rank0 accumulates results for i in range(len(self.models)): rho_i = self.trainers[i].predict_correlation(self.models[i], pred_loader) if rho_i is None: - # non-rank0 under DDP - return None + saw_none = True + continue rhos_list.append(rho_i) + if saw_none: + return None + rhos = np.array(rhos_list) if individual_preds: @@ -179,7 +184,6 @@ def predict_correlation( mean_rhos = np.mean(rhos, axis=0) return np.square(mean_rhos) if squared else mean_rhos - def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False ) -> Union[np.ndarray, List[np.ndarray], None]: @@ -218,7 +222,7 @@ def measure_mses( X_true = X_eff[None, :, :] residuals = X_hat - X_true - mses = (residuals ** 2).mean(axis=-1) + mses = (residuals**2).mean(axis=-1) else: B, N_hat, F1, F2 = X_hat.shape @@ -237,7 +241,7 @@ def measure_mses( X_true = X_eff[None, :, :, None] residuals = X_hat - X_true - mses = (residuals ** 2).mean(axis=(-1, -2)) + mses = (residuals**2).mean(axis=(-1, -2)) return mses if individual_preds else mses.mean(axis=0) @@ -257,9 +261,9 @@ def predict_precisions( Predicts context-specific precision matrices. DDP behavior: - - All ranks must execute the predict loop to avoid collective mismatches. + - All ranks must execute the full per-model predict loop to avoid collective mismatches. - rank0 returns arrays - - non-rank0 returns None (propagated from trainer) + - non-rank0 returns None (rank-0-only trainer outputs are propagated) """ C_scaled = self._maybe_scale_C(C) Y_zero = np.zeros((len(C_scaled), self.x_dim), dtype=np.float32) @@ -287,22 +291,27 @@ def predict_precisions( task_type="singletask_univariate", ) - # CRITICAL FIX: setup before calling predict_dataloader() + # FIX (1): setup before calling predict_dataloader() dm.setup(stage="predict") pred_loader = dm.predict_dataloader() + saw_none = False prec_list = [] + + # FIX (2): call predict for all models on all ranks; only rank0 accumulates results for i in range(len(self.models)): p_i = self.trainers[i].predict_precision(self.models[i], pred_loader) if p_i is None: - # non-rank0 under DDP - return None + saw_none = True + continue prec_list.append(p_i) + if saw_none: + return None + precisions = np.array(prec_list) return precisions if individual_preds else np.mean(precisions, axis=0) - def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False ) -> Union[np.ndarray, List[np.ndarray], None]: @@ -319,10 +328,7 @@ def measure_mses( for b in range(len(betas)): for i in range(F): preds = np.array( - [ - X[j].dot(betas[b, j, i, :]) + mus[b, j, i] - for j in range(len(X)) - ] + [X[j].dot(betas[b, j, i, :]) + mus[b, j, i] for j in range(len(X))] ) residuals = X[:, i] - preds mses[b, :] += residuals**2 / F @@ -339,9 +345,7 @@ def _parse_private_init_kwargs(self, **kwargs): Parse NOTMAD kwargs into model init dicts. """ self._init_kwargs["model"]["encoder_kwargs"] = { - "type": kwargs.pop( - "encoder_type", self._init_kwargs["model"]["encoder_type"] - ), + "type": kwargs.pop("encoder_type", self._init_kwargs["model"]["encoder_type"]), "params": { "width": self.constructor_kwargs["encoder_kwargs"]["width"], "layers": self.constructor_kwargs["encoder_kwargs"]["layers"], @@ -349,9 +353,7 @@ def _parse_private_init_kwargs(self, **kwargs): }, } - archetype_dag_loss_type = kwargs.pop( - "archetype_dag_loss_type", DEFAULT_DAG_LOSS_TYPE - ) + archetype_dag_loss_type = kwargs.pop("archetype_dag_loss_type", DEFAULT_DAG_LOSS_TYPE) self._init_kwargs["model"]["archetype_loss_params"] = { "l1": kwargs.get("archetype_l1", 0.0), "dag": kwargs.get( @@ -378,9 +380,9 @@ def _parse_private_init_kwargs(self, **kwargs): for param, value in self._init_kwargs["model"]["archetype_loss_params"]["dag"][ "params" ].items(): - self._init_kwargs["model"]["archetype_loss_params"]["dag"]["params"][ - param - ] = kwargs.pop(f"archetype_{param}", value) + self._init_kwargs["model"]["archetype_loss_params"]["dag"]["params"][param] = ( + kwargs.pop(f"archetype_{param}", value) + ) sample_specific_dag_loss_type = kwargs.pop( "sample_specific_dag_loss_type", DEFAULT_DAG_LOSS_TYPE @@ -398,12 +400,12 @@ def _parse_private_init_kwargs(self, **kwargs): }, ), } - for param, value in self._init_kwargs["model"]["sample_specific_loss_params"][ - "dag" - ]["params"].items(): - self._init_kwargs["model"]["sample_specific_loss_params"]["dag"]["params"][ - param - ] = kwargs.pop(f"sample_specific_{param}", value) + for param, value in self._init_kwargs["model"]["sample_specific_loss_params"]["dag"][ + "params" + ].items(): + self._init_kwargs["model"]["sample_specific_loss_params"]["dag"]["params"][param] = ( + kwargs.pop(f"sample_specific_{param}", value) + ) self._init_kwargs["model"]["opt_params"] = { "learning_rate": kwargs.pop("learning_rate", 1e-3), @@ -469,9 +471,7 @@ def predict_networks( """ if kwargs.pop("with_offsets", False): print("No offsets can be returned by NOTMAD.") - betas = self.predict_params( - C, uses_y=False, project_to_dag=project_to_dag, **kwargs - ) + betas = self.predict_params(C, uses_y=False, project_to_dag=project_to_dag, **kwargs) return betas def measure_mses( diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index f054836e..efee1506 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -1,26 +1,43 @@ -# --- imports you need above the class --- +""" +An sklearn-like wrapper for Contextualized models. + +Design goals (compat + correctness): +- Preserve prior public API: fit(), predict(), predict_params(), kwarg routing, normalization. +- Default to the DDP-safe, map-style ContextualizedRegressionDataModule when available. +- Avoid redundant DDP gather/ordering logic in the wrapper: + * DDP-safe predict assembly is handled by contextualized.regression.trainers.RegressionTrainer + (predict_y / predict_params) together with lightning_modules.py predict_step payloads. +- Keep legacy compatibility for older models that still expose `model.dataloader(...)`. +""" + import copy import os -from typing import * +from typing import Any, Dict, List, Optional, Tuple, Union + import numpy as np import torch import torch.distributed as dist -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import StandardScaler -from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.plugins.environments import LightningEnvironment +from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.strategies import DDPStrategy +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler from contextualized.functions import LINK_FUNCTIONS -from contextualized.regression import REGULARIZERS, LOSSES -from contextualized.regression.datamodules import ContextualizedRegressionDataModule +from contextualized.regression import LOSSES, REGULARIZERS + +# Prefer the new, DDP-safe DataModule path when available. +try: + from contextualized.regression.datamodules import ContextualizedRegressionDataModule +except Exception: # pragma: no cover + ContextualizedRegressionDataModule = None # type: ignore + DEFAULT_LEARNING_RATE = 1e-3 DEFAULT_N_BOOTSTRAPS = 1 DEFAULT_ES_PATIENCE = 1 DEFAULT_VAL_BATCH_SIZE = 16 -DEFAULT_TRAIN_BATCH_SIZE = 64 +DEFAULT_TRAIN_BATCH_SIZE = 1 # keep legacy default DEFAULT_TEST_BATCH_SIZE = 16 DEFAULT_VAL_SPLIT = 0.2 DEFAULT_ENCODER_TYPE = "mlp" @@ -30,168 +47,43 @@ DEFAULT_NORMALIZE = False -def _is_distributed() -> bool: - """Check if we're in a distributed context.""" +def _dist_initialized() -> bool: return dist.is_available() and dist.is_initialized() -def _get_rank() -> int: - """Get current process rank.""" - if _is_distributed(): - return dist.get_rank() +def _rank() -> int: + if _dist_initialized(): + return int(dist.get_rank()) return int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0"))) def _is_main_process() -> bool: - """Check if this is the main process (rank 0).""" - return _get_rank() == 0 - -def _flatten_pl_predict_output(preds): - """ - Lightning can return: - - list[dict] (single dataloader) - - list[list[dict]] (multiple dataloaders) - Normalize to list[dict]. - """ - if preds is None: - return [] - if len(preds) > 0 and isinstance(preds[0], list): - out = [] - for sub in preds: - out.extend(sub) - return out - return preds - - -def _pack_local_pred_payload(pred_list: list) -> dict: - """ - Convert list[dict] -> dict[str, np.ndarray] by concatenating along axis 0. - Assumes each dict entry is either a torch.Tensor (CPU) or a Python scalar. - """ - pred_list = _flatten_pl_predict_output(pred_list) - if not pred_list: - return {} - - # Union of keys across batches (some models include extra keys) - keys = set() - for d in pred_list: - keys.update(d.keys()) - - packed = {} - for k in keys: - chunks = [] - for d in pred_list: - if k not in d: - continue - v = d[k] - if torch.is_tensor(v): - chunks.append(v.detach().cpu().numpy()) - else: - chunks.append(np.asarray(v)) - if not chunks: - continue - # Concatenate on first dim where possible; fallback to stack - try: - packed[k] = np.concatenate(chunks, axis=0) - except Exception: - packed[k] = np.stack(chunks, axis=0) - return packed + return _rank() == 0 -def _gather_object_to_rank0(obj): - """ - Gather arbitrary Python objects to rank 0. - Returns: list[obj] on rank 0, None on non-zero ranks. - """ - if not _is_distributed(): - return [obj] - - world_size = dist.get_world_size() - if world_size == 1: - return [obj] - - if _is_main_process(): - gathered = [None for _ in range(world_size)] - dist.gather_object(obj, object_gather_list=gathered, dst=0) - return gathered - else: - dist.gather_object(obj, object_gather_list=None, dst=0) - return None - - -def _merge_packed_payloads(payloads: list) -> dict: - """ - Merge list[dict[str, np.ndarray]] -> dict[str, np.ndarray] by concatenation axis 0. - """ - merged = {} - if not payloads: - return merged - - keys = set() - for p in payloads: - if p: - keys.update(p.keys()) - - for k in keys: - chunks = [p[k] for p in payloads if p and (k in p) and (p[k] is not None) and (len(p[k]) > 0)] - if not chunks: - continue - merged[k] = np.concatenate(chunks, axis=0) - return merged - - -def _stable_sort_and_dedupe_by_key(payload: dict, primary: str, secondary: tuple = ()) -> dict: - """ - Sort payload arrays by a composite key (primary + optional secondary indices), - then dedupe (needed because DistributedSampler may pad/duplicate). - """ - if (payload is None) or (primary not in payload) or (len(payload[primary]) == 0): - return payload - - primary_arr = payload[primary].astype(np.int64) - - # Build composite key - if secondary: - parts = [primary_arr] - for s in secondary: - if s in payload: - parts.append(payload[s].astype(np.int64)) - if len(parts) == 1: - key = primary_arr - else: - # lexsort uses last key as primary; reverse order - order = np.lexsort(tuple(reversed(parts))) - key_sorted = np.stack([p[order] for p in parts], axis=1) - # Dedup by full composite row - _, uniq_pos = np.unique(key_sorted, axis=0, return_index=True) - keep = order[np.sort(uniq_pos)] - else: - order = np.argsort(primary_arr, kind="mergesort") - key_sorted = primary_arr[order] - _, uniq_pos = np.unique(key_sorted, return_index=True) - keep = order[np.sort(uniq_pos)] - - out = {} - for k, v in payload.items(): - if isinstance(v, np.ndarray) and (v.shape[0] == primary_arr.shape[0]): - out[k] = v[keep] - else: - out[k] = v - return out - +def _world_size_env() -> int: + try: + return int(os.environ.get("WORLD_SIZE", "1")) + except Exception: + return 1 class SKLearnWrapper: """ An sklearn-like wrapper for Contextualized models. - - FIXED VERSION with proper DDP handling for: - - Prediction (avoids duplicate computation) - - Data loading (proper num_workers) - - Distributed inference + + Args: + base_constructor (callable/class): LightningModule constructor for the model. + extra_model_kwargs (list[str] or set[str]): extra kw names allowed in "model". + extra_data_kwargs (list[str] or set[str]): extra kw names allowed in "data". + trainer_constructor (class): Trainer class (should provide predict_y / predict_params for DDP-safe inference). + **kwargs: routed into model/data/trainer/wrapper based on acceptable_kwargs. """ - def _set_defaults(self): + # ---------------------------- + # Defaults / initialization + # ---------------------------- + def _set_defaults(self) -> None: self.default_learning_rate = DEFAULT_LEARNING_RATE self.default_n_bootstraps = DEFAULT_N_BOOTSTRAPS self.default_es_patience = DEFAULT_ES_PATIENCE @@ -214,28 +106,32 @@ def __init__( **kwargs, ): self._set_defaults() + self.base_constructor = base_constructor self.trainer_constructor = trainer_constructor + # Optional: allow callers to pass default trainer kwargs in a single dict self._trainer_init_kwargs = kwargs.pop("trainer_kwargs", None) - self.n_bootstraps = 1 - self.models = None - self.trainers = None - - # Track if we trained with DDP (affects prediction strategy) - self._trained_with_ddp = False - self._trained_devices = 1 - - self.normalize = kwargs.pop("normalize", self.default_normalize) - self.scalers = {"C": None, "X": None, "Y": None} - self.context_dim = None - self.x_dim = None - self.y_dim = None - self.accelerator = "cuda" if torch.cuda.is_available() else "cpu" - - # Accepted kwarg routes - self.acceptable_kwargs = { + self.n_bootstraps: int = 1 + self.models: Optional[List[Any]] = None + self.trainers: Optional[List[Any]] = None + + # Keep legacy attribute for external users who expect it + self.dataloaders: Optional[Dict[str, List[Any]]] = None + + self.normalize: bool = bool(kwargs.pop("normalize", self.default_normalize)) + self.scalers: Dict[str, Optional[StandardScaler]] = {"C": None, "X": None, "Y": None} + + self.context_dim: Optional[int] = None + self.x_dim: Optional[int] = None + self.y_dim: Optional[int] = None + + # Lightning expects "gpu" / "cpu" (legacy wrapper used "gpu") + self.accelerator: str = "gpu" if torch.cuda.is_available() else "cpu" + + # Expanded routing (superset of legacy); safe for backward compatibility. + self.acceptable_kwargs: Dict[str, List[str]] = { "data": [ "train_batch_size", "val_batch_size", @@ -265,6 +161,10 @@ def __init__( "context_dim", "x_dim", "y_dim", + # legacy-friendly knobs + "width", + "layers", + "encoder_link_fn", ], "trainer": [ "max_epochs", @@ -296,14 +196,11 @@ def __init__( "normalize", ], } + self._update_acceptable_kwargs("model", extra_model_kwargs) self._update_acceptable_kwargs("data", extra_data_kwargs) - self._update_acceptable_kwargs( - "model", kwargs.pop("remove_model_kwargs", []), acceptable=False - ) - self._update_acceptable_kwargs( - "data", kwargs.pop("remove_data_kwargs", []), acceptable=False - ) + self._update_acceptable_kwargs("model", kwargs.pop("remove_model_kwargs", []), acceptable=False) + self._update_acceptable_kwargs("data", kwargs.pop("remove_data_kwargs", []), acceptable=False) self.convenience_kwargs = [ "alpha", @@ -315,51 +212,66 @@ def __init__( "encoder_link_fn", ] + # Build model-constructor defaults (and allow legacy + new encoder_kwargs styles) self.constructor_kwargs = self._organize_constructor_kwargs(**kwargs) - self.constructor_kwargs["encoder_kwargs"]["width"] = kwargs.pop( - "width", self.constructor_kwargs["encoder_kwargs"]["width"] - ) - self.constructor_kwargs["encoder_kwargs"]["layers"] = kwargs.pop( - "layers", self.constructor_kwargs["encoder_kwargs"]["layers"] - ) - self.constructor_kwargs["encoder_kwargs"]["link_fn"] = kwargs.pop( - "encoder_link_fn", - self.constructor_kwargs["encoder_kwargs"].get( - "link_fn", self.default_encoder_link_fn - ), - ) + # Apply convenience overrides (legacy keys) + if "encoder_kwargs" in self.constructor_kwargs: + ek = self.constructor_kwargs["encoder_kwargs"] + ek["width"] = kwargs.pop("width", ek.get("width", self.default_encoder_width)) + ek["layers"] = kwargs.pop("layers", ek.get("layers", self.default_encoder_layers)) + ek["link_fn"] = kwargs.pop("encoder_link_fn", ek.get("link_fn", self.default_encoder_link_fn)) + else: + self.constructor_kwargs["width"] = kwargs.pop("width", self.constructor_kwargs.get("width", self.default_encoder_width)) + self.constructor_kwargs["layers"] = kwargs.pop("layers", self.constructor_kwargs.get("layers", self.default_encoder_layers)) + self.constructor_kwargs["encoder_link_fn"] = kwargs.pop( + "encoder_link_fn", + self.constructor_kwargs.get("encoder_link_fn", self.default_encoder_link_fn), + ) + + # Store remaining kwargs to be organized by router self.not_constructor_kwargs = { - k: v - for k, v in kwargs.items() - if k not in self.constructor_kwargs and k not in self.convenience_kwargs + k: v for k, v in kwargs.items() if k not in self.constructor_kwargs and k not in self.convenience_kwargs } - self._init_kwargs, unrecognized = self._organize_kwargs( - **self.not_constructor_kwargs - ) + self._init_kwargs, unrecognized = self._organize_kwargs(**self.not_constructor_kwargs) + + # Inject constructor kwargs into model bucket for k, v in self.constructor_kwargs.items(): self._init_kwargs["model"][k] = v - if self._trainer_init_kwargs is not None: + # Inject trainer init kwargs + if isinstance(self._trainer_init_kwargs, dict): self._init_kwargs["trainer"].update(self._trainer_init_kwargs) + # Allow subclasses to swallow additional init kwargs + recognized_private = set(self._parse_private_init_kwargs(**kwargs)) for kw in unrecognized: - print(f"Received unknown keyword argument {kw}, probably ignoring.") + if kw not in recognized_private: + print(f"Received unknown keyword argument {kw}, probably ignoring.") + + # ---------------------------- + # Hooks for subclasses + # ---------------------------- + def _parse_private_fit_kwargs(self, **kwargs) -> List[str]: + return [] + + def _parse_private_init_kwargs(self, **kwargs) -> List[str]: + return [] - def _update_acceptable_kwargs(self, category, new_kwargs, acceptable=True): + # ---------------------------- + # Kwarg routing / organization + # ---------------------------- + def _update_acceptable_kwargs(self, category, new_kwargs, acceptable: bool = True) -> None: + new_kwargs = list(new_kwargs) if new_kwargs is not None else [] if acceptable: - self.acceptable_kwargs[category] = list( - set(self.acceptable_kwargs[category]).union(set(new_kwargs)) - ) + self.acceptable_kwargs[category] = list(set(self.acceptable_kwargs[category]).union(set(new_kwargs))) else: - self.acceptable_kwargs[category] = list( - set(self.acceptable_kwargs[category]) - set(new_kwargs) - ) + self.acceptable_kwargs[category] = list(set(self.acceptable_kwargs[category]) - set(new_kwargs)) - def _organize_kwargs(self, **kwargs): + def _organize_kwargs(self, **kwargs) -> Tuple[Dict[str, Dict[str, Any]], List[str]]: out = {cat: {} for cat in self.acceptable_kwargs} - unknown = [] + unknown: List[str] = [] for k, v in kwargs.items(): placed = False for cat, allowed in self.acceptable_kwargs.items(): @@ -371,135 +283,257 @@ def _organize_kwargs(self, **kwargs): unknown.append(k) return out, unknown - def _organize_constructor_kwargs(self, **kwargs): - model = {} + def _organize_constructor_kwargs(self, **kwargs) -> Dict[str, Any]: + """ + Create default model constructor kwargs, supporting both: + - new style: encoder_kwargs={width,layers,link_fn} + - legacy style: width/layers/encoder_link_fn top-level + """ + ctor: Dict[str, Any] = {} def maybe_add(kw, default_val): if kw in self.acceptable_kwargs["model"]: - model[kw] = kwargs.get(kw, default_val) + ctor[kw] = kwargs.get(kw, default_val) maybe_add("link_fn", LINK_FUNCTIONS["identity"]) maybe_add("univariate", False) - maybe_add("encoder_type", DEFAULT_ENCODER_TYPE) + maybe_add("encoder_type", self.default_encoder_type) maybe_add("loss_fn", LOSSES["mse"]) - maybe_add( - "encoder_kwargs", - { - "width": kwargs.get("encoder_width", DEFAULT_ENCODER_WIDTH), - "layers": kwargs.get("encoder_layers", DEFAULT_ENCODER_LAYERS), - "link_fn": kwargs.get("encoder_link_fn", DEFAULT_ENCODER_LINK_FN), - }, - ) - if kwargs.get("subtype_probabilities", False): - model["encoder_kwargs"]["link_fn"] = LINK_FUNCTIONS["softmax"] + # Prefer new style if allowed + if "encoder_kwargs" in self.acceptable_kwargs["model"]: + ctor["encoder_kwargs"] = kwargs.get( + "encoder_kwargs", + { + "width": kwargs.get("encoder_width", self.default_encoder_width), + "layers": kwargs.get("encoder_layers", self.default_encoder_layers), + "link_fn": kwargs.get("encoder_link_fn", self.default_encoder_link_fn), + }, + ) + if kwargs.get("subtype_probabilities", False): + ctor["encoder_kwargs"]["link_fn"] = LINK_FUNCTIONS["softmax"] + else: + maybe_add("width", self.default_encoder_width) + maybe_add("layers", self.default_encoder_layers) + maybe_add("encoder_link_fn", self.default_encoder_link_fn) + if kwargs.get("subtype_probabilities", False): + ctor["encoder_link_fn"] = LINK_FUNCTIONS["softmax"] + + # Regularizer if "model_regularizer" in self.acceptable_kwargs["model"]: - if kwargs.get("alpha", 0) > 0: - model["model_regularizer"] = REGULARIZERS["l1_l2"]( - kwargs["alpha"], + alpha = float(kwargs.get("alpha", 0.0) or 0.0) + if alpha > 0: + ctor["model_regularizer"] = REGULARIZERS["l1_l2"]( + alpha, kwargs.get("l1_ratio", 1.0), kwargs.get("mu_ratio", 0.5), ) else: - model["model_regularizer"] = kwargs.get( - "model_regularizer", REGULARIZERS["none"] - ) - return model + ctor["model_regularizer"] = kwargs.get("model_regularizer", REGULARIZERS["none"]) + + return ctor + + # ---------------------------- + # Utilities + # ---------------------------- + def _maybe_scale_C(self, C: np.ndarray) -> np.ndarray: + if self.normalize and self.scalers["C"] is not None: + return self.scalers["C"].transform(C) + return C + + def _maybe_scale_X(self, X: np.ndarray) -> np.ndarray: + if self.normalize and self.scalers["X"] is not None: + return self.scalers["X"].transform(X) + return X + + def _nanrobust_mean(self, arr: np.ndarray, axis: int = 0) -> np.ndarray: + if not np.isfinite(arr).all(): + arr = np.where(np.isfinite(arr), arr, np.nan) + with np.errstate(invalid="ignore"): + mean = np.nanmean(arr, axis=axis) + if np.isnan(mean).any(): + raise RuntimeError("All bootstraps produced non-finite predictions for some items.") + return mean - @staticmethod - def _retarget_or_strip_early_stopping(cb, use_val: bool, train_monitor="train_loss"): - try: - from pytorch_lightning.callbacks.early_stopping import EarlyStopping as _ES - except Exception: - return cb - if not isinstance(cb, _ES): - return cb - if use_val: - return cb - monitor = getattr(cb, "monitor", None) - if (monitor is None) or (isinstance(monitor, str) and monitor.startswith("val_")): - return _ES( - monitor=train_monitor, - mode=getattr(cb, "mode", "min"), - patience=getattr(cb, "patience", 1), - verbose=getattr(cb, "verbose", False), - min_delta=getattr(cb, "min_delta", 0.0), - ) - return cb - def _default_num_workers(self, devices: int) -> int: - """ - Heuristic for default DataLoader workers. - FIXED: CPU also benefits from workers for I/O overlap. - """ try: n_cpu = os.cpu_count() or 0 except Exception: n_cpu = 0 - if n_cpu <= 0: return 0 - - # For CPU-only, still use some workers for data loading overlap - if self.accelerator not in ("cuda", "gpu"): + if self.accelerator != "gpu": return min(2, n_cpu) - world_size_env = os.environ.get("WORLD_SIZE", None) - if world_size_env is not None: - try: - world_size = max(1, int(world_size_env)) - except ValueError: - world_size = 1 - else: - world_size = max(1, devices) - - cpu_per_rank = max(1, n_cpu // world_size) - # 2-4 workers per rank, capped + world = max(1, _world_size_env() if _world_size_env() > 1 else devices) + cpu_per_rank = max(1, n_cpu // world) return int(min(4, max(2, cpu_per_rank // 2))) - def _organize_and_expand_fit_kwargs(self, **kwargs): + def _safe_val_split(self, n: int, val_split: float) -> float: + vs = float(val_split) + if vs <= 0.0: + return 0.0 + # require at least 2 validation samples for stable metrics + if int(round(n * vs)) < 2: + return 0.0 + return vs + + def _resolve_train_val_arrays( + self, + C: np.ndarray, + X: np.ndarray, + Y: Optional[np.ndarray], + *, + C_val: Optional[np.ndarray], + X_val: Optional[np.ndarray], + Y_val: Optional[np.ndarray], + Y_required: bool, + val_split: float, + random_state: int = 42, + shuffle: bool = True, + ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], np.ndarray, Optional[np.ndarray]]: """ - Expand/normalize kwargs for data/model/trainer/wrapper/fit. - FIXED: Better DDP defaults and tracking. + Returns: + C_all, X_all, Y_all, train_idx, val_idx_or_None + + Supports: + - val arrays (C_val/X_val[/Y_val]) by concatenation + - otherwise uses val_split indices inside the original arrays """ + if C_val is not None and X_val is not None and (not Y_required or Y_val is not None): + # concatenate and build disjoint train/val indices + n_tr = int(C.shape[0]) + C_all = np.concatenate([C, C_val], axis=0) + X_all = np.concatenate([X, X_val], axis=0) + if Y is None: + Y_all = None + else: + if Y_val is None and Y_required: + raise ValueError("Y_val is required when Y is provided.") + Y_all = np.concatenate([Y, Y_val], axis=0) if Y_val is not None else Y + + train_idx = np.arange(n_tr) + val_idx = np.arange(n_tr, int(C_all.shape[0])) + return C_all, X_all, Y_all, train_idx, val_idx + + # split by indices within the same arrays + n = int(C.shape[0]) + vs = self._safe_val_split(n, val_split) + if vs <= 0.0: + return C, X, Y, np.arange(n), None + + tr_idx, va_idx = train_test_split( + np.arange(n), + test_size=vs, + shuffle=shuffle, + random_state=random_state, # fixed seed to keep DDP ranks consistent + ) + return C, X, Y, tr_idx, va_idx + + def _build_datamodule( + self, + C: np.ndarray, + X: np.ndarray, + Y: Optional[np.ndarray], + *, + train_idx: Optional[np.ndarray], + val_idx: Optional[np.ndarray], + test_idx: Optional[np.ndarray], + predict_idx: Optional[np.ndarray], + data_kwargs: Dict[str, Any], + task_type: str, + ): + if ContextualizedRegressionDataModule is None: + raise RuntimeError("ContextualizedRegressionDataModule is not available in this installation.") + + dk = { + "train_batch_size": self.default_train_batch_size, + "val_batch_size": self.default_val_batch_size, + "test_batch_size": self.default_test_batch_size, + "predict_batch_size": self.default_val_batch_size, + "num_workers": 0, + "pin_memory": (self.accelerator == "gpu"), + "persistent_workers": False, + "drop_last": False, + "shuffle_train": True, + "shuffle_eval": False, + "dtype": torch.float, + } + dk.update(data_kwargs or {}) + + return ContextualizedRegressionDataModule( + C=C, + X=X, + Y=Y, + task_type=task_type, + train_idx=train_idx, + val_idx=val_idx, + test_idx=test_idx, + predict_idx=predict_idx, + train_batch_size=dk["train_batch_size"], + val_batch_size=dk["val_batch_size"], + test_batch_size=dk["test_batch_size"], + predict_batch_size=dk["predict_batch_size"], + num_workers=dk["num_workers"], + pin_memory=dk["pin_memory"], + persistent_workers=dk["persistent_workers"], + drop_last=dk["drop_last"], + shuffle_train=dk["shuffle_train"], + shuffle_eval=dk["shuffle_eval"], + dtype=dk["dtype"], + ) + + def _use_datamodule_for_model(self, model: Any) -> bool: + # Prefer DataModule when available and model doesn't provide legacy dataloader(). + if ContextualizedRegressionDataModule is None: + return False + return not callable(getattr(model, "dataloader", None)) + + # ---------------------------- + # Fit kwargs expansion + # ---------------------------- + def _organize_and_expand_fit_kwargs(self, **kwargs) -> Dict[str, Dict[str, Any]]: organized, unrecognized = self._organize_kwargs(**kwargs) + recognized_private = set(self._parse_private_fit_kwargs(**kwargs)) + for kw in unrecognized: + if kw not in recognized_private: + print(f"Received unknown keyword argument {kw}, probably ignoring.") + # Merge init defaults (fit kwargs win) for category, cat_kwargs in self._init_kwargs.items(): for k, v in cat_kwargs.items(): organized[category].setdefault(k, v) - max_epochs_cli = kwargs.get("max_epochs", None) - epochs_cli = kwargs.get("epochs", None) - if max_epochs_cli is not None: - organized["trainer"]["max_epochs"] = int(max_epochs_cli) - elif epochs_cli is not None: - organized["trainer"]["max_epochs"] = int(epochs_cli) - else: - organized["trainer"].setdefault("max_epochs", 3) + # Helper + def maybe_add(cat: str, k: str, default_val: Any) -> None: + if k in self.acceptable_kwargs[cat]: + organized[cat][k] = organized[cat].get(k, default_val) + + # Model dims / lr + maybe_add("model", "learning_rate", self.default_learning_rate) + maybe_add("model", "context_dim", self.context_dim) + maybe_add("model", "x_dim", self.x_dim) + maybe_add("model", "y_dim", self.y_dim) + + if organized["model"].get("num_archetypes", 1) == 0: + organized["model"].pop("num_archetypes", None) - current_val_split = organized["data"].get("val_split", self.default_val_split) - organized["data"]["val_split"] = current_val_split - use_val = float(current_val_split) > 0.0 + # Data defaults + maybe_add("data", "train_batch_size", self.default_train_batch_size) + maybe_add("data", "val_batch_size", self.default_val_batch_size) + maybe_add("data", "test_batch_size", self.default_test_batch_size) + maybe_add("data", "predict_batch_size", organized["data"].get("val_batch_size", self.default_val_batch_size)) - organized["trainer"].setdefault("accelerator", self.accelerator) + # Trainer defaults + maybe_add("trainer", "accelerator", self.accelerator) organized["trainer"].setdefault("enable_progress_bar", False) organized["trainer"].setdefault("logger", False) - organized["trainer"].setdefault("enable_checkpointing", False) organized["trainer"].setdefault("num_sanity_val_steps", 0) - - # FIXED: Default to mixed precision on GPU - if self.accelerator in ("cuda", "gpu"): - organized["trainer"].setdefault("precision", "16-mixed") - else: - organized["trainer"].setdefault("precision", 32) - if not use_val: - organized["trainer"].setdefault("limit_val_batches", 0) - - world_size_env = int(os.environ.get("WORLD_SIZE", "1")) + # devices/strategy defaults for torchrun/DDP safety + world = _world_size_env() if "devices" not in organized["trainer"]: - # When torchrun is active, devices must match world_size - organized["trainer"]["devices"] = world_size_env if world_size_env > 1 else 1 + organized["trainer"]["devices"] = world if world > 1 else 1 devices_cfg = organized["trainer"].get("devices", 1) if isinstance(devices_cfg, int): @@ -508,362 +542,380 @@ def _organize_and_expand_fit_kwargs(self, **kwargs): devices = len(devices_cfg) else: devices = 1 - - # Validate: if torchrun sets WORLD_SIZE > 1, devices must match - if world_size_env > 1 and devices != world_size_env: - if _is_main_process(): - print(f"[WARNING] torchrun WORLD_SIZE={world_size_env} but devices={devices}. " - f"Overriding devices to {world_size_env}.") - devices = world_size_env - organized["trainer"]["devices"] = devices - # Track for prediction strategy - self._trained_devices = devices - self._trained_with_ddp = devices > 1 + if world > 1 and devices != world: + if _is_main_process(): + print(f"[WARNING] WORLD_SIZE={world} but devices={devices}; overriding devices -> {world}.") + organized["trainer"]["devices"] = world + devices = world if "strategy" not in organized["trainer"]: - if devices > 1 or world_size_env > 1: - from datetime import timedelta - # Check if we're under torchrun (process group may already exist) - if world_size_env > 1: - # torchrun case: let Lightning use existing process group - organized["trainer"]["strategy"] = "ddp" - else: - # Lightning-spawned DDP case - organized["trainer"]["strategy"] = DDPStrategy( - process_group_backend="nccl" if torch.cuda.is_available() else "gloo", - find_unused_parameters=False, - broadcast_buffers=False, - timeout=timedelta(minutes=30), - ) + if devices > 1 or world > 1: + organized["trainer"]["strategy"] = DDPStrategy( + find_unused_parameters=False, + broadcast_buffers=False, + process_group_backend="nccl" if torch.cuda.is_available() else "gloo", + ) else: organized["trainer"]["strategy"] = "auto" - if ( - organized["trainer"].get("strategy") in ("auto", None) - and organized["trainer"].get("devices", 1) == 1 - and world_size_env == 1 # Not under torchrun - and "plugins" not in organized["trainer"] - ): - organized["trainer"]["plugins"] = [LightningEnvironment()] - - def maybe_add(cat, k, default): - if k in self.acceptable_kwargs[cat]: - organized[cat][k] = organized[cat].get(k, default) - - maybe_add("model", "learning_rate", self.default_learning_rate) - maybe_add("model", "context_dim", self.context_dim) - maybe_add("model", "x_dim", self.x_dim) - maybe_add("model", "y_dim", self.y_dim) - if organized["model"].get("num_archetypes", 1) == 0: - organized["model"].pop("num_archetypes", None) - - maybe_add("data", "train_batch_size", self.default_train_batch_size) - maybe_add("data", "val_batch_size", self.default_val_batch_size) - maybe_add("data", "test_batch_size", self.default_test_batch_size) - maybe_add("data", "predict_batch_size", self.default_val_batch_size) - - # FIXED: Better num_workers default - default_nw = self._default_num_workers(devices) - maybe_add("data", "num_workers", default_nw) - - maybe_add("data", "pin_memory", self.accelerator in ("cuda", "gpu")) - - persistent_default = organized["data"].get("num_workers", 0) > 0 - maybe_add("data", "persistent_workers", persistent_default) - - drop_last_default = devices > 1 - maybe_add("data", "drop_last", drop_last_default) + # Performance defaults + if self.accelerator == "gpu": + organized["trainer"].setdefault("precision", "16-mixed") + else: + organized["trainer"].setdefault("precision", 32) + # DataLoader perf defaults + maybe_add("data", "num_workers", self._default_num_workers(devices)) + maybe_add("data", "pin_memory", self.accelerator == "gpu") + maybe_add("data", "persistent_workers", organized["data"].get("num_workers", 0) > 0) + maybe_add("data", "drop_last", (devices > 1 or world > 1)) maybe_add("data", "shuffle_train", True) maybe_add("data", "shuffle_eval", False) maybe_add("data", "dtype", torch.float) + # Wrapper defaults maybe_add("wrapper", "n_bootstraps", self.default_n_bootstraps) + # Validation split + val_split = float(organized["data"].get("val_split", self.default_val_split)) + organized["data"]["val_split"] = val_split + + # Callbacks: preserve legacy behavior (EarlyStopping + ModelCheckpoint by default) + use_val = self._safe_val_split(10, val_split) > 0.0 # placeholder check; refined at fit time + es_patience = organized["wrapper"].get("es_patience", self.default_es_patience) es_monitor = organized["wrapper"].get("es_monitor", "val_loss" if use_val else "train_loss") es_mode = organized["wrapper"].get("es_mode", "min") - es_patience = organized["wrapper"].get("es_patience", self.default_es_patience) es_verbose = organized["wrapper"].get("es_verbose", False) es_min_delta = organized["wrapper"].get("es_min_delta", 0.0) - cb_ctors = organized["trainer"].get("callback_constructors", []) + cb_ctors = organized["trainer"].get("callback_constructors", None) + if cb_ctors is None: + cb_ctors = [] - if use_val and (es_patience is not None and es_patience > 0): + # Default: enable checkpointing unless explicitly disabled + organized["trainer"].setdefault("enable_checkpointing", True) + + # Add EarlyStopping only if patience > 0 + if es_patience is not None and int(es_patience) > 0: cb_ctors.append( lambda i: EarlyStopping( monitor=es_monitor, mode=es_mode, - patience=es_patience, - verbose=es_verbose, - min_delta=es_min_delta, + patience=int(es_patience), + verbose=bool(es_verbose), + min_delta=float(es_min_delta), ) ) - if organized["trainer"].get("enable_checkpointing", False): + if bool(organized["trainer"].get("enable_checkpointing", True)): cb_ctors.append( lambda i: ModelCheckpoint( - monitor=("val_loss" if use_val else None), + monitor=es_monitor, dirpath=f"{kwargs.get('checkpoint_path', './lightning_logs')}/boot_{i}_checkpoints", - filename=("{epoch}-{val_loss:.4f}" if use_val else "{epoch}"), + filename="{epoch}-{val_loss:.4f}", ) ) + organized["trainer"]["callback_constructors"] = cb_ctors + return organized - for kw in unrecognized: - print(f"Received unknown keyword argument {kw}, probably ignoring.") + # ---------------------------- + # Public API + # ---------------------------- + def fit(self, *args, **kwargs) -> None: + """ + Fit contextualized model to data. - cb_list = organized["trainer"].get("callbacks", []) - cb_list = [self._retarget_or_strip_early_stopping(cb, use_val) for cb in cb_list] - organized["trainer"]["callbacks"] = cb_list + Backward compatible with legacy: + - fit(C, X) -> uses X as targets (Contextualized Networks behavior) + - fit(C, X, Y) + - fit(..., Y=...) override + - supports C_val/X_val/Y_val and/or val_split + """ + self.models, self.trainers = [], [] + self.dataloaders = {"train": [], "val": [], "test": []} - ctor_list = organized["trainer"].get("callback_constructors", []) + if len(args) < 2: + raise ValueError("fit expects at least (C, X) as positional args.") - def _wrap_ctor(ctor): - def _wrapped(i): - cb = ctor(i) - return self._retarget_or_strip_early_stopping(cb, use_val) - return _wrapped + C = kwargs.pop("C", None) + X = kwargs.pop("X", None) + Y = kwargs.pop("Y", None) - organized["trainer"]["callback_constructors"] = [_wrap_ctor(c) for c in ctor_list] + if C is None or X is None: + C = args[0] + X = args[1] + if len(args) >= 3: + Y = args[2] + if C is None or X is None: + raise ValueError("fit requires C and X.") - return organized + C = np.asarray(C) + X = np.asarray(X) + if Y is not None: + Y = np.asarray(Y) - def _build_datamodule( - self, - C: np.ndarray, - X: np.ndarray, - Y: Optional[np.ndarray], - *, - train_idx=None, - val_idx=None, - test_idx=None, - predict_idx=None, - data_kwargs: Optional[dict] = None, - task_type: str = "singletask_multivariate", - ) -> ContextualizedRegressionDataModule: - dk = dict( - train_batch_size=self.default_train_batch_size, - val_batch_size=self.default_val_batch_size, - test_batch_size=self.default_test_batch_size, - predict_batch_size=self.default_val_batch_size, - num_workers=0, - pin_memory=(self.accelerator in ("cuda", "gpu")), - persistent_workers=False, - drop_last=False, - shuffle_train=True, - shuffle_eval=False, - dtype=torch.float, - ) - if data_kwargs: - dk.update(data_kwargs) + # Normalize / scale + if self.normalize: + if self.scalers["C"] is None: + self.scalers["C"] = StandardScaler().fit(C) + C = self.scalers["C"].transform(C) - dm = ContextualizedRegressionDataModule( - C=C, - X=X, - Y=Y, - task_type=task_type, - train_idx=train_idx, - val_idx=val_idx, - test_idx=test_idx, - predict_idx=predict_idx, - train_batch_size=dk["train_batch_size"], - val_batch_size=dk["val_batch_size"], - test_batch_size=dk["test_batch_size"], - predict_batch_size=dk["predict_batch_size"], - num_workers=dk["num_workers"], - pin_memory=dk["pin_memory"], - persistent_workers=dk["persistent_workers"], - drop_last=dk["drop_last"], - shuffle_train=dk["shuffle_train"], - shuffle_eval=dk["shuffle_eval"], - dtype=dk["dtype"], - ) - return dm + if self.scalers["X"] is None: + self.scalers["X"] = StandardScaler().fit(X) + X = self.scalers["X"].transform(X) - def _split_train_data( - self, - C: np.ndarray, - X: np.ndarray, - Y: Optional[np.ndarray] = None, - *, - Y_required: bool = True, - val_split: Optional[float] = None, - random_state: Optional[int] = None, - shuffle: bool = True, - **_, - ): - """Return (train_idx, val_idx) over rows.""" - if Y_required and Y is None: - raise ValueError("Y is required but was not provided.") - n = C.shape[0] - vs = self.default_val_split if val_split is None else float(val_split) - if vs <= 0.0: - idx = np.arange(n) - return idx, None - - # FIXED: Handle small datasets - min_val_samples = max(1, int(n * vs)) - if min_val_samples < 2: - # Too small for validation split - idx = np.arange(n) - return idx, None - - # CRITICAL FIX: Use deterministic random_state for DDP - # All ranks MUST get the same train/val split - if random_state is None: - random_state = 42 # Fixed seed for reproducibility across ranks - - tr_idx, va_idx = train_test_split( - np.arange(n), - test_size=vs, - shuffle=shuffle, - random_state=random_state, - ) - return tr_idx, va_idx + self.context_dim = int(C.shape[-1]) + self.x_dim = int(X.shape[-1]) - def _maybe_scale_C(self, C: np.ndarray) -> np.ndarray: - if self.normalize and self.scalers["C"] is not None: - return self.scalers["C"].transform(C) - return C + # Legacy semantics: if Y not provided, use X as targets + if Y is None: + Y = X + else: + if Y.ndim == 1: + Y = np.expand_dims(Y, 1) - def _maybe_scale_X(self, X: np.ndarray) -> np.ndarray: - if self.normalize and self.scalers["X"] is not None: - return self.scalers["X"].transform(X) - return X + if self.normalize and self.scalers["Y"] is not None: + # already fitted (e.g., multiple fits). keep behavior consistent. + pass - def _get_inference_device(self) -> torch.device: - """ - Get the device to use for inference. - FIXED: Always use single device for prediction to avoid DDP complexity. - """ - if self.accelerator in ("cuda", "gpu") and torch.cuda.is_available(): - return torch.device("cuda:0") - return torch.device("cpu") + # Scale Y if it's continuous (avoid scaling binary) + if self.normalize and not np.array_equal(np.unique(Y), np.array([0, 1])): + if self.scalers["Y"] is None: + self.scalers["Y"] = StandardScaler().fit(Y) + Y = self.scalers["Y"].transform(Y) - def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs): - if not hasattr(self, "models") or self.models is None: - raise ValueError("Trying to predict with a model that hasn't been trained yet.") + self.y_dim = int(Y.shape[-1]) - Cq = self._maybe_scale_C(C) - Xq = self._maybe_scale_X(X) - Yq = np.zeros((len(Cq), self.y_dim), dtype=np.float32) - - dm = self._build_datamodule( - C=Cq, - X=Xq, - Y=Yq, - predict_idx=np.arange(len(Cq)), - data_kwargs=dict( - train_batch_size=self._init_kwargs["data"].get("train_batch_size", self.default_train_batch_size), - val_batch_size=self._init_kwargs["data"].get("val_batch_size", self.default_val_batch_size), - test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), - predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), - num_workers=0, - pin_memory=False, - persistent_workers=False, - shuffle_train=False, - shuffle_eval=False, - dtype=self._init_kwargs["data"].get("dtype", torch.float), - ), - task_type="singletask_univariate" if self._init_kwargs["model"].get("univariate", False) - else "singletask_multivariate", + organized = self._organize_and_expand_fit_kwargs(**kwargs) + self.n_bootstraps = int(organized["wrapper"].get("n_bootstraps", self.n_bootstraps)) + + # Determine val split now that we know n + val_split = float(organized["data"].get("val_split", self.default_val_split)) + val_split = self._safe_val_split(int(C.shape[0]), val_split) + organized["data"]["val_split"] = val_split + use_val = val_split > 0.0 + + # If no val, retarget default monitors to train_loss (legacy had a try/except; we make it explicit) + if not use_val: + # Adjust any callback_constructors that still monitor val_loss + new_ctors = [] + for ctor in organized["trainer"].get("callback_constructors", []): + def _wrap_ctor(_ctor): + def _inner(i): + cb = _ctor(i) + if isinstance(cb, EarlyStopping) and isinstance(getattr(cb, "monitor", ""), str) and cb.monitor.startswith("val_"): + return EarlyStopping( + monitor="train_loss", + mode=getattr(cb, "mode", "min"), + patience=getattr(cb, "patience", self.default_es_patience), + verbose=getattr(cb, "verbose", False), + min_delta=getattr(cb, "min_delta", 0.0), + ) + if isinstance(cb, ModelCheckpoint) and isinstance(getattr(cb, "monitor", ""), str) and cb.monitor.startswith("val_"): + # If no validation, checkpointing on val_loss is meaningless; keep callback but disable monitor. + cb.monitor = None # PL will checkpoint on epoch end + return cb + return _inner + new_ctors.append(_wrap_ctor(ctor)) + organized["trainer"]["callback_constructors"] = new_ctors + + # Also avoid running validation loops + organized["trainer"].setdefault("limit_val_batches", 0) + + # Optional explicit val arrays + C_val = organized["data"].get("C_val", None) + X_val = organized["data"].get("X_val", None) + Y_val = organized["data"].get("Y_val", None) + + # Task type + univariate_flag = bool(organized["model"].get("univariate", False)) + task_type = "singletask_univariate" if univariate_flag else "singletask_multivariate" + + # Build final arrays + indices (supports separate val arrays by concatenation) + C_all, X_all, Y_all, train_idx, val_idx = self._resolve_train_val_arrays( + C, + X, + Y, + C_val=C_val, + X_val=X_val, + Y_val=Y_val, + Y_required=True, + val_split=val_split, ) - # Let Lightning handle sharding under DDP - preds = [] - n_expected = len(Cq) + for b in range(self.n_bootstraps): + # Construct model kwargs (do NOT pass wrapper-only keys) + model_kwargs = dict(organized["model"]) + # univariate is a wrapper concern; base_constructor should already encode univariate vs multivariate + model_kwargs.pop("univariate", None) - for i in range(len(self.models)): - model = self.models[i] - model.eval() + model = self.base_constructor(**model_kwargs) - # Prefer the trainer created during fit (keeps strategy/devices consistent) - trainer = None - if hasattr(self, "trainers") and self.trainers is not None and i < len(self.trainers): - trainer = self.trainers[i] + use_dm = self._use_datamodule_for_model(model) + + # Build trainer kwargs + callbacks + trainer_kwargs = copy.deepcopy(organized["trainer"]) + cb_ctors = trainer_kwargs.pop("callback_constructors", []) + callbacks = list(trainer_kwargs.get("callbacks", [])) + callbacks.extend([ctor(b) for ctor in cb_ctors]) + trainer_kwargs["callbacks"] = callbacks - if _is_distributed() and trainer is not None: - # ---- DDP path: use trainer.predict + gather outputs to rank 0 ---- - local_pred = trainer.predict(model, datamodule=dm) + # Ensure checkpoint directories exist + for cb in callbacks: + if isinstance(cb, ModelCheckpoint): + try: + os.makedirs(cb.dirpath, exist_ok=True) + except Exception: + pass - local_packed = _pack_local_pred_payload(local_pred) - gathered = _gather_object_to_rank0(local_packed) + # Construct trainer via factory that handles env/plugins correctly + from contextualized.regression.trainers import make_trainer_with_env - if not _is_main_process(): - # Non-zero ranks return nothing; rank 0 will return the final answer. - return None + trainer = make_trainer_with_env(self.trainer_constructor, **trainer_kwargs) + + if use_dm: + dm = self._build_datamodule( + C=C_all, + X=X_all, + Y=Y_all, + train_idx=train_idx, + val_idx=val_idx if use_val else None, + test_idx=None, + predict_idx=None, + data_kwargs=organized["data"], + task_type=task_type, + ) - merged = _merge_packed_payloads(gathered) + if _is_main_process(): + print(f"[RANK {_rank()}] train_idx[:5]={train_idx[:5]}, val_idx[:5]={val_idx[:5] if val_idx is not None else None}") - # Sort/dedupe by orig_idx (DistributedSampler may pad) - merged = _stable_sort_and_dedupe_by_key(merged, primary="orig_idx") + trainer.fit(model, datamodule=dm, **organized["fit"]) - if "betas" not in merged or "mus" not in merged or "orig_idx" not in merged: - raise RuntimeError("predict: Missing required keys in gathered payload: need orig_idx, betas, mus.") + # Keep dataloaders for compatibility (best-effort) + try: + dm.setup("fit") + self.dataloaders["train"].append(dm.train_dataloader()) + self.dataloaders["val"].append(dm.val_dataloader() if use_val else None) + self.dataloaders["test"].append(None) + except Exception: + self.dataloaders["train"].append(None) + self.dataloaders["val"].append(None) + self.dataloaders["test"].append(None) - orig_idx = merged["orig_idx"].astype(np.int64) - betas = torch.as_tensor(merged["betas"]) - mus = torch.as_tensor(merged["mus"]) + else: + # Legacy path: model provides dataloader(C, X, Y, batch_size=...) + train_data = [C_all[train_idx], X_all[train_idx], Y_all[train_idx]] if Y_all is not None else [C_all[train_idx], X_all[train_idx]] + val_data = None + if use_val and val_idx is not None: + val_data = [C_all[val_idx], X_all[val_idx], Y_all[val_idx]] if Y_all is not None else [C_all[val_idx], X_all[val_idx]] + + train_dl = model.dataloader(*train_data, batch_size=organized["data"].get("train_batch_size", self.default_train_batch_size)) + val_dl = None + if val_data is not None: + val_dl = model.dataloader(*val_data, batch_size=organized["data"].get("val_batch_size", self.default_val_batch_size)) + + try: + trainer.fit(model, train_dl, val_dl, **organized["fit"]) + except Exception: + trainer.fit(model, train_dl, **organized["fit"]) + + self.dataloaders["train"].append(train_dl) + self.dataloaders["val"].append(val_dl) + self.dataloaders["test"].append(None) + + # Load best checkpoint (legacy behavior) if present + ckpt_cb = next((cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)), None) + if ckpt_cb is not None and getattr(ckpt_cb, "best_model_path", None): + best_path = ckpt_cb.best_model_path + if isinstance(best_path, str) and best_path and os.path.exists(best_path): + try: + best = torch.load(best_path, map_location="cpu") + if isinstance(best, dict) and "state_dict" in best: + model.load_state_dict(best["state_dict"]) + except Exception: + pass - # Ensure we are aligned to query order - # (orig_idx is row-id into the query arrays because predict_idx=np.arange(n)) - C_sorted = torch.as_tensor(Cq[orig_idx], dtype=betas.dtype) - X_sorted = torch.as_tensor(Xq[orig_idx], dtype=betas.dtype) + self.models.append(model) + self.trainers.append(trainer) - # Compute yhat on rank 0 in correct global order - with torch.no_grad(): - yhat = model._predict_y(C_sorted, X_sorted, betas, mus).detach().cpu().numpy() + return None - # If DDP padded, we may have > n_expected; trim safely by orig_idx range - # (should not happen if orig_idx is in [0, n_expected)) - if yhat.shape[0] != n_expected: - # Build dense output in original query order - dense = np.zeros((n_expected,) + yhat.shape[1:], dtype=yhat.dtype) - dense[orig_idx] = yhat - yhat = dense + def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs): + if self.models is None or self.trainers is None: + raise ValueError("Trying to predict with a model that hasn't been trained yet.") - preds.append(yhat) + C = np.asarray(C) + X = np.asarray(X) - else: - # ---- Single-process fallback: iterate predict_dataloader directly ---- - dm.setup(stage="predict") - pred_loader = dm.predict_dataloader() + Cq = self._maybe_scale_C(C) + Xq = self._maybe_scale_X(X) - out_batches = [] - device = self._get_inference_device() - model.to(device) + # Build a DDP-safe predict loader via DataModule when possible + preds_all: List[np.ndarray] = [] - with torch.no_grad(): - for b_idx, batch in enumerate(pred_loader): - batch = { - k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) - for k, v in batch.items() - } + for i, (model, trainer) in enumerate(zip(self.models, self.trainers)): + if not hasattr(trainer, "predict_y"): + raise RuntimeError( + "Trainer does not implement predict_y(). " + "Use contextualized.regression.trainers.RegressionTrainer (or a subclass)." + ) + + use_dm = self._use_datamodule_for_model(model) + + if use_dm: + # zeros outcomes for predict + Yq = np.zeros((len(Cq), int(self.y_dim or 1)), dtype=np.float32) + + task_type = "singletask_univariate" if bool(getattr(model, "hparams", {}).get("univariate", False)) else "singletask_multivariate" + # Prefer wrapper's univariate flag if present + univariate_flag = bool(self._init_kwargs.get("model", {}).get("univariate", False)) + task_type = "singletask_univariate" if univariate_flag else "singletask_multivariate" + + dm = self._build_datamodule( + C=Cq, + X=Xq, + Y=Yq, + train_idx=None, + val_idx=None, + test_idx=None, + predict_idx=np.arange(len(Cq)), + data_kwargs={**self._init_kwargs.get("data", {}), **kwargs}, + task_type=task_type, + ) + dm.setup("predict") + dl = dm.predict_dataloader() + else: + dl = model.dataloader(Cq, Xq, np.zeros((len(Cq), int(self.y_dim or 1))), batch_size=kwargs.get("predict_batch_size", self.default_val_batch_size)) - out = model.predict_step(batch, b_idx) - betas = out["betas"] - mus = out["mus"] + yhat = trainer.predict_y(model, dl, **kwargs) + if yhat is None: + # DDP non-rank0: avoid duplicating work/outputs + return None - # IMPORTANT: use the *batch* for C/X, not the output payload - yb = model._predict_y(batch["contexts"], batch["predictors"], betas, mus) - out_batches.append(yb.detach().cpu()) + preds_all.append(np.asarray(yhat, dtype=float)) - yhat = torch.cat(out_batches, dim=0).numpy() - preds.append(yhat) + predictions = np.array(preds_all, dtype=float) - predictions = np.array(preds) - if not individual_preds: - predictions = np.mean(predictions, axis=0) + if individual_preds: + out = predictions + else: + bad = ~np.isfinite(predictions) + if bad.any(): + num_bad_boots = np.unique(np.where(bad)[0]).size + print( + f"Warning: {num_bad_boots}/{len(preds_all)} bootstraps produced non-finite predictions; excluding them from the ensemble." + ) + out = self._nanrobust_mean(predictions, axis=0) if self.normalize and self.scalers["Y"] is not None: if individual_preds: - predictions = np.array([self.scalers["Y"].inverse_transform(p) for p in predictions]) + out = np.array([self.scalers["Y"].inverse_transform(p) for p in out]) else: - predictions = self.scalers["Y"].inverse_transform(predictions) - - return predictions + out = self.scalers["Y"].inverse_transform(out) + return out def predict_params( self, @@ -871,252 +923,74 @@ def predict_params( individual_preds: bool = False, model_includes_mus: bool = True, **kwargs, - ): - if not hasattr(self, "models") or self.models is None: + ) -> Union[ + np.ndarray, + Tuple[np.ndarray, np.ndarray], + ]: + if self.models is None or self.trainers is None: raise ValueError("Trying to predict with a model that hasn't been trained yet.") + C = np.asarray(C) Cq = self._maybe_scale_C(C) - X_zero = np.zeros((len(Cq), self.x_dim), dtype=np.float32) - Y_zero = np.zeros((len(Cq), self.y_dim), dtype=np.float32) - - uses_y = kwargs.pop("uses_y", True) - - dm = self._build_datamodule( - C=Cq, - X=X_zero, - Y=Y_zero if uses_y else None, - predict_idx=np.arange(len(Cq)), - data_kwargs=dict( - train_batch_size=self._init_kwargs["data"].get("train_batch_size", self.default_train_batch_size), - val_batch_size=self._init_kwargs["data"].get("val_batch_size", self.default_val_batch_size), - test_batch_size=self._init_kwargs["data"].get("test_batch_size", self.default_test_batch_size), - predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", self.default_val_batch_size), - num_workers=0, - pin_memory=False, - persistent_workers=False, - shuffle_train=False, - shuffle_eval=False, - dtype=self._init_kwargs["data"].get("dtype", torch.float), - ), - task_type="singletask_univariate" if self._init_kwargs["model"].get("univariate", False) - else "singletask_multivariate", - ) - - out_betas, out_mus = [], [] - n_expected = len(Cq) - for i in range(len(self.models)): - model = self.models[i] - model.eval() + uses_y = bool(kwargs.pop("uses_y", True)) - trainer = None - if hasattr(self, "trainers") and self.trainers is not None and i < len(self.trainers): - trainer = self.trainers[i] + betas_list: List[np.ndarray] = [] + mus_list: List[np.ndarray] = [] - if _is_distributed() and trainer is not None: - local_pred = trainer.predict(model, datamodule=dm) - local_packed = _pack_local_pred_payload(local_pred) - gathered = _gather_object_to_rank0(local_packed) - - if not _is_main_process(): - return (None, None) if model_includes_mus else None - - - merged = _merge_packed_payloads(gathered) - merged = _stable_sort_and_dedupe_by_key(merged, primary="orig_idx") - - if "betas" not in merged or "orig_idx" not in merged: - raise RuntimeError("predict_params: Missing required keys in gathered payload: need orig_idx, betas.") - - orig_idx = merged["orig_idx"].astype(np.int64) - - betas_i = merged["betas"] - if betas_i.shape[0] != n_expected: - dense_b = np.zeros((n_expected,) + betas_i.shape[1:], dtype=betas_i.dtype) - dense_b[orig_idx] = betas_i - betas_i = dense_b - - out_betas.append(betas_i) - - if model_includes_mus: - if "mus" not in merged: - raise RuntimeError("predict_params: model_includes_mus=True but mus missing in payload.") - mus_i = merged["mus"] - if mus_i.shape[0] != n_expected: - dense_m = np.zeros((n_expected,) + mus_i.shape[1:], dtype=mus_i.dtype) - dense_m[orig_idx] = mus_i - mus_i = dense_m - out_mus.append(mus_i) + for model, trainer in zip(self.models, self.trainers): + if not hasattr(trainer, "predict_params"): + raise RuntimeError( + "Trainer does not implement predict_params(). " + "Use contextualized.regression.trainers.RegressionTrainer (or a subclass)." + ) + use_dm = self._use_datamodule_for_model(model) + + if use_dm: + X_zero = np.zeros((len(Cq), int(self.x_dim or 1)), dtype=np.float32) + Y_zero = np.zeros((len(Cq), int(self.y_dim or 1)), dtype=np.float32) if uses_y else None + + univariate_flag = bool(self._init_kwargs.get("model", {}).get("univariate", False)) + task_type = "singletask_univariate" if univariate_flag else "singletask_multivariate" + + dm = self._build_datamodule( + C=Cq, + X=X_zero, + Y=Y_zero, + train_idx=None, + val_idx=None, + test_idx=None, + predict_idx=np.arange(len(Cq)), + data_kwargs={**self._init_kwargs.get("data", {}), **kwargs}, + task_type=task_type, + ) + dm.setup("predict") + dl = dm.predict_dataloader() else: - # Single-process fallback (local ordered) - dm.setup(stage="predict") - pred_loader = dm.predict_dataloader() - - device = self._get_inference_device() - model.to(device) - - beta_batches, mu_batches = [], [] - with torch.no_grad(): - for b_idx, batch in enumerate(pred_loader): - batch = { - k: (v.to(device, non_blocking=True) if torch.is_tensor(v) else v) - for k, v in batch.items() - } - out = model.predict_step(batch, b_idx) - beta_batches.append(out["betas"].detach().cpu()) - if model_includes_mus: - mu_batches.append(out["mus"].detach().cpu()) - - betas_i = torch.cat(beta_batches, dim=0).numpy() - out_betas.append(betas_i) - - if model_includes_mus: - mus_i = torch.cat(mu_batches, dim=0).numpy() - out_mus.append(mus_i) - - betas = np.array(out_betas) - if model_includes_mus: - mus = np.array(out_mus) - return (betas, mus) if individual_preds else (np.mean(betas, axis=0), np.mean(mus, axis=0)) - - return betas if individual_preds else np.mean(betas, axis=0) - - - def fit(self, *args, **kwargs) -> None: - """ - Fit contextualized model to data. - FIXED: Proper DDP handling and device tracking. - """ - self.models, self.trainers = [], [] - - # Normalize argument order - C_in = kwargs.pop("C", None) - X_in = kwargs.pop("X", None) - Y_in = kwargs.pop("Y", None) - - if (C_in is not None) and (X_in is not None): - C, X, Y = C_in, X_in, Y_in - else: - if len(args) == 3: - A, B, Carg = args - if A.shape[0] == B.shape[0] == Carg.shape[0]: - if (B.ndim == 1) or (B.ndim == 2 and B.shape[1] <= 4): - X, Y, C = A, B, Carg - else: - C, X, Y = A, B, Carg + if uses_y: + dl = model.dataloader(Cq, np.zeros((len(Cq), int(self.x_dim or 1))), np.zeros((len(Cq), int(self.y_dim or 1)))) else: - raise ValueError("Mismatched sample counts among provided arrays.") - elif len(args) == 2: - A, B = args - if A.shape[0] != B.shape[0]: - raise ValueError("Mismatched sample counts for two-argument fit.") - C, X, Y = A, B, None - else: - raise ValueError("fit expects (C,X[,Y]) or (X,Y,C) or kw-only C=..., X=...") + dl = model.dataloader(Cq, np.zeros((len(Cq), int(self.x_dim or 1)))) - # Optional scaling - if self.normalize: - if self.scalers["C"] is None: - self.scalers["C"] = StandardScaler().fit(C) - C = self.scalers["C"].transform(C) - if self.scalers["X"] is None: - self.scalers["X"] = StandardScaler().fit(X) - X = self.scalers["X"].transform(X) - - self.context_dim = C.shape[-1] - self.x_dim = X.shape[-1] + out = trainer.predict_params(model, dl, **kwargs) + if out is None or (isinstance(out, tuple) and out[0] is None): + return (None, None) if model_includes_mus else None # DDP non-rank0 - if Y is not None: - if len(Y.shape) == 1: - Y = np.expand_dims(Y, 1) - if self.normalize and not np.array_equal(np.unique(Y), np.array([0, 1])): - if self.scalers["Y"] is None: - self.scalers["Y"] = StandardScaler().fit(Y) - Y = self.scalers["Y"].transform(Y) - self.y_dim = Y.shape[-1] - args = (C, X, Y) - else: - self.y_dim = self.x_dim - args = (C, X) - - organized = self._organize_and_expand_fit_kwargs(**kwargs) - self.n_bootstraps = organized["wrapper"].get("n_bootstraps", self.n_bootstraps) - - n = C.shape[0] - val_split = organized["data"].get("val_split", self.default_val_split) - use_val = val_split > 0.0 - - for b in range(self.n_bootstraps): - # Model - _model_kwargs = dict(organized["model"]) - _model_kwargs.pop("univariate", None) - model = self.base_constructor(**_model_kwargs) - self.model_ = model - - # Indices - train_idx, val_idx = self._split_train_data( - C, X, (args[2] if len(args) == 3 else None), - Y_required=(len(args) == 3), - val_split=val_split, - ) - print(f"[RANK {os.environ.get('RANK', 0)}] train_idx[:5]={train_idx[:5]}, val_idx[:5]={val_idx[:5] if val_idx is not None else None}") - - test_idx = None - - # DataModule - task_type = "singletask_univariate" if organized["model"].get("univariate", False) else "singletask_multivariate" - dm = self._build_datamodule( - C=args[0], X=args[1], Y=(args[2] if len(args) == 3 else None), - train_idx=train_idx, val_idx=val_idx, test_idx=test_idx, - data_kwargs=dict( - train_batch_size=organized["data"].get("train_batch_size", self.default_train_batch_size), - val_batch_size=organized["data"].get("val_batch_size", self.default_val_batch_size), - test_batch_size=organized["data"].get("test_batch_size", self.default_test_batch_size), - predict_batch_size=organized["data"].get("predict_batch_size", self.default_val_batch_size), - num_workers=organized["data"].get("num_workers", 0), - pin_memory=organized["data"].get("pin_memory", self.accelerator in ("cuda", "gpu")), - persistent_workers=organized["data"].get("persistent_workers", False), - drop_last=organized["data"].get("drop_last", False), - shuffle_train=organized["data"].get("shuffle_train", True), - shuffle_eval=organized["data"].get("shuffle_eval", False), - dtype=organized["data"].get("dtype", torch.float), - ), - task_type=task_type, - ) - - # Trainer - trainer_kwargs = copy.deepcopy(organized["trainer"]) - trainer_kwargs["callbacks"] = [f(b) for f in trainer_kwargs.get("callback_constructors", [])] - trainer_kwargs.pop("callback_constructors", None) - - from contextualized.regression.trainers import make_trainer_with_env - trainer = make_trainer_with_env( - self.trainer_constructor, - **trainer_kwargs, - ) - - for cb in trainer_kwargs.get("callbacks", []): - if isinstance(cb, ModelCheckpoint): - os.makedirs(cb.dirpath, exist_ok=True) - - # Ensure all ranks have setup data before training - if torch.cuda.is_available(): - torch.cuda.synchronize() - - # Fit - trainer.fit( - model, - datamodule=dm, - **organized["fit"], - ) + if model_includes_mus: + b, m = out + betas_list.append(np.asarray(b)) + mus_list.append(np.asarray(m)) + else: + betas_list.append(np.asarray(out)) - # Load best checkpoint if enabled - if trainer_kwargs.get("enable_checkpointing", False): - ckpt_cb = next((cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)), None) - if ckpt_cb and ckpt_cb.best_model_path and os.path.exists(ckpt_cb.best_model_path): - best = torch.load(ckpt_cb.best_model_path, map_location="cpu") - model.load_state_dict(best["state_dict"]) + betas = np.array(betas_list) + if model_includes_mus: + mus = np.array(mus_list) + if individual_preds: + return betas, mus + return np.mean(betas, axis=0), np.mean(mus, axis=0) - self.models.append(model) - self.trainers.append(trainer) \ No newline at end of file + if individual_preds: + return betas + return np.mean(betas, axis=0) diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index ad04dfc9..983c2056 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -273,18 +273,33 @@ def _predict_payload(self, batch: dict, **outputs) -> dict: def training_step(self, batch, batch_idx): loss = self._batch_loss(batch, batch_idx) bs = self._batch_size_from_batch(batch) + + # Step-level logging: keep visibility, avoid per-step all-reduce (DDP scaling killer) self.log( - "train_loss", + "train_loss_step", loss, on_step=True, - on_epoch=True, + on_epoch=False, prog_bar=True, + sync_dist=False, + batch_size=bs, + ) + + # Epoch-level logging: sync across ranks once per epoch (correct global metric) + self.log( + "train_loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=False, sync_dist=True, batch_size=bs, ) + return loss + def validation_step(self, batch, batch_idx): loss = self._batch_loss(batch, batch_idx) bs = self._batch_size_from_batch(batch) diff --git a/regression_scale_bench.py b/regression_scale_bench.py deleted file mode 100644 index f105fe90..00000000 --- a/regression_scale_bench.py +++ /dev/null @@ -1,480 +0,0 @@ -#!/usr/bin/env python3 - -""" -# 0) See what NICs you actually have (optional, for sanity): -ls -1 /sys/class/net -ip -o link show | awk -F': ' '{print NR-1": "$2}' - -# 1) Minimal, safe NCCL/torch env (no hard-coded eth0): -export CUDA_VISIBLE_DEVICES=0,1,2,3 -export OMP_NUM_THREADS=1 -export MKL_NUM_THREADS=1 -export TOKENIZERS_PARALLELISM=false -export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 -export NCCL_DEBUG=WARN -export NCCL_P2P_DISABLE=0 -export NCCL_IB_DISABLE=1 -export NCCL_SOCKET_IFNAME=$(ls /sys/class/net | grep -E '^(ens|enp|eno|eth|bond|ib)' | head -n1) -# If that prints nothing on your machine, fall back to auto-exclude: -[ -z "$NCCL_SOCKET_IFNAME" ] && export NCCL_SOCKET_IFNAME="^lo,docker0" - -# CUDA allocator tweak (fine to keep) -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True - -# 2) Kill any stragglers (optional) -pkill -f scale_bench.py || true -pkill -f torchrun || true - -# 3a) Single-GPU run (torchrun, WORLD_SIZE=1) -torchrun --standalone --nproc_per_node=1 scale_bench.py \ - --epochs 3 --batch-size 2048 --num-workers 8 --precision bf16 \ - --num-samples 1800000 --outdir bench_out/gpu1 - -# 3b) Two GPUs -torchrun --standalone --nproc_per_node=2 scale_bench.py \ - --epochs 3 --batch-size 2048 --num-workers 8 --precision bf16 \ - --num-samples 1800000 --outdir bench_out/gpu2 - -# 3c) Three GPUs -torchrun --standalone --nproc_per_node=3 scale_bench.py \ - --epochs 3 --batch-size 2048 --num-workers 8 --precision bf16 \ - --num-samples 1800000 --outdir bench_out/gpu3 - -# 3d) Four GPUs -torchrun --standalone --nproc_per_node=4 scale_bench.py \ - --epochs 3 --batch-size 2048 --num-workers 8 --precision bf16 \ - --num-samples 1800000 --outdir bench_out/gpu4 - -""" -import os, time, csv, argparse, math, json -from dataclasses import dataclass -from typing import List, Dict -from datetime import timedelta - -import numpy as np -import torch -import pytorch_lightning as pl -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.strategies import DDPStrategy - -# ---- your package pieces ---- -from contextualized.regression import ContextualizedRegression -from contextualized.regression.datamodules import ContextualizedRegressionDataModule - - -# ---------------- launcher/cluster helpers ---------------- -def under_torchrun() -> bool: - e = os.environ - return ("LOCAL_RANK" in e) or ("RANK" in e) or ("WORLD_SIZE" in e) - -def world_size() -> int: - try: - return int(os.environ.get("WORLD_SIZE", "1")) - except Exception: - return 1 - -def is_global_zero() -> bool: - return int(os.environ.get("RANK", "0")) == 0 - - -# ---------------- env + perf ---------------- -def set_env_defaults(): - os.environ.setdefault("OMP_NUM_THREADS", "1") - os.environ.setdefault("MKL_NUM_THREADS", "1") - os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - - # Safer NCCL defaults on cloud single node - os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") - os.environ.setdefault("NCCL_DEBUG", "WARN") - os.environ.setdefault("NCCL_P2P_DISABLE", "0") - os.environ.setdefault("NCCL_IB_DISABLE", "1") # IB usually unavailable on single-node Lambda - - # Pick an interface if not set - if "NCCL_SOCKET_IFNAME" not in os.environ: - try: - ifaces = [d for d in os.listdir("/sys/class/net") if os.path.isdir(f"/sys/class/net/{d}")] - cand = next((i for i in ifaces if i not in ("lo", "docker0")), None) - os.environ["NCCL_SOCKET_IFNAME"] = cand or "lo" - except Exception: - os.environ["NCCL_SOCKET_IFNAME"] = "lo" - - # Rendezvous (used only by ddp_spawn mode) - os.environ.setdefault("MASTER_ADDR", "127.0.0.1") - os.environ.setdefault("MASTER_PORT", str(12355 + (os.getpid() % 20000))) - - if is_global_zero(): - keys = ["NCCL_DEBUG","NCCL_IB_DISABLE","NCCL_P2P_DISABLE","NCCL_SOCKET_IFNAME","MASTER_ADDR","MASTER_PORT"] - print("DDP/NCCL env:", {k: os.environ.get(k) for k in keys}) - - # Ampere+ matmul speedups - try: - torch.set_float32_matmul_precision("high") - except Exception: - pass - - -def map_precision(p): - p = (p or "").lower() - if p in ("bf16", "bfloat16", "bf16-mixed"): - return "bf16-mixed" - if p in ("fp16", "16", "16-mixed"): - return "16-mixed" - return 32 # full precision - - -class EpochTimer(Callback): - def __init__(self): - self._epoch_start = None - self.epoch_times = [] - - @staticmethod - def _using_cuda(trainer) -> bool: - try: - return trainer.accelerator is not None and "cuda" in str(trainer.accelerator).lower() - except Exception: - return torch.cuda.is_available() - - def on_train_epoch_start(self, trainer, pl_module): - if self._using_cuda(trainer): - torch.cuda.synchronize() - self._epoch_start = time.time() - - def on_train_epoch_end(self, trainer, pl_module): - if self._using_cuda(trainer): - torch.cuda.synchronize() - self.epoch_times.append(time.time() - self._epoch_start) - - -# ---------------- synthetic data ---------------- -def make_synthetic(n, c_dim, x_dim, y_dim, seed=42): - rng = np.random.default_rng(seed) - C = rng.standard_normal((n, c_dim)).astype(np.float32) - X = rng.standard_normal((n, x_dim)).astype(np.float32) - W = rng.standard_normal((y_dim, x_dim)).astype(np.float32) - MU = rng.standard_normal((y_dim, 1)).astype(np.float32) - Y = (X @ W.T) + MU.squeeze(-1) + 0.01 * rng.standard_normal((n, y_dim)).astype(np.float32) - return C, X, Y - - -# ---------------- model/trainer builders ---------------- -def build_model(c_dim, x_dim, y_dim, width, layers, lr): - model = ContextualizedRegression( - context_dim=c_dim, - x_dim=x_dim, - y_dim=y_dim, - num_archetypes=8, - encoder_type="mlp", - encoder_kwargs={"width": width, "layers": layers, "link_fn": "identity"}, - learning_rate=lr, - fit_intercept=True, - link_fn="identity", - loss_fn="mse", - model_regularizer="none", - ) - return model - - -def build_dm( - C, X, Y, - train_batch_size: int, - num_workers: int, - pin_memory: bool, -): - n = C.shape[0] - perm = np.random.permutation(n) - n_train = int(0.9 * n) - train_idx = perm[:n_train] - val_idx = perm[n_train:] - - dm = ContextualizedRegressionDataModule( - C=C, X=X, Y=Y, - task_type="singletask_multivariate", - train_idx=train_idx, - val_idx=val_idx, - test_idx=None, - predict_idx=None, - train_batch_size=train_batch_size, - val_batch_size=train_batch_size, - test_batch_size=train_batch_size, - predict_batch_size=train_batch_size, - num_workers=num_workers, - pin_memory=bool(pin_memory), - persistent_workers=bool(num_workers > 0), - drop_last=True, - shuffle_train=True, - shuffle_eval=False, - dtype=torch.float, - ) - dm.prepare_data(); dm.setup() - return dm - - -def build_trainer(devices, precision, epochs, ddp_timeout_s=120, torchrun_mode=False): - """ - devices: - - 0 => cpu - - >=1 => number of devices this process should report to Lightning - - torchrun_mode: - - True => launched via torchrun; use DDP with devices = WORLD_SIZE, - no spawn. Satisfies Lightning's validation. - """ - timer = EpochTimer() - - if devices == 0: - accelerator = "cpu" - devices_arg = 1 - strategy = "auto" - else: - accelerator = "gpu" - if torchrun_mode: - ws = world_size() - devices_arg = ws # <-- IMPORTANT: devices must equal WORLD_SIZE here - strategy = DDPStrategy( - find_unused_parameters=False, - gradient_as_bucket_view=True, - static_graph=True, - timeout=timedelta(seconds=ddp_timeout_s), - ) - else: - devices_arg = devices - strategy = "auto" if devices == 1 else DDPStrategy( - start_method="spawn", - find_unused_parameters=False, - gradient_as_bucket_view=True, - static_graph=True, - timeout=timedelta(seconds=ddp_timeout_s), - ) - - trainer = pl.Trainer( - accelerator=accelerator, - devices=devices_arg, - strategy=strategy, - precision=precision, - max_epochs=epochs, - logger=False, - enable_checkpointing=False, - num_sanity_val_steps=0, - enable_progress_bar=False, - log_every_n_steps=50, - callbacks=[timer], - inference_mode=False, - detect_anomaly=False, - ) - return trainer, timer - - -# ---------------- benchmark runner ---------------- -@dataclass -class BenchCfg: - label: str - devices: int # >=1 gpus - - -def run_once(cfg: BenchCfg, C, X, Y, args, torchrun_mode: bool) -> Dict: - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - dm = build_dm( - C, X, Y, - train_batch_size=args.batch_size, - num_workers=args.num_workers, - pin_memory=True, - ) - model = build_model(args.context_dim, args.x_dim, args.y_dim, - args.width, args.layers, args.lr) - - # Warm-up (stabilize kernels/allocators) on same accelerator config - tiny = max(1024, math.ceil(0.01 * C.shape[0])) - dm_warm = build_dm( - C[:tiny], X[:tiny], Y[:tiny], - train_batch_size=args.batch_size, - num_workers=0, - pin_memory=True, - ) - warm_trainer, _ = build_trainer( - devices=(world_size() if torchrun_mode else cfg.devices), # <-- fix - precision=map_precision(args.precision), - epochs=1, - ddp_timeout_s=args.ddp_timeout, - torchrun_mode=torchrun_mode, - ) - warm_trainer.fit(model, train_dataloaders=dm_warm.train_dataloader()) - - # Timed run - trainer, timer = build_trainer( - devices=(world_size() if torchrun_mode else cfg.devices), # <-- fix - precision=map_precision(args.precision), - epochs=args.epochs, - ddp_timeout_s=args.ddp_timeout, - torchrun_mode=torchrun_mode, - ) - - if torch.cuda.is_available(): - torch.cuda.synchronize() - t0 = time.time() - trainer.fit(model, train_dataloaders=dm.train_dataloader()) - if torch.cuda.is_available(): - torch.cuda.synchronize() - wall = time.time() - t0 - - train_samples = len(dm.train_dataloader().dataset) - samples_total = train_samples * args.epochs - throughput = samples_total / max(wall, 1e-9) - - world = world_size() if torchrun_mode else cfg.devices - per_device = throughput / max(world, 1) - - epoch_times = timer.epoch_times[:] - - res = dict( - label=cfg.label, - devices=world, - wall_seconds=wall, - samples_total=int(samples_total), - throughput_samples_per_s=throughput, - per_device_throughput=per_device, - steps_per_epoch=math.ceil(train_samples / args.batch_size), - samples_per_epoch=int(train_samples), - epoch_times=epoch_times, - ) - if is_global_zero(): - print(json.dumps({ - "label": res["label"], - "devices": res["devices"], - "wall_s": round(res["wall_seconds"], 3), - "throughput_sps": round(res["throughput_samples_per_s"], 2), - "per_device_sps": round(res["per_device_throughput"], 2), - "avg_epoch_s": round(float(np.mean(res["epoch_times"])) if res["epoch_times"] else float("nan"), 3) - }, indent=2)) - return res - - -def save_csv(rows: List[Dict], outdir: str): - os.makedirs(outdir, exist_ok=True) - path = os.path.join(outdir, "scale_results.csv") - fields = ["label","devices","wall_seconds","samples_total", - "throughput_samples_per_s","per_device_throughput", - "steps_per_epoch","samples_per_epoch","epoch_times"] - with open(path, "w", newline="") as f: - w = csv.DictWriter(f, fieldnames=fields) - w.writeheader() - for r in rows: - r2 = r.copy() - r2["epoch_times"] = ";".join(f"{x:.6f}" for x in r["epoch_times"]) - w.writerow(r2) - return path - - -def plot_curves(rows: List[Dict], outdir: str): - import matplotlib - matplotlib.use("Agg") - import matplotlib.pyplot as plt - import numpy as np - - os.makedirs(outdir, exist_ok=True) - labels = [r["label"] for r in rows] - devs = [r["devices"] for r in rows] - thr = [r["throughput_samples_per_s"] for r in rows] - wall = [r["wall_seconds"] for r in rows] - avg_epoch = [np.mean(r["epoch_times"]) if r["epoch_times"] else float("nan") for r in rows] - - plt.figure() - plt.plot(devs, thr, marker="o") - plt.xticks(devs, labels, rotation=30, ha="right") - plt.xlabel("Devices") - plt.ylabel("Throughput (samples/s)") - plt.title("Throughput vs Devices") - plt.tight_layout() - plt.savefig(os.path.join(outdir, "throughput_vs_devices.png")) - plt.close() - - plt.figure() - plt.plot(devs, wall, marker="o") - plt.xticks(devs, labels, rotation=30, ha="right") - plt.xlabel("Devices") - plt.ylabel("Total Wall Time (s)") - plt.title("Wall Time vs Devices") - plt.tight_layout() - plt.savefig(os.path.join(outdir, "walltime_vs_devices.png")) - plt.close() - - plt.figure() - plt.plot(devs, avg_epoch, marker="o") - plt.xticks(devs, labels, rotation=30, ha="right") - plt.xlabel("Devices") - plt.ylabel("Avg Train Epoch Time (s)") - plt.title("Epoch Time vs Devices") - plt.tight_layout() - plt.savefig(os.path.join(outdir, "epoch_time_vs_devices.png")) - plt.close() - - -# ---------------- main ---------------- -def parse_args(): - ap = argparse.ArgumentParser() - ap.add_argument("--epochs", type=int, default=5) - ap.add_argument("--batch-size", type=int, default=2048) # PER GPU - ap.add_argument("--num-workers", type=int, default=8) - ap.add_argument("--precision", type=str, default="bf16") - - # Accept BOTH forms; they write to the same dest - ap.add_argument("--num-samples", dest="num_samples", type=int, default=2_000_000) - ap.add_argument("--n", dest="num_samples", type=int) # optional legacy alias - - ap.add_argument("--context-dim", type=int, default=16) - ap.add_argument("--x-dim", type=int, default=512) - ap.add_argument("--y-dim", type=int, default=64) - ap.add_argument("--width", type=int, default=1024) - ap.add_argument("--layers", type=int, default=4) - ap.add_argument("--lr", type=float, default=1e-3) - ap.add_argument("--outdir", type=str, default="bench_out") - ap.add_argument("--ddp-timeout", type=int, default=180) - ap.add_argument("--max-gpus", type=int, default=4) - return ap.parse_args() - - - -def main(): - set_env_defaults() - args = parse_args() - os.makedirs(args.outdir, exist_ok=True) - - if torch.cuda.is_available(): - torch.backends.cudnn.benchmark = True - os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") - - # Generate data once - C, X, Y = make_synthetic(args.num_samples, args.context_dim, args.x_dim, args.y_dim) - - results = [] - torchrun_mode = under_torchrun() - - if torchrun_mode: - # Run a single config under torchrun (WORLD_SIZE GPUs, 1 per process) - cfg = BenchCfg(label=f"gpu-{world_size()}", devices=1) - if is_global_zero(): - print(f"\n=== Running {cfg.label} (torchrun, {world_size()} processes) ===") - res = run_once(cfg, C, X, Y, args, torchrun_mode=True) - results.append(res) - else: - # Standalone: GPU-only sweep 1..k (skip CPU entirely) - gpus = torch.cuda.device_count() - dev_list = [BenchCfg(f"gpu-{k}", k) for k in range(1, min(args.max_gpus, gpus) + 1)] - for cfg in dev_list: - if is_global_zero(): - print(f"\n=== Running {cfg.label} ===") - res = run_once(cfg, C, X, Y, args, torchrun_mode=False) - results.append(res) - - # Save outputs - if is_global_zero(): - csv_path = save_csv(results, args.outdir) - plot_curves(results, args.outdir) - print(f"\nSaved CSV → {csv_path}") - print(f"Saved plots → {args.outdir}/throughput_vs_devices.png, " - f"walltime_vs_devices.png, epoch_time_vs_devices.png") - - -if __name__ == "__main__": - main() diff --git a/scale_bench.py b/scale_bench.py new file mode 100644 index 00000000..29937bac --- /dev/null +++ b/scale_bench.py @@ -0,0 +1,540 @@ +#!/usr/bin/env python3 +""" +scale_bench.py + +A single-node, torchrun-friendly DDP scaling benchmark for ContextualizedRegression. + +Design goals (to reveal true scaling): + - Fixed number of optimizer steps (not epochs) so each run does identical work. + - Optional GPU-resident synthetic dataset to remove CPU dataloading/transfer bottlenecks. + - Measures only the *steady-state* region (warmup steps excluded). + - Uses Lightning DDP under torchrun correctly (devices=1 per process). + +------------------------------------------------------------ +Quick start (single node, 1..4 GPUs) +------------------------------------------------------------ + +# 0) See NICs (optional) +ls -1 /sys/class/net +ip -o link show | awk -F': ' '{print NR-1": "$2}' + +# 1) Minimal, safe NCCL/torch env (no hard-coded eth0): +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export TOKENIZERS_PARALLELISM=false +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export NCCL_DEBUG=WARN +export NCCL_P2P_DISABLE=0 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=$(ls /sys/class/net | grep -E '^(ens|enp|eno|eth|bond|ib)' | head -n1) +[ -z "$NCCL_SOCKET_IFNAME" ] && export NCCL_SOCKET_IFNAME="^lo,docker0" + +# CUDA allocator tweak (fine to keep) +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# 2) Kill any stragglers (optional) +pkill -f scale_bench.py || true +pkill -f torchrun || true + +# 3) Runs (IMPORTANT: --batch-size is PER GPU) +# Suggested defaults: steps=400 warmup=50 (steady state measured steps=400) + +torchrun --standalone --nproc_per_node=1 scale_bench.py \ + --steps 400 --warmup-steps 50 \ + --batch-size 2048 --precision bf16 \ + --context-dim 16 --x-dim 512 --y-dim 64 \ + --width 1024 --layers 4 \ + --buffer-batches 32 --data-device auto \ + --outdir bench_out/gpu1 + +torchrun --standalone --nproc_per_node=2 scale_bench.py \ + --steps 400 --warmup-steps 50 \ + --batch-size 2048 --precision bf16 \ + --context-dim 16 --x-dim 512 --y-dim 64 \ + --width 1024 --layers 4 \ + --buffer-batches 32 --data-device auto \ + --outdir bench_out/gpu2 + +torchrun --standalone --nproc_per_node=3 scale_bench.py \ + --steps 400 --warmup-steps 50 \ + --batch-size 2048 --precision bf16 \ + --context-dim 16 --x-dim 512 --y-dim 64 \ + --width 1024 --layers 4 \ + --buffer-batches 32 --data-device auto \ + --outdir bench_out/gpu3 + +torchrun --standalone --nproc_per_node=4 scale_bench.py \ + --steps 400 --warmup-steps 50 \ + --batch-size 2048 --precision bf16 \ + --context-dim 16 --x-dim 512 --y-dim 64 \ + --width 1024 --layers 4 \ + --buffer-batches 32 --data-device auto \ + --outdir bench_out/gpu4 + +Notes: + - If scaling is still poor with this benchmark, it is very likely a *real* bottleneck + (GPU interconnect/topology, NCCL config, too-small batch, CPU frequency limits, etc.), + not a dataloader artifact. +""" + +import os +import time +import json +import math +import argparse +from dataclasses import dataclass +from datetime import timedelta +from typing import Dict, Optional + +import numpy as np +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.strategies import DDPStrategy + +# ---- your package pieces ---- +from contextualized.regression import ContextualizedRegression +from contextualized.regression.datamodules import ContextualizedRegressionDataModule + + +# ---------------- launcher/cluster helpers ---------------- +def under_torchrun() -> bool: + e = os.environ + return ("LOCAL_RANK" in e) or ("RANK" in e) or ("WORLD_SIZE" in e) + + +def world_size() -> int: + try: + return int(os.environ.get("WORLD_SIZE", "1")) + except Exception: + return 1 + + +def global_rank() -> int: + try: + return int(os.environ.get("RANK", "0")) + except Exception: + return 0 + + +def local_rank() -> int: + try: + return int(os.environ.get("LOCAL_RANK", "0")) + except Exception: + return 0 + + +def is_global_zero() -> bool: + return global_rank() == 0 + + +# ---------------- env + perf ---------------- +def set_env_defaults(): + os.environ.setdefault("OMP_NUM_THREADS", "1") + os.environ.setdefault("MKL_NUM_THREADS", "1") + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + + os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") + os.environ.setdefault("NCCL_DEBUG", "WARN") + os.environ.setdefault("NCCL_P2P_DISABLE", "0") + os.environ.setdefault("NCCL_IB_DISABLE", "1") + + if "NCCL_SOCKET_IFNAME" not in os.environ: + try: + ifaces = [ + d + for d in os.listdir("/sys/class/net") + if os.path.isdir(f"/sys/class/net/{d}") + ] + cand = next((i for i in ifaces if i not in ("lo", "docker0")), None) + os.environ["NCCL_SOCKET_IFNAME"] = cand or "^lo,docker0" + except Exception: + os.environ["NCCL_SOCKET_IFNAME"] = "^lo,docker0" + + # TF32 / matmul speedups (safe for benchmarking throughput) + if torch.cuda.is_available(): + try: + torch.backends.cuda.matmul.allow_tf32 = True + except Exception: + pass + try: + torch.set_float32_matmul_precision("high") + except Exception: + pass + try: + torch.backends.cudnn.benchmark = True + except Exception: + pass + + if under_torchrun() and torch.cuda.is_available(): + # Ensures each rank uses its intended GPU even if something upstream is odd. + try: + torch.cuda.set_device(local_rank()) + except Exception: + pass + + if is_global_zero(): + keys = [ + "NCCL_DEBUG", + "NCCL_IB_DISABLE", + "NCCL_P2P_DISABLE", + "NCCL_SOCKET_IFNAME", + "TORCH_NCCL_ASYNC_ERROR_HANDLING", + ] + print("DDP/NCCL env:", {k: os.environ.get(k) for k in keys}) + if torch.cuda.is_available(): + print( + "CUDA:", + { + "torch": torch.__version__, + "lightning": pl.__version__, + "gpus_visible": torch.cuda.device_count(), + }, + ) + + +def map_precision(p: str): + p = (p or "").lower() + if p in ("bf16", "bfloat16", "bf16-mixed"): + return "bf16-mixed" + if p in ("fp16", "16", "16-mixed"): + return "16-mixed" + return 32 + + +# ---------------- timing ---------------- +class SteadyStateStepTimer(Callback): + """ + Times optimizer steps in a steady-state window: + - ignore first warmup_steps + - measure next measure_steps + + Assumes accumulate_grad_batches == 1. + """ + + def __init__(self, warmup_steps: int, measure_steps: int): + super().__init__() + self.warmup_steps = int(warmup_steps) + self.measure_steps = int(measure_steps) + self._seen_steps = 0 + self.step_times = [] + self._step_start_t = None + + @staticmethod + def _sync_if_cuda(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + s = self._seen_steps + if self.warmup_steps <= s < (self.warmup_steps + self.measure_steps): + self._sync_if_cuda() + self._step_start_t = time.time() + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + s = self._seen_steps + if self.warmup_steps <= s < (self.warmup_steps + self.measure_steps): + self._sync_if_cuda() + dt = time.time() - (self._step_start_t or time.time()) + self.step_times.append(dt) + + self._seen_steps += 1 + + def measured_wall_time(self) -> float: + return float(sum(self.step_times)) + + +def dist_max(value: float) -> float: + """ + Returns max(value across ranks) if distributed is initialized; else returns value. + """ + try: + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + t = torch.tensor([value], device="cuda" if torch.cuda.is_available() else "cpu") + dist.all_reduce(t, op=dist.ReduceOp.MAX) + return float(t.item()) + except Exception: + pass + return float(value) + + +# ---------------- synthetic data ---------------- +def make_synthetic_tensors( + n: int, + c_dim: int, + x_dim: int, + y_dim: int, + device: torch.device, + seed: int, +) -> Dict[str, torch.Tensor]: + """ + Generates a fixed buffer of synthetic data. + + IMPORTANT: This runs once before timing begins. Keep n reasonable. + """ + g = torch.Generator(device=device) + g.manual_seed(int(seed) + 1000 * global_rank()) + + C = torch.randn((n, c_dim), generator=g, device=device, dtype=torch.float32) + X = torch.randn((n, x_dim), generator=g, device=device, dtype=torch.float32) + Y = torch.randn((n, y_dim), generator=g, device=device, dtype=torch.float32) + return {"C": C, "X": X, "Y": Y} + + +# ---------------- model/trainer/datamodule ---------------- +def build_model(args) -> ContextualizedRegression: + # Uses your current link_fn handling (string keys are valid). + return ContextualizedRegression( + context_dim=args.context_dim, + x_dim=args.x_dim, + y_dim=args.y_dim, + num_archetypes=args.num_archetypes, + encoder_type=args.encoder_type, + encoder_kwargs={"width": args.width, "layers": args.layers, "link_fn": "identity"}, + learning_rate=args.lr, + fit_intercept=True, + link_fn="identity", + loss_fn="mse", + model_regularizer="none", + ) + + +def build_dm(args, C, X, Y) -> ContextualizedRegressionDataModule: + n = int(C.shape[0]) + # Simple split; validation never runs in this benchmark (we pass only train_dataloader). + n_train = int(0.95 * n) + train_idx = np.arange(0, n_train, dtype=np.int64) + val_idx = np.arange(n_train, n, dtype=np.int64) + + dm = ContextualizedRegressionDataModule( + C=C, + X=X, + Y=Y, + task_type="singletask_multivariate", + train_idx=train_idx, + val_idx=val_idx, + test_idx=None, + predict_idx=None, + train_batch_size=args.batch_size, + val_batch_size=args.batch_size, + test_batch_size=args.batch_size, + predict_batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=bool(args.pin_memory), + persistent_workers=bool(args.num_workers > 0), + drop_last=True, + shuffle_train=False, # let Lightning/DDP sampler handle partitioning; shuffle not needed for perf + shuffle_eval=False, + dtype=torch.float, + ) + dm.prepare_data() + dm.setup() + return dm + + +def build_trainer(args, timer: SteadyStateStepTimer) -> pl.Trainer: + if torch.cuda.is_available(): + accelerator = "gpu" + devices = 1 if under_torchrun() else min(args.devices, torch.cuda.device_count()) + strategy = ( + DDPStrategy( + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + timeout=timedelta(seconds=args.ddp_timeout), + ) + if (under_torchrun() or devices > 1) + else "auto" + ) + else: + accelerator = "cpu" + devices = 1 + strategy = "auto" + + # We benchmark *steps*, not epochs. + max_steps = args.warmup_steps + args.steps + + trainer = pl.Trainer( + accelerator=accelerator, + devices=devices, + strategy=strategy, + precision=map_precision(args.precision), + max_steps=max_steps, + max_epochs=10_000, # irrelevant when max_steps is set + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + num_sanity_val_steps=0, + log_every_n_steps=50, + callbacks=[timer], + inference_mode=False, + detect_anomaly=False, + enable_model_summary=False, + use_distributed_sampler=True, + accumulate_grad_batches=1, + limit_val_batches=0, + ) + return trainer + + +# ---------------- benchmark runner ---------------- +@dataclass +class Result: + world_size: int + batch_size_per_gpu: int + global_batch_size: int + warmup_steps: int + measured_steps: int + measured_wall_s: float + throughput_samples_per_s: float + per_gpu_throughput_samples_per_s: float + avg_step_s: float + p95_step_s: float + + +def run_bench(args) -> Result: + ws = world_size() if under_torchrun() else int(args.devices) + dev = torch.device("cuda", local_rank()) if (args.data_device == "cuda" and torch.cuda.is_available()) else torch.device("cpu") + + # If auto: keep data on GPU when available (this removes input bottlenecks). + if args.data_device == "auto": + if torch.cuda.is_available(): + dev = torch.device("cuda", local_rank()) + else: + dev = torch.device("cpu") + + # Dataloader workers cannot safely handle CUDA tensors. + if dev.type == "cuda" and args.num_workers != 0: + if is_global_zero(): + print("NOTE: forcing --num-workers=0 because data-device is CUDA.") + args.num_workers = 0 + + # Build fixed synthetic buffer (not timed) + n = int(args.batch_size * args.buffer_batches) + tensors = make_synthetic_tensors( + n=n, + c_dim=args.context_dim, + x_dim=args.x_dim, + y_dim=args.y_dim, + device=dev, + seed=args.seed, + ) + + dm = build_dm(args, tensors["C"], tensors["X"], tensors["Y"]) + model = build_model(args) + + timer = SteadyStateStepTimer(args.warmup_steps, args.steps) + trainer = build_trainer(args, timer) + + if is_global_zero(): + print( + "\nConfig:", + json.dumps( + { + "torchrun": under_torchrun(), + "world_size": ws, + "local_rank": local_rank(), + "batch_size_per_gpu": args.batch_size, + "global_batch_size": args.batch_size * ws, + "steps_measured": args.steps, + "steps_warmup": args.warmup_steps, + "buffer_samples": n, + "data_device": str(dev), + "precision": map_precision(args.precision), + }, + indent=2, + ), + ) + + trainer.fit(model, train_dataloaders=dm.train_dataloader()) + + measured_wall = timer.measured_wall_time() + measured_wall = dist_max(measured_wall) # slowest rank dictates wall time + + measured_steps = int(args.steps) + global_batch = int(args.batch_size * ws) + samples_total = global_batch * measured_steps + throughput = samples_total / max(measured_wall, 1e-12) + per_gpu = throughput / max(ws, 1) + + step_times = timer.step_times[:] if timer.step_times else [float("nan")] + avg_step = float(np.mean(step_times)) + p95_step = float(np.percentile(step_times, 95)) if len(step_times) > 1 else float("nan") + + return Result( + world_size=ws, + batch_size_per_gpu=int(args.batch_size), + global_batch_size=int(global_batch), + warmup_steps=int(args.warmup_steps), + measured_steps=int(measured_steps), + measured_wall_s=float(measured_wall), + throughput_samples_per_s=float(throughput), + per_gpu_throughput_samples_per_s=float(per_gpu), + avg_step_s=float(avg_step), + p95_step_s=float(p95_step), + ) + + +def save_result(outdir: str, res: Result): + os.makedirs(outdir, exist_ok=True) + path = os.path.join(outdir, "result.json") + with open(path, "w") as f: + json.dump(res.__dict__, f, indent=2) + return path + + +# ---------------- main ---------------- +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--steps", type=int, default=400, help="Measured optimizer steps") + ap.add_argument("--warmup-steps", type=int, default=50, help="Warmup steps excluded from timing") + + ap.add_argument("--batch-size", type=int, default=2048, help="Per-GPU batch size") + ap.add_argument("--num-workers", type=int, default=0) + ap.add_argument("--pin-memory", action="store_true", default=False) + ap.add_argument("--precision", type=str, default="bf16") + + ap.add_argument("--context-dim", type=int, default=16) + ap.add_argument("--x-dim", type=int, default=512) + ap.add_argument("--y-dim", type=int, default=64) + + ap.add_argument("--encoder-type", type=str, default="mlp") + ap.add_argument("--num-archetypes", type=int, default=8) + ap.add_argument("--width", type=int, default=1024) + ap.add_argument("--layers", type=int, default=4) + ap.add_argument("--lr", type=float, default=1e-3) + + ap.add_argument("--buffer-batches", type=int, default=32, help="Dataset buffer size = batch_size * buffer_batches") + ap.add_argument("--data-device", type=str, choices=["auto", "cpu", "cuda"], default="auto") + ap.add_argument("--devices", type=int, default=1, help="Only used when NOT under torchrun") + + ap.add_argument("--ddp-timeout", type=int, default=180) + ap.add_argument("--seed", type=int, default=123) + ap.add_argument("--outdir", type=str, default="bench_out") + + return ap.parse_args() + + +def main(): + set_env_defaults() + args = parse_args() + + if args.data_device == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" # ensure no accidental CUDA use + + res = run_bench(args) + + if is_global_zero(): + path = save_result(args.outdir, res) + print( + "\nResult:", + json.dumps(res.__dict__, indent=2), + ) + print(f"\nSaved → {path}") + + +if __name__ == "__main__": + main() diff --git a/scale_bench_networks.py b/scale_bench_networks.py new file mode 100644 index 00000000..2c6125a4 --- /dev/null +++ b/scale_bench_networks.py @@ -0,0 +1,623 @@ +#!/usr/bin/env python3 +""" +scale_bench_networks.py + +A torchrun-friendly DDP scaling benchmark for Contextualized *Networks* lightning modules +(e.g., ContextualizedCorrelation, ContextualizedMarkovGraph, NOTMAD). + +Design goals (to reveal true scaling): + - Fixed number of optimizer steps (not epochs) so each run does identical work. + - Optional GPU-resident synthetic dataset to remove CPU dataloading/transfer bottlenecks. + - Measures only the *steady-state* region (warmup steps excluded). + - Uses Lightning DDP under torchrun correctly (devices=1 per process). + - No validation, no logging, no checkpoints. + +------------------------------------------------------------ +Quick start (single node, 1..4 GPUs) +------------------------------------------------------------ + +# 0) NICs (optional) +ls -1 /sys/class/net +ip -o link show | awk -F': ' '{print NR-1": "$2}' + +# 1) Minimal, safe NCCL/torch env (no hard-coded eth0): +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export TOKENIZERS_PARALLELISM=false +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export NCCL_DEBUG=WARN +export NCCL_P2P_DISABLE=0 +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=$(ls /sys/class/net | grep -E '^(ens|enp|eno|eth|bond|ib)' | head -n1) +[ -z "$NCCL_SOCKET_IFNAME" ] && export NCCL_SOCKET_IFNAME="^lo,docker0" + +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# 2) Kill any stragglers (optional) +pkill -f scale_bench_networks.py || true +pkill -f torchrun || true + +# 3) Runs (IMPORTANT: --batch-size is PER GPU) + +# Correlation networks +torchrun --standalone --nproc_per_node=1 scale_bench_networks.py \ + --network correlation \ + --steps 400 --warmup-steps 50 \ + --batch-size 2048 --precision bf16 \ + --context-dim 16 --x-dim 512 \ + --encoder-type mlp --width 1024 --layers 4 \ + --num-archetypes 8 \ + --buffer-batches 32 --data-device auto \ + --outdir bench_out/corr_gpu1 + +torchrun --standalone --nproc_per_node=2 scale_bench_networks.py \ + --network correlation \ + --steps 400 --warmup-steps 50 \ + --batch-size 2048 --precision bf16 \ + --context-dim 16 --x-dim 512 \ + --encoder-type mlp --width 1024 --layers 4 \ + --num-archetypes 8 \ + --buffer-batches 32 --data-device auto \ + --outdir bench_out/corr_gpu2 + +# Markov networks (precision matrices) +torchrun --standalone --nproc_per_node=4 scale_bench_networks.py \ + --network markov \ + --steps 400 --warmup-steps 50 \ + --batch-size 1024 --precision bf16 \ + --context-dim 16 --x-dim 256 \ + --encoder-type mlp --width 512 --layers 3 \ + --num-archetypes 8 \ + --buffer-batches 32 --data-device auto \ + --outdir bench_out/markov_gpu4 + +Notes: + - If scaling is poor with --data-device=cuda (or auto on GPU), the bottleneck is + likely *real* (NCCL/topology/comm, too-small batch, CPU freq limits, etc.). + - Multi-node: remove --standalone and use --nnodes/--node_rank with a shared rdzv endpoint. +""" + +import os +import time +import json +import math +import argparse +from dataclasses import dataclass +from datetime import timedelta +from typing import Dict + +import numpy as np +import torch +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.strategies import DDPStrategy + +# ---- your package pieces ---- +from contextualized.regression.datamodules import ContextualizedRegressionDataModule +from contextualized.regression.lightning_modules import ( + ContextualizedCorrelation, + ContextualizedMarkovGraph, +) +from contextualized.dags.lightning_modules import NOTMAD + + +# ---------------- launcher/cluster helpers ---------------- +def under_torchrun() -> bool: + e = os.environ + return ("LOCAL_RANK" in e) or ("RANK" in e) or ("WORLD_SIZE" in e) + + +def world_size() -> int: + try: + return int(os.environ.get("WORLD_SIZE", "1")) + except Exception: + return 1 + + +def global_rank() -> int: + try: + return int(os.environ.get("RANK", "0")) + except Exception: + return 0 + + +def local_rank() -> int: + try: + return int(os.environ.get("LOCAL_RANK", "0")) + except Exception: + return 0 + + +def is_global_zero() -> bool: + return global_rank() == 0 + + +# ---------------- env + perf ---------------- +def set_env_defaults(): + os.environ.setdefault("OMP_NUM_THREADS", "1") + os.environ.setdefault("MKL_NUM_THREADS", "1") + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + + os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") + os.environ.setdefault("NCCL_DEBUG", "WARN") + os.environ.setdefault("NCCL_P2P_DISABLE", "0") + os.environ.setdefault("NCCL_IB_DISABLE", "1") + + if "NCCL_SOCKET_IFNAME" not in os.environ: + try: + ifaces = [ + d + for d in os.listdir("/sys/class/net") + if os.path.isdir(f"/sys/class/net/{d}") + ] + cand = next((i for i in ifaces if i not in ("lo", "docker0")), None) + os.environ["NCCL_SOCKET_IFNAME"] = cand or "^lo,docker0" + except Exception: + os.environ["NCCL_SOCKET_IFNAME"] = "^lo,docker0" + + # TF32 / matmul speedups (safe for throughput benchmarking) + if torch.cuda.is_available(): + try: + torch.backends.cuda.matmul.allow_tf32 = True + except Exception: + pass + try: + torch.set_float32_matmul_precision("high") + except Exception: + pass + try: + torch.backends.cudnn.benchmark = True + except Exception: + pass + + if under_torchrun() and torch.cuda.is_available(): + try: + torch.cuda.set_device(local_rank()) + except Exception: + pass + + if is_global_zero(): + keys = [ + "NCCL_DEBUG", + "NCCL_IB_DISABLE", + "NCCL_P2P_DISABLE", + "NCCL_SOCKET_IFNAME", + "TORCH_NCCL_ASYNC_ERROR_HANDLING", + ] + print("DDP/NCCL env:", {k: os.environ.get(k) for k in keys}) + if torch.cuda.is_available(): + print( + "CUDA:", + { + "torch": torch.__version__, + "lightning": pl.__version__, + "gpus_visible": torch.cuda.device_count(), + }, + ) + + +def map_precision(p: str): + p = (p or "").lower() + if p in ("bf16", "bfloat16", "bf16-mixed"): + return "bf16-mixed" + if p in ("fp16", "16", "16-mixed"): + return "16-mixed" + return 32 + + +# ---------------- timing ---------------- +class SteadyStateStepTimer(Callback): + """ + Times optimizer steps in a steady-state window: + - ignore first warmup_steps + - measure next measure_steps + Assumes accumulate_grad_batches == 1. + """ + + def __init__(self, warmup_steps: int, measure_steps: int): + super().__init__() + self.warmup_steps = int(warmup_steps) + self.measure_steps = int(measure_steps) + self._seen_steps = 0 + self.step_times = [] + self._step_start_t = None + + @staticmethod + def _sync_if_cuda(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + s = self._seen_steps + if self.warmup_steps <= s < (self.warmup_steps + self.measure_steps): + self._sync_if_cuda() + self._step_start_t = time.time() + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + s = self._seen_steps + if self.warmup_steps <= s < (self.warmup_steps + self.measure_steps): + self._sync_if_cuda() + dt = time.time() - (self._step_start_t or time.time()) + self.step_times.append(dt) + self._seen_steps += 1 + + def measured_wall_time(self) -> float: + return float(sum(self.step_times)) + + +def dist_max(value: float) -> float: + """ + Returns max(value across ranks) if distributed is initialized; else returns value. + """ + try: + import torch.distributed as dist + + if dist.is_available() and dist.is_initialized(): + t = torch.tensor( + [value], device="cuda" if torch.cuda.is_available() else "cpu" + ) + dist.all_reduce(t, op=dist.ReduceOp.MAX) + return float(t.item()) + except Exception: + pass + return float(value) + + +# ---------------- synthetic data ---------------- +def make_synthetic_tensors( + n: int, + c_dim: int, + x_dim: int, + device: torch.device, + seed: int, +) -> Dict[str, torch.Tensor]: + """ + Builds a fixed synthetic buffer (not timed). Shapes: + C: (n, c_dim) + X: (n, x_dim) + Y: (n, x_dim) # for networks we follow the wrapper convention (univariate task uses y_dim=x_dim) + """ + g = torch.Generator(device=device) + # Per-rank seed to avoid identical data, while keeping identical shapes across ranks. + g.manual_seed(int(seed) + 1000 * global_rank()) + + C = torch.randn((n, c_dim), generator=g, device=device, dtype=torch.float32) + X = torch.randn((n, x_dim), generator=g, device=device, dtype=torch.float32) + Y = torch.randn((n, x_dim), generator=g, device=device, dtype=torch.float32) + return {"C": C, "X": X, "Y": Y} + + +# ---------------- model/datamodule/trainer ---------------- +def build_model(args): + """ + Robustly instantiate the selected network LightningModule. + + We pass a *superset* of kwargs and filter by the model's __init__ signature to + remain compatible with small constructor differences across implementations. + """ + import inspect + + if args.network == "correlation": + model_cls = ContextualizedCorrelation + elif args.network == "markov": + model_cls = ContextualizedMarkovGraph + elif args.network == "bayesian": + model_cls = NOTMAD + else: + raise ValueError(f"Unknown --network {args.network}") + + encoder_kwargs = {"width": args.width, "layers": args.layers, "link_fn": "identity"} + + # Common superset + kw = dict( + context_dim=args.context_dim, + x_dim=args.x_dim, + y_dim=args.x_dim, # networks wrapper convention + univariate=True, + num_archetypes=args.num_archetypes, + encoder_type=args.encoder_type, + encoder_kwargs=encoder_kwargs, + learning_rate=args.lr, + link_fn="identity", + fit_intercept=True, + loss_fn="mse", + model_regularizer="none", + ) + + # NOTMAD-specific defaults (safe baseline; tune as needed) + if args.network == "bayesian": + kw.update( + archetype_loss_params=dict( + l1=0.0, + dag=dict(loss_type="notears", params=dict(alpha=1.0, rho=1.0, s=1.0, tol=1e-8)), + init_mat=None, + num_factors=0, + factor_mat_l1=0.0, + num_archetypes=max(1, int(args.num_archetypes)), + ), + sample_specific_loss_params=dict( + l1=0.0, + dag=dict(loss_type="notears", params=dict(alpha=1.0, rho=1.0, s=1.0, tol=1e-8)), + ), + opt_params=dict( + learning_rate=args.lr, + step=50, + ), + ) + + sig = inspect.signature(model_cls.__init__) + accepts_var_kw = any( + p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() + ) + if accepts_var_kw: + return model_cls(**kw) + + filtered = {k: v for k, v in kw.items() if k in sig.parameters} + # Basic required-arg check (only for explicit signatures) + required = [ + name + for name, p in sig.parameters.items() + if name != "self" + and p.default is inspect._empty + and p.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + ] + missing = [r for r in required if r not in filtered] + if missing: + raise TypeError( + f"{model_cls.__name__}.__init__ missing required args {missing}. " + f"Accepted params in script: {sorted(filtered.keys())}. " + f"Signature: {sig}" + ) + return model_cls(**filtered) + + +def build_dm(args, C, X, Y) -> ContextualizedRegressionDataModule: + """ + Uses the same DataModule family as the wrapper (consistent batch structure). + IMPORTANT: If data lives on CUDA, we force num_workers=0. + """ + n = int(C.shape[0]) + n_train = max(1, int(0.98 * n)) + train_idx = np.arange(0, n_train, dtype=np.int64) + val_idx = np.arange(n_train, n, dtype=np.int64) + + task_type = args.task_type + if task_type is None: + # Networks wrappers use the univariate convention. + task_type = "singletask_univariate" + + dm = ContextualizedRegressionDataModule( + C=C, + X=X, + Y=Y, + task_type=task_type, + train_idx=train_idx, + val_idx=val_idx, + test_idx=None, + predict_idx=None, + train_batch_size=args.batch_size, + val_batch_size=args.batch_size, + test_batch_size=args.batch_size, + predict_batch_size=args.batch_size, + num_workers=args.num_workers, + pin_memory=bool(args.pin_memory), + persistent_workers=bool(args.num_workers > 0), + drop_last=True, + shuffle_train=False, + shuffle_eval=False, + dtype=torch.float, + ) + dm.prepare_data() + dm.setup() + return dm + + +def build_trainer(args, timer: SteadyStateStepTimer) -> pl.Trainer: + if torch.cuda.is_available(): + accelerator = "gpu" + # Under torchrun: each process uses exactly 1 device + devices = 1 if under_torchrun() else min(args.devices, torch.cuda.device_count()) + strategy = ( + DDPStrategy( + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=True, + timeout=timedelta(seconds=args.ddp_timeout), + ) + if (under_torchrun() or devices > 1) + else "auto" + ) + else: + accelerator = "cpu" + devices = 1 + strategy = "auto" + + max_steps = int(args.warmup_steps + args.steps) + + return pl.Trainer( + accelerator=accelerator, + devices=devices, + strategy=strategy, + precision=map_precision(args.precision), + max_steps=max_steps, + max_epochs=10_000, # irrelevant when max_steps is set + logger=False, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + num_sanity_val_steps=0, + log_every_n_steps=50, + callbacks=[timer], + inference_mode=False, + detect_anomaly=False, + accumulate_grad_batches=1, + limit_val_batches=0, # no validation + use_distributed_sampler=False # IMPORTANT: our synthetic buffer is already identical-sized per rank + ) + + +# ---------------- benchmark runner ---------------- +@dataclass +class Result: + network: str + world_size: int + batch_size_per_gpu: int + global_batch_size: int + warmup_steps: int + measured_steps: int + measured_wall_s: float + throughput_samples_per_s: float + per_gpu_throughput_samples_per_s: float + avg_step_s: float + p95_step_s: float + data_device: str + + +def run_bench(args) -> Result: + ws = world_size() if under_torchrun() else int(args.devices) + + # Resolve data device + if args.data_device == "cpu": + dev = torch.device("cpu") + elif args.data_device == "cuda": + dev = torch.device("cuda", local_rank()) if torch.cuda.is_available() else torch.device("cpu") + else: # auto + dev = torch.device("cuda", local_rank()) if torch.cuda.is_available() else torch.device("cpu") + + # Dataloader workers cannot safely handle CUDA tensors + if dev.type == "cuda" and args.num_workers != 0: + if is_global_zero(): + print("NOTE: forcing --num-workers=0 because data-device is CUDA.") + args.num_workers = 0 + + # Build fixed synthetic buffer (not timed) + n = int(args.batch_size * args.buffer_batches) + tensors = make_synthetic_tensors( + n=n, + c_dim=args.context_dim, + x_dim=args.x_dim, + device=dev, + seed=args.seed, + ) + + dm = build_dm(args, tensors["C"], tensors["X"], tensors["Y"]) + model = build_model(args) + + timer = SteadyStateStepTimer(args.warmup_steps, args.steps) + trainer = build_trainer(args, timer) + + if is_global_zero(): + print( + "\nConfig:", + json.dumps( + { + "network": args.network, + "torchrun": under_torchrun(), + "world_size": ws, + "local_rank": local_rank(), + "batch_size_per_gpu": args.batch_size, + "global_batch_size": args.batch_size * ws, + "steps_measured": args.steps, + "steps_warmup": args.warmup_steps, + "buffer_samples_per_rank": n, + "data_device": str(dev), + "precision": map_precision(args.precision), + "task_type": args.task_type or "singletask_univariate", + }, + indent=2, + ), + ) + + trainer.fit(model, train_dataloaders=dm.train_dataloader()) + + measured_wall = timer.measured_wall_time() + measured_wall = dist_max(measured_wall) # slowest rank dictates + + measured_steps = int(args.steps) + global_batch = int(args.batch_size * ws) + samples_total = global_batch * measured_steps + throughput = samples_total / max(measured_wall, 1e-12) + per_gpu = throughput / max(ws, 1) + + step_times = timer.step_times[:] if timer.step_times else [float("nan")] + avg_step = float(np.mean(step_times)) + p95_step = float(np.percentile(step_times, 95)) if len(step_times) > 1 else float("nan") + + return Result( + network=str(args.network), + world_size=int(ws), + batch_size_per_gpu=int(args.batch_size), + global_batch_size=int(global_batch), + warmup_steps=int(args.warmup_steps), + measured_steps=int(measured_steps), + measured_wall_s=float(measured_wall), + throughput_samples_per_s=float(throughput), + per_gpu_throughput_samples_per_s=float(per_gpu), + avg_step_s=float(avg_step), + p95_step_s=float(p95_step), + data_device=str(dev), + ) + + +def save_result(outdir: str, res: Result) -> str: + os.makedirs(outdir, exist_ok=True) + path = os.path.join(outdir, "result.json") + with open(path, "w") as f: + json.dump(res.__dict__, f, indent=2) + return path + + +# ---------------- main ---------------- +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("--network", type=str, choices=["correlation", "markov", "bayesian"], default="correlation") + + ap.add_argument("--steps", type=int, default=400, help="Measured optimizer steps") + ap.add_argument("--warmup-steps", type=int, default=50, help="Warmup steps excluded from timing") + + ap.add_argument("--batch-size", type=int, default=2048, help="Per-GPU batch size") + ap.add_argument("--num-workers", type=int, default=0) + ap.add_argument("--pin-memory", action="store_true", default=False) + ap.add_argument("--precision", type=str, default="bf16") + + ap.add_argument("--context-dim", type=int, default=16) + ap.add_argument("--x-dim", type=int, default=512) + + ap.add_argument("--encoder-type", type=str, default="mlp") + ap.add_argument("--num-archetypes", type=int, default=8) + ap.add_argument("--width", type=int, default=1024) + ap.add_argument("--layers", type=int, default=4) + ap.add_argument("--lr", type=float, default=1e-3) + + ap.add_argument("--buffer-batches", type=int, default=32, + help="Per-rank buffer size = batch_size * buffer_batches") + ap.add_argument("--data-device", type=str, choices=["auto", "cpu", "cuda"], default="auto") + + ap.add_argument("--task-type", type=str, default=None, + help="Override task_type if needed (default: singletask_univariate)") + + ap.add_argument("--devices", type=int, default=1, help="Only used when NOT under torchrun") + ap.add_argument("--ddp-timeout", type=int, default=180) + ap.add_argument("--seed", type=int, default=123) + ap.add_argument("--outdir", type=str, default="bench_out") + return ap.parse_args() + + +def main(): + set_env_defaults() + args = parse_args() + + if args.data_device == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" # prevent accidental CUDA use + + res = run_bench(args) + + if is_global_zero(): + path = save_result(args.outdir, res) + print("\nResult:", json.dumps(res.__dict__, indent=2)) + print(f"\nSaved → {path}") + + +if __name__ == "__main__": + main() From ca4b9a78142a177b3cf44123bd17dcb598543247 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Sun, 11 Jan 2026 23:33:30 -0500 Subject: [PATCH 15/19] updates to file formatting --- README_HPC.md | 147 +++++ contextualized/data.py | 2 +- contextualized/easy/ContextualizedNetworks.py | 360 ++++++++--- .../easy/ContextualizedRegressor.py | 12 +- .../easy/wrappers/SKLearnWrapper.py | 407 +++++++----- contextualized/modules.py | 41 +- contextualized/regression/datamodules.py | 38 +- contextualized/regression/datasets.py | 33 +- .../regression/lightning_modules.py | 587 +++++------------- contextualized/regression/trainers.py | 138 ++-- network_scaling_heavy.py | 465 +++++--------- networks_pert_scale_bench.py | 94 +-- scale_bench.py | 512 +++++++-------- scale_bench_networks.py | 149 +---- 14 files changed, 1375 insertions(+), 1610 deletions(-) create mode 100644 README_HPC.md diff --git a/README_HPC.md b/README_HPC.md new file mode 100644 index 00000000..76bff00a --- /dev/null +++ b/README_HPC.md @@ -0,0 +1,147 @@ +# HPC and DDP Usage Guide + +This package supports single GPU, multi-GPU, and HPC clusters using PyTorch Lightning. The primary goal is consistent behavior across CPU, GPU, and DDP environments. The secondary goal is correct ordering of predictions under DDP. + +## Core Principles + +**Map-style datasets**: Lightning can shard data with DistributedSampler when using `torch.utils.data.Dataset`. + +**LightningDataModule pattern**: Builds datasets from user arrays and manages train/val/test splits consistently. + +**Stable prediction ordering**: Return prediction payloads that include stable indices, then gather and reorder on rank 0. + +## Dataset and Batch Structure + +Datasets are map-style (`torch.utils.data.Dataset`). Each `__getitem__` returns a dict with standard keys: `contexts`, `predictors`, `outcomes`, plus indexing keys for DDP-safe prediction assembly: `idx` (dataset local index), `orig_idx` (stable original row ID). For multitask variants, additional keys include `sample_idx`, `outcome_idx`, `predictor_idx`. DDP sharding can change sample order and pad the last batch, so these indices enable correct reconstruction. + +## DataModule Usage + +The DataModule converts numpy or pandas arrays into tensors and slices by split indices. It passes `orig_idx` into each dataset so every sample reports its original row ID. + +**Split configuration**: Provide `train_idx`, `val_idx`, `test_idx`, `predict_idx` directly, or provide a `splitter(C, X, Y)` callable that returns `(train_idx, val_idx, test_idx)`. If `predict_idx` is not provided, it defaults to `test_idx` when present, otherwise defaults to the full range. + +**Example instantiation**: +```python +from contextualized.regression.datamodules import ContextualizedRegressionDataModule + +dm = ContextualizedRegressionDataModule( + C=C, X=X, Y=Y, + task_type="singletask_multivariate", + train_idx=train_idx, + val_idx=val_idx, + test_idx=test_idx, + predict_idx=predict_idx, + train_batch_size=32, + val_batch_size=64, + test_batch_size=64, + predict_batch_size=64, + num_workers=4, + pin_memory=True, + persistent_workers=True, + drop_last=False, + shuffle_train=True, + shuffle_eval=False, +) + +trainer.fit(model, datamodule=dm) +preds = trainer.predict(model, datamodule=dm) +``` + +If calling loaders manually, call `dm.setup(stage="predict")` before retrieving the dataloader. + +## DDP Prediction Mechanics + +Prediction assembly occurs only on rank 0. Non-rank-0 processes return `None`. This prevents duplicated outputs and keeps the API stable under DDP. + +**Predict step payload**: Each `LightningModule.predict_step` returns a dict containing indices (`idx`, `orig_idx`, and optional task indices), batch content when needed (`contexts`, `predictors`), and model outputs (`betas`, `mus`, and sometimes `correlations`). Tensors are detached and moved to CPU inside `predict_step` to keep GPU memory stable. + +**Gather and reorder process**: The trainer helper packs requested keys from local predict outputs, gathers to rank 0, merges payloads by concatenating on axis 0, stable sorts by `idx` when present (otherwise `orig_idx`), and deduplicates padded samples from `DistributedSampler`. If using `dist.gather_object`, ensure the collective backend supports object gathers (commonly Gloo). If your default process group is NCCL, use a Gloo group for object gather or switch to tensor all-gather. + +## Training Configuration + +**Single node, single GPU**: Use `Trainer(accelerator="gpu", devices=1, strategy="auto")`. Recommended data settings: `pin_memory=True`, `num_workers` tuned to CPU count and batch size, `persistent_workers=True` if `num_workers > 0`. + +**Single node, multi-GPU**: Two supported patterns exist. + +Pattern A (Lightning spawns processes): Use a normal Python launch and let Lightning spawn processes. Set `devices` to GPU count and `strategy="ddp"` or `DDPStrategy(...)`. Example for 2 GPUs: `Trainer(accelerator="gpu", devices=2, strategy="ddp")`. + +Pattern B (torchrun launch, recommended for clusters): Use torchrun to launch one process per GPU. Each process should use `devices=1` (or omit devices) and set `strategy="ddp"` or `DDPStrategy(...)`. Example for 2 GPUs on one node: +```bash +export CUDA_VISIBLE_DEVICES=0,1 +torchrun --standalone --nproc_per_node=2 your_script.py +``` + +Recommended environment variables for NCCL: +```bash +export OMP_NUM_THREADS=1 +export MKL_NUM_THREADS=1 +export TOKENIZERS_PARALLELISM=false +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export NCCL_DEBUG=WARN +export NCCL_IB_DISABLE=1 +export NCCL_SOCKET_IFNAME=eno1 +``` + +Set `NCCL_SOCKET_IFNAME` to your actual NIC. If you see hangs, confirm `WORLD_SIZE`, `RANK`, `LOCAL_RANK` are correct. + +## Sklearn-Style Wrappers + +The sklearn-style API uses the same DataModule path when available, keeping the public surface area stable. +```python +from contextualized.easy.wrappers import ContextualizedRegressor + +m = ContextualizedRegressor( + num_archetypes=8, + encoder_type="mlp", + encoder_kwargs={"width": 256, "layers": 2, "link_fn": "identity"}, + trainer_kwargs={ + "accelerator": "gpu", + "devices": 1, + "max_epochs": 50, + }, +) + +m.fit(C_train, X_train, Y_train, val_split=0.2) +yhat = m.predict(C_test, X_test) +betas, mus = m.predict_params(C_test) +``` + +**DDP behavior for wrappers**: Under DDP, rank 0 returns arrays and non-rank-0 returns `None`. If running under torchrun, only rank 0 should consume outputs. Do not stack outputs across ranks. + +## Batch Sizing for Strong Scaling + +Strong scaling means fixed global batch size. Per-GPU batch is `global_batch_size / world_size`. Set one global batch size and keep it fixed. As you add GPUs, let the per-GPU batch shrink. This reduces OOM risk when scaling up GPU count. + +If you hit OOM: reduce global batch first, then reduce model width or layers, then reduce DataLoader workers if CPU memory becomes the bottleneck. + +## Data Movement and Pinned Memory + +Pinned CPU buffers improve host-to-device transfer speed. Recommended for GPU training: `pin_memory=True` in DataLoader, use pinned host tensors in synthetic or streaming benchmarks. For real datasets, use normal CPU tensors and let DataLoader pin memory. + +## Common Pitfalls + +**Wrong batch dict keys**: Models expect batch dict keys `contexts`, `predictors`, `outcomes`. If using custom datasets, match these names. + +**Device mismatch under torchrun**: Under torchrun, processes already map to devices. Use `devices=1` per process (or omit devices), and each process uses `LOCAL_RANK` as its CUDA device. + +**Dropping samples during eval**: Do not drop samples for validation, test, or predict. It breaks ordering assumptions. Use `drop_last=False` for eval loaders. + +**Expecting prediction output on every rank**: Prediction helpers are rank-0-only by design. Non-rank-0 returns `None`. + +## Minimal DDP Launch Recipe + +Single node, 4 GPUs: +```bash +export CUDA_VISIBLE_DEVICES=0,1,2,3 +torchrun --standalone --nproc_per_node=4 train_script.py +``` + +Trainer settings that work well: `strategy=DDPStrategy(find_unused_parameters=False, broadcast_buffers=False)`, mixed precision on GPU if stable for your model, logging sync only on epoch metrics (not per step). + +## Benchmark Pattern for Scaling + +A good scaling benchmark uses fixed global batch size, uses a warmup window then measures steady state, uses already batched inputs to reduce DataLoader overhead, uses pinned CPU memory when measuring host-to-device transfer. Loss computation should be simple and shape-safe. Keep it in the benchmark harness. Avoid clever reshapes that depend on internal model conventions. + +## Summary + +HPC readiness comes from three components: map-style datasets and a DataModule so Lightning can shard correctly, prediction payloads that include stable indices and are CPU-friendly, rank-0 gather and reorder so predictions match user-expected order. Following these patterns ensures multi-GPU training and prediction are stable and repeatable. \ No newline at end of file diff --git a/contextualized/data.py b/contextualized/data.py index 5c8fa627..a43bc03c 100644 --- a/contextualized/data.py +++ b/contextualized/data.py @@ -1,5 +1,5 @@ import torch -from pytorch_lightning import LightningDataModule +from lightning.pytorch import LightningDataModule from contextualized.regression.datasets import MultivariateDataset, UnivariateDataset, MultitaskMultivariateDataset, MultitaskUnivariateDataset from sklearn.model_selection import train_test_split diff --git a/contextualized/easy/ContextualizedNetworks.py b/contextualized/easy/ContextualizedNetworks.py index 04e11aa2..ec3e2155 100644 --- a/contextualized/easy/ContextualizedNetworks.py +++ b/contextualized/easy/ContextualizedNetworks.py @@ -1,16 +1,8 @@ """ sklearn-like interface to Contextualized Networks. - -CPU/DDP FIXES (drag-and-drop): -1) When using a LightningDataModule outside Trainer.fit/predict, you MUST call - dm.setup(stage="predict") before dm.predict_dataloader(). -2) Under DDP, prediction helpers are rank-0 only (by design in your trainers/wrapper). - We therefore avoid constructing np.array([None,...]) and return None on non-rank0, - while still executing the full per-model predict loop on all ranks to prevent - collective mismatches/hangs. """ -from typing import List, Tuple, Union, Optional +from typing import * import numpy as np import torch @@ -32,10 +24,12 @@ def _is_distributed() -> bool: + """Returns True if torch.distributed is available and initialized.""" return dist.is_available() and dist.is_initialized() def _rank() -> int: + """Returns the current distributed rank, defaulting to 0 when not distributed.""" if _is_distributed(): return dist.get_rank() return 0 @@ -58,9 +52,25 @@ def _split_train_data( shuffle: bool = True, **kwargs, ) -> Tuple[np.ndarray, Optional[np.ndarray]]: - """ - Override only to change the default behavior (networks do not *require* Y), - but keep the signature compatible with SKLearnWrapper._split_train_data. + """Splits data into train and test sets. + + Notes: + This override exists to set the default behavior for networks (Y is not required), + while preserving compatibility with SKLearnWrapper._split_train_data. + + Args: + C (np.ndarray): Contextual features for each sample. + X (np.ndarray): The data matrix. + Y (Optional[np.ndarray], optional): Optional targets. Defaults to None. + Y_required (bool, optional): Whether Y is required. Defaults to False. + val_split (Optional[float], optional): Validation split fraction. Defaults to None. + random_state (Optional[int], optional): Random state for splitting. Defaults to None. + shuffle (bool, optional): Whether to shuffle before splitting. Defaults to True. + **kwargs: Additional keyword arguments forwarded to the base implementation. + + Returns: + Tuple[np.ndarray, Optional[np.ndarray]]: The train/test split outputs as returned by + SKLearnWrapper._split_train_data. """ return super()._split_train_data( C, @@ -86,14 +96,30 @@ def predict_networks( Tuple[List[np.ndarray], List[np.ndarray]], None, ]: + """Predicts context-specific networks given contextual features. + + Notes: + Under DDP, prediction helpers are rank-0 only (by design in the trainers/wrapper). + In such cases, this method returns None on non-rank-0 processes. + + Args: + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + with_offsets (bool, optional): If True, returns both the network parameters and + offsets (when available). Defaults to False. + individual_preds (bool, optional): If True, returns the predictions for each + bootstrap. Defaults to False. + **kwargs: Keyword arguments forwarded to predict_params. + + Returns: + Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray, np.ndarray], + Tuple[List[np.ndarray], List[np.ndarray]], None]: + The predicted network parameters (and offsets if with_offsets is True). + Returned as lists of individual bootstraps if individual_preds is True. + Returns None on non-rank-0 under DDP. """ - Predicts context-specific network parameters (and offsets if available). - - DDP behavior: - - rank0 returns arrays/tuples - - non-rank0 returns None - """ - out = self.predict_params(C, individual_preds=individual_preds, uses_y=False, **kwargs) + out = self.predict_params( + C, individual_preds=individual_preds, uses_y=False, **kwargs + ) if out is None: return None @@ -101,36 +127,72 @@ def predict_networks( if betas is None: return None - return (betas, mus) if with_offsets else betas + if with_offsets: + return betas, mus + return betas def predict_X( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs ) -> Union[np.ndarray, List[np.ndarray]]: - """ - Reconstructs X via predicted networks using the base wrapper predict(). + """Reconstructs the data matrix based on predicted contextualized networks and + the true data matrix. + + Useful for measuring reconstruction error or for imputation. + + Args: + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + X (np.ndarray): The data matrix (n_samples, n_features). + individual_preds (bool, optional): If True, returns the predictions for each + bootstrap. Defaults to False. + **kwargs: Keyword arguments for the Lightning trainer's prediction method. + + Returns: + Union[np.ndarray, List[np.ndarray]]: The predicted data matrix, or matrices for + each bootstrap if individual_preds is True (n_samples, n_features). """ return self.predict(C, X, individual_preds=individual_preds, **kwargs) class ContextualizedCorrelationNetworks(ContextualizedNetworks): """ - Contextualized Correlation Networks reveal context-varying feature correlations. + Contextualized Correlation Networks reveal context-varying feature correlations, + interaction strengths, and dependencies in feature groups. + Uses the Contextualized Networks model. + + Notes: + This implementation includes CPU/DDP-safe prediction behavior: + - When using a LightningDataModule outside Trainer.fit/predict, setup(stage="predict") + is called before predict_dataloader(). + - Under DDP, only rank-0 returns numpy outputs; non-rank-0 returns None, while still + executing the per-model predict loop to avoid collective mismatches/hangs. """ def __init__(self, **kwargs): - super().__init__(ContextualizedCorrelation, [], [], CorrelationTrainer, **kwargs) + super().__init__( + ContextualizedCorrelation, [], [], CorrelationTrainer, **kwargs + ) def predict_correlation( self, C: np.ndarray, individual_preds: bool = True, squared: bool = True ) -> Union[np.ndarray, List[np.ndarray], None]: - """ - Returns per-sample correlation matrices (or squared correlations). - - DDP behavior: - - All ranks must execute the full per-model predict loop to avoid collective mismatches. - - rank0 returns arrays - - non-rank0 returns None (rank-0-only trainer outputs are propagated) + """Predicts context-specific correlations between features. + + Notes: + Under DDP, only rank-0 returns numpy outputs. If any per-model prediction returns + None (rank-0-only behavior), this method returns None. + + Args: + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + individual_preds (bool, optional): If True, returns the predictions for each + bootstrap. Defaults to True. + squared (bool, optional): If True, returns the squared correlations. Defaults to True. + + Returns: + Union[np.ndarray, List[np.ndarray], None]: + The predicted context-specific correlation matrices, or matrices for each + bootstrap if individual_preds is True (n_samples, n_features, n_features). + Returns None on non-rank-0 under DDP. """ C_scaled = self._maybe_scale_C(C) Y_zero = np.zeros((len(C_scaled), self.x_dim), dtype=np.float32) @@ -144,28 +206,30 @@ def predict_correlation( train_batch_size=self._init_kwargs["data"].get("train_batch_size", 16), val_batch_size=self._init_kwargs["data"].get("val_batch_size", 16), test_batch_size=self._init_kwargs["data"].get("test_batch_size", 16), - predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", 16), + predict_batch_size=self._init_kwargs["data"].get( + "predict_batch_size", 16 + ), num_workers=self._init_kwargs["data"].get("num_workers", 0), pin_memory=self._init_kwargs["data"].get( "pin_memory", (self.accelerator in ("cuda", "gpu")) ), - persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), + persistent_workers=self._init_kwargs["data"].get( + "persistent_workers", False + ), drop_last=False, shuffle_train=False, shuffle_eval=False, dtype=self._init_kwargs["data"].get("dtype", torch.float), ), - task_type="singletask_univariate", # correlation uses univariate convention + task_type="singletask_univariate", ) - # FIX (1): setup before calling predict_dataloader() when not using Trainer.predict(datamodule=...) dm.setup(stage="predict") pred_loader = dm.predict_dataloader() saw_none = False - rhos_list = [] + rhos_list: List[np.ndarray] = [] - # FIX (2): call predict for all models on all ranks; only rank0 accumulates results for i in range(len(self.models)): rho_i = self.trainers[i].predict_correlation(self.models[i], pred_loader) if rho_i is None: @@ -179,17 +243,35 @@ def predict_correlation( rhos = np.array(rhos_list) if individual_preds: - return np.square(rhos) if squared else rhos + if squared: + return np.square(rhos) + return rhos mean_rhos = np.mean(rhos, axis=0) - return np.square(mean_rhos) if squared else mean_rhos + if squared: + return np.square(mean_rhos) + return mean_rhos def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False ) -> Union[np.ndarray, List[np.ndarray], None]: - """ - Measures mean-squared reconstruction errors between true X and reconstructed X_hat. - (Behavior unchanged; this already handles N_hat != N_true.) + """Measures mean-squared errors. + + Notes: + This method computes MSEs from reconstructions returned by predict_X, including + handling potential (bootstrap, sample, feature) or (bootstrap, sample, feature, feature) + tensor shapes, and handling N_hat != N_true by truncation to min(N_hat, N_true). + + Args: + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + X (np.ndarray): The data matrix (n_samples, n_features). + individual_preds (bool, optional): If True, returns the MSEs for each bootstrap. + Defaults to False. + + Returns: + Union[np.ndarray, List[np.ndarray], None]: + The mean-squared errors for each sample, or for each bootstrap if + individual_preds is True (n_samples). Returns None on non-rank-0 under DDP. """ X_hat = self.predict_X(C, X, individual_preds=True) if X_hat is None: @@ -243,12 +325,22 @@ def measure_mses( residuals = X_hat - X_true mses = (residuals**2).mean(axis=(-1, -2)) - return mses if individual_preds else mses.mean(axis=0) + if individual_preds: + return mses + return mses.mean(axis=0) class ContextualizedMarkovNetworks(ContextualizedNetworks): """ - Contextualized Markov Networks (Gaussian precision matrices). + Contextualized Markov Networks reveal context-varying feature dependencies, cliques, + and modules. + + Implemented as Contextualized Gaussian Precision Matrices, directly interpretable as + Markov Networks. + + Notes: + This implementation includes CPU/DDP-safe prediction behavior analogous to + ContextualizedCorrelationNetworks.predict_correlation. """ def __init__(self, **kwargs): @@ -257,13 +349,22 @@ def __init__(self, **kwargs): def predict_precisions( self, C: np.ndarray, individual_preds: bool = True ) -> Union[np.ndarray, List[np.ndarray], None]: - """ - Predicts context-specific precision matrices. - - DDP behavior: - - All ranks must execute the full per-model predict loop to avoid collective mismatches. - - rank0 returns arrays - - non-rank0 returns None (rank-0-only trainer outputs are propagated) + """Predicts context-specific precision matrices. + + Notes: + Under DDP, only rank-0 returns numpy outputs. If any per-model prediction returns + None (rank-0-only behavior), this method returns None. + + Args: + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + individual_preds (bool, optional): If True, returns the predictions for each + bootstrap. Defaults to True. + + Returns: + Union[np.ndarray, List[np.ndarray], None]: + The predicted context-specific precision matrices, or matrices for each + bootstrap if individual_preds is True (n_samples, n_features, n_features). + Returns None on non-rank-0 under DDP. """ C_scaled = self._maybe_scale_C(C) Y_zero = np.zeros((len(C_scaled), self.x_dim), dtype=np.float32) @@ -277,12 +378,16 @@ def predict_precisions( train_batch_size=self._init_kwargs["data"].get("train_batch_size", 16), val_batch_size=self._init_kwargs["data"].get("val_batch_size", 16), test_batch_size=self._init_kwargs["data"].get("test_batch_size", 16), - predict_batch_size=self._init_kwargs["data"].get("predict_batch_size", 16), + predict_batch_size=self._init_kwargs["data"].get( + "predict_batch_size", 16 + ), num_workers=self._init_kwargs["data"].get("num_workers", 0), pin_memory=self._init_kwargs["data"].get( "pin_memory", (self.accelerator in ("cuda", "gpu")) ), - persistent_workers=self._init_kwargs["data"].get("persistent_workers", False), + persistent_workers=self._init_kwargs["data"].get( + "persistent_workers", False + ), drop_last=False, shuffle_train=False, shuffle_eval=False, @@ -291,14 +396,12 @@ def predict_precisions( task_type="singletask_univariate", ) - # FIX (1): setup before calling predict_dataloader() dm.setup(stage="predict") pred_loader = dm.predict_dataloader() saw_none = False - prec_list = [] + prec_list: List[np.ndarray] = [] - # FIX (2): call predict for all models on all ranks; only rank0 accumulates results for i in range(len(self.models)): p_i = self.trainers[i].predict_precision(self.models[i], pred_loader) if p_i is None: @@ -310,42 +413,76 @@ def predict_precisions( return None precisions = np.array(prec_list) - return precisions if individual_preds else np.mean(precisions, axis=0) + if individual_preds: + return precisions + return np.mean(precisions, axis=0) def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False ) -> Union[np.ndarray, List[np.ndarray], None]: - """ - Measures mean-squared reconstruction errors using precision-implied betas/mus. + """Measures mean-squared errors. + + Args: + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + X (np.ndarray): The data matrix (n_samples, n_features). + individual_preds (bool, optional): If True, returns the MSEs for each bootstrap. + Defaults to False. + + Returns: + Union[np.ndarray, List[np.ndarray], None]: + The mean-squared errors for each sample, or for each bootstrap if + individual_preds is True (n_samples). Returns None on non-rank-0 under DDP. """ out = self.predict_networks(C, individual_preds=True, with_offsets=True) if out is None: return None betas, mus = out - mses = np.zeros((len(betas), len(C))) + mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples F = X.shape[-1] for b in range(len(betas)): for i in range(F): preds = np.array( - [X[j].dot(betas[b, j, i, :]) + mus[b, j, i] for j in range(len(X))] + [ + X[j].dot(betas[b, j, i, :]) + mus[b, j, i] + for j in range(len(X)) + ] ) residuals = X[:, i] - preds mses[b, :] += residuals**2 / F - return mses if individual_preds else np.mean(mses, axis=0) + + if individual_preds: + return mses + return np.mean(mses, axis=0) class ContextualizedBayesianNetworks(ContextualizedNetworks): """ - Contextualized Bayesian Networks (NOTMAD): context-dependent DAGs. + Contextualized Bayesian Networks and Directed Acyclic Graphs (DAGs) reveal + context-dependent causal relationships, effect sizes, and variable ordering. + + Uses the NOTMAD model. + + Notes: + This wrapper preserves the HPC/DDP behavior: rank-0 produces arrays, non-rank-0 + returns None where applicable. """ def _parse_private_init_kwargs(self, **kwargs): + """Parses the kwargs for the NOTMAD model. + + Args: + **kwargs: Keyword arguments for the NOTMAD model, including the encoder, + archetype loss, sample-specific loss, and optimization parameters. + + Returns: + List[str]: Names of kwargs consumed/handled by this parser. """ - Parse NOTMAD kwargs into model init dicts. - """ + # Encoder Parameters self._init_kwargs["model"]["encoder_kwargs"] = { - "type": kwargs.pop("encoder_type", self._init_kwargs["model"]["encoder_type"]), + "type": kwargs.pop( + "encoder_type", self._init_kwargs["model"]["encoder_type"] + ), "params": { "width": self.constructor_kwargs["encoder_kwargs"]["width"], "layers": self.constructor_kwargs["encoder_kwargs"]["layers"], @@ -353,7 +490,10 @@ def _parse_private_init_kwargs(self, **kwargs): }, } - archetype_dag_loss_type = kwargs.pop("archetype_dag_loss_type", DEFAULT_DAG_LOSS_TYPE) + # Archetype-specific parameters + archetype_dag_loss_type = kwargs.pop( + "archetype_dag_loss_type", DEFAULT_DAG_LOSS_TYPE + ) self._init_kwargs["model"]["archetype_loss_params"] = { "l1": kwargs.get("archetype_l1", 0.0), "dag": kwargs.get( @@ -371,19 +511,22 @@ def _parse_private_init_kwargs(self, **kwargs): "factor_mat_l1": kwargs.pop("factor_mat_l1", 0), "num_archetypes": kwargs.pop("num_archetypes", 16), } + if self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] <= 0: print( "WARNING: num_archetypes is 0. NOTMAD requires archetypes. Setting num_archetypes to 16." ) self._init_kwargs["model"]["archetype_loss_params"]["num_archetypes"] = 16 + # Possibly update values with convenience parameters for param, value in self._init_kwargs["model"]["archetype_loss_params"]["dag"][ "params" ].items(): - self._init_kwargs["model"]["archetype_loss_params"]["dag"]["params"][param] = ( - kwargs.pop(f"archetype_{param}", value) - ) + self._init_kwargs["model"]["archetype_loss_params"]["dag"]["params"][ + param + ] = kwargs.pop(f"archetype_{param}", value) + # Sample-specific parameters sample_specific_dag_loss_type = kwargs.pop( "sample_specific_dag_loss_type", DEFAULT_DAG_LOSS_TYPE ) @@ -395,18 +538,23 @@ def _parse_private_init_kwargs(self, **kwargs): "loss_type": sample_specific_dag_loss_type, "params": kwargs.pop( "sample_specific_dag_loss_params", - DEFAULT_DAG_LOSS_PARAMS[sample_specific_dag_loss_type].copy(), + DEFAULT_DAG_LOSS_PARAMS[ + sample_specific_dag_loss_type + ].copy(), ), }, ), } - for param, value in self._init_kwargs["model"]["sample_specific_loss_params"]["dag"][ - "params" - ].items(): - self._init_kwargs["model"]["sample_specific_loss_params"]["dag"]["params"][param] = ( - kwargs.pop(f"sample_specific_{param}", value) - ) + # Possibly update values with convenience parameters + for param, value in self._init_kwargs["model"]["sample_specific_loss_params"][ + "dag" + ]["params"].items(): + self._init_kwargs["model"]["sample_specific_loss_params"]["dag"]["params"][ + param + ] = kwargs.pop(f"sample_specific_{param}", value) + + # Optimization parameters self._init_kwargs["model"]["opt_params"] = { "learning_rate": kwargs.pop("learning_rate", 1e-3), "step": kwargs.pop("step", 50), @@ -458,34 +606,74 @@ def __init__(self, **kwargs): def predict_params( self, C: np.ndarray, **kwargs ) -> Union[np.ndarray, List[np.ndarray], None]: + """Predicts context-specific Bayesian network parameters as linear coefficients + in a linear structural equation model (SEM). + + Args: + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + **kwargs: Keyword arguments for contextualized.dags.GraphTrainer.predict_params. + + Returns: + Union[np.ndarray, List[np.ndarray], None]: + The linear coefficients of the predicted context-specific Bayesian network + parameters (n_samples, n_features, n_features). Returned as lists of + individual bootstraps if individual_preds is True. Returns None on + non-rank-0 under DDP. """ - Predicts context-specific Bayesian network parameters (SEM coefficients). - """ + # No mus for NOTMAD at present. return super().predict_params(C, model_includes_mus=False, **kwargs) def predict_networks( self, C: np.ndarray, project_to_dag: bool = True, **kwargs ) -> Union[np.ndarray, List[np.ndarray], None]: - """ - Predicts context-specific Bayesian networks (optionally projected to DAG). + """Predicts context-specific Bayesian networks. + + Args: + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + project_to_dag (bool, optional): If True, guarantees returned graphs are DAGs by + trimming edges until acyclicity is satisified. Defaults to True. + **kwargs: Keyword arguments for contextualized.dags.GraphTrainer.predict_params. + + Returns: + Union[np.ndarray, List[np.ndarray], None]: + The linear coefficients of the predicted context-specific Bayesian network + parameters (n_samples, n_features, n_features). Returned as lists of + individual bootstraps if individual_preds is True. Returns None on + non-rank-0 under DDP. """ if kwargs.pop("with_offsets", False): print("No offsets can be returned by NOTMAD.") - betas = self.predict_params(C, uses_y=False, project_to_dag=project_to_dag, **kwargs) + betas = self.predict_params( + C, uses_y=False, project_to_dag=project_to_dag, **kwargs + ) return betas def measure_mses( self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs ) -> Union[np.ndarray, List[np.ndarray], None]: - """ - Measures mean-squared errors of DAG-based reconstruction. + """Measures mean-squared errors. + + Args: + C (np.ndarray): Contextual features for each sample (n_samples, n_context_features). + X (np.ndarray): The data matrix (n_samples, n_features). + individual_preds (bool, optional): If True, returns the MSEs for each bootstrap. + Defaults to False. + **kwargs: Keyword arguments for contextualized.dags.GraphTrainer.predict_params. + + Returns: + Union[np.ndarray, List[np.ndarray], None]: + The mean-squared errors for each sample, or for each bootstrap if + individual_preds is True (n_samples). Returns None on non-rank-0 under DDP. """ betas = self.predict_networks(C, individual_preds=True, **kwargs) if betas is None: return None - mses = np.zeros((len(betas), len(C))) + mses = np.zeros((len(betas), len(C))) # n_bootstraps x n_samples for b in range(len(betas)): X_pred = dag_pred_np(X, betas[b]) mses[b, :] = np.mean((X - X_pred) ** 2, axis=1) - return mses if individual_preds else np.mean(mses, axis=0) + + if individual_preds: + return mses + return np.mean(mses, axis=0) diff --git a/contextualized/easy/ContextualizedRegressor.py b/contextualized/easy/ContextualizedRegressor.py index 932fa971..8e1a350f 100644 --- a/contextualized/easy/ContextualizedRegressor.py +++ b/contextualized/easy/ContextualizedRegressor.py @@ -7,7 +7,7 @@ ContextualizedRegression, ) from contextualized.easy.wrappers import SKLearnWrapper -from contextualized.regression.trainers import RegressionTrainer # <-- updated import +from contextualized.regression.trainers import RegressionTrainer class ContextualizedRegressor(SKLearnWrapper): @@ -32,17 +32,14 @@ def __init__(self, **kwargs): elif self.num_archetypes > 0: constructor = ContextualizedRegression else: - print( - f""" - Was told to construct a ContextualizedRegressor with {self.num_archetypes} - archetypes, but this should be a non-negative integer.""" + raise ValueError( + f"num_archetypes must be a non-negative integer, got {self.num_archetypes}." ) - # Wrapper will accept these; no need to expose DataModule specifics here. + extra_model_kwargs = ["base_param_predictor", "base_y_predictor", "y_dim"] extra_data_kwargs = ["Y_val"] trainer_constructor = RegressionTrainer - super().__init__( constructor, extra_model_kwargs, @@ -51,6 +48,5 @@ def __init__(self, **kwargs): **kwargs, ) - # Preserve legacy behavior that Y is expected/required for regression fits def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs): return super()._split_train_data(C, X, Y, Y_required=True, **kwargs) \ No newline at end of file diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index efee1506..12792e3c 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -1,13 +1,5 @@ """ An sklearn-like wrapper for Contextualized models. - -Design goals (compat + correctness): -- Preserve prior public API: fit(), predict(), predict_params(), kwarg routing, normalization. -- Default to the DDP-safe, map-style ContextualizedRegressionDataModule when available. -- Avoid redundant DDP gather/ordering logic in the wrapper: - * DDP-safe predict assembly is handled by contextualized.regression.trainers.RegressionTrainer - (predict_y / predict_params) together with lightning_modules.py predict_step payloads. -- Keep legacy compatibility for older models that still expose `model.dataloader(...)`. """ import copy @@ -17,9 +9,9 @@ import numpy as np import torch import torch.distributed as dist -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.strategies import DDPStrategy +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks.early_stopping import EarlyStopping +from lightning.pytorch.strategies import DDPStrategy from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler @@ -29,15 +21,15 @@ # Prefer the new, DDP-safe DataModule path when available. try: from contextualized.regression.datamodules import ContextualizedRegressionDataModule -except Exception: # pragma: no cover - ContextualizedRegressionDataModule = None # type: ignore +except Exception: + ContextualizedRegressionDataModule = None DEFAULT_LEARNING_RATE = 1e-3 DEFAULT_N_BOOTSTRAPS = 1 DEFAULT_ES_PATIENCE = 1 DEFAULT_VAL_BATCH_SIZE = 16 -DEFAULT_TRAIN_BATCH_SIZE = 1 # keep legacy default +DEFAULT_TRAIN_BATCH_SIZE = 1 DEFAULT_TEST_BATCH_SIZE = 16 DEFAULT_VAL_SPLIT = 0.2 DEFAULT_ENCODER_TYPE = "mlp" @@ -74,15 +66,22 @@ class SKLearnWrapper: Args: base_constructor (callable/class): LightningModule constructor for the model. - extra_model_kwargs (list[str] or set[str]): extra kw names allowed in "model". - extra_data_kwargs (list[str] or set[str]): extra kw names allowed in "data". + extra_model_kwargs (list[str] or set[str]): Extra kw names allowed in "model". + extra_data_kwargs (list[str] or set[str]): Extra kw names allowed in "data". trainer_constructor (class): Trainer class (should provide predict_y / predict_params for DDP-safe inference). - **kwargs: routed into model/data/trainer/wrapper based on acceptable_kwargs. + n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. + encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". + loss_fn (torch.nn.Module, optional): Loss function. Defaults to LOSSES["mse"]. + link_fn (torch.nn.Module, optional): Link function. Defaults to LINK_FUNCTIONS["identity"]. + alpha (float, optional): Regularization strength. Defaults to 0.0. + mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to + context-specific parameters or context-specific offsets. + l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 + vs l2 parameter norms. + normalize (bool, optional): If True, automatically standardize inputs during training and inverse-transform + predictions. Defaults to False. """ - # ---------------------------- - # Defaults / initialization - # ---------------------------- def _set_defaults(self) -> None: self.default_learning_rate = DEFAULT_LEARNING_RATE self.default_n_bootstraps = DEFAULT_N_BOOTSTRAPS @@ -110,14 +109,11 @@ def __init__( self.base_constructor = base_constructor self.trainer_constructor = trainer_constructor - # Optional: allow callers to pass default trainer kwargs in a single dict self._trainer_init_kwargs = kwargs.pop("trainer_kwargs", None) self.n_bootstraps: int = 1 self.models: Optional[List[Any]] = None self.trainers: Optional[List[Any]] = None - - # Keep legacy attribute for external users who expect it self.dataloaders: Optional[Dict[str, List[Any]]] = None self.normalize: bool = bool(kwargs.pop("normalize", self.default_normalize)) @@ -127,10 +123,8 @@ def __init__( self.x_dim: Optional[int] = None self.y_dim: Optional[int] = None - # Lightning expects "gpu" / "cpu" (legacy wrapper used "gpu") self.accelerator: str = "gpu" if torch.cuda.is_available() else "cpu" - # Expanded routing (superset of legacy); safe for backward compatibility. self.acceptable_kwargs: Dict[str, List[str]] = { "data": [ "train_batch_size", @@ -161,7 +155,6 @@ def __init__( "context_dim", "x_dim", "y_dim", - # legacy-friendly knobs "width", "layers", "encoder_link_fn", @@ -199,8 +192,12 @@ def __init__( self._update_acceptable_kwargs("model", extra_model_kwargs) self._update_acceptable_kwargs("data", extra_data_kwargs) - self._update_acceptable_kwargs("model", kwargs.pop("remove_model_kwargs", []), acceptable=False) - self._update_acceptable_kwargs("data", kwargs.pop("remove_data_kwargs", []), acceptable=False) + self._update_acceptable_kwargs( + "model", kwargs.pop("remove_model_kwargs", []), acceptable=False + ) + self._update_acceptable_kwargs( + "data", kwargs.pop("remove_data_kwargs", []), acceptable=False + ) self.convenience_kwargs = [ "alpha", @@ -212,64 +209,84 @@ def __init__( "encoder_link_fn", ] - # Build model-constructor defaults (and allow legacy + new encoder_kwargs styles) self.constructor_kwargs = self._organize_constructor_kwargs(**kwargs) - # Apply convenience overrides (legacy keys) if "encoder_kwargs" in self.constructor_kwargs: ek = self.constructor_kwargs["encoder_kwargs"] ek["width"] = kwargs.pop("width", ek.get("width", self.default_encoder_width)) ek["layers"] = kwargs.pop("layers", ek.get("layers", self.default_encoder_layers)) - ek["link_fn"] = kwargs.pop("encoder_link_fn", ek.get("link_fn", self.default_encoder_link_fn)) + ek["link_fn"] = kwargs.pop( + "encoder_link_fn", ek.get("link_fn", self.default_encoder_link_fn) + ) else: - self.constructor_kwargs["width"] = kwargs.pop("width", self.constructor_kwargs.get("width", self.default_encoder_width)) - self.constructor_kwargs["layers"] = kwargs.pop("layers", self.constructor_kwargs.get("layers", self.default_encoder_layers)) + self.constructor_kwargs["width"] = kwargs.pop( + "width", self.constructor_kwargs.get("width", self.default_encoder_width) + ) + self.constructor_kwargs["layers"] = kwargs.pop( + "layers", self.constructor_kwargs.get("layers", self.default_encoder_layers) + ) self.constructor_kwargs["encoder_link_fn"] = kwargs.pop( "encoder_link_fn", self.constructor_kwargs.get("encoder_link_fn", self.default_encoder_link_fn), ) - # Store remaining kwargs to be organized by router self.not_constructor_kwargs = { - k: v for k, v in kwargs.items() if k not in self.constructor_kwargs and k not in self.convenience_kwargs + k: v + for k, v in kwargs.items() + if k not in self.constructor_kwargs and k not in self.convenience_kwargs } self._init_kwargs, unrecognized = self._organize_kwargs(**self.not_constructor_kwargs) - # Inject constructor kwargs into model bucket for k, v in self.constructor_kwargs.items(): self._init_kwargs["model"][k] = v - # Inject trainer init kwargs if isinstance(self._trainer_init_kwargs, dict): self._init_kwargs["trainer"].update(self._trainer_init_kwargs) - # Allow subclasses to swallow additional init kwargs recognized_private = set(self._parse_private_init_kwargs(**kwargs)) for kw in unrecognized: if kw not in recognized_private: print(f"Received unknown keyword argument {kw}, probably ignoring.") - # ---------------------------- - # Hooks for subclasses - # ---------------------------- def _parse_private_fit_kwargs(self, **kwargs) -> List[str]: + """ + Parse private (model-specific) kwargs passed to fit function. + Return the list of parsed kwargs. + """ return [] def _parse_private_init_kwargs(self, **kwargs) -> List[str]: + """ + Parse private (model-specific) kwargs passed to constructor. + Return the list of parsed kwargs. + """ return [] - # ---------------------------- - # Kwarg routing / organization - # ---------------------------- - def _update_acceptable_kwargs(self, category, new_kwargs, acceptable: bool = True) -> None: + def _update_acceptable_kwargs( + self, category, new_kwargs, acceptable: bool = True + ) -> None: + """ + Helper function to update the acceptable kwargs. + + If acceptable=True, the new kwargs will be added to the list of acceptable kwargs. + If acceptable=False, the new kwargs will be removed from the list of acceptable kwargs. + """ new_kwargs = list(new_kwargs) if new_kwargs is not None else [] if acceptable: - self.acceptable_kwargs[category] = list(set(self.acceptable_kwargs[category]).union(set(new_kwargs))) + self.acceptable_kwargs[category] = list( + set(self.acceptable_kwargs[category]).union(set(new_kwargs)) + ) else: - self.acceptable_kwargs[category] = list(set(self.acceptable_kwargs[category]) - set(new_kwargs)) + self.acceptable_kwargs[category] = list( + set(self.acceptable_kwargs[category]) - set(new_kwargs) + ) def _organize_kwargs(self, **kwargs) -> Tuple[Dict[str, Dict[str, Any]], List[str]]: + """ + Private helper function to organize kwargs passed to constructor or fit function. + Organizes kwargs into data, model, trainer, fit, and wrapper categories. + """ out = {cat: {} for cat in self.acceptable_kwargs} unknown: List[str] = [] for k, v in kwargs.items(): @@ -285,9 +302,7 @@ def _organize_kwargs(self, **kwargs) -> Tuple[Dict[str, Dict[str, Any]], List[st def _organize_constructor_kwargs(self, **kwargs) -> Dict[str, Any]: """ - Create default model constructor kwargs, supporting both: - - new style: encoder_kwargs={width,layers,link_fn} - - legacy style: width/layers/encoder_link_fn top-level + Helper function to set all the default constructor kwargs or changes allowed. """ ctor: Dict[str, Any] = {} @@ -300,7 +315,6 @@ def maybe_add(kw, default_val): maybe_add("encoder_type", self.default_encoder_type) maybe_add("loss_fn", LOSSES["mse"]) - # Prefer new style if allowed if "encoder_kwargs" in self.acceptable_kwargs["model"]: ctor["encoder_kwargs"] = kwargs.get( "encoder_kwargs", @@ -319,7 +333,6 @@ def maybe_add(kw, default_val): if kwargs.get("subtype_probabilities", False): ctor["encoder_link_fn"] = LINK_FUNCTIONS["softmax"] - # Regularizer if "model_regularizer" in self.acceptable_kwargs["model"]: alpha = float(kwargs.get("alpha", 0.0) or 0.0) if alpha > 0: @@ -329,13 +342,12 @@ def maybe_add(kw, default_val): kwargs.get("mu_ratio", 0.5), ) else: - ctor["model_regularizer"] = kwargs.get("model_regularizer", REGULARIZERS["none"]) + ctor["model_regularizer"] = kwargs.get( + "model_regularizer", REGULARIZERS["none"] + ) return ctor - # ---------------------------- - # Utilities - # ---------------------------- def _maybe_scale_C(self, C: np.ndarray) -> np.ndarray: if self.normalize and self.scalers["C"] is not None: return self.scalers["C"].transform(C) @@ -352,7 +364,9 @@ def _nanrobust_mean(self, arr: np.ndarray, axis: int = 0) -> np.ndarray: with np.errstate(invalid="ignore"): mean = np.nanmean(arr, axis=axis) if np.isnan(mean).any(): - raise RuntimeError("All bootstraps produced non-finite predictions for some items.") + raise RuntimeError( + "All bootstraps produced non-finite predictions for some items." + ) return mean def _default_num_workers(self, devices: int) -> int: @@ -373,7 +387,6 @@ def _safe_val_split(self, n: int, val_split: float) -> float: vs = float(val_split) if vs <= 0.0: return 0.0 - # require at least 2 validation samples for stable metrics if int(round(n * vs)) < 2: return 0.0 return vs @@ -392,19 +405,15 @@ def _resolve_train_val_arrays( random_state: int = 42, shuffle: bool = True, ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], np.ndarray, Optional[np.ndarray]]: - """ - Returns: - C_all, X_all, Y_all, train_idx, val_idx_or_None - - Supports: - - val arrays (C_val/X_val[/Y_val]) by concatenation - - otherwise uses val_split indices inside the original arrays - """ - if C_val is not None and X_val is not None and (not Y_required or Y_val is not None): - # concatenate and build disjoint train/val indices + if ( + C_val is not None + and X_val is not None + and (not Y_required or Y_val is not None) + ): n_tr = int(C.shape[0]) C_all = np.concatenate([C, C_val], axis=0) X_all = np.concatenate([X, X_val], axis=0) + if Y is None: Y_all = None else: @@ -416,7 +425,6 @@ def _resolve_train_val_arrays( val_idx = np.arange(n_tr, int(C_all.shape[0])) return C_all, X_all, Y_all, train_idx, val_idx - # split by indices within the same arrays n = int(C.shape[0]) vs = self._safe_val_split(n, val_split) if vs <= 0.0: @@ -426,7 +434,7 @@ def _resolve_train_val_arrays( np.arange(n), test_size=vs, shuffle=shuffle, - random_state=random_state, # fixed seed to keep DDP ranks consistent + random_state=random_state, ) return C, X, Y, tr_idx, va_idx @@ -444,7 +452,9 @@ def _build_datamodule( task_type: str, ): if ContextualizedRegressionDataModule is None: - raise RuntimeError("ContextualizedRegressionDataModule is not available in this installation.") + raise RuntimeError( + "ContextualizedRegressionDataModule is not available in this installation." + ) dk = { "train_batch_size": self.default_train_batch_size, @@ -484,32 +494,28 @@ def _build_datamodule( ) def _use_datamodule_for_model(self, model: Any) -> bool: - # Prefer DataModule when available and model doesn't provide legacy dataloader(). if ContextualizedRegressionDataModule is None: return False return not callable(getattr(model, "dataloader", None)) - # ---------------------------- - # Fit kwargs expansion - # ---------------------------- def _organize_and_expand_fit_kwargs(self, **kwargs) -> Dict[str, Dict[str, Any]]: + """ + Private function to organize kwargs passed to constructor or fit function. + """ organized, unrecognized = self._organize_kwargs(**kwargs) recognized_private = set(self._parse_private_fit_kwargs(**kwargs)) for kw in unrecognized: if kw not in recognized_private: print(f"Received unknown keyword argument {kw}, probably ignoring.") - # Merge init defaults (fit kwargs win) for category, cat_kwargs in self._init_kwargs.items(): for k, v in cat_kwargs.items(): organized[category].setdefault(k, v) - # Helper def maybe_add(cat: str, k: str, default_val: Any) -> None: if k in self.acceptable_kwargs[cat]: organized[cat][k] = organized[cat].get(k, default_val) - # Model dims / lr maybe_add("model", "learning_rate", self.default_learning_rate) maybe_add("model", "context_dim", self.context_dim) maybe_add("model", "x_dim", self.x_dim) @@ -518,22 +524,27 @@ def maybe_add(cat: str, k: str, default_val: Any) -> None: if organized["model"].get("num_archetypes", 1) == 0: organized["model"].pop("num_archetypes", None) - # Data defaults maybe_add("data", "train_batch_size", self.default_train_batch_size) maybe_add("data", "val_batch_size", self.default_val_batch_size) maybe_add("data", "test_batch_size", self.default_test_batch_size) - maybe_add("data", "predict_batch_size", organized["data"].get("val_batch_size", self.default_val_batch_size)) + maybe_add( + "data", + "predict_batch_size", + organized["data"].get("val_batch_size", self.default_val_batch_size), + ) - # Trainer defaults maybe_add("trainer", "accelerator", self.accelerator) organized["trainer"].setdefault("enable_progress_bar", False) organized["trainer"].setdefault("logger", False) organized["trainer"].setdefault("num_sanity_val_steps", 0) - # devices/strategy defaults for torchrun/DDP safety world = _world_size_env() + launched_externally = world > 1 and ( + os.environ.get("LOCAL_RANK") is not None or os.environ.get("RANK") is not None + ) + if "devices" not in organized["trainer"]: - organized["trainer"]["devices"] = world if world > 1 else 1 + organized["trainer"]["devices"] = 1 if launched_externally else (world if world > 1 else 1) devices_cfg = organized["trainer"].get("devices", 1) if isinstance(devices_cfg, int): @@ -543,12 +554,16 @@ def maybe_add(cat: str, k: str, default_val: Any) -> None: else: devices = 1 - if world > 1 and devices != world: + if world > 1 and (not launched_externally) and devices != world: if _is_main_process(): - print(f"[WARNING] WORLD_SIZE={world} but devices={devices}; overriding devices -> {world}.") + print( + f"[WARNING] WORLD_SIZE={world} but devices={devices}; " + f"overriding devices -> {world}." + ) organized["trainer"]["devices"] = world devices = world + if "strategy" not in organized["trainer"]: if devices > 1 or world > 1: organized["trainer"]["strategy"] = DDPStrategy( @@ -559,32 +574,33 @@ def maybe_add(cat: str, k: str, default_val: Any) -> None: else: organized["trainer"]["strategy"] = "auto" - # Performance defaults if self.accelerator == "gpu": organized["trainer"].setdefault("precision", "16-mixed") else: organized["trainer"].setdefault("precision", 32) - # DataLoader perf defaults maybe_add("data", "num_workers", self._default_num_workers(devices)) maybe_add("data", "pin_memory", self.accelerator == "gpu") - maybe_add("data", "persistent_workers", organized["data"].get("num_workers", 0) > 0) + maybe_add( + "data", + "persistent_workers", + organized["data"].get("num_workers", 0) > 0, + ) maybe_add("data", "drop_last", (devices > 1 or world > 1)) maybe_add("data", "shuffle_train", True) maybe_add("data", "shuffle_eval", False) maybe_add("data", "dtype", torch.float) - # Wrapper defaults maybe_add("wrapper", "n_bootstraps", self.default_n_bootstraps) - # Validation split val_split = float(organized["data"].get("val_split", self.default_val_split)) organized["data"]["val_split"] = val_split - # Callbacks: preserve legacy behavior (EarlyStopping + ModelCheckpoint by default) - use_val = self._safe_val_split(10, val_split) > 0.0 # placeholder check; refined at fit time + use_val = self._safe_val_split(10, val_split) > 0.0 es_patience = organized["wrapper"].get("es_patience", self.default_es_patience) - es_monitor = organized["wrapper"].get("es_monitor", "val_loss" if use_val else "train_loss") + es_monitor = organized["wrapper"].get( + "es_monitor", "val_loss" if use_val else "train_loss" + ) es_mode = organized["wrapper"].get("es_mode", "min") es_verbose = organized["wrapper"].get("es_verbose", False) es_min_delta = organized["wrapper"].get("es_min_delta", 0.0) @@ -593,10 +609,8 @@ def maybe_add(cat: str, k: str, default_val: Any) -> None: if cb_ctors is None: cb_ctors = [] - # Default: enable checkpointing unless explicitly disabled organized["trainer"].setdefault("enable_checkpointing", True) - # Add EarlyStopping only if patience > 0 if es_patience is not None and int(es_patience) > 0: cb_ctors.append( lambda i: EarlyStopping( @@ -620,18 +634,25 @@ def maybe_add(cat: str, k: str, default_val: Any) -> None: organized["trainer"]["callback_constructors"] = cb_ctors return organized - # ---------------------------- - # Public API - # ---------------------------- def fit(self, *args, **kwargs) -> None: """ Fit contextualized model to data. - Backward compatible with legacy: - - fit(C, X) -> uses X as targets (Contextualized Networks behavior) - - fit(C, X, Y) - - fit(..., Y=...) override - - supports C_val/X_val/Y_val and/or val_split + Args: + C (np.ndarray): Context array of shape (n_samples, n_context_features) + X (np.ndarray): Predictor array of shape (N, n_features) + Y (np.ndarray, optional): Target array of shape (N, n_targets). Defaults to None. + max_epochs (int, optional): Maximum number of epochs to train for. Defaults to 1. + learning_rate (float, optional): Learning rate for optimizer. Defaults to 1e-3. + val_split (float, optional): Proportion of data to use for validation and early stopping. Defaults to 0.2. + n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. + train_batch_size (int, optional): Batch size for training. Defaults to 1. + val_batch_size (int, optional): Batch size for validation. Defaults to 16. + test_batch_size (int, optional): Batch size for testing. Defaults to 16. + es_patience (int, optional): Number of epochs to wait before early stopping. Defaults to 1. + es_monitor (str, optional): Metric to monitor for early stopping. Defaults to "val_loss". + es_mode (str, optional): Mode for early stopping. Defaults to "min". + es_verbose (bool, optional): Whether to print early stopping updates. Defaults to False. """ self.models, self.trainers = [], [] self.dataloaders = {"train": [], "val": [], "test": []} @@ -656,7 +677,6 @@ def fit(self, *args, **kwargs) -> None: if Y is not None: Y = np.asarray(Y) - # Normalize / scale if self.normalize: if self.scalers["C"] is None: self.scalers["C"] = StandardScaler().fit(C) @@ -669,7 +689,6 @@ def fit(self, *args, **kwargs) -> None: self.context_dim = int(C.shape[-1]) self.x_dim = int(X.shape[-1]) - # Legacy semantics: if Y not provided, use X as targets if Y is None: Y = X else: @@ -677,10 +696,8 @@ def fit(self, *args, **kwargs) -> None: Y = np.expand_dims(Y, 1) if self.normalize and self.scalers["Y"] is not None: - # already fitted (e.g., multiple fits). keep behavior consistent. pass - # Scale Y if it's continuous (avoid scaling binary) if self.normalize and not np.array_equal(np.unique(Y), np.array([0, 1])): if self.scalers["Y"] is None: self.scalers["Y"] = StandardScaler().fit(Y) @@ -689,23 +706,27 @@ def fit(self, *args, **kwargs) -> None: self.y_dim = int(Y.shape[-1]) organized = self._organize_and_expand_fit_kwargs(**kwargs) - self.n_bootstraps = int(organized["wrapper"].get("n_bootstraps", self.n_bootstraps)) + self.n_bootstraps = int( + organized["wrapper"].get("n_bootstraps", self.n_bootstraps) + ) - # Determine val split now that we know n val_split = float(organized["data"].get("val_split", self.default_val_split)) val_split = self._safe_val_split(int(C.shape[0]), val_split) organized["data"]["val_split"] = val_split use_val = val_split > 0.0 - # If no val, retarget default monitors to train_loss (legacy had a try/except; we make it explicit) if not use_val: - # Adjust any callback_constructors that still monitor val_loss new_ctors = [] for ctor in organized["trainer"].get("callback_constructors", []): + def _wrap_ctor(_ctor): def _inner(i): cb = _ctor(i) - if isinstance(cb, EarlyStopping) and isinstance(getattr(cb, "monitor", ""), str) and cb.monitor.startswith("val_"): + if ( + isinstance(cb, EarlyStopping) + and isinstance(getattr(cb, "monitor", ""), str) + and cb.monitor.startswith("val_") + ): return EarlyStopping( monitor="train_loss", mode=getattr(cb, "mode", "min"), @@ -713,27 +734,27 @@ def _inner(i): verbose=getattr(cb, "verbose", False), min_delta=getattr(cb, "min_delta", 0.0), ) - if isinstance(cb, ModelCheckpoint) and isinstance(getattr(cb, "monitor", ""), str) and cb.monitor.startswith("val_"): - # If no validation, checkpointing on val_loss is meaningless; keep callback but disable monitor. - cb.monitor = None # PL will checkpoint on epoch end + if ( + isinstance(cb, ModelCheckpoint) + and isinstance(getattr(cb, "monitor", ""), str) + and cb.monitor.startswith("val_") + ): + cb.monitor = None return cb + return _inner + new_ctors.append(_wrap_ctor(ctor)) organized["trainer"]["callback_constructors"] = new_ctors - - # Also avoid running validation loops organized["trainer"].setdefault("limit_val_batches", 0) - # Optional explicit val arrays C_val = organized["data"].get("C_val", None) X_val = organized["data"].get("X_val", None) Y_val = organized["data"].get("Y_val", None) - # Task type univariate_flag = bool(organized["model"].get("univariate", False)) task_type = "singletask_univariate" if univariate_flag else "singletask_multivariate" - # Build final arrays + indices (supports separate val arrays by concatenation) C_all, X_all, Y_all, train_idx, val_idx = self._resolve_train_val_arrays( C, X, @@ -746,23 +767,19 @@ def _inner(i): ) for b in range(self.n_bootstraps): - # Construct model kwargs (do NOT pass wrapper-only keys) model_kwargs = dict(organized["model"]) - # univariate is a wrapper concern; base_constructor should already encode univariate vs multivariate model_kwargs.pop("univariate", None) model = self.base_constructor(**model_kwargs) use_dm = self._use_datamodule_for_model(model) - # Build trainer kwargs + callbacks trainer_kwargs = copy.deepcopy(organized["trainer"]) cb_ctors = trainer_kwargs.pop("callback_constructors", []) callbacks = list(trainer_kwargs.get("callbacks", [])) callbacks.extend([ctor(b) for ctor in cb_ctors]) trainer_kwargs["callbacks"] = callbacks - # Ensure checkpoint directories exist for cb in callbacks: if isinstance(cb, ModelCheckpoint): try: @@ -770,7 +787,6 @@ def _inner(i): except Exception: pass - # Construct trainer via factory that handles env/plugins correctly from contextualized.regression.trainers import make_trainer_with_env trainer = make_trainer_with_env(self.trainer_constructor, **trainer_kwargs) @@ -789,11 +805,13 @@ def _inner(i): ) if _is_main_process(): - print(f"[RANK {_rank()}] train_idx[:5]={train_idx[:5]}, val_idx[:5]={val_idx[:5] if val_idx is not None else None}") + print( + f"[RANK {_rank()}] train_idx[:5]={train_idx[:5]}, " + f"val_idx[:5]={val_idx[:5] if val_idx is not None else None}" + ) trainer.fit(model, datamodule=dm, **organized["fit"]) - # Keep dataloaders for compatibility (best-effort) try: dm.setup("fit") self.dataloaders["train"].append(dm.train_dataloader()) @@ -805,16 +823,35 @@ def _inner(i): self.dataloaders["test"].append(None) else: - # Legacy path: model provides dataloader(C, X, Y, batch_size=...) - train_data = [C_all[train_idx], X_all[train_idx], Y_all[train_idx]] if Y_all is not None else [C_all[train_idx], X_all[train_idx]] + train_data = ( + [C_all[train_idx], X_all[train_idx], Y_all[train_idx]] + if Y_all is not None + else [C_all[train_idx], X_all[train_idx]] + ) + val_data = None if use_val and val_idx is not None: - val_data = [C_all[val_idx], X_all[val_idx], Y_all[val_idx]] if Y_all is not None else [C_all[val_idx], X_all[val_idx]] + val_data = ( + [C_all[val_idx], X_all[val_idx], Y_all[val_idx]] + if Y_all is not None + else [C_all[val_idx], X_all[val_idx]] + ) + + train_dl = model.dataloader( + *train_data, + batch_size=organized["data"].get( + "train_batch_size", self.default_train_batch_size + ), + ) - train_dl = model.dataloader(*train_data, batch_size=organized["data"].get("train_batch_size", self.default_train_batch_size)) val_dl = None if val_data is not None: - val_dl = model.dataloader(*val_data, batch_size=organized["data"].get("val_batch_size", self.default_val_batch_size)) + val_dl = model.dataloader( + *val_data, + batch_size=organized["data"].get( + "val_batch_size", self.default_val_batch_size + ), + ) try: trainer.fit(model, train_dl, val_dl, **organized["fit"]) @@ -825,8 +862,10 @@ def _inner(i): self.dataloaders["val"].append(val_dl) self.dataloaders["test"].append(None) - # Load best checkpoint (legacy behavior) if present - ckpt_cb = next((cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)), None) + ckpt_cb = next( + (cb for cb in trainer.callbacks if isinstance(cb, ModelCheckpoint)), + None, + ) if ckpt_cb is not None and getattr(ckpt_cb, "best_model_path", None): best_path = ckpt_cb.best_model_path if isinstance(best_path, str) and best_path and os.path.exists(best_path): @@ -842,20 +881,32 @@ def _inner(i): return None - def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs): + def predict( + self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, **kwargs + ) -> Union[np.ndarray, List[np.ndarray], None]: + """Predict outcomes from context C and predictors X. + + Args: + C (np.ndarray): Context array of shape (n_samples, n_context_features) + X (np.ndarray): Predictor array of shape (N, n_features) + individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. + + Returns: + Union[np.ndarray, List[np.ndarray], None]: Predicted outcomes. If individual_preds is True, returns + predictions for each bootstrap. Returns None if any trainer returns None. + """ if self.models is None or self.trainers is None: raise ValueError("Trying to predict with a model that hasn't been trained yet.") C = np.asarray(C) X = np.asarray(X) - Cq = self._maybe_scale_C(C) Xq = self._maybe_scale_X(X) - # Build a DDP-safe predict loader via DataModule when possible preds_all: List[np.ndarray] = [] + saw_none = False - for i, (model, trainer) in enumerate(zip(self.models, self.trainers)): + for model, trainer in zip(self.models, self.trainers): if not hasattr(trainer, "predict_y"): raise RuntimeError( "Trainer does not implement predict_y(). " @@ -865,13 +916,14 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, use_dm = self._use_datamodule_for_model(model) if use_dm: - # zeros outcomes for predict Yq = np.zeros((len(Cq), int(self.y_dim or 1)), dtype=np.float32) - task_type = "singletask_univariate" if bool(getattr(model, "hparams", {}).get("univariate", False)) else "singletask_multivariate" - # Prefer wrapper's univariate flag if present univariate_flag = bool(self._init_kwargs.get("model", {}).get("univariate", False)) - task_type = "singletask_univariate" if univariate_flag else "singletask_multivariate" + task_type = ( + "singletask_univariate" + if univariate_flag + else "singletask_multivariate" + ) dm = self._build_datamodule( C=Cq, @@ -887,15 +939,25 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, dm.setup("predict") dl = dm.predict_dataloader() else: - dl = model.dataloader(Cq, Xq, np.zeros((len(Cq), int(self.y_dim or 1))), batch_size=kwargs.get("predict_batch_size", self.default_val_batch_size)) + dl = model.dataloader( + Cq, + Xq, + np.zeros((len(Cq), int(self.y_dim or 1))), + batch_size=kwargs.get( + "predict_batch_size", self.default_val_batch_size + ), + ) yhat = trainer.predict_y(model, dl, **kwargs) if yhat is None: - # DDP non-rank0: avoid duplicating work/outputs - return None + saw_none = True + continue preds_all.append(np.asarray(yhat, dtype=float)) + if saw_none: + return None + predictions = np.array(preds_all, dtype=float) if individual_preds: @@ -905,7 +967,8 @@ def predict(self, C: np.ndarray, X: np.ndarray, individual_preds: bool = False, if bad.any(): num_bad_boots = np.unique(np.where(bad)[0]).size print( - f"Warning: {num_bad_boots}/{len(preds_all)} bootstraps produced non-finite predictions; excluding them from the ensemble." + f"Warning: {num_bad_boots}/{len(preds_all)} bootstraps produced " + f"non-finite predictions; excluding them from the ensemble." ) out = self._nanrobust_mean(predictions, axis=0) @@ -923,10 +986,23 @@ def predict_params( individual_preds: bool = False, model_includes_mus: bool = True, **kwargs, - ) -> Union[ - np.ndarray, - Tuple[np.ndarray, np.ndarray], - ]: + ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], Tuple[None, None], None]: + """ + Predict context-specific model parameters from context C. + + Args: + C (np.ndarray): Context array of shape (n_samples, n_context_features) + individual_preds (bool, optional): Whether to return individual model predictions for each bootstrap. + Defaults to False, averaging across bootstraps. + model_includes_mus (bool, optional): Whether the model includes context-specific offsets (mu). + Defaults to True. + + Returns: + Union[np.ndarray, Tuple[np.ndarray, np.ndarray], Tuple[None, None], None]: + If model_includes_mus is True, returns (betas, mus); otherwise returns betas. + If individual_preds is True, returns arrays stacked over bootstraps. + Returns (None, None) or None if any trainer returns None. + """ if self.models is None or self.trainers is None: raise ValueError("Trying to predict with a model that hasn't been trained yet.") @@ -937,6 +1013,7 @@ def predict_params( betas_list: List[np.ndarray] = [] mus_list: List[np.ndarray] = [] + saw_none = False for model, trainer in zip(self.models, self.trainers): if not hasattr(trainer, "predict_params"): @@ -949,10 +1026,18 @@ def predict_params( if use_dm: X_zero = np.zeros((len(Cq), int(self.x_dim or 1)), dtype=np.float32) - Y_zero = np.zeros((len(Cq), int(self.y_dim or 1)), dtype=np.float32) if uses_y else None + Y_zero = ( + np.zeros((len(Cq), int(self.y_dim or 1)), dtype=np.float32) + if uses_y + else None + ) univariate_flag = bool(self._init_kwargs.get("model", {}).get("univariate", False)) - task_type = "singletask_univariate" if univariate_flag else "singletask_multivariate" + task_type = ( + "singletask_univariate" + if univariate_flag + else "singletask_multivariate" + ) dm = self._build_datamodule( C=Cq, @@ -969,13 +1054,21 @@ def predict_params( dl = dm.predict_dataloader() else: if uses_y: - dl = model.dataloader(Cq, np.zeros((len(Cq), int(self.x_dim or 1))), np.zeros((len(Cq), int(self.y_dim or 1)))) + dl = model.dataloader( + Cq, + np.zeros((len(Cq), int(self.x_dim or 1))), + np.zeros((len(Cq), int(self.y_dim or 1))), + ) else: - dl = model.dataloader(Cq, np.zeros((len(Cq), int(self.x_dim or 1)))) + dl = model.dataloader( + Cq, + np.zeros((len(Cq), int(self.x_dim or 1))), + ) out = trainer.predict_params(model, dl, **kwargs) if out is None or (isinstance(out, tuple) and out[0] is None): - return (None, None) if model_includes_mus else None # DDP non-rank0 + saw_none = True + continue if model_includes_mus: b, m = out @@ -984,7 +1077,11 @@ def predict_params( else: betas_list.append(np.asarray(out)) + if saw_none: + return (None, None) if model_includes_mus else None + betas = np.array(betas_list) + if model_includes_mus: mus = np.array(mus_list) if individual_preds: diff --git a/contextualized/modules.py b/contextualized/modules.py index 596d033b..52c7c6ff 100644 --- a/contextualized/modules.py +++ b/contextualized/modules.py @@ -7,6 +7,24 @@ from contextualized.functions import LINK_FUNCTIONS +def _resolve_link_fn(maybe_link): + """ + Accepts either: + - a string key (looked up in LINK_FUNCTIONS), or + - a callable (returned as-is, including functools.partial) + """ + if isinstance(maybe_link, str): + try: + return LINK_FUNCTIONS[maybe_link] + except KeyError as e: + raise KeyError( + f"Unknown link_fn '{maybe_link}'. " + f"Valid options: {list(LINK_FUNCTIONS.keys())}" + ) from e + if callable(maybe_link): + return maybe_link + raise TypeError(f"link_fn must be str or callable, got {type(maybe_link).__name__}") + class SoftSelect(nn.Module): """ @@ -60,30 +78,10 @@ def set_archetypes(self, archetypes): class Explainer(SoftSelect): - """ - 2D subtype-archetype parameter sharing - """ - + """ 2D subtype-archetype parameter sharing """ def __init__(self, k, out_shape): super().__init__((k,), out_shape) -def _resolve_link_fn(maybe_link): - """ - Accepts either: - - a string key (looked up in LINK_FUNCTIONS), or - - a callable (returned as-is, including functools.partial) - """ - if isinstance(maybe_link, str): - try: - return LINK_FUNCTIONS[maybe_link] - except KeyError as e: - raise KeyError( - f"Unknown link_fn '{maybe_link}'. " - f"Valid options: {list(LINK_FUNCTIONS.keys())}" - ) from e - if callable(maybe_link): - return maybe_link - raise TypeError(f"link_fn must be str or callable, got {type(maybe_link).__name__}") @@ -139,7 +137,6 @@ def __init__( self.input_dim = input_dim self.output_dim = output_dim - # Internal NAM pieces should be identity-linked; the global link is applied after summation. per_feat_link = "identity" self.nams = nn.ModuleList( diff --git a/contextualized/regression/datamodules.py b/contextualized/regression/datamodules.py index ff2df7c3..8c6b601b 100644 --- a/contextualized/regression/datamodules.py +++ b/contextualized/regression/datamodules.py @@ -31,11 +31,9 @@ def _to_tensor(x: TensorLike, dtype: torch.dtype) -> torch.Tensor: return x.to(dtype=dtype, copy=False) if isinstance(x, (pd.DataFrame, pd.Series)): x = x.to_numpy(copy=False) - # np.ndarray -> avoid copy where possible return torch.as_tensor(x, dtype=dtype) - def _maybe_index(x: torch.Tensor, idx: IndexLike) -> torch.Tensor: if idx is None: return x @@ -44,9 +42,9 @@ def _maybe_index(x: torch.Tensor, idx: IndexLike) -> torch.Tensor: if isinstance(idx, np.ndarray): idx = torch.as_tensor(idx, dtype=torch.long) return x[idx] - # assume Sequence[int] return x[torch.as_tensor(idx, dtype=torch.long)] + def _to_index_tensor(idx: IndexLike) -> Optional[torch.Tensor]: """Normalize an index-like into a 1D CPU LongTensor.""" if idx is None: @@ -56,9 +54,9 @@ def _to_index_tensor(idx: IndexLike) -> Optional[torch.Tensor]: elif isinstance(idx, np.ndarray): out = torch.as_tensor(idx, dtype=torch.long, device="cpu") else: - # assume Sequence[int] out = torch.as_tensor(idx, dtype=torch.long, device="cpu") - return out.view(-1) # ensure 1D + return out.view(-1) + class ContextualizedRegressionDataModule(pl.LightningDataModule): """ @@ -80,7 +78,6 @@ def __init__( Y: Optional[TensorLike], *, task_type: str, - # splits: pass explicit index arrays OR a splitter callable train_idx: IndexLike = None, val_idx: IndexLike = None, test_idx: IndexLike = None, @@ -89,7 +86,6 @@ def __init__( Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], Tuple[IndexLike, IndexLike, IndexLike]] ] = None, - # dataloader config train_batch_size: int = 32, val_batch_size: int = 32, test_batch_size: int = 32, @@ -101,7 +97,6 @@ def __init__( shuffle_train: bool = True, shuffle_eval: bool = False, dtype: torch.dtype = torch.float, - ): super().__init__() if task_type not in TASK_TO_DATASET: @@ -111,19 +106,16 @@ def __init__( ) self.task_type = task_type - # raw inputs (convert in setup) self._C_raw = C self._X_raw = X self._Y_raw = Y - # split config self.train_idx = train_idx self.val_idx = val_idx self.test_idx = test_idx self.predict_idx = predict_idx self.splitter = splitter - # dl config self.train_batch_size = train_batch_size self.val_batch_size = val_batch_size self.test_batch_size = test_batch_size @@ -136,8 +128,6 @@ def __init__( self.shuffle_eval = shuffle_eval self.dtype = dtype - - # will be set in setup() self.C: Optional[torch.Tensor] = None self.X: Optional[torch.Tensor] = None self.Y: Optional[torch.Tensor] = None @@ -147,37 +137,30 @@ def __init__( self.ds_test = None self.ds_predict = None - # One-time downloads or heavy ops would go here; we have none. def prepare_data(self) -> None: pass def setup(self, stage: Optional[str] = None) -> None: - # Convert inputs to tensors C = _to_tensor(self._C_raw, self.dtype) X = _to_tensor(self._X_raw, self.dtype) Y = None if self._Y_raw is None else _to_tensor(self._Y_raw, self.dtype) - # Basic shape sanity could be added here if desired. - - # If no explicit indices were given, allow a splitter to define them. if self.train_idx is None and self.val_idx is None and self.test_idx is None: if self.splitter is not None: tr, va, te = self.splitter(C, X, Y) self.train_idx, self.val_idx, self.test_idx = tr, va, te - # If predict_idx not given, default to test indices (or full range if all None) if self.predict_idx is None: if self.test_idx is not None: self.predict_idx = self.test_idx else: self.predict_idx = torch.arange(C.shape[0], dtype=torch.long) - # Slice tensors per split (map-style datasets rely on correct len() for sharding) def _mk_dataset(idx: IndexLike): if idx is None: return None - idx_t = _to_index_tensor(idx) # <-- NEW: stable mapping to original rows + idx_t = _to_index_tensor(idx) C_s = _maybe_index(C, idx_t) X_s = _maybe_index(X, idx_t) @@ -185,25 +168,17 @@ def _mk_dataset(idx: IndexLike): ds_cls = TASK_TO_DATASET[self.task_type] if Y_s is None: - # Allow unsupervised / network-style usage where Y is omitted. - # In that case, use X as a dummy target so shapes line up. Y_s = X_s - # IMPORTANT: pass orig_idx so every item can report its original row id return ds_cls(C_s, X_s, Y_s, orig_idx=idx_t, dtype=self.dtype) - - - self.ds_train = _mk_dataset(self.train_idx) self.ds_val = _mk_dataset(self.val_idx) self.ds_test = _mk_dataset(self.test_idx) self.ds_predict = _mk_dataset(self.predict_idx) - # Keep tensors for potential later use self.C, self.X, self.Y = C, X, Y - # ---- Dataloaders ---- def _common_dl_kwargs(self, batch_size: int, *, drop_last: Optional[bool] = None) -> Dict: return { "batch_size": batch_size, @@ -213,9 +188,6 @@ def _common_dl_kwargs(self, batch_size: int, *, drop_last: Optional[bool] = None "drop_last": self.drop_last if drop_last is None else bool(drop_last), } - - - def train_dataloader(self) -> DataLoader: if self.ds_train is None: raise RuntimeError("train dataset is not set; provide train_idx or splitter.") @@ -231,7 +203,6 @@ def val_dataloader(self): return DataLoader( dataset=self.ds_val, shuffle=self.shuffle_eval, - # NEVER drop samples for eval (avoids silent data loss / mis-ordering) **self._common_dl_kwargs(self.val_batch_size, drop_last=False), ) @@ -252,4 +223,3 @@ def predict_dataloader(self) -> DataLoader: shuffle=False, **self._common_dl_kwargs(self.predict_batch_size, drop_last=False), ) - diff --git a/contextualized/regression/datasets.py b/contextualized/regression/datasets.py index 64791876..30b36d05 100644 --- a/contextualized/regression/datasets.py +++ b/contextualized/regression/datasets.py @@ -16,14 +16,11 @@ def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): self.X = torch.as_tensor(X, dtype=dtype) self.Y = torch.as_tensor(Y, dtype=dtype) - # NEW: stable original-row index for distributed ordered gather - # FIX: enforce 1D LongTensor when provided if orig_idx is None: self.orig_idx = torch.arange(len(self.C), dtype=torch.long) else: self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) - # FIX: derive dims from converted tensors to prevent shape mismatches self.c_dim = self.C.shape[-1] self.x_dim = self.X.shape[-1] self.y_dim = self.Y.shape[-1] @@ -34,8 +31,8 @@ def __len__(self): def __getitem__(self, idx): return { - "idx": idx, # dataset-local position - "orig_idx": self.orig_idx[idx], # NEW: original-row id + "idx": idx, + "orig_idx": self.orig_idx[idx], "contexts": self.C[idx], "predictors": self.X[idx].expand(self.y_dim, -1), "outcomes": self.Y[idx].unsqueeze(-1), @@ -51,14 +48,11 @@ def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): self.X = torch.as_tensor(X, dtype=dtype) self.Y = torch.as_tensor(Y, dtype=dtype) - # NEW: stable original-row index - # FIX: enforce 1D LongTensor when provided if orig_idx is None: self.orig_idx = torch.arange(len(self.C), dtype=torch.long) else: self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) - # FIX: derive dims from converted tensors to prevent shape mismatches self.c_dim = self.C.shape[-1] self.x_dim = self.X.shape[-1] self.y_dim = self.Y.shape[-1] @@ -70,7 +64,7 @@ def __len__(self): def __getitem__(self, idx): return { "idx": idx, - "orig_idx": self.orig_idx[idx], # NEW + "orig_idx": self.orig_idx[idx], "contexts": self.C[idx], "predictors": self.X[idx].expand(self.y_dim, -1).unsqueeze(-1), "outcomes": self.Y[idx].expand(self.x_dim, -1).T.unsqueeze(-1), @@ -86,14 +80,11 @@ def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): self.X = X.to(dtype) if isinstance(X, torch.Tensor) else torch.as_tensor(X, dtype=dtype) self.Y = Y.to(dtype) if isinstance(Y, torch.Tensor) else torch.as_tensor(Y, dtype=dtype) - # NEW: stable original-row index per sample - # FIX: enforce 1D LongTensor when provided if orig_idx is None: self.orig_idx = torch.arange(len(self.C), dtype=torch.long) else: self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) - # FIX: derive dims from converted tensors to prevent shape mismatches self.c_dim = self.C.shape[-1] self.x_dim = self.X.shape[-1] self.y_dim = self.Y.shape[-1] @@ -103,21 +94,22 @@ def __len__(self): return len(self.C) * self.y_dim def __getitem__(self, idx): + # Get task-split sample indices n_i = idx // self.y_dim y_i = idx % self.y_dim - # Minor improvement: task vector dtype matches dataset dtype + # Create a one-hot encoding for the task t = torch.zeros(self.y_dim, dtype=self.dtype) t[y_i] = 1 return { - "idx": idx, # dataset-item index - "orig_idx": self.orig_idx[n_i], # NEW: original-row id of the sample + "idx": idx, + "orig_idx": self.orig_idx[n_i], "contexts": self.C[n_i], "task": t, "predictors": self.X[n_i], "outcomes": self.Y[n_i, y_i].unsqueeze(0), - "sample_idx": n_i, # local sample index within this dataset + "sample_idx": n_i, "outcome_idx": y_i, } @@ -132,14 +124,11 @@ def __init__(self, C, X, Y, orig_idx=None, dtype=torch.float): self.X = torch.as_tensor(X, dtype=dtype) self.Y = torch.as_tensor(Y, dtype=dtype) - # NEW: stable original-row index per sample - # FIX: enforce 1D LongTensor when provided if orig_idx is None: self.orig_idx = torch.arange(len(self.C), dtype=torch.long) else: self.orig_idx = torch.as_tensor(orig_idx, dtype=torch.long).view(-1) - # FIX: derive dims from converted tensors to prevent shape mismatches self.c_dim = self.C.shape[-1] self.x_dim = self.X.shape[-1] self.y_dim = self.Y.shape[-1] @@ -149,17 +138,19 @@ def __len__(self): return len(self.C) * self.x_dim * self.y_dim def __getitem__(self, idx): + # Get task-split sample indices n_i = idx // (self.x_dim * self.y_dim) x_i = (idx // self.y_dim) % self.x_dim y_i = idx % self.y_dim + # Create a one-hot encoding for the task t = torch.zeros(self.x_dim + self.y_dim, dtype=self.dtype) t[x_i] = 1 t[self.x_dim + y_i] = 1 return { - "idx": idx, # dataset-item index - "orig_idx": self.orig_idx[n_i], # NEW: original-row id of the sample + "idx": idx, + "orig_idx": self.orig_idx[n_i], "contexts": self.C[n_i], "task": t, "predictors": self.X[n_i, x_i].unsqueeze(0), diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index 983c2056..a635369d 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -10,20 +10,19 @@ Implemented with PyTorch Lightning """ -# For distributed runs, use the ContextualizedRegressionDataModule which returns -# map-style datasets and allows Lightning's Trainer to auto-shard with DDP. -from .datamodules import ContextualizedRegressionDataModule # noqa: F401 + +from .datamodules import ContextualizedRegressionDataModule from abc import abstractmethod import numpy as np import torch from torch.utils.data import DataLoader import pytorch_lightning as pl + from contextualized.regression.regularizers import REGULARIZERS -from contextualized.regression.losses import MSE +from contextualized.regression.losses import MSE from contextualized.functions import LINK_FUNCTIONS - from contextualized.regression.metamodels import ( NaiveMetamodel, SubtypeMetamodel, @@ -33,9 +32,15 @@ MULTITASK_METAMODELS, ) -# --- Accept both string registry keys and callables for link_fn / loss_fn --- + def _resolve_registry_or_callable(maybe_obj, registry, name: str): - """Return a function from a registry by key, or the callable directly.""" + """ + + :param maybe_obj: + :param registry: + :param name: + + """ if isinstance(maybe_obj, str): try: return registry[maybe_obj] @@ -45,15 +50,16 @@ def _resolve_registry_or_callable(maybe_obj, registry, name: str): ) from e if callable(maybe_obj): return maybe_obj - raise TypeError(f"{name} must be a string key or a callable, got {type(maybe_obj).__name__}") + raise TypeError( + f"{name} must be a string key or a callable, got {type(maybe_obj).__name__}" + ) def _resolve_loss(maybe_loss): """ - Allow: - * 'mse' string (maps to local MSE), - * any callable (already constructed loss), - and reject unknown strings to avoid circular imports with package-level registries. + + :param maybe_loss: + """ if isinstance(maybe_loss, str): if maybe_loss.lower() == "mse": @@ -64,13 +70,16 @@ def _resolve_loss(maybe_loss): ) if callable(maybe_loss): return maybe_loss - raise TypeError(f"loss_fn must be a string key or a callable, got {type(maybe_loss).__name__}") -# --------------------------------------------------------------------------- + raise TypeError( + f"loss_fn must be a string key or a callable, got {type(maybe_loss).__name__}" + ) + + def _resolve_regularizer(maybe_reg): """ - Allow: - * string key -> lookup in REGULARIZERS - * callable -> pass through directly + + :param maybe_reg: + """ if isinstance(maybe_reg, str): try: @@ -82,7 +91,10 @@ def _resolve_regularizer(maybe_reg): ) from e if callable(maybe_reg): return maybe_reg - raise TypeError(f"model_regularizer must be a string key or a callable, got {type(maybe_reg).__name__}") + raise TypeError( + "model_regularizer must be a string key or a callable, got " + f"{type(maybe_reg).__name__}" + ) class ContextualizedRegressionBase(pl.LightningModule): @@ -126,7 +138,7 @@ class ContextualizedRegressionBase(pl.LightningModule): # self.base_y_predictor = base_y_predictor # self.base_param_predictor = base_param_predictor # self._build_metamodel( - # context_dim, + # context_dim, # x_dim, # y_dim, # univariate, @@ -138,8 +150,8 @@ class ContextualizedRegressionBase(pl.LightningModule): # @abstractmethod # def _build_metamodel( - # self, - # context_dim, + # self, + # context_dim, # x_dim, # y_dim, # univariate, @@ -156,7 +168,7 @@ class ContextualizedRegressionBase(pl.LightningModule): # """ # # builds the metamodel # self.metamodel = SINGLE_TASK_METAMODELS[self.metamodel_type]( - # context_dim, + # context_dim, # x_dim, # y_dim, # univariate, @@ -230,7 +242,9 @@ def forward(self, batch): if not self.fit_intercept: mu = torch.zeros_like(mu) if self.base_param_predictor is not None: - base_beta, base_mu = self.base_param_predictor.predict_params(batch["contexts"]) + base_beta, base_mu = self.base_param_predictor.predict_params( + batch["contexts"] + ) beta = beta + base_beta.to(beta.device) mu = mu + base_mu.to(mu.device) return beta, mu @@ -243,27 +257,41 @@ def configure_optimizers(self): return optimizer def _batch_size_from_batch(self, batch: dict) -> int: - # all your datasets provide "contexts" in the batch dict - if isinstance(batch, dict) and "contexts" in batch and isinstance(batch["contexts"], torch.Tensor): + """ + + :param batch: + + """ + if ( + isinstance(batch, dict) + and "contexts" in batch + and isinstance(batch["contexts"], torch.Tensor) + ): return int(batch["contexts"].shape[0]) return 1 - def _predict_payload(self, batch: dict, **outputs) -> dict: """ - Return a minimal, DDP-safe payload for trainer.predict: - - indices needed to reorder across ranks - - model outputs - Everything is detached and moved to CPU to avoid GPU memory blow-ups. + + :param batch: + :param **outputs: + """ out = {} - for k in ("idx", "orig_idx", "sample_idx", "outcome_idx", "predictor_idx"): + for k in ( + "idx", + "orig_idx", + "sample_idx", + "outcome_idx", + "predictor_idx", + "contexts", + "predictors", + ): if isinstance(batch, dict) and k in batch: out[k] = batch[k] out.update(outputs) - # Detach + move tensors to CPU for cheap gather/reorder in wrapper code later for k, v in list(out.items()): if isinstance(v, torch.Tensor): out[k] = v.detach().cpu() @@ -271,10 +299,15 @@ def _predict_payload(self, batch: dict, **outputs) -> dict: def training_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ loss = self._batch_loss(batch, batch_idx) bs = self._batch_size_from_batch(batch) - # Step-level logging: keep visibility, avoid per-step all-reduce (DDP scaling killer) self.log( "train_loss_step", loss, @@ -285,7 +318,6 @@ def training_step(self, batch, batch_idx): batch_size=bs, ) - # Epoch-level logging: sync across ranks once per epoch (correct global metric) self.log( "train_loss", loss, @@ -298,9 +330,13 @@ def training_step(self, batch, batch_idx): return loss + def validation_step(self, batch, batch_idx): + """ + :param batch: + :param batch_idx: - def validation_step(self, batch, batch_idx): + """ loss = self._batch_loss(batch, batch_idx) bs = self._batch_size_from_batch(batch) self.log( @@ -314,8 +350,13 @@ def validation_step(self, batch, batch_idx): ) return loss - def test_step(self, batch, batch_idx): + """ + + :param batch: + :param batch_idx: + + """ loss = self._batch_loss(batch, batch_idx) bs = self._batch_size_from_batch(batch) self.log( @@ -328,56 +369,58 @@ def test_step(self, batch, batch_idx): batch_size=bs, ) return loss - + def _predict_from_models(self, X, beta_hat, mu_hat): """ - Make shapes consistent before computing: - y = g( (beta ⊙ X).sum(-1, keepdim=True) + mu ) - ... - """ - # ---- Univariate grid case: X is (B, y_dim, x_dim, 1) ---- - # singletask_univariate dataset convention produces predictors shaped (B, y, x, 1) + :param X: + :param beta_hat: + :param mu_hat: + + """ if isinstance(X, torch.Tensor) and X.dim() == 4 and X.shape[-1] == 1: - # move X to device/dtype X = X.to(device=beta_hat.device, dtype=beta_hat.dtype) - # beta_hat should be (B, y, x, 1) in this regime if beta_hat.dim() == 3: beta_hat = beta_hat.unsqueeze(-1) if beta_hat.dim() != 4 or beta_hat.shape[-1] != 1: - raise RuntimeError(f"Univariate expects beta_hat (B,y,x,1); got {beta_hat.shape}") + raise RuntimeError( + f"Univariate expects beta_hat (B,y,x,1); got {beta_hat.shape}" + ) - # mu_hat should broadcast to (B, y, x, 1) if not isinstance(mu_hat, torch.Tensor): - mu_hat = torch.as_tensor(mu_hat, device=beta_hat.device, dtype=beta_hat.dtype) + mu_hat = torch.as_tensor( + mu_hat, device=beta_hat.device, dtype=beta_hat.dtype + ) else: mu_hat = mu_hat.to(device=beta_hat.device, dtype=beta_hat.dtype) if mu_hat.dim() == 2: - # (B, y) -> (B, y, 1, 1) -> expand across x - mu_hat = mu_hat.unsqueeze(-1).unsqueeze(-1).expand(-1, beta_hat.shape[1], beta_hat.shape[2], 1) + mu_hat = ( + mu_hat.unsqueeze(-1) + .unsqueeze(-1) + .expand(-1, beta_hat.shape[1], beta_hat.shape[2], 1) + ) elif mu_hat.dim() == 3: - # (B, y, x) or (B, y, 1) -> (B, y, x, 1) if mu_hat.shape[-1] == 1: - mu_hat = mu_hat.unsqueeze(-1).expand(-1, beta_hat.shape[1], beta_hat.shape[2], 1) + mu_hat = mu_hat.unsqueeze(-1).expand( + -1, beta_hat.shape[1], beta_hat.shape[2], 1 + ) else: mu_hat = mu_hat.unsqueeze(-1) elif mu_hat.dim() == 4 and mu_hat.shape[-1] == 1: pass else: - raise RuntimeError(f"Unsupported mu_hat shape for univariate: {mu_hat.shape}") + raise RuntimeError( + f"Unsupported mu_hat shape for univariate: {mu_hat.shape}" + ) out = (beta_hat * X).sum(dim=-1, keepdim=True) + mu_hat return self.link_fn(out) - # ---- Normalize beta_hat to (B, y_dim, x_dim) ---- if not isinstance(beta_hat, torch.Tensor): - raise RuntimeError( - f"beta_hat must be a tensor, got {type(beta_hat)}" - ) + raise RuntimeError(f"beta_hat must be a tensor, got {type(beta_hat)}") - # Handle univariate case where shape is (B, y, x, 1) if beta_hat.dim() == 4 and beta_hat.shape[-1] == 1: beta_hat = beta_hat.squeeze(-1) @@ -389,14 +432,12 @@ def _predict_from_models(self, X, beta_hat, mu_hat): B, y_dim, x_dim = beta_hat.shape - # ---- Move and normalize X ---- if not isinstance(X, torch.Tensor): X = torch.as_tensor(X, device=beta_hat.device, dtype=beta_hat.dtype) else: X = X.to(device=beta_hat.device, dtype=beta_hat.dtype) if X.dim() == 2: - # (B, x_dim) -> broadcast over y_dim if X.shape[0] != B: raise RuntimeError( f"X batch dim {X.shape[0]} != beta_hat batch dim {B}. " @@ -417,11 +458,11 @@ def _predict_from_models(self, X, beta_hat, mu_hat): ) if X.shape[1] == y_dim and X.shape[2] == x_dim: - pass # already good + pass elif X.shape[1] == 1 and X.shape[2] == x_dim: - X = X.expand(-1, y_dim, -1) # (B,1,x) -> (B,y,x) + X = X.expand(-1, y_dim, -1) elif X.shape[1] == x_dim and X.shape[2] == y_dim and x_dim == y_dim: - X = X.permute(0, 2, 1) # (B,x,y) -> (B,y,x) + X = X.permute(0, 2, 1) else: raise RuntimeError( f"Unexpected X shape {X.shape} for beta_hat {beta_hat.shape}. " @@ -433,21 +474,17 @@ def _predict_from_models(self, X, beta_hat, mu_hat): f"expected 2 or 3. X.shape={X.shape}, beta_hat.shape={beta_hat.shape}" ) - # ---- Normalize mu_hat to broadcast correctly ---- if not isinstance(mu_hat, torch.Tensor): mu_hat = torch.as_tensor(mu_hat, device=beta_hat.device, dtype=beta_hat.dtype) else: mu_hat = mu_hat.to(device=beta_hat.device, dtype=beta_hat.dtype) - # Handle univariate case where mu_hat is (B, y, x, 1) if mu_hat.dim() == 4 and mu_hat.shape[-1] == 1: mu_hat = mu_hat.squeeze(-1) if mu_hat.dim() == 2: - # (B, y_dim) -> (B, y_dim, 1) mu_hat = mu_hat.unsqueeze(-1) elif mu_hat.dim() == 3: - # assume already (B, y_dim, 1) or (B, y_dim, x_dim) pass else: raise RuntimeError( @@ -458,10 +495,6 @@ def _predict_from_models(self, X, beta_hat, mu_hat): out = (beta_hat * X).sum(dim=-1, keepdim=True) + mu_hat return self.link_fn(out) - - - - def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -492,85 +525,6 @@ def _predict_y(self, C, X, beta_hat, mu_hat): # return DataLoader(dataset=DataIterable(dataset_constructor(C, X, Y)), **kwargs) -# class NaiveContextualizedRegression(ContextualizedRegressionBase): -# """See NaiveMetamodel""" - -# def _build_metamodel(self, *args, **kwargs): -# """ - -# :param *args: -# :param **kwargs: - -# """ -# kwargs["univariate"] = False -# self.metamodel = NaiveMetamodel(*args, **kwargs) - -# def _batch_loss(self, batch, batch_idx): -# """ - -# :param batch: -# :param batch_idx: - -# """ -# C, X, Y, _ = batch -# beta_hat, mu_hat = self.predict_step(batch, batch_idx) -# pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat)) -# reg_loss = self.model_regularizer(beta_hat, mu_hat) -# return pred_loss + reg_loss - -# def predict_step(self, batch, batch_idx): -# """ - -# :param batch: -# :param batch_idx: - -# """ -# C, _, _, _ = batch -# beta_hat, mu_hat = self(C) -# return beta_hat, mu_hat - - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, n_idx = data - # betas[n_idx] = beta_hats - # mus[n_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, X, _, n_idx = data - # ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, MultivariateDataset, **kwargs) - - class ContextualizedRegression(ContextualizedRegressionBase): """Supports SubtypeMetamodel and NaiveMetamodel, see selected metamodel for docs""" def __init__( @@ -608,7 +562,7 @@ def __init__( self.base_param_predictor = base_param_predictor if metamodel_type == "subtype": self.metamodel = SubtypeMetamodel( - context_dim=context_dim, + context_dim=context_dim, x_dim=x_dim, y_dim=y_dim, univariate=False, @@ -620,7 +574,7 @@ def __init__( if num_archetypes is not None: raise ValueError("NaiveMetamodel does not support num_archetypes.") self.metamodel = NaiveMetamodel( - context_dim=context_dim, + context_dim=context_dim, x_dim=x_dim, y_dim=y_dim, univariate=False, @@ -638,7 +592,10 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch["outcomes"], self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss @@ -648,51 +605,9 @@ def predict_step(self, batch, batch_idx): return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, n_idx = data - # betas[n_idx] = beta_hats - # mus[n_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, X, _, n_idx = data - # ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, MultivariateDataset, **kwargs) - - class NaiveContextualizedRegression(ContextualizedRegression): """Handle for NaiveMetamodel usage of ContextualizedRegression. - Does not use archetypes. + Does not use archetypes. """ def __init__( self, @@ -727,12 +642,11 @@ def __init__( loss_fn=loss_fn, model_regularizer=model_regularizer, base_y_predictor=base_y_predictor, - base_param_predictor=base_param_predictor + base_param_predictor=base_param_predictor, ) self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) - class MultitaskContextualizedRegression(ContextualizedRegressionBase): """See MultitaskMetamodel""" def __init__( @@ -782,7 +696,6 @@ def forward(self, batch): beta, mu = self.metamodel(batch["contexts"], batch["task"]) if not self.fit_intercept: mu = torch.zeros_like(mu) - # Does not support base_param_predictor return beta, mu def _batch_loss(self, batch, batch_idx): @@ -793,10 +706,13 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch['outcomes'], self._predict_y(batch['contexts'], batch['predictors'], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss - + def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -807,7 +723,6 @@ def _predict_y(self, C, X, beta_hat, mu_hat): """ Y = self._predict_from_models(X, beta_hat, mu_hat) - # Does not support base_y_predictor return Y def predict_step(self, batch, batch_idx): @@ -816,50 +731,6 @@ def predict_step(self, batch, batch_idx): return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) - - - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, _, n_idx, y_idx = data - # betas[n_idx, y_idx] = beta_hats - # mus[n_idx, y_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, _, X, _, n_idx, y_idx = data - # ys[n_idx, y_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs) - - class TasksplitContextualizedRegression(ContextualizedRegressionBase): """See TasksplitMetamodel""" @@ -912,7 +783,7 @@ def __init__( task_encoder_type=task_encoder_type, task_encoder_kwargs=task_encoder_kwargs, ) - + def forward(self, batch): """ @@ -922,7 +793,6 @@ def forward(self, batch): beta, mu = self.metamodel(batch["contexts"], batch["task"]) if not self.fit_intercept: mu = torch.zeros_like(mu) - # Does not support base_param_predictor return beta, mu def _batch_loss(self, batch, batch_idx): @@ -933,10 +803,13 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch['outcomes'], self._predict_y(batch['contexts'], batch['predictors'], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss - + def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -947,7 +820,6 @@ def _predict_y(self, C, X, beta_hat, mu_hat): """ Y = self._predict_from_models(X, beta_hat, mu_hat) - # Does not support base_y_predictor return Y def predict_step(self, batch, batch_idx): @@ -956,75 +828,6 @@ def predict_step(self, batch, batch_idx): return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) - - # def _batch_loss(self, batch, batch_idx): - # """ - - # :param batch: - # :param batch_idx: - - # """ - # beta_hat, mu_hat = self(batch) - # pred_loss = self.loss_fn(batch["outcomes"], self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat)) - # reg_loss = self.model_regularizer(beta_hat, mu_hat) - # return pred_loss + reg_loss - - # def predict_step(self, batch, batch_idx): - # """ - - # :param batch: - # :param batch_idx: - - # """ - # beta_hat, mu_hat = self(batch) - # batch.update({ - # "betas": beta_hat, - # "mus": mu_hat.squeeze(-1) - # }) - # return batch - - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, _, n_idx, y_idx = data - # betas[n_idx, y_idx] = beta_hats - # mus[n_idx, y_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, _, X, _, n_idx, y_idx = data - # ys[n_idx, y_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs) - - class ContextualizedUnivariateRegression(ContextualizedRegressionBase): """Supports SubtypeMetamodel and NaiveMetamodel, see selected metamodel for docs""" def __init__( @@ -1062,7 +865,7 @@ def __init__( self.base_param_predictor = base_param_predictor if metamodel_type == "subtype": self.metamodel = SubtypeMetamodel( - context_dim=context_dim, + context_dim=context_dim, x_dim=x_dim, y_dim=y_dim, univariate=True, @@ -1074,7 +877,7 @@ def __init__( if num_archetypes is not None: raise ValueError("NaiveMetamodel does not support num_archetypes.") self.metamodel = NaiveMetamodel( - context_dim=context_dim, + context_dim=context_dim, x_dim=x_dim, y_dim=y_dim, univariate=True, @@ -1083,7 +886,7 @@ def __init__( ) else: raise ValueError("Supported metamodel_type's: subtype, naive") - + def forward(self, batch): """ @@ -1093,9 +896,8 @@ def forward(self, batch): beta, mu = self.metamodel(batch["contexts"]) if not self.fit_intercept: mu = torch.zeros_like(mu) - # Does not support base_param_predictor return beta, mu - + def _batch_loss(self, batch, batch_idx): """ @@ -1104,7 +906,10 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch["outcomes"], self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss @@ -1114,48 +919,6 @@ def predict_step(self, batch, batch_idx): return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, n_idx = data - # betas[n_idx] = beta_hats.squeeze(-1) - # mus[n_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, X, _, n_idx = data - # ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, UnivariateDataset, **kwargs) - - class MultitaskContextualizedUnivariateRegression(ContextualizedRegressionBase): """See MultitaskMetamodel""" @@ -1196,7 +959,7 @@ def __init__( encoder_type=encoder_type, encoder_kwargs=encoder_kwargs, ) - + def forward(self, batch): """ @@ -1206,7 +969,6 @@ def forward(self, batch): beta, mu = self.metamodel(batch["contexts"], batch["task"]) if not self.fit_intercept: mu = torch.zeros_like(mu) - # Does not support base_param_predictor return beta, mu def _batch_loss(self, batch, batch_idx): @@ -1217,10 +979,13 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch['outcomes'], self._predict_y(batch['contexts'], batch['predictors'], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss - + def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -1231,7 +996,6 @@ def _predict_y(self, C, X, beta_hat, mu_hat): """ Y = self._predict_from_models(X, beta_hat, mu_hat) - # Does not support base_y_predictor return Y def predict_step(self, batch, batch_idx): @@ -1240,7 +1004,6 @@ def predict_step(self, batch, batch_idx): return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) - class TasksplitContextualizedUnivariateRegression(ContextualizedRegressionBase): """See TasksplitMetamodel""" @@ -1291,7 +1054,7 @@ def __init__( task_encoder_type=task_encoder_type, task_encoder_kwargs=task_encoder_kwargs, ) - + def forward(self, batch): """ @@ -1301,7 +1064,6 @@ def forward(self, batch): beta, mu = self.metamodel(batch["contexts"], batch["task"]) if not self.fit_intercept: mu = torch.zeros_like(mu) - # Does not support base_param_predictor return beta, mu def _batch_loss(self, batch, batch_idx): @@ -1312,10 +1074,13 @@ def _batch_loss(self, batch, batch_idx): """ beta_hat, mu_hat = self(batch) - pred_loss = self.loss_fn(batch['outcomes'], self._predict_y(batch['contexts'], batch['predictors'], beta_hat, mu_hat)) + pred_loss = self.loss_fn( + batch["outcomes"], + self._predict_y(batch["contexts"], batch["predictors"], beta_hat, mu_hat), + ) reg_loss = self.model_regularizer(beta_hat, mu_hat) return pred_loss + reg_loss - + def _predict_y(self, C, X, beta_hat, mu_hat): """ @@ -1326,7 +1091,6 @@ def _predict_y(self, C, X, beta_hat, mu_hat): """ Y = self._predict_from_models(X, beta_hat, mu_hat) - # Does not support base_y_predictor return Y def predict_step(self, batch, batch_idx): @@ -1335,51 +1099,6 @@ def predict_step(self, batch, batch_idx): return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) - - # def _params_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # mus = betas.copy() - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # _, _, _, _, n_idx, x_idx, y_idx = data - # betas[n_idx, y_idx, x_idx] = beta_hats.squeeze(-1) - # mus[n_idx, y_idx, x_idx] = mu_hats.squeeze(-1) - # return betas, mus - - # def _y_reshape(self, preds, dataloader): - # """ - - # :param preds: - # :param dataloader: - - # """ - # ds = dataloader.dataset.dataset - # ys = np.zeros((ds.n, ds.y_dim, ds.x_dim)) - # for (beta_hats, mu_hats), data in zip(preds, dataloader): - # C, _, X, _, n_idx, x_idx, y_idx = data - # ys[n_idx, y_idx, x_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze( - # -1 - # ) - # return ys - - # def dataloader(self, C, X, Y, **kwargs): - # """ - - # :param C: - # :param X: - # :param Y: - # :param **kwargs: - - # """ - # return self._dataloader(C, X, Y, MultitaskUnivariateDataset, **kwargs) - - class ContextualizedCorrelation(ContextualizedUnivariateRegression): """Using univariate contextualized regression to estimate Pearson's correlation See SubtypeMetamodel for assumptions and full docstring @@ -1393,10 +1112,9 @@ def __init__(self, context_dim, x_dim, **kwargs): super().__init__(context_dim, x_dim, x_dim, **kwargs) self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) - def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(batch) - beta_hat = beta_hat.squeeze(-1) # (B, y, x) + beta_hat = beta_hat.squeeze(-1) beta_hat_T = beta_hat.transpose(1, 2) signs = torch.sign(beta_hat) @@ -1404,10 +1122,9 @@ def predict_step(self, batch, batch_idx): correlations = signs * torch.sqrt(torch.abs(beta_hat * beta_hat_T)) mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) - return self._predict_payload(batch, betas=beta_hat, mus=mu_hat, correlations=correlations) - - - + return self._predict_payload( + batch, betas=beta_hat, mus=mu_hat, correlations=correlations + ) class MultitaskContextualizedCorrelation(MultitaskContextualizedUnivariateRegression): @@ -1424,7 +1141,6 @@ def __init__(self, context_dim, x_dim, **kwargs): self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) - class TasksplitContextualizedCorrelation(TasksplitContextualizedUnivariateRegression): """Using multitask univariate contextualized regression to estimate Pearson's correlation See TasksplitMetamodel for assumptions and full docstring @@ -1439,7 +1155,6 @@ def __init__(self, context_dim, x_dim, **kwargs): self.save_hyperparameters(ignore=["base_y_predictor", "base_param_predictor"]) - class ContextualizedNeighborhoodSelection(ContextualizedRegression): """Using singletask multivariate contextualized regression to do edge-regression for estimating conditional dependencies @@ -1465,16 +1180,13 @@ def __init__( self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) def predict_step(self, batch, batch_idx): - beta_hat, mu_hat = self(batch) # dict batch + beta_hat, mu_hat = self(batch) beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) - - - class ContextualizedMarkovGraph(ContextualizedRegression): """Using singletask multivariate contextualized regression to do edge-regression for estimating conditional dependencies @@ -1492,10 +1204,9 @@ def __init__(self, context_dim, x_dim, **kwargs): self.register_buffer("diag_mask", torch.ones(x_dim, x_dim) - torch.eye(x_dim)) def predict_step(self, batch, batch_idx): - beta_hat, mu_hat = self(batch) # dict batch + beta_hat, mu_hat = self(batch) beta_hat = beta_hat + beta_hat.transpose(1, 2) beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) mu_hat = mu_hat if mu_hat.dim() >= 3 else mu_hat.unsqueeze(-1) return self._predict_payload(batch, betas=beta_hat, mus=mu_hat) - diff --git a/contextualized/regression/trainers.py b/contextualized/regression/trainers.py index 527c1164..107d1f9d 100644 --- a/contextualized/regression/trainers.py +++ b/contextualized/regression/trainers.py @@ -3,18 +3,20 @@ """ from typing import Any, Tuple, List, Dict, Optional + import numpy as np import torch import torch.distributed as dist -import pytorch_lightning as pl -from pytorch_lightning.plugins.environments import LightningEnvironment +import lightning.pytorch as pl +from lightning.pytorch.plugins.environments import LightningEnvironment import os -from pytorch_lightning.strategies import DDPStrategy - +from lightning.pytorch.strategies import DDPStrategy def _stack_from_preds(preds: List[dict], key: str) -> torch.Tensor: - """Concatenate a tensor field from the list of batch dicts returned by predict().""" + """ + Concatenate a tensor field from the list of batch dicts returned by predict(). + """ preds = _flatten_pl_predict_output(preds) parts = [] for p in preds: @@ -36,8 +38,9 @@ def _is_main_process() -> bool: def _flatten_pl_predict_output(preds): """ Lightning can return: - - list[dict] (single dataloader) - - list[list[dict]] (multiple dataloaders) + - list[dict] (single dataloader) + - list[list[dict]] (multiple dataloaders) + Normalize to list[dict]. """ if preds is None: @@ -88,7 +91,10 @@ def _pack_keys_from_preds(preds: list, keys: Tuple[str, ...]) -> Dict[str, np.nd def _gather_object_to_rank0(obj): """ Gather arbitrary Python objects to rank 0. - Returns list[obj] on rank 0, None on other ranks. + + Returns: + - list[obj] on rank 0 + - None on other ranks """ if not _is_distributed(): return [obj] @@ -106,7 +112,9 @@ def _gather_object_to_rank0(obj): return None -def _merge_packed_payloads(payloads: List[Optional[Dict[str, np.ndarray]]]) -> Dict[str, np.ndarray]: +def _merge_packed_payloads( + payloads: List[Optional[Dict[str, np.ndarray]]], +) -> Dict[str, np.ndarray]: """ Merge list[dict[str, np.ndarray]] -> dict[str, np.ndarray] by concatenation axis 0. """ @@ -120,7 +128,11 @@ def _merge_packed_payloads(payloads: List[Optional[Dict[str, np.ndarray]]]) -> D keys.update(p.keys()) for k in keys: - chunks = [p[k] for p in payloads if (k in p) and (p[k] is not None) and (len(p[k]) > 0)] + chunks = [ + p[k] + for p in payloads + if (k in p) and (p[k] is not None) and (len(p[k]) > 0) + ] if not chunks: continue merged[k] = np.concatenate(chunks, axis=0) @@ -157,78 +169,89 @@ def _stable_sort_and_dedupe(payload: Dict[str, np.ndarray]) -> Dict[str, np.ndar return out - -def _gather_predict_payload(preds, keys: Tuple[str, ...]) -> Optional[Dict[str, np.ndarray]]: +def _gather_predict_payload( + preds, keys: Tuple[str, ...] +) -> Optional[Dict[str, np.ndarray]]: """ Packs requested keys from local preds, gathers to rank0 under DDP, merges, and stable-sorts/dedupes by orig_idx (if present). - Returns payload dict on rank0; returns None on non-rank0 in DDP. + + Returns: + - payload dict on rank 0 + - None on non-rank0 in DDP """ local = _pack_keys_from_preds(preds, keys) gathered = _gather_object_to_rank0(local) if gathered is None: - return None # non-rank0 DDP + return None merged = _merge_packed_payloads(gathered) merged = _stable_sort_and_dedupe(merged) return merged - class RegressionTrainer(pl.Trainer): """ Trains the contextualized.regression lightning_modules - and provides convenience prediction helpers that reshape - batched outputs into expected numpy arrays without relying - on model-private _*reshape helpers. """ @torch.no_grad() - def predict_params(self, model: pl.LightningModule, dataloader) -> Tuple[np.ndarray, np.ndarray]: + def predict_params( + self, model: pl.LightningModule, dataloader + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Returns context-specific regression models + - beta (numpy.ndarray): (n, y_dim, x_dim) + - mu (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate]) + """ preds = super().predict(model, dataloader) payload = _gather_predict_payload(preds, keys=("idx", "orig_idx", "betas", "mus")) if payload is None: - # non-rank0 DDP: return nothing to avoid duplicated outputs return None, None - if "betas" not in payload or "mus" not in payload: - raise RuntimeError("predict_params: predict_step must return 'betas' and 'mus' (and ideally 'orig_idx').") + raise RuntimeError( + "predict_params: predict_step must return 'betas' and 'mus' (and ideally 'orig_idx')." + ) return payload["betas"], payload["mus"] - @torch.no_grad() def predict_y(self, model: pl.LightningModule, dataloader) -> np.ndarray: + """ + Returns context-specific predictions of the response Y + - y_hat (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate]) + """ preds = super().predict(model, dataloader) - # Prefer lightweight gather, but allow legacy keys if present. - payload = _gather_predict_payload(preds, keys=("idx", "orig_idx", "betas", "mus")) + payload = _gather_predict_payload( + preds, keys=("idx", "orig_idx", "contexts", "predictors", "betas", "mus") + ) + if payload is None: - return None # non-rank0 DDP + return None if "betas" not in payload or "mus" not in payload: raise RuntimeError("predict_y: predict_step must return 'betas' and 'mus'.") betas = torch.as_tensor(payload["betas"]) - mus = torch.as_tensor(payload["mus"]) + mus = torch.as_tensor(payload["mus"]) - # If legacy contexts/predictors were returned and gathered, use them. if ("contexts" in payload) and ("predictors" in payload): C = torch.as_tensor(payload["contexts"]) X = torch.as_tensor(payload["predictors"]) else: - # Option A path: reconstruct from dataset via dataset-local idx (NOT orig_idx) ds = getattr(dataloader, "dataset", None) if ds is None: - raise RuntimeError("predict_y: dataloader has no .dataset; cannot reconstruct C/X.") + raise RuntimeError( + "predict_y: dataloader has no .dataset; cannot reconstruct C/X." + ) idx_np = payload["idx"].astype(np.int64) idx_t = torch.as_tensor(idx_np, dtype=torch.long) - # Support Subset wrapper if user wrapped loaders externally if hasattr(ds, "dataset") and hasattr(ds, "indices"): base = ds.dataset if not (hasattr(base, "C") and hasattr(base, "X")): @@ -239,11 +262,12 @@ def predict_y(self, model: pl.LightningModule, dataloader) -> np.ndarray: X = base.X[base_pos_t] else: if not (hasattr(ds, "C") and hasattr(ds, "X")): - raise RuntimeError("predict_y: dataset must expose .C and .X tensors for Option A prediction.") + raise RuntimeError( + "predict_y: dataset must expose .C and .X tensors for Option A prediction." + ) C = ds.C[idx_t] X = ds.X[idx_t] - # dtype align if torch.is_tensor(C): C = C.to(dtype=betas.dtype) else: @@ -254,39 +278,39 @@ def predict_y(self, model: pl.LightningModule, dataloader) -> np.ndarray: else: X = torch.as_tensor(X, dtype=betas.dtype) - with torch.no_grad(): yhat = model._predict_y(C, X, betas, mus).detach().cpu().numpy() return yhat - - class CorrelationTrainer(RegressionTrainer): """ Trains the contextualized.regression correlation lightning_modules - and exposes a helper to compute context-specific correlation matrices. """ @torch.no_grad() def predict_correlation(self, model: pl.LightningModule, dataloader) -> np.ndarray: + """ + Returns context-specific correlation networks containing Pearson's correlation coefficient + - correlation (numpy.ndarray): (n, x_dim, x_dim) + """ preds = super().predict(model, dataloader) preds_flat = _flatten_pl_predict_output(preds) - # If model returns correlations directly, gather and reorder them. if preds_flat and ("correlations" in preds_flat[0]): payload = _gather_predict_payload(preds, keys=("orig_idx", "correlations")) if payload is None: - return None # non-rank0 DDP + return None if "correlations" not in payload: - raise RuntimeError("predict_correlation: predict_step returned no 'correlations'.") + raise RuntimeError( + "predict_correlation: predict_step returned no 'correlations'." + ) return payload["correlations"] - # Fallback: derive from betas betas, _ = self.predict_params(model, dataloader) if betas is None: - return None # non-rank0 DDP + return None signs = np.sign(betas) signs[signs != np.transpose(signs, (0, 2, 1))] = 0 @@ -294,50 +318,40 @@ def predict_correlation(self, model: pl.LightningModule, dataloader) -> np.ndarr return correlations - class MarkovTrainer(CorrelationTrainer): """ Trains the contextualized.regression markov graph lightning_modules - and exposes a helper to compute context-specific precision matrices. """ @torch.no_grad() def predict_precision(self, model: pl.LightningModule, dataloader) -> np.ndarray: """ - Returns context-specific precision matrix under a Gaussian graphical model. - + Returns context-specific precision matrix under a Gaussian graphical model Assuming all diagonal precisions are equal and constant over context, this is equivalent to the negative of the multivariate regression coefficient. - - Returns - ------- - precision : (n, x_dim, x_dim) + - precision (numpy.ndarray): (n, x_dim, x_dim) """ - # A trick in the markov lightning_module predict_step ensures the - # correlation output corresponds (up to sign) to precision entries. return -super().predict_correlation(model, dataloader) - def choose_lightning_environment() -> LightningEnvironment: - # If you have a custom Environment subclass, wire it here. - # Otherwise, the default LightningEnvironment is fine. + """ + Returns the Lightning environment plugin used for single-process runs. + """ return LightningEnvironment() + def make_trainer_with_env(trainer_cls, **trainer_kwargs): """ Factory that respects caller-provided `devices` and `strategy`. - FIXED: Don't inject LightningEnvironment when torchrun is managing processes. + Does not inject LightningEnvironment when torchrun is managing processes. """ import os - - # Check if we're under torchrun (WORLD_SIZE > 1 means torchrun is managing) + world_size = int(os.environ.get("WORLD_SIZE", "1")) - - # Only inject LightningEnvironment for single-process runs - # When torchrun is active, Lightning will auto-detect TorchElasticEnvironment + if "plugins" not in trainer_kwargs and world_size == 1: env = choose_lightning_environment() trainer_kwargs["plugins"] = [env] - return trainer_cls(**trainer_kwargs) \ No newline at end of file + return trainer_cls(**trainer_kwargs) diff --git a/network_scaling_heavy.py b/network_scaling_heavy.py index c1e04bff..77314ab2 100644 --- a/network_scaling_heavy.py +++ b/network_scaling_heavy.py @@ -1,38 +1,12 @@ #!/usr/bin/env python3 -""" -HEAVY ContextualizedCorrelationNetworks DDP Scaling Benchmark - -This benchmark tests multi-GPU scaling with the actual ContextualizedCorrelationNetworks -model, but configured for maximum compute to properly stress-test GPU parallelism. - -Key optimizations for heavier compute: -1. Larger encoder networks (more layers, wider hidden dims) -2. More archetypes (more mixture components to learn) -3. Multiple bootstraps (ensemble of models) -4. Larger batch sizes to saturate GPU memory -5. More training epochs -6. Increased data dimensionality (more PCs) - -The goal is to make the model heavy enough that: -- Forward/backward pass takes significant time (50-200ms per batch) -- GPU compute dominates over NCCL sync overhead -- Multi-GPU scaling approaches theoretical limits (85-95% efficiency) - -Usage: - # 1-GPU baseline - python ccn_scaling_heavy.py --epochs 20 --devices 1 --label 1gpu_baseline - - # Multi-GPU with torchrun - torchrun --standalone --nproc_per_node=4 ccn_scaling_heavy.py --epochs 20 --label 4gpu_ddp -""" +# Heavy DDP scaling benchmark for ContextualizedCorrelationNetworks with cached preprocessing and compute-intensive settings (no CSV output). import os import time -import csv import warnings import pickle from dataclasses import dataclass -from typing import Tuple, Optional, List +from typing import Tuple, Optional import numpy as np import pandas as pd @@ -49,17 +23,15 @@ from contextualized.easy import ContextualizedCorrelationNetworks -# ================= CONFIGURATION ================= - +# Configuration BASE_DIR = os.path.dirname(os.path.abspath(__file__)) DATA_DIR = os.path.join(os.path.dirname(BASE_DIR), "data") PATH_L1000 = os.path.join(DATA_DIR, "trt_cp_smiles_qc.csv") PATH_CTLS = os.path.join(DATA_DIR, "ctrls.csv") -# INCREASED: More PCs = larger feature space = more compute -N_DATA_PCS = 100 # Was 50 -N_CONTEXT_PCS = 100 # Control profile PCs +N_DATA_PCS = 100 +N_CONTEXT_PCS = 100 PERTURBATION_HOLDOUT_SIZE = 0.2 RANDOM_STATE = 42 @@ -67,10 +39,8 @@ morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=4096) -# ================= DISTRIBUTED HELPERS ================= - +# Distributed helpers def is_global_zero() -> bool: - """Return True only on global rank 0.""" if dist.is_available() and dist.is_initialized(): try: return dist.get_rank() == 0 @@ -105,33 +75,28 @@ def print_rank0(msg: str): print(msg, flush=True) -# ================= ENVIRONMENT SETUP ================= - +# Environment setup def set_env_defaults(): - """Optimized environment for heavy CCN training.""" world_size = int(os.environ.get("WORLD_SIZE", "1")) cpu_count = os.cpu_count() or 8 threads = max(1, cpu_count // max(world_size, 1)) - + os.environ.setdefault("OMP_NUM_THREADS", str(min(threads, 4))) os.environ.setdefault("MKL_NUM_THREADS", str(min(threads, 4))) os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - - # NCCL optimizations + os.environ.setdefault("NCCL_DEBUG", "WARN") os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "1") os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") os.environ.setdefault("NCCL_ALGO", "Ring") os.environ.setdefault("NCCL_NSOCKS_PERTHREAD", "4") os.environ.setdefault("NCCL_SOCKET_NTHREADS", "2") - - # PyTorch optimizations + try: torch.set_float32_matmul_precision("high") - except: + except Exception: pass - - # Deterministic seeds + np.random.seed(RANDOM_STATE) torch.manual_seed(RANDOM_STATE) if torch.cuda.is_available(): @@ -141,24 +106,20 @@ def set_env_defaults(): torch.backends.cudnn.benchmark = True -# ================= FINGERPRINT HELPER ================= - +# Fingerprint helper def smiles_to_morgan_fp(smiles: str) -> np.ndarray: - """Convert SMILES to Morgan fingerprint.""" try: mol = Chem.MolFromSmiles(smiles) if mol is None: return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) fp = morgan_gen.GetFingerprint(mol) return np.array(fp, dtype=np.float32) - except: + except Exception: return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) -# ================= DATA LOADING WITH CACHE ================= - +# Data loading with optional cache def get_cache_path(subsample_fraction: Optional[float], n_data_pcs: int) -> str: - """Generate cache path based on config.""" suffix = f"_sub{subsample_fraction}" if subsample_fraction else "" suffix += f"_pcs{n_data_pcs}" return os.path.join(DATA_DIR, f"ccn_heavy_cache{suffix}.pkl") @@ -170,140 +131,135 @@ def load_and_preprocess( n_data_pcs: int = N_DATA_PCS, n_context_pcs: int = N_CONTEXT_PCS, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """ - Load and preprocess data with configurable dimensionality. - Higher dimensions = more compute in the model. - """ cache_path = get_cache_path(subsample_fraction, n_data_pcs) - - # Try cache + if use_cache and os.path.exists(cache_path): print_rank0(f"[DATA] Loading from cache: {cache_path}") - with open(cache_path, 'rb') as f: + with open(cache_path, "rb") as f: cached = pickle.load(f) return ( - cached['C_train'], cached['X_train_norm'], - cached['C_test'], cached['X_test_norm'], - cached['cell_ids_train'], cached['cell_ids_test'] + cached["C_train"], + cached["X_train_norm"], + cached["C_test"], + cached["X_test_norm"], + cached["cell_ids_train"], + cached["cell_ids_test"], ) - - # Wait for rank 0 to create cache + if not is_global_zero() and use_cache: wait_count = 0 while not os.path.exists(cache_path) and wait_count < 600: time.sleep(1) wait_count += 1 if os.path.exists(cache_path): - with open(cache_path, 'rb') as f: + with open(cache_path, "rb") as f: cached = pickle.load(f) return ( - cached['C_train'], cached['X_train_norm'], - cached['C_test'], cached['X_test_norm'], - cached['cell_ids_train'], cached['cell_ids_test'] + cached["C_train"], + cached["X_train_norm"], + cached["C_test"], + cached["X_test_norm"], + cached["cell_ids_train"], + cached["cell_ids_test"], ) - + print_rank0(f"[DATA] Loading L1000 from {PATH_L1000}") df = pd.read_csv(PATH_L1000, engine="pyarrow") - + df = df[df["pert_type"].isin(["trt_cp"])] - + bad = ( - (df["distil_cc_q75"] < 0.2) | - (df["distil_cc_q75"] == -666) | - (df["distil_cc_q75"].isna()) | - (df["pct_self_rank_q25"] > 5) | - (df["pct_self_rank_q25"] == -666) | - (df["pct_self_rank_q25"].isna()) + (df["distil_cc_q75"] < 0.2) + | (df["distil_cc_q75"] == -666) + | (df["distil_cc_q75"].isna()) + | (df["pct_self_rank_q25"] > 5) + | (df["pct_self_rank_q25"] == -666) + | (df["pct_self_rank_q25"].isna()) ) df = df[~bad] df = df.dropna(subset=["canonical_smiles"]) df = df[df["canonical_smiles"] != ""] - + print_rank0(f"[DATA] Samples after QC: {len(df)}") - + if subsample_fraction is not None: df = df.sample(frac=subsample_fraction, random_state=RANDOM_STATE) - print_rank0(f"[DATA] Subsampled to {len(df)} ({subsample_fraction*100:.1f}%)") - - # Split by perturbation + print_rank0(f"[DATA] Subsampled to {len(df)} ({subsample_fraction * 100:.1f}%)") + unique_smiles = df["canonical_smiles"].unique() print_rank0(f"[DATA] Unique perturbations: {len(unique_smiles)}") - + smiles_train, smiles_test = train_test_split( unique_smiles, test_size=PERTURBATION_HOLDOUT_SIZE, random_state=RANDOM_STATE ) - + df_train = df[df["canonical_smiles"].isin(smiles_train)].copy() df_test = df[df["canonical_smiles"].isin(smiles_test)].copy() - + print_rank0(f"[DATA] Train: {len(df_train)}, Test: {len(df_test)}") - - # Handle missing values + pert_time_mean = df_train.loc[df_train["pert_time"] != -666, "pert_time"].mean() pert_dose_mean = df_train.loc[df_train["pert_dose"] != -666, "pert_dose"].mean() - - for df_split in [df_train, df_test]: + + for df_split in (df_train, df_test): df_split["ignore_flag_pert_time"] = (df_split["pert_time"] == -666).astype(int) df_split["ignore_flag_pert_dose"] = (df_split["pert_dose"] == -666).astype(int) df_split["pert_time"] = df_split["pert_time"].replace(-666, pert_time_mean) df_split["pert_dose"] = df_split["pert_dose"].replace(-666, pert_dose_mean) - + def process_split(df_split, name): numeric_cols = df_split.select_dtypes(include=[np.number]).columns drop_cols = ["pert_dose", "pert_dose_unit", "pert_time", "distil_cc_q75", "pct_self_rank_q25"] feature_cols = [c for c in numeric_cols if c not in drop_cols] X_raw = df_split[feature_cols].values.astype(np.float32) - + print_rank0(f"[DATA] [{name}] Generating fingerprints...") fps = np.stack([smiles_to_morgan_fp(s) for s in df_split["canonical_smiles"]]) print_rank0(f"[DATA] [{name}] Fingerprint shape: {fps.shape}") - + pert_time = df_split["pert_time"].to_numpy().reshape(-1, 1).astype(np.float32) pert_dose = df_split["pert_dose"].to_numpy().reshape(-1, 1).astype(np.float32) ign_t = df_split["ignore_flag_pert_time"].to_numpy().reshape(-1, 1).astype(np.float32) ign_d = df_split["ignore_flag_pert_dose"].to_numpy().reshape(-1, 1).astype(np.float32) - + return X_raw, fps, pert_time, pert_dose, ign_t, ign_d, df_split["cell_id"].to_numpy() - + X_train_raw, morgan_train, pt_train, pd_train, ign_t_train, ign_d_train, cells_train = process_split(df_train, "train") X_test_raw, morgan_test, pt_test, pd_test, ign_t_test, ign_d_test, cells_test = process_split(df_test, "test") - - # Scale features + print_rank0("[DATA] Scaling gene expression...") scaler_genes = StandardScaler() X_train_scaled = scaler_genes.fit_transform(X_train_raw) X_test_scaled = scaler_genes.transform(X_test_raw) - - # Load controls + print_rank0(f"[DATA] Loading controls from {PATH_CTLS}") ctrls_df = pd.read_csv(PATH_CTLS, index_col=0) - + unique_cells = np.union1d(np.unique(cells_train), np.unique(cells_test)) ctrls_df = ctrls_df.loc[ctrls_df.index.intersection(unique_cells)] - + scaler_ctrls = StandardScaler() ctrls_scaled = scaler_ctrls.fit_transform(ctrls_df.values) - - # INCREASED: More control PCs + actual_n_ctrl_pcs = min(n_context_pcs, ctrls_scaled.shape[0], ctrls_scaled.shape[1]) print_rank0(f"[DATA] Using {actual_n_ctrl_pcs} control PCs") - + pca_ctrls = PCA(n_components=actual_n_ctrl_pcs, random_state=RANDOM_STATE) ctrls_pcs = pca_ctrls.fit_transform(ctrls_scaled) cell2vec = dict(zip(ctrls_df.index, ctrls_pcs)) - + if not cell2vec: raise ValueError("No overlapping cell IDs") - + print_rank0(f"[DATA] Control embeddings for {len(cell2vec)} cells") - - def build_context(df_split, X_scaled, morgan, pt, pd, ign_t, ign_d, name, scaler=None, fit=False): + + def build_context(df_split, X_scaled, morgan, pt, pd, ign_t, ign_d, scaler=None, fit=False): cell_ids = df_split["cell_id"].to_numpy() unique_cells_split = np.sort(df_split["cell_id"].unique()) - + all_cont = [] valid_cells = [] - + for cell_id in unique_cells_split: if cell_id not in cell2vec: continue @@ -311,83 +267,87 @@ def build_context(df_split, X_scaled, morgan, pt, pd, ign_t, ign_d, name, scaler if mask.sum() == 0: continue valid_cells.append(cell_id) - cont = np.hstack([ - np.tile(cell2vec[cell_id], (mask.sum(), 1)), - pt[mask], - pd[mask], - ]).astype(np.float32) + cont = np.hstack( + [ + np.tile(cell2vec[cell_id], (mask.sum(), 1)), + pt[mask], + pd[mask], + ] + ).astype(np.float32) all_cont.append(cont) - + if fit: all_cont_stacked = np.vstack(all_cont) scaler = StandardScaler() scaler.fit(all_cont_stacked) - + X_list, C_list, cid_list = [], [], [] - + for i, cell_id in enumerate(valid_cells): mask = cell_ids == cell_id X_cell = X_scaled[mask] cont_scaled = scaler.transform(all_cont[i]) - C_cell = np.hstack([ - cont_scaled, - morgan[mask], - ign_t[mask], - ign_d[mask], - ]).astype(np.float32) - + C_cell = np.hstack( + [ + cont_scaled, + morgan[mask], + ign_t[mask], + ign_d[mask], + ] + ).astype(np.float32) + X_list.append(X_cell) C_list.append(C_cell) cid_list.append(cell_ids[mask]) - + X_final = np.vstack(X_list) C_final = np.vstack(C_list) cell_ids_final = np.concatenate(cid_list) - + return X_final, C_final, cell_ids_final, scaler - + print_rank0("[DATA] Building context matrices...") X_train, C_train, cell_ids_train, ctx_scaler = build_context( - df_train, X_train_scaled, morgan_train, pt_train, pd_train, ign_t_train, ign_d_train, "train", fit=True + df_train, X_train_scaled, morgan_train, pt_train, pd_train, ign_t_train, ign_d_train, fit=True ) X_test, C_test, cell_ids_test, _ = build_context( - df_test, X_test_scaled, morgan_test, pt_test, pd_test, ign_t_test, ign_d_test, "test", scaler=ctx_scaler + df_test, X_test_scaled, morgan_test, pt_test, pd_test, ign_t_test, ign_d_test, scaler=ctx_scaler ) - + print_rank0(f"[DATA] Context shapes: C_train={C_train.shape}, C_test={C_test.shape}") - - # INCREASED: More data PCs + actual_n_data_pcs = min(n_data_pcs, X_train.shape[1], X_train.shape[0]) print_rank0(f"[DATA] Using {actual_n_data_pcs} data PCs") - + pca_data = PCA(n_components=actual_n_data_pcs, random_state=RANDOM_STATE) X_train_pca = pca_data.fit_transform(X_train) X_test_pca = pca_data.transform(X_test) - + pca_scaler = StandardScaler() X_train_norm = pca_scaler.fit_transform(X_train_pca).astype(np.float32) X_test_norm = pca_scaler.transform(X_test_pca).astype(np.float32) - + print_rank0(f"[DATA] Final: X_train={X_train_norm.shape}, X_test={X_test_norm.shape}") print_rank0(f"[DATA] Final: C_train={C_train.shape}, C_test={C_test.shape}") - - # Save cache + if use_cache and is_global_zero(): cache_data = { - 'C_train': C_train, 'X_train_norm': X_train_norm, - 'C_test': C_test, 'X_test_norm': X_test_norm, - 'cell_ids_train': cell_ids_train, 'cell_ids_test': cell_ids_test, + "C_train": C_train, + "X_train_norm": X_train_norm, + "C_test": C_test, + "X_test_norm": X_test_norm, + "cell_ids_train": cell_ids_train, + "cell_ids_test": cell_ids_test, } os.makedirs(os.path.dirname(cache_path), exist_ok=True) - with open(cache_path, 'wb') as f: + with open(cache_path, "wb") as f: pickle.dump(cache_data, f) print_rank0(f"[DATA] Saved cache: {cache_path}") - - return C_train, X_train_norm, C_test, X_test_norm, cell_ids_train, cell_ids_test + return C_train, X_train_norm, C_test, X_test_norm, cell_ids_train, cell_ids_test -# ================= BENCHMARK RESULT ================= +# Benchmark result @dataclass class BenchResult: label: str @@ -406,8 +366,7 @@ class BenchResult: efficiency: float = 100.0 -# ================= MAIN BENCHMARK ================= - +# Benchmark runner def run_ccn_benchmark( label: str, C_train: np.ndarray, @@ -418,7 +377,6 @@ def run_ccn_benchmark( devices: int, batch_size_per_gpu: int = 512, num_workers: int = 4, - # Heavy CCN parameters num_archetypes: int = 64, encoder_width: int = 256, encoder_layers: int = 6, @@ -426,22 +384,11 @@ def run_ccn_benchmark( warmup_epochs: int = 1, baseline_time: Optional[float] = None, ) -> BenchResult: - """ - Run ContextualizedCorrelationNetworks benchmark with heavy configuration. - - Key parameters for increased compute: - - num_archetypes: More mixture components (64 vs default 16) - - encoder_width: Wider encoder networks (256 vs default 25) - - encoder_layers: Deeper encoders (6 vs default 3) - - n_bootstraps: Ensemble of models (3 vs default 1) - """ - world_size = int(os.environ.get("WORLD_SIZE", "1")) rank = get_rank() local_rank = get_local_rank() launched_with_torchrun = world_size > 1 - - # Device setup + if torch.cuda.is_available() and devices > 0: accelerator = "gpu" if launched_with_torchrun: @@ -450,17 +397,15 @@ def run_ccn_benchmark( accelerator = "cpu" devices = 1 num_workers = 0 - - # Reduce workers for multi-GPU + if launched_with_torchrun and num_workers > 2: num_workers = 2 - - # Batch size: scale with GPUs for proper throughput scaling + effective_batch = batch_size_per_gpu * max(world_size, 1) - - print_rank0(f"\n{'='*70}") + + print_rank0(f"\n{'=' * 70}") print_rank0(f"[{label}] HEAVY CCN BENCHMARK") - print_rank0(f"{'='*70}") + print_rank0(f"{'=' * 70}") print_rank0(f" World size: {world_size}") print_rank0(f" Accelerator: {accelerator}") print_rank0(f" Devices: {devices}") @@ -474,19 +419,18 @@ def run_ccn_benchmark( print_rank0(f" Encoder layers: {encoder_layers}") print_rank0(f" Bootstraps: {n_bootstraps}") print_rank0(f" Data dims: C={C_train.shape[1]}, X={X_train_norm.shape[1]}") - - # Log per-process info + print( f"[{label}] [RANK {rank} / LOCAL {local_rank}] " f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}", - flush=True + flush=True, ) - - # Strategy configuration + strategy_kwarg = "auto" if accelerator == "gpu" and launched_with_torchrun and world_size > 1: try: from pytorch_lightning.strategies import DDPStrategy + strategy_kwarg = DDPStrategy( process_group_backend="nccl", find_unused_parameters=False, @@ -496,8 +440,7 @@ def run_ccn_benchmark( except Exception as e: strategy_kwarg = "ddp" print_rank0(f"[{label}] Falling back to strategy='ddp': {e}") - - # Trainer kwargs + trainer_kwargs = { "max_epochs": epochs + warmup_epochs, "accelerator": accelerator, @@ -509,12 +452,10 @@ def run_ccn_benchmark( "precision": "16-mixed" if accelerator == "gpu" else 32, "strategy": strategy_kwarg, } - + print_rank0(f"[{label}] Trainer kwargs: {trainer_kwargs}") - - # Construct HEAVY CCN model print_rank0(f"[{label}] Constructing ContextualizedCorrelationNetworks...") - + ccn = ContextualizedCorrelationNetworks( encoder_type="mlp", num_archetypes=num_archetypes, @@ -524,28 +465,27 @@ def run_ccn_benchmark( "layers": encoder_layers, }, trainer_kwargs=trainer_kwargs, - es_patience=0, # No early stopping for benchmark + es_patience=0, ) - - # Estimate parameter count - # CCN params ≈ n_bootstraps × (encoder_params + archetype_params + correlation_params) - # encoder_params ≈ (context_dim × width + width × width × (layers-1) + width × archetypes) - # archetype_params ≈ archetypes × x_dim × x_dim (correlation matrices) + context_dim = C_train.shape[1] x_dim = X_train_norm.shape[1] - encoder_params = context_dim * encoder_width + encoder_width * encoder_width * (encoder_layers - 1) + encoder_width * num_archetypes + encoder_params = ( + context_dim * encoder_width + + encoder_width * encoder_width * (encoder_layers - 1) + + encoder_width * num_archetypes + ) archetype_params = num_archetypes * x_dim * x_dim total_params = n_bootstraps * (encoder_params + archetype_params) - print_rank0(f"[{label}] Estimated parameters: ~{total_params:,} ({total_params/1e6:.2f}M)") - - # Synchronize before training + print_rank0(f"[{label}] Estimated parameters: ~{total_params:,} ({total_params / 1e6:.2f}M)") + barrier() if torch.cuda.is_available(): torch.cuda.synchronize() - + print_rank0(f"[{label}] Starting training...") t0 = time.time() - + ccn.fit( C_train, X_train_norm, @@ -556,34 +496,31 @@ def run_ccn_benchmark( persistent_workers=(num_workers > 0), pin_memory=(accelerator == "gpu"), ) - - # Synchronize after training + barrier() if torch.cuda.is_available(): torch.cuda.synchronize() - + wall = time.time() - t0 - - # Adjust for warmup + if warmup_epochs > 0 and epochs > 0: wall_per_epoch = wall / (epochs + warmup_epochs) wall = wall_per_epoch * epochs - + print_rank0(f"[{label}] Training completed in {wall:.2f}s") - - # Metrics + n_samples = C_train.shape[0] samples_per_sec = (n_samples * epochs) / max(wall, 1e-6) - + speedup = 1.0 efficiency = 100.0 if baseline_time is not None and baseline_time > 0: speedup = baseline_time / wall efficiency = (speedup / world_size) * 100 - + train_mse = float("nan") test_mse = float("nan") - + if is_global_zero(): try: print_rank0(f"[{label}] Computing MSE...") @@ -593,7 +530,7 @@ def run_ccn_benchmark( test_mse = float(np.mean(mse_test_vec)) except Exception as e: warnings.warn(f"[{label}] measure_mses failed: {e}") - + print_rank0(f"\n[{label}] RESULTS:") print_rank0(f" Wall time: {wall:.2f}s") print_rank0(f" Samples/sec: {samples_per_sec:.1f}") @@ -602,7 +539,7 @@ def run_ccn_benchmark( if baseline_time: print_rank0(f" Speedup: {speedup:.2f}x") print_rank0(f" Efficiency: {efficiency:.1f}%") - + return BenchResult( label=label, wall_seconds=wall, @@ -621,105 +558,45 @@ def run_ccn_benchmark( ) -# ================= CSV OUTPUT ================= - -def save_results_csv(results: List[BenchResult], outdir: str): - if not is_global_zero(): - return - - os.makedirs(outdir, exist_ok=True) - path = os.path.join(outdir, "ccn_heavy_scaling_results.csv") - - write_header = not os.path.exists(path) - - with open(path, "a", newline="") as f: - writer = csv.writer(f) - if write_header: - writer.writerow([ - "label", "wall_seconds", "train_mse", "test_mse", - "num_gpus", "batch_per_gpu", "effective_batch", "samples_per_sec", - "archetypes", "encoder_width", "encoder_layers", "bootstraps", - "speedup", "efficiency" - ]) - for r in results: - writer.writerow([ - r.label, - f"{r.wall_seconds:.4f}", - f"{r.train_mse_mean:.6f}", - f"{r.test_mse_mean:.6f}", - r.num_gpus, - r.batch_size_per_gpu, - r.effective_batch_size, - f"{r.samples_per_second:.2f}", - r.num_archetypes, - r.encoder_width, - r.encoder_layers, - r.n_bootstraps, - f"{r.speedup:.2f}", - f"{r.efficiency:.1f}", - ]) - - print_rank0(f"\n[OUTPUT] Results appended to: {path}") - - -# ================= CLI ================= - +# CLI def parse_args(): import argparse - - ap = argparse.ArgumentParser(description="Heavy ContextualizedCorrelationNetworks Scaling Benchmark") - - # Training config + + ap = argparse.ArgumentParser(description="Heavy ContextualizedCorrelationNetworks Scaling Benchmark (no CSV output)") + ap.add_argument("--epochs", type=int, default=20) ap.add_argument("--warmup-epochs", type=int, default=1) - ap.add_argument("--batch-size", type=int, default=512, - help="Batch size per GPU") + ap.add_argument("--batch-size", type=int, default=512) ap.add_argument("--num-workers", type=int, default=4) - - # CCN architecture (HEAVY defaults) - ap.add_argument("--archetypes", type=int, default=64, - help="Number of archetypes (default: 64, original: 16)") - ap.add_argument("--encoder-width", type=int, default=256, - help="Encoder hidden width (default: 256, original: 25)") - ap.add_argument("--encoder-layers", type=int, default=6, - help="Encoder depth (default: 6, original: 3)") - ap.add_argument("--bootstraps", type=int, default=3, - help="Number of bootstrap models (default: 3, original: 1)") - - # Data config - ap.add_argument("--data-pcs", type=int, default=100, - help="Number of data PCs (default: 100, original: 50)") - ap.add_argument("--context-pcs", type=int, default=100, - help="Number of context PCs (default: 100)") + + ap.add_argument("--archetypes", type=int, default=64) + ap.add_argument("--encoder-width", type=int, default=256) + ap.add_argument("--encoder-layers", type=int, default=6) + ap.add_argument("--bootstraps", type=int, default=3) + + ap.add_argument("--data-pcs", type=int, default=100) + ap.add_argument("--context-pcs", type=int, default=100) ap.add_argument("--subsample-fraction", type=float, default=None) - - # Runtime config + ap.add_argument("--devices", type=int, default=1) - ap.add_argument("--outdir", type=str, default="bench_results_ccn_heavy") ap.add_argument("--label", type=str, default=None) ap.add_argument("--baseline-time", type=float, default=None) ap.add_argument("--no-cache", action="store_true") - - return ap.parse_args() + return ap.parse_args() -# ================= MAIN ================= +# Main def main(): args = parse_args() set_env_defaults() - + world_size = get_world_size() - - # Auto-generate label if not provided - if args.label: - label = args.label - else: - label = f"{world_size}gpu_ccn_heavy" - - print_rank0("\n" + "="*70) - print_rank0("HEAVY ContextualizedCorrelationNetworks SCALING BENCHMARK") - print_rank0("="*70) + label = args.label or f"{world_size}gpu_ccn_heavy" + + print_rank0("\n" + "=" * 70) + print_rank0("HEAVY ContextualizedCorrelationNetworks SCALING BENCHMARK (NO CSV)") + print_rank0("=" * 70) print_rank0(f" World size: {world_size}") print_rank0(f" Epochs: {args.epochs}") print_rank0(f" Batch size: {args.batch_size}") @@ -727,19 +604,17 @@ def main(): print_rank0(f" Encoder: {args.encoder_width}w × {args.encoder_layers}L") print_rank0(f" Bootstraps: {args.bootstraps}") print_rank0(f" Data PCs: {args.data_pcs}") - - # Load data + C_train, X_train_norm, C_test, X_test_norm, _, _ = load_and_preprocess( subsample_fraction=args.subsample_fraction, use_cache=not args.no_cache, n_data_pcs=args.data_pcs, n_context_pcs=args.context_pcs, ) - + barrier() - - # Run benchmark - result = run_ccn_benchmark( + + _ = run_ccn_benchmark( label=label, C_train=C_train, X_train_norm=X_train_norm, @@ -756,11 +631,7 @@ def main(): warmup_epochs=args.warmup_epochs, baseline_time=args.baseline_time, ) - - # Save results - if is_global_zero(): - save_results_csv([result], args.outdir) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/networks_pert_scale_bench.py b/networks_pert_scale_bench.py index 2525de0b..edd15c6b 100644 --- a/networks_pert_scale_bench.py +++ b/networks_pert_scale_bench.py @@ -1,35 +1,5 @@ #!/usr/bin/env python3 -""" -Baseline scaling benchmark for unseen_pert with true 1-GPU vs 2-GPU comparison (DDP). - -- Preprocesses L1000 + controls, building C (context) and X (features) -- Trains a simple MLP regressor C -> X - -It runs two modes in ONE command: - 1) 1 GPU -> single-process training on cuda:0 - 2) 2 GPUs -> DistributedDataParallel (DDP) with 2 processes (ranks 0 and 1), - each bound to one GPU. -D -For each mode it prints: - - wall time (seconds) - - throughput (samples / second) - - final train MSE - - final test MSE - -Outputs: - - CSV: bench_out_unseen/scale_results_unseen_ddp.csv (two rows: 1gpu, 2gpu) - -Typical usage inside a 2-GPU interactive job: - - cd /fs/scratch/PAS2942/samuel_wales_mcgrath/hpc/Contextualized - conda activate contextpert-hpc - - python unseen_pert_scale_ddp.py \ - --epochs 20 \ - --batch-size 512 \ - --num-workers 0 \ - --subsample-fraction 1.0 -""" +# Benchmark script that preprocesses unseen_pert data and compares 1-GPU training vs 2-GPU DDP training for a simple MLP regressor. import os import time @@ -56,8 +26,7 @@ from rdkit.Chem import rdFingerprintGenerator -# ------------------- paths & basic config ------------------- - +# Paths and basic config BASE_DIR = os.path.dirname(os.path.abspath(__file__)) DATA_DIR = os.path.join(os.path.dirname(BASE_DIR), "data") @@ -71,15 +40,12 @@ morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=4096) -# ------------------- env + seeds ------------------- - +# Environment and RNG seeding def set_env_defaults(): - """Safe CPU/GPU threading + seeds (for non-DDP parts).""" os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - # Do NOT clear MASTER_ADDR/MASTER_PORT here; DDP will need them later. try: torch.set_float32_matmul_precision("high") except Exception: @@ -92,17 +58,14 @@ def set_env_defaults(): def set_seeds(rank: int): - """Per-process seeds for DDP workers.""" np.random.seed(RANDOM_STATE + rank) torch.manual_seed(RANDOM_STATE + rank) if torch.cuda.is_available(): torch.cuda.manual_seed_all(RANDOM_STATE + rank) -# ------------------- fingerprint helper ------------------- - +# Fingerprint helper def smiles_to_morgan_fp(smiles: str) -> np.ndarray: - """Convert a SMILES string to a Morgan fingerprint (binary vector).""" try: mol = Chem.MolFromSmiles(smiles) if mol is None: @@ -116,26 +79,15 @@ def smiles_to_morgan_fp(smiles: str) -> np.ndarray: return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) -# ------------------- data preprocessing (unseen_pert) ------------------- - +# Data preprocessing for unseen_pert def load_and_preprocess( subsample_fraction: Optional[float] = None, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """ - Implements the unseen_pert preprocessing, returning: - C_train, X_train_norm, C_test, X_test_norm, cell_ids_train, cell_ids_test - - where: - - X_*_norm are PCA+standardized gene features - - C_* are context vectors (ctrl PCs + Morgan + time/dose flags) - """ print(f"Reading L1000 data from {PATH_L1000}") df = pd.read_csv(PATH_L1000, engine="pyarrow") - # Only trt_cp perturbations df = df[df["pert_type"].isin(["trt_cp"])] - # Quality filters bad = ( (df["distil_cc_q75"] < 0.2) | (df["distil_cc_q75"] == -666) @@ -146,7 +98,6 @@ def load_and_preprocess( ) df = df[~bad] - # Valid SMILES only df = df.dropna(subset=["canonical_smiles"]) df = df[df["canonical_smiles"] != ""] @@ -156,7 +107,6 @@ def load_and_preprocess( df = df.sample(frac=subsample_fraction, random_state=RANDOM_STATE) print(f"Subsampled to {len(df)} samples ({subsample_fraction * 100:.1f}% of data)") - # Perturbation holdout: split on unique SMILES unique_smiles = df["canonical_smiles"].unique() print(f"Found {len(unique_smiles)} unique perturbations (SMILES)") smiles_train, smiles_test = train_test_split( @@ -171,7 +121,6 @@ def load_and_preprocess( print(f"Perturbation split: {len(smiles_train)} train, {len(smiles_test)} test perturbations") print(f"Sample split: {len(df_train)} train, {len(df_test)} test samples") - # Handle pert_time / pert_dose missing values with -666 logic pert_time_mean = None pert_dose_mean = None @@ -221,17 +170,14 @@ def process_data_split(df_split, split_name): df_test, "test" ) - # Scale gene expression print("Scaling gene expression...") scaler_genes = StandardScaler() X_train_scaled = scaler_genes.fit_transform(X_raw_train) X_test_scaled = scaler_genes.transform(X_raw_test) - # Morgan fingerprints as float (already binary) morgan_train_scaled = morgan_train.astype(np.float32) morgan_test_scaled = morgan_test.astype(np.float32) - # Load controls print(f"Reading control profiles from {PATH_CTLS}") ctrls_df = pd.read_csv(PATH_CTLS, index_col=0) @@ -356,7 +302,6 @@ def build_context_matrix( print(f"C_train: {C_train.shape}, X_train: {X_train.shape}") print(f"C_test: {C_test.shape}, X_test: {X_test.shape}") - # PCA on X then scale print("PCA + scaling on gene features...") pca_data = PCA(n_components=N_DATA_PCS, random_state=RANDOM_STATE) X_train_pca = pca_data.fit_transform(X_train) @@ -371,7 +316,6 @@ def build_context_matrix( return C_train, X_train_norm, C_test, X_test_norm, cell_ids_train, cell_ids_test - @dataclass class BenchResult: label: str @@ -404,7 +348,6 @@ def run_single_gpu( num_workers: int, subsample_fraction: Optional[float], ) -> BenchResult: - """Single-process, single-GPU training on cuda:0 (or CPU).""" label = "1gpu_single" print("\n================ 1-GPU baseline (single process) ================") @@ -481,7 +424,6 @@ def run_single_gpu( samples_total = n_samples * epochs throughput = samples_total / max(wall, 1e-9) - # Evaluation def eval_mse(loader, split_name: str) -> float: model.eval() total_loss = 0.0 @@ -529,13 +471,8 @@ def ddp_worker( subsample_fraction: Optional[float], result_dict, ): - """ - DDP worker function run by each spawned process (rank 0 and 1). - We only record metrics in rank 0 and put them in result_dict["2gpu_ddp"]. - """ set_seeds(rank) - # Device mapping: assume 2 GPUs visible, use local index = rank if torch.cuda.is_available(): torch.cuda.set_device(rank) device = torch.device(f"cuda:{rank}") @@ -557,7 +494,6 @@ def ddp_worker( if torch.cuda.is_available(): print(f"[{label}] Using GPUs 0 and 1 with DDP") - # IMPORTANT: we measure only training time; data loading can be duplicated. C_train, X_train_norm, C_test, X_test_norm, _, _ = load_and_preprocess( subsample_fraction=subsample_fraction ) @@ -586,7 +522,6 @@ def ddp_worker( pin_memory=torch.cuda.is_available(), ) - # test_loader will only be used on rank 0 after training (non-distributed). in_dim = C_train.shape[1] out_dim = X_train_norm.shape[1] @@ -603,13 +538,11 @@ def ddp_worker( n_samples = C_train.shape[0] - # Synchronize before timing dist.barrier() if torch.cuda.is_available(): torch.cuda.synchronize() t0 = time.time() - # Training loop for epoch in range(epochs): ddp_model.train() train_sampler.set_epoch(epoch) @@ -631,7 +564,6 @@ def ddp_worker( running_loss += loss.item() * bsz count_seen += bsz - # Aggregate epoch loss to rank 0 (average over all samples) loss_tensor = torch.tensor([running_loss, count_seen], dtype=torch.float64, device=device) dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM) if rank == 0: @@ -639,15 +571,12 @@ def ddp_worker( epoch_loss = total_loss / max(total_count, 1.0) print(f"[{label}] Epoch {epoch+1}/{epochs} - train MSE {epoch_loss:.6f}") - # Synchronize end of training dist.barrier() if torch.cuda.is_available(): torch.cuda.synchronize() wall = time.time() - t0 - # Only rank 0 computes evaluation, using full dataset on its GPU if rank == 0: - # For evaluation we use the underlying model (not wrapped in DDP) eval_model = ddp_model.module eval_model.eval() @@ -704,12 +633,10 @@ def eval_mse(loader, split_name: str) -> float: test_mse_mean=test_mse, ) - # Tear down process group dist.destroy_process_group() -# ------------------- CSV writer ------------------- - +# CSV writer def save_results_csv(results: List[BenchResult], outdir: str): os.makedirs(outdir, exist_ok=True) path = os.path.join(outdir, "scale_results_unseen_ddp.csv") @@ -739,8 +666,7 @@ def save_results_csv(results: List[BenchResult], outdir: str): print(f"\nSaved CSV → {path}") -# ------------------- CLI & main ------------------- - +# CLI and main def parse_args(): import argparse @@ -780,7 +706,6 @@ def main(): results: List[BenchResult] = [] - # 1-GPU baseline res_1gpu = run_single_gpu( epochs=args.epochs, batch_size=args.batch_size, @@ -789,10 +714,9 @@ def main(): ) results.append(res_1gpu) - # 2-GPU DDP baseline if torch.cuda.is_available() and torch.cuda.device_count() >= 2: world_size = 2 - port = args.ddp_port # use TCP init on localhost + port = args.ddp_port manager = mp.Manager() result_dict = manager.dict() @@ -823,4 +747,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/scale_bench.py b/scale_bench.py index 29937bac..4e5e96b1 100644 --- a/scale_bench.py +++ b/scale_bench.py @@ -1,135 +1,48 @@ #!/usr/bin/env python3 -""" -scale_bench.py - -A single-node, torchrun-friendly DDP scaling benchmark for ContextualizedRegression. - -Design goals (to reveal true scaling): - - Fixed number of optimizer steps (not epochs) so each run does identical work. - - Optional GPU-resident synthetic dataset to remove CPU dataloading/transfer bottlenecks. - - Measures only the *steady-state* region (warmup steps excluded). - - Uses Lightning DDP under torchrun correctly (devices=1 per process). - ------------------------------------------------------------- -Quick start (single node, 1..4 GPUs) ------------------------------------------------------------- - -# 0) See NICs (optional) -ls -1 /sys/class/net -ip -o link show | awk -F': ' '{print NR-1": "$2}' - -# 1) Minimal, safe NCCL/torch env (no hard-coded eth0): -export CUDA_VISIBLE_DEVICES=0,1,2,3 -export OMP_NUM_THREADS=1 -export MKL_NUM_THREADS=1 -export TOKENIZERS_PARALLELISM=false -export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 -export NCCL_DEBUG=WARN -export NCCL_P2P_DISABLE=0 -export NCCL_IB_DISABLE=1 -export NCCL_SOCKET_IFNAME=$(ls /sys/class/net | grep -E '^(ens|enp|eno|eth|bond|ib)' | head -n1) -[ -z "$NCCL_SOCKET_IFNAME" ] && export NCCL_SOCKET_IFNAME="^lo,docker0" - -# CUDA allocator tweak (fine to keep) -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True - -# 2) Kill any stragglers (optional) -pkill -f scale_bench.py || true -pkill -f torchrun || true - -# 3) Runs (IMPORTANT: --batch-size is PER GPU) -# Suggested defaults: steps=400 warmup=50 (steady state measured steps=400) - -torchrun --standalone --nproc_per_node=1 scale_bench.py \ - --steps 400 --warmup-steps 50 \ - --batch-size 2048 --precision bf16 \ - --context-dim 16 --x-dim 512 --y-dim 64 \ - --width 1024 --layers 4 \ - --buffer-batches 32 --data-device auto \ - --outdir bench_out/gpu1 - -torchrun --standalone --nproc_per_node=2 scale_bench.py \ - --steps 400 --warmup-steps 50 \ - --batch-size 2048 --precision bf16 \ - --context-dim 16 --x-dim 512 --y-dim 64 \ - --width 1024 --layers 4 \ - --buffer-batches 32 --data-device auto \ - --outdir bench_out/gpu2 - -torchrun --standalone --nproc_per_node=3 scale_bench.py \ - --steps 400 --warmup-steps 50 \ - --batch-size 2048 --precision bf16 \ - --context-dim 16 --x-dim 512 --y-dim 64 \ - --width 1024 --layers 4 \ - --buffer-batches 32 --data-device auto \ - --outdir bench_out/gpu3 - -torchrun --standalone --nproc_per_node=4 scale_bench.py \ - --steps 400 --warmup-steps 50 \ - --batch-size 2048 --precision bf16 \ - --context-dim 16 --x-dim 512 --y-dim 64 \ - --width 1024 --layers 4 \ - --buffer-batches 32 --data-device auto \ - --outdir bench_out/gpu4 - -Notes: - - If scaling is still poor with this benchmark, it is very likely a *real* bottleneck - (GPU interconnect/topology, NCCL config, too-small batch, CPU frequency limits, etc.), - not a dataloader artifact. -""" +# Single-node strong-scaling benchmark runner for ContextualizedRegression using synthetic batched data. import os import time import json -import math import argparse from dataclasses import dataclass from datetime import timedelta -from typing import Dict, Optional +from typing import Any, Optional, Tuple import numpy as np import torch +import torch.nn.functional as F import pytorch_lightning as pl +from torch.utils.data import IterableDataset, DataLoader from pytorch_lightning.callbacks import Callback from pytorch_lightning.strategies import DDPStrategy -# ---- your package pieces ---- from contextualized.regression import ContextualizedRegression -from contextualized.regression.datamodules import ContextualizedRegressionDataModule -# ---------------- launcher/cluster helpers ---------------- +# Torchrun helpers def under_torchrun() -> bool: e = os.environ return ("LOCAL_RANK" in e) or ("RANK" in e) or ("WORLD_SIZE" in e) def world_size() -> int: - try: - return int(os.environ.get("WORLD_SIZE", "1")) - except Exception: - return 1 + return int(os.environ.get("WORLD_SIZE", "1")) def global_rank() -> int: - try: - return int(os.environ.get("RANK", "0")) - except Exception: - return 0 + return int(os.environ.get("RANK", "0")) def local_rank() -> int: - try: - return int(os.environ.get("LOCAL_RANK", "0")) - except Exception: - return 0 + return int(os.environ.get("LOCAL_RANK", "0")) def is_global_zero() -> bool: return global_rank() == 0 -# ---------------- env + perf ---------------- +# Environment defaults def set_env_defaults(): os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") @@ -142,17 +55,12 @@ def set_env_defaults(): if "NCCL_SOCKET_IFNAME" not in os.environ: try: - ifaces = [ - d - for d in os.listdir("/sys/class/net") - if os.path.isdir(f"/sys/class/net/{d}") - ] - cand = next((i for i in ifaces if i not in ("lo", "docker0")), None) + ifaces = [d for d in os.listdir("/sys/class/net") if d not in ("lo", "docker0")] + cand = next((i for i in ifaces if i.startswith(("ens", "enp", "eno", "eth", "bond", "ib"))), None) os.environ["NCCL_SOCKET_IFNAME"] = cand or "^lo,docker0" except Exception: os.environ["NCCL_SOCKET_IFNAME"] = "^lo,docker0" - # TF32 / matmul speedups (safe for benchmarking throughput) if torch.cuda.is_available(): try: torch.backends.cuda.matmul.allow_tf32 = True @@ -168,7 +76,6 @@ def set_env_defaults(): pass if under_torchrun() and torch.cuda.is_available(): - # Ensures each rank uses its intended GPU even if something upstream is odd. try: torch.cuda.set_device(local_rank()) except Exception: @@ -203,57 +110,44 @@ def map_precision(p: str): return 32 -# ---------------- timing ---------------- +# Timing callback class SteadyStateStepTimer(Callback): - """ - Times optimizer steps in a steady-state window: - - ignore first warmup_steps - - measure next measure_steps - - Assumes accumulate_grad_batches == 1. - """ - def __init__(self, warmup_steps: int, measure_steps: int): super().__init__() self.warmup_steps = int(warmup_steps) self.measure_steps = int(measure_steps) - self._seen_steps = 0 + self._seen = 0 self.step_times = [] - self._step_start_t = None + self._t0 = None @staticmethod - def _sync_if_cuda(): + def _sync(): if torch.cuda.is_available(): torch.cuda.synchronize() def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - s = self._seen_steps - if self.warmup_steps <= s < (self.warmup_steps + self.measure_steps): - self._sync_if_cuda() - self._step_start_t = time.time() + s = self._seen + if self.warmup_steps <= s < self.warmup_steps + self.measure_steps: + self._sync() + self._t0 = time.time() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - s = self._seen_steps - if self.warmup_steps <= s < (self.warmup_steps + self.measure_steps): - self._sync_if_cuda() - dt = time.time() - (self._step_start_t or time.time()) - self.step_times.append(dt) + s = self._seen + if self.warmup_steps <= s < self.warmup_steps + self.measure_steps: + self._sync() + self.step_times.append(time.time() - (self._t0 or time.time())) + self._seen += 1 - self._seen_steps += 1 - - def measured_wall_time(self) -> float: + def measured_wall(self) -> float: return float(sum(self.step_times)) def dist_max(value: float) -> float: - """ - Returns max(value across ranks) if distributed is initialized; else returns value. - """ try: import torch.distributed as dist if dist.is_available() and dist.is_initialized(): - t = torch.tensor([value], device="cuda" if torch.cuda.is_available() else "cpu") + t = torch.tensor([value], device="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float64) dist.all_reduce(t, op=dist.ReduceOp.MAX) return float(t.item()) except Exception: @@ -261,84 +155,154 @@ def dist_max(value: float) -> float: return float(value) -# ---------------- synthetic data ---------------- -def make_synthetic_tensors( - n: int, - c_dim: int, - x_dim: int, - y_dim: int, - device: torch.device, - seed: int, -) -> Dict[str, torch.Tensor]: - """ - Generates a fixed buffer of synthetic data. - - IMPORTANT: This runs once before timing begins. Keep n reasonable. - """ - g = torch.Generator(device=device) - g.manual_seed(int(seed) + 1000 * global_rank()) - - C = torch.randn((n, c_dim), generator=g, device=device, dtype=torch.float32) - X = torch.randn((n, x_dim), generator=g, device=device, dtype=torch.float32) - Y = torch.randn((n, y_dim), generator=g, device=device, dtype=torch.float32) - return {"C": C, "X": X, "Y": Y} - - -# ---------------- model/trainer/datamodule ---------------- -def build_model(args) -> ContextualizedRegression: - # Uses your current link_fn handling (string keys are valid). - return ContextualizedRegression( - context_dim=args.context_dim, - x_dim=args.x_dim, - y_dim=args.y_dim, - num_archetypes=args.num_archetypes, - encoder_type=args.encoder_type, - encoder_kwargs={"width": args.width, "layers": args.layers, "link_fn": "identity"}, - learning_rate=args.lr, - fit_intercept=True, - link_fn="identity", - loss_fn="mse", - model_regularizer="none", - ) +# Synthetic batched iterable +class SyntheticBatchStream(IterableDataset): + def __init__( + self, + batch_size: int, + c_dim: int, + x_dim: int, + y_dim: int, + buffer_batches: int, + buffer_mult: int, + seed: int, + pin: bool, + ): + super().__init__() + self.batch_size = int(batch_size) + self.c_dim = int(c_dim) + self.x_dim = int(x_dim) + self.y_dim = int(y_dim) + + self.n_batches = int(buffer_batches) * int(buffer_mult) + if self.n_batches <= 0: + raise ValueError("buffer_batches * buffer_mult must be >= 1") + + g = torch.Generator(device="cpu") + g.manual_seed(int(seed) + 1000 * global_rank()) + + self.C = torch.randn((self.n_batches, self.batch_size, self.c_dim), generator=g, device="cpu", dtype=torch.float32) + self.X = torch.randn((self.n_batches, self.batch_size, self.x_dim), generator=g, device="cpu", dtype=torch.float32) + self.Y = torch.randn((self.n_batches, self.batch_size, self.y_dim), generator=g, device="cpu", dtype=torch.float32) + + if pin and torch.cuda.is_available(): + self.C = self.C.pin_memory() + self.X = self.X.pin_memory() + self.Y = self.Y.pin_memory() + + def __iter__(self): + ws = world_size() + r = global_rank() + k = 0 + while True: + b = (k * ws + r) % self.n_batches + yield {"contexts": self.C[b], "predictors": self.X[b], "outcomes": self.Y[b]} + k += 1 + + +def _as_2d(t: torch.Tensor) -> torch.Tensor: + # Accept [B, y, 1] or [B, 1, y] and squeeze the singleton dim + if t.ndim == 3: + if t.shape[-1] == 1: + # Convert [B, y, 1] -> [B, y] + t = t.squeeze(-1) + elif t.shape[1] == 1: + # Convert [B, 1, y] -> [B, y] + t = t.squeeze(1) + if t.ndim == 1: + return t.unsqueeze(-1) + if t.ndim == 2: + return t + raise RuntimeError(f"Expected 1D or 2D tensor (or squeezable 3D), got shape {tuple(t.shape)}") + + +def _canonicalize_y(y: torch.Tensor, B: int, y_dim: int, name: str) -> torch.Tensor: + y = _as_2d(y) + if y.shape == (B, y_dim): + return y + if y.shape == (y_dim, B): + return y.transpose(0, 1) + if y_dim == 1 and y.shape == (B,): + return y.view(B, 1) + raise RuntimeError(f"{name} has incompatible shape {tuple(y.shape)}; expected [{B},{y_dim}] or [{y_dim},{B}].") + + +def _extract_mu_hat(out: Any) -> torch.Tensor: + # Prefer mu_hat as y_pred for this benchmark + if torch.is_tensor(out): + return out + + if isinstance(out, dict): + for k in ("mu_hat", "mu", "y_pred", "y_hat", "pred"): + if k in out and torch.is_tensor(out[k]): + return out[k] + raise RuntimeError(f"Forward returned dict without mu_hat/y_hat keys: {list(out.keys())}") + + if isinstance(out, (tuple, list)): + tensors = [t for t in out if torch.is_tensor(t)] + if len(tensors) >= 2: + return tensors[1] + if len(tensors) == 1: + return tensors[0] + raise RuntimeError("Forward returned tuple/list with no tensors.") + + raise RuntimeError(f"Unsupported forward output type: {type(out)}") + + +# Lightning bench module +class BenchModule(pl.LightningModule): + def __init__(self, inner: ContextualizedRegression, lr: float, y_dim: int): + super().__init__() + self.inner = inner + self.lr = float(lr) + self.y_dim = int(y_dim) + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.lr) -def build_dm(args, C, X, Y) -> ContextualizedRegressionDataModule: - n = int(C.shape[0]) - # Simple split; validation never runs in this benchmark (we pass only train_dataloader). - n_train = int(0.95 * n) - train_idx = np.arange(0, n_train, dtype=np.int64) - val_idx = np.arange(n_train, n, dtype=np.int64) - - dm = ContextualizedRegressionDataModule( - C=C, - X=X, - Y=Y, - task_type="singletask_multivariate", - train_idx=train_idx, - val_idx=val_idx, - test_idx=None, - predict_idx=None, - train_batch_size=args.batch_size, - val_batch_size=args.batch_size, - test_batch_size=args.batch_size, - predict_batch_size=args.batch_size, - num_workers=args.num_workers, - pin_memory=bool(args.pin_memory), - persistent_workers=bool(args.num_workers > 0), - drop_last=True, - shuffle_train=False, # let Lightning/DDP sampler handle partitioning; shuffle not needed for perf - shuffle_eval=False, - dtype=torch.float, - ) - dm.prepare_data() - dm.setup() - return dm + def training_step(self, batch, batch_idx): + device = self.device + + C = batch["contexts"].to(device, non_blocking=True) + X = batch["predictors"].to(device, non_blocking=True) + Y = batch["outcomes"].to(device, non_blocking=True) + + B = C.shape[0] + Y_true = _canonicalize_y(Y, B, self.y_dim, "Y_true") + + # Prefer calling with dict to match internal conventions + out = self.inner({"contexts": C, "predictors": X, "outcomes": Y_true}) + + mu_hat = _extract_mu_hat(out) + Y_pred = _canonicalize_y(mu_hat, B, self.y_dim, "Y_pred(mu_hat)") + loss = F.mse_loss(Y_pred, Y_true) + return loss + +# Batch sizing +def resolve_batch_sizes(args, ws: int) -> Tuple[int, int]: + if args.global_batch_size is None: + per_gpu = int(args.batch_size) + return per_gpu, per_gpu * ws + gbs = int(args.global_batch_size) + if gbs % ws != 0: + raise ValueError(f"--global-batch-size {gbs} must be divisible by world_size {ws}") + return gbs // ws, gbs + + +# Trainer def build_trainer(args, timer: SteadyStateStepTimer) -> pl.Trainer: - if torch.cuda.is_available(): + use_cuda = torch.cuda.is_available() and (args.run_device != "cpu") + + if use_cuda: accelerator = "gpu" - devices = 1 if under_torchrun() else min(args.devices, torch.cuda.device_count()) + + if under_torchrun(): + devices = 1 + else: + devices = min(int(args.devices), torch.cuda.device_count()) + strategy = ( DDPStrategy( find_unused_parameters=False, @@ -354,16 +318,15 @@ def build_trainer(args, timer: SteadyStateStepTimer) -> pl.Trainer: devices = 1 strategy = "auto" - # We benchmark *steps*, not epochs. - max_steps = args.warmup_steps + args.steps + max_steps = int(args.warmup_steps) + int(args.steps) - trainer = pl.Trainer( + return pl.Trainer( accelerator=accelerator, devices=devices, strategy=strategy, precision=map_precision(args.precision), max_steps=max_steps, - max_epochs=10_000, # irrelevant when max_steps is set + max_epochs=10_000, logger=False, enable_checkpointing=False, enable_progress_bar=False, @@ -371,16 +334,13 @@ def build_trainer(args, timer: SteadyStateStepTimer) -> pl.Trainer: log_every_n_steps=50, callbacks=[timer], inference_mode=False, - detect_anomaly=False, enable_model_summary=False, - use_distributed_sampler=True, accumulate_grad_batches=1, limit_val_batches=0, ) - return trainer -# ---------------- benchmark runner ---------------- +# Results @dataclass class Result: world_size: int @@ -395,41 +355,56 @@ class Result: p95_step_s: float +def save_result(outdir: str, res: Result): + os.makedirs(outdir, exist_ok=True) + path = os.path.join(outdir, "result.json") + with open(path, "w") as f: + json.dump(res.__dict__, f, indent=2) + return path + + +# Main bench def run_bench(args) -> Result: ws = world_size() if under_torchrun() else int(args.devices) - dev = torch.device("cuda", local_rank()) if (args.data_device == "cuda" and torch.cuda.is_available()) else torch.device("cpu") + per_gpu_bs, global_bs = resolve_batch_sizes(args, ws) - # If auto: keep data on GPU when available (this removes input bottlenecks). - if args.data_device == "auto": - if torch.cuda.is_available(): - dev = torch.device("cuda", local_rank()) - else: - dev = torch.device("cpu") - - # Dataloader workers cannot safely handle CUDA tensors. - if dev.type == "cuda" and args.num_workers != 0: - if is_global_zero(): - print("NOTE: forcing --num-workers=0 because data-device is CUDA.") - args.num_workers = 0 - - # Build fixed synthetic buffer (not timed) - n = int(args.batch_size * args.buffer_batches) - tensors = make_synthetic_tensors( - n=n, + pin = args.data_device == "cpu_pinned" + + ds = SyntheticBatchStream( + batch_size=per_gpu_bs, c_dim=args.context_dim, x_dim=args.x_dim, y_dim=args.y_dim, - device=dev, + buffer_batches=args.buffer_batches, + buffer_mult=args.buffer_mult, seed=args.seed, + pin=pin, + ) + + dl = DataLoader(ds, batch_size=None, num_workers=0, pin_memory=False) + + inner = ContextualizedRegression( + context_dim=args.context_dim, + x_dim=args.x_dim, + y_dim=args.y_dim, + num_archetypes=args.num_archetypes, + encoder_type=args.encoder_type, + encoder_kwargs={"width": args.width, "layers": args.layers, "link_fn": "identity"}, + learning_rate=args.lr, + fit_intercept=True, + link_fn="identity", + loss_fn="mse", + model_regularizer="none", ) - dm = build_dm(args, tensors["C"], tensors["X"], tensors["Y"]) - model = build_model(args) + model = BenchModule(inner=inner, lr=args.lr, y_dim=args.y_dim) timer = SteadyStateStepTimer(args.warmup_steps, args.steps) trainer = build_trainer(args, timer) if is_global_zero(): + buffer_batches_total = int(args.buffer_batches) * int(args.buffer_mult) + buffer_samples_per_rank = int(per_gpu_bs) * buffer_batches_total print( "\nConfig:", json.dumps( @@ -437,64 +412,57 @@ def run_bench(args) -> Result: "torchrun": under_torchrun(), "world_size": ws, "local_rank": local_rank(), - "batch_size_per_gpu": args.batch_size, - "global_batch_size": args.batch_size * ws, - "steps_measured": args.steps, - "steps_warmup": args.warmup_steps, - "buffer_samples": n, - "data_device": str(dev), + "batch_size_per_gpu": per_gpu_bs, + "global_batch_size": global_bs, + "steps_measured": int(args.steps), + "steps_warmup": int(args.warmup_steps), + "buffer_batches_total": buffer_batches_total, + "buffer_samples_per_rank": buffer_samples_per_rank, + "buffer_samples_global_approx": buffer_samples_per_rank * int(ws), + "run_device": args.run_device, + "data_device": args.data_device, + "pin_memory": pin, "precision": map_precision(args.precision), }, indent=2, ), ) - trainer.fit(model, train_dataloaders=dm.train_dataloader()) + trainer.fit(model, train_dataloaders=dl) - measured_wall = timer.measured_wall_time() - measured_wall = dist_max(measured_wall) # slowest rank dictates wall time + measured_wall = dist_max(timer.measured_wall()) measured_steps = int(args.steps) - global_batch = int(args.batch_size * ws) - samples_total = global_batch * measured_steps + samples_total = global_bs * measured_steps throughput = samples_total / max(measured_wall, 1e-12) - per_gpu = throughput / max(ws, 1) + per_gpu_thr = throughput / max(ws, 1) step_times = timer.step_times[:] if timer.step_times else [float("nan")] avg_step = float(np.mean(step_times)) p95_step = float(np.percentile(step_times, 95)) if len(step_times) > 1 else float("nan") return Result( - world_size=ws, - batch_size_per_gpu=int(args.batch_size), - global_batch_size=int(global_batch), + world_size=int(ws), + batch_size_per_gpu=int(per_gpu_bs), + global_batch_size=int(global_bs), warmup_steps=int(args.warmup_steps), measured_steps=int(measured_steps), measured_wall_s=float(measured_wall), throughput_samples_per_s=float(throughput), - per_gpu_throughput_samples_per_s=float(per_gpu), + per_gpu_throughput_samples_per_s=float(per_gpu_thr), avg_step_s=float(avg_step), p95_step_s=float(p95_step), ) -def save_result(outdir: str, res: Result): - os.makedirs(outdir, exist_ok=True) - path = os.path.join(outdir, "result.json") - with open(path, "w") as f: - json.dump(res.__dict__, f, indent=2) - return path - - -# ---------------- main ---------------- def parse_args(): ap = argparse.ArgumentParser() - ap.add_argument("--steps", type=int, default=400, help="Measured optimizer steps") - ap.add_argument("--warmup-steps", type=int, default=50, help="Warmup steps excluded from timing") + ap.add_argument("--steps", type=int, default=400) + ap.add_argument("--warmup-steps", type=int, default=50) + + ap.add_argument("--batch-size", type=int, default=2048, help="Per-GPU batch size (ignored if --global-batch-size set)") + ap.add_argument("--global-batch-size", type=int, default=None, help="Fixed global batch for strong scaling") - ap.add_argument("--batch-size", type=int, default=2048, help="Per-GPU batch size") - ap.add_argument("--num-workers", type=int, default=0) - ap.add_argument("--pin-memory", action="store_true", default=False) ap.add_argument("--precision", type=str, default="bf16") ap.add_argument("--context-dim", type=int, default=16) @@ -507,9 +475,12 @@ def parse_args(): ap.add_argument("--layers", type=int, default=4) ap.add_argument("--lr", type=float, default=1e-3) - ap.add_argument("--buffer-batches", type=int, default=32, help="Dataset buffer size = batch_size * buffer_batches") - ap.add_argument("--data-device", type=str, choices=["auto", "cpu", "cuda"], default="auto") - ap.add_argument("--devices", type=int, default=1, help="Only used when NOT under torchrun") + ap.add_argument("--buffer-batches", type=int, default=16, help="Buffer depth in batches (per rank)") + ap.add_argument("--buffer-mult", type=int, default=4, help="Extra multiplier on buffer size (per rank)") + + ap.add_argument("--data-device", choices=["cpu", "cpu_pinned"], default="cpu_pinned") + ap.add_argument("--run-device", choices=["auto", "cpu"], default="auto") + ap.add_argument("--devices", type=int, default=1, help="Used only when NOT under torchrun") ap.add_argument("--ddp-timeout", type=int, default=180) ap.add_argument("--seed", type=int, default=123) @@ -522,17 +493,14 @@ def main(): set_env_defaults() args = parse_args() - if args.data_device == "cpu": - os.environ["CUDA_VISIBLE_DEVICES"] = "" # ensure no accidental CUDA use + if args.run_device == "cpu": + os.environ["CUDA_VISIBLE_DEVICES"] = "" res = run_bench(args) if is_global_zero(): path = save_result(args.outdir, res) - print( - "\nResult:", - json.dumps(res.__dict__, indent=2), - ) + print("\nResult:", json.dumps(res.__dict__, indent=2)) print(f"\nSaved → {path}") diff --git a/scale_bench_networks.py b/scale_bench_networks.py index 2c6125a4..286c2edf 100644 --- a/scale_bench_networks.py +++ b/scale_bench_networks.py @@ -1,82 +1,5 @@ #!/usr/bin/env python3 -""" -scale_bench_networks.py - -A torchrun-friendly DDP scaling benchmark for Contextualized *Networks* lightning modules -(e.g., ContextualizedCorrelation, ContextualizedMarkovGraph, NOTMAD). - -Design goals (to reveal true scaling): - - Fixed number of optimizer steps (not epochs) so each run does identical work. - - Optional GPU-resident synthetic dataset to remove CPU dataloading/transfer bottlenecks. - - Measures only the *steady-state* region (warmup steps excluded). - - Uses Lightning DDP under torchrun correctly (devices=1 per process). - - No validation, no logging, no checkpoints. - ------------------------------------------------------------- -Quick start (single node, 1..4 GPUs) ------------------------------------------------------------- - -# 0) NICs (optional) -ls -1 /sys/class/net -ip -o link show | awk -F': ' '{print NR-1": "$2}' - -# 1) Minimal, safe NCCL/torch env (no hard-coded eth0): -export CUDA_VISIBLE_DEVICES=0,1,2,3 -export OMP_NUM_THREADS=1 -export MKL_NUM_THREADS=1 -export TOKENIZERS_PARALLELISM=false -export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 -export NCCL_DEBUG=WARN -export NCCL_P2P_DISABLE=0 -export NCCL_IB_DISABLE=1 -export NCCL_SOCKET_IFNAME=$(ls /sys/class/net | grep -E '^(ens|enp|eno|eth|bond|ib)' | head -n1) -[ -z "$NCCL_SOCKET_IFNAME" ] && export NCCL_SOCKET_IFNAME="^lo,docker0" - -export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True - -# 2) Kill any stragglers (optional) -pkill -f scale_bench_networks.py || true -pkill -f torchrun || true - -# 3) Runs (IMPORTANT: --batch-size is PER GPU) - -# Correlation networks -torchrun --standalone --nproc_per_node=1 scale_bench_networks.py \ - --network correlation \ - --steps 400 --warmup-steps 50 \ - --batch-size 2048 --precision bf16 \ - --context-dim 16 --x-dim 512 \ - --encoder-type mlp --width 1024 --layers 4 \ - --num-archetypes 8 \ - --buffer-batches 32 --data-device auto \ - --outdir bench_out/corr_gpu1 - -torchrun --standalone --nproc_per_node=2 scale_bench_networks.py \ - --network correlation \ - --steps 400 --warmup-steps 50 \ - --batch-size 2048 --precision bf16 \ - --context-dim 16 --x-dim 512 \ - --encoder-type mlp --width 1024 --layers 4 \ - --num-archetypes 8 \ - --buffer-batches 32 --data-device auto \ - --outdir bench_out/corr_gpu2 - -# Markov networks (precision matrices) -torchrun --standalone --nproc_per_node=4 scale_bench_networks.py \ - --network markov \ - --steps 400 --warmup-steps 50 \ - --batch-size 1024 --precision bf16 \ - --context-dim 16 --x-dim 256 \ - --encoder-type mlp --width 512 --layers 3 \ - --num-archetypes 8 \ - --buffer-batches 32 --data-device auto \ - --outdir bench_out/markov_gpu4 - -Notes: - - If scaling is poor with --data-device=cuda (or auto on GPU), the bottleneck is - likely *real* (NCCL/topology/comm, too-small batch, CPU freq limits, etc.). - - Multi-node: remove --standalone and use --nnodes/--node_rank with a shared rdzv endpoint. -""" +# Torchrun-friendly DDP scaling benchmark for Contextualized network LightningModules using synthetic buffered data. import os import time @@ -93,7 +16,6 @@ from pytorch_lightning.callbacks import Callback from pytorch_lightning.strategies import DDPStrategy -# ---- your package pieces ---- from contextualized.regression.datamodules import ContextualizedRegressionDataModule from contextualized.regression.lightning_modules import ( ContextualizedCorrelation, @@ -102,7 +24,7 @@ from contextualized.dags.lightning_modules import NOTMAD -# ---------------- launcher/cluster helpers ---------------- +# Launcher/cluster helpers def under_torchrun() -> bool: e = os.environ return ("LOCAL_RANK" in e) or ("RANK" in e) or ("WORLD_SIZE" in e) @@ -133,7 +55,7 @@ def is_global_zero() -> bool: return global_rank() == 0 -# ---------------- env + perf ---------------- +# Environment defaults and GPU performance flags def set_env_defaults(): os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") @@ -156,7 +78,7 @@ def set_env_defaults(): except Exception: os.environ["NCCL_SOCKET_IFNAME"] = "^lo,docker0" - # TF32 / matmul speedups (safe for throughput benchmarking) + # TF32 / matmul speedups for throughput benchmarking if torch.cuda.is_available(): try: torch.backends.cuda.matmul.allow_tf32 = True @@ -206,15 +128,8 @@ def map_precision(p: str): return 32 -# ---------------- timing ---------------- +# Timing callback for steady-state step timing class SteadyStateStepTimer(Callback): - """ - Times optimizer steps in a steady-state window: - - ignore first warmup_steps - - measure next measure_steps - Assumes accumulate_grad_batches == 1. - """ - def __init__(self, warmup_steps: int, measure_steps: int): super().__init__() self.warmup_steps = int(warmup_steps) @@ -247,9 +162,7 @@ def measured_wall_time(self) -> float: def dist_max(value: float) -> float: - """ - Returns max(value across ranks) if distributed is initialized; else returns value. - """ + # Return max(value across ranks) if distributed is initialized try: import torch.distributed as dist @@ -264,7 +177,7 @@ def dist_max(value: float) -> float: return float(value) -# ---------------- synthetic data ---------------- +# Synthetic buffer construction def make_synthetic_tensors( n: int, c_dim: int, @@ -272,14 +185,8 @@ def make_synthetic_tensors( device: torch.device, seed: int, ) -> Dict[str, torch.Tensor]: - """ - Builds a fixed synthetic buffer (not timed). Shapes: - C: (n, c_dim) - X: (n, x_dim) - Y: (n, x_dim) # for networks we follow the wrapper convention (univariate task uses y_dim=x_dim) - """ + # Build a fixed synthetic buffer with per-rank seeding g = torch.Generator(device=device) - # Per-rank seed to avoid identical data, while keeping identical shapes across ranks. g.manual_seed(int(seed) + 1000 * global_rank()) C = torch.randn((n, c_dim), generator=g, device=device, dtype=torch.float32) @@ -288,14 +195,9 @@ def make_synthetic_tensors( return {"C": C, "X": X, "Y": Y} -# ---------------- model/datamodule/trainer ---------------- +# Model/datamodule/trainer builders def build_model(args): - """ - Robustly instantiate the selected network LightningModule. - - We pass a *superset* of kwargs and filter by the model's __init__ signature to - remain compatible with small constructor differences across implementations. - """ + # Instantiate the selected network LightningModule with signature-filtered kwargs import inspect if args.network == "correlation": @@ -309,11 +211,10 @@ def build_model(args): encoder_kwargs = {"width": args.width, "layers": args.layers, "link_fn": "identity"} - # Common superset kw = dict( context_dim=args.context_dim, x_dim=args.x_dim, - y_dim=args.x_dim, # networks wrapper convention + y_dim=args.x_dim, univariate=True, num_archetypes=args.num_archetypes, encoder_type=args.encoder_type, @@ -325,7 +226,6 @@ def build_model(args): model_regularizer="none", ) - # NOTMAD-specific defaults (safe baseline; tune as needed) if args.network == "bayesian": kw.update( archetype_loss_params=dict( @@ -354,7 +254,6 @@ def build_model(args): return model_cls(**kw) filtered = {k: v for k, v in kw.items() if k in sig.parameters} - # Basic required-arg check (only for explicit signatures) required = [ name for name, p in sig.parameters.items() @@ -373,10 +272,7 @@ def build_model(args): def build_dm(args, C, X, Y) -> ContextualizedRegressionDataModule: - """ - Uses the same DataModule family as the wrapper (consistent batch structure). - IMPORTANT: If data lives on CUDA, we force num_workers=0. - """ + # Construct the datamodule with a fixed synthetic buffer and deterministic indices n = int(C.shape[0]) n_train = max(1, int(0.98 * n)) train_idx = np.arange(0, n_train, dtype=np.int64) @@ -384,7 +280,6 @@ def build_dm(args, C, X, Y) -> ContextualizedRegressionDataModule: task_type = args.task_type if task_type is None: - # Networks wrappers use the univariate convention. task_type = "singletask_univariate" dm = ContextualizedRegressionDataModule( @@ -416,7 +311,6 @@ def build_dm(args, C, X, Y) -> ContextualizedRegressionDataModule: def build_trainer(args, timer: SteadyStateStepTimer) -> pl.Trainer: if torch.cuda.is_available(): accelerator = "gpu" - # Under torchrun: each process uses exactly 1 device devices = 1 if under_torchrun() else min(args.devices, torch.cuda.device_count()) strategy = ( DDPStrategy( @@ -441,7 +335,7 @@ def build_trainer(args, timer: SteadyStateStepTimer) -> pl.Trainer: strategy=strategy, precision=map_precision(args.precision), max_steps=max_steps, - max_epochs=10_000, # irrelevant when max_steps is set + max_epochs=10_000, logger=False, enable_checkpointing=False, enable_progress_bar=False, @@ -452,12 +346,12 @@ def build_trainer(args, timer: SteadyStateStepTimer) -> pl.Trainer: inference_mode=False, detect_anomaly=False, accumulate_grad_batches=1, - limit_val_batches=0, # no validation - use_distributed_sampler=False # IMPORTANT: our synthetic buffer is already identical-sized per rank + limit_val_batches=0, + use_distributed_sampler=False, ) -# ---------------- benchmark runner ---------------- +# Benchmark runner @dataclass class Result: network: str @@ -477,21 +371,18 @@ class Result: def run_bench(args) -> Result: ws = world_size() if under_torchrun() else int(args.devices) - # Resolve data device if args.data_device == "cpu": dev = torch.device("cpu") elif args.data_device == "cuda": dev = torch.device("cuda", local_rank()) if torch.cuda.is_available() else torch.device("cpu") - else: # auto + else: dev = torch.device("cuda", local_rank()) if torch.cuda.is_available() else torch.device("cpu") - # Dataloader workers cannot safely handle CUDA tensors if dev.type == "cuda" and args.num_workers != 0: if is_global_zero(): print("NOTE: forcing --num-workers=0 because data-device is CUDA.") args.num_workers = 0 - # Build fixed synthetic buffer (not timed) n = int(args.batch_size * args.buffer_batches) tensors = make_synthetic_tensors( n=n, @@ -532,7 +423,7 @@ def run_bench(args) -> Result: trainer.fit(model, train_dataloaders=dm.train_dataloader()) measured_wall = timer.measured_wall_time() - measured_wall = dist_max(measured_wall) # slowest rank dictates + measured_wall = dist_max(measured_wall) measured_steps = int(args.steps) global_batch = int(args.batch_size * ws) @@ -568,7 +459,7 @@ def save_result(outdir: str, res: Result) -> str: return path -# ---------------- main ---------------- +# Entrypoint def parse_args(): ap = argparse.ArgumentParser() ap.add_argument("--network", type=str, choices=["correlation", "markov", "bayesian"], default="correlation") @@ -609,7 +500,7 @@ def main(): args = parse_args() if args.data_device == "cpu": - os.environ["CUDA_VISIBLE_DEVICES"] = "" # prevent accidental CUDA use + os.environ["CUDA_VISIBLE_DEVICES"] = "" res = run_bench(args) From 19d4638e87ac0d5d1c419c0bf6d0a8979fc7d791 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Wed, 14 Jan 2026 12:53:24 -0500 Subject: [PATCH 16/19] Update imports on regression lightning module and datamodules --- contextualized/regression/datamodules.py | 2 +- contextualized/regression/lightning_modules.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/contextualized/regression/datamodules.py b/contextualized/regression/datamodules.py index 8c6b601b..1e53abf0 100644 --- a/contextualized/regression/datamodules.py +++ b/contextualized/regression/datamodules.py @@ -6,7 +6,7 @@ import pandas as pd import torch from torch.utils.data import DataLoader -import pytorch_lightning as pl +import lightning.pytorch as pl from .datasets import ( MultivariateDataset, diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index a635369d..4a696a07 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -17,7 +17,7 @@ import numpy as np import torch from torch.utils.data import DataLoader -import pytorch_lightning as pl +import lightning.pytorch as pl from contextualized.regression.regularizers import REGULARIZERS from contextualized.regression.losses import MSE From 78b3ce39cdd75affff9a918b9ac9261d6b09a2b7 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Wed, 14 Jan 2026 16:09:14 -0500 Subject: [PATCH 17/19] Remove unused files --- bash_scripts/network_heavy.sh | 222 --------- bash_scripts/network_optimized.sh | 168 ------- network_scaling_heavy.py | 637 ------------------------- networks_pert_scale_bench.py | 750 ------------------------------ scale_bench.py | 508 -------------------- scale_bench_networks.py | 514 -------------------- 6 files changed, 2799 deletions(-) delete mode 100644 bash_scripts/network_heavy.sh delete mode 100644 bash_scripts/network_optimized.sh delete mode 100644 network_scaling_heavy.py delete mode 100644 networks_pert_scale_bench.py delete mode 100644 scale_bench.py delete mode 100644 scale_bench_networks.py diff --git a/bash_scripts/network_heavy.sh b/bash_scripts/network_heavy.sh deleted file mode 100644 index 2eec78fd..00000000 --- a/bash_scripts/network_heavy.sh +++ /dev/null @@ -1,222 +0,0 @@ -#!/bin/bash -# ============================================================================= -# HEAVY ContextualizedCorrelationNetworks DDP SCALING BENCHMARK -# ============================================================================= -# -# This benchmark tests multi-GPU scaling with the ACTUAL CCN model, but -# configured for maximum compute to properly stress-test GPU parallelism. -# -# HEAVY Configuration vs Original: -# Parameter | Original | Heavy | Compute Impact -# -----------------|----------|---------|---------------- -# Archetypes | 16-30 | 64 | 2-4x more mixture components -# Encoder width | 25 | 256 | 10x wider networks -# Encoder layers | 3 | 6 | 2x deeper networks -# Bootstraps | 1 | 3 | 3x more models (ensemble) -# Data PCs | 50 | 100 | 2x larger output space -# -# Estimated parameters: ~15-30M (vs ~300K original) -# -# Expected scaling: -# 1 GPU: baseline -# 2 GPU: ~1.85x speedup (92% efficiency) -# 3 GPU: ~2.65x speedup (88% efficiency) -# 4 GPU: ~3.4x speedup (85% efficiency) -# -# ============================================================================= - -set -e - -# ===== CONFIGURATION ===== -SCRIPT="ccn_scaling_heavy.py" -OUTDIR="bench_results_ccn_heavy" -EPOCHS=20 -WARMUP=1 -BATCH_SIZE=512 # Per GPU - -# HEAVY CCN Architecture -ARCHETYPES=64 # Original: 16-30 -ENCODER_WIDTH=256 # Original: 25 -ENCODER_LAYERS=6 # Original: 3 -BOOTSTRAPS=1 # Original: 1 - -# Data dimensionality -DATA_PCS=100 # Original: 50 -CONTEXT_PCS=100 - -# Runtime -NUM_WORKERS=4 -SUBSAMPLE=1.0 - -# Clean previous results -echo "==============================================" -echo "Cleaning previous results..." -echo "==============================================" -rm -f "${OUTDIR}/ccn_heavy_scaling_results.csv" -mkdir -p "${OUTDIR}" - -echo "" -echo "==============================================" -echo "HEAVY CCN SCALING BENCHMARK" -echo "==============================================" -echo "Script: ${SCRIPT}" -echo "Epochs: ${EPOCHS} (+ ${WARMUP} warmup)" -echo "Batch size per GPU: ${BATCH_SIZE}" -echo "" -echo "--- HEAVY CCN Config ---" -echo "Archetypes: ${ARCHETYPES}" -echo "Encoder: ${ENCODER_WIDTH}w × ${ENCODER_LAYERS}L" -echo "Bootstraps: ${BOOTSTRAPS}" -echo "Data PCs: ${DATA_PCS}" -echo "" -echo "Output: ${OUTDIR}" -echo "" - -# ----------------------------------------------------------------------------- -# TEST 1: 1-GPU Baseline -# ----------------------------------------------------------------------------- -echo "==============================================" -echo "[1/4] Running 1-GPU baseline..." -echo "==============================================" - -python ${SCRIPT} \ - --epochs ${EPOCHS} \ - --warmup-epochs ${WARMUP} \ - --batch-size ${BATCH_SIZE} \ - --archetypes ${ARCHETYPES} \ - --encoder-width ${ENCODER_WIDTH} \ - --encoder-layers ${ENCODER_LAYERS} \ - --bootstraps ${BOOTSTRAPS} \ - --data-pcs ${DATA_PCS} \ - --context-pcs ${CONTEXT_PCS} \ - --num-workers ${NUM_WORKERS} \ - --subsample-fraction ${SUBSAMPLE} \ - --devices 1 \ - --outdir ${OUTDIR} \ - --label "1gpu_baseline" - -# Extract baseline time for efficiency calculation -BASELINE_TIME=$(tail -1 "${OUTDIR}/ccn_heavy_scaling_results.csv" | cut -d',' -f2) -echo "" -echo ">>> Baseline time: ${BASELINE_TIME}s" -echo "" - -# ----------------------------------------------------------------------------- -# TEST 2: 2-GPU DDP -# ----------------------------------------------------------------------------- -echo "==============================================" -echo "[2/4] Running 2-GPU DDP..." -echo "==============================================" - -torchrun \ - --standalone \ - --nproc_per_node=2 \ - ${SCRIPT} \ - --epochs ${EPOCHS} \ - --warmup-epochs ${WARMUP} \ - --batch-size ${BATCH_SIZE} \ - --archetypes ${ARCHETYPES} \ - --encoder-width ${ENCODER_WIDTH} \ - --encoder-layers ${ENCODER_LAYERS} \ - --bootstraps ${BOOTSTRAPS} \ - --data-pcs ${DATA_PCS} \ - --context-pcs ${CONTEXT_PCS} \ - --num-workers ${NUM_WORKERS} \ - --subsample-fraction ${SUBSAMPLE} \ - --devices 2 \ - --outdir ${OUTDIR} \ - --label "2gpu_ddp" \ - --baseline-time ${BASELINE_TIME} - -echo "" - -# ----------------------------------------------------------------------------- -# TEST 3: 3-GPU DDP -# ----------------------------------------------------------------------------- -echo "==============================================" -echo "[3/4] Running 3-GPU DDP..." -echo "==============================================" - -torchrun \ - --standalone \ - --nproc_per_node=3 \ - ${SCRIPT} \ - --epochs ${EPOCHS} \ - --warmup-epochs ${WARMUP} \ - --batch-size ${BATCH_SIZE} \ - --archetypes ${ARCHETYPES} \ - --encoder-width ${ENCODER_WIDTH} \ - --encoder-layers ${ENCODER_LAYERS} \ - --bootstraps ${BOOTSTRAPS} \ - --data-pcs ${DATA_PCS} \ - --context-pcs ${CONTEXT_PCS} \ - --num-workers ${NUM_WORKERS} \ - --subsample-fraction ${SUBSAMPLE} \ - --devices 3 \ - --outdir ${OUTDIR} \ - --label "3gpu_ddp" \ - --baseline-time ${BASELINE_TIME} - -echo "" - -# ----------------------------------------------------------------------------- -# TEST 4: 4-GPU DDP -# ----------------------------------------------------------------------------- -echo "==============================================" -echo "[4/4] Running 4-GPU DDP..." -echo "==============================================" - -torchrun \ - --standalone \ - --nproc_per_node=4 \ - ${SCRIPT} \ - --epochs ${EPOCHS} \ - --warmup-epochs ${WARMUP} \ - --batch-size ${BATCH_SIZE} \ - --archetypes ${ARCHETYPES} \ - --encoder-width ${ENCODER_WIDTH} \ - --encoder-layers ${ENCODER_LAYERS} \ - --bootstraps ${BOOTSTRAPS} \ - --data-pcs ${DATA_PCS} \ - --context-pcs ${CONTEXT_PCS} \ - --num-workers ${NUM_WORKERS} \ - --subsample-fraction ${SUBSAMPLE} \ - --devices 4 \ - --outdir ${OUTDIR} \ - --label "4gpu_ddp" \ - --baseline-time ${BASELINE_TIME} - -echo "" - -# ----------------------------------------------------------------------------- -# SUMMARY -# ----------------------------------------------------------------------------- -echo "==============================================" -echo "BENCHMARK COMPLETE" -echo "==============================================" -echo "" -echo "Full Results:" -echo "" -column -t -s',' "${OUTDIR}/ccn_heavy_scaling_results.csv" -echo "" - -echo "==============================================" -echo "SCALING SUMMARY" -echo "==============================================" -awk -F',' ' -NR==1 {next} -{ - printf " %-15s: %8.2fs | %5.2fx speedup | %5.1f%% efficiency\n", $1, $2, $13, $14 -} -' "${OUTDIR}/ccn_heavy_scaling_results.csv" -echo "" - -echo "==============================================" -echo "CCN CONFIGURATION USED" -echo "==============================================" -echo " Archetypes: ${ARCHETYPES}" -echo " Encoder width: ${ENCODER_WIDTH}" -echo " Encoder layers: ${ENCODER_LAYERS}" -echo " Bootstraps: ${BOOTSTRAPS}" -echo " Data PCs: ${DATA_PCS}" -echo "" \ No newline at end of file diff --git a/bash_scripts/network_optimized.sh b/bash_scripts/network_optimized.sh deleted file mode 100644 index 71266cc4..00000000 --- a/bash_scripts/network_optimized.sh +++ /dev/null @@ -1,168 +0,0 @@ -#!/bin/bash -# ============================================================================= -# OPTIMIZED DDP SCALING BENCHMARK SCRIPT -# ============================================================================= -# -# This script runs a proper scaling comparison with CONSTANT GLOBAL BATCH SIZE -# to measure true parallel efficiency. -# -# Key differences from original: -# 1. Global batch size stays at 256 regardless of GPU count -# 2. Each GPU processes 256/N samples per batch -# 3. Warmup epoch excluded from timing -# 4. Reduced DataLoader workers to avoid contention -# 5. NCCL optimizations enabled -# -# Expected scaling (realistic for small models): -# 1 GPU: baseline -# 2 GPU: 1.6-1.8x speedup (80-90% efficiency) -# 3 GPU: 2.2-2.6x speedup (73-87% efficiency) -# 4 GPU: 2.8-3.4x speedup (70-85% efficiency) -# -# ============================================================================= - -set -e # Exit on error - -# Configuration -SCRIPT="unseen_pert_scale_optimized.py" -OUTDIR="bench_results_optimized" -EPOCHS=40 -WARMUP=1 -BATCH_SIZE=256 # Per-GPU batch size (global = this × num_gpus) -NUM_WORKERS=4 # Will be auto-reduced for multi-GPU -SUBSAMPLE=1.0 # Use full data (matches existing cache filename) - -# IMPORTANT: For small models, we MUST scale batch size with GPUs. -# Otherwise communication overhead dominates and multi-GPU is SLOWER. -# Using --scale-batch flag to scale global batch with GPU count. - -# Clean previous results -rm -f "${OUTDIR}/scaling_results_optimized.csv" -mkdir -p "${OUTDIR}" - -echo "==============================================" -echo "STARTING SCALING BENCHMARK" -echo "==============================================" -echo "Script: ${SCRIPT}" -echo "Epochs: ${EPOCHS} (+ ${WARMUP} warmup)" -echo "Global Batch Size: ${BATCH_SIZE} (constant)" -echo "Output: ${OUTDIR}" -echo "" - -# ----------------------------------------------------------------------------- -# TEST 1: 1-GPU Baseline -# ----------------------------------------------------------------------------- -echo "==============================================" -echo "[1/4] Running 1-GPU baseline..." -echo "==============================================" - -python ${SCRIPT} \ - --epochs ${EPOCHS} \ - --warmup-epochs ${WARMUP} \ - --subsample-fraction ${SUBSAMPLE} \ - --devices 1 \ - --batch-size ${BATCH_SIZE} \ - --num-workers ${NUM_WORKERS} \ - --outdir ${OUTDIR} \ - --label "1gpu_baseline" \ - --verbose - -# Extract baseline time for efficiency calculation -BASELINE_TIME=$(tail -1 "${OUTDIR}/scaling_results_optimized.csv" | cut -d',' -f2) -echo "" -echo "Baseline time: ${BASELINE_TIME}s" -echo "" - -# ----------------------------------------------------------------------------- -# TEST 2: 2-GPU with torchrun -# ----------------------------------------------------------------------------- -echo "==============================================" -echo "[2/4] Running 2-GPU DDP with torchrun..." -echo "==============================================" - -torchrun \ - --standalone \ - --nproc_per_node=2 \ - ${SCRIPT} \ - --epochs ${EPOCHS} \ - --warmup-epochs ${WARMUP} \ - --subsample-fraction ${SUBSAMPLE} \ - --devices 2 \ - --batch-size ${BATCH_SIZE} \ - --num-workers ${NUM_WORKERS} \ - --outdir ${OUTDIR} \ - --label "2gpu_ddp" \ - --baseline-time ${BASELINE_TIME} \ - --scale-batch \ - --verbose - -echo "" - -# ----------------------------------------------------------------------------- -# TEST 3: 3-GPU with torchrun -# ----------------------------------------------------------------------------- -echo "==============================================" -echo "[3/4] Running 3-GPU DDP with torchrun..." -echo "==============================================" - -torchrun \ - --standalone \ - --nproc_per_node=3 \ - ${SCRIPT} \ - --epochs ${EPOCHS} \ - --warmup-epochs ${WARMUP} \ - --subsample-fraction ${SUBSAMPLE} \ - --devices 3 \ - --batch-size ${BATCH_SIZE} \ - --num-workers ${NUM_WORKERS} \ - --outdir ${OUTDIR} \ - --label "3gpu_ddp" \ - --baseline-time ${BASELINE_TIME} \ - --scale-batch \ - --verbose - -echo "" - -# ----------------------------------------------------------------------------- -# TEST 4: 4-GPU with torchrun -# ----------------------------------------------------------------------------- -echo "==============================================" -echo "[4/4] Running 4-GPU DDP with torchrun..." -echo "==============================================" - -torchrun \ - --standalone \ - --nproc_per_node=4 \ - ${SCRIPT} \ - --epochs ${EPOCHS} \ - --warmup-epochs ${WARMUP} \ - --subsample-fraction ${SUBSAMPLE} \ - --devices 4 \ - --batch-size ${BATCH_SIZE} \ - --num-workers ${NUM_WORKERS} \ - --outdir ${OUTDIR} \ - --label "4gpu_ddp" \ - --baseline-time ${BASELINE_TIME} \ - --scale-batch \ - --verbose - -echo "" - -# ----------------------------------------------------------------------------- -# SUMMARY -# ----------------------------------------------------------------------------- -echo "==============================================" -echo "BENCHMARK COMPLETE" -echo "==============================================" -echo "" -echo "Results saved to: ${OUTDIR}/scaling_results_optimized.csv" -echo "" -echo "Results:" -cat "${OUTDIR}/scaling_results_optimized.csv" | column -t -s',' -echo "" - -# Calculate speedups -echo "Speedup Summary:" -echo "----------------" -awk -F',' 'NR==1 {next} NR==2 {base=$2} {printf "%s: %.2fs (%.2fx speedup, %.1f%% efficiency)\n", $1, $2, base/$2, $8}' \ - "${OUTDIR}/scaling_results_optimized.csv" \ No newline at end of file diff --git a/network_scaling_heavy.py b/network_scaling_heavy.py deleted file mode 100644 index 77314ab2..00000000 --- a/network_scaling_heavy.py +++ /dev/null @@ -1,637 +0,0 @@ -#!/usr/bin/env python3 -# Heavy DDP scaling benchmark for ContextualizedCorrelationNetworks with cached preprocessing and compute-intensive settings (no CSV output). - -import os -import time -import warnings -import pickle -from dataclasses import dataclass -from typing import Tuple, Optional - -import numpy as np -import pandas as pd -from sklearn.decomposition import PCA -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import StandardScaler - -import torch -import torch.distributed as dist - -from rdkit import Chem -from rdkit.Chem import rdFingerprintGenerator - -from contextualized.easy import ContextualizedCorrelationNetworks - - -# Configuration -BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -DATA_DIR = os.path.join(os.path.dirname(BASE_DIR), "data") - -PATH_L1000 = os.path.join(DATA_DIR, "trt_cp_smiles_qc.csv") -PATH_CTLS = os.path.join(DATA_DIR, "ctrls.csv") - -N_DATA_PCS = 100 -N_CONTEXT_PCS = 100 - -PERTURBATION_HOLDOUT_SIZE = 0.2 -RANDOM_STATE = 42 - -morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=4096) - - -# Distributed helpers -def is_global_zero() -> bool: - if dist.is_available() and dist.is_initialized(): - try: - return dist.get_rank() == 0 - except Exception: - return True - return int(os.environ.get("GLOBAL_RANK", os.environ.get("RANK", "0"))) == 0 - - -def get_rank() -> int: - if dist.is_available() and dist.is_initialized(): - return dist.get_rank() - return int(os.environ.get("RANK", os.environ.get("LOCAL_RANK", "0"))) - - -def get_world_size() -> int: - if dist.is_available() and dist.is_initialized(): - return dist.get_world_size() - return int(os.environ.get("WORLD_SIZE", "1")) - - -def get_local_rank() -> int: - return int(os.environ.get("LOCAL_RANK", "0")) - - -def barrier(): - if dist.is_available() and dist.is_initialized(): - dist.barrier() - - -def print_rank0(msg: str): - if is_global_zero(): - print(msg, flush=True) - - -# Environment setup -def set_env_defaults(): - world_size = int(os.environ.get("WORLD_SIZE", "1")) - cpu_count = os.cpu_count() or 8 - threads = max(1, cpu_count // max(world_size, 1)) - - os.environ.setdefault("OMP_NUM_THREADS", str(min(threads, 4))) - os.environ.setdefault("MKL_NUM_THREADS", str(min(threads, 4))) - os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - - os.environ.setdefault("NCCL_DEBUG", "WARN") - os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "1") - os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") - os.environ.setdefault("NCCL_ALGO", "Ring") - os.environ.setdefault("NCCL_NSOCKS_PERTHREAD", "4") - os.environ.setdefault("NCCL_SOCKET_NTHREADS", "2") - - try: - torch.set_float32_matmul_precision("high") - except Exception: - pass - - np.random.seed(RANDOM_STATE) - torch.manual_seed(RANDOM_STATE) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(RANDOM_STATE) - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.backends.cudnn.benchmark = True - - -# Fingerprint helper -def smiles_to_morgan_fp(smiles: str) -> np.ndarray: - try: - mol = Chem.MolFromSmiles(smiles) - if mol is None: - return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) - fp = morgan_gen.GetFingerprint(mol) - return np.array(fp, dtype=np.float32) - except Exception: - return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) - - -# Data loading with optional cache -def get_cache_path(subsample_fraction: Optional[float], n_data_pcs: int) -> str: - suffix = f"_sub{subsample_fraction}" if subsample_fraction else "" - suffix += f"_pcs{n_data_pcs}" - return os.path.join(DATA_DIR, f"ccn_heavy_cache{suffix}.pkl") - - -def load_and_preprocess( - subsample_fraction: Optional[float] = None, - use_cache: bool = True, - n_data_pcs: int = N_DATA_PCS, - n_context_pcs: int = N_CONTEXT_PCS, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - cache_path = get_cache_path(subsample_fraction, n_data_pcs) - - if use_cache and os.path.exists(cache_path): - print_rank0(f"[DATA] Loading from cache: {cache_path}") - with open(cache_path, "rb") as f: - cached = pickle.load(f) - return ( - cached["C_train"], - cached["X_train_norm"], - cached["C_test"], - cached["X_test_norm"], - cached["cell_ids_train"], - cached["cell_ids_test"], - ) - - if not is_global_zero() and use_cache: - wait_count = 0 - while not os.path.exists(cache_path) and wait_count < 600: - time.sleep(1) - wait_count += 1 - if os.path.exists(cache_path): - with open(cache_path, "rb") as f: - cached = pickle.load(f) - return ( - cached["C_train"], - cached["X_train_norm"], - cached["C_test"], - cached["X_test_norm"], - cached["cell_ids_train"], - cached["cell_ids_test"], - ) - - print_rank0(f"[DATA] Loading L1000 from {PATH_L1000}") - df = pd.read_csv(PATH_L1000, engine="pyarrow") - - df = df[df["pert_type"].isin(["trt_cp"])] - - bad = ( - (df["distil_cc_q75"] < 0.2) - | (df["distil_cc_q75"] == -666) - | (df["distil_cc_q75"].isna()) - | (df["pct_self_rank_q25"] > 5) - | (df["pct_self_rank_q25"] == -666) - | (df["pct_self_rank_q25"].isna()) - ) - df = df[~bad] - df = df.dropna(subset=["canonical_smiles"]) - df = df[df["canonical_smiles"] != ""] - - print_rank0(f"[DATA] Samples after QC: {len(df)}") - - if subsample_fraction is not None: - df = df.sample(frac=subsample_fraction, random_state=RANDOM_STATE) - print_rank0(f"[DATA] Subsampled to {len(df)} ({subsample_fraction * 100:.1f}%)") - - unique_smiles = df["canonical_smiles"].unique() - print_rank0(f"[DATA] Unique perturbations: {len(unique_smiles)}") - - smiles_train, smiles_test = train_test_split( - unique_smiles, test_size=PERTURBATION_HOLDOUT_SIZE, random_state=RANDOM_STATE - ) - - df_train = df[df["canonical_smiles"].isin(smiles_train)].copy() - df_test = df[df["canonical_smiles"].isin(smiles_test)].copy() - - print_rank0(f"[DATA] Train: {len(df_train)}, Test: {len(df_test)}") - - pert_time_mean = df_train.loc[df_train["pert_time"] != -666, "pert_time"].mean() - pert_dose_mean = df_train.loc[df_train["pert_dose"] != -666, "pert_dose"].mean() - - for df_split in (df_train, df_test): - df_split["ignore_flag_pert_time"] = (df_split["pert_time"] == -666).astype(int) - df_split["ignore_flag_pert_dose"] = (df_split["pert_dose"] == -666).astype(int) - df_split["pert_time"] = df_split["pert_time"].replace(-666, pert_time_mean) - df_split["pert_dose"] = df_split["pert_dose"].replace(-666, pert_dose_mean) - - def process_split(df_split, name): - numeric_cols = df_split.select_dtypes(include=[np.number]).columns - drop_cols = ["pert_dose", "pert_dose_unit", "pert_time", "distil_cc_q75", "pct_self_rank_q25"] - feature_cols = [c for c in numeric_cols if c not in drop_cols] - X_raw = df_split[feature_cols].values.astype(np.float32) - - print_rank0(f"[DATA] [{name}] Generating fingerprints...") - fps = np.stack([smiles_to_morgan_fp(s) for s in df_split["canonical_smiles"]]) - print_rank0(f"[DATA] [{name}] Fingerprint shape: {fps.shape}") - - pert_time = df_split["pert_time"].to_numpy().reshape(-1, 1).astype(np.float32) - pert_dose = df_split["pert_dose"].to_numpy().reshape(-1, 1).astype(np.float32) - ign_t = df_split["ignore_flag_pert_time"].to_numpy().reshape(-1, 1).astype(np.float32) - ign_d = df_split["ignore_flag_pert_dose"].to_numpy().reshape(-1, 1).astype(np.float32) - - return X_raw, fps, pert_time, pert_dose, ign_t, ign_d, df_split["cell_id"].to_numpy() - - X_train_raw, morgan_train, pt_train, pd_train, ign_t_train, ign_d_train, cells_train = process_split(df_train, "train") - X_test_raw, morgan_test, pt_test, pd_test, ign_t_test, ign_d_test, cells_test = process_split(df_test, "test") - - print_rank0("[DATA] Scaling gene expression...") - scaler_genes = StandardScaler() - X_train_scaled = scaler_genes.fit_transform(X_train_raw) - X_test_scaled = scaler_genes.transform(X_test_raw) - - print_rank0(f"[DATA] Loading controls from {PATH_CTLS}") - ctrls_df = pd.read_csv(PATH_CTLS, index_col=0) - - unique_cells = np.union1d(np.unique(cells_train), np.unique(cells_test)) - ctrls_df = ctrls_df.loc[ctrls_df.index.intersection(unique_cells)] - - scaler_ctrls = StandardScaler() - ctrls_scaled = scaler_ctrls.fit_transform(ctrls_df.values) - - actual_n_ctrl_pcs = min(n_context_pcs, ctrls_scaled.shape[0], ctrls_scaled.shape[1]) - print_rank0(f"[DATA] Using {actual_n_ctrl_pcs} control PCs") - - pca_ctrls = PCA(n_components=actual_n_ctrl_pcs, random_state=RANDOM_STATE) - ctrls_pcs = pca_ctrls.fit_transform(ctrls_scaled) - cell2vec = dict(zip(ctrls_df.index, ctrls_pcs)) - - if not cell2vec: - raise ValueError("No overlapping cell IDs") - - print_rank0(f"[DATA] Control embeddings for {len(cell2vec)} cells") - - def build_context(df_split, X_scaled, morgan, pt, pd, ign_t, ign_d, scaler=None, fit=False): - cell_ids = df_split["cell_id"].to_numpy() - unique_cells_split = np.sort(df_split["cell_id"].unique()) - - all_cont = [] - valid_cells = [] - - for cell_id in unique_cells_split: - if cell_id not in cell2vec: - continue - mask = cell_ids == cell_id - if mask.sum() == 0: - continue - valid_cells.append(cell_id) - cont = np.hstack( - [ - np.tile(cell2vec[cell_id], (mask.sum(), 1)), - pt[mask], - pd[mask], - ] - ).astype(np.float32) - all_cont.append(cont) - - if fit: - all_cont_stacked = np.vstack(all_cont) - scaler = StandardScaler() - scaler.fit(all_cont_stacked) - - X_list, C_list, cid_list = [], [], [] - - for i, cell_id in enumerate(valid_cells): - mask = cell_ids == cell_id - X_cell = X_scaled[mask] - cont_scaled = scaler.transform(all_cont[i]) - C_cell = np.hstack( - [ - cont_scaled, - morgan[mask], - ign_t[mask], - ign_d[mask], - ] - ).astype(np.float32) - - X_list.append(X_cell) - C_list.append(C_cell) - cid_list.append(cell_ids[mask]) - - X_final = np.vstack(X_list) - C_final = np.vstack(C_list) - cell_ids_final = np.concatenate(cid_list) - - return X_final, C_final, cell_ids_final, scaler - - print_rank0("[DATA] Building context matrices...") - X_train, C_train, cell_ids_train, ctx_scaler = build_context( - df_train, X_train_scaled, morgan_train, pt_train, pd_train, ign_t_train, ign_d_train, fit=True - ) - X_test, C_test, cell_ids_test, _ = build_context( - df_test, X_test_scaled, morgan_test, pt_test, pd_test, ign_t_test, ign_d_test, scaler=ctx_scaler - ) - - print_rank0(f"[DATA] Context shapes: C_train={C_train.shape}, C_test={C_test.shape}") - - actual_n_data_pcs = min(n_data_pcs, X_train.shape[1], X_train.shape[0]) - print_rank0(f"[DATA] Using {actual_n_data_pcs} data PCs") - - pca_data = PCA(n_components=actual_n_data_pcs, random_state=RANDOM_STATE) - X_train_pca = pca_data.fit_transform(X_train) - X_test_pca = pca_data.transform(X_test) - - pca_scaler = StandardScaler() - X_train_norm = pca_scaler.fit_transform(X_train_pca).astype(np.float32) - X_test_norm = pca_scaler.transform(X_test_pca).astype(np.float32) - - print_rank0(f"[DATA] Final: X_train={X_train_norm.shape}, X_test={X_test_norm.shape}") - print_rank0(f"[DATA] Final: C_train={C_train.shape}, C_test={C_test.shape}") - - if use_cache and is_global_zero(): - cache_data = { - "C_train": C_train, - "X_train_norm": X_train_norm, - "C_test": C_test, - "X_test_norm": X_test_norm, - "cell_ids_train": cell_ids_train, - "cell_ids_test": cell_ids_test, - } - os.makedirs(os.path.dirname(cache_path), exist_ok=True) - with open(cache_path, "wb") as f: - pickle.dump(cache_data, f) - print_rank0(f"[DATA] Saved cache: {cache_path}") - - return C_train, X_train_norm, C_test, X_test_norm, cell_ids_train, cell_ids_test - - -# Benchmark result -@dataclass -class BenchResult: - label: str - wall_seconds: float - train_mse_mean: float - test_mse_mean: float - num_gpus: int - batch_size_per_gpu: int - effective_batch_size: int - samples_per_second: float - num_archetypes: int - encoder_width: int - encoder_layers: int - n_bootstraps: int - speedup: float = 1.0 - efficiency: float = 100.0 - - -# Benchmark runner -def run_ccn_benchmark( - label: str, - C_train: np.ndarray, - X_train_norm: np.ndarray, - C_test: np.ndarray, - X_test_norm: np.ndarray, - epochs: int, - devices: int, - batch_size_per_gpu: int = 512, - num_workers: int = 4, - num_archetypes: int = 64, - encoder_width: int = 256, - encoder_layers: int = 6, - n_bootstraps: int = 3, - warmup_epochs: int = 1, - baseline_time: Optional[float] = None, -) -> BenchResult: - world_size = int(os.environ.get("WORLD_SIZE", "1")) - rank = get_rank() - local_rank = get_local_rank() - launched_with_torchrun = world_size > 1 - - if torch.cuda.is_available() and devices > 0: - accelerator = "gpu" - if launched_with_torchrun: - devices = world_size - else: - accelerator = "cpu" - devices = 1 - num_workers = 0 - - if launched_with_torchrun and num_workers > 2: - num_workers = 2 - - effective_batch = batch_size_per_gpu * max(world_size, 1) - - print_rank0(f"\n{'=' * 70}") - print_rank0(f"[{label}] HEAVY CCN BENCHMARK") - print_rank0(f"{'=' * 70}") - print_rank0(f" World size: {world_size}") - print_rank0(f" Accelerator: {accelerator}") - print_rank0(f" Devices: {devices}") - print_rank0(f" Batch size per GPU: {batch_size_per_gpu}") - print_rank0(f" Effective batch size: {effective_batch}") - print_rank0(f" Epochs: {epochs} (+ {warmup_epochs} warmup)") - print_rank0(f" Num workers: {num_workers}") - print_rank0(f" --- CCN Config (HEAVY) ---") - print_rank0(f" Archetypes: {num_archetypes}") - print_rank0(f" Encoder width: {encoder_width}") - print_rank0(f" Encoder layers: {encoder_layers}") - print_rank0(f" Bootstraps: {n_bootstraps}") - print_rank0(f" Data dims: C={C_train.shape[1]}, X={X_train_norm.shape[1]}") - - print( - f"[{label}] [RANK {rank} / LOCAL {local_rank}] " - f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}", - flush=True, - ) - - strategy_kwarg = "auto" - if accelerator == "gpu" and launched_with_torchrun and world_size > 1: - try: - from pytorch_lightning.strategies import DDPStrategy - - strategy_kwarg = DDPStrategy( - process_group_backend="nccl", - find_unused_parameters=False, - gradient_as_bucket_view=True, - ) - print_rank0(f"[{label}] Using DDPStrategy with NCCL + gradient_as_bucket_view") - except Exception as e: - strategy_kwarg = "ddp" - print_rank0(f"[{label}] Falling back to strategy='ddp': {e}") - - trainer_kwargs = { - "max_epochs": epochs + warmup_epochs, - "accelerator": accelerator, - "devices": devices, - "enable_progress_bar": False, - "logger": False, - "enable_checkpointing": False, - "num_sanity_val_steps": 0, - "precision": "16-mixed" if accelerator == "gpu" else 32, - "strategy": strategy_kwarg, - } - - print_rank0(f"[{label}] Trainer kwargs: {trainer_kwargs}") - print_rank0(f"[{label}] Constructing ContextualizedCorrelationNetworks...") - - ccn = ContextualizedCorrelationNetworks( - encoder_type="mlp", - num_archetypes=num_archetypes, - n_bootstraps=n_bootstraps, - encoder_kwargs={ - "width": encoder_width, - "layers": encoder_layers, - }, - trainer_kwargs=trainer_kwargs, - es_patience=0, - ) - - context_dim = C_train.shape[1] - x_dim = X_train_norm.shape[1] - encoder_params = ( - context_dim * encoder_width - + encoder_width * encoder_width * (encoder_layers - 1) - + encoder_width * num_archetypes - ) - archetype_params = num_archetypes * x_dim * x_dim - total_params = n_bootstraps * (encoder_params + archetype_params) - print_rank0(f"[{label}] Estimated parameters: ~{total_params:,} ({total_params / 1e6:.2f}M)") - - barrier() - if torch.cuda.is_available(): - torch.cuda.synchronize() - - print_rank0(f"[{label}] Starting training...") - t0 = time.time() - - ccn.fit( - C_train, - X_train_norm, - train_batch_size=batch_size_per_gpu, - val_batch_size=batch_size_per_gpu, - test_batch_size=batch_size_per_gpu, - num_workers=num_workers, - persistent_workers=(num_workers > 0), - pin_memory=(accelerator == "gpu"), - ) - - barrier() - if torch.cuda.is_available(): - torch.cuda.synchronize() - - wall = time.time() - t0 - - if warmup_epochs > 0 and epochs > 0: - wall_per_epoch = wall / (epochs + warmup_epochs) - wall = wall_per_epoch * epochs - - print_rank0(f"[{label}] Training completed in {wall:.2f}s") - - n_samples = C_train.shape[0] - samples_per_sec = (n_samples * epochs) / max(wall, 1e-6) - - speedup = 1.0 - efficiency = 100.0 - if baseline_time is not None and baseline_time > 0: - speedup = baseline_time / wall - efficiency = (speedup / world_size) * 100 - - train_mse = float("nan") - test_mse = float("nan") - - if is_global_zero(): - try: - print_rank0(f"[{label}] Computing MSE...") - mse_train_vec = ccn.measure_mses(C_train, X_train_norm, individual_preds=False) - mse_test_vec = ccn.measure_mses(C_test, X_test_norm, individual_preds=False) - train_mse = float(np.mean(mse_train_vec)) - test_mse = float(np.mean(mse_test_vec)) - except Exception as e: - warnings.warn(f"[{label}] measure_mses failed: {e}") - - print_rank0(f"\n[{label}] RESULTS:") - print_rank0(f" Wall time: {wall:.2f}s") - print_rank0(f" Samples/sec: {samples_per_sec:.1f}") - print_rank0(f" Train MSE: {train_mse:.6f}") - print_rank0(f" Test MSE: {test_mse:.6f}") - if baseline_time: - print_rank0(f" Speedup: {speedup:.2f}x") - print_rank0(f" Efficiency: {efficiency:.1f}%") - - return BenchResult( - label=label, - wall_seconds=wall, - train_mse_mean=train_mse, - test_mse_mean=test_mse, - num_gpus=world_size, - batch_size_per_gpu=batch_size_per_gpu, - effective_batch_size=effective_batch, - samples_per_second=samples_per_sec, - num_archetypes=num_archetypes, - encoder_width=encoder_width, - encoder_layers=encoder_layers, - n_bootstraps=n_bootstraps, - speedup=speedup, - efficiency=efficiency, - ) - - -# CLI -def parse_args(): - import argparse - - ap = argparse.ArgumentParser(description="Heavy ContextualizedCorrelationNetworks Scaling Benchmark (no CSV output)") - - ap.add_argument("--epochs", type=int, default=20) - ap.add_argument("--warmup-epochs", type=int, default=1) - ap.add_argument("--batch-size", type=int, default=512) - ap.add_argument("--num-workers", type=int, default=4) - - ap.add_argument("--archetypes", type=int, default=64) - ap.add_argument("--encoder-width", type=int, default=256) - ap.add_argument("--encoder-layers", type=int, default=6) - ap.add_argument("--bootstraps", type=int, default=3) - - ap.add_argument("--data-pcs", type=int, default=100) - ap.add_argument("--context-pcs", type=int, default=100) - ap.add_argument("--subsample-fraction", type=float, default=None) - - ap.add_argument("--devices", type=int, default=1) - ap.add_argument("--label", type=str, default=None) - ap.add_argument("--baseline-time", type=float, default=None) - ap.add_argument("--no-cache", action="store_true") - - return ap.parse_args() - - -# Main -def main(): - args = parse_args() - set_env_defaults() - - world_size = get_world_size() - label = args.label or f"{world_size}gpu_ccn_heavy" - - print_rank0("\n" + "=" * 70) - print_rank0("HEAVY ContextualizedCorrelationNetworks SCALING BENCHMARK (NO CSV)") - print_rank0("=" * 70) - print_rank0(f" World size: {world_size}") - print_rank0(f" Epochs: {args.epochs}") - print_rank0(f" Batch size: {args.batch_size}") - print_rank0(f" Archetypes: {args.archetypes}") - print_rank0(f" Encoder: {args.encoder_width}w × {args.encoder_layers}L") - print_rank0(f" Bootstraps: {args.bootstraps}") - print_rank0(f" Data PCs: {args.data_pcs}") - - C_train, X_train_norm, C_test, X_test_norm, _, _ = load_and_preprocess( - subsample_fraction=args.subsample_fraction, - use_cache=not args.no_cache, - n_data_pcs=args.data_pcs, - n_context_pcs=args.context_pcs, - ) - - barrier() - - _ = run_ccn_benchmark( - label=label, - C_train=C_train, - X_train_norm=X_train_norm, - C_test=C_test, - X_test_norm=X_test_norm, - epochs=args.epochs, - devices=args.devices, - batch_size_per_gpu=args.batch_size, - num_workers=args.num_workers, - num_archetypes=args.archetypes, - encoder_width=args.encoder_width, - encoder_layers=args.encoder_layers, - n_bootstraps=args.bootstraps, - warmup_epochs=args.warmup_epochs, - baseline_time=args.baseline_time, - ) - - -if __name__ == "__main__": - main() diff --git a/networks_pert_scale_bench.py b/networks_pert_scale_bench.py deleted file mode 100644 index edd15c6b..00000000 --- a/networks_pert_scale_bench.py +++ /dev/null @@ -1,750 +0,0 @@ -#!/usr/bin/env python3 -# Benchmark script that preprocesses unseen_pert data and compares 1-GPU training vs 2-GPU DDP training for a simple MLP regressor. - -import os -import time -import csv -import warnings -from dataclasses import dataclass -from typing import Tuple, Optional, List - -import numpy as np -import pandas as pd -from sklearn.decomposition import PCA -from sklearn.model_selection import train_test_split -from sklearn.preprocessing import StandardScaler - -import torch -import torch.nn as nn -from torch.utils.data import DataLoader, TensorDataset -from torch.utils.data.distributed import DistributedSampler -import torch.multiprocessing as mp -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP - -from rdkit import Chem -from rdkit.Chem import rdFingerprintGenerator - - -# Paths and basic config -BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -DATA_DIR = os.path.join(os.path.dirname(BASE_DIR), "data") - -PATH_L1000 = os.path.join(DATA_DIR, "trt_cp_smiles_qc.csv") -PATH_CTLS = os.path.join(DATA_DIR, "ctrls.csv") - -N_DATA_PCS = 50 -PERTURBATION_HOLDOUT_SIZE = 0.2 -RANDOM_STATE = 42 - -morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=3, fpSize=4096) - - -# Environment and RNG seeding -def set_env_defaults(): - os.environ.setdefault("OMP_NUM_THREADS", "1") - os.environ.setdefault("MKL_NUM_THREADS", "1") - os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - - try: - torch.set_float32_matmul_precision("high") - except Exception: - pass - - np.random.seed(RANDOM_STATE) - torch.manual_seed(RANDOM_STATE) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(RANDOM_STATE) - - -def set_seeds(rank: int): - np.random.seed(RANDOM_STATE + rank) - torch.manual_seed(RANDOM_STATE + rank) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(RANDOM_STATE + rank) - - -# Fingerprint helper -def smiles_to_morgan_fp(smiles: str) -> np.ndarray: - try: - mol = Chem.MolFromSmiles(smiles) - if mol is None: - warnings.warn(f"Invalid SMILES: {smiles}") - return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) - fp = morgan_gen.GetFingerprint(mol) - arr = np.array(fp, dtype=np.float32) - return arr - except Exception as e: - warnings.warn(f"Error processing SMILES {smiles}: {e}") - return np.zeros(morgan_gen.GetOptions().fpSize, dtype=np.float32) - - -# Data preprocessing for unseen_pert -def load_and_preprocess( - subsample_fraction: Optional[float] = None, -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - print(f"Reading L1000 data from {PATH_L1000}") - df = pd.read_csv(PATH_L1000, engine="pyarrow") - - df = df[df["pert_type"].isin(["trt_cp"])] - - bad = ( - (df["distil_cc_q75"] < 0.2) - | (df["distil_cc_q75"] == -666) - | (df["distil_cc_q75"].isna()) - | (df["pct_self_rank_q25"] > 5) - | (df["pct_self_rank_q25"] == -666) - | (df["pct_self_rank_q25"].isna()) - ) - df = df[~bad] - - df = df.dropna(subset=["canonical_smiles"]) - df = df[df["canonical_smiles"] != ""] - - print(f"Remaining samples after QC + SMILES filter: {len(df)}") - - if subsample_fraction is not None: - df = df.sample(frac=subsample_fraction, random_state=RANDOM_STATE) - print(f"Subsampled to {len(df)} samples ({subsample_fraction * 100:.1f}% of data)") - - unique_smiles = df["canonical_smiles"].unique() - print(f"Found {len(unique_smiles)} unique perturbations (SMILES)") - smiles_train, smiles_test = train_test_split( - unique_smiles, - test_size=PERTURBATION_HOLDOUT_SIZE, - random_state=RANDOM_STATE, - ) - - df_train = df[df["canonical_smiles"].isin(smiles_train)].copy() - df_test = df[df["canonical_smiles"].isin(smiles_test)].copy() - - print(f"Perturbation split: {len(smiles_train)} train, {len(smiles_test)} test perturbations") - print(f"Sample split: {len(df_train)} train, {len(df_test)} test samples") - - pert_time_mean = None - pert_dose_mean = None - - for df_split, split_name in ((df_train, "train"), (df_test, "test")): - df_split["ignore_flag_pert_time"] = (df_split["pert_time"] == -666).astype(int) - df_split["ignore_flag_pert_dose"] = (df_split["pert_dose"] == -666).astype(int) - - for col in ["pert_time", "pert_dose"]: - if split_name == "train": - mean_val = df_split.loc[df_split[col] != -666, col].mean() - if col == "pert_time": - pert_time_mean = mean_val - else: - pert_dose_mean = mean_val - else: - mean_val = pert_time_mean if col == "pert_time" else pert_dose_mean - - df_split[col] = df_split[col].replace(-666, mean_val) - - def process_data_split(df_split, split_name): - numeric_cols = df_split.select_dtypes(include=[np.number]).columns - drop_cols = [ - "pert_dose", - "pert_dose_unit", - "pert_time", - "distil_cc_q75", - "pct_self_rank_q25", - ] - feature_cols = [c for c in numeric_cols if c not in drop_cols] - X_raw = df_split[feature_cols].values.astype(np.float32) - - print(f"[{split_name}] Generating Morgan fingerprints...") - fps = np.stack([smiles_to_morgan_fp(s) for s in df_split["canonical_smiles"]]) - print(f"[{split_name}] Morgan shape: {fps.shape}") - - pert_time = df_split["pert_time"].to_numpy().reshape(-1, 1).astype(np.float32) - pert_dose = df_split["pert_dose"].to_numpy().reshape(-1, 1).astype(np.float32) - ignore_time = df_split["ignore_flag_pert_time"].to_numpy().reshape(-1, 1).astype(np.float32) - ignore_dose = df_split["ignore_flag_pert_dose"].to_numpy().reshape(-1, 1).astype(np.float32) - - return X_raw, fps, pert_time, pert_dose, ignore_time, ignore_dose - - (X_raw_train, morgan_train, pt_train, pd_train, ign_t_train, ign_d_train) = process_data_split( - df_train, "train" - ) - (X_raw_test, morgan_test, pt_test, pd_test, ign_t_test, ign_d_test) = process_data_split( - df_test, "test" - ) - - print("Scaling gene expression...") - scaler_genes = StandardScaler() - X_train_scaled = scaler_genes.fit_transform(X_raw_train) - X_test_scaled = scaler_genes.transform(X_raw_test) - - morgan_train_scaled = morgan_train.astype(np.float32) - morgan_test_scaled = morgan_test.astype(np.float32) - - print(f"Reading control profiles from {PATH_CTLS}") - ctrls_df = pd.read_csv(PATH_CTLS, index_col=0) - - unique_cells_train = np.sort(df_train["cell_id"].unique()) - unique_cells_test = np.sort(df_test["cell_id"].unique()) - unique_cells_all = np.sort(np.union1d(unique_cells_train, unique_cells_test)) - - ctrls_df = ctrls_df.loc[ctrls_df.index.intersection(unique_cells_all)] - scaler_ctrls = StandardScaler() - ctrls_scaled = scaler_ctrls.fit_transform(ctrls_df.values) - - n_cells = ctrls_scaled.shape[0] - n_ctrl_pcs = min(50, n_cells) - - pca_ctrls = PCA(n_components=n_ctrl_pcs, random_state=RANDOM_STATE) - ctrls_pcs = pca_ctrls.fit_transform(ctrls_scaled) - - cell2vec = dict(zip(ctrls_df.index, ctrls_pcs)) - if not cell2vec: - raise ValueError("No overlapping cell IDs between L1000 and ctrls.csv") - - print(f"Control embeddings for {len(cell2vec)} cells (PCs={n_ctrl_pcs})") - - def build_context_matrix( - df_split, - X_scaled, - morgan_scaled, - pt, - pd, - ign_t, - ign_d, - split_name, - scaler_context=None, - is_train=False, - ): - cell_ids = df_split["cell_id"].to_numpy() - unique_cells_split = np.sort(df_split["cell_id"].unique()) - - all_continuous_context = [] - valid_cells = [] - - for cell_id in unique_cells_split: - if cell_id not in cell2vec: - print(f"[{split_name}] Warning: cell {cell_id} not in control embeddings; skipping") - continue - mask = cell_ids == cell_id - if mask.sum() == 0: - continue - - valid_cells.append(cell_id) - cont = np.hstack( - [ - np.tile(cell2vec[cell_id], (mask.sum(), 1)), - pt[mask], - pd[mask], - ] - ).astype(np.float32) - all_continuous_context.append(cont) - - if is_train: - all_cont = np.vstack(all_continuous_context) - scaler_context = StandardScaler() - scaler_context.fit(all_cont) - print(f"[{split_name}] Context scaler fit on {all_cont.shape} continuous features") - - if scaler_context is None: - raise ValueError("scaler_context must be provided for non-training split") - - X_list, C_list, cid_list = [], [], [] - - for i, cell_id in enumerate(valid_cells): - mask = cell_ids == cell_id - X_cell = X_scaled[mask] - cont_scaled = scaler_context.transform(all_continuous_context[i]) - C_cell = np.hstack( - [ - cont_scaled, - morgan_scaled[mask], - ign_t[mask], - ign_d[mask], - ] - ).astype(np.float32) - - X_list.append(X_cell) - C_list.append(C_cell) - cid_list.append(cell_ids[mask]) - - if not X_list: - raise RuntimeError(f"No data for split {split_name}") - - X_final = np.vstack(X_list) - C_final = np.vstack(C_list) - cell_ids_final = np.concatenate(cid_list) - - return X_final, C_final, cell_ids_final, scaler_context - - print("Building context matrices...") - X_train, C_train, cell_ids_train, scaler_context = build_context_matrix( - df_train, - X_train_scaled, - morgan_train_scaled, - pt_train, - pd_train, - ign_t_train, - ign_d_train, - "train", - is_train=True, - ) - X_test, C_test, cell_ids_test, _ = build_context_matrix( - df_test, - X_test_scaled, - morgan_test_scaled, - pt_test, - pd_test, - ign_t_test, - ign_d_test, - "test", - scaler_context=scaler_context, - is_train=False, - ) - - print(f"C_train: {C_train.shape}, X_train: {X_train.shape}") - print(f"C_test: {C_test.shape}, X_test: {X_test.shape}") - - print("PCA + scaling on gene features...") - pca_data = PCA(n_components=N_DATA_PCS, random_state=RANDOM_STATE) - X_train_pca = pca_data.fit_transform(X_train) - X_test_pca = pca_data.transform(X_test) - - pca_scaler = StandardScaler() - X_train_norm = pca_scaler.fit_transform(X_train_pca) - X_test_norm = pca_scaler.transform(X_test_pca) - - print(f"Final X_train_norm: {X_train_norm.shape}, X_test_norm: {X_test_norm.shape}") - - return C_train, X_train_norm, C_test, X_test_norm, cell_ids_train, cell_ids_test - - -@dataclass -class BenchResult: - label: str - wall_seconds: float - samples_total: int - throughput_sps: float - train_mse_mean: float - test_mse_mean: float - - -class SimpleRegressor(nn.Module): - def __init__(self, in_dim: int, out_dim: int): - super().__init__() - hidden = 512 - self.net = nn.Sequential( - nn.Linear(in_dim, hidden), - nn.ReLU(), - nn.Linear(hidden, hidden), - nn.ReLU(), - nn.Linear(hidden, out_dim), - ) - - def forward(self, x): - return self.net(x) - - -def run_single_gpu( - epochs: int, - batch_size: int, - num_workers: int, - subsample_fraction: Optional[float], -) -> BenchResult: - label = "1gpu_single" - print("\n================ 1-GPU baseline (single process) ================") - - C_train, X_train_norm, C_test, X_test_norm, _, _ = load_and_preprocess( - subsample_fraction=subsample_fraction - ) - - if torch.cuda.is_available(): - device = torch.device("cuda:0") - print(f"[{label}] Using CUDA on device {device}") - else: - device = torch.device("cpu") - print(f"[{label}] CUDA not available, using CPU") - - C_train_t = torch.from_numpy(C_train).float() - X_train_t = torch.from_numpy(X_train_norm).float() - C_test_t = torch.from_numpy(C_test).float() - X_test_t = torch.from_numpy(X_test_norm).float() - - train_ds = TensorDataset(C_train_t, X_train_t) - test_ds = TensorDataset(C_test_t, X_test_t) - - train_loader = DataLoader( - train_ds, - batch_size=batch_size, - shuffle=True, - num_workers=num_workers, - pin_memory=torch.cuda.is_available(), - ) - test_loader = DataLoader( - test_ds, - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - pin_memory=torch.cuda.is_available(), - ) - - in_dim = C_train.shape[1] - out_dim = X_train_norm.shape[1] - - model = SimpleRegressor(in_dim, out_dim).to(device) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - criterion = nn.MSELoss() - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.backends.cudnn.benchmark = True - torch.cuda.synchronize() - - n_samples = C_train.shape[0] - t0 = time.time() - - for epoch in range(epochs): - model.train() - epoch_loss = 0.0 - for batch_C, batch_X in train_loader: - batch_C = batch_C.to(device, non_blocking=True) - batch_X = batch_X.to(device, non_blocking=True) - - optimizer.zero_grad() - preds = model(batch_C) - loss = criterion(preds, batch_X) - loss.backward() - optimizer.step() - epoch_loss += loss.item() * batch_C.size(0) - - epoch_loss /= n_samples - print(f"[{label}] Epoch {epoch+1}/{epochs} - train MSE {epoch_loss:.6f}") - - if torch.cuda.is_available(): - torch.cuda.synchronize() - wall = time.time() - t0 - - samples_total = n_samples * epochs - throughput = samples_total / max(wall, 1e-9) - - def eval_mse(loader, split_name: str) -> float: - model.eval() - total_loss = 0.0 - count = 0 - with torch.no_grad(): - for batch_C, batch_X in loader: - batch_C = batch_C.to(device, non_blocking=True) - batch_X = batch_X.to(device, non_blocking=True) - preds = model(batch_C) - loss = criterion(preds, batch_X) - bsz = batch_C.size(0) - total_loss += loss.item() * bsz - count += bsz - mse = total_loss / max(count, 1) - print(f"[{label}] {split_name} MSE {mse:.6f}") - return mse - - train_mse = eval_mse(train_loader, "train") - test_mse = eval_mse(test_loader, "test") - - print(f"\n[{label}] run complete") - print(f" wall time (s): {wall:.2f}") - print(f" total samples: {samples_total}") - print(f" throughput (samples/s): {throughput:.2f}") - print(f" final train MSE: {train_mse:.6f}") - print(f" final test MSE: {test_mse:.6f}") - - return BenchResult( - label=label, - wall_seconds=wall, - samples_total=samples_total, - throughput_sps=throughput, - train_mse_mean=train_mse, - test_mse_mean=test_mse, - ) - - -def ddp_worker( - rank: int, - world_size: int, - port: str, - epochs: int, - batch_size: int, - num_workers: int, - subsample_fraction: Optional[float], - result_dict, -): - set_seeds(rank) - - if torch.cuda.is_available(): - torch.cuda.set_device(rank) - device = torch.device(f"cuda:{rank}") - else: - device = torch.device("cpu") - - init_method = f"tcp://127.0.0.1:{port}" - dist.init_process_group( - backend="gloo", - init_method=init_method, - world_size=world_size, - rank=rank, - ) - - label = "2gpu_ddp" - if rank == 0: - print("\n================ 2-GPU DDP baseline ================") - print(f"[{label}] world_size={world_size}, backend=gloo, init_method={init_method}") - if torch.cuda.is_available(): - print(f"[{label}] Using GPUs 0 and 1 with DDP") - - C_train, X_train_norm, C_test, X_test_norm, _, _ = load_and_preprocess( - subsample_fraction=subsample_fraction - ) - - C_train_t = torch.from_numpy(C_train).float() - X_train_t = torch.from_numpy(X_train_norm).float() - C_test_t = torch.from_numpy(C_test).float() - X_test_t = torch.from_numpy(X_test_norm).float() - - train_ds = TensorDataset(C_train_t, X_train_t) - test_ds = TensorDataset(C_test_t, X_test_t) - - train_sampler = DistributedSampler( - train_ds, - num_replicas=world_size, - rank=rank, - shuffle=True, - drop_last=False, - ) - - train_loader = DataLoader( - train_ds, - batch_size=batch_size, - sampler=train_sampler, - num_workers=num_workers, - pin_memory=torch.cuda.is_available(), - ) - - in_dim = C_train.shape[1] - out_dim = X_train_norm.shape[1] - - model = SimpleRegressor(in_dim, out_dim).to(device) - ddp_model = DDP(model, device_ids=[rank] if torch.cuda.is_available() else None) - - optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3) - criterion = nn.MSELoss() - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.backends.cudnn.benchmark = True - torch.cuda.synchronize() - - n_samples = C_train.shape[0] - - dist.barrier() - if torch.cuda.is_available(): - torch.cuda.synchronize() - t0 = time.time() - - for epoch in range(epochs): - ddp_model.train() - train_sampler.set_epoch(epoch) - - running_loss = 0.0 - count_seen = 0 - - for batch_C, batch_X in train_loader: - batch_C = batch_C.to(device, non_blocking=True) - batch_X = batch_X.to(device, non_blocking=True) - - optimizer.zero_grad() - preds = ddp_model(batch_C) - loss = criterion(preds, batch_X) - loss.backward() - optimizer.step() - - bsz = batch_C.size(0) - running_loss += loss.item() * bsz - count_seen += bsz - - loss_tensor = torch.tensor([running_loss, count_seen], dtype=torch.float64, device=device) - dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM) - if rank == 0: - total_loss, total_count = loss_tensor.tolist() - epoch_loss = total_loss / max(total_count, 1.0) - print(f"[{label}] Epoch {epoch+1}/{epochs} - train MSE {epoch_loss:.6f}") - - dist.barrier() - if torch.cuda.is_available(): - torch.cuda.synchronize() - wall = time.time() - t0 - - if rank == 0: - eval_model = ddp_model.module - eval_model.eval() - - test_loader = DataLoader( - test_ds, - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - pin_memory=torch.cuda.is_available(), - ) - train_loader_full = DataLoader( - train_ds, - batch_size=batch_size, - shuffle=False, - num_workers=num_workers, - pin_memory=torch.cuda.is_available(), - ) - - def eval_mse(loader, split_name: str) -> float: - total_loss = 0.0 - count = 0 - with torch.no_grad(): - for batch_C, batch_X in loader: - batch_C = batch_C.to(device, non_blocking=True) - batch_X = batch_X.to(device, non_blocking=True) - preds = eval_model(batch_C) - loss = criterion(preds, batch_X) - bsz = batch_C.size(0) - total_loss += loss.item() * bsz - count += bsz - mse = total_loss / max(count, 1) - print(f"[{label}] {split_name} MSE {mse:.6f}") - return mse - - train_mse = eval_mse(train_loader_full, "train") - test_mse = eval_mse(test_loader, "test") - - samples_total = n_samples * epochs - throughput = samples_total / max(wall, 1e-9) - - print(f"\n[{label}] run complete") - print(f" wall time (s): {wall:.2f}") - print(f" total samples: {samples_total}") - print(f" throughput (samples/s): {throughput:.2f}") - print(f" final train MSE: {train_mse:.6f}") - print(f" final test MSE: {test_mse:.6f}") - - result_dict["2gpu_ddp"] = BenchResult( - label=label, - wall_seconds=wall, - samples_total=samples_total, - throughput_sps=throughput, - train_mse_mean=train_mse, - test_mse_mean=test_mse, - ) - - dist.destroy_process_group() - - -# CSV writer -def save_results_csv(results: List[BenchResult], outdir: str): - os.makedirs(outdir, exist_ok=True) - path = os.path.join(outdir, "scale_results_unseen_ddp.csv") - with open(path, "w", newline="") as f: - writer = csv.writer(f) - writer.writerow( - [ - "label", - "wall_seconds", - "samples_total", - "throughput_samples_per_s", - "train_mse_mean", - "test_mse_mean", - ] - ) - for r in results: - writer.writerow( - [ - r.label, - f"{r.wall_seconds:.6f}", - r.samples_total, - f"{r.throughput_sps:.6f}", - f"{r.train_mse_mean:.6f}", - f"{r.test_mse_mean:.6f}", - ] - ) - print(f"\nSaved CSV → {path}") - - -# CLI and main -def parse_args(): - import argparse - - ap = argparse.ArgumentParser() - ap.add_argument("--epochs", type=int, default=3) - ap.add_argument("--batch-size", type=int, default=256) - ap.add_argument( - "--num-workers", - type=int, - default=0, - help="DataLoader workers (0 is safest on HPC).", - ) - ap.add_argument( - "--subsample-fraction", - type=float, - default=None, - help="Optional fraction of rows to subsample for quick tests", - ) - ap.add_argument( - "--outdir", - type=str, - default="bench_out_unseen", - ) - ap.add_argument( - "--ddp-port", - type=str, - default="29611", - help="TCP port for DDP init_method (tcp://127.0.0.1:PORT).", - ) - return ap.parse_args() - - -def main(): - args = parse_args() - mp.set_start_method("spawn", force=True) - set_env_defaults() - - results: List[BenchResult] = [] - - res_1gpu = run_single_gpu( - epochs=args.epochs, - batch_size=args.batch_size, - num_workers=args.num_workers, - subsample_fraction=args.subsample_fraction, - ) - results.append(res_1gpu) - - if torch.cuda.is_available() and torch.cuda.device_count() >= 2: - world_size = 2 - port = args.ddp_port - - manager = mp.Manager() - result_dict = manager.dict() - - mp.spawn( - ddp_worker, - args=( - world_size, - port, - args.epochs, - args.batch_size, - args.num_workers, - args.subsample_fraction, - result_dict, - ), - nprocs=world_size, - join=True, - ) - - if "2gpu_ddp" in result_dict: - results.append(result_dict["2gpu_ddp"]) - else: - print("\n[WARN] DDP finished but no result in result_dict['2gpu_ddp'].") - else: - print("\n[Info] < 2 GPUs visible; skipping 2-GPU DDP benchmark.") - - save_results_csv(results, args.outdir) - - -if __name__ == "__main__": - main() diff --git a/scale_bench.py b/scale_bench.py deleted file mode 100644 index 4e5e96b1..00000000 --- a/scale_bench.py +++ /dev/null @@ -1,508 +0,0 @@ -#!/usr/bin/env python3 -# Single-node strong-scaling benchmark runner for ContextualizedRegression using synthetic batched data. - -import os -import time -import json -import argparse -from dataclasses import dataclass -from datetime import timedelta -from typing import Any, Optional, Tuple - -import numpy as np -import torch -import torch.nn.functional as F -import pytorch_lightning as pl -from torch.utils.data import IterableDataset, DataLoader -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.strategies import DDPStrategy - -from contextualized.regression import ContextualizedRegression - - -# Torchrun helpers -def under_torchrun() -> bool: - e = os.environ - return ("LOCAL_RANK" in e) or ("RANK" in e) or ("WORLD_SIZE" in e) - - -def world_size() -> int: - return int(os.environ.get("WORLD_SIZE", "1")) - - -def global_rank() -> int: - return int(os.environ.get("RANK", "0")) - - -def local_rank() -> int: - return int(os.environ.get("LOCAL_RANK", "0")) - - -def is_global_zero() -> bool: - return global_rank() == 0 - - -# Environment defaults -def set_env_defaults(): - os.environ.setdefault("OMP_NUM_THREADS", "1") - os.environ.setdefault("MKL_NUM_THREADS", "1") - os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - - os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") - os.environ.setdefault("NCCL_DEBUG", "WARN") - os.environ.setdefault("NCCL_P2P_DISABLE", "0") - os.environ.setdefault("NCCL_IB_DISABLE", "1") - - if "NCCL_SOCKET_IFNAME" not in os.environ: - try: - ifaces = [d for d in os.listdir("/sys/class/net") if d not in ("lo", "docker0")] - cand = next((i for i in ifaces if i.startswith(("ens", "enp", "eno", "eth", "bond", "ib"))), None) - os.environ["NCCL_SOCKET_IFNAME"] = cand or "^lo,docker0" - except Exception: - os.environ["NCCL_SOCKET_IFNAME"] = "^lo,docker0" - - if torch.cuda.is_available(): - try: - torch.backends.cuda.matmul.allow_tf32 = True - except Exception: - pass - try: - torch.set_float32_matmul_precision("high") - except Exception: - pass - try: - torch.backends.cudnn.benchmark = True - except Exception: - pass - - if under_torchrun() and torch.cuda.is_available(): - try: - torch.cuda.set_device(local_rank()) - except Exception: - pass - - if is_global_zero(): - keys = [ - "NCCL_DEBUG", - "NCCL_IB_DISABLE", - "NCCL_P2P_DISABLE", - "NCCL_SOCKET_IFNAME", - "TORCH_NCCL_ASYNC_ERROR_HANDLING", - ] - print("DDP/NCCL env:", {k: os.environ.get(k) for k in keys}) - if torch.cuda.is_available(): - print( - "CUDA:", - { - "torch": torch.__version__, - "lightning": pl.__version__, - "gpus_visible": torch.cuda.device_count(), - }, - ) - - -def map_precision(p: str): - p = (p or "").lower() - if p in ("bf16", "bfloat16", "bf16-mixed"): - return "bf16-mixed" - if p in ("fp16", "16", "16-mixed"): - return "16-mixed" - return 32 - - -# Timing callback -class SteadyStateStepTimer(Callback): - def __init__(self, warmup_steps: int, measure_steps: int): - super().__init__() - self.warmup_steps = int(warmup_steps) - self.measure_steps = int(measure_steps) - self._seen = 0 - self.step_times = [] - self._t0 = None - - @staticmethod - def _sync(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - s = self._seen - if self.warmup_steps <= s < self.warmup_steps + self.measure_steps: - self._sync() - self._t0 = time.time() - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - s = self._seen - if self.warmup_steps <= s < self.warmup_steps + self.measure_steps: - self._sync() - self.step_times.append(time.time() - (self._t0 or time.time())) - self._seen += 1 - - def measured_wall(self) -> float: - return float(sum(self.step_times)) - - -def dist_max(value: float) -> float: - try: - import torch.distributed as dist - - if dist.is_available() and dist.is_initialized(): - t = torch.tensor([value], device="cuda" if torch.cuda.is_available() else "cpu", dtype=torch.float64) - dist.all_reduce(t, op=dist.ReduceOp.MAX) - return float(t.item()) - except Exception: - pass - return float(value) - - -# Synthetic batched iterable -class SyntheticBatchStream(IterableDataset): - def __init__( - self, - batch_size: int, - c_dim: int, - x_dim: int, - y_dim: int, - buffer_batches: int, - buffer_mult: int, - seed: int, - pin: bool, - ): - super().__init__() - self.batch_size = int(batch_size) - self.c_dim = int(c_dim) - self.x_dim = int(x_dim) - self.y_dim = int(y_dim) - - self.n_batches = int(buffer_batches) * int(buffer_mult) - if self.n_batches <= 0: - raise ValueError("buffer_batches * buffer_mult must be >= 1") - - g = torch.Generator(device="cpu") - g.manual_seed(int(seed) + 1000 * global_rank()) - - self.C = torch.randn((self.n_batches, self.batch_size, self.c_dim), generator=g, device="cpu", dtype=torch.float32) - self.X = torch.randn((self.n_batches, self.batch_size, self.x_dim), generator=g, device="cpu", dtype=torch.float32) - self.Y = torch.randn((self.n_batches, self.batch_size, self.y_dim), generator=g, device="cpu", dtype=torch.float32) - - if pin and torch.cuda.is_available(): - self.C = self.C.pin_memory() - self.X = self.X.pin_memory() - self.Y = self.Y.pin_memory() - - def __iter__(self): - ws = world_size() - r = global_rank() - k = 0 - while True: - b = (k * ws + r) % self.n_batches - yield {"contexts": self.C[b], "predictors": self.X[b], "outcomes": self.Y[b]} - k += 1 - - -def _as_2d(t: torch.Tensor) -> torch.Tensor: - # Accept [B, y, 1] or [B, 1, y] and squeeze the singleton dim - if t.ndim == 3: - if t.shape[-1] == 1: - # Convert [B, y, 1] -> [B, y] - t = t.squeeze(-1) - elif t.shape[1] == 1: - # Convert [B, 1, y] -> [B, y] - t = t.squeeze(1) - if t.ndim == 1: - return t.unsqueeze(-1) - if t.ndim == 2: - return t - raise RuntimeError(f"Expected 1D or 2D tensor (or squeezable 3D), got shape {tuple(t.shape)}") - - -def _canonicalize_y(y: torch.Tensor, B: int, y_dim: int, name: str) -> torch.Tensor: - y = _as_2d(y) - if y.shape == (B, y_dim): - return y - if y.shape == (y_dim, B): - return y.transpose(0, 1) - if y_dim == 1 and y.shape == (B,): - return y.view(B, 1) - raise RuntimeError(f"{name} has incompatible shape {tuple(y.shape)}; expected [{B},{y_dim}] or [{y_dim},{B}].") - - -def _extract_mu_hat(out: Any) -> torch.Tensor: - # Prefer mu_hat as y_pred for this benchmark - if torch.is_tensor(out): - return out - - if isinstance(out, dict): - for k in ("mu_hat", "mu", "y_pred", "y_hat", "pred"): - if k in out and torch.is_tensor(out[k]): - return out[k] - raise RuntimeError(f"Forward returned dict without mu_hat/y_hat keys: {list(out.keys())}") - - if isinstance(out, (tuple, list)): - tensors = [t for t in out if torch.is_tensor(t)] - if len(tensors) >= 2: - return tensors[1] - if len(tensors) == 1: - return tensors[0] - raise RuntimeError("Forward returned tuple/list with no tensors.") - - raise RuntimeError(f"Unsupported forward output type: {type(out)}") - - -# Lightning bench module -class BenchModule(pl.LightningModule): - def __init__(self, inner: ContextualizedRegression, lr: float, y_dim: int): - super().__init__() - self.inner = inner - self.lr = float(lr) - self.y_dim = int(y_dim) - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.lr) - - def training_step(self, batch, batch_idx): - device = self.device - - C = batch["contexts"].to(device, non_blocking=True) - X = batch["predictors"].to(device, non_blocking=True) - Y = batch["outcomes"].to(device, non_blocking=True) - - B = C.shape[0] - Y_true = _canonicalize_y(Y, B, self.y_dim, "Y_true") - - # Prefer calling with dict to match internal conventions - out = self.inner({"contexts": C, "predictors": X, "outcomes": Y_true}) - - mu_hat = _extract_mu_hat(out) - Y_pred = _canonicalize_y(mu_hat, B, self.y_dim, "Y_pred(mu_hat)") - - loss = F.mse_loss(Y_pred, Y_true) - return loss - - -# Batch sizing -def resolve_batch_sizes(args, ws: int) -> Tuple[int, int]: - if args.global_batch_size is None: - per_gpu = int(args.batch_size) - return per_gpu, per_gpu * ws - gbs = int(args.global_batch_size) - if gbs % ws != 0: - raise ValueError(f"--global-batch-size {gbs} must be divisible by world_size {ws}") - return gbs // ws, gbs - - -# Trainer -def build_trainer(args, timer: SteadyStateStepTimer) -> pl.Trainer: - use_cuda = torch.cuda.is_available() and (args.run_device != "cpu") - - if use_cuda: - accelerator = "gpu" - - if under_torchrun(): - devices = 1 - else: - devices = min(int(args.devices), torch.cuda.device_count()) - - strategy = ( - DDPStrategy( - find_unused_parameters=False, - gradient_as_bucket_view=True, - static_graph=True, - timeout=timedelta(seconds=args.ddp_timeout), - ) - if (under_torchrun() or devices > 1) - else "auto" - ) - else: - accelerator = "cpu" - devices = 1 - strategy = "auto" - - max_steps = int(args.warmup_steps) + int(args.steps) - - return pl.Trainer( - accelerator=accelerator, - devices=devices, - strategy=strategy, - precision=map_precision(args.precision), - max_steps=max_steps, - max_epochs=10_000, - logger=False, - enable_checkpointing=False, - enable_progress_bar=False, - num_sanity_val_steps=0, - log_every_n_steps=50, - callbacks=[timer], - inference_mode=False, - enable_model_summary=False, - accumulate_grad_batches=1, - limit_val_batches=0, - ) - - -# Results -@dataclass -class Result: - world_size: int - batch_size_per_gpu: int - global_batch_size: int - warmup_steps: int - measured_steps: int - measured_wall_s: float - throughput_samples_per_s: float - per_gpu_throughput_samples_per_s: float - avg_step_s: float - p95_step_s: float - - -def save_result(outdir: str, res: Result): - os.makedirs(outdir, exist_ok=True) - path = os.path.join(outdir, "result.json") - with open(path, "w") as f: - json.dump(res.__dict__, f, indent=2) - return path - - -# Main bench -def run_bench(args) -> Result: - ws = world_size() if under_torchrun() else int(args.devices) - per_gpu_bs, global_bs = resolve_batch_sizes(args, ws) - - pin = args.data_device == "cpu_pinned" - - ds = SyntheticBatchStream( - batch_size=per_gpu_bs, - c_dim=args.context_dim, - x_dim=args.x_dim, - y_dim=args.y_dim, - buffer_batches=args.buffer_batches, - buffer_mult=args.buffer_mult, - seed=args.seed, - pin=pin, - ) - - dl = DataLoader(ds, batch_size=None, num_workers=0, pin_memory=False) - - inner = ContextualizedRegression( - context_dim=args.context_dim, - x_dim=args.x_dim, - y_dim=args.y_dim, - num_archetypes=args.num_archetypes, - encoder_type=args.encoder_type, - encoder_kwargs={"width": args.width, "layers": args.layers, "link_fn": "identity"}, - learning_rate=args.lr, - fit_intercept=True, - link_fn="identity", - loss_fn="mse", - model_regularizer="none", - ) - - model = BenchModule(inner=inner, lr=args.lr, y_dim=args.y_dim) - - timer = SteadyStateStepTimer(args.warmup_steps, args.steps) - trainer = build_trainer(args, timer) - - if is_global_zero(): - buffer_batches_total = int(args.buffer_batches) * int(args.buffer_mult) - buffer_samples_per_rank = int(per_gpu_bs) * buffer_batches_total - print( - "\nConfig:", - json.dumps( - { - "torchrun": under_torchrun(), - "world_size": ws, - "local_rank": local_rank(), - "batch_size_per_gpu": per_gpu_bs, - "global_batch_size": global_bs, - "steps_measured": int(args.steps), - "steps_warmup": int(args.warmup_steps), - "buffer_batches_total": buffer_batches_total, - "buffer_samples_per_rank": buffer_samples_per_rank, - "buffer_samples_global_approx": buffer_samples_per_rank * int(ws), - "run_device": args.run_device, - "data_device": args.data_device, - "pin_memory": pin, - "precision": map_precision(args.precision), - }, - indent=2, - ), - ) - - trainer.fit(model, train_dataloaders=dl) - - measured_wall = dist_max(timer.measured_wall()) - - measured_steps = int(args.steps) - samples_total = global_bs * measured_steps - throughput = samples_total / max(measured_wall, 1e-12) - per_gpu_thr = throughput / max(ws, 1) - - step_times = timer.step_times[:] if timer.step_times else [float("nan")] - avg_step = float(np.mean(step_times)) - p95_step = float(np.percentile(step_times, 95)) if len(step_times) > 1 else float("nan") - - return Result( - world_size=int(ws), - batch_size_per_gpu=int(per_gpu_bs), - global_batch_size=int(global_bs), - warmup_steps=int(args.warmup_steps), - measured_steps=int(measured_steps), - measured_wall_s=float(measured_wall), - throughput_samples_per_s=float(throughput), - per_gpu_throughput_samples_per_s=float(per_gpu_thr), - avg_step_s=float(avg_step), - p95_step_s=float(p95_step), - ) - - -def parse_args(): - ap = argparse.ArgumentParser() - ap.add_argument("--steps", type=int, default=400) - ap.add_argument("--warmup-steps", type=int, default=50) - - ap.add_argument("--batch-size", type=int, default=2048, help="Per-GPU batch size (ignored if --global-batch-size set)") - ap.add_argument("--global-batch-size", type=int, default=None, help="Fixed global batch for strong scaling") - - ap.add_argument("--precision", type=str, default="bf16") - - ap.add_argument("--context-dim", type=int, default=16) - ap.add_argument("--x-dim", type=int, default=512) - ap.add_argument("--y-dim", type=int, default=64) - - ap.add_argument("--encoder-type", type=str, default="mlp") - ap.add_argument("--num-archetypes", type=int, default=8) - ap.add_argument("--width", type=int, default=1024) - ap.add_argument("--layers", type=int, default=4) - ap.add_argument("--lr", type=float, default=1e-3) - - ap.add_argument("--buffer-batches", type=int, default=16, help="Buffer depth in batches (per rank)") - ap.add_argument("--buffer-mult", type=int, default=4, help="Extra multiplier on buffer size (per rank)") - - ap.add_argument("--data-device", choices=["cpu", "cpu_pinned"], default="cpu_pinned") - ap.add_argument("--run-device", choices=["auto", "cpu"], default="auto") - ap.add_argument("--devices", type=int, default=1, help="Used only when NOT under torchrun") - - ap.add_argument("--ddp-timeout", type=int, default=180) - ap.add_argument("--seed", type=int, default=123) - ap.add_argument("--outdir", type=str, default="bench_out") - - return ap.parse_args() - - -def main(): - set_env_defaults() - args = parse_args() - - if args.run_device == "cpu": - os.environ["CUDA_VISIBLE_DEVICES"] = "" - - res = run_bench(args) - - if is_global_zero(): - path = save_result(args.outdir, res) - print("\nResult:", json.dumps(res.__dict__, indent=2)) - print(f"\nSaved → {path}") - - -if __name__ == "__main__": - main() diff --git a/scale_bench_networks.py b/scale_bench_networks.py deleted file mode 100644 index 286c2edf..00000000 --- a/scale_bench_networks.py +++ /dev/null @@ -1,514 +0,0 @@ -#!/usr/bin/env python3 -# Torchrun-friendly DDP scaling benchmark for Contextualized network LightningModules using synthetic buffered data. - -import os -import time -import json -import math -import argparse -from dataclasses import dataclass -from datetime import timedelta -from typing import Dict - -import numpy as np -import torch -import pytorch_lightning as pl -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.strategies import DDPStrategy - -from contextualized.regression.datamodules import ContextualizedRegressionDataModule -from contextualized.regression.lightning_modules import ( - ContextualizedCorrelation, - ContextualizedMarkovGraph, -) -from contextualized.dags.lightning_modules import NOTMAD - - -# Launcher/cluster helpers -def under_torchrun() -> bool: - e = os.environ - return ("LOCAL_RANK" in e) or ("RANK" in e) or ("WORLD_SIZE" in e) - - -def world_size() -> int: - try: - return int(os.environ.get("WORLD_SIZE", "1")) - except Exception: - return 1 - - -def global_rank() -> int: - try: - return int(os.environ.get("RANK", "0")) - except Exception: - return 0 - - -def local_rank() -> int: - try: - return int(os.environ.get("LOCAL_RANK", "0")) - except Exception: - return 0 - - -def is_global_zero() -> bool: - return global_rank() == 0 - - -# Environment defaults and GPU performance flags -def set_env_defaults(): - os.environ.setdefault("OMP_NUM_THREADS", "1") - os.environ.setdefault("MKL_NUM_THREADS", "1") - os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") - - os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") - os.environ.setdefault("NCCL_DEBUG", "WARN") - os.environ.setdefault("NCCL_P2P_DISABLE", "0") - os.environ.setdefault("NCCL_IB_DISABLE", "1") - - if "NCCL_SOCKET_IFNAME" not in os.environ: - try: - ifaces = [ - d - for d in os.listdir("/sys/class/net") - if os.path.isdir(f"/sys/class/net/{d}") - ] - cand = next((i for i in ifaces if i not in ("lo", "docker0")), None) - os.environ["NCCL_SOCKET_IFNAME"] = cand or "^lo,docker0" - except Exception: - os.environ["NCCL_SOCKET_IFNAME"] = "^lo,docker0" - - # TF32 / matmul speedups for throughput benchmarking - if torch.cuda.is_available(): - try: - torch.backends.cuda.matmul.allow_tf32 = True - except Exception: - pass - try: - torch.set_float32_matmul_precision("high") - except Exception: - pass - try: - torch.backends.cudnn.benchmark = True - except Exception: - pass - - if under_torchrun() and torch.cuda.is_available(): - try: - torch.cuda.set_device(local_rank()) - except Exception: - pass - - if is_global_zero(): - keys = [ - "NCCL_DEBUG", - "NCCL_IB_DISABLE", - "NCCL_P2P_DISABLE", - "NCCL_SOCKET_IFNAME", - "TORCH_NCCL_ASYNC_ERROR_HANDLING", - ] - print("DDP/NCCL env:", {k: os.environ.get(k) for k in keys}) - if torch.cuda.is_available(): - print( - "CUDA:", - { - "torch": torch.__version__, - "lightning": pl.__version__, - "gpus_visible": torch.cuda.device_count(), - }, - ) - - -def map_precision(p: str): - p = (p or "").lower() - if p in ("bf16", "bfloat16", "bf16-mixed"): - return "bf16-mixed" - if p in ("fp16", "16", "16-mixed"): - return "16-mixed" - return 32 - - -# Timing callback for steady-state step timing -class SteadyStateStepTimer(Callback): - def __init__(self, warmup_steps: int, measure_steps: int): - super().__init__() - self.warmup_steps = int(warmup_steps) - self.measure_steps = int(measure_steps) - self._seen_steps = 0 - self.step_times = [] - self._step_start_t = None - - @staticmethod - def _sync_if_cuda(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): - s = self._seen_steps - if self.warmup_steps <= s < (self.warmup_steps + self.measure_steps): - self._sync_if_cuda() - self._step_start_t = time.time() - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): - s = self._seen_steps - if self.warmup_steps <= s < (self.warmup_steps + self.measure_steps): - self._sync_if_cuda() - dt = time.time() - (self._step_start_t or time.time()) - self.step_times.append(dt) - self._seen_steps += 1 - - def measured_wall_time(self) -> float: - return float(sum(self.step_times)) - - -def dist_max(value: float) -> float: - # Return max(value across ranks) if distributed is initialized - try: - import torch.distributed as dist - - if dist.is_available() and dist.is_initialized(): - t = torch.tensor( - [value], device="cuda" if torch.cuda.is_available() else "cpu" - ) - dist.all_reduce(t, op=dist.ReduceOp.MAX) - return float(t.item()) - except Exception: - pass - return float(value) - - -# Synthetic buffer construction -def make_synthetic_tensors( - n: int, - c_dim: int, - x_dim: int, - device: torch.device, - seed: int, -) -> Dict[str, torch.Tensor]: - # Build a fixed synthetic buffer with per-rank seeding - g = torch.Generator(device=device) - g.manual_seed(int(seed) + 1000 * global_rank()) - - C = torch.randn((n, c_dim), generator=g, device=device, dtype=torch.float32) - X = torch.randn((n, x_dim), generator=g, device=device, dtype=torch.float32) - Y = torch.randn((n, x_dim), generator=g, device=device, dtype=torch.float32) - return {"C": C, "X": X, "Y": Y} - - -# Model/datamodule/trainer builders -def build_model(args): - # Instantiate the selected network LightningModule with signature-filtered kwargs - import inspect - - if args.network == "correlation": - model_cls = ContextualizedCorrelation - elif args.network == "markov": - model_cls = ContextualizedMarkovGraph - elif args.network == "bayesian": - model_cls = NOTMAD - else: - raise ValueError(f"Unknown --network {args.network}") - - encoder_kwargs = {"width": args.width, "layers": args.layers, "link_fn": "identity"} - - kw = dict( - context_dim=args.context_dim, - x_dim=args.x_dim, - y_dim=args.x_dim, - univariate=True, - num_archetypes=args.num_archetypes, - encoder_type=args.encoder_type, - encoder_kwargs=encoder_kwargs, - learning_rate=args.lr, - link_fn="identity", - fit_intercept=True, - loss_fn="mse", - model_regularizer="none", - ) - - if args.network == "bayesian": - kw.update( - archetype_loss_params=dict( - l1=0.0, - dag=dict(loss_type="notears", params=dict(alpha=1.0, rho=1.0, s=1.0, tol=1e-8)), - init_mat=None, - num_factors=0, - factor_mat_l1=0.0, - num_archetypes=max(1, int(args.num_archetypes)), - ), - sample_specific_loss_params=dict( - l1=0.0, - dag=dict(loss_type="notears", params=dict(alpha=1.0, rho=1.0, s=1.0, tol=1e-8)), - ), - opt_params=dict( - learning_rate=args.lr, - step=50, - ), - ) - - sig = inspect.signature(model_cls.__init__) - accepts_var_kw = any( - p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() - ) - if accepts_var_kw: - return model_cls(**kw) - - filtered = {k: v for k, v in kw.items() if k in sig.parameters} - required = [ - name - for name, p in sig.parameters.items() - if name != "self" - and p.default is inspect._empty - and p.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) - ] - missing = [r for r in required if r not in filtered] - if missing: - raise TypeError( - f"{model_cls.__name__}.__init__ missing required args {missing}. " - f"Accepted params in script: {sorted(filtered.keys())}. " - f"Signature: {sig}" - ) - return model_cls(**filtered) - - -def build_dm(args, C, X, Y) -> ContextualizedRegressionDataModule: - # Construct the datamodule with a fixed synthetic buffer and deterministic indices - n = int(C.shape[0]) - n_train = max(1, int(0.98 * n)) - train_idx = np.arange(0, n_train, dtype=np.int64) - val_idx = np.arange(n_train, n, dtype=np.int64) - - task_type = args.task_type - if task_type is None: - task_type = "singletask_univariate" - - dm = ContextualizedRegressionDataModule( - C=C, - X=X, - Y=Y, - task_type=task_type, - train_idx=train_idx, - val_idx=val_idx, - test_idx=None, - predict_idx=None, - train_batch_size=args.batch_size, - val_batch_size=args.batch_size, - test_batch_size=args.batch_size, - predict_batch_size=args.batch_size, - num_workers=args.num_workers, - pin_memory=bool(args.pin_memory), - persistent_workers=bool(args.num_workers > 0), - drop_last=True, - shuffle_train=False, - shuffle_eval=False, - dtype=torch.float, - ) - dm.prepare_data() - dm.setup() - return dm - - -def build_trainer(args, timer: SteadyStateStepTimer) -> pl.Trainer: - if torch.cuda.is_available(): - accelerator = "gpu" - devices = 1 if under_torchrun() else min(args.devices, torch.cuda.device_count()) - strategy = ( - DDPStrategy( - find_unused_parameters=False, - gradient_as_bucket_view=True, - static_graph=True, - timeout=timedelta(seconds=args.ddp_timeout), - ) - if (under_torchrun() or devices > 1) - else "auto" - ) - else: - accelerator = "cpu" - devices = 1 - strategy = "auto" - - max_steps = int(args.warmup_steps + args.steps) - - return pl.Trainer( - accelerator=accelerator, - devices=devices, - strategy=strategy, - precision=map_precision(args.precision), - max_steps=max_steps, - max_epochs=10_000, - logger=False, - enable_checkpointing=False, - enable_progress_bar=False, - enable_model_summary=False, - num_sanity_val_steps=0, - log_every_n_steps=50, - callbacks=[timer], - inference_mode=False, - detect_anomaly=False, - accumulate_grad_batches=1, - limit_val_batches=0, - use_distributed_sampler=False, - ) - - -# Benchmark runner -@dataclass -class Result: - network: str - world_size: int - batch_size_per_gpu: int - global_batch_size: int - warmup_steps: int - measured_steps: int - measured_wall_s: float - throughput_samples_per_s: float - per_gpu_throughput_samples_per_s: float - avg_step_s: float - p95_step_s: float - data_device: str - - -def run_bench(args) -> Result: - ws = world_size() if under_torchrun() else int(args.devices) - - if args.data_device == "cpu": - dev = torch.device("cpu") - elif args.data_device == "cuda": - dev = torch.device("cuda", local_rank()) if torch.cuda.is_available() else torch.device("cpu") - else: - dev = torch.device("cuda", local_rank()) if torch.cuda.is_available() else torch.device("cpu") - - if dev.type == "cuda" and args.num_workers != 0: - if is_global_zero(): - print("NOTE: forcing --num-workers=0 because data-device is CUDA.") - args.num_workers = 0 - - n = int(args.batch_size * args.buffer_batches) - tensors = make_synthetic_tensors( - n=n, - c_dim=args.context_dim, - x_dim=args.x_dim, - device=dev, - seed=args.seed, - ) - - dm = build_dm(args, tensors["C"], tensors["X"], tensors["Y"]) - model = build_model(args) - - timer = SteadyStateStepTimer(args.warmup_steps, args.steps) - trainer = build_trainer(args, timer) - - if is_global_zero(): - print( - "\nConfig:", - json.dumps( - { - "network": args.network, - "torchrun": under_torchrun(), - "world_size": ws, - "local_rank": local_rank(), - "batch_size_per_gpu": args.batch_size, - "global_batch_size": args.batch_size * ws, - "steps_measured": args.steps, - "steps_warmup": args.warmup_steps, - "buffer_samples_per_rank": n, - "data_device": str(dev), - "precision": map_precision(args.precision), - "task_type": args.task_type or "singletask_univariate", - }, - indent=2, - ), - ) - - trainer.fit(model, train_dataloaders=dm.train_dataloader()) - - measured_wall = timer.measured_wall_time() - measured_wall = dist_max(measured_wall) - - measured_steps = int(args.steps) - global_batch = int(args.batch_size * ws) - samples_total = global_batch * measured_steps - throughput = samples_total / max(measured_wall, 1e-12) - per_gpu = throughput / max(ws, 1) - - step_times = timer.step_times[:] if timer.step_times else [float("nan")] - avg_step = float(np.mean(step_times)) - p95_step = float(np.percentile(step_times, 95)) if len(step_times) > 1 else float("nan") - - return Result( - network=str(args.network), - world_size=int(ws), - batch_size_per_gpu=int(args.batch_size), - global_batch_size=int(global_batch), - warmup_steps=int(args.warmup_steps), - measured_steps=int(measured_steps), - measured_wall_s=float(measured_wall), - throughput_samples_per_s=float(throughput), - per_gpu_throughput_samples_per_s=float(per_gpu), - avg_step_s=float(avg_step), - p95_step_s=float(p95_step), - data_device=str(dev), - ) - - -def save_result(outdir: str, res: Result) -> str: - os.makedirs(outdir, exist_ok=True) - path = os.path.join(outdir, "result.json") - with open(path, "w") as f: - json.dump(res.__dict__, f, indent=2) - return path - - -# Entrypoint -def parse_args(): - ap = argparse.ArgumentParser() - ap.add_argument("--network", type=str, choices=["correlation", "markov", "bayesian"], default="correlation") - - ap.add_argument("--steps", type=int, default=400, help="Measured optimizer steps") - ap.add_argument("--warmup-steps", type=int, default=50, help="Warmup steps excluded from timing") - - ap.add_argument("--batch-size", type=int, default=2048, help="Per-GPU batch size") - ap.add_argument("--num-workers", type=int, default=0) - ap.add_argument("--pin-memory", action="store_true", default=False) - ap.add_argument("--precision", type=str, default="bf16") - - ap.add_argument("--context-dim", type=int, default=16) - ap.add_argument("--x-dim", type=int, default=512) - - ap.add_argument("--encoder-type", type=str, default="mlp") - ap.add_argument("--num-archetypes", type=int, default=8) - ap.add_argument("--width", type=int, default=1024) - ap.add_argument("--layers", type=int, default=4) - ap.add_argument("--lr", type=float, default=1e-3) - - ap.add_argument("--buffer-batches", type=int, default=32, - help="Per-rank buffer size = batch_size * buffer_batches") - ap.add_argument("--data-device", type=str, choices=["auto", "cpu", "cuda"], default="auto") - - ap.add_argument("--task-type", type=str, default=None, - help="Override task_type if needed (default: singletask_univariate)") - - ap.add_argument("--devices", type=int, default=1, help="Only used when NOT under torchrun") - ap.add_argument("--ddp-timeout", type=int, default=180) - ap.add_argument("--seed", type=int, default=123) - ap.add_argument("--outdir", type=str, default="bench_out") - return ap.parse_args() - - -def main(): - set_env_defaults() - args = parse_args() - - if args.data_device == "cpu": - os.environ["CUDA_VISIBLE_DEVICES"] = "" - - res = run_bench(args) - - if is_global_zero(): - path = save_result(args.outdir, res) - print("\nResult:", json.dumps(res.__dict__, indent=2)) - print(f"\nSaved → {path}") - - -if __name__ == "__main__": - main() From 52934a786a55a33987a6b2541a52e9e1bf764c25 Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Wed, 14 Jan 2026 16:39:28 -0500 Subject: [PATCH 18/19] cleaning old testing files and updating doc strings --- contextualized/__init__.py | 11 +---- contextualized/easy/ContextualGAM.py | 28 +++++++++++++ .../easy/ContextualizedClassifier.py | 30 ++++++++++++++ .../easy/wrappers/SKLearnWrapper.py | 18 +++++---- contextualized/modules.py | 19 +++++---- contextualized/utils/engine.py | 40 ------------------- 6 files changed, 79 insertions(+), 67 deletions(-) delete mode 100644 contextualized/utils/engine.py diff --git a/contextualized/__init__.py b/contextualized/__init__.py index 82574ff5..3c5e4932 100644 --- a/contextualized/__init__.py +++ b/contextualized/__init__.py @@ -6,7 +6,7 @@ if torch.cuda.is_available(): try: - torch.set_float32_matmul_precision("high") # use TF32 kernels + torch.set_float32_matmul_precision("high") except Exception: pass from contextualized import analysis @@ -18,12 +18,3 @@ from contextualized.utils import * -import os -os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") -os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "1") -os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1") -# single-node default (disable IB unless you know you need it) -os.environ.setdefault("NCCL_IB_DISABLE", "1") -os.environ.setdefault("NCCL_P2P_DISABLE", "0") -from .utils.engine import pick_engine # optional re-export -__all__ = ["pick_engine"] \ No newline at end of file diff --git a/contextualized/easy/ContextualGAM.py b/contextualized/easy/ContextualGAM.py index f21777d0..c526f103 100644 --- a/contextualized/easy/ContextualGAM.py +++ b/contextualized/easy/ContextualGAM.py @@ -10,6 +10,20 @@ class ContextualGAMClassifier(ContextualizedClassifier): """ + The Contextual GAM Classifier separates and interprets the effect of context in context-varying decisions and classifiers, such as + heterogeneous disease diagnoses. + Implemented as a Contextual Generalized Additive Model with a classifier on top. + Always uses a Neural Additive Model ("ngam") encoder for interpretability. + See `this paper `__ + for more details. + + Args: + n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. + num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. + alpha (float, optional): Regularization strength. Defaults to 0.0. + mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. + l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. + Contextual GAM Classifier with a Neural Additive Model ("ngam") encoder. Inherits the sklearn-like API from ContextualizedClassifier. """ @@ -22,6 +36,20 @@ def __init__(self, **kwargs): class ContextualGAMRegressor(ContextualizedRegressor): """ + The Contextual GAM Regressor separates and interprets the effect of context in context-varying relationships, such as heterogeneous + treatment effects. + Implemented as a Contextual Generalized Additive Model with a linear regressor on top. + Always uses a Neural Additive Model ("ngam") encoder for interpretability. + See `this paper `__ + for more details. + + Args: + n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. + num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. + alpha (float, optional): Regularization strength. Defaults to 0.0. + mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. + l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. + Contextual GAM Regressor with a Neural Additive Model ("ngam") encoder. Inherits the sklearn-like API from ContextualizedRegressor. """ diff --git a/contextualized/easy/ContextualizedClassifier.py b/contextualized/easy/ContextualizedClassifier.py index 32e5eb5a..e6a668b6 100644 --- a/contextualized/easy/ContextualizedClassifier.py +++ b/contextualized/easy/ContextualizedClassifier.py @@ -13,6 +13,14 @@ class ContextualizedClassifier(ContextualizedRegressor): """ Contextualized Logistic Regression reveals context-dependent decisions and decision boundaries. Implemented as a ContextualizedRegressor with logistic link function and binary cross-entropy loss. + + Args: + n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. + num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. + encoder_type (str, optional): Type of encoder to use ("mlp", "ngam", "linear"). Defaults to "mlp". + alpha (float, optional): Regularization strength. Defaults to 0.0. + mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. + l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. """ def __init__(self, **kwargs): @@ -21,6 +29,16 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def predict(self, C, X, individual_preds: bool = False, **kwargs): + """Predict binary outcomes from context C and predictors X. + + Args: + C (np.ndarray): Context array of shape (n_samples, n_context_features) + X (np.ndarray): Predictor array of shape (N, n_features) + individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. + + Returns: + Union[np.ndarray, List[np.ndarray]]: The binary outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True. + """ out = super().predict(C, X, individual_preds=individual_preds, **kwargs) if out is None: return None @@ -40,6 +58,18 @@ def predict(self, C, X, individual_preds: bool = False, **kwargs): ] def predict_proba(self, C, X, **kwargs): + """ + Predict probabilities of outcomes from context C and predictors X. + + Args: + C (np.ndarray): Context array of shape (n_samples, n_context_features) + X (np.ndarray): Predictor array of shape (N, n_features) + individual_preds (bool, optional): Whether to return individual predictions for each model. Defaults to False. + + Returns: + Union[np.ndarray, List[np.ndarray]]: The outcome probabilities predicted by the context-specific models (n_samples, y_dim, 2). Returned as lists of individual bootstraps if individual_preds is True. + """ + # Returns a np array of shape N samples, K outcomes, 2. probs = super().predict(C, X, **kwargs) if probs is None: return None diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index 12792e3c..83968ad4 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -135,6 +135,7 @@ def __init__( "X_val", "Y_val", "val_split", + "random_state", "num_workers", "pin_memory", "persistent_workers", @@ -143,6 +144,7 @@ def __init__( "shuffle_eval", "dtype", ], + "model": [ "loss_fn", "link_fn", @@ -402,7 +404,7 @@ def _resolve_train_val_arrays( Y_val: Optional[np.ndarray], Y_required: bool, val_split: float, - random_state: int = 42, + random_state: Optional[int] = None, shuffle: bool = True, ) -> Tuple[np.ndarray, np.ndarray, Optional[np.ndarray], np.ndarray, Optional[np.ndarray]]: if ( @@ -430,14 +432,14 @@ def _resolve_train_val_arrays( if vs <= 0.0: return C, X, Y, np.arange(n), None - tr_idx, va_idx = train_test_split( - np.arange(n), - test_size=vs, - shuffle=shuffle, - random_state=random_state, - ) + split_kwargs = dict(test_size=vs, shuffle=shuffle) + if random_state is not None: + split_kwargs["random_state"] = random_state + + tr_idx, va_idx = train_test_split(np.arange(n), **split_kwargs) return C, X, Y, tr_idx, va_idx + def _build_datamodule( self, C: np.ndarray, @@ -764,8 +766,10 @@ def _inner(i): Y_val=Y_val, Y_required=True, val_split=val_split, + random_state=organized["data"].get("random_state", None), ) + for b in range(self.n_bootstraps): model_kwargs = dict(organized["model"]) model_kwargs.pop("univariate", None) diff --git a/contextualized/modules.py b/contextualized/modules.py index 52c7c6ff..65dd45c0 100644 --- a/contextualized/modules.py +++ b/contextualized/modules.py @@ -7,6 +7,7 @@ from contextualized.functions import LINK_FUNCTIONS + def _resolve_link_fn(maybe_link): """ Accepts either: @@ -78,13 +79,14 @@ def set_archetypes(self, archetypes): class Explainer(SoftSelect): - """ 2D subtype-archetype parameter sharing """ + """ + 2D subtype-archetype parameter sharing + """ + def __init__(self, k, out_shape): super().__init__((k,), out_shape) - - class MLP(nn.Module): """ Multi-layer perceptron @@ -110,7 +112,6 @@ def __init__( self.mlp = nn.Sequential(*mlp_layers) self.link_fn = _resolve_link_fn(link_fn) - def forward(self, X): """Torch Forward pass.""" ret = self.mlp(X) @@ -119,9 +120,7 @@ def forward(self, X): class NGAM(nn.Module): """ - Neural generalized additive model: sum_i f_i(x_i). - Each f_i is an MLP that outputs (B, output_dim). - The final link function is applied once to the summed output. + Neural generalized additive model """ def __init__( @@ -137,6 +136,7 @@ def __init__( self.input_dim = input_dim self.output_dim = output_dim + # Each feature-wise network uses an identity link; the global link is applied once. per_feat_link = "identity" self.nams = nn.ModuleList( @@ -155,14 +155,13 @@ def __init__( self.link_fn = _resolve_link_fn(link_fn) def forward(self, X): - """X: (B, input_dim)""" + """Torch Forward pass.""" ret = self.nams[0](X[:, 0].unsqueeze(-1)) for i, nam in enumerate(self.nams[1:], start=1): ret += nam(X[:, i].unsqueeze(-1)) return self.link_fn(ret) - class Linear(nn.Module): """ Linear encoder @@ -179,4 +178,4 @@ def forward(self, X): return self.linear(X) -ENCODERS = {"mlp": MLP, "ngam": NGAM, "linear": Linear} \ No newline at end of file +ENCODERS = {"mlp": MLP, "ngam": NGAM, "linear": Linear} diff --git a/contextualized/utils/engine.py b/contextualized/utils/engine.py deleted file mode 100644 index ebabb468..00000000 --- a/contextualized/utils/engine.py +++ /dev/null @@ -1,40 +0,0 @@ -# contextualized/utils/engine.py -import os, torch -from typing import Tuple, Union - -def _under_torchrun() -> bool: - e = os.environ - return any(k in e for k in ("LOCAL_RANK", "RANK", "WORLD_SIZE")) - -def _visible_gpus() -> int: - return torch.cuda.device_count() if torch.cuda.is_available() else 0 - -def pick_engine( - accelerator: str | None = None, - devices: Union[int, str, list[int]] | None = None, - strategy: str | None = None, - prefer_spawn: bool = True, -) -> Tuple[str, Union[int, str, list[int]], Union[str, object]]: - """ - CPU / 1-GPU / multi-GPU auto-selection WITHOUT requiring torchrun. - - If user passes any of (accelerator/devices/strategy), we respect them. - - Else: - GPUs == 0 => cpu, devices='auto' - GPUs == 1 => gpu, devices=1 - GPUs > 1 => - - if launched with torchrun => gpu, devices=1, strategy='ddp' - - else => gpu, devices=, strategy='ddp_spawn' - """ - if accelerator is not None or devices is not None or strategy is not None: - return accelerator or "auto", devices or "auto", strategy or "auto" - - ngpu = _visible_gpus() - if ngpu == 0: - return "cpu", "auto", "auto" - - if ngpu == 1: - return "gpu", 1, "auto" - - if _under_torchrun(): - return "gpu", 1, "ddp" # one proc per GPU (torchrun sets ranks) - return "gpu", ngpu, ("ddp_spawn" if prefer_spawn else "auto") \ No newline at end of file From abb13daff7bbc12adb058de2a5026b64dc0478fa Mon Sep 17 00:00:00 2001 From: Samuel-WM Date: Wed, 14 Jan 2026 16:46:09 -0500 Subject: [PATCH 19/19] reversion of files --- contextualized/easy/ContextualGAM.py | 19 +++--------- .../easy/ContextualizedClassifier.py | 31 ++----------------- 2 files changed, 7 insertions(+), 43 deletions(-) diff --git a/contextualized/easy/ContextualGAM.py b/contextualized/easy/ContextualGAM.py index c526f103..e09ce295 100644 --- a/contextualized/easy/ContextualGAM.py +++ b/contextualized/easy/ContextualGAM.py @@ -4,14 +4,12 @@ for more details. """ -from contextualized.easy import ContextualizedClassifier -from contextualized.easy import ContextualizedRegressor +from contextualized.easy import ContextualizedClassifier, ContextualizedRegressor class ContextualGAMClassifier(ContextualizedClassifier): """ - The Contextual GAM Classifier separates and interprets the effect of context in context-varying decisions and classifiers, such as - heterogeneous disease diagnoses. + The Contextual GAM Classifier separates and interprets the effect of context in context-varying decisions and classifiers, such as heterogeneous disease diagnoses. Implemented as a Contextual Generalized Additive Model with a classifier on top. Always uses a Neural Additive Model ("ngam") encoder for interpretability. See `this paper `__ @@ -23,21 +21,16 @@ class ContextualGAMClassifier(ContextualizedClassifier): alpha (float, optional): Regularization strength. Defaults to 0.0. mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. - - Contextual GAM Classifier with a Neural Additive Model ("ngam") encoder. - Inherits the sklearn-like API from ContextualizedClassifier. """ def __init__(self, **kwargs): - # Force interpretability via NAM encoder kwargs["encoder_type"] = "ngam" super().__init__(**kwargs) class ContextualGAMRegressor(ContextualizedRegressor): """ - The Contextual GAM Regressor separates and interprets the effect of context in context-varying relationships, such as heterogeneous - treatment effects. + The Contextual GAM Regressor separates and interprets the effect of context in context-varying relationships, such as heterogeneous treatment effects. Implemented as a Contextual Generalized Additive Model with a linear regressor on top. Always uses a Neural Additive Model ("ngam") encoder for interpretability. See `this paper `__ @@ -49,12 +42,8 @@ class ContextualGAMRegressor(ContextualizedRegressor): alpha (float, optional): Regularization strength. Defaults to 0.0. mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. - - Contextual GAM Regressor with a Neural Additive Model ("ngam") encoder. - Inherits the sklearn-like API from ContextualizedRegressor. """ def __init__(self, **kwargs): - # Force interpretability via NAM encoder kwargs["encoder_type"] = "ngam" - super().__init__(**kwargs) + super().__init__(**kwargs) \ No newline at end of file diff --git a/contextualized/easy/ContextualizedClassifier.py b/contextualized/easy/ContextualizedClassifier.py index e6a668b6..11bb17ef 100644 --- a/contextualized/easy/ContextualizedClassifier.py +++ b/contextualized/easy/ContextualizedClassifier.py @@ -28,7 +28,7 @@ def __init__(self, **kwargs): kwargs["loss_fn"] = LOSSES["bceloss"] super().__init__(**kwargs) - def predict(self, C, X, individual_preds: bool = False, **kwargs): + def predict(self, C, X, individual_preds=False, **kwargs): """Predict binary outcomes from context C and predictors X. Args: @@ -39,23 +39,7 @@ def predict(self, C, X, individual_preds: bool = False, **kwargs): Returns: Union[np.ndarray, List[np.ndarray]]: The binary outcomes predicted by the context-specific models (n_samples, y_dim). Returned as lists of individual bootstraps if individual_preds is True. """ - out = super().predict(C, X, individual_preds=individual_preds, **kwargs) - if out is None: - return None - - out = np.asarray(out) - - if not individual_preds: - # common binary case: (N, 1, 1) or (N, 1) - if out.ndim == 3 and out.shape[-1] == 1: - out = out[..., 0] - return np.round(out) - - # individual_preds=True: list/array across bootstraps - return [ - np.round(p[..., 0] if (p.ndim == 3 and p.shape[-1] == 1) else p) - for p in out - ] + return np.round(super().predict(C, X, individual_preds, **kwargs)) def predict_proba(self, C, X, **kwargs): """ @@ -71,13 +55,4 @@ def predict_proba(self, C, X, **kwargs): """ # Returns a np array of shape N samples, K outcomes, 2. probs = super().predict(C, X, **kwargs) - if probs is None: - return None - - probs = np.asarray(probs) - if probs.ndim == 3 and probs.shape[-1] == 1: - probs = probs[..., 0] - - p1 = probs - p0 = 1.0 - p1 - return np.stack([p0, p1], axis=-1) + return np.array([1 - probs, probs]).T.swapaxes(0, 1) \ No newline at end of file