-
Notifications
You must be signed in to change notification settings - Fork 13
fix edexplore #276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
fix edexplore #276
Changes from all commits
6619a6e
7916c44
f80a091
879998e
896c10f
957ecce
26a14e6
db1aa2d
ceb3d7b
0abc4d0
8fa8d93
ad875a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,8 @@ | |
| ) | ||
| from edflow import get_obj_from_str | ||
| from edflow.data.dataset_mixin import DatasetMixin | ||
| from edflow.util import contains_key, retrieve | ||
| from edflow.data.util import adjust_support | ||
|
|
||
|
|
||
| def display_default(obj): | ||
|
|
@@ -61,7 +63,11 @@ def display(key, obj): | |
| st.text(obj) | ||
|
|
||
| elif sel == "Image": | ||
| st.image((obj + 1.0) / 2.0) | ||
| try: | ||
| st.image((obj + 1.0) / 2.0) | ||
| except RuntimeError: | ||
| obj = adjust_support(obj, "-1->1", "0->255") | ||
| st.image((obj + 1.0) / 2.0) | ||
|
|
||
| elif sel == "Flow": | ||
| display_flow(obj, key) | ||
|
|
@@ -71,6 +77,19 @@ def display(key, obj): | |
| img = obj[:, :, idx].astype(np.float) | ||
| st.image(img) | ||
|
|
||
| elif sel == "Segmentation Flat": | ||
| idx = st.number_input("Segmentation Index", 0, 255, 0) | ||
| img = (obj[:, :].astype(np.int) == idx).astype(np.float) | ||
| st.image(img) | ||
|
|
||
| elif sel == "IUV": | ||
| # TODO: implement different visualization | ||
| try: | ||
| st.image((obj + 1.0) / 2.0) | ||
| except RuntimeError: | ||
| obj = adjust_support(obj, "-1->1", "0->255") | ||
| st.image((obj + 1.0) / 2.0) | ||
|
|
||
|
|
||
| def selector(key, obj): | ||
| """Show select box to choose display mode of obj in streamlit | ||
|
|
@@ -87,7 +106,16 @@ def selector(key, obj): | |
| str | ||
| Selected display method for item | ||
| """ | ||
| options = ["Auto", "Text", "Image", "Flow", "Segmentation", "None"] | ||
| options = [ | ||
| "Auto", | ||
| "Text", | ||
| "Image", | ||
| "Flow", | ||
| "Segmentation", | ||
| "Segmentation Flat", | ||
| "IUV", | ||
| "None", | ||
| ] | ||
| idx = options.index(display_default(obj)) | ||
| select = st.selectbox("Display {} as".format(key), options, index=idx) | ||
| return select | ||
|
|
@@ -201,8 +229,9 @@ def show_example(dset, idx, config): | |
|
|
||
| # additional visualizations | ||
| default_additional_visualizations = retrieve( | ||
| config, "edexplore/visualizations", default=dict() | ||
| config, "edexplore/visualizations", default={} | ||
| ).keys() | ||
| default_additional_visualizations = list(default_additional_visualizations) | ||
| additional_visualizations = st.sidebar.multiselect( | ||
| "Additional visualizations", | ||
| list(ADDITIONAL_VISUALIZATIONS.keys()), | ||
|
|
@@ -226,7 +255,13 @@ def show_example(dset, idx, config): | |
|
|
||
|
|
||
| def _get_state(config): | ||
| Dataset = get_obj_from_str(config["dataset"]) | ||
| if contains_key(config, "dataset"): | ||
| Dataset = get_obj_from_str(config["dataset"]) | ||
| elif contains_key(config, "datasets/train"): | ||
| module_name = retrieve(config, "datasets/train") | ||
| Dataset = get_obj_from_str(module_name) | ||
| else: | ||
| raise KeyError | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we should also add the possiblity to explore the validation dataset? Maybe with a selector in the sidebar?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was only fixing a bug that got introduced here. I am not willing to implement a new feature at this point. I think it is not important enough to implement this feature at the moment because you can just hack around it. |
||
| dataset = Dataset(config) | ||
| return dataset | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| # Example | ||
|
|
||
| * explore dataset using `edexplore` | ||
| * on mnist | ||
|  | ||
|
|
||
| * on deepfashion, which has additional annotations such as segmentation and IUV flow. | ||
|  | ||
|
|
||
| ``` | ||
| export STREAMLIT_SERVER_PORT=8080 | ||
| edexplore -b vae/config_explore.yaml | ||
| ``` | ||
|
|
||
|
|
||
| * train and evaluate model | ||
| ``` | ||
| edflow -b vae/config.yaml -t # abort at some point | ||
| edflow -b vae/config.yaml -p logs/xxx # will also trigger evaluation | ||
| ``` | ||
|
|
||
| * will generate FID outputs | ||
|
|
||
|  | ||
|
|
||
|
|
||
| ## Working with MetaDatasets | ||
|
|
||
| * load evaluation outputs using the MetaDataset. Open `ipython` | ||
|
|
||
| ```python | ||
| from edflow.data.believers.meta import MetaDataset | ||
| M = MetaDataset("logs/xxx/eval/yyy/zzz/model_outputs") | ||
| print(M) | ||
| # +--------------------+----------+-----------+ | ||
| # | Name | Type | Content | | ||
| # +====================+==========+===========+ | ||
| # | rec_loss | memmap | (100,) | | ||
| # +--------------------+----------+-----------+ | ||
| # | kl_loss | memmap | (100,) | | ||
| # +--------------------+----------+-----------+ | ||
| # | reconstructions_ | memmap | (100,) | | ||
| # +--------------------+----------+-----------+ | ||
| # | samples_ | memmap | (100,) | | ||
| # +--------------------+----------+-----------+ | ||
|
|
||
| M.num_examples | ||
| >>> 100 | ||
|
|
||
| M[0]["labels_"] | ||
| >>> { | ||
| 'rec_loss': 0.3759992, | ||
| 'kl_loss': 1.5367432, | ||
| 'reconstructions_': 'logs/xxx/eval/yyy/zzz/model_outputs/reconstructions_000000.png', | ||
| 'samples_': 'logs/xxx/eval/yyy/zzz/model_outputs/samples_000000.png' | ||
| } | ||
|
|
||
| # images are loaded lazily | ||
| M[0]["reconstructions"] | ||
| >>> <function edflow.data.believers.meta_loaders.image_loader.<locals>.loader(support='0->255', resize_to=None, root='')> | ||
|
|
||
| M[0]["reconstructions"]().shape | ||
| >>>(256, 256, 3) | ||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| datasets: | ||
| train: vae.datasets.Deepfashion_Img | ||
| validation: vae.datasets.Deepfashion_Img_Val | ||
| model: vae.edflow.Model | ||
| iterator: vae.edflow.Iterator | ||
|
|
||
|
|
||
| batch_size: 25 | ||
| num_epochs: 2 | ||
|
|
||
| spatial_size: 256 | ||
| in_channels: 3 | ||
| latent_dim: 128 | ||
| hidden_dims: [32, 64, 64, 128, 256, 256, 512] | ||
|
|
||
|
|
||
| fid: | ||
| batch_size: 50 | ||
| fid_stats: | ||
| pre_calc_stat_path: "fid_stats/deepfashion.npz" | ||
|
|
||
|
|
||
| fixed_example_indices: { | ||
| "train" : [0, 1, 2, 3], | ||
| "validation" : [0, 1, 2, 3] | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| datasets: | ||
| train: vae.datasets.Deepfashion | ||
| validation: vae.datasets.Deepfashion |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| from edflow.data.dataset import PRNGMixin, DatasetMixin | ||
| from edflow.util import retrieve | ||
| import albumentations | ||
| import os | ||
| import pandas as pd | ||
| import cv2 | ||
| from edflow.data.util import adjust_support | ||
| import numpy as np | ||
|
|
||
|
|
||
| COLORS = np.array( | ||
| [ | ||
| [0, 0, 0], | ||
| [0, 0, 255], | ||
| [50, 205, 50], | ||
| [139, 78, 16], | ||
| [144, 238, 144], | ||
| [211, 211, 211], | ||
| [250, 250, 255], | ||
| ] | ||
| ) | ||
| W = np.power(255, [0, 1, 2]) | ||
|
|
||
| HASHES = np.sum(W * COLORS, axis=-1) | ||
| HASH2COLOR = {h: c for h, c in zip(HASHES, COLORS)} | ||
| HASH2IDX = {h: i for i, h in enumerate(HASHES)} | ||
|
|
||
|
|
||
| def rgb2index(segmentation_rgb): | ||
| """ | ||
| turn a 3 channel RGB color to 1 channel index color | ||
| """ | ||
| s_shape = segmentation_rgb.shape | ||
| s_hashes = np.sum(W * segmentation_rgb, axis=-1) | ||
| print(np.unique(segmentation_rgb.reshape((-1, 3)), axis=0)) | ||
| func = lambda x: HASH2IDX[int(x)] | ||
| segmentation_idx = np.apply_along_axis(func, 0, s_hashes.reshape((1, -1))) | ||
| segmentation_idx = segmentation_idx.reshape(s_shape[:2]) | ||
| return segmentation_idx | ||
|
|
||
|
|
||
| class Deepfashion(DatasetMixin, PRNGMixin): | ||
| def __init__(self, config): | ||
| self.size = retrieve(config, "spatial_size", default=256) | ||
| self.root = os.path.join("data", "deepfashion") | ||
| self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) | ||
| self.preprocessor = albumentations.Compose([self.rescaler]) | ||
|
|
||
| df = pd.read_csv( | ||
| os.path.join(self.root, "Anno", "list_bbox_inshop.txt"), | ||
| skiprows=1, | ||
| sep="\s+", | ||
| ) | ||
| self.fnames = list(df.image_name) | ||
|
|
||
| def __len__(self): | ||
| return len(self.fnames) | ||
|
|
||
| def imread(self, path): | ||
| img = cv2.imread(path, -1) | ||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
| return img | ||
|
|
||
| def get_example(self, i): | ||
| fname = self.fnames[i] | ||
| fname2 = "/".join(fname.split("/")[1:]) | ||
|
|
||
| fname_iuv, _ = os.path.splitext(fname2) | ||
| fname_iuv = fname_iuv + "_IUV.png" | ||
|
|
||
| fname_segmentation, _ = os.path.splitext(fname2) | ||
| fname_segmentation = fname_segmentation + "_segment.png" | ||
|
|
||
| img_path = os.path.join(self.root, "Img", fname) | ||
| segmentation_path = os.path.join( | ||
| self.root, "Anno", "segmentation", "img_highres", fname_segmentation | ||
| ) | ||
| iuv_path = os.path.join(self.root, "Anno", "densepose", "img_iuv", fname_iuv) | ||
|
|
||
| img = self.imread(img_path) | ||
| img = adjust_support(img, "-1->1", "0->255") | ||
|
|
||
| segmentation = cv2.imread(segmentation_path, -1)[:, :, :3] | ||
| segmentation = rgb2index( | ||
| segmentation | ||
| ) # TODO: resizing changes aspect ratio, which might not be okay | ||
| segmentation = cv2.resize( | ||
| segmentation, (self.size, self.size), interpolation=cv2.INTER_NEAREST | ||
| ) | ||
| iuv = self.imread(iuv_path) | ||
|
|
||
| example = {"img": img, "segmentation": segmentation, "iuv": iuv} | ||
|
|
||
| return example | ||
|
|
||
|
|
||
| class Deepfashion_Img(DatasetMixin, PRNGMixin): | ||
| def __init__(self, config): | ||
| self.size = retrieve(config, "spatial_size", default=256) | ||
| self.root = os.path.join("data", "deepfashion") | ||
| self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) | ||
| self.preprocessor = albumentations.Compose([self.rescaler]) | ||
|
|
||
| df = pd.read_csv( | ||
| os.path.join(self.root, "Anno", "list_bbox_inshop.txt"), | ||
| skiprows=1, | ||
| sep="\s+", | ||
| ) | ||
| self.fnames = list(df.image_name) | ||
|
|
||
| def __len__(self): | ||
| return len(self.fnames) | ||
|
|
||
| def imread(self, path): | ||
| img = cv2.imread(path, -1) | ||
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | ||
| return img | ||
|
|
||
| def get_example(self, i): | ||
| fname = self.fnames[i] | ||
|
|
||
| img_path = os.path.join(self.root, "Img", fname) | ||
|
|
||
| img = self.imread(img_path) | ||
| img = adjust_support(img, "-1->1", "0->255") | ||
|
|
||
| example = { | ||
| "img": img.astype(np.float32), | ||
| } | ||
|
|
||
| return example | ||
|
|
||
|
|
||
| class Deepfashion_Img_Val(Deepfashion_Img): | ||
| def __len__(self): | ||
| return 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://stackoverflow.com/a/5790954
indicates a performance improvement for this change. For consisntency however, I would not change only a single instance.
should we change dict() to {} then everywhere at leas in this file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this does not matter. I was just fixing a bug when default was not given an empty dict.