Skip to content
43 changes: 39 additions & 4 deletions edflow/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
)
from edflow import get_obj_from_str
from edflow.data.dataset_mixin import DatasetMixin
from edflow.util import contains_key, retrieve
from edflow.data.util import adjust_support


def display_default(obj):
Expand Down Expand Up @@ -61,7 +63,11 @@ def display(key, obj):
st.text(obj)

elif sel == "Image":
st.image((obj + 1.0) / 2.0)
try:
st.image((obj + 1.0) / 2.0)
except RuntimeError:
obj = adjust_support(obj, "-1->1", "0->255")
st.image((obj + 1.0) / 2.0)

elif sel == "Flow":
display_flow(obj, key)
Expand All @@ -71,6 +77,19 @@ def display(key, obj):
img = obj[:, :, idx].astype(np.float)
st.image(img)

elif sel == "Segmentation Flat":
idx = st.number_input("Segmentation Index", 0, 255, 0)
img = (obj[:, :].astype(np.int) == idx).astype(np.float)
st.image(img)

elif sel == "IUV":
# TODO: implement different visualization
try:
st.image((obj + 1.0) / 2.0)
except RuntimeError:
obj = adjust_support(obj, "-1->1", "0->255")
st.image((obj + 1.0) / 2.0)


def selector(key, obj):
"""Show select box to choose display mode of obj in streamlit
Expand All @@ -87,7 +106,16 @@ def selector(key, obj):
str
Selected display method for item
"""
options = ["Auto", "Text", "Image", "Flow", "Segmentation", "None"]
options = [
"Auto",
"Text",
"Image",
"Flow",
"Segmentation",
"Segmentation Flat",
"IUV",
"None",
]
idx = options.index(display_default(obj))
select = st.selectbox("Display {} as".format(key), options, index=idx)
return select
Expand Down Expand Up @@ -201,8 +229,9 @@ def show_example(dset, idx, config):

# additional visualizations
default_additional_visualizations = retrieve(
config, "edexplore/visualizations", default=dict()
config, "edexplore/visualizations", default={}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://stackoverflow.com/a/5790954
indicates a performance improvement for this change. For consisntency however, I would not change only a single instance.
should we change dict() to {} then everywhere at leas in this file?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this does not matter. I was just fixing a bug when default was not given an empty dict.

).keys()
default_additional_visualizations = list(default_additional_visualizations)
additional_visualizations = st.sidebar.multiselect(
"Additional visualizations",
list(ADDITIONAL_VISUALIZATIONS.keys()),
Expand All @@ -226,7 +255,13 @@ def show_example(dset, idx, config):


def _get_state(config):
Dataset = get_obj_from_str(config["dataset"])
if contains_key(config, "dataset"):
Dataset = get_obj_from_str(config["dataset"])
elif contains_key(config, "datasets/train"):
module_name = retrieve(config, "datasets/train")
Dataset = get_obj_from_str(module_name)
else:
raise KeyError
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should also add the possiblity to explore the validation dataset? Maybe with a selector in the sidebar?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was only fixing a bug that got introduced here. I am not willing to implement a new feature at this point.

I think it is not important enough to implement this feature at the moment because you can just hack around it.

dataset = Dataset(config)
return dataset

Expand Down
64 changes: 64 additions & 0 deletions examples/vae/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Example

* explore dataset using `edexplore`
* on mnist
![assets/edexplore_mnist.gif](assets/edexplore_mnist.gif)

* on deepfashion, which has additional annotations such as segmentation and IUV flow.
![assets/edexplore_df.gif](assets/edexplore_df.gif)

```
export STREAMLIT_SERVER_PORT=8080
edexplore -b vae/config_explore.yaml
```


* train and evaluate model
```
edflow -b vae/config.yaml -t # abort at some point
edflow -b vae/config.yaml -p logs/xxx # will also trigger evaluation
```

* will generate FID outputs

![assets/FID_eval.png](assets/FID_eval.png)


## Working with MetaDatasets

* load evaluation outputs using the MetaDataset. Open `ipython`

```python
from edflow.data.believers.meta import MetaDataset
M = MetaDataset("logs/xxx/eval/yyy/zzz/model_outputs")
print(M)
# +--------------------+----------+-----------+
# | Name | Type | Content |
# +====================+==========+===========+
# | rec_loss | memmap | (100,) |
# +--------------------+----------+-----------+
# | kl_loss | memmap | (100,) |
# +--------------------+----------+-----------+
# | reconstructions_ | memmap | (100,) |
# +--------------------+----------+-----------+
# | samples_ | memmap | (100,) |
# +--------------------+----------+-----------+

M.num_examples
>>> 100

M[0]["labels_"]
>>> {
'rec_loss': 0.3759992,
'kl_loss': 1.5367432,
'reconstructions_': 'logs/xxx/eval/yyy/zzz/model_outputs/reconstructions_000000.png',
'samples_': 'logs/xxx/eval/yyy/zzz/model_outputs/samples_000000.png'
}

# images are loaded lazily
M[0]["reconstructions"]
>>> <function edflow.data.believers.meta_loaders.image_loader.<locals>.loader(support='0->255', resize_to=None, root='')>

M[0]["reconstructions"]().shape
>>>(256, 256, 3)
```
Binary file added examples/vae/assets/FID_eval.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/vae/assets/edexplore_df.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/vae/assets/edexplore_mnist.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 26 additions & 0 deletions examples/vae/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
datasets:
train: vae.datasets.Deepfashion_Img
validation: vae.datasets.Deepfashion_Img_Val
model: vae.edflow.Model
iterator: vae.edflow.Iterator


batch_size: 25
num_epochs: 2

spatial_size: 256
in_channels: 3
latent_dim: 128
hidden_dims: [32, 64, 64, 128, 256, 256, 512]


fid:
batch_size: 50
fid_stats:
pre_calc_stat_path: "fid_stats/deepfashion.npz"


fixed_example_indices: {
"train" : [0, 1, 2, 3],
"validation" : [0, 1, 2, 3]
}
3 changes: 3 additions & 0 deletions examples/vae/config_explore.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
datasets:
train: vae.datasets.Deepfashion
validation: vae.datasets.Deepfashion
136 changes: 136 additions & 0 deletions examples/vae/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from edflow.data.dataset import PRNGMixin, DatasetMixin
from edflow.util import retrieve
import albumentations
import os
import pandas as pd
import cv2
from edflow.data.util import adjust_support
import numpy as np


COLORS = np.array(
[
[0, 0, 0],
[0, 0, 255],
[50, 205, 50],
[139, 78, 16],
[144, 238, 144],
[211, 211, 211],
[250, 250, 255],
]
)
W = np.power(255, [0, 1, 2])

HASHES = np.sum(W * COLORS, axis=-1)
HASH2COLOR = {h: c for h, c in zip(HASHES, COLORS)}
HASH2IDX = {h: i for i, h in enumerate(HASHES)}


def rgb2index(segmentation_rgb):
"""
turn a 3 channel RGB color to 1 channel index color
"""
s_shape = segmentation_rgb.shape
s_hashes = np.sum(W * segmentation_rgb, axis=-1)
print(np.unique(segmentation_rgb.reshape((-1, 3)), axis=0))
func = lambda x: HASH2IDX[int(x)]
segmentation_idx = np.apply_along_axis(func, 0, s_hashes.reshape((1, -1)))
segmentation_idx = segmentation_idx.reshape(s_shape[:2])
return segmentation_idx


class Deepfashion(DatasetMixin, PRNGMixin):
def __init__(self, config):
self.size = retrieve(config, "spatial_size", default=256)
self.root = os.path.join("data", "deepfashion")
self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
self.preprocessor = albumentations.Compose([self.rescaler])

df = pd.read_csv(
os.path.join(self.root, "Anno", "list_bbox_inshop.txt"),
skiprows=1,
sep="\s+",
)
self.fnames = list(df.image_name)

def __len__(self):
return len(self.fnames)

def imread(self, path):
img = cv2.imread(path, -1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

def get_example(self, i):
fname = self.fnames[i]
fname2 = "/".join(fname.split("/")[1:])

fname_iuv, _ = os.path.splitext(fname2)
fname_iuv = fname_iuv + "_IUV.png"

fname_segmentation, _ = os.path.splitext(fname2)
fname_segmentation = fname_segmentation + "_segment.png"

img_path = os.path.join(self.root, "Img", fname)
segmentation_path = os.path.join(
self.root, "Anno", "segmentation", "img_highres", fname_segmentation
)
iuv_path = os.path.join(self.root, "Anno", "densepose", "img_iuv", fname_iuv)

img = self.imread(img_path)
img = adjust_support(img, "-1->1", "0->255")

segmentation = cv2.imread(segmentation_path, -1)[:, :, :3]
segmentation = rgb2index(
segmentation
) # TODO: resizing changes aspect ratio, which might not be okay
segmentation = cv2.resize(
segmentation, (self.size, self.size), interpolation=cv2.INTER_NEAREST
)
iuv = self.imread(iuv_path)

example = {"img": img, "segmentation": segmentation, "iuv": iuv}

return example


class Deepfashion_Img(DatasetMixin, PRNGMixin):
def __init__(self, config):
self.size = retrieve(config, "spatial_size", default=256)
self.root = os.path.join("data", "deepfashion")
self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
self.preprocessor = albumentations.Compose([self.rescaler])

df = pd.read_csv(
os.path.join(self.root, "Anno", "list_bbox_inshop.txt"),
skiprows=1,
sep="\s+",
)
self.fnames = list(df.image_name)

def __len__(self):
return len(self.fnames)

def imread(self, path):
img = cv2.imread(path, -1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

def get_example(self, i):
fname = self.fnames[i]

img_path = os.path.join(self.root, "Img", fname)

img = self.imread(img_path)
img = adjust_support(img, "-1->1", "0->255")

example = {
"img": img.astype(np.float32),
}

return example


class Deepfashion_Img_Val(Deepfashion_Img):
def __len__(self):
return 100
Loading