diff --git a/edflow/explore.py b/edflow/explore.py index 03acdfe..274e1a5 100644 --- a/edflow/explore.py +++ b/edflow/explore.py @@ -18,6 +18,8 @@ ) from edflow import get_obj_from_str from edflow.data.dataset_mixin import DatasetMixin +from edflow.util import contains_key, retrieve +from edflow.data.util import adjust_support def display_default(obj): @@ -61,7 +63,11 @@ def display(key, obj): st.text(obj) elif sel == "Image": - st.image((obj + 1.0) / 2.0) + try: + st.image((obj + 1.0) / 2.0) + except RuntimeError: + obj = adjust_support(obj, "-1->1", "0->255") + st.image((obj + 1.0) / 2.0) elif sel == "Flow": display_flow(obj, key) @@ -71,6 +77,19 @@ def display(key, obj): img = obj[:, :, idx].astype(np.float) st.image(img) + elif sel == "Segmentation Flat": + idx = st.number_input("Segmentation Index", 0, 255, 0) + img = (obj[:, :].astype(np.int) == idx).astype(np.float) + st.image(img) + + elif sel == "IUV": + # TODO: implement different visualization + try: + st.image((obj + 1.0) / 2.0) + except RuntimeError: + obj = adjust_support(obj, "-1->1", "0->255") + st.image((obj + 1.0) / 2.0) + def selector(key, obj): """Show select box to choose display mode of obj in streamlit @@ -87,7 +106,16 @@ def selector(key, obj): str Selected display method for item """ - options = ["Auto", "Text", "Image", "Flow", "Segmentation", "None"] + options = [ + "Auto", + "Text", + "Image", + "Flow", + "Segmentation", + "Segmentation Flat", + "IUV", + "None", + ] idx = options.index(display_default(obj)) select = st.selectbox("Display {} as".format(key), options, index=idx) return select @@ -201,8 +229,9 @@ def show_example(dset, idx, config): # additional visualizations default_additional_visualizations = retrieve( - config, "edexplore/visualizations", default=dict() + config, "edexplore/visualizations", default={} ).keys() + default_additional_visualizations = list(default_additional_visualizations) additional_visualizations = st.sidebar.multiselect( "Additional visualizations", list(ADDITIONAL_VISUALIZATIONS.keys()), @@ -226,7 +255,13 @@ def show_example(dset, idx, config): def _get_state(config): - Dataset = get_obj_from_str(config["dataset"]) + if contains_key(config, "dataset"): + Dataset = get_obj_from_str(config["dataset"]) + elif contains_key(config, "datasets/train"): + module_name = retrieve(config, "datasets/train") + Dataset = get_obj_from_str(module_name) + else: + raise KeyError dataset = Dataset(config) return dataset diff --git a/examples/vae/README.md b/examples/vae/README.md new file mode 100644 index 0000000..1daae4e --- /dev/null +++ b/examples/vae/README.md @@ -0,0 +1,64 @@ +# Example + +* explore dataset using `edexplore` +* on mnist +![assets/edexplore_mnist.gif](assets/edexplore_mnist.gif) + +* on deepfashion, which has additional annotations such as segmentation and IUV flow. +![assets/edexplore_df.gif](assets/edexplore_df.gif) + +``` +export STREAMLIT_SERVER_PORT=8080 +edexplore -b vae/config_explore.yaml +``` + + +* train and evaluate model +``` +edflow -b vae/config.yaml -t # abort at some point +edflow -b vae/config.yaml -p logs/xxx # will also trigger evaluation +``` + +* will generate FID outputs + +![assets/FID_eval.png](assets/FID_eval.png) + + +## Working with MetaDatasets + +* load evaluation outputs using the MetaDataset. Open `ipython` + +```python +from edflow.data.believers.meta import MetaDataset +M = MetaDataset("logs/xxx/eval/yyy/zzz/model_outputs") +print(M) +# +--------------------+----------+-----------+ +# | Name | Type | Content | +# +====================+==========+===========+ +# | rec_loss | memmap | (100,) | +# +--------------------+----------+-----------+ +# | kl_loss | memmap | (100,) | +# +--------------------+----------+-----------+ +# | reconstructions_ | memmap | (100,) | +# +--------------------+----------+-----------+ +# | samples_ | memmap | (100,) | +# +--------------------+----------+-----------+ + +M.num_examples +>>> 100 + +M[0]["labels_"] +>>> { + 'rec_loss': 0.3759992, + 'kl_loss': 1.5367432, + 'reconstructions_': 'logs/xxx/eval/yyy/zzz/model_outputs/reconstructions_000000.png', + 'samples_': 'logs/xxx/eval/yyy/zzz/model_outputs/samples_000000.png' +} + +# images are loaded lazily +M[0]["reconstructions"] +>>> .loader(support='0->255', resize_to=None, root='')> + +M[0]["reconstructions"]().shape +>>>(256, 256, 3) +``` diff --git a/examples/vae/assets/FID_eval.png b/examples/vae/assets/FID_eval.png new file mode 100644 index 0000000..6d4c25a Binary files /dev/null and b/examples/vae/assets/FID_eval.png differ diff --git a/examples/vae/assets/edexplore_df.gif b/examples/vae/assets/edexplore_df.gif new file mode 100644 index 0000000..1d2f13b Binary files /dev/null and b/examples/vae/assets/edexplore_df.gif differ diff --git a/examples/vae/assets/edexplore_mnist.gif b/examples/vae/assets/edexplore_mnist.gif new file mode 100644 index 0000000..d439d02 Binary files /dev/null and b/examples/vae/assets/edexplore_mnist.gif differ diff --git a/examples/vae/config.yaml b/examples/vae/config.yaml new file mode 100644 index 0000000..8cc77fa --- /dev/null +++ b/examples/vae/config.yaml @@ -0,0 +1,26 @@ +datasets: + train: vae.datasets.Deepfashion_Img + validation: vae.datasets.Deepfashion_Img_Val +model: vae.edflow.Model +iterator: vae.edflow.Iterator + + +batch_size: 25 +num_epochs: 2 + +spatial_size: 256 +in_channels: 3 +latent_dim: 128 +hidden_dims: [32, 64, 64, 128, 256, 256, 512] + + +fid: + batch_size: 50 +fid_stats: + pre_calc_stat_path: "fid_stats/deepfashion.npz" + + +fixed_example_indices: { + "train" : [0, 1, 2, 3], + "validation" : [0, 1, 2, 3] +} \ No newline at end of file diff --git a/examples/vae/config_explore.yaml b/examples/vae/config_explore.yaml new file mode 100644 index 0000000..84e7608 --- /dev/null +++ b/examples/vae/config_explore.yaml @@ -0,0 +1,3 @@ +datasets: + train: vae.datasets.Deepfashion + validation: vae.datasets.Deepfashion diff --git a/examples/vae/datasets.py b/examples/vae/datasets.py new file mode 100644 index 0000000..9bd41ba --- /dev/null +++ b/examples/vae/datasets.py @@ -0,0 +1,136 @@ +from edflow.data.dataset import PRNGMixin, DatasetMixin +from edflow.util import retrieve +import albumentations +import os +import pandas as pd +import cv2 +from edflow.data.util import adjust_support +import numpy as np + + +COLORS = np.array( + [ + [0, 0, 0], + [0, 0, 255], + [50, 205, 50], + [139, 78, 16], + [144, 238, 144], + [211, 211, 211], + [250, 250, 255], + ] +) +W = np.power(255, [0, 1, 2]) + +HASHES = np.sum(W * COLORS, axis=-1) +HASH2COLOR = {h: c for h, c in zip(HASHES, COLORS)} +HASH2IDX = {h: i for i, h in enumerate(HASHES)} + + +def rgb2index(segmentation_rgb): + """ + turn a 3 channel RGB color to 1 channel index color + """ + s_shape = segmentation_rgb.shape + s_hashes = np.sum(W * segmentation_rgb, axis=-1) + print(np.unique(segmentation_rgb.reshape((-1, 3)), axis=0)) + func = lambda x: HASH2IDX[int(x)] + segmentation_idx = np.apply_along_axis(func, 0, s_hashes.reshape((1, -1))) + segmentation_idx = segmentation_idx.reshape(s_shape[:2]) + return segmentation_idx + + +class Deepfashion(DatasetMixin, PRNGMixin): + def __init__(self, config): + self.size = retrieve(config, "spatial_size", default=256) + self.root = os.path.join("data", "deepfashion") + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) + self.preprocessor = albumentations.Compose([self.rescaler]) + + df = pd.read_csv( + os.path.join(self.root, "Anno", "list_bbox_inshop.txt"), + skiprows=1, + sep="\s+", + ) + self.fnames = list(df.image_name) + + def __len__(self): + return len(self.fnames) + + def imread(self, path): + img = cv2.imread(path, -1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + def get_example(self, i): + fname = self.fnames[i] + fname2 = "/".join(fname.split("/")[1:]) + + fname_iuv, _ = os.path.splitext(fname2) + fname_iuv = fname_iuv + "_IUV.png" + + fname_segmentation, _ = os.path.splitext(fname2) + fname_segmentation = fname_segmentation + "_segment.png" + + img_path = os.path.join(self.root, "Img", fname) + segmentation_path = os.path.join( + self.root, "Anno", "segmentation", "img_highres", fname_segmentation + ) + iuv_path = os.path.join(self.root, "Anno", "densepose", "img_iuv", fname_iuv) + + img = self.imread(img_path) + img = adjust_support(img, "-1->1", "0->255") + + segmentation = cv2.imread(segmentation_path, -1)[:, :, :3] + segmentation = rgb2index( + segmentation + ) # TODO: resizing changes aspect ratio, which might not be okay + segmentation = cv2.resize( + segmentation, (self.size, self.size), interpolation=cv2.INTER_NEAREST + ) + iuv = self.imread(iuv_path) + + example = {"img": img, "segmentation": segmentation, "iuv": iuv} + + return example + + +class Deepfashion_Img(DatasetMixin, PRNGMixin): + def __init__(self, config): + self.size = retrieve(config, "spatial_size", default=256) + self.root = os.path.join("data", "deepfashion") + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) + self.preprocessor = albumentations.Compose([self.rescaler]) + + df = pd.read_csv( + os.path.join(self.root, "Anno", "list_bbox_inshop.txt"), + skiprows=1, + sep="\s+", + ) + self.fnames = list(df.image_name) + + def __len__(self): + return len(self.fnames) + + def imread(self, path): + img = cv2.imread(path, -1) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + def get_example(self, i): + fname = self.fnames[i] + + img_path = os.path.join(self.root, "Img", fname) + + img = self.imread(img_path) + img = adjust_support(img, "-1->1", "0->255") + + example = { + "img": img.astype(np.float32), + } + + return example + + +class Deepfashion_Img_Val(Deepfashion_Img): + def __len__(self): + return 100 diff --git a/examples/vae/edflow.py b/examples/vae/edflow.py new file mode 100644 index 0000000..ee885c1 --- /dev/null +++ b/examples/vae/edflow.py @@ -0,0 +1,296 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch import Tensor + +import numpy as np +from edflow import TemplateIterator, get_logger +from typing import * +from edflow.data.util import adjust_support + +import torch +from torch import nn +from torch.nn import functional as F + +from edflow.util import retrieve +from fid import fid_callback + + +def rec_fid_callback(*args, **kwargs): + return fid_callback.fid( + *args, + **kwargs, + im_in_key="img", + im_in_support="-1->1", + im_out_key="reconstructions", + im_out_support="0->255", + name="fid_recons" + ) + + +def sample_fid_callback(*args, **kwargs): + return fid_callback.fid( + *args, + **kwargs, + im_in_key="img", + im_in_support="-1->1", + im_out_key="samples", + im_out_support="0->255", + name="fid_samples" + ) + + +def reconstruction_callback(root, data_in, data_out, config): + log = {"scalars": dict()} + log["scalars"]["rec_loss"] = np.mean(data_out.labels["rec_loss"]) + log["scalars"]["kl_loss"] = np.mean(data_out.labels["kl_loss"]) + return log + + +class KLDLoss(nn.Module): + def __init__(self, reduction="sum"): + super(KLDLoss, self).__init__() + self.reduction = reduction + + def forward(self, mean, logvar): + # KLD loss + kld_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp(), 1) + # Size average + if self.reduction == "mean": + kld_loss = torch.mean(kld_loss) + elif self.reduction == "sum": + kld_loss = torch.sum(kld_loss) + return kld_loss + + +class Model(torch.nn.Module): + def __init__(self, config, **kwargs) -> None: + super(Model, self).__init__() + + in_channels = config["in_channels"] + latent_dim = config["latent_dim"] + hidden_dims = config["hidden_dims"] + beta = config.get("beta", 1) + + self.latent_dim = latent_dim + self.beta = beta + + modules = [] + if hidden_dims is None: + hidden_dims = [32, 64, 128, 256, 512] + + # Build Encoder + for h_dim in hidden_dims: + modules.append( + nn.Sequential( + nn.Conv2d( + in_channels, + out_channels=h_dim, + kernel_size=3, + stride=2, + padding=1, + ), + nn.BatchNorm2d(h_dim), + nn.LeakyReLU(), + ) + ) + in_channels = h_dim + + self.encoder = nn.Sequential(*modules) + self.fc_mu = nn.Linear(hidden_dims[-1] * 4, latent_dim) + self.fc_var = nn.Linear(hidden_dims[-1] * 4, latent_dim) + + # Build Decoder + modules = [] + + self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4) + + hidden_dims.reverse() + + for i in range(len(hidden_dims) - 1): + modules.append( + nn.Sequential( + nn.ConvTranspose2d( + hidden_dims[i], + hidden_dims[i + 1], + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + nn.BatchNorm2d(hidden_dims[i + 1]), + nn.LeakyReLU(), + ) + ) + + self.decoder = nn.Sequential(*modules) + + self.final_layer = nn.Sequential( + nn.ConvTranspose2d( + hidden_dims[-1], + hidden_dims[-1], + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + nn.BatchNorm2d(hidden_dims[-1]), + nn.LeakyReLU(), + nn.Conv2d(hidden_dims[-1], out_channels=3, kernel_size=3, padding=1), + nn.Tanh(), + ) + + def encode(self, input: Tensor) -> List[Tensor]: + """ + Encodes the input by passing through the encoder network + and returns the latent codes. + :param input: (Tensor) Input tensor to encoder [N x C x H x W] + :return: (Tensor) List of latent codes + """ + result = self.encoder(input) + result = torch.flatten(result, start_dim=1) + + # Split the result into mu and var components + # of the latent Gaussian distribution + mu = self.fc_mu(result) + log_var = self.fc_var(result) + + return [mu, log_var] + + def decode(self, z: Tensor) -> Tensor: + result = self.decoder_input(z) + result = result.view(-1, 512, 2, 2) + result = self.decoder(result) + result = self.final_layer(result) + return result + + def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps * std + mu + + def forward(self, input: Tensor, **kwargs) -> Tensor: + mu, log_var = self.encode(input) + z = self.reparameterize(mu, log_var) + return [self.decode(z), input, mu, log_var] + + def loss_function(self, recons, input, mu, log_var) -> dict: + # make batch of losses + recons_loss = F.mse_loss(recons, input, reduction="none") + recons_loss = recons_loss.mean(dim=[1, 2, 3]) + + # batch of losses + kld_loss = -0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1) + + loss = recons_loss + self.beta * kld_loss + + return {"loss": loss, "Reconstruction_Loss": recons_loss, "KLD": kld_loss} + + def sample(self, num_samples: int, current_device: int, **kwargs) -> Tensor: + """ + Samples from the latent space and return the corresponding + image space map. + :param num_samples: (Int) Number of samples + :param current_device: (Int) Device to run the model + :return: (Tensor) + """ + z = torch.randn(num_samples, self.latent_dim) + + z = z.to(current_device) + + samples = self.decode(z) + return samples + + def generate(self, x: Tensor, **kwargs) -> Tensor: + """ + Given an input image x, returns the reconstructed image + :param x: (Tensor) [B x C x H x W] + :return: (Tensor) [B x C x H x W] + """ + + return self.forward(x)[0] + + +class Iterator(TemplateIterator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # loss and optimizer + self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001) + + @property + def callbacks(self): + return {} + + def save(self, checkpoint_path): + state = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + } + torch.save(state, checkpoint_path) + + def restore(self, checkpoint_path): + state = torch.load(checkpoint_path) + self.model.load_state_dict(state["model"]) + self.optimizer.load_state_dict(state["optimizer"]) + + def step_op(self, model, **kwargs): + inputs = kwargs["img"] + # inputs = adjust_support( + # inputs, "-1->1", "0->1" + # ) # make sure adjust support preservers datatype + inputs = inputs.astype(np.float32) + inputs = torch.tensor(inputs) + inputs = inputs.permute(0, 3, 1, 2) + + def train_op(): + # compute loss + recons, _, mu, log_var = model(inputs) + loss_dict = model.loss_function(recons, inputs, mu, log_var) + loss = loss_dict["loss"].mean() + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + def log_op(): + with torch.no_grad(): + recons, _, mu, log_var = model(inputs) + loss_dict = model.loss_function(recons, inputs, mu, log_var) + loss = loss_dict["loss"].mean() + loss_rec = loss_dict["Reconstruction_Loss"].mean() + loss_kld = loss_dict["KLD"].mean() + + image_logs = { + "inputs": inputs.detach().permute(0, 2, 3, 1).numpy(), + "recons": recons.detach().permute(0, 2, 3, 1).numpy(), + } + scalar_logs = {"loss": loss, "loss_rec": loss_rec, "loss_kld": loss_kld} + + return {"images": image_logs, "scalars": scalar_logs} + + def eval_op(): + with torch.no_grad(): + recons, _, mu, log_var = model(inputs) + samples = model.sample(inputs.shape[0], recons.device) + loss_dict = model.loss_function(recons, inputs, mu, log_var) + loss_rec = loss_dict["Reconstruction_Loss"] + loss_kld = loss_dict["KLD"] + return { + "reconstructions": recons.detach().permute(0, 2, 3, 1).numpy(), + "samples": samples.detach() + .permute(0, 2, 3, 1) + .numpy(), # TODO: replace with samples + "labels": { + "rec_loss": loss_rec.detach().numpy(), + "kl_loss": loss_kld.detach().numpy(), + }, + } + + return {"train_op": train_op, "log_op": log_op, "eval_op": eval_op} + + @property + def callbacks(self): + cbs = {"eval_op": {"reconstruction": reconstruction_callback}} + cbs["eval_op"]["fid_reconstruction"] = rec_fid_callback + cbs["eval_op"]["fid_samples"] = sample_fid_callback + return cbs