From fe189c5861a85ac65b8dd9afce740d689a12b5f8 Mon Sep 17 00:00:00 2001 From: Dennis Goldfarb Date: Sun, 10 Aug 2025 22:41:31 -0500 Subject: [PATCH] Detach charge embedding for tracing --- altimeter/export.py | 18 +++++++++--------- altimeter/model.py | 3 ++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/altimeter/export.py b/altimeter/export.py index 7bf20a3..74a82e7 100644 --- a/altimeter/export.py +++ b/altimeter/export.py @@ -61,30 +61,30 @@ def main(): model = LitFlipyFlopy.load_from_checkpoint( args.model_ckpt, config=config, model_config=model_config ) - + model.eval() input_seq = torch.zeros((1, channels, D.seq_len), dtype=torch.float32, device=device) input_ch = torch.zeros((1,1), dtype=torch.float32, device=device) - + input_sample = [input_seq, input_ch] input_names = ["inp", "inpch"] output_names = ["coefficients", "knots", "AUCs"] - + #print(model.model.get_knots()) - - script = torch.jit.trace( - lambda seq, ch: model.forward_coef([seq, ch]), - (input_seq, input_ch) - ) + with torch.no_grad(): + script = torch.jit.trace( + lambda seq, ch: model.forward_coef([seq, ch]), + (input_seq, input_ch) + ) torch.jit.save(script, args.altimeter_outpath) # repeat for splines model2 = LitBSplineNN() input_coef = torch.zeros((1, 4, 380), dtype=torch.float32, device=device) - input_knots = model.model.get_knots().unsqueeze(0).to(device) + input_knots = model.model.get_knots().unsqueeze(0).to(device).detach() input_ce = torch.zeros((1,1), dtype=torch.float32, device=device) input_sample = (input_coef, input_knots, input_ce) y = model2(*input_sample) diff --git a/altimeter/model.py b/altimeter/model.py index 2bdc0fc..3eadf4c 100755 --- a/altimeter/model.py +++ b/altimeter/model.py @@ -462,8 +462,9 @@ def compute_coefficients(self, inp:tuple[torch.Tensor, torch.Tensor]): elif len(inpch.shape) == 2: inpch = inpch.squeeze(1) + ce_feat = self.embedCE(inpch, self.cesz, 10.0).detach() ch_embed = nn.functional.silu( - self.denseCH(self.embedCE(inpch, self.cesz, 10.0)) + self.denseCH(ce_feat) ) embed = self.postcat(torch.cat([ch_embed],-1))