Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed .DS_Store
Binary file not shown.
8 changes: 5 additions & 3 deletions engiopt/cgan_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,17 +259,19 @@ def sample_designs(n_designs: int) -> th.Tensor:
if batches_done % args.sample_interval == 0:
# Extract 25 designs
desired_conds, designs = sample_designs(25)
fig, axes = plt.subplots(5, 5, figsize=(12, 12))
fig, axes = plt.subplots(5, 5, figsize=(24, 24))

# Flatten axes for easy indexing
axes = axes.flatten()

# Plot each tensor as a scatter plot
for j, tensor in enumerate(designs):
x, y = tensor.cpu().numpy() # Extract x and y coordinates
do = desired_conds[j].cpu()
dc = desired_conds[j].cpu()
axes[j].scatter(x, y, s=10, alpha=0.7) # Scatter plot
axes[j].title.set_text(f"m1: {do[0]:.2f}, m2: {do[1]:.2f}")
axes[j].title.set_text(
", ".join(f"{key[:3]}: {val:.2f}" for key, val in zip(problem.conditions_keys, dc))
)
axes[j].set_xticks([]) # Hide x ticks
axes[j].set_yticks([]) # Hide y ticks

Expand Down
10 changes: 6 additions & 4 deletions engiopt/cgan_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,18 +267,20 @@ def sample_designs(n_designs: int) -> th.Tensor:
# This saves a grid image of 25 generated designs every sample_interval
if batches_done % args.sample_interval == 0:
# Extract 25 designs
desired_objs, designs = sample_designs(25)
fig, axes = plt.subplots(5, 5, figsize=(12, 12))
desired_conds, designs = sample_designs(25)
fig, axes = plt.subplots(5, 5, figsize=(24, 24))

# Flatten axes for easy indexing
axes = axes.flatten()

# Plot each tensor as a scatter plot
for j, tensor in enumerate(designs):
img = tensor.cpu().numpy() # Extract x and y coordinates
do = desired_objs[j].cpu()
dc = desired_conds[j].cpu()
axes[j].imshow(img) # Scatter plot
axes[j].title.set_text(f"volfrac: {do[0]:.2f}, penal: {do[1]:.2f}")
axes[j].title.set_text(
", ".join(f"{key[:3]}: {val:.2f}" for key, val in zip(problem.conditions_keys, dc))
)
axes[j].set_xticks([]) # Hide x ticks
axes[j].set_yticks([]) # Hide y ticks

Expand Down
8 changes: 5 additions & 3 deletions engiopt/cgan_cnn_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,17 +361,19 @@ def sample_designs(n_designs: int) -> th.Tensor:
if batches_done % args.sample_interval == 0:
# Extract 25 designs
desired_conds, designs = sample_designs(25)
fig, axes = plt.subplots(5, 5, figsize=(12, 12))
fig, axes = plt.subplots(5, 5, figsize=(24, 24))

# Flatten axes for easy indexing
axes = axes.flatten()

# Plot each tensor as a scatter plot
for j, tensor in enumerate(designs):
img = tensor.cpu().numpy().reshape(design_shape[0], design_shape[1]) # Extract x and y coordinates
do = desired_conds[j].cpu()
dc = desired_conds[j].cpu()
axes[j].imshow(img) # Scatter plot
axes[j].title.set_text(f"volfrac: {do[0]:.2f}")
axes[j].title.set_text(
", ".join(f"{key[:3]}: {val:.2f}" for key, val in zip(problem.conditions_keys, dc))
)
axes[j].set_xticks([]) # Hide x ticks
axes[j].set_yticks([]) # Hide y ticks

Expand Down
2 changes: 1 addition & 1 deletion engiopt/diffusion_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class Args:
if batches_done % args.sample_interval == 0:
# Extract 25 designs
designs = diffusion.sample(batch_size=25)
fig, axes = plt.subplots(5, 5, figsize=(12, 12))
fig, axes = plt.subplots(5, 5, figsize=(24, 24))

# Flatten axes for easy indexing
axes = axes.flatten()
Expand Down
30 changes: 10 additions & 20 deletions engiopt/diffusion_2d_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import numpy as np
import torch as th
from torch.nn import functional
from torchvision import transforms
import tqdm
import tyro
import wandb
Expand Down Expand Up @@ -245,7 +244,7 @@ def sample_timestep(

# Initialize UNet from Huggingface
model = UNet2DConditionModel(
sample_size=(100, 100),
sample_size=design_shape,
in_channels=1,
out_channels=1,
cross_attention_dim=64,
Expand All @@ -262,18 +261,8 @@ def sample_timestep(

# Configure data loader
training_ds = problem.dataset.with_format("torch", device=device)["train"]
filtered_ds = th.zeros(len(training_ds), 100, 100, device=device)
for i in range(len(training_ds)):
filtered_ds[i] = transforms.Resize((100, 100))(
training_ds[i]["optimal_design"].reshape(1, design_shape[0], design_shape[1])
)
filtered_ds_max = filtered_ds.max()
filtered_ds_min = filtered_ds.min()
filtered_ds *= 2
filtered_ds -= 1
filtered_ds_norm = (filtered_ds - filtered_ds_min) / (filtered_ds_max - filtered_ds_min)
training_ds = th.utils.data.TensorDataset(
filtered_ds_norm.flatten(1), *[training_ds[key] for key, _ in problem.conditions]
training_ds["optimal_design"].flatten(1), *[training_ds[key] for key in problem.conditions_keys]
)
vf_min = training_ds.tensors[1].min()
vf_max = training_ds.tensors[1].max()
Expand Down Expand Up @@ -325,7 +314,7 @@ def sample_designs(model: UNet2DConditionModel, n_designs: int = 25) -> tuple[th
"""Samples n_designs designs."""
model.eval()
with th.no_grad():
dims = (n_designs, 1, 100, 100)
dims = (n_designs, 1, design_shape[0], design_shape[1])
image = th.randn(dims, device=device) # initial image
encoder_hidden_states = th.linspace(vf_min, vf_max, n_designs, device=device)
encoder_hidden_states = encoder_hidden_states.view(n_designs, 1, 1).expand(n_designs, 1, 32)
Expand All @@ -343,7 +332,7 @@ def sample_designs(model: UNet2DConditionModel, n_designs: int = 25) -> tuple[th
for i, data in enumerate(dataloader):
# Zero the parameter gradients
optimizer.zero_grad()
designs = data[0].reshape(-1, 1, 100, 100)
designs = data[0].reshape(-1, 1, design_shape[0], design_shape[1])
x = designs.to(device)
conds = th.stack((data[1:]), dim=1).reshape(-1, 1, 1)
conds_ex = conds.expand(-1, 1, 32)
Expand Down Expand Up @@ -381,20 +370,21 @@ def sample_designs(model: UNet2DConditionModel, n_designs: int = 25) -> tuple[th
# Extract 25 designs

designs, hidden_states = sample_designs(model, 25)
fig, axes = plt.subplots(5, 5, figsize=(12, 12))
fig, axes = plt.subplots(5, 5, figsize=(24, 24))

# Flatten axes for easy indexing
axes = axes.flatten()

# Plot the image created by each output
for j, tensor in enumerate(designs):
img = tensor.cpu().numpy().reshape(100, 100) # Extract x and y coordinates
do = hidden_states[j, 0, 0].cpu()
img = tensor.cpu().numpy() # Extract x and y coordinates
dc = [hidden_states[j, 0, 0].cpu().item()]
axes[j].imshow(img.T) # image plot
axes[j].title.set_text(f"volfrac: {do:.2f}") # Set title
axes[j].title.set_text(
", ".join(f"{key}: {val:.2f}" for key, val in zip(problem.conditions_keys, dc))
)
axes[j].set_xticks([]) # Hide x ticks
axes[j].set_yticks([]) # Hide y ticks

plt.tight_layout()
img_fname = f"images/{batches_done}.png"
plt.savefig(img_fname)
Expand Down
Loading