diff --git a/src/pytorchfwd/freq_math.py b/src/pytorchfwd/freq_math.py index 4676db8..662b338 100644 --- a/src/pytorchfwd/freq_math.py +++ b/src/pytorchfwd/freq_math.py @@ -183,7 +183,8 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): if np.iscomplexobj(covmean): if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): m = np.max(np.abs(covmean.imag)) - raise ValueError("Imaginary component {}".format(m)) + print(f"Warning: got imaginary component {m}") + # raise ValueError("Imaginary component {}".format(m)) covmean = covmean.real tr_covmean = np.trace(covmean) diff --git a/src/pytorchfwd/fwd.py b/src/pytorchfwd/fwd.py index 565db3e..47c3c61 100644 --- a/src/pytorchfwd/fwd.py +++ b/src/pytorchfwd/fwd.py @@ -19,6 +19,29 @@ NUM_PROCESSES = None +def get_packet_shape( + dataloader: th.utils.data.DataLoader, wavelet: str, max_level: int, log_scale: bool +) -> Tuple[int, int, int, int, int]: + """Get shape of wavelet packet transform output. + + Args: + dataloader (th.utils.data.DataLoader): Torch dataloader. + wavelet (str): Choice of wavelet. + max_level (int): Wavelet decomposition level. + log_scale (bool): Apply log scale. + + Returns: + Tuple[int, int, int, int, int]: Shape of wavelet packet transform output. + """ + device = th.device("cuda:0") if th.cuda.is_available() else th.device("cpu") + img_batch = next(iter(dataloader)) + if isinstance(img_batch, list): + img_batch = img_batch[0] + img_batch = img_batch.to(device) + packet_ = forward_wavelet_packet_transform(img_batch, wavelet, max_level, log_scale) + return packet_.shape # (b_size, num_packets, c, d0, d1) + + def compute_packet_statistics( dataloader: th.utils.data.DataLoader, wavelet: str, max_level: int, log_scale: bool ) -> Tuple[np.ndarray, ...]: @@ -33,18 +56,21 @@ def compute_packet_statistics( Returns: Tuple[np.ndarray, ...]: Mean and sigma for each packet. """ - packets = [] device = th.device("cuda:0") if th.cuda.is_available() else th.device("cpu") - for img_batch in tqdm(dataloader): + + n = len(dataloader.dataset) + b_size, num_packets, c, d0, d1 = get_packet_shape( + dataloader, wavelet, max_level, log_scale + ) + packet_tensor = th.zeros(n, num_packets, c, d0, d1) + for k, img_batch in enumerate(tqdm(dataloader)): if isinstance(img_batch, list): img_batch = img_batch[0] img_batch = img_batch.to(device) - packets.append( - forward_wavelet_packet_transform( - img_batch, wavelet, max_level, log_scale - ).cpu() - ) - packet_tensor = th.cat(packets, dim=0) + packet_ = forward_wavelet_packet_transform(img_batch, wavelet, max_level, log_scale) + start_idx = k * b_size + end_idx = start_idx + packet_.shape[0] + packet_tensor[start_idx:end_idx, :, :, :, :] = packet_.cpu() packet_tensor = th.permute(packet_tensor, (1, 0, 2, 3, 4)) P, BS, C, H, W = packet_tensor.shape packet_tensor = th.reshape(packet_tensor, (P, BS, C * H * W)) @@ -54,9 +80,9 @@ def compute_packet_statistics( def gpu_cov(tensor_): return th.cov(tensor_.T).cpu() - sigma = th.stack( - [gpu_cov(packet_tensor[p, :, :].to(device)) for p in range(P)], dim=0 - ).numpy() + sigma = np.zeros((P, C * H * W, C * H * W)) + for p in range(P): + sigma[p, :, :] = gpu_cov(packet_tensor[p, :, :].to(device)).numpy() return mu, sigma