#called in the predict()
def _forward(self, frames_in, frames_gt):
B, T_in, c, h, w = frames_in.shape
T_out = frames_gt.shape[1]
# 确定性预测,调用simVP
device = frames_in.device
backbone_output, backbone_loss = self.backbone_net.predict(frames_in, frames_gt=frames_gt,
compute_loss=True)
#归一化
frames_in = self.normalize(frames_in)
frames_gt = self.normalize(frames_gt)
backbone_output = self.normalize(backbone_output)
#计算残差r = y - mu 和 h
residual = frames_gt - backbone_output #eq.7
global_ctx, local_ctx = self.ctx_net.scan_ctx(torch.cat((frames_in, backbone_output), dim=1)) #eq.14
#进入
pre_frag = frames_in
pre_mu = None
pred_ress = []
diff_loss = 0.
t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long() #随机在[0,T]之间采样一个batch的时间步
#以segment进行循环
for frag_idx in range(T_out // T_in):
#取当前segment的mu和r
mu = backbone_output[:, frag_idx * T_in: (frag_idx + 1) * T_in] # ^mu_j
res = residual[:, frag_idx * T_in: (frag_idx + 1) * T_in] # ^s_j
# s_j-1 由于j=0时,s_-1没有值,用frame_in代替
cond = pre_frag - pre_mu if pre_mu is not None else torch.zeros_like(pre_frag)
# 用 s_j-1,h,t 来进行预测
_, noise_loss = self.p_losses(res, t, cond=cond, ctx=global_ctx if frag_idx > 0 else local_ctx,
idx=torch.full((B,), frag_idx, device=device, dtype=torch.long))
diff_loss += noise_loss
pre_frag = frames_gt[:, frag_idx * T_in: (frag_idx + 1) * T_in]
pre_mu = mu
diff_loss /= (T_out // T_in)
loss = (1 - self.loss_weight_factor) * backbone_loss + self.loss_weight_factor * diff_loss
return loss
@autocast(enabled = False)
def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, ctx=None, idx=None, cond=None):
b, _, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
if offset_noise_strength > 0.:
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
# noise sample
x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating
model_out = self.model(x, t, cond=cond, ctx=ctx, idx=idx)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
loss = F.mse_loss(model_out, target, reduction = 'none')
loss = reduce(loss, 'b ... -> b', 'mean')
loss = loss * extract(self.loss_weight, t, loss.shape)
return model_out, loss.mean()
After the 200K iterations on the SHANGHAI dataset of 5:20, I go the following results:
01/05/2025 12:52:33 - INFO - root - ****************************** < Evaluation Results: > ******************************
01/05/2025 12:52:33 - INFO - root - Total 850 samples with 20 seq_len.
01/05/2025 12:52:33 - INFO - root - ******************************************************************************************
01/05/2025 12:52:33 - INFO - root - ====================Threshold: 20 with melthod 1====================
01/05/2025 12:52:33 - INFO - root - <CSI> : 0.565800888469066; [0.82626148 0.77370377 0.73132469 0.6950084 0.65983306 0.63396121
0.61126816 0.5897937 0.57044672 0.54923247 0.53140962 0.5152367
0.5000041 0.48699576 0.47042913 0.4575324 0.44525369 0.43425388
0.42381216 0.41025667]
01/05/2025 12:52:33 - INFO - root - <FAR> : 0.24224349243045568; [0.09140799 0.11713036 0.13861137 0.1577845 0.17509894 0.19173357
0.20543915 0.21761054 0.23058481 0.2454496 0.26172361 0.27388339
0.28266903 0.29216861 0.30459843 0.31655573 0.32447319 0.33138028
0.3382836 0.34828315]
01/05/2025 12:52:33 - INFO - root - <POD> : 0.6789452268045424; [0.90117156 0.86220701 0.82886745 0.79904948 0.76730174 0.74617523
0.7260127 0.70553094 0.68807796 0.66870393 0.65475824 0.63952266
0.62269332 0.60951819 0.59252213 0.58056485 0.56638467 0.55334772
0.54103347 0.52546129]
01/05/2025 12:52:33 - INFO - root - <HSS> : 0.6890801163422913; [0.89509785 0.85943835 0.82916844 0.80209205 0.77483725 0.7540032
0.7352728 0.71710628 0.70032196 0.68142543 0.66503864 0.65001531
0.63570107 0.62322894 0.60698766 0.59403597 0.58174211 0.57058332
0.55987288 0.54563282]
01/05/2025 12:52:33 - INFO - root - < CSI_POOL 4x4 > : 0.6325025462842695; CSI_POOL 16x16: 0.7372783605343171
01/05/2025 12:52:33 - INFO - root - ====================Threshold: 30 with melthod 1====================
01/05/2025 12:52:33 - INFO - root - <CSI> : 0.47518983292786804; [0.77611761 0.70332587 0.65081106 0.60978278 0.57104518 0.54241503
0.51670373 0.49470957 0.4747346 0.45419573 0.4378177 0.42049936
0.40426451 0.3901161 0.37451299 0.36072241 0.3478775 0.33526671
0.32560534 0.31327288]
01/05/2025 12:52:33 - INFO - root - <FAR> : 0.2974831751010489; [0.11926539 0.15896182 0.18961822 0.21324625 0.23454383 0.25044013
0.26632791 0.27872468 0.29262943 0.30590132 0.31843402 0.33102171
0.339345 0.34945696 0.36111365 0.37224913 0.38123591 0.38774515
0.39555952 0.40384346]
01/05/2025 12:52:33 - INFO - root - <POD> : 0.5818251720470691; [0.86726605 0.81115542 0.76772015 0.73052289 0.69215437 0.66247519
0.63599545 0.61163767 0.59075331 0.56786671 0.5504043 0.53098057
0.51020965 0.4935439 0.47507953 0.45887849 0.44278149 0.42564184
0.41377368 0.39766279]
01/05/2025 12:52:33 - INFO - root - <HSS> : 0.6179291139958538; [0.86733634 0.81679291 0.7776323 0.74534312 0.71336996 0.68876311
0.66590397 0.6457995 0.62703107 0.60721697 0.59098745 0.57350071
0.5568494 0.54201544 0.52525985 0.51014044 0.49587965 0.48172341
0.4706891 0.45634757]
01/05/2025 12:52:33 - INFO - root - < CSI_POOL 4x4 > : 0.5486872387090821; CSI_POOL 16x16: 0.650950245431883
01/05/2025 12:52:33 - INFO - root - ====================Threshold: 35 with melthod 1====================
01/05/2025 12:52:33 - INFO - root - <CSI> : 0.39606485937913605; [0.73081963 0.63969667 0.57626237 0.52955149 0.48813026 0.45701573
0.42922055 0.40402249 0.38486699 0.36413771 0.35018367 0.3348004
0.32014697 0.30861241 0.29448463 0.2841865 0.27228398 0.26026559
0.25226034 0.2403488 ]
01/05/2025 12:52:33 - INFO - root - <FAR> : 0.3634073054154245; [0.1444574 0.20044106 0.24029416 0.27191867 0.29757277 0.31559764
0.33603606 0.35141051 0.36686866 0.38152131 0.39263914 0.40685528
0.41472505 0.42344394 0.43519667 0.44320855 0.45207288 0.4567756
0.46447454 0.47263621]
01/05/2025 12:52:33 - INFO - root - <POD> : 0.4973780625681573; [0.83369612 0.76187494 0.70471085 0.66010192 0.61538519 0.57904561
0.5483362 0.51724965 0.49533172 0.46962838 0.45265733 0.4346078
0.41408439 0.39906242 0.38091826 0.36726745 0.35117589 0.33318127
0.32291046 0.30633542]
01/05/2025 12:52:33 - INFO - root - <HSS> : 0.5443081733634114; [0.83988494 0.77385992 0.72345734 0.68371708 0.64648701 0.61720833
0.58996796 0.56440731 0.54435543 0.52208137 0.50671081 0.48939319
0.4726161 0.45916685 0.44233217 0.42988669 0.41527263 0.40032962
0.39016709 0.37486163]
01/05/2025 12:52:33 - INFO - root - < CSI_POOL 4x4 > : 0.48441086275194206; CSI_POOL 16x16: 0.5865675293493181
01/05/2025 12:52:33 - INFO - root - ====================Threshold: 40 with melthod 1====================
01/05/2025 12:52:33 - INFO - root - <CSI> : 0.285276937459521; [0.6645734 0.54933477 0.47211479 0.41623437 0.37060638 0.33353102
0.30549998 0.27927691 0.25895973 0.24064859 0.22956794 0.21639686
0.20275106 0.1936537 0.18488311 0.1772157 0.16702052 0.15529105
0.14945041 0.13852845]
01/05/2025 12:52:33 - INFO - root - <FAR> : 0.4689316897550597; [0.18273019 0.26169754 0.31643544 0.36079712 0.39470287 0.41781393
0.44302741 0.46294865 0.48355582 0.49967694 0.51101887 0.52326946
0.53594818 0.54682181 0.55620708 0.56156442 0.56961588 0.57769891
0.58252461 0.59057862]
01/05/2025 12:52:33 - INFO - root - <POD> : 0.36574409829472676; [0.7805557 0.68216308 0.6041532 0.54405659 0.48871093 0.43848919
0.40356718 0.36782879 0.34184685 0.31678369 0.30202902 0.28380725
0.26474463 0.2527027 0.24064379 0.22925242 0.21441824 0.19717867
0.18882775 0.17312233]
01/05/2025 12:52:33 - INFO - root - <HSS> : 0.4211213455260009; [0.7956374 0.70508913 0.63656133 0.58233788 0.53489373 0.49405709
0.461629 0.43007983 0.40470491 0.38120264 0.3666254 0.34901006
0.33035745 0.3176708 0.30529602 0.29437251 0.27963301 0.26240511
0.2536806 0.23718301]
01/05/2025 12:52:33 - INFO - root - < CSI_POOL 4x4 > : 0.40419628442688127; CSI_POOL 16x16: 0.51684827582216
01/05/2025 12:52:33 - INFO - root - ********************Overall Avg Metrics on Thresholds [20, 30, 35, 40]********************
01/05/2025 12:52:33 - INFO - root - [ avg_csi ] : 0.43058312955889777; [ avg_far ] : 0.3430164156754972; [ avg_pod ] : 0.5309731399286239; [ avg_hss] : 0.5681096873068894
01/05/2025 12:52:33 - INFO - root - [ avg_csi_pool 4x4 ] : 0.5174492330430437; [ avg_csi_pool 16x16 ]: 0.6229111027844196
Does the result make sense? May I ask you to take a look and give me some suggestion!
Thanks a lot!
Hi
thank you for your great open source work .
I have write the training code base on the disscussion in the #4
The code is as the following:
After the 200K iterations on the SHANGHAI dataset of 5:20, I go the following results:
The sample output frames are as the following:

Does the result make sense? May I ask you to take a look and give me some suggestion!
Thanks a lot!
Sincerely,
QC