diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 47b9fb9..0000000 Binary files a/.DS_Store and /dev/null differ diff --git a/engiopt/cgan_1d.py b/engiopt/cgan_1d.py index b71512a..0e6c27f 100644 --- a/engiopt/cgan_1d.py +++ b/engiopt/cgan_1d.py @@ -259,7 +259,7 @@ def sample_designs(n_designs: int) -> th.Tensor: if batches_done % args.sample_interval == 0: # Extract 25 designs desired_conds, designs = sample_designs(25) - fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + fig, axes = plt.subplots(5, 5, figsize=(24, 24)) # Flatten axes for easy indexing axes = axes.flatten() @@ -267,9 +267,11 @@ def sample_designs(n_designs: int) -> th.Tensor: # Plot each tensor as a scatter plot for j, tensor in enumerate(designs): x, y = tensor.cpu().numpy() # Extract x and y coordinates - do = desired_conds[j].cpu() + dc = desired_conds[j].cpu() axes[j].scatter(x, y, s=10, alpha=0.7) # Scatter plot - axes[j].title.set_text(f"m1: {do[0]:.2f}, m2: {do[1]:.2f}") + axes[j].title.set_text( + ", ".join(f"{key[:3]}: {val:.2f}" for key, val in zip(problem.conditions_keys, dc)) + ) axes[j].set_xticks([]) # Hide x ticks axes[j].set_yticks([]) # Hide y ticks diff --git a/engiopt/cgan_2d.py b/engiopt/cgan_2d.py index c370227..a58afcc 100644 --- a/engiopt/cgan_2d.py +++ b/engiopt/cgan_2d.py @@ -267,8 +267,8 @@ def sample_designs(n_designs: int) -> th.Tensor: # This saves a grid image of 25 generated designs every sample_interval if batches_done % args.sample_interval == 0: # Extract 25 designs - desired_objs, designs = sample_designs(25) - fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + desired_conds, designs = sample_designs(25) + fig, axes = plt.subplots(5, 5, figsize=(24, 24)) # Flatten axes for easy indexing axes = axes.flatten() @@ -276,9 +276,11 @@ def sample_designs(n_designs: int) -> th.Tensor: # Plot each tensor as a scatter plot for j, tensor in enumerate(designs): img = tensor.cpu().numpy() # Extract x and y coordinates - do = desired_objs[j].cpu() + dc = desired_conds[j].cpu() axes[j].imshow(img) # Scatter plot - axes[j].title.set_text(f"volfrac: {do[0]:.2f}, penal: {do[1]:.2f}") + axes[j].title.set_text( + ", ".join(f"{key[:3]}: {val:.2f}" for key, val in zip(problem.conditions_keys, dc)) + ) axes[j].set_xticks([]) # Hide x ticks axes[j].set_yticks([]) # Hide y ticks diff --git a/engiopt/cgan_cnn_2d.py b/engiopt/cgan_cnn_2d.py index d842b9c..a65c4c2 100644 --- a/engiopt/cgan_cnn_2d.py +++ b/engiopt/cgan_cnn_2d.py @@ -361,7 +361,7 @@ def sample_designs(n_designs: int) -> th.Tensor: if batches_done % args.sample_interval == 0: # Extract 25 designs desired_conds, designs = sample_designs(25) - fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + fig, axes = plt.subplots(5, 5, figsize=(24, 24)) # Flatten axes for easy indexing axes = axes.flatten() @@ -369,9 +369,11 @@ def sample_designs(n_designs: int) -> th.Tensor: # Plot each tensor as a scatter plot for j, tensor in enumerate(designs): img = tensor.cpu().numpy().reshape(design_shape[0], design_shape[1]) # Extract x and y coordinates - do = desired_conds[j].cpu() + dc = desired_conds[j].cpu() axes[j].imshow(img) # Scatter plot - axes[j].title.set_text(f"volfrac: {do[0]:.2f}") + axes[j].title.set_text( + ", ".join(f"{key[:3]}: {val:.2f}" for key, val in zip(problem.conditions_keys, dc)) + ) axes[j].set_xticks([]) # Hide x ticks axes[j].set_yticks([]) # Hide y ticks diff --git a/engiopt/diffusion_1d.py b/engiopt/diffusion_1d.py index 37648d3..235379a 100644 --- a/engiopt/diffusion_1d.py +++ b/engiopt/diffusion_1d.py @@ -147,7 +147,7 @@ class Args: if batches_done % args.sample_interval == 0: # Extract 25 designs designs = diffusion.sample(batch_size=25) - fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + fig, axes = plt.subplots(5, 5, figsize=(24, 24)) # Flatten axes for easy indexing axes = axes.flatten() diff --git a/engiopt/diffusion_2d_cond.py b/engiopt/diffusion_2d_cond.py index 5923118..2fb1834 100644 --- a/engiopt/diffusion_2d_cond.py +++ b/engiopt/diffusion_2d_cond.py @@ -15,7 +15,6 @@ import numpy as np import torch as th from torch.nn import functional -from torchvision import transforms import tqdm import tyro import wandb @@ -245,7 +244,7 @@ def sample_timestep( # Initialize UNet from Huggingface model = UNet2DConditionModel( - sample_size=(100, 100), + sample_size=design_shape, in_channels=1, out_channels=1, cross_attention_dim=64, @@ -262,18 +261,8 @@ def sample_timestep( # Configure data loader training_ds = problem.dataset.with_format("torch", device=device)["train"] - filtered_ds = th.zeros(len(training_ds), 100, 100, device=device) - for i in range(len(training_ds)): - filtered_ds[i] = transforms.Resize((100, 100))( - training_ds[i]["optimal_design"].reshape(1, design_shape[0], design_shape[1]) - ) - filtered_ds_max = filtered_ds.max() - filtered_ds_min = filtered_ds.min() - filtered_ds *= 2 - filtered_ds -= 1 - filtered_ds_norm = (filtered_ds - filtered_ds_min) / (filtered_ds_max - filtered_ds_min) training_ds = th.utils.data.TensorDataset( - filtered_ds_norm.flatten(1), *[training_ds[key] for key, _ in problem.conditions] + training_ds["optimal_design"].flatten(1), *[training_ds[key] for key in problem.conditions_keys] ) vf_min = training_ds.tensors[1].min() vf_max = training_ds.tensors[1].max() @@ -325,7 +314,7 @@ def sample_designs(model: UNet2DConditionModel, n_designs: int = 25) -> tuple[th """Samples n_designs designs.""" model.eval() with th.no_grad(): - dims = (n_designs, 1, 100, 100) + dims = (n_designs, 1, design_shape[0], design_shape[1]) image = th.randn(dims, device=device) # initial image encoder_hidden_states = th.linspace(vf_min, vf_max, n_designs, device=device) encoder_hidden_states = encoder_hidden_states.view(n_designs, 1, 1).expand(n_designs, 1, 32) @@ -343,7 +332,7 @@ def sample_designs(model: UNet2DConditionModel, n_designs: int = 25) -> tuple[th for i, data in enumerate(dataloader): # Zero the parameter gradients optimizer.zero_grad() - designs = data[0].reshape(-1, 1, 100, 100) + designs = data[0].reshape(-1, 1, design_shape[0], design_shape[1]) x = designs.to(device) conds = th.stack((data[1:]), dim=1).reshape(-1, 1, 1) conds_ex = conds.expand(-1, 1, 32) @@ -381,20 +370,21 @@ def sample_designs(model: UNet2DConditionModel, n_designs: int = 25) -> tuple[th # Extract 25 designs designs, hidden_states = sample_designs(model, 25) - fig, axes = plt.subplots(5, 5, figsize=(12, 12)) + fig, axes = plt.subplots(5, 5, figsize=(24, 24)) # Flatten axes for easy indexing axes = axes.flatten() # Plot the image created by each output for j, tensor in enumerate(designs): - img = tensor.cpu().numpy().reshape(100, 100) # Extract x and y coordinates - do = hidden_states[j, 0, 0].cpu() + img = tensor.cpu().numpy() # Extract x and y coordinates + dc = [hidden_states[j, 0, 0].cpu().item()] axes[j].imshow(img.T) # image plot - axes[j].title.set_text(f"volfrac: {do:.2f}") # Set title + axes[j].title.set_text( + ", ".join(f"{key}: {val:.2f}" for key, val in zip(problem.conditions_keys, dc)) + ) axes[j].set_xticks([]) # Hide x ticks axes[j].set_yticks([]) # Hide y ticks - plt.tight_layout() img_fname = f"images/{batches_done}.png" plt.savefig(img_fname)