Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/pytorchfwd/freq_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
if np.iscomplexobj(covmean):
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
m = np.max(np.abs(covmean.imag))
raise ValueError("Imaginary component {}".format(m))
print(f"Warning: got imaginary component {m}")
# raise ValueError("Imaginary component {}".format(m))
covmean = covmean.real

tr_covmean = np.trace(covmean)
Expand Down
48 changes: 37 additions & 11 deletions src/pytorchfwd/fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@
NUM_PROCESSES = None


def get_packet_shape(
dataloader: th.utils.data.DataLoader, wavelet: str, max_level: int, log_scale: bool
) -> Tuple[int, int, int, int, int]:
"""Get shape of wavelet packet transform output.

Args:
dataloader (th.utils.data.DataLoader): Torch dataloader.
wavelet (str): Choice of wavelet.
max_level (int): Wavelet decomposition level.
log_scale (bool): Apply log scale.

Returns:
Tuple[int, int, int, int, int]: Shape of wavelet packet transform output.
"""
device = th.device("cuda:0") if th.cuda.is_available() else th.device("cpu")
img_batch = next(iter(dataloader))
if isinstance(img_batch, list):
img_batch = img_batch[0]
img_batch = img_batch.to(device)
packet_ = forward_wavelet_packet_transform(img_batch, wavelet, max_level, log_scale)
return packet_.shape # (b_size, num_packets, c, d0, d1)


def compute_packet_statistics(
dataloader: th.utils.data.DataLoader, wavelet: str, max_level: int, log_scale: bool
) -> Tuple[np.ndarray, ...]:
Expand All @@ -33,18 +56,21 @@ def compute_packet_statistics(
Returns:
Tuple[np.ndarray, ...]: Mean and sigma for each packet.
"""
packets = []
device = th.device("cuda:0") if th.cuda.is_available() else th.device("cpu")
for img_batch in tqdm(dataloader):

n = len(dataloader.dataset)
b_size, num_packets, c, d0, d1 = get_packet_shape(
dataloader, wavelet, max_level, log_scale
)
packet_tensor = th.zeros(n, num_packets, c, d0, d1)
for k, img_batch in enumerate(tqdm(dataloader)):
if isinstance(img_batch, list):
img_batch = img_batch[0]
img_batch = img_batch.to(device)
packets.append(
forward_wavelet_packet_transform(
img_batch, wavelet, max_level, log_scale
).cpu()
)
packet_tensor = th.cat(packets, dim=0)
packet_ = forward_wavelet_packet_transform(img_batch, wavelet, max_level, log_scale)
start_idx = k * b_size
end_idx = start_idx + packet_.shape[0]
packet_tensor[start_idx:end_idx, :, :, :, :] = packet_.cpu()
packet_tensor = th.permute(packet_tensor, (1, 0, 2, 3, 4))
P, BS, C, H, W = packet_tensor.shape
packet_tensor = th.reshape(packet_tensor, (P, BS, C * H * W))
Expand All @@ -54,9 +80,9 @@ def compute_packet_statistics(
def gpu_cov(tensor_):
return th.cov(tensor_.T).cpu()

sigma = th.stack(
[gpu_cov(packet_tensor[p, :, :].to(device)) for p in range(P)], dim=0
).numpy()
sigma = np.zeros((P, C * H * W, C * H * W))
for p in range(P):
sigma[p, :, :] = gpu_cov(packet_tensor[p, :, :].to(device)).numpy()
return mu, sigma


Expand Down