diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..262859f --- /dev/null +++ b/.gitignore @@ -0,0 +1,140 @@ + +# Created by https://www.toptal.com/developers/gitignore/api/python +# Edit at https://www.toptal.com/developers/gitignore?templates=python + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +.idea/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# End of https://www.toptal.com/developers/gitignore/api/python diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..90f8cc1 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,10 @@ +[LOGGING] + +# Format style used to check logging format string. `old` means using % +# formatting, while `new` is for `{}` formatting. +#logging-format-style=fstr + +[MESSAGE CONTROL] +disable= + logging-fstring-interpolation, + logging-format-interpolation diff --git a/README-dev.md b/README-dev.md new file mode 100644 index 0000000..a2afe84 --- /dev/null +++ b/README-dev.md @@ -0,0 +1,17 @@ +# DSen2 +## Setup +Create a new virtual environment e.g. using [virtualenvwrapper](https://virtualenvwrapper.readthedocs.io/en/latest/): +```bash +mkvirtualenv --python=$(which python3.7) dsen2 +``` +Install requirements: +```bash +pip install -r requirements.txt +pip install -r requirements-dev.txt +``` + +## Testing +To run tests: +```bash +bash test.sh +``` diff --git a/S2L1C_test_img_ids.txt b/S2L1C_test_img_ids.txt new file mode 100644 index 0000000..2c4318f --- /dev/null +++ b/S2L1C_test_img_ids.txt @@ -0,0 +1,15 @@ +S2B_MSIL1C_20200131T050029_N0208_R119_T45RTK_20200131T074329 +S2B_MSIL1C_20200101T082239_N0208_R121_T34HCJ_20200101T102510 +S2B_MSIL1C_20200103T024319_N0208_R003_T49JFM_20200103T054632 +S2B_MSIL1C_20200112T143449_N0208_R139_T23VMG_20200112T162018 +S2B_MSIL1C_20200116T155609_N0208_R054_T18TVK_20200116T192635 +S2B_MSIL1C_20200112T144659_N0208_R139_T19LED_20200112T181024 +S2B_MSIL1C_20200210T064019_N0209_R120_T43WEP_20200210T102521 +S2B_MSIL1C_20200105T132229_N0208_R038_T22JEP_20200105T140318 +S2A_MSIL1C_20200102T223701_N0208_R072_T59GLM_20200102T233015 +S2B_MSIL1C_20200117T102249_N0208_R065_T33UVA_20200117T113858 +S2B_MSIL1C_20200112T111329_N0208_R137_T30UWB_20200112T114220 +S2B_MSIL1C_20200131T050029_N0208_R119_T45RTK_20200131T074329 +S2A_MSIL1C_20200104T160641_N0208_R097_T16PET_20200104T193742 +S2A_MSIL1C_20200105T021051_N0208_R103_T52SEC_20200105T040308 +S2A_MSIL1C_20200104T074311_N0208_R092_T37NDF_20200104T092554 diff --git a/S2L1C_training_img_ids.txt b/S2L1C_training_img_ids.txt new file mode 100644 index 0000000..cef58e1 --- /dev/null +++ b/S2L1C_training_img_ids.txt @@ -0,0 +1,44 @@ +S2B_MSIL1C_20200708T070629_N0209_R106_T39RVJ_20200708T100303 +S2A_MSIL1C_20200620T183921_N0209_R070_T11UQV_20200620T222330 +S2B_MSIL1C_20200702T032539_N0209_R018_T47NQG_20200702T070403 +S2A_MSIL1C_20200622T155911_N0209_R097_T17RMN_20200622T211519 +S2A_MSIL1C_20200228T135111_N0209_R024_T22MDE_20200228T153108 +S2B_MSIL1C_20200509T022329_N0209_R103_T51PVN_20200509T035328 +S2A_MSIL1C_20200702T092031_N0209_R093_T34SFH_20200702T105153 +S2A_MSIL1C_20200513T023551_N0209_R089_T50RQV_20200513T042651 +S2A_MSIL1C_20200623T103031_N0209_R108_T32TLT_20200623T124659 +S2B_MSIL1C_20200322T110639_N0209_R137_T28PGQ_20200322T131534 +S2A_MSIL1C_20200527T185921_N0209_R013_T10TER_20200528T002151 +S2A_MSIL1C_20200707T071211_N0209_R020_T38KNF_20200707T095401 +S2B_MSIL1C_20200706T180919_N0209_R084_T12SUF_20200706T214131 +S2B_MSIL1C_20200707T092029_N0209_R093_T33QYE_20200707T114047 +S2B_MSIL1C_20200518T143729_N0209_R096_T19HCC_20200518T175905 +S2A_MSIL1C_20200703T170901_N0209_R112_T14SPB_20200703T204911 +S2B_MSIL1C_20200119T131859_N0208_R095_T19DFF_20200119T142732 +S2A_MSIL1C_20200628T000251_N0209_R030_T56HKH_20200628T012541 +S2A_MSIL1C_20200427T035541_N0209_R004_T48UXU_20200427T065744 +S2A_MSIL1C_20200708T075611_N0209_R035_T38UMU_20200708T093146 +S2B_MSIL1C_20200624T040549_N0209_R047_T47SNB_20200624T065011 +S2B_MSIL1C_20200302T093029_N0209_R136_T32MPE_20200302T115907 +S2A_MSIL1C_20200705T074621_N0209_R135_T37QED_20200705T095040 +S2A_MSIL1C_20200610T003711_N0209_R059_T55LBD_20200610T021551 +S2A_MSIL1C_20200610T003711_N0209_R059_T54HUG_20200610T021551 +S2A_MSIL1C_20200518T101031_N0209_R022_T33UVR_20200518T121146 +S2B_MSIL1C_20200703T161829_N0209_R040_T17TLH_20200703T195539 +S2A_MSIL1C_20200706T141741_N0209_R010_T18FXG_20200706T174147 +S2A_MSIL1C_20200704T151711_N0209_R125_T19PCN_20200704T183937 +S2B_MSIL1C_20200705T140059_N0209_R067_T21LYJ_20200705T171900 +S2B_MSIL1C_20200620T060639_N0209_R134_T42SWH_20200620T085814 +S2B_MSIL1C_20200707T105619_N0209_R094_T30SVH_20200707T130404 +S2B_MSIL1C_20200705T165849_N0209_R069_T14RNN_20200705T203718 +S2B_MSIL1C_20200703T161829_N0209_R040_T17UPT_20200703T195539 +S2A_MSIL1C_20200627T032541_N0209_R018_T48QWH_20200627T062952 +S2B_MSIL1C_20200708T170849_N0209_R112_T15TTF_20200708T203645 +S2A_MSIL1C_20200628T013721_N0209_R031_T52LEH_20200628T030413 +S2A_MSIL1C_20200126T032011_N0208_R118_T48NUG_20200126T061348 +S2A_MSIL1C_20200524T052651_N0209_R105_T43QCA_20200524T090700 +S2B_MSIL1C_20200703T112119_N0209_R037_T29SNC_20200703T132109 +S2A_MSIL1C_20200702T092031_N0209_R093_T34SDA_20200702T112247 +S2A_MSIL1C_20200622T105631_N0209_R094_T30TXQ_20200622T130553 +S2A_MSIL1C_20200402T050601_N0209_R076_T44NNN_20200402T074005 +S2B_MSIL1C_20200509T102559_N0209_R108_T32TMT_20200509T124301 \ No newline at end of file diff --git a/S2L2A_test_img_ids.txt b/S2L2A_test_img_ids.txt new file mode 100644 index 0000000..0bb00d0 --- /dev/null +++ b/S2L2A_test_img_ids.txt @@ -0,0 +1,15 @@ +S2B_MSIL2A_20200805T044709_N0214_R076_T45RTK_20200805T090405 +S2B_MSIL2A_20200808T081609_N0214_R121_T34HDK_20200808T122511 +S2A_MSIL2A_20200805T024331_N0214_R003_T49JEN_20200805T042709 +S2B_MSIL2A_20200806T141739_N0214_R096_T23VNH_20200806T164315 +S2A_MSIL2A_20200805T153911_N0214_R011_T18TWK_20200805T200607 +S2B_MSIL2A_20200809T144729_N0214_R139_T19LFE_20200809T190255 +S2A_MSIL2A_20200806T064631_N0214_R020_T43WEP_20200806T083125 +S2A_MSIL2A_20200807T132241_N0214_R038_T22JFP_20200807T154317 +S2A_MSIL2A_20200809T223721_N0214_R072_T59GMN_20200810T002854 +S2A_MSIL2A_20200809T102031_N0214_R065_T33UWA_20200809T130506 +S2B_MSIL2A_20200809T110629_N0214_R137_T30UXC_20200809T140506 +S2B_MSIL2A_20200805T044709_N0214_R076_T45RTK_20200805T090405 +S2A_MSIL2A_20200804T160911_N0214_R140_T16PET_20200804T215526 +S2B_MSIL2A_20200804T015659_N0214_R060_T52SED_20200804T044656 +S2A_MSIL2A_20200801T073621_N0214_R092_T37NCE_20200801T102406 diff --git a/S2L2A_training_img_ids.txt b/S2L2A_training_img_ids.txt new file mode 100644 index 0000000..4c65194 --- /dev/null +++ b/S2L2A_training_img_ids.txt @@ -0,0 +1,44 @@ +S2B_MSIL2A_20200119T131859_N0213_R095_T19DFF_20200119T145349 +S2A_MSIL2A_20200126T032011_N0213_R118_T48NUG_20200126T071110 +S2A_MSIL2A_20200228T135111_N0214_R024_T22MDE_20200228T162124 +S2B_MSIL2A_20200302T093029_N0214_R136_T32MPE_20200302T124106 +S2B_MSIL2A_20200322T110639_N0214_R137_T28PGQ_20200322T140428 +S2A_MSIL2A_20200402T050601_N0214_R076_T44NNN_20200402T080805 +S2A_MSIL2A_20200427T035541_N0214_R004_T48UXU_20200427T080955 +S2A_MSIL2A_20200513T023551_N0214_R089_T50RQV_20200513T045417 +S2A_MSIL2A_20200518T101031_N0214_R022_T33UVR_20200518T130110 +S2B_MSIL2A_20200518T143729_N0214_R096_T19HCC_20200518T184649 +S2A_MSIL2A_20200524T052651_N0214_R105_T43QCA_20200524T094211 +S2A_MSIL2A_20200527T185921_N0214_R013_T10TER_20200528T010342 +S2B_MSIL2A_20200509T022329_N0214_R103_T51PVN_20200509T041832 +S2B_MSIL2A_20200509T102559_N0214_R108_T32TMT_20200509T135055 +S2A_MSIL2A_20200610T003711_N0214_R059_T54HUG_20200610T032520 +S2A_MSIL2A_20200610T003711_N0214_R059_T55LBD_20200610T032520 +S2A_MSIL2A_20200620T183921_N0214_R070_T11UQV_20200620T232224 +S2B_MSIL2A_20200620T060639_N0214_R134_T42SWH_20200620T095034 +S2A_MSIL2A_20200622T105631_N0214_R094_T30TXQ_20200622T135503 +S2A_MSIL2A_20200622T155911_N0214_R097_T17RMN_20200622T222624 +S2A_MSIL2A_20200623T103031_N0214_R108_T32TLT_20200623T142851 +S2B_MSIL2A_20200624T040549_N0214_R047_T47SNB_20200624T074104 +S2A_MSIL2A_20200627T032541_N0214_R018_T48QWH_20200627T073329 +S2A_MSIL2A_20200628T000251_N0214_R030_T56HKH_20200628T020102 +S2A_MSIL2A_20200628T013721_N0214_R031_T52LEH_20200628T033123 +S2A_MSIL2A_20200702T092031_N0214_R093_T34SDA_20200702T121048 +S2A_MSIL2A_20200702T092031_N0214_R093_T34SFH_20200702T114806 +S2B_MSIL2A_20200702T032539_N0214_R018_T47NQG_20200702T073845 +S2A_MSIL2A_20200703T170901_N0214_R112_T14SPB_20200703T213059 +S2B_MSIL2A_20200703T112119_N0214_R037_T29SNC_20200703T140926 +S2B_MSIL2A_20200703T161829_N0214_R040_T17TLH_20200703T203651 +S2B_MSIL2A_20200703T161829_N0214_R040_T17UPT_20200703T203651 +S2A_MSIL2A_20200704T151711_N0214_R125_T19PCN_20200704T192121 +S2A_MSIL2A_20200705T074621_N0214_R135_T37QED_20200705T103234 +S2B_MSIL2A_20200705T140059_N0214_R067_T21LYJ_20200705T180214 +S2B_MSIL2A_20200705T165849_N0214_R069_T14RNN_20200705T211456 +S2A_MSIL2A_20200706T141741_N0214_R010_T18FXG_20200706T183007 +S2B_MSIL2A_20200706T180919_N0214_R084_T12SUF_20200706T222223 +S2A_MSIL2A_20200707T071211_N0214_R020_T38KNF_20200707T102832 +S2B_MSIL2A_20200707T092029_N0214_R093_T33QYE_20200707T122716 +S2B_MSIL2A_20200707T105619_N0214_R094_T30SVH_20200707T135135 +S2A_MSIL2A_20200708T075611_N0214_R035_T38UMU_20200708T111246 +S2B_MSIL2A_20200708T070629_N0214_R106_T39RVJ_20200708T111714 +S2B_MSIL2A_20200708T170849_N0214_R112_T15TTF_20200708T210844 \ No newline at end of file diff --git a/models/aesr_20m_s2_038_lr_1e-04.hdf5 b/models/aesr_20m_s2_038_lr_1e-04.hdf5 new file mode 100644 index 0000000..844ead5 Binary files /dev/null and b/models/aesr_20m_s2_038_lr_1e-04.hdf5 differ diff --git a/models/aesr_60m_s2_038_lr_1e-04.hdf5 b/models/aesr_60m_s2_038_lr_1e-04.hdf5 new file mode 100644 index 0000000..6a36197 Binary files /dev/null and b/models/aesr_60m_s2_038_lr_1e-04.hdf5 differ diff --git a/pylintrc b/pylintrc new file mode 100644 index 0000000..1df3ed1 --- /dev/null +++ b/pylintrc @@ -0,0 +1,44 @@ +[MASTER] +init-hook='import glob; [sys.path.append(d) for d in glob.glob("*/") if not d.startswith("_")]' + +[MESSAGE CONTROL] +disable= + missing-docstring, + no-else-return, + too-few-public-methods, + missing-final-newline, + too-many-boolean-expressions, + bad-continuation, + invalid-name, + super-init-not-called, + inconsistent-return-statements, + too-many-arguments, + too-many-locals, + protected-access, + redefined-outer-name, + too-many-instance-attributes, + fixme, + wrong-import-position, + logging-fstring-interpolation, + logging-format-interpolation + +[FORMAT] +max-line-length=120 +single-line-if-stmt=yes +include-naming-hint=yes +function-rgx=[a-z_][a-z0-9_]*$ +argument-rgx=[a-z_][a-z0-9_]*$ +variable-rgx=[a-z_][a-z0-9_]*$ +# "logger" and "api" are common module-level globals, and not true 'constants' +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__)|logger|api|_api)$ + +[DESIGN] +max-args=6 +ignored-argument-names=_.*|self + +[SIMILARITIES] +# Minimum lines number of a similarity. +min-similarity-lines=20 # TODO: Reset lower when pylint bug fixed #214. +ignore-comments=yes +ignore-docstrings=yes +ignore-imports=no diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..e82eadc --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,10 @@ +black +pylint==2.5.0 +pytest +pytest-pylint +pytest-sugar +mypy +mypy-extensions +pytest-cov +pytest-mypy +coverage-badge diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..3483002 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +numpy +tensorflow +keras +scikit-image +imageio +rasterio +pyproj +matplotlib +pydot +up42-blockutils +image-similarity-measures diff --git a/test.sh b/test.sh new file mode 100644 index 0000000..f828ca1 --- /dev/null +++ b/test.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +rm -r .pytest_cache +black . +python -m pytest --pylint --pylint-rcfile=pylintrc --mypy --mypy-ignore-missing-imports --durations=5 +coverage-badge -f -o coverage.svg +RET_VALUE=$? +exit $RET_VALUE diff --git a/testing/demoDSen2.py b/testing/demoDSen2.py index 079dba1..bbb6958 100644 --- a/testing/demoDSen2.py +++ b/testing/demoDSen2.py @@ -1,26 +1,29 @@ from __future__ import absolute_import +import sys + import h5py import matplotlib.pyplot as plt import numpy as np -from supres import DSen2_20, DSen2_60 -import sys -sys.path.append('../') -from utils.imresize import imresize +from supres import dsen2_20, dsen2_60 -DATA_PATH = '../data/' +sys.path.append("../") +from utils.imresize import imresize + +DATA_PATH = "../data/" +# pylint: disable=unbalanced-tuple-unpacking def readh5(fname, im60=False, imGT=False): - with h5py.File(DATA_PATH+fname, 'r') as f: - d10 = f['im10'][()].transpose() - d20 = f['im20'][()].transpose() + with h5py.File(DATA_PATH + fname, "r") as f: + d10 = f["im10"][()].transpose() + d20 = f["im20"][()].transpose() if im60: - d60 = f['im60'][()].transpose() + d60 = f["im60"][()].transpose() if not imGT: return d10, d20, d60 if imGT: - dGT = f['imGT'][()].transpose() + dGT = f["imGT"][()].transpose() if im60: return d10, d20, d60, dGT else: @@ -29,104 +32,106 @@ def readh5(fname, im60=False, imGT=False): def RMSE(x1, x2): - diff = x1.astype(np.float64)-x2.astype(np.float64) + diff = x1.astype(np.float64) - x2.astype(np.float64) rms = np.sqrt(np.mean(np.power(diff, 2))) - print('RMSE: {:.4f}'.format(rms)) + print("RMSE: {:.4f}".format(rms)) return rms -if __name__ == '__main__': +if __name__ == "__main__": # Siberia, same area of Fig. 8 in the paper - print('Siberia') - im10, im20, imGT = readh5('S2B_MSIL1C_20170725_T43WFQ.mat', imGT=True) - SR20 = DSen2_20(im10, im20) + print("Siberia") + im10, im20, imGT = readh5("S2B_MSIL1C_20170725_T43WFQ.mat", imGT=True) + SR20 = dsen2_20(im10, im20) # Evaluation against the ground truth on the 20m resolution bands (simulated) - print('DSen2:') + print("DSen2:") RMSE(SR20, imGT) - print('Bicubic:') + print("Bicubic:") RMSE(imresize(im20, 2), imGT) fig = plt.figure(1) ax = fig.add_subplot(111) cax = ax.imshow(SR20[:, :, 2]) fig.colorbar(cax) - ax.set_title('Super-resolved band B6') + ax.set_title("Super-resolved band B6") fig = plt.figure(2) ax = fig.add_subplot(111) - cax = plt.imshow(np.abs(SR20[:, :, 4]-imGT[:, :, 4]), vmin=0, vmax=200) + cax = plt.imshow(np.abs(SR20[:, :, 4] - imGT[:, :, 4]), vmin=0, vmax=200) fig.colorbar(cax) - ax.set_title('Absolute differences to the GT, band B11') + ax.set_title("Absolute differences to the GT, band B11") plt.show(block=False) # # South Africa, same area of Fig. 9 in the paper - print('S. Africa') - im10, im20, im60, imGT = readh5('S2A_MSIL1C_20171028_T34HCH.mat', im60=True, imGT=True) - SR60 = DSen2_60(im10, im20, im60) + print("S. Africa") + im10, im20, im60, imGT = readh5( + "S2A_MSIL1C_20171028_T34HCH.mat", im60=True, imGT=True + ) + SR60 = dsen2_60(im10, im20, im60) # Evaluation against the ground truth on the 60m resolution bands (simulated) - print('DSen2:') + print("DSen2:") RMSE(SR60, imGT) - print('Bicubic:') + print("Bicubic:") RMSE(imresize(im60, 6), imGT) fig = plt.figure(3) ax = fig.add_subplot(111) - cax = plt.imshow(np.abs(SR60[:, :, 1]-imGT[:, :, 1]), vmin=0, vmax=200) + cax = plt.imshow(np.abs(SR60[:, :, 1] - imGT[:, :, 1]), vmin=0, vmax=200) fig.colorbar(cax) - ax.set_title('Absolute differences to the GT, band B9') + ax.set_title("Absolute differences to the GT, band B9") plt.show(block=False) # # New York, same area of Fig. 10 (bottom) in the paper # Here using the very deep variable (VDSen2) - print('New York') - im10, im20, imGT = readh5('S2B_MSIL1C_20170928_T18TWL.mat', im60=False, imGT=True) - SR20 = DSen2_20(im10, im20, deep=False) + print("New York") + im10, im20, imGT = readh5("S2B_MSIL1C_20170928_T18TWL.mat", im60=False, imGT=True) + SR20 = dsen2_20(im10, im20) # Evaluation against the ground truth on the 20m resolution bands (simulated) - print('DSen2:') - RMSE(SR20,imGT) - print('Bicubic:') + print("DSen2:") + RMSE(SR20, imGT) + print("Bicubic:") RMSE(imresize(im20, 2), imGT) # # Malmo, Sweden, same area of Fig. 10 (top) in the paper - print('Malmo, no ground truth') - im10, im20, im60 = readh5('S2A_MSIL1C_20170527_T33UUB.mat', im60=True, imGT=False) + print("Malmo, no ground truth") + im10, im20, im60 = readh5("S2A_MSIL1C_20170527_T33UUB.mat", im60=True, imGT=False) - SR20 = DSen2_20(im10, im20) - SR60 = DSen2_60(im10, im20, im60) + SR20 = dsen2_20(im10, im20) + SR60 = dsen2_60(im10, im20, im60) # No ground truth available, no simulation. Comparison to the low-res input fig = plt.figure(4) ax1 = fig.add_subplot(121) plt.imshow(im60[:, :, 0], vmin=np.min(im60[:, :, 0]), vmax=np.max(im60[:, :, 0])) - ax1.set_title('Band B1, input 60m') + ax1.set_title("Band B1, input 60m") ax2 = fig.add_subplot(122) plt.imshow(SR60[:, :, 0], vmin=np.min(im60[:, :, 0]), vmax=np.max(im60[:, :, 0])) - ax2.set_title('Band B1, 10m super-resolution') + ax2.set_title("Band B1, 10m super-resolution") plt.show(block=False) fig = plt.figure(5) ax1 = fig.add_subplot(121) plt.imshow(im20[:, :, 1], vmin=np.min(im20[:, :, 1]), vmax=np.max(im20[:, :, 1])) - ax1.set_title('Band B6, input 20m') + ax1.set_title("Band B6, input 20m") ax2 = fig.add_subplot(122) plt.imshow(SR20[:, :, 1], vmin=np.min(im20[:, :, 1]), vmax=np.max(im20[:, :, 1])) - ax2.set_title('Band B6, 10m super-resolution') + ax2.set_title("Band B6, 10m super-resolution") plt.show(block=False) # # Shark bay, Australia, same area of Fig. 10 (middle) in the paper - print('Shark Bay, no ground truth') - im10, im20, im60 = readh5('S2B_MSIL1C_20171022_T49JGM.mat', im60=True, imGT=False) - SR20 = DSen2_20(im10, im20) - SR60 = DSen2_60(im10, im20, im60) + print("Shark Bay, no ground truth") + im10, im20, im60 = readh5("S2B_MSIL1C_20171022_T49JGM.mat", im60=True, imGT=False) + SR20 = dsen2_20(im10, im20) + SR60 = dsen2_60(im10, im20, im60) # Stretching the image for better visualization for i in range(SR60.shape[2]): @@ -141,10 +146,10 @@ def RMSE(x1, x2): fig = plt.figure(6) ax1 = fig.add_subplot(121) plt.imshow(im60s) - ax1.set_title('Color composite (B1,B9,B1) \n 60m input') + ax1.set_title("Color composite (B1,B9,B1) \n 60m input") ax2 = fig.add_subplot(122) plt.imshow(imSR) - ax2.set_title('Color composite (B1,B9,B1) \n 10m super-resolution') + ax2.set_title("Color composite (B1,B9,B1) \n 10m super-resolution") # Stretching the image for better visualization imSR = SR20[:, :, [5, 3, 0]] @@ -159,9 +164,9 @@ def RMSE(x1, x2): fig = plt.figure(7) ax1 = fig.add_subplot(121) plt.imshow(im20s) - ax1.set_title('Color composite (B12,B8a,B5) \n 20m input') + ax1.set_title("Color composite (B12,B8a,B5) \n 20m input") ax2 = fig.add_subplot(122) plt.imshow(imSR) - ax2.set_title('Color composite (B12,B8a,B5) \n 10m super-resolution') + ax2.set_title("Color composite (B12,B8a,B5) \n 10m super-resolution") plt.show() diff --git a/testing/s2_tiles_supres.py b/testing/s2_tiles_supres.py index 01aed21..7e144de 100644 --- a/testing/s2_tiles_supres.py +++ b/testing/s2_tiles_supres.py @@ -1,421 +1,213 @@ from __future__ import division -import argparse -import numpy as np import os -import re import sys -from osgeo import gdal, osr -from collections import defaultdict -from supres import DSen2_20, DSen2_60 - -# This code is adapted from this repository http://nicolas.brodu.net/code/superres and is distributed under the same -# license. - -parser = argparse.ArgumentParser(description="Perform super-resolution on Sentinel-2 with DSen2. Code based on superres" - " by Nicolas Brodu.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("data_file", - help="An input sentinel-2 data file. This can be either the original ZIP file, or the S2A[...].xml " - "file in a SAFE directory extracted from that ZIP.") -parser.add_argument("output_file", nargs="?", - help="A target data file. See also the --save_prefix option, and the --output_file_format option " - "(default is GTiff).") -parser.add_argument("--roi_lon_lat", default="", - help="Sets the region of interest to extract, WGS84, decimal notation. Use this syntax: lon_1," - "lat_1,lon_2,lat_2. The order of points 1 and 2 does not matter: the region of interest " - "extends to the min/max in each direction. " - "Example: --roi_lon_lat=-1.12132,44.72408,-0.90350,44.58646") -parser.add_argument("--roi_x_y", default="", - help="Sets the region of interest to extract as pixels locations on the 10m bands. Use this " - "syntax: x_1,y_1,x_2,y_2. The order of points 1 and 2 does not matter: the region of interest " - "extends to the min/max in each direction and to nearby 60m pixel boundaries.") -parser.add_argument("--list_bands", action="store_true", - help="List bands in the input file subdata set matching the selected UTM zone, and exit.") -parser.add_argument("--run_60", action="store_true", - help="Select which bands to process and include in the output file. If this flag is set it will " - "super-resolve the 20m and 60m bands (B1,B2,B3,B4,B5,B6,B7,B8,B8A,B9,B11,B12). If it is not " - "set it will only super-resolve the 20m bands (B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12). Band B10 " - "is to noisy and is not super-resolved.") -parser.add_argument("--list_UTM", action="store_true", - help="List all UTM zones present in the input file, together with their coverage of the ROI in " - "10m x 10m pixels.") -parser.add_argument("--select_UTM", default="", - help="Select a UTM zone. The default is to select the zone with the largest coverage of the ROI.") -parser.add_argument("--list_output_file_formats", action="store_true", - help="If specified, list all supported raster output file formats declared by GDAL and exit. Some " - "of these formats may be inappropriate for storing Sentinel-2 multispectral data.") -parser.add_argument("--output_file_format", default="GTiff", - help="Speficies the name of a GDAL driver that supports file creation, like ENVI or GTiff. If no " - "such driver exists, or if the format is \"npz\", then save all bands instead as a compressed " - "python/numpy file") -parser.add_argument("--copy_original_bands", action="store_true", - help="The default is not to copy the original selected 10m bands into the output file in addition " - "to the super-resolved bands. If this flag is used, the output file may be used as a 10m " - "version of the original Sentinel-2 file.") -parser.add_argument("--save_prefix", default="", - help="If set, speficies the name of a prefix for all output files. Use a trailing / to save into a " - "directory. The default of no prefix will save into the current directory. " - "Example: --save_prefix result/") - - -args = parser.parse_args() -globals().update(args.__dict__) - -if list_output_file_formats: - dcount = gdal.GetDriverCount() - for didx in range(dcount): - driver = gdal.GetDriver(didx) - if driver: - metadata = driver.GetMetadata() - if (gdal.DCAP_CREATE in (driver and metadata) and metadata[gdal.DCAP_CREATE] == 'YES' and - gdal.DCAP_RASTER in metadata and metadata[gdal.DCAP_RASTER] == 'YES'): - name = driver.GetDescription() - if "DMD_LONGNAME" in metadata: - name += ": " + metadata["DMD_LONGNAME"] - else: - name = driver.GetDescription() - if "DMD_EXTENSIONS" in metadata: name += " (" + metadata["DMD_EXTENSIONS"] + ")" - print(name) - sys.exit(0) - -if run_60: - select_bands = 'B1,B2,B3,B4,B5,B6,B7,B8,B8A,B9,B11,B12' -else: - select_bands = 'B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12' - -# convert comma separated band list into a list -select_bands = [x for x in re.split(',', select_bands)] - -if roi_lon_lat: - roi_lon1, roi_lat1, roi_lon2, roi_lat2 = [float(x) for x in re.split(',', roi_lon_lat)] -else: - roi_lon1, roi_lat1, roi_lon2, roi_lat2 = -180, -90, 180, 90 - -if roi_x_y: - roi_x1, roi_y1, roi_x2, roi_y2 = [float(x) for x in re.split(',', roi_x_y)] - -raster = gdal.Open(data_file) - - -datasets = raster.GetSubDatasets(); -tenMsets = [] -twentyMsets = [] -sixtyMsets = [] -unknownMsets = [] -for (dsname, dsdesc) in datasets: - if '10m resolution' in dsdesc: - tenMsets += [(dsname, dsdesc)] - elif '20m resolution' in dsdesc: - twentyMsets += [(dsname, dsdesc)] - elif '60m resolution' in dsdesc: - sixtyMsets += [(dsname, dsdesc)] - else: - unknownMsets += [(dsname, dsdesc)] - -# case where we have several UTM in the data set -# => select the one with maximal coverage of the study zone -utm_idx = 0 -utm = select_UTM -all_utms = defaultdict(int) -xmin, ymin, xmax, ymax = 0, 0, 0, 0 -largest_area = -1 -# process even if there is only one 10m set, in order to get roi -> pixels -for (tmidx, (dsname, dsdesc)) in enumerate(tenMsets + unknownMsets): - ds = gdal.Open(dsname) - if roi_x_y: - tmxmin = max(min(roi_x1, roi_x2, ds.RasterXSize - 1), 0) - tmxmax = min(max(roi_x1, roi_x2, 0), ds.RasterXSize - 1) - tmymin = max(min(roi_y1, roi_y2, ds.RasterYSize - 1), 0) - tmymax = min(max(roi_y1, roi_y2, 0), ds.RasterYSize - 1) - # enlarge to the nearest 60 pixel boundary for the super-resolution - tmxmin = int(tmxmin / 6) * 6 - tmxmax = int((tmxmax + 1) / 6) * 6 - 1 - tmymin = int(tmymin / 6) * 6 - tmymax = int((tmymax + 1) / 6) * 6 - 1 - elif not roi_lon_lat: - tmxmin = 0 - tmxmax = ds.RasterXSize - 1 - tmymin = 0 - tmymax = ds.RasterYSize - 1 - else: - xoff, a, b, yoff, d, e = ds.GetGeoTransform() - srs = osr.SpatialReference() - srs.ImportFromWkt(ds.GetProjection()) - srsLatLon = osr.SpatialReference() - srsLatLon.SetWellKnownGeogCS("WGS84"); - ct = osr.CoordinateTransformation(srsLatLon, srs) - - - def to_xy(lon, lat): - (xp, yp, h) = ct.TransformPoint(lon, lat, 0.) - xp -= xoff - yp -= yoff - # matrix inversion - det_inv = 1. / (a * e - d * b) - x = (e * xp - b * yp) * det_inv - y = (-d * xp + a * yp) * det_inv - return (int(x), int(y)) - - - x1, y1 = to_xy(roi_lon1, roi_lat1) - x2, y2 = to_xy(roi_lon2, roi_lat2) - tmxmin = max(min(x1, x2, ds.RasterXSize - 1), 0) - tmxmax = min(max(x1, x2, 0), ds.RasterXSize - 1) - tmymin = max(min(y1, y2, ds.RasterYSize - 1), 0) - tmymax = min(max(y1, y2, 0), ds.RasterYSize - 1) - # enlarge to the nearest 60 pixel boundary for the super-resolution - tmxmin = int(tmxmin / 6) * 6 - tmxmax = int((tmxmax + 1) / 6) * 6 - 1 - tmymin = int(tmymin / 6) * 6 - tmymax = int((tmymax + 1) / 6) * 6 - 1 - area = (tmxmax - tmxmin + 1) * (tmymax - tmymin + 1) - current_utm = dsdesc[dsdesc.find("UTM"):] - if area > all_utms[current_utm]: - all_utms[current_utm] = area - if current_utm == select_UTM: - xmin, ymin, xmax, ymax = tmxmin, tmymin, tmxmax, tmymax - utm_idx = tmidx - utm = current_utm - break - if area > largest_area: - xmin, ymin, xmax, ymax = tmxmin, tmymin, tmxmax, tmymax - largest_area = area - utm_idx = tmidx - utm = dsdesc[dsdesc.find("UTM"):] - -if list_UTM: - print("List of UTM zones (with ROI coverage in pixels):") - for u in all_utms: - print("%s (%d)" % (u, all_utms[u])) - sys.exit(0) - -print("Selected UTM Zone:", utm) -print("Selected pixel region: xmin=%d, ymin=%d, xmax=%d, ymax=%d:" % (xmin, ymin, xmax, ymax)) -print("Image size: width=%d x height=%d" % (xmax - xmin + 1, ymax - ymin + 1)) - -if xmax < xmin or ymax < ymin: - print("Invalid region of interest / UTM Zone combination") - sys.exit(0) - -selected_10m_data_set = None -if not tenMsets: - selected_10m_data_set = unknownMsets[0] -else: - selected_10m_data_set = tenMsets[utm_idx] -selected_20m_data_set = None -for (dsname, dsdesc) in enumerate(twentyMsets): - if utm in dsdesc: - selected_20m_data_set = (dsname, dsdesc) -# if not found, assume the listing is in the same order -# => OK if only one set -if not selected_20m_data_set: selected_20m_data_set = twentyMsets[utm_idx] -selected_60m_data_set = None -for (dsname, dsdesc) in enumerate(sixtyMsets): - if utm in dsdesc: - selected_60m_data_set = (dsname, dsdesc) -if not selected_60m_data_set: selected_60m_data_set = sixtyMsets[utm_idx] - -ds10 = gdal.Open(selected_10m_data_set[0]) -ds20 = gdal.Open(selected_20m_data_set[0]) -ds60 = gdal.Open(selected_60m_data_set[0]) - - -def validate_description(description): - m = re.match("(.*?), central wavelength (\d+) nm", description) - if m: - return m.group(1) + " (" + m.group(2) + " nm)" - # Some HDR restrictions... ENVI band names should not include commas - if output_file_format == 'ENVI' and ',' in description: - pos = description.find(',') - return description[:pos] + description[(pos + 1):] - return description - - -if list_bands: - print("\n10m bands:") - for b in range(0, ds10.RasterCount): - print("- " + validate_description(ds10.GetRasterBand(b + 1).GetDescription())) - print("\n20m bands:") - for b in range(0, ds20.RasterCount): - print("- " + validate_description(ds20.GetRasterBand(b + 1).GetDescription())) - print("\n60m bands:") - for b in range(0, ds60.RasterCount): - print("- " + validate_description(ds60.GetRasterBand(b + 1).GetDescription())) - print("") - - -def get_band_short_name(description): - if ',' in description: - return description[:description.find(',')] - if ' ' in description: - return description[:description.find(' ')] - return description[:3] - - -validated_10m_bands = [] -validated_10m_indices = [] -validated_20m_bands = [] -validated_20m_indices = [] -validated_60m_bands = [] -validated_60m_indices = [] -validated_descriptions = defaultdict(str) - -sys.stdout.write("Selected 10m bands:") -for b in range(0, ds10.RasterCount): - desc = validate_description(ds10.GetRasterBand(b + 1).GetDescription()) - shortname = get_band_short_name(desc) - if shortname in select_bands: - sys.stdout.write(" " + shortname) - select_bands.remove(shortname) - validated_10m_bands += [shortname] - validated_10m_indices += [b] - validated_descriptions[shortname] = desc -sys.stdout.write("\nSelected 20m bands:") -for b in range(0, ds20.RasterCount): - desc = validate_description(ds20.GetRasterBand(b + 1).GetDescription()) - shortname = get_band_short_name(desc) - if shortname in select_bands: - sys.stdout.write(" " + shortname) - select_bands.remove(shortname) - validated_20m_bands += [shortname] - validated_20m_indices += [b] - validated_descriptions[shortname] = desc -sys.stdout.write("\nSelected 60m bands:") -for b in range(0, ds60.RasterCount): - desc = validate_description(ds60.GetRasterBand(b + 1).GetDescription()) - shortname = get_band_short_name(desc) - if shortname in select_bands: - sys.stdout.write(" " + shortname) - select_bands.remove(shortname) - validated_60m_bands += [shortname] - validated_60m_indices += [b] - validated_descriptions[shortname] = desc -sys.stdout.write("\n") - -if list_bands: - sys.exit(0) - -# All query options are processed, we now require an output file -if not output_file: - print("Error: you must provide the name of an output file. I will set it identical to the input...") - output_file = os.path.split(data_file)[1] + '.tif' - # sys.exit(1) - - -output_file = save_prefix + output_file -# Some HDR restrictions... ENVI file name should be the .bin, not the .hdr -if output_file_format == 'ENVI' and (output_file[-4:] == '.hdr' or output_file[-4:] == '.HDR'): - output_file = output_file[:-4] + '.bin' - - -if validated_10m_indices: - print("Loading selected data from: %s" % selected_10m_data_set[1]) - data10 = np.rollaxis( - ds10.ReadAsArray(xoff=xmin, yoff=ymin, xsize=xmax - xmin + 1, ysize=ymax - ymin + 1, buf_xsize=xmax - xmin + 1, - buf_ysize=ymax - ymin + 1), 0, 3)[:, :, validated_10m_indices] - -if validated_20m_indices: - print("Loading selected data from: %s" % selected_20m_data_set[1]) - data20 = np.rollaxis( - ds20.ReadAsArray(xoff=xmin // 2, yoff=ymin // 2, xsize=(xmax - xmin + 1) // 2, ysize=(ymax - ymin + 1) // 2, - buf_xsize=(xmax - xmin + 1) // 2, buf_ysize=(ymax - ymin + 1) // 2), 0, 3)[:, :, - validated_20m_indices] - -if validated_60m_indices: - print("Loading selected data from: %s" % selected_60m_data_set[1]) - data60 = np.rollaxis( - ds60.ReadAsArray(xoff=xmin // 6, yoff=ymin // 6, xsize=(xmax - xmin + 1) // 6, ysize=(ymax - ymin + 1) // 6, - buf_xsize=(xmax - xmin + 1) // 6, buf_ysize=(ymax - ymin + 1) // 6), 0, 3)[:, :, - validated_60m_indices] - - -if validated_60m_bands and validated_20m_bands and validated_10m_bands: - print("Super-resolving the 60m data into 10m bands") - sr60 = DSen2_60(data10, data20, data60, deep=False) -else: - sr60 = None - -if validated_10m_bands and validated_20m_bands: - print("Super-resolving the 20m data into 10m bands") - sr20 = DSen2_20(data10, data20, deep=False) -else: - sr20 = None - -sr_band_names = [] - -if sr20 is None: - print("No super-resolution performed, exiting") - sys.exit(0) - -if output_file_format != "npz": - revert_to_npz = True - driver = gdal.GetDriverByName(output_file_format) - if driver: - metadata = driver.GetMetadata() - if gdal.DCAP_CREATE in metadata and metadata[gdal.DCAP_CREATE] == 'YES': - revert_to_npz = False - if revert_to_npz: - print("Gdal doesn't support creating %s files" % output_file_format) - print("Writing to npz as a fallback") - output_file_format = "npz" - bands = None -else: - bands = dict() - result_dataset = None - -bidx = 0 -all_descriptions = [] -source_band = dict() - - -def write_band_data(data, description, shortname=None): - global all_descriptions - global bidx - all_descriptions += [description] - if output_file_format == "npz": - bands[description] = data - else: - bidx += 1 - result_dataset.GetRasterBand(bidx).SetDescription(description) - result_dataset.GetRasterBand(bidx).WriteArray(data) - - -if sr60 is not None: - sr = np.concatenate((sr20, sr60), axis=2) - validated_sr_bands = validated_20m_bands + validated_60m_bands -else: - sr = sr20 - validated_sr_bands = validated_20m_bands - -if copy_original_bands: - out_dims = data10.shape[2] + sr.shape[2] -else: - out_dims = sr.shape[2] - - -sys.stdout.write("Writing") -result_dataset = driver.Create(output_file, data10.shape[1], data10.shape[0], out_dims, gdal.GDT_Float64) - -# Translate the image upper left corner. We multiply x10 to transform from pixel position in the 10m_band to meters. -geot = list(ds10.GetGeoTransform()) -geot[0] += xmin * 10 -geot[3] -= ymin * 10 -result_dataset.SetGeoTransform(tuple(geot)) -result_dataset.SetProjection(ds10.GetProjection()) - -if copy_original_bands: - sys.stdout.write(" the original 10m bands and") - # Write the original 10m bands - for bi, bn in enumerate(validated_10m_bands): - write_band_data(data10[:, :, bi], validated_descriptions[bn]) -print(" the super-resolved bands in %s" % output_file) -for bi, bn in enumerate(validated_sr_bands): - write_band_data(sr[:, :, bi], "SR" + validated_descriptions[bn], "SR" + bn) - +import gc +from typing import Tuple +import argparse -for desc in all_descriptions: - print(desc) +import rasterio +from rasterio import Affine as A -if output_file_format == "npz": - np.savez(output_file, bands=bands) +import numpy as np +from utils.data_utils import DATA_UTILS, get_logger +from supres import dsen2_20, dsen2_60 + +LOGGER = get_logger(__name__) + +# pylint: disable-msg=too-many-arguments +def save_result( + model_output, output_bands, valid_desc, output_profile, image_name, +): + """ + This method saves the feature collection meta data and the + image with high resolution for desired bands to the provided location. + :param model_output: The high resolution image. + :param output_bands: The associated bands for the output image. + :param valid_desc: The valid description of the existing bands. + :param output_profile: The georeferencing for the output image. + :param output_features: The meta data for the output image. + :param image_name: The name of the output image. + + """ + + with rasterio.open(image_name, "w", **output_profile) as d_s: + for b_i, b_n in enumerate(output_bands): + d_s.write(model_output[:, :, b_i], indexes=b_i + 1) + d_s.set_band_description(b_i + 1, "SR " + valid_desc[b_n]) + + +# pylint: disable-msg=too-many-arguments +def update(pr_10m, size_10m: Tuple, model_output: np.ndarray, xmi: int, ymi: int): + """ + This method creates the proper georeferencing for the output image. + :param data: The raster file for 10m resolution. + + """ + # Here based on the params.json file, the output image dimension will be calculated. + out_dims = model_output.shape[2] + + new_transform = pr_10m["transform"] * A.translation(xmi, ymi) + pr_10m.update(dtype=rasterio.float32) + pr_10m.update(driver="GTiff") + pr_10m.update(width=size_10m[1]) + pr_10m.update(height=size_10m[0]) + pr_10m.update(count=out_dims) + pr_10m.update(transform=new_transform) + return pr_10m + + +class Superresolution(DATA_UTILS): + def __init__(self, data_file_path, clip_to_aoi, copy_original_bands, output_dir): + self.data_file_path = data_file_path + self.clip_to_aoi = clip_to_aoi + self.copy_original_bands = copy_original_bands + self.output_dir = output_dir + self.data_name = os.path.basename(data_file_path) + + super().__init__(data_file_path) + + # pylint: disable=attribute-defined-outside-init + def start(self): + data_list = self.get_data() + + for dsdesc in data_list: + if "10m" in dsdesc: + if self.clip_to_aoi: + xmin, ymin, xmax, ymax, interest_area = self.area_of_interest( + dsdesc, self.clip_to_aoi + ) + else: + # Get the pixel bounds of the full scene + xmin, ymin, xmax, ymax, interest_area = self.get_max_min( + 0, 0, 20000, 20000, dsdesc + ) + LOGGER.info("Selected pixel region:") + LOGGER.info("xmin = %s", xmin) + LOGGER.info("ymin = %s", ymin) + LOGGER.info("xmax = %s", xmax) + LOGGER.info("ymax = %s", ymax) + LOGGER.info("The area of selected region = %s", interest_area) + self.check_size(dims=(xmin, ymin, xmax, ymax)) + + for dsdesc in data_list: + if "10m" in dsdesc: + LOGGER.info("Selected 10m bands:") + ( + self.validated_10m_bands, + validated_10m_indices, + dic_10m, + ) = self.validate(dsdesc) + data10 = self.data_final( + dsdesc, validated_10m_indices, xmin, ymin, xmax, ymax, 1, 1 + ) + with rasterio.open(dsdesc) as d_s: + pr_10m = d_s.profile + + if "20m" in dsdesc: + LOGGER.info("Selected 20m bands:") + ( + self.validated_20m_bands, + validated_20m_indices, + dic_20m, + ) = self.validate(dsdesc) + data20 = self.data_final( + dsdesc, validated_20m_indices, xmin, ymin, xmax, ymax, 1, 2, + ) + if "60m" in dsdesc: + LOGGER.info("Selected 60m bands:") + ( + self.validated_60m_bands, + validated_60m_indices, + dic_60m, + ) = self.validate(dsdesc) + data60 = self.data_final( + dsdesc, validated_60m_indices, xmin, ymin, xmax, ymax, 1, 6, + ) + + self.validated_descriptions_all = {**dic_10m, **dic_20m, **dic_60m} + return data10, data20, data60, [xmin, ymin, xmax, ymax], pr_10m + + def inference(self, data10, data20, data60, coord, pr_10m): + + if ( + self.validated_60m_bands + and self.validated_20m_bands + and self.validated_10m_bands + ): + LOGGER.info("Super-resolving the 60m data into 10m bands") + sr60 = dsen2_60(data10, data20, data60) + LOGGER.info("Super-resolving the 20m data into 10m bands") + sr20 = dsen2_20(data10, data20) + else: + LOGGER.info("No super-resolution performed, exiting") + sys.exit(0) + + if self.copy_original_bands: + sr_final = np.concatenate((data10, sr20, sr60), axis=2) + validated_sr_final_bands = ( + self.validated_10m_bands + + self.validated_20m_bands + + self.validated_60m_bands + ) + else: + sr_final = np.concatenate((sr20, sr60), axis=2) + validated_sr_final_bands = ( + self.validated_20m_bands + self.validated_60m_bands + ) + + pr_10m_updated = update(pr_10m, data10.shape, sr_final, coord[0], coord[1]) + + path_to_output_img = self.data_name.split(".")[0] + "_superresolution.tif" + filename = os.path.join(self.output_dir, path_to_output_img) + + LOGGER.info("Now writing the super-resolved bands") + save_result( + sr_final, + validated_sr_final_bands, + self.validated_descriptions_all, + pr_10m_updated, + filename, + ) + del sr_final + LOGGER.info("This is for releasing memory: %s", gc.collect()) + LOGGER.info("Writing the super-resolved bands is finished.") + + def process(self): + data10, data20, data60, coord, pr_10m = self.start() + self.inference(data10, data20, data60, coord, pr_10m) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Perform super-resolution on Sentinel-2 with DSen2. Code based on superres" + " by Nicolas Brodu.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "data_file_path", + help="An input sentinel-2 data file. This can be either the original ZIP file, or the S2A[...].xml " + "file in a SAFE directory extracted from that ZIP.", + ) + parser.add_argument( + "--clip_to_aoi", + default="", + help=( + "Sets the region of interest to extract as pixels locations on the 10m" + 'bands. Use this syntax: x_1,y_1,x_2,y_2. E.g. --roi_x_y "2000,2000,3200,3200"' + ), + ) + parser.add_argument( + "--copy_original_bands", + action="store_true", + help="The default is not to copy the original selected 10m bands into the output file in addition " + "to the super-resolved bands. If this flag is used, the output file may be used as a 10m " + "version of the original Sentinel-2 file.", + ) + parser.add_argument( + "--output_dir", default="", help="Directory to the final output", + ) + args = parser.parse_args() + Superresolution( + args.data_file_path, args.clip_to_aoi, args.copy_original_bands, args.output_dir + ).process() diff --git a/testing/supres.py b/testing/supres.py index f721b18..26ab6a7 100644 --- a/testing/supres.py +++ b/testing/supres.py @@ -1,18 +1,20 @@ from __future__ import division -# import numpy as np -# import argparse -# from skimage.transform import resize -import sys -sys.path.append('../') -from utils.DSen2Net import s2model + +import tensorflow as tf +from tensorflow import keras from utils.patches import get_test_patches, get_test_patches60, recompose_images SCALE = 2000 -MDL_PATH = '../models/' +MDL_PATH = "./models/" + +MDL_PATH_20m_AESR = MDL_PATH + "aesr_20m_s2_038_lr_1e-04.hdf5" +MDL_PATH_60m_AESR = MDL_PATH + "aesr_60m_s2_038_lr_1e-04.hdf5" + +STRATEGY = tf.distribute.MirroredStrategy() -def DSen2_20(d10, d20, deep=False): +def dsen2_20(d10, d20): # Input to the funcion must be of shape: # d10: [x,y,4] (B2, B3, B4, B8) # d20: [x/2,y/4,6] (B5, B6, B7, B8a, B11, B12) @@ -23,14 +25,13 @@ def DSen2_20(d10, d20, deep=False): p10 /= SCALE p20 /= SCALE test = [p10, p20] - input_shape = ((4, None, None), (6, None, None)) - prediction = _predict(test, input_shape, deep=deep) + prediction = _predict(test, model_filename=MDL_PATH_20m_AESR) images = recompose_images(prediction, border=border, size=d10.shape) images *= SCALE return images -def DSen2_60(d10, d20, d60, deep=False): +def dsen2_60(d10, d20, d60): # Input to the funcion must be of shape: # d10: [x,y,4] (B2, B3, B4, B8) # d20: [x/2,y/4,6] (B5, B6, B7, B8a, B11, B12) @@ -43,25 +44,17 @@ def DSen2_60(d10, d20, d60, deep=False): p20 /= SCALE p60 /= SCALE test = [p10, p20, p60] - input_shape = ((4, None, None), (6, None, None), (2, None, None)) - prediction = _predict(test, input_shape, deep=deep, run_60=True) + prediction = _predict(test, model_filename=MDL_PATH_60m_AESR) images = recompose_images(prediction, border=border, size=d10.shape) images *= SCALE return images -def _predict(test, input_shape, deep=False, run_60=False): +def _predict(test, model_filename): # create model - if deep: - model = s2model(input_shape, num_layers=32, feature_size=256) - predict_file = MDL_PATH+'s2_034_lr_1e-04.hdf5' if run_60 else MDL_PATH+'s2_033_lr_1e-04.hdf5' - else: - model = s2model(input_shape, num_layers=6, feature_size=128) - predict_file = MDL_PATH+'s2_030_lr_1e-05.hdf5' if run_60 else MDL_PATH+'s2_032_lr_1e-04.hdf5' - print('Symbolic Model Created.') - - model.load_weights(predict_file) - print("Predicting using file: {}".format(predict_file)) + with STRATEGY.scope(): + model = keras.models.load_model(model_filename) + print("Symbolic Model Created.") + print("Predicting using file: {}".format(model_filename)) prediction = model.predict(test, verbose=1) return prediction - diff --git a/tests/context.py b/tests/context.py new file mode 100644 index 0000000..b1b7ba4 --- /dev/null +++ b/tests/context.py @@ -0,0 +1,19 @@ +""" +This module is used in test_s2_tiles_supres script. +""" +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../utils/")) +) +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../testing/")) +) + +# pylint: disable=unused-import,wrong-import-position +from data_utils import DATA_UTILS, get_logger +from s2_tiles_supres import Superresolution +import patches +import DSen2Net diff --git a/tests/mock_data/data_10.npy b/tests/mock_data/data_10.npy new file mode 100644 index 0000000..c8c4057 Binary files /dev/null and b/tests/mock_data/data_10.npy differ diff --git a/tests/mock_data/data_20.npy b/tests/mock_data/data_20.npy new file mode 100644 index 0000000..48f66b9 Binary files /dev/null and b/tests/mock_data/data_20.npy differ diff --git a/tests/mock_data/data_60.npy b/tests/mock_data/data_60.npy new file mode 100644 index 0000000..d7b6712 Binary files /dev/null and b/tests/mock_data/data_60.npy differ diff --git a/tests/mock_data/test_10m.tif b/tests/mock_data/test_10m.tif new file mode 100644 index 0000000..508479e Binary files /dev/null and b/tests/mock_data/test_10m.tif differ diff --git a/tests/mock_data/test_20m.tif b/tests/mock_data/test_20m.tif new file mode 100644 index 0000000..b2bbd08 Binary files /dev/null and b/tests/mock_data/test_20m.tif differ diff --git a/tests/mock_data/test_60m.tif b/tests/mock_data/test_60m.tif new file mode 100644 index 0000000..48153f6 Binary files /dev/null and b/tests/mock_data/test_60m.tif differ diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py new file mode 100644 index 0000000..cd14586 --- /dev/null +++ b/tests/test_data_utils.py @@ -0,0 +1,273 @@ +""" +This module include multiple test cases to check the performance of the s2_tiles_supres script. +""" +import os +from pathlib import Path +import tempfile +import logging + +import numpy as np +import rasterio +from rasterio.transform import from_origin +from blockutils.syntheticimage import SyntheticImage + +from context import DATA_UTILS, get_logger + +LOGGER = get_logger(__name__) + + +def test_get_max_min(): + """ + This method checks the get_min_max method. + """ + dsr_xmin_exm, dsr_ymin_exm, dsr_xmax_exm, dsr_ymax_exm, dsr_area_exm = ( + 0, + 0, + 5, + 5, + 36, + ) + + test_dir = Path(tempfile.mkdtemp()) + valid_desc = [ + "B4, central wavelength 665 nm", + "B3, central wavelength 560 nm", + "B2, central wavelength 490 nm", + "B8, central wavelength 842 nm", + ] + transform = from_origin(1470996, 6914001, 10.0, 10.0) + + test_img, _ = SyntheticImage(40, 40, 4, "uint16", test_dir, 32640).create( + seed=45, transform=transform, band_desc=valid_desc + ) + # dsr = rasterio.open(test_img) + dsr_xmin, dsr_ymin, dsr_xmax, dsr_ymax, dsr_area = DATA_UTILS.get_max_min( + 0, 0, 10, 10, test_img + ) + + assert dsr_xmin == dsr_xmin_exm + assert dsr_ymin == dsr_ymin_exm + assert dsr_xmax == dsr_xmax_exm + assert dsr_ymax == dsr_ymax_exm + assert dsr_area == dsr_area_exm + + +def test_to_xy(): + """ + This method checks to_xy method. + """ + dsr_x_exm = -575834 + dsr_y_exm = 66564 + test_dir = Path(tempfile.mkdtemp()) + valid_desc = [ + "B4, central wavelength 665 nm", + "B3, central wavelength 560 nm", + "B2, central wavelength 490 nm", + "B8, central wavelength 842 nm", + ] + transform = from_origin(1470996, 6914001, 10.0, 10.0) + + test_img, _ = SyntheticImage(20, 18, 4, "uint16", test_dir, 32640).create( + seed=45, transform=transform, band_desc=valid_desc + ) + # dsr = rasterio.open(test_img) + data_file_path_test = "/test/a.SAFE" + dsr_x, dsr_y = DATA_UTILS(data_file_path_test).to_xy(lon=1, lat=40, data=test_img) + + assert dsr_x == dsr_x_exm + assert dsr_y == dsr_y_exm + + +def test_get_utm(): + """ + This method check the get_utm methods. + """ + utm_exm = "epsg:32640" + test_dir = Path(tempfile.mkdtemp()) + valid_desc = [ + "B4, central wavelength 665 nm", + "B3, central wavelength 560 nm", + "B2, central wavelength 490 nm", + "B8, central wavelength 842 nm", + ] + transform = from_origin(1470996, 6914001, 10.0, 10.0) + + test_img, _ = SyntheticImage(40, 40, 4, "uint16", test_dir, 32640).create( + seed=45, transform=transform, band_desc=valid_desc + ) + + dsr_utm = DATA_UTILS.get_utm(test_img) + + assert dsr_utm == utm_exm + + +# pylint: disable-msg=too-many-locals +def test_area_of_interest(): + """ + this method checks the area_of_interest methods. + """ + dsr_xmin_exm, dsr_ymin_exm, dsr_xmax_exm, dsr_ymax_exm, dsr_area_exm = ( + 0, + 18, + 17, + 35, + 324, + ) + + test_dir = Path(tempfile.mkdtemp()) + valid_desc = [ + "B4, central wavelength 665 nm", + "B3, central wavelength 560 nm", + "B2, central wavelength 490 nm", + "B8, central wavelength 842 nm", + ] + transform = from_origin(1470996, 6914001, 10.0, 10.0) + + test_img, _ = SyntheticImage(40, 40, 4, "uint16", test_dir, 32640).create( + seed=45, transform=transform, band_desc=valid_desc + ) + + data_file_path_test = "/test/a.SAFE" + clip_to_aoi_test = "75.192123,61.127161,75.195960,61.127993" + dsr_xmin, dsr_ymin, dsr_xmax, dsr_ymax, dsr_area = DATA_UTILS( + data_file_path_test + ).area_of_interest(test_img, clip_to_aoi_test) + print(dsr_xmin, dsr_ymin, dsr_xmax, dsr_ymax, dsr_area) + assert dsr_xmin == dsr_xmin_exm + assert dsr_ymin == dsr_ymin_exm + assert dsr_xmax == dsr_xmax_exm + assert dsr_ymax == dsr_ymax_exm + assert dsr_area == dsr_area_exm + + +def test_validate_description(): + """ + this method checks the validate_description methods. + """ + valid_desc_exm = ["B4 (665 nm)", "B3 (560 nm)", "B2 (490 nm)", "B8 (842 nm)"] + + test_dir = Path(tempfile.mkdtemp()) + valid_desc = [ + "B4, central wavelength 665 nm", + "B3, central wavelength 560 nm", + "B2, central wavelength 490 nm", + "B8, central wavelength 842 nm", + ] + transform = from_origin(1470996, 6914001, 10.0, 10.0) + + test_img, _ = SyntheticImage(20, 18, 4, "uint16", test_dir).create( + seed=45, transform=transform, band_desc=valid_desc + ) + + dsr = rasterio.open(test_img) + valid_desc = [] + print(dsr.count) + for i in range(dsr.count): + valid_desc.append(DATA_UTILS.validate_description(dsr.descriptions[i])) + + assert set(valid_desc) == set(valid_desc_exm) + + +def test_get_band_short_name(): + """ + This method checks the functionality of get_short_name methods. + """ + short_desc_exm = ["B4", "B3", "B2", "B8"] + + test_dir = Path(tempfile.mkdtemp()) + valid_desc = [ + "B4, central wavelength 665 nm", + "B3, central wavelength 560 nm", + "B2, central wavelength 490 nm", + "B8, central wavelength 842 nm", + ] + transform = from_origin(1470996, 6914001, 10.0, 10.0) + + test_img, _ = SyntheticImage(20, 18, 4, "uint16", test_dir).create( + seed=45, transform=transform, band_desc=valid_desc + ) + + dsr = rasterio.open(test_img) + short_desc = [] + + for i in range(dsr.count): + desc = DATA_UTILS.validate_description(dsr.descriptions[i]) + short_desc.append(DATA_UTILS.get_band_short_name(desc)) + + assert set(short_desc) == set(short_desc_exm) + + +# pylint: disable-msg=too-many-locals +def test_validate(): + """ + This method check whether validate function defined in the s2_tiles_supres + file produce the correct results. + """ + test_dir = Path(tempfile.mkdtemp()) + valid_desc_10 = [ + "B4, central wavelength 665 nm", + "B3, central wavelength 560 nm", + "B2, central wavelength 490 nm", + "B8, central wavelength 842 nm", + ] + + transform = from_origin(1470996, 6914001, 10.0, 10.0) + + test_img_10, _ = SyntheticImage(20, 18, 4, "uint16", test_dir).create( + seed=45, transform=transform, band_desc=valid_desc_10 + ) + + validated_10m_indices_exm = [0, 1, 2, 3] + validated_10m_bands_exm = ["B2", "B3", "B4", "B8"] + data_file_path_test = "/test/a.SAFE" + validated_10m_bands, validated_10m_indices, _ = DATA_UTILS( + data_file_path_test + ).validate(data=test_img_10) + + assert set(validated_10m_bands) == set(validated_10m_bands_exm) + assert validated_10m_indices == validated_10m_indices_exm + + +def test_data_final(): + """ + This method checks the functionality of data_final method. + """ + test_dir = Path(tempfile.mkdtemp()) + valid_desc = [ + "B4, central wavelength 665 nm", + "B3, central wavelength 560 nm", + "B2, central wavelength 490 nm", + "B8, central wavelength 842 nm", + ] + valid_indices = [0, 1, 2, 3] + + transform = from_origin(1470996, 6914001, 10.0, 10.0) + + test_img, _ = SyntheticImage(20, 18, 4, "uint16", test_dir).create( + seed=45, transform=transform, band_desc=valid_desc + ) + + d_final = DATA_UTILS.data_final(test_img, valid_indices, 0, 0, 5, 5, 1, 1) + assert d_final.shape == (6, 6, 4) + + +def test_save_band(): + save_prefix = "/tmp/" + name = "a" + + array = np.zeros([100, 200, 3], dtype=np.uint8) + array[:, :100] = [255, 128, 0] # Orange left side + array[:, 100:] = [0, 0, 255] # Blue right side + + DATA_UTILS.save_band(save_prefix, array, name) + assert os.path.isfile(save_prefix + name + ".png") + + +def test_check_size(caplog): + dim_exm = 180, 180, 199, 199 + with caplog.at_level(logging.DEBUG): + DATA_UTILS.check_size(dim_exm) + assert ( + "AOI too small. Try again with a larger AOI (minimum pixel width or heigh of 192)" + in caplog.text + ) diff --git a/tests/test_dsen2.py b/tests/test_dsen2.py new file mode 100644 index 0000000..c3387fb --- /dev/null +++ b/tests/test_dsen2.py @@ -0,0 +1,75 @@ +import pytest + +from keras.models import Model + +from context import DSen2Net + + +@pytest.fixture +def input_shape_20(): + return ((4, None, None), (6, None, None)) + + +@pytest.fixture +def input_shape_60(): + return ((4, None, None), (6, None, None), (2, None, None)) + + +def test_s2model_20(input_shape_20): + m = DSen2Net.s2model(input_shape_20) + assert m.layers[-1].output_shape[1:] == input_shape_20[-1] + assert isinstance(m, Model) + + +def test_s2model_60(input_shape_60): + m = DSen2Net.s2model(input_shape_60) + assert m.layers[-1].output_shape[1:] == input_shape_60[-1] + assert isinstance(m, Model) + + +def test_aesrmodel_20(input_shape_20): + m = DSen2Net.aesrmodel(input_shape_20) + assert m.layers[-1].output_shape[1:] == input_shape_20[-1] + assert isinstance(m, Model) + + +def test_aesrmodel_60(input_shape_60): + m = DSen2Net.aesrmodel(input_shape_60) + assert m.layers[-1].output_shape[1:] == input_shape_60[-1] + assert isinstance(m, Model) + + +def test_srcnn_20(input_shape_20): + m = DSen2Net.srcnn(input_shape_20) + assert m.layers[-1].output_shape[1:] == input_shape_20[-1] + assert isinstance(m, Model) + + +def test_srcnn_60(input_shape_60): + m = DSen2Net.srcnn(input_shape_60) + assert m.layers[-1].output_shape[1:] == input_shape_60[-1] + assert isinstance(m, Model) + + +def test_rednetsr_20(input_shape_20): + m = DSen2Net.rednetsr(input_shape_20) + assert m.layers[-1].output_shape[1:] == input_shape_20[-1] + assert isinstance(m, Model) + + +def test_rednetsr_60(input_shape_60): + m = DSen2Net.rednetsr(input_shape_60) + assert m.layers[-1].output_shape[1:] == input_shape_60[-1] + assert isinstance(m, Model) + + +def test_resnetsr_20(input_shape_20): + m = DSen2Net.resnetsr(input_shape_20) + assert m.layers[-1].output_shape[1:] == input_shape_20[-1] + assert isinstance(m, Model) + + +def test_resnetsr_60(input_shape_60): + m = DSen2Net.resnetsr(input_shape_60) + assert m.layers[-1].output_shape[1:] == input_shape_60[-1] + assert isinstance(m, Model) diff --git a/tests/test_patches.py b/tests/test_patches.py new file mode 100644 index 0000000..7fa4230 --- /dev/null +++ b/tests/test_patches.py @@ -0,0 +1,147 @@ +from pathlib import Path +import tempfile +import pytest + +import numpy as np + +from context import patches + + +@pytest.fixture() +def dset_10(): + return np.ones((10980, 10980, 4)) + + +@pytest.fixture() +def dset_20(): + return np.ones((5490, 5490, 5)) + + +@pytest.fixture() +def dset_60(): + return np.ones((1830, 1830, 3)) + + +@pytest.fixture() +def scale_20(): + return 2 + + +@pytest.fixture() +def scale_60(): + return 6 + + +def test_get_test_patches(dset_10, dset_20): + r = patches.get_test_patches(dset_10, dset_20, 128, 8) + assert len(r) == 2 + assert r[0].shape == (9801, 4, 128, 128) + assert r[1].shape == (9801, 5, 128, 128) + + +def test_get_test_patches60(dset_10, dset_20, dset_60): + r = patches.get_test_patches60(dset_10, dset_20, dset_60, 192, 12) + assert len(r) == 3 + assert r[0].shape == (4356, 4, 192, 192) + assert r[1].shape == (4356, 5, 192, 192) + assert r[2].shape == (4356, 3, 192, 192) + + +@pytest.mark.skip(reason="too long test") +def save_test_patches(dset_10, dset_20): + with tempfile.TemporaryDirectory() as tmpdir: + patches.save_test_patches(dset_10, dset_20, tmpdir + "/", 128, 8) + f = Path(tmpdir).glob("*/**") + assert len(list(f)) == 2 + + +@pytest.mark.skip(reason="too long test") +def save_test_patches60(dset_10, dset_20, dset_60): + with tempfile.TemporaryDirectory() as tmpdir: + patches.save_test_patches60(dset_10, dset_20, dset_60, tmpdir + "/", 192, 12) + f = Path(tmpdir).glob("*/**") + assert len(list(f)) == 3 + + +def test_get_random_patches(dset_10, dset_20, scale_20): + data10_lr = patches.downPixelAggr(dset_10, SCALE=scale_20) + data20_lr = patches.downPixelAggr(dset_20, SCALE=scale_20) + r = patches.get_random_patches(dset_20, data10_lr, data20_lr, 8000) + assert len(r) == 3 + assert r[0].shape == (8000, 4, 32, 32) + assert r[1].shape == (8000, 5, 32, 32) + assert r[2].shape == (8000, 5, 32, 32) + + +def test_get_random_patches60(dset_10, dset_20, dset_60, scale_60): + data10_lr = patches.downPixelAggr(dset_10, SCALE=scale_60) + data20_lr = patches.downPixelAggr(dset_20, SCALE=scale_60) + data60_lr = patches.downPixelAggr(dset_60, SCALE=scale_60) + r = patches.get_random_patches60(dset_60, data10_lr, data20_lr, data60_lr, 8000) + assert len(r) == 4 + assert r[0].shape == (8000, 4, 96, 96) + assert r[1].shape == (8000, 3, 96, 96) + assert r[2].shape == (8000, 5, 96, 96) + assert r[3].shape == (8000, 3, 96, 96) + + +def test_get_crop_window(): + w = patches.get_crop_window(100, 50, 25) + assert w == [100, 50, 125, 75] + w = patches.get_crop_window(100, 50, 25, 2) + assert w == [200, 100, 250, 150] + + +def test_crop_array_to_window(): + ar = np.ones(shape=(100, 100, 4)) + w = patches.get_crop_window(50, 50, 25) + assert patches.crop_array_to_window(ar, w).shape == (4, 25, 25) + assert patches.crop_array_to_window(ar, w, False).shape == (25, 25, 4) + + +@pytest.mark.skip(reason="too long test") +def test_save_random_patches(dset_10, dset_20, scale_20): + with tempfile.TemporaryDirectory() as tmpdir: + patches.save_random_patches( + dset_20, + patches.downPixelAggr(dset_10, scale_20), + patches.downPixelAggr(dset_20, scale_20), + tmpdir + "/", + ) + f = Path(tmpdir).glob("*/**") + assert len(list(f)) == 2 + + +@pytest.mark.skip(reason="too long test") +def test_save_random_patches60(dset_10, dset_20, dset_60, scale_60): + with tempfile.TemporaryDirectory() as tmpdir: + patches.save_random_patches60( + dset_60, + patches.downPixelAggr(dset_10, scale_60), + patches.downPixelAggr(dset_20, scale_60), + patches.downPixelAggr(dset_60, scale_60), + tmpdir + "/", + ) + f = Path(tmpdir).glob("*/**") + assert len(list(f)) == 3 + + +def test_downPixelAggr(dset_10, scale_20): + r = patches.downPixelAggr(dset_10, scale_20) + assert r.shape == (5490, 5490, 4) + + dset_20_w = np.ones((5489, 5489, 6)) + r = patches.downPixelAggr(dset_20_w, scale_20) + assert r.shape == (2744, 2744, 6) + + +def test_recompose_images(dset_10, dset_20): + p = patches.get_test_patches(dset_10, dset_20, 128, 8) + r_p = patches.recompose_images(p[0], 8, dset_10.shape) + assert dset_10.shape == r_p.shape + + +@pytest.mark.skip(reason="too long test") +def test_interp_patches(dset_20, dset_10): + r = patches.interp_patches(dset_20, dset_10.shape) + assert r diff --git a/tests/test_s2_tiles_superres.py b/tests/test_s2_tiles_superres.py new file mode 100644 index 0000000..1c46dd8 --- /dev/null +++ b/tests/test_s2_tiles_superres.py @@ -0,0 +1,125 @@ +import os +import numpy as np + +import pytest +import rasterio +from rasterio import Affine +from rasterio.crs import CRS + +import tensorflow as tf + +from blockutils.common import ensure_data_directories_exist +from context import Superresolution + + +# pylint: disable=redefined-outer-name +@pytest.fixture(scope="session") +def fixture_superresolution_clip(): + ensure_data_directories_exist() + return Superresolution( + "a.SAFE", "50.550671,26.15174,50.596161,26.19195", True, "/tmp/output" + ) + + +def test_start(fixture_superresolution_clip, monkeypatch): + _location_ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + data_10m = os.path.join(_location_, "mock_data/test_10m.tif") + data_20m = os.path.join(_location_, "mock_data/test_20m.tif") + data_60m = os.path.join(_location_, "mock_data/test_60m.tif") + expected_final_dset = [data_10m, data_20m, data_60m] + + def _mock_getdata(self): + return expected_final_dset + + monkeypatch.setattr(Superresolution, "get_data", _mock_getdata) + ( + data10, + data20, + data60, + [xmin, ymin, xmax, ymax], + pr, + ) = fixture_superresolution_clip.start() + assert data10.shape == (444, 456, 4) + assert data20.shape == (222, 228, 6) + assert data60.shape == (74, 76, 2) + assert [xmin, ymin, xmax, ymax] == [48, 174, 503, 617] + assert pr == { + "driver": "GTiff", + "dtype": "uint16", + "nodata": None, + "width": 1584, + "height": 1762, + "count": 4, + "crs": CRS.from_epsg(32639), + "transform": Affine(10.0, 0.0, 454590.0, 0.0, -10.0, 2898770.0), + "blockxsize": 128, + "blockysize": 128, + "tiled": True, + "interleave": "pixel", + } + + +@pytest.mark.skipif( + len(tf.config.list_physical_devices("GPU")) == 0, + reason="Conv2D op requires GPU for channels first configuration.", +) +def test_inference(fixture_superresolution_clip): + _location_ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__))) + fixture_superresolution_clip.validated_10m_bands = ["B4", "B3", "B2", "B8"] + fixture_superresolution_clip.validated_20m_bands = [ + "B5", + "B6", + "B7", + "B8A", + "B11", + "B12", + ] + fixture_superresolution_clip.validated_60m_bands = ["B1", "B9"] + fixture_superresolution_clip.validated_descriptions_all = { + "B4": "B4 (665 nm)", + "B3": "B3 (560 nm)", + "B2": "B2 (490 nm)", + "B8": "B8 (842 nm)", + "B5": "B5 (705 nm)", + "B6": "B6 (740 nm)", + "B7": "B7 (783 nm)", + "B8A": "B8A (865 nm)", + "B11": "B11 (1610 nm)", + "B12": "B12 (2190 nm)", + "B1": "B1 (443 nm)", + "B9": "B9 (945 nm)", + } + fixture_superresolution_clip.data_name = ( + "S2B_MSIL1C_20200708T070629_N0209_R106_T39RVJ_20200708T100303.SAFE" + ) + + data10 = np.load(os.path.join(_location_, "mock_data/data_10.npy")) + data20 = np.load(os.path.join(_location_, "mock_data/data_20.npy")) + data60 = np.load(os.path.join(_location_, "mock_data/data_60.npy")) + + coord = [48, 174, 503, 617] + pr = { + "driver": "GTiff", + "dtype": "uint16", + "nodata": None, + "width": 1584, + "height": 1762, + "count": 4, + "crs": CRS.from_epsg(32639), + "transform": Affine(10.0, 0.0, 454590.0, 0.0, -10.0, 2898770.0), + "blockxsize": 128, + "blockysize": 128, + "tiled": True, + "interleave": "pixel", + } + + fixture_superresolution_clip.inference(data10, data20, data60, coord, pr) + result_path = os.path.join( + "/tmp/output", + fixture_superresolution_clip.data_name.split(".")[0] + "_superresolution.tif", + ) + assert os.path.isfile(result_path) + with rasterio.open(result_path) as src: + assert src.count == 12 + assert src.transform == Affine(10.0, 0.0, 455070.0, 0.0, -10.0, 2897030.0) + assert src.profile["driver"] == "GTiff" diff --git a/training/create_patches.py b/training/create_patches.py index b21c02f..31c89a9 100644 --- a/training/create_patches.py +++ b/training/create_patches.py @@ -1,354 +1,353 @@ from __future__ import division -import argparse -import numpy as np -from osgeo import gdal -import sys -from collections import defaultdict -import re import os -import imageio +import sys +import argparse import json -sys.path.append('../') -from utils.patches import downPixelAggr, save_test_patches, save_random_patches, save_random_patches60, save_test_patches60 - - -data_filename = '/MTD_MSIL1C.xml' - -# sleep(randint(0, 20)) - -def readS2fromFile(data_file, - test_data=False, - roi_x_y=None, - save_prefix="../data/", - write_images=False, - run_60=False, - true_data=False): - - if run_60: - select_bands = "B1,B2,B3,B4,B5,B6,B7,B8,B8A,B9,B11,B12" - else: - select_bands = "B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12" - - raster = gdal.Open(data_file + data_filename) - - datasets = raster.GetSubDatasets() - tenMsets = [] - twentyMsets = [] - sixtyMsets = [] - unknownMsets = [] - for (dsname, dsdesc) in datasets: - if '10m resolution' in dsdesc: - tenMsets += [ (dsname, dsdesc) ] - elif '20m resolution' in dsdesc: - twentyMsets += [ (dsname, dsdesc) ] - elif '60m resolution' in dsdesc: - sixtyMsets += [ (dsname, dsdesc) ] - else: - unknownMsets += [ (dsname, dsdesc) ] - - if roi_x_y: - roi_x1, roi_y1, roi_x2, roi_y2 = [float(x) for x in re.split(',', args.roi_x_y)] - - # case where we have several UTM in the data set - # => select the one with maximal coverage of the study zone - utm_idx = 0 - utm = "" - all_utms = defaultdict(int) - xmin, ymin, xmax, ymax = 0, 0, 0, 0 - largest_area = -1 - # process even if there is only one 10m set, in order to get roi -> pixels - for (tmidx, (dsname, dsdesc)) in enumerate(tenMsets + unknownMsets): - ds = gdal.Open(dsname) - if roi_x_y: - tmxmin = max(min(roi_x1, roi_x2, ds.RasterXSize - 1), 0) - tmxmax = min(max(roi_x1, roi_x2, 0), ds.RasterXSize - 1) - tmymin = max(min(roi_y1, roi_y2, ds.RasterYSize - 1), 0) - tmymax = min(max(roi_y1, roi_y2, 0), ds.RasterYSize - 1) - # enlarge to the nearest 60 pixel boundary for the super-resolution - tmxmin = int(tmxmin / 36) * 36 - tmxmax = int((tmxmax + 1) / 36) * 36 - 1 - tmymin = int(tmymin / 36) * 36 - tmymax = int((tmymax + 1) / 36) * 36 - 1 - else: - tmxmin = 0 - tmxmax = ds.RasterXSize - 1 - tmymin = 0 - tmymax = ds.RasterYSize - 1 - - area = (tmxmax - tmxmin + 1) * (tmymax - tmymin + 1) - current_utm = dsdesc[dsdesc.find("UTM"):] - if area > all_utms[current_utm]: - all_utms[current_utm] = area - if area > largest_area: - xmin, ymin, xmax, ymax = tmxmin, tmymin, tmxmax, tmymax - largest_area = area - utm_idx = tmidx - utm = dsdesc[dsdesc.find("UTM"):] - - # convert comma separated band list into a list - select_bands = [x for x in re.split(',',select_bands) ] - - print("Selected UTM Zone:".format(utm)) - print("Selected pixel region: xmin=%d, ymin=%d, xmax=%d, ymax=%d:" % (xmin, ymin, xmax, ymax)) - print("Selected pixel region: tmxmin=%d, tmymin=%d, tmxmax=%d, tmymax=%d:" % (tmxmin, tmymin, tmxmax, tmymax)) - print("Image size: width=%d x height=%d" % (xmax - xmin + 1, ymax - ymin + 1)) - - if xmax < xmin or ymax < ymin: - print("Invalid region of interest / UTM Zone combination") - sys.exit(0) - - selected_10m_data_set = None - if not tenMsets: - selected_10m_data_set = unknownMsets[0] - else: - selected_10m_data_set = tenMsets[utm_idx] - selected_20m_data_set = None - for (dsname, dsdesc) in enumerate(twentyMsets): - if utm in dsdesc: - selected_20m_data_set = (dsname, dsdesc) - # if not found, assume the listing is in the same order - # => OK if only one set - if not selected_20m_data_set: selected_20m_data_set = twentyMsets[utm_idx] - selected_60m_data_set = None - for (dsname, dsdesc) in enumerate(sixtyMsets): - if utm in dsdesc: - selected_60m_data_set = (dsname, dsdesc) - if not selected_60m_data_set: selected_60m_data_set = sixtyMsets[utm_idx] - - ds10 = gdal.Open(selected_10m_data_set[0]) - ds20 = gdal.Open(selected_20m_data_set[0]) - ds60 = gdal.Open(selected_60m_data_set[0]) - - def validate_description(description): - m = re.match("(.*?), central wavelength (\d+) nm", description) - if m: - return m.group(1) + " (" + m.group(2) + " nm)" - # Some HDR restrictions... ENVI band names should not include commas - - pos = description.find(',') - return description[:pos] + description[(pos + 1):] - - def get_band_short_name(description): - if ',' in description: - return description[:description.find(',')] - if ' ' in description: - return description[:description.find(' ')] - return description[:3] - - validated_10m_bands = [] - validated_10m_indices = [] - validated_20m_bands = [] - validated_20m_indices = [] - validated_60m_bands = [] - validated_60m_indices = [] - validated_descriptions = defaultdict(str) - - sys.stdout.write("Selected 10m bands:") - for b in range(0, ds10.RasterCount): - desc = validate_description(ds10.GetRasterBand(b + 1).GetDescription()) - shortname = get_band_short_name(desc) - if shortname in select_bands: - sys.stdout.write(" " + shortname) - select_bands.remove(shortname) - validated_10m_bands += [shortname] - validated_10m_indices += [b] - validated_descriptions[shortname] = desc - sys.stdout.write("\nSelected 20m bands:") - for b in range(0, ds20.RasterCount): - desc = validate_description(ds20.GetRasterBand(b + 1).GetDescription()) - shortname = get_band_short_name(desc) - if shortname in select_bands: - sys.stdout.write(" " + shortname) - select_bands.remove(shortname) - validated_20m_bands += [shortname] - validated_20m_indices += [b] - validated_descriptions[shortname] = desc - sys.stdout.write("\nSelected 60m bands:") - for b in range(0, ds60.RasterCount): - desc = validate_description(ds60.GetRasterBand(b + 1).GetDescription()) - shortname = get_band_short_name(desc) - if shortname in select_bands: - sys.stdout.write(" " + shortname) - select_bands.remove(shortname) - validated_60m_bands += [shortname] - validated_60m_indices += [b] - validated_descriptions[shortname] = desc - sys.stdout.write("\n") - - if validated_10m_indices: - print("Loading selected data from: %s" % selected_10m_data_set[1]) - data10 = np.rollaxis( - ds10.ReadAsArray(xoff=xmin, yoff=ymin, xsize=xmax - xmin + 1, ysize=ymax - ymin + 1, buf_xsize=xmax - xmin + 1, - buf_ysize=ymax - ymin + 1), 0, 3)[:, :, validated_10m_indices] - - if validated_20m_indices: - print("Loading selected data from: %s" % selected_20m_data_set[1]) - data20 = np.rollaxis( - ds20.ReadAsArray(xoff=xmin // 2, yoff=ymin // 2, xsize=(xmax - xmin + 1) // 2, ysize=(ymax - ymin + 1) // 2, - buf_xsize=(xmax - xmin + 1) // 2, buf_ysize=(ymax - ymin + 1) // 2), 0, 3)[:, :, - validated_20m_indices] - - if validated_60m_indices: - print("Loading selected data from: %s" % selected_60m_data_set[1]) - data60 = np.rollaxis( - ds60.ReadAsArray(xoff=xmin // 6, yoff=ymin // 6, xsize=(xmax - xmin + 1) // 6, ysize=(ymax - ymin + 1) // 6, - buf_xsize=(xmax - xmin + 1) // 6, buf_ysize=(ymax - ymin + 1) // 6), 0, 3)[:, :, - validated_60m_indices] - - # The percentile_data argument is used to plot superresolved and original data - # with a comparable black/white scale - def save_band(data, name, percentile_data=None): - if percentile_data is None: - percentile_data = data - mi, ma = np.percentile(percentile_data, (1, 99)) - band_data = np.clip(data, mi, ma) - band_data = (band_data - mi) / (ma - mi) - imageio.imsave(save_prefix + name + ".png", band_data) # img_as_uint(band_data)) - - chan3 = data10[:, :, 0] - vis = (chan3 < 1).astype(np.int) - if np.sum(vis) > 0: - print('The selected image has some blank pixels') - # sys.exit() - - scale20 = 2 - scale60 = 6 - - data10_gt = data10 - data20_gt = data20 - - if not true_data: - if run_60: - data60_gt = data60 - data10_lr = downPixelAggr(data10_gt, SCALE=scale60) - data20_lr = downPixelAggr(data20_gt, SCALE=scale60) - data60_lr = downPixelAggr(data60_gt, SCALE=scale60) - else: - data10_lr = downPixelAggr(data10_gt, SCALE=scale20) - data20_lr = downPixelAggr(data20_gt, SCALE=scale20) - if scale20 > 2: - data20_lr = downPixelAggr(data20_gt, SCALE=scale20//2) - - if data_file.endswith('/'): - tmp = os.path.split(data_file)[0] - data_file = os.path.split(tmp)[1] - else: - data_file = os.path.split(data_file)[1] - print(data_file) - - if test_data: - if run_60: - out_per_image0 = save_prefix + 'test60/' - out_per_image = save_prefix + 'test60/' + data_file + '/' - else: - out_per_image0 = save_prefix + 'test/' - out_per_image = save_prefix + 'test/' + data_file + '/' - if not os.path.isdir(out_per_image0): - os.mkdir(out_per_image0) - if not os.path.isdir(out_per_image): - os.mkdir(out_per_image) - - print('Writing files for testing to:{}'.format(out_per_image)) - if run_60: - save_test_patches60(data10_lr, data20_lr, data60_lr, out_per_image) - with open(out_per_image + 'roi.json', 'w') as f: - json.dump([tmxmin // scale60, tmymin // scale60, (tmxmax + 1) // scale60, (tmymax + 1) // scale60], f) - else: - save_test_patches(data10_lr, data20_lr, out_per_image) - with open(out_per_image + 'roi.json', 'w') as f: - json.dump([tmxmin // scale20, tmymin // scale20, (tmxmax+1) // scale20, (tmymax+1) // scale20], f) - if not os.path.isdir(out_per_image + 'no_tiling/'): - os.mkdir(out_per_image + 'no_tiling/') +from typing import Tuple +import numpy as np - print("Now saving the whole image without tiling...") - if run_60: - np.save(out_per_image + 'no_tiling/' + 'data60_gt', data60_gt.astype(np.float32)) - np.save(out_per_image + 'no_tiling/' + 'data60', data60_lr.astype(np.float32)) +sys.path.append("..") + +from utils.data_utils import DATA_UTILS, get_logger +from utils.patches import ( + downPixelAggr, + save_test_patches, + save_random_patches, + save_random_patches60, + save_test_patches60, +) + +LOGGER = get_logger(__name__) + + +def parser_common(parser): + parser.add_argument( + "--test_data", + default=False, + action="store_true", + help="Store test patches in a separate dir.", + ) + parser.add_argument( + "--rgb_images", + default=False, + action="store_true", + help=( + "If set, write PNG images for the original and the superresolved bands," + " together with a composite rgb image (first three 10m bands), all with a " + "quick and dirty clipping to 99%% of the original bands dynamic range and " + "a quantization of the values to 256 levels." + ), + ) + parser.add_argument( + "--save_prefix", + default="../data/", + help=( + "If set, speficies the name of a prefix for all output files. " + "Use a trailing / to save into a directory. The default of no prefix will " + "save into the current directory. Example: --save_prefix result/" + ), + ) + parser.add_argument( + "--run_60", + default=False, + action="store_true", + help="If set, it will create patches also from the 60m channels.", + ) + parser.add_argument( + "--true_data", + default=False, + action="store_true", + help=( + "If set, it will create patches for S2 without GT. This option is not " + "really useful here, please check the testing folder for predicting S2 images." + ), + ) + parser.add_argument( + "--train_data", + default=False, + action="store_true", + help="Store train patches in a separate dir", + ) + return parser + + +# pylint: disable=unbalanced-tuple-unpacking +class readS2fromFile(DATA_UTILS): + def __init__( + self, + data_file_path, + clip_to_aoi=None, + save_prefix="../data/", + rgb_images=False, + run_60=False, + true_data=False, + test_data=False, + train_data=False, + ): + self.data_file_path = data_file_path + self.test_data = test_data + self.clip_to_aoi = clip_to_aoi + self.save_prefix = save_prefix + self.rgb_images = rgb_images + self.run_60 = run_60 + self.true_data = true_data + self.train_data = train_data + self.data_name = os.path.basename(data_file_path) + + super().__init__(data_file_path) + + def get_original_image(self) -> Tuple: + + data_list = self.get_data() + for dsdesc in data_list: + if "10m" in dsdesc: + xmin, ymin, xmax, ymax, interest_area = self.area_of_interest( + dsdesc, self.clip_to_aoi + ) + LOGGER.info("Selected pixel region:") + LOGGER.info("xmin = %s", xmin) + LOGGER.info("ymin = %s", ymin) + LOGGER.info("xmax = %s", xmax) + LOGGER.info("ymax = %s", ymax) + LOGGER.info("The area of selected region = %s", interest_area) + self.check_size(dims=(xmin, ymin, xmax, ymax)) + + for dsdesc in data_list: + if "10m" in dsdesc: + LOGGER.info("Selected 10m bands:") + _, validated_10m_indices, _ = self.validate(dsdesc) + data10 = self.data_final( + dsdesc, validated_10m_indices, xmin, ymin, xmax, ymax, 1, 1 + ) + if "20m" in dsdesc: + LOGGER.info("Selected 20m bands:") + _, validated_20m_indices, _ = self.validate(dsdesc) + data20 = self.data_final( + dsdesc, validated_20m_indices, xmin, ymin, xmax, ymax, 1, 2 + ) + if "60m" in dsdesc: + LOGGER.info("Selected 60m bands:") + _, validated_60m_indices, _ = self.validate(dsdesc) + data60 = self.data_final( + dsdesc, validated_60m_indices, xmin, ymin, xmax, ymax, 1, 6 + ) + + return data10, data20, data60, xmin, ymin, xmax, ymax + + def get_downsampled_images(self, data10, data20, data60) -> Tuple: + if self.run_60: + data10_lr = downPixelAggr(data10, SCALE=6) + data20_lr = downPixelAggr(data20, SCALE=6) + data60_lr = downPixelAggr(data60, SCALE=6) + return data10_lr, data20_lr, data60_lr else: - np.save(out_per_image + 'no_tiling/' + 'data20_gt', data20_gt.astype(np.float32)) - save_band(data10_lr[:, :, 0:3], '/test/' + data_file + '/RGB') - np.save(out_per_image + 'no_tiling/' + 'data10', data10_lr.astype(np.float32)) - np.save(out_per_image + 'no_tiling/' + 'data20', data20_lr.astype(np.float32)) - - elif write_images: - print('Creating RGB images...') - save_band(data10_lr[:, :, 0:3], '/raw/rgbs/' + data_file + 'RGB') - save_band(data20_lr[:, :, 0:3], '/raw/rgbs/' + data_file + 'RGB20') - - elif true_data: - out_per_image0 = save_prefix + 'true/' - out_per_image = save_prefix + 'true/' + data_file + '/' - if not os.path.isdir(out_per_image0): - os.mkdir(out_per_image0) - if not os.path.isdir(out_per_image): - os.mkdir(out_per_image) - - print('Writing files for testing to:{}'.format(out_per_image)) - save_test_patches60(data10_gt, data20_gt, data60_gt, out_per_image, patchSize=384, border=12) + data10_lr = downPixelAggr(data10, SCALE=2) + data20_lr = downPixelAggr(data20, SCALE=2) - with open(out_per_image + 'roi.json', 'w') as f: - json.dump([tmxmin, tmymin, tmxmax+1, tmymax+1], f) + return data10_lr, data20_lr - if not os.path.isdir(out_per_image + 'no_tiling/'): - os.mkdir(out_per_image + 'no_tiling/') + def process_patches(self): + if self.run_60: + scale = 6 + else: + scale = 2 + + # self.name = self.data_name.split(".")[0] + + data10, data20, data60, xmin, ymin, xmax, ymax = self.get_original_image() + + if self.test_data: + out_per_image = self.saving_test_data(data10, data20, data60) + with open(out_per_image + "roi.json", "w") as f: + json.dump( + [ + xmin // scale, + ymin // scale, + (xmax + 1) // scale, + (ymax + 1) // scale, + ], + f, + ) + + if self.rgb_images: + self.create_rgb_images(data10, data20, data60) + + if self.true_data: + out_per_image = self.saving_true_data(data10, data20, data60) + with open(out_per_image + "roi.json", "w") as f: + json.dump( + [ + xmin // scale, + ymin // scale, + (xmax + 1) // scale, + (ymax + 1) // scale, + ], + f, + ) + + if self.train_data: + self.saving_train_data(data10, data20, data60) + + LOGGER.info("Success.") + + def saving_test_data(self, data10, data20, data60): + # if test_data: + if self.run_60: + data10_lr, data20_lr, data60_lr = self.get_downsampled_images( + data10, data20, data60 + ) + out_per_image0 = self.save_prefix + "test60/" + out_per_image = self.save_prefix + "test60/" + self.data_name + "/" + if not os.path.isdir(out_per_image0): + os.mkdir(out_per_image0) + if not os.path.isdir(out_per_image): + os.mkdir(out_per_image) + + LOGGER.info(f"Writing files for testing to:{out_per_image}") + save_test_patches60(data10_lr, data20_lr, data60_lr, out_per_image) - print("Now saving the whole image without tiling...") - np.save(out_per_image + 'no_tiling/' + 'data10', data10_gt.astype(np.float32)) - np.save(out_per_image + 'no_tiling/' + 'data20', data20_gt.astype(np.float32)) - np.save(out_per_image + 'no_tiling/' + 'data60', data60_gt.astype(np.float32)) + else: + data10_lr, data20_lr = self.get_downsampled_images(data10, data20, data60) + out_per_image0 = self.save_prefix + "test/" + out_per_image = self.save_prefix + "test/" + self.data_name + "/" + if not os.path.isdir(out_per_image0): + os.mkdir(out_per_image0) + if not os.path.isdir(out_per_image): + os.mkdir(out_per_image) + + LOGGER.info( + f"Writing files for testing to:{out_per_image}" + ) # pylint: disable=logging-fstring-interpolation + save_test_patches(data10_lr, data20_lr, out_per_image) - else: - if run_60: - out_per_image0 = save_prefix + 'train60/' - out_per_image = save_prefix + 'train60/' + data_file + '/' + if not os.path.isdir(out_per_image + "no_tiling/"): + os.mkdir(out_per_image + "no_tiling/") + + LOGGER.info("Now saving the whole image without tiling...") + if self.run_60: + np.save( + out_per_image + "no_tiling/" + "data60_gt", data60.astype(np.float32) + ) + np.save( + out_per_image + "no_tiling/" + "data60", data60_lr.astype(np.float32) + ) else: - out_per_image0 = save_prefix + 'train/' - out_per_image = save_prefix + 'train/' + data_file + '/' + np.save( + out_per_image + "no_tiling/" + "data20_gt", data20.astype(np.float32) + ) + self.save_band( + self.save_prefix, + data10_lr[:, :, 0:3], + "test/" + self.data_name + "/RGB", + ) + np.save(out_per_image + "no_tiling/" + "data10", data10_lr.astype(np.float32)) + np.save(out_per_image + "no_tiling/" + "data20", data20_lr.astype(np.float32)) + return out_per_image + + def create_rgb_images(self, data10, data20, data60): + # elif write_images + data10_lr, data20_lr = self.get_downsampled_images(data10, data20, data60) + LOGGER.info("Creating RGB images...") + self.save_band( + self.save_prefix, + data10_lr[:, :, 0:3], + "/raw/rgbs/" + self.data_name + "RGB", + ) + self.save_band( + self.save_prefix, + data20_lr[:, :, 0:3], + "/raw/rgbs/" + self.data_name + "RGB20", + ) + + def saving_true_data(self, data10, data20, data60): + # elif true_data: + out_per_image0 = self.save_prefix + "true/" + out_per_image = self.save_prefix + "true/" + self.data_name + "/" if not os.path.isdir(out_per_image0): os.mkdir(out_per_image0) if not os.path.isdir(out_per_image): os.mkdir(out_per_image) - print('Writing files for training to:{}'.format(out_per_image)) - if run_60: - save_random_patches60(data60_gt, data10_lr, data20_lr, data60_lr, out_per_image) + + # pylint: disable=logging-fstring-interpolation + LOGGER.info(f"Writing files for testing to:{out_per_image}") + save_test_patches60( + data10, data20, data60, out_per_image, patchSize=384, border=12 + ) + + if not os.path.isdir(out_per_image + "no_tiling/"): + os.mkdir(out_per_image + "no_tiling/") + + LOGGER.info("Now saving the whole image without tiling...") + np.save(out_per_image + "no_tiling/" + "data10", data10.astype(np.float32)) + np.save(out_per_image + "no_tiling/" + "data20", data20.astype(np.float32)) + np.save(out_per_image + "no_tiling/" + "data60", data60.astype(np.float32)) + return out_per_image + + def saving_train_data(self, data10, data20, data60): + # if train_data + if self.run_60: + out_per_image0 = self.save_prefix + "train60/" + out_per_image = self.save_prefix + "train60/" + self.data_name + "/" + if not os.path.isdir(out_per_image0): + os.mkdir(out_per_image0) + if not os.path.isdir(out_per_image): + os.mkdir(out_per_image) + LOGGER.info( + f"Writing files for training to:{out_per_image}" + ) # pylint: disable=logging-fstring-interpolation + data10_lr, data20_lr, data60_lr = self.get_downsampled_images( + data10, data20, data60 + ) + save_random_patches60( + data60, data10_lr, data20_lr, data60_lr, out_per_image + ) else: - save_random_patches(data20_gt, data10_lr, data20_lr, out_per_image) - - print("Success.") - - -parser = argparse.ArgumentParser(description="Read Sentinel-2 data. The code was adapted from N. Brodu.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("data_file", help="An input Sentinel-2 data file. This can be either the original ZIP file, or the S2A[...].xml file in a SAFE directory extracted from that ZIP.") -parser.add_argument("--roi_x_y", default="", - help="Sets the region of interest to extract as pixels locations on the 10m bands. Use this syntax: x_1,y_1,x_2,y_2. E.g. --roi_x_y \"2000,2000,3200,3200\"") -parser.add_argument("--test_data", default=False, action="store_true", help="Store test patches in a separate dir.") -parser.add_argument("--write_images", default=False, action="store_true", help="If set, write PNG images for the original and the superresolved bands, together with a composite rgb image (first three 10m bands), all with a quick and dirty clipping to 99%% of the original bands dynamic range and a quantization of the values to 256 levels.") -parser.add_argument("--save_prefix", default="../data/", help="If set, speficies the name of a prefix for all output files. Use a trailing / to save into a directory. The default of no prefix will save into the current directory. Example: --save_prefix result/") -parser.add_argument("--run_60", default=False, action="store_true", help="If set, it will create patches also from the 60m channels.") -parser.add_argument("--true_data", default=False, action="store_true", help="If set, it will create patches for S2 without GT. This option is not really useful here, please check the testing folder for predicting S2 images.") -args = parser.parse_args() - -# args.data_file = sorted(glob.glob(data_prefix + 'S2*' + data_filename)) - -if __name__ == '__main__': - # if type(args.data_file) is list: - # fileList = args.data_file - # for s2file in fileList: - # args.data_file = os.path.split(os.path.split(s2file)[0])[1] - # readS2fromFile(args.data_file, - # args.test_data, - # args.roi_x_y, - # args.save_prefix, - # args.write_images, - # args.run_60, - # args.true_data) - # else: - print('I will proceed with file {}'.format(args.data_file)) - readS2fromFile(args.data_file, - args.test_data, - args.roi_x_y, - args.save_prefix, - args.write_images, - args.run_60, - args.true_data) + out_per_image0 = self.save_prefix + "train/" + out_per_image = self.save_prefix + "train/" + self.data_name + "/" + if not os.path.isdir(out_per_image0): + os.mkdir(out_per_image0) + if not os.path.isdir(out_per_image): + os.mkdir(out_per_image) + LOGGER.info( + f"Writing files for training to:{out_per_image}" + ) # pylint: disable=logging-fstring-interpolation + data10_lr, data20_lr = self.get_downsampled_images(data10, data20, data60) + save_random_patches(data20, data10_lr, data20_lr, out_per_image) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Read Sentinel-2 data. The code was adapted from N. Brodu.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "data_file_path", + help=( + "An input Sentinel-2 data file. This can be either the original ZIP file," + " or the S2A[...].xml file in a SAFE directory extracted from that ZIP." + ), + ) + parser.add_argument( + "clip_to_aoi", + help=( + "The original image will be clipped to the given area. ex: '12.211,52.291,12.513,52.521'" + ), + ) + parser = parser_common(parser) + + args = parser.parse_args() + + LOGGER.info( + f"I will proceed with file {args.data_file_path}" + ) # pylint: disable=logging-fstring-interpolation + readS2fromFile( + args.data_file_path, + args.clip_to_aoi, + args.save_prefix, + args.rgb_images, + args.run_60, + args.true_data, + args.test_data, + args.train_data, + ).process_patches() diff --git a/training/create_random.py b/training/create_random.py deleted file mode 100644 index f61e219..0000000 --- a/training/create_random.py +++ /dev/null @@ -1,27 +0,0 @@ -from random import randrange -import numpy as np - -# The `val_index.npy` must be created every time the number of training patches changes. It defines (and keeps set) -# which of the patches will be used for validation. - -# This file must be changed if the DSen2_60 net is trained! (change the `path` and size of patches) - -# Size: number of S2 tiles (times) patches per tile -size = 45*8000 -ratio = .1 -nb = int(size * ratio) - -index = np.zeros(size).astype(np.bool) -i = 0 -while np.sum(index.astype(np.int)) < nb: - x = randrange(0, size) - index[x] = True - i += 1 - -path = '../data/train/' -np.save(path + 'val_index.npy', index) - -print('Full no of samples: {}'.format(size)) -print('Validation samples: {}'.format(np.sum(index.astype(np.int)))) - -print("Number of iterations: {}".format(i)) diff --git a/training/create_validation_set.py b/training/create_validation_set.py new file mode 100644 index 0000000..7673214 --- /dev/null +++ b/training/create_validation_set.py @@ -0,0 +1,55 @@ +import os +import glob +import argparse + +from random import randrange +import numpy as np + + +def get_args(): + parser = argparse.ArgumentParser( + description="Create train validation index split file" + ) + parser.add_argument("--path", help="Path of data. Only relevant if set.") + parser.add_argument( + "--run_60", action="store_true", help="Generate val_index for 60m patches." + ) + args = parser.parse_args() + return args + + +def main(args): + # The `val_index.npy` must be created every time the number of training patches changes. It defines (and keeps set) + # which of the patches will be used for validation. + + # This file must be changed if the DSen2_60 net is trained! (change the `path` and size of patches) + + # Size: number of S2 tiles (times) patches per tile + n_scenes = len( + [os.path.basename(x) for x in sorted(glob.glob(args.path + "*SAFE"))] + ) + n_patches = 8000 + if args.run_60: + n_patches = 500 + size = n_scenes * n_patches + ratio = 0.1 + nb = int(size * ratio) + + index = np.zeros(size).astype(np.bool) + i = 0 + while np.sum(index.astype(np.int)) < nb: + x = randrange(0, size) + index[x] = True + i += 1 + + np.save(args.path + "val_index.npy", index) + + print("Full no of samples: {}".format(size)) + print("Validation samples: {}".format(np.sum(index.astype(np.int)))) + + print("Number of iterations: {}".format(i)) + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/training/generate_patches.md b/training/generate_patches.md new file mode 100644 index 0000000..14fcfee --- /dev/null +++ b/training/generate_patches.md @@ -0,0 +1,38 @@ +# Create train datasets +Place L1C and L2A data into train folder, and then run: +```bash +python generate_patches.py L1C --save_prefix "../data/l1c" --run_60 --train_data +python generate_patches.py L1C --save_prefix "../data/l1c" --train_data +python generate_patches.py L2A --save_prefix "../data/l2a" --run_60 --train_data +python generate_patches.py L2A --save_prefix "../data/l2a" --train_data +``` +Create validation set: +```bash +python training/create_validation_set.py --path "l1ctrain/" +python training/create_validation_set.py --path "l1ctrain60/" --run_60 +python training/create_validation_set.py --path "l2atrain/" +python training/create_validation_set.py --path "l2atrain60/" --run_60 +``` +# Create test datasets +```bash +python generate_patches.py L1C_test --save_prefix "l1c" --run_60 --test_data +python generate_patches.py L1C_test --save_prefix "l1c" --test_data +python generate_patches.py L2A_test --save_prefix "../data/l2a" --run_60 --test_data +python generate_patches.py L2A_test --save_prefix "../data/l2a" --test_data +``` + +# Upload datasets +Test: +```bash +aws s3 cp --recursive l1ctest s3://s2-l1c-training-imgs/v2/l1ctest/ +aws s3 cp --recursive l1ctest60 s3://s2-l1c-training-imgs/v2/l1ctest60/ +aws s3 cp --recursive l2atest s3://s2-l1c-training-imgs/v2/l2atest/ +aws s3 cp --recursive l2atest60 s3://s2-l1c-training-imgs/v2/l2atest60/ +``` +Train: +```bash +aws s3 cp --recursive l1ctrain s3://s2-l1c-training-imgs/v2/l1ctrain/ +aws s3 cp --recursive l1ctrain60 s3://s2-l1c-training-imgs/v2/l1ctrain60/ +aws s3 cp --recursive l2atrain s3://s2-l1c-training-imgs/v2/l2atrain/ +aws s3 cp --recursive l2atrain60 s3://s2-l1c-training-imgs/v2/l2atrain60/ +``` diff --git a/training/generate_patches.py b/training/generate_patches.py new file mode 100644 index 0000000..54e2303 --- /dev/null +++ b/training/generate_patches.py @@ -0,0 +1,48 @@ +import sys +import argparse +from pathlib import Path + +sys.path.append("..") + +from create_patches import readS2fromFile +from create_patches import parser_common + + +from utils.data_utils import get_logger + +LOGGER = get_logger(__name__) + + +def arg_parse(): + parser = argparse.ArgumentParser( + description="Read Sentinel-2 data. The code was adapted from N. Brodu.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "data_folder_path", help=("Path to folder with S2 SAFE files."), + ) + parser = parser_common(parser) + args = parser.parse_args() + return args + + +def main(args): + # pylint: disable=logging-fstring-interpolation + LOGGER.info(f"I will proceed with file {args.data_folder_path}") + + for file_path in Path(args.data_folder_path).glob("S2*"): + LOGGER.info(f"Processing {file_path}") + readS2fromFile( + str(file_path), + "", + args.save_prefix, + args.rgb_images, + args.run_60, + args.true_data, + args.test_data, + args.train_data, + ).process_patches() + + +if __name__ == "__main__": + main(arg_parse()) diff --git a/training/supres_train.py b/training/supres_train.py index c91d894..5f340f4 100644 --- a/training/supres_train.py +++ b/training/supres_train.py @@ -1,231 +1,298 @@ from __future__ import division -import numpy as np -import datetime +import os import glob +import sys +import datetime import time + import argparse -import os -import sys +import numpy as np + import matplotlib as mpl -mpl.use('Agg') +import matplotlib.pyplot as plt + +mpl.use("Agg") +import tensorflow as tf from keras.optimizers import Nadam from keras.callbacks import ModelCheckpoint, Callback, ReduceLROnPlateau from keras.utils import plot_model import keras.backend as K -sys.path.append('../') +sys.path.append("../") from utils.patches import recompose_images, OpenDataFilesTest, OpenDataFiles -from utils.DSen2Net import s2model - -K.set_image_data_format('channels_first') +from utils.DSen2Net import s2model, aesrmodel, srcnn, rednetsr, resnetsr # Define file prefix for new training, must be 7 characters of this form: -model_nr = 's2_038_' +model_nr = "s2_038_" SCALE = 2000 lr = 1e-4 -path = '../data/' +path = "../data/" if not os.path.isdir(path): os.mkdir(path) -out_path = '../data/network_data/' +out_path = "../data/network_data/" if not os.path.isdir(out_path): os.mkdir(out_path) +STRATEGY = tf.distribute.MirroredStrategy() + class PlotLosses(Callback): def __init__(self, model_nr, lr): self.model_nr = model_nr self.lr = lr + self.losses = [] + self.val_losses = [] + self.i = 0 + self.x = [] + self.filename = out_path + self.model_nr + "_lr_{:.1e}.txt".format(self.lr) def on_train_begin(self, logs=None): self.losses = [] self.val_losses = [] self.i = 0 self.x = [] - self.filename = out_path + self.model_nr + '_lr_{:.1e}.txt'.format(self.lr) - open(self.filename, 'w').close() + self.filename = out_path + self.model_nr + "_lr_{:.1e}.txt".format(self.lr) + open(self.filename, "w").close() def on_epoch_end(self, epoch, logs=None): - import matplotlib.pyplot as plt + plt.ioff() lr = float(K.get_value(self.model.optimizer.lr)) # data = np.loadtxt("training.log", skiprows=1, delimiter=',') - self.losses.append(logs.get('loss')) - self.val_losses.append(logs.get('val_loss')) + self.losses.append(logs.get("loss")) + self.val_losses.append(logs.get("val_loss")) self.x.append(self.i) self.i += 1 try: - with open(self.filename, 'a') as self.f: - self.f.write('Finished epoch {:5d}: loss {:.3e}, valid: {:.3e}, lr: {:.1e}\n' - .format(epoch, logs.get('loss'), logs.get('val_loss'), lr)) + with open(self.filename, "a") as f: + f.write( + "Finished epoch {:5d}: loss {:.3e}, valid: {:.3e}, lr: {:.1e}\n".format( + epoch, logs.get("loss"), logs.get("val_loss"), lr + ) + ) if epoch > 500: plt.clf() - plt.plot(self.x[475:], self.losses[475:], label='loss') - plt.plot(self.x[475:], self.val_losses[475:], label='val_loss') + plt.plot(self.x[475:], self.losses[475:], label="loss") + plt.plot(self.x[475:], self.val_losses[475:], label="val_loss") plt.legend() - plt.xlabel('epochs') + plt.xlabel("epochs") # plt.waitforbuttonpress(0) - plt.savefig(out_path + self.model_nr + '_loss4.png') + plt.savefig(out_path + self.model_nr + "_loss4.png") elif epoch > 250: plt.clf() - plt.plot(self.x[240:], self.losses[240:], label='loss') - plt.plot(self.x[240:], self.val_losses[240:], label='val_loss') + plt.plot(self.x[240:], self.losses[240:], label="loss") + plt.plot(self.x[240:], self.val_losses[240:], label="val_loss") plt.legend() - plt.xlabel('epochs') + plt.xlabel("epochs") # plt.waitforbuttonpress(0) - plt.savefig(out_path + self.model_nr + '_loss3.png') + plt.savefig(out_path + self.model_nr + "_loss3.png") elif epoch > 100: plt.clf() - plt.plot(self.x[85:], self.losses[85:], label='loss') - plt.plot(self.x[85:], self.val_losses[85:], label='val_loss') + plt.plot(self.x[85:], self.losses[85:], label="loss") + plt.plot(self.x[85:], self.val_losses[85:], label="val_loss") plt.legend() - plt.xlabel('epochs') + plt.xlabel("epochs") # plt.waitforbuttonpress(0) - plt.savefig(out_path + self.model_nr + '_loss2.png') + plt.savefig(out_path + self.model_nr + "_loss2.png") elif epoch > 50: plt.clf() - plt.plot(self.x[50:], self.losses[50:], label='loss') - plt.plot(self.x[50:], self.val_losses[50:], label='val_loss') + plt.plot(self.x[50:], self.losses[50:], label="loss") + plt.plot(self.x[50:], self.val_losses[50:], label="val_loss") plt.legend() - plt.xlabel('epochs') + plt.xlabel("epochs") # plt.waitforbuttonpress(0) - plt.savefig(out_path + self.model_nr + '_loss1.png') + plt.savefig(out_path + self.model_nr + "_loss1.png") else: plt.clf() - plt.plot(self.x[0:], self.losses[0:], label='loss') - plt.plot(self.x[0:], self.val_losses[0:], label='val_loss') + plt.plot(self.x[0:], self.losses[0:], label="loss") + plt.plot(self.x[0:], self.val_losses[0:], label="val_loss") plt.legend() - plt.xlabel('epochs') + plt.xlabel("epochs") # plt.waitforbuttonpress(0) - plt.savefig(out_path + self.model_nr + '_loss0.png') + plt.savefig(out_path + self.model_nr + "_loss0.png") except IOError: - print('Network drive unavailable.') + print("Network drive unavailable.") print(datetime.datetime.now().time()) -if __name__ == '__main__': +if __name__ == "__main__": - parser = argparse.ArgumentParser(description='SupResS2.') - parser.add_argument('--predict', action='store', dest='predict_file', help='Predict.') - parser.add_argument('--resume', action='store', dest='resume_file', help='Resume training.') - parser.add_argument('--true', action='store_true', help='Use true scale data. No simulation or different resolutions.') - parser.add_argument('--run_60', action='store_true', help='Whether to run a 60->10m network. Default 20->10m.') - parser.add_argument('--deep', action='store_true', help='.') - parser.add_argument('--path', help='Path of data. Only relevant if set.') + parser = argparse.ArgumentParser(description="SupResS2.") + parser.add_argument( + "--predict", action="store", dest="predict_file", help="Predict." + ) + parser.add_argument( + "--resume", action="store", dest="resume_file", help="Resume training." + ) + parser.add_argument( + "--true", + action="store_true", + help="Use true scale data. No simulation or different resolutions.", + ) + parser.add_argument( + "--run_60", + action="store_true", + help="Whether to run a 60->10m network. Default 20->10m.", + ) + parser.add_argument( + "--model", + default="dsen2", + choices=["vdsen2", "aesr", "srcnn", "rednet", "resnet"], + help="Model architecture to use.", + ) + parser.add_argument( + "--epochs", default=8 * 1024, type=int, help="Number of epochs to train with." + ) + parser.add_argument("--path", help="Path of data. Only relevant if set.") args = parser.parse_args() if args.path is not None: path = args.path - # input_shape = ((4,32,32),(6,16,16)) - if args.run_60: - input_shape = ((4, None, None), (6, None, None), (2, None, None)) - else: - input_shape = ((4, None, None), (6, None, None)) - # create model - if args.deep: - model = s2model(input_shape, num_layers=32, feature_size=256) - batch_size = 8 - else: - model = s2model(input_shape, num_layers=6, feature_size=128) - batch_size = 128 - print('Symbolic Model Created.') - - nadam = Nadam(lr=lr, - beta_1=0.9, - beta_2=0.999, - epsilon=1e-8, - schedule_decay=0.004) - # clipvalue=0.000005) - - model.compile(optimizer=nadam, loss='mean_absolute_error', metrics=['mean_squared_error']) - print('Model compiled.') - model.count_params() - # model.summary() + with STRATEGY.scope(): + # input_shape = ((4,32,32),(6,16,16)) + if args.run_60: + input_shape = ((4, None, None), (6, None, None), (2, None, None)) # type: ignore + else: + input_shape = ((4, None, None), (6, None, None)) # type: ignore + # create model + print( + "================================================================" + f"Using {args.model}" + "================================================================" + ) + if args.model == "dsen2": + model = s2model(input_shape, num_layers=6, feature_size=128) + batch_size = 128 + elif args.model == "vdsen2": + model = s2model(input_shape, num_layers=32, feature_size=256) + batch_size = 8 + elif args.model == "aesr": + model = aesrmodel(input_shape) + batch_size = 128 + elif args.model == "srcnn": + model = srcnn(input_shape) + batch_size = 128 + elif args.model == "rednet": + model = rednetsr(input_shape) + batch_size = 128 + elif args.model == "resnet": + model = resnetsr(input_shape) + batch_size = 128 + + else: + print("No model selected!!") + + print("Symbolic Model Created.") + + nadam = Nadam( + lr=lr, beta_1=0.9, beta_2=0.999, epsilon=1e-8, schedule_decay=0.004 + ) + # clipvalue=0.000005) + + model.compile( + optimizer=nadam, loss="mean_absolute_error", metrics=["mean_squared_error"] + ) + print("Model compiled.") + model.count_params() + # model.summary() if args.predict_file: if args.true: - folder = 'true/' + folder = "true/" border = 12 elif args.run_60: - folder = 'test60/' + folder = "test60/" border = 12 else: - folder = 'test/' + folder = "test/" border = 4 model_nr = args.predict_file[-20:-13] - print('Changing the model number to: {}'.format(model_nr)) + print("Changing the model number to: {}".format(model_nr)) model.load_weights(args.predict_file) print("Predicting using file: {}".format(args.predict_file)) - fileList = [os.path.basename(x) for x in sorted(glob.glob(path + folder + '*SAFE'))] + fileList = [ + os.path.basename(x) for x in sorted(glob.glob(path + folder + "*SAFE")) + ] for dset in fileList: start = time.time() print("Timer started.") print("Predicting: {}.".format(dset)) - train, image_size = OpenDataFilesTest(path + folder + dset, args.run_60, SCALE, args.true) - prediction = model.predict(train, - batch_size=8, - verbose=1) - prediction_file = model_nr + '-predict' + train, image_size = OpenDataFilesTest( + path + folder + dset, args.run_60, SCALE, args.true + ) + prediction = model.predict(train, batch_size=8, verbose=1) + prediction_file = model_nr + "-predict" # np.save(path + 'test/' + dset + '/' + prediction_file + 'pat', prediction * SCALE) images = recompose_images(prediction, border=border, size=image_size) - print('Writing to file...') - np.save(path + folder + dset + '/' + prediction_file, images * SCALE) + print("Writing to file...") + np.save(path + folder + dset + "/" + prediction_file, images * SCALE) end = time.time() - print('Elapsed time: {}.'.format(end - start)) + print("Elapsed time: {}.".format(end - start)) sys.exit(0) if args.resume_file: print("Will resume from the weights {}".format(args.resume_file)) model.load_weights(args.resume_file) model_nr = args.resume_file[-20:-13] - print('Changing the model number to: {}'.format(model_nr)) + print("Changing the model number to: {}".format(model_nr)) else: - print('Model number is {}'.format(model_nr)) - plot_model(model, to_file=out_path + model_nr+'model.png', show_shapes=True, show_layer_names=True) + print("Model number is {}".format(model_nr)) + plot_model( + model, + to_file=out_path + model_nr + "model.png", + show_shapes=True, + show_layer_names=True, + ) model_yaml = model.to_yaml() - with open(out_path + model_nr + "model.yaml", 'w') as yaml_file: + with open(out_path + model_nr + "model.yaml", "w") as yaml_file: yaml_file.write(model_yaml) - filepath = out_path + model_nr + 'lr_{:.0e}.hdf5'.format(lr) - checkpoint = ModelCheckpoint(filepath, - monitor='val_loss', - verbose=1, - save_best_only=True, - save_weights_only=False, - mode='auto') + filepath = out_path + model_nr + "lr_{:.0e}.hdf5".format(lr) + checkpoint = ModelCheckpoint( + filepath, + monitor="val_loss", + verbose=1, + save_best_only=True, + save_weights_only=False, + mode="auto", + ) plot_losses = PlotLosses(model_nr, lr) - LRreducer = ReduceLROnPlateau(monitor='val_loss', - factor=0.5, - patience=5, - verbose=1, - epsilon=1e-6, - cooldown=20, - min_lr=1e-5) + LRreducer = ReduceLROnPlateau( + monitor="val_loss", + factor=0.5, + patience=5, + verbose=1, + epsilon=1e-6, + cooldown=20, + min_lr=1e-5, + ) callbacks_list = [checkpoint, plot_losses, LRreducer] - print('Loading the training data...') + print("Loading the training data...") train, label, val_tr, val_lb = OpenDataFiles(path, args.run_60, SCALE) - print('Training starts...') - - model.fit(x=train, - y=label, - batch_size=batch_size, - epochs=8 * 1024, - verbose=1, - callbacks=callbacks_list, - validation_split=0., - validation_data=(val_tr, val_lb), - shuffle=True, - class_weight=None, - sample_weight=None, - initial_epoch=0, - validation_steps=None) + print("Training starts...") + model.fit( + x=train, + y=label, + batch_size=batch_size, + epochs=args.epochs, + verbose=1, + callbacks=callbacks_list, + validation_split=0.0, + validation_data=(val_tr, val_lb), + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + validation_steps=None, + ) diff --git a/utils/DSen2Net.py b/utils/DSen2Net.py index 9d6d381..b1cd5a2 100644 --- a/utils/DSen2Net.py +++ b/utils/DSen2Net.py @@ -1,44 +1,219 @@ from __future__ import division +import keras from keras.models import Model, Input -from keras.layers import Conv2D, Concatenate, Activation, Lambda, Add -import keras.backend as K +from keras.layers import ( + Conv2D, + Conv2DTranspose, + Concatenate, + Activation, + Lambda, + Add, + BatchNormalization, + ReLU, +) -K.set_image_data_format('channels_first') +keras.backend.set_image_data_format("channels_first") -def resBlock(x, channels, kernel_size=[3, 3], scale=0.1): - tmp = Conv2D(channels, kernel_size, kernel_initializer='he_uniform', padding='same')(x) - tmp = Activation('relu')(tmp) - tmp = Conv2D(channels, kernel_size, kernel_initializer='he_uniform', padding='same')(tmp) +def resBlock(x, channels, kernel_size, scale=0.1): + """ + A residual block v = ResBlock(z, f ) is defined as a series of layers that operate on + an input image z to generate an output z4, then adds that output to the input image as follows: + z1 = conv(z, f ) #convolution (5a) + z2 = max(z1, 0) #ReLU layer (5b) + z3 = conv(z2, f ) #convolution (5c) + z4 = lamda*z3 #residual scaling (5d) + v = z4 + z #skip connection + + Conv2D: 2D convolution layer + Activation function: A function that is added into an artificial neural network in order + to help the network learn complex patterns in the data. + Add function: Simply adding two layers + + """ + tmp = Conv2D( + channels, kernel_size, kernel_initializer="he_uniform", padding="same", + )(x) + tmp = Activation("relu")(tmp) + tmp = Conv2D( + channels, kernel_size, kernel_initializer="he_uniform", padding="same", + )(tmp) tmp = Lambda(lambda x: x * scale)(tmp) return Add()([x, tmp]) -def s2model(input_shape, num_layers=32, feature_size=256): - +def init(input_shape): + """ + Input function: is used to instantiate a Keras tensor. + """ input10 = Input(shape=input_shape[0]) input20 = Input(shape=input_shape[1]) + res = [input10, input20] + channels = input_shape[1][0] if len(input_shape) == 3: input60 = Input(shape=input_shape[2]) x = Concatenate(axis=1)([input10, input20, input60]) + res.append(input60) + channels = input_shape[2][0] else: x = Concatenate(axis=1)([input10, input20]) + return x, res, channels + + +def aesrmodel(input_shape, n1=64): + x, _input, channels = init(input_shape) + + level1_1 = Conv2D(n1, (3, 3), activation="relu", padding="same")(x) + level2_1 = Conv2D(n1, (3, 3), activation="relu", padding="same")(level1_1) + + level2_2 = Conv2DTranspose(n1, (3, 3), activation="relu", padding="same")(level2_1) + level2 = Add()([level2_1, level2_2]) + + level1_2 = Conv2DTranspose(n1, (3, 3), activation="relu", padding="same")(level2) + level1 = Add()([level1_1, level1_2]) + + decoded = Conv2D(channels, (5, 5), activation="linear", padding="same",)(level1) + + model = Model(inputs=_input, outputs=decoded) + return model + + +def srcnn(input_shape): + f1 = 9 + f2 = 1 + f3 = 5 + + n1 = 64 + n2 = 32 + x, _input, channels = init(input_shape) + + x = Conv2D(n1, (f1, f1), activation="relu", padding="same", name="level1")(x) + x = Conv2D(n2, (f2, f2), activation="relu", padding="same", name="level2")(x) + + out = Conv2D(channels, (f3, f3), padding="same", name="output")(x) + + model = Model(inputs=_input, outputs=out) + return model + + +def rednetsr(input_shape): + def _build_layer_list(model): + model_outputs = [layer.output for layer in model.layers] + return model_outputs + + n_conv_layers = 15 + n_deconv_layers = 15 + n_skip = 2 + n = 32 + + x, _input, channels = init(input_shape) + + for i in range(n_conv_layers): + conv_idx = i + 1 + if conv_idx == 1: + conv = Conv2D(n, (3, 3), activation="relu", padding="same")(x) + else: + conv = Conv2D(n, (3, 3), activation="relu", padding="same")(conv) + + encoded = conv + encoder = Model(inputs=_input, outputs=encoded, name="encoder") + # Create encoder layer and output lists + encoder_outputs = _build_layer_list(encoder) + + # CREATE AUTOENCODER MODEL + for i, skip in enumerate(reversed(encoder_outputs[len(_input) + 1 :])): + + deconv_idx = i + 1 + deconv_filters = n + if deconv_idx == n_deconv_layers: + deconv_filters = channels + + if deconv_idx == 1: + deconv = Conv2DTranspose( + deconv_filters, (3, 3), activation="relu", padding="same" + )(encoded) + else: + deconv = Conv2DTranspose( + deconv_filters, (3, 3), activation="relu", padding="same" + )(deconv) + + if deconv_idx % n_skip == 0: + deconv = Add()([deconv, skip]) + ReLU()(deconv) + + decoded = deconv # (decoder_inputs) + model = Model(inputs=_input, outputs=decoded) + return model + + +def resnetsr(input_shape): + def _residual_block(ip, _id): + channel_axis = 1 # channels first + init = ip + + x = Conv2D(n, (3, 3), padding="same", name="sr_res_conv_" + str(_id) + "_1")(ip) + x = BatchNormalization( + momentum=0.5, axis=channel_axis, name="sr_res_batchnorm_" + str(_id) + "_1" + )(x) + x = ReLU()(x) + x = Conv2D(n, (3, 3), padding="same", name="sr_res_conv_" + str(_id) + "_2")(x) + x = BatchNormalization( + momentum=0.5, axis=channel_axis, name="sr_res_batchnorm_" + str(_id) + "_2" + )(x) + + m = Add(name="sr_res_merge_" + str(_id))([x, init]) + return m + + n = 64 + x, _input, channels = init(input_shape) + + x0 = Conv2D(n, (9, 9), padding="same", name="sr_res_conv1")(x) + x0 = ReLU()(x0) + x = x0 + + nb_residual = 16 + for i in range(nb_residual): + x0 = _residual_block(x0, i + 1) + + x0 = Conv2D(filters=n, kernel_size=3, strides=1, padding="same")(x0) + x0 = BatchNormalization(axis=1, momentum=0.5)(x0) + x0 = Add()([x, x0]) + + x0 = Conv2D(channels, (9, 9), padding="same", name="sr_res_conv_final")(x0) + x0 = Activation("tanh")(x0) + model = Model(inputs=_input, outputs=x0) + + return model + + +def s2model(input_shape, num_layers=32, feature_size=256): + """ + This function contains the model architecture which contains a resBlock and 2 extra Conv2D layer. + """ + + x, _input, _ = init(input_shape) # Treat the concatenation - x = Conv2D(feature_size, (3, 3), kernel_initializer='he_uniform', activation='relu', padding='same')(x) + x = Conv2D( + feature_size, + (3, 3), + kernel_initializer="he_uniform", + activation="relu", + padding="same", + )(x) - for i in range(num_layers): - x = resBlock(x, feature_size) + for _ in range(num_layers): + x = resBlock(x, feature_size, [3, 3]) # One more convolution, and then we add the output of our first conv layer - x = Conv2D(input_shape[-1][0], (3, 3), kernel_initializer='he_uniform', padding='same')(x) - # x = Dropout(0.3)(x) + x = Conv2D( + input_shape[-1][0], (3, 3), kernel_initializer="he_uniform", padding="same", + )(x) if len(input_shape) == 3: - x = Add()([x, input60]) - model = Model(inputs=[input10, input20, input60], outputs=x) + x = Add()([x, _input[2]]) + model = Model(inputs=_input, outputs=x) else: - x = Add()([x, input20]) - model = Model(inputs=[input10, input20], outputs=x) + x = Add()([x, _input[1]]) + model = Model(inputs=_input, outputs=x) return model - diff --git a/utils/data_utils.py b/utils/data_utils.py new file mode 100644 index 0000000..64a26f4 --- /dev/null +++ b/utils/data_utils.py @@ -0,0 +1,286 @@ +from __future__ import division +import os +import sys +import re +import glob +import logging + +from collections import defaultdict +from typing import List, Tuple +import numpy as np +import imageio + +import rasterio +from rasterio.windows import Window +import pyproj as proj + +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + +def get_logger(name: str, level=logging.DEBUG) -> logging.Logger: + """ + Instantiate a logger with a given level and name. + + Example: + ```python + logger = get_logger(__name__) + # __name__ is the name of the file where the logger is instantiated. + ``` + + Arguments: + name: A name for the logger - is included in all the logging messages. + level: A logging level (i.e. logging.DEBUG). + + Returns: + A logger instance. + """ + logger = logging.getLogger(name) + logger.setLevel(level) + # create console handler and set level to debug + ch = logging.StreamHandler() + ch.setLevel(level) + formatter = logging.Formatter(LOG_FORMAT) + ch.setFormatter(formatter) + logger.addHandler(ch) + return logger + + +LOGGER = get_logger(__name__) + + +class DATA_UTILS: + def __init__(self, data_file_path): + self.data_file_path = data_file_path + + # pylint: disable=attribute-defined-outside-init + def get_data(self) -> list: + """ + This method returns the raster data set of original image for + all the available resolutions. + """ + data_folder = "MTD*.xml" + for file in glob.iglob(os.path.join(self.data_file_path, data_folder)): + data_path = file + + LOGGER.info(f"Data path is {data_path}") + + raster_data = rasterio.open(data_path) + datasets = raster_data.subdatasets + + return datasets + + @staticmethod + def get_max_min(x_1: int, y_1: int, x_2: int, y_2: int, data) -> Tuple: + """ + This method gets pixels' location for the region of interest on the 10m bands + and returns the min/max in each direction and to nearby 60m pixel boundaries and the area + associated to the region of interest. + **Example** + >>> get_max_min(0,0,400,400) + (0, 0, 395, 395, 156816) + + """ + with rasterio.open(data) as d_s: + d_width = d_s.width + d_height = d_s.height + + tmxmin = max(min(x_1, x_2, d_width - 1), 0) + tmxmax = min(max(x_1, x_2, 0), d_width - 1) + tmymin = max(min(y_1, y_2, d_height - 1), 0) + tmymax = min(max(y_1, y_2, 0), d_height - 1) + # enlarge to the nearest 60 pixel boundary for the super-resolution + tmxmin = int(tmxmin / 6) * 6 + tmxmax = int((tmxmax + 1) / 6) * 6 - 1 + tmymin = int(tmymin / 6) * 6 + tmymax = int((tmymax + 1) / 6) * 6 - 1 + area = (tmxmax - tmxmin + 1) * (tmymax - tmymin + 1) + return tmxmin, tmymin, tmxmax, tmymax, area + + # pylint: disable-msg=too-many-locals + def to_xy(self, lon: float, lat: float, data) -> Tuple: + """ + This method gets the longitude and the latitude of a given point and projects it + into pixel location in the new coordinate system. + :param lon: The longitude of a chosen point + :param lat: The longitude of a chosen point + :return: The pixel location in the coordinate system of the input image + """ + # get the image's coordinate system. + with rasterio.open(data) as d_s: + coor = d_s.transform + a_t, b_t, xoff, d_t, e_t, yoff = [coor[x] for x in range(6)] + + # transform the lat and lon into x and y position which are defined in + # the world's coordinate system. + local_crs = self.get_utm(data) + crs_wgs = proj.Proj(init="epsg:4326") # WGS 84 geographic coordinate system + crs_bng = proj.Proj(init=local_crs) # use a locally appropriate projected CRS + x_p, y_p = proj.transform(crs_wgs, crs_bng, lon, lat) + x_p -= xoff + y_p -= yoff + + # matrix inversion + # get the x and y position in image's coordinate system. + det_inv = 1.0 / (a_t * e_t - d_t * b_t) + x_n = (e_t * x_p - b_t * y_p) * det_inv + y_n = (-d_t * x_p + a_t * y_p) * det_inv + return int(x_n), int(y_n) + + @staticmethod + def get_utm(data) -> str: + """ + This method returns the utm of the input image. + :param data: The raster file for a specific resolution. + :return: UTM of the selected raster file. + """ + with rasterio.open(data) as d_s: + data_crs = d_s.crs.to_dict() + utm = data_crs["init"] + return utm + + # pylint: disable-msg=too-many-locals + def area_of_interest(self, data, clip_to_aoi) -> Tuple: + """ + This method returns the coordinates that define the desired area of interest. + """ + if clip_to_aoi: + roi_lon1, roi_lat1, roi_lon2, roi_lat2 = [ + float(x) for x in re.split(",", clip_to_aoi) + ] + x_1, y_1 = self.to_xy(roi_lon1, roi_lat1, data) + x_2, y_2 = self.to_xy(roi_lon2, roi_lat2, data) + else: + x_1, y_1, x_2, y_2 = 0, 0, 20000, 20000 + + xmi, ymi, xma, yma, area = self.get_max_min(x_1, y_1, x_2, y_2, data) + return xmi, ymi, xma, yma, area + + @staticmethod + def validate_description(description: str) -> str: + """ + This method rewrites the description of each band in the given data set. + :param description: The actual description of a chosen band. + + **Example** + >>> ds10.descriptions[0] + 'B4, central wavelength 665 nm' + >>> validate_description(ds10.descriptions[0]) + 'B4 (665 nm)' + """ + m_re = re.match(r"(.*?), central wavelength (\d+) nm", description) + if m_re: + return m_re.group(1) + " (" + m_re.group(2) + " nm)" + return description + + @staticmethod + def get_band_short_name(description: str) -> str: + """ + This method returns only the name of the bands at a chosen resolution. + + :param description: This is the output of the validate_description method. + + **Example** + >>> desc = validate_description(ds10.descriptions[0]) + >>> desc + 'B4 (665 nm)' + >>> get_band_short_name(desc) + 'B4' + """ + if "," in description: + return description[: description.find(",")] + if " " in description: + return description[: description.find(" ")] + return description[:3] + + def validate(self, data) -> Tuple: + """ + This method takes the short name of the bands for each + separate resolution and returns three lists. The validated_ + bands and validated_indices contain the name of the bands and + the indices related to them respectively. + The validated_descriptions is a list of descriptions for each band + obtained from the validate_description method. + :param data: The raster file for a specific resolution. + **Example** + >>> validated_10m_bands, validated_10m_indices, \ + >>> dic_10m = validate(ds10) + >>> validated_10m_bands + ['B4', 'B3', 'B2', 'B8'] + >>> validated_10m_indices + [0, 1, 2, 3] + >>> dic_10m + defaultdict(, {'B4': 'B4 (665 nm)', + 'B3': 'B3 (560 nm)', 'B2': 'B2 (490 nm)', 'B8': 'B8 (842 nm)'}) + """ + input_select_bands = "B1,B2,B3,B4,B5,B6,B7,B8,B8A,B9,B11,B12" # type: str + select_bands = re.split(",", input_select_bands) # type: List[str] + validated_bands = [] # type: list + validated_indices = [] # type: list + validated_descriptions = defaultdict(str) # type: defaultdict + with rasterio.open(data) as d_s: + for i in range(0, d_s.count): + desc = self.validate_description(d_s.descriptions[i]) + name = self.get_band_short_name(desc) + if name in select_bands: + select_bands.remove(name) + validated_bands += [name] + validated_indices += [i] + validated_descriptions[name] = desc + return validated_bands, validated_indices, validated_descriptions + + # pylint: disable-msg=too-many-arguments + @staticmethod + def data_final( + data, term: List, x_mi: int, y_mi: int, x_ma: int, y_ma: int, n_res, scale + ) -> np.ndarray: + """ + This method takes the raster file at a specific + resolution and uses the output of get_max_min + to specify the area of interest. + Then it returns an numpy array of values + for all the pixels inside the area of interest. + :param data: The raster file for a specific resolution. + :param term: The validate indices of the + bands obtained from the validate method. + :return: The numpy array of pixels' value. + """ + if term: + LOGGER.info(term) + with rasterio.open(data) as d_s: + d_final = np.rollaxis( + d_s.read( + window=Window( + col_off=x_mi // scale, + row_off=y_mi // scale, + width=(x_ma - x_mi + n_res) // scale, + height=(y_ma - y_mi + n_res) // scale, + ) + ), + 0, + 3, + )[:, :, term] + return d_final + + @staticmethod + def save_band(save_prefix: str, data: np.ndarray, name: str, percentile_data=None): + # The percentile_data argument is used to plot superresolved and original data + # with a comparable black/white scale + if percentile_data is None: + percentile_data = data + mi, ma = np.percentile(percentile_data, (1, 99)) + band_data = np.clip(data, mi, ma) + band_data = (band_data - mi) / (ma - mi) + imageio.imsave(save_prefix + name + ".png", band_data) + + @staticmethod + def check_size(dims): + xmin, ymin, xmax, ymax = dims + if xmax < xmin or ymax < ymin: + LOGGER.error("Invalid region of interest / UTM Zone combination") + sys.exit(1) + + if (xmax - xmin) < 192 or (ymax - ymin) < 192: + LOGGER.error( + "AOI too small. Try again with a larger AOI (minimum pixel width or heigh of 192)" + ) + # sys.exit(1) diff --git a/utils/evaluation.py b/utils/evaluation.py new file mode 100644 index 0000000..257cf00 --- /dev/null +++ b/utils/evaluation.py @@ -0,0 +1,174 @@ +from __future__ import print_function, division +import os + +import time +import argparse +from glob import glob + +from tensorflow import keras +import numpy as np +import skimage.transform + +# For usage of eval in evaluation +# pylint: disable=unused-import +import image_similarity_measures +from image_similarity_measures.quality_metrics import ( + psnr, + sam, + sre, + ssim, + issm, + fsim, +) + +from data_utils import get_logger +from patches import recompose_images, OpenDataFilesTest + +logger = get_logger(__name__) + +SCALE = 2000 +MODEL_PATH = "../models/" + +def uiq(org_img: np.ndarray, pred_img: np.ndarray, win_size=1024, step=1024//2): + """ + Universal Image Quality index + """ + return image_similarity_measures.quality_metrics.uiq(org_img, pred_img, step_size=step, window_size=win_size) + +def rmse(org_img: np.ndarray, pred_img: np.ndarray): + """ + Root Mean Squared Error + """ + return image_similarity_measures.quality_metrics.rmse(org_img, pred_img, max_p=1) + + +def write_final_dict(metric, metric_dict): + # Create a directory to save the text file of including evaluation values. + predict_path = "val_predict/" + if not os.path.exists(predict_path): + os.makedirs(predict_path) + + with open(os.path.join(predict_path, metric + ".txt"), "w") as f: + f.writelines("{}:{}\n".format(k, v) for k, v in metric_dict.items()) + + +def predict_downsampled_img(path, model_path, folder, dset, border, final_name): + + model = keras.models.load_model(model_path) + + start = time.time() + print("Timer started.") + print("Predicting: {}.".format(dset)) + train, image_size = OpenDataFilesTest( + os.path.join(path, folder + dset), args.run_60, SCALE, False + ) + logger.info("Predicting ...") + prediction = model.predict(train, batch_size=8, verbose=1) + + images = recompose_images(prediction, border=border, size=image_size) + print("Writing to file...") + np.save(os.path.join(final_name), images * SCALE) + end = time.time() + logger.info(f"Elapsed time: {end - start}.") + + +def evaluation(org_img, pred_img, metric, bic=False): + org_img_array = np.load(org_img) + pred_img_array = np.load(pred_img) + print("eval %d %d %d" % pred_img_array.shape) + if bic: + pred_img_array = skimage.transform.resize(org_img_array, org_img_array.shape) + + print("eval %d %d %d" % org_img_array.shape) + print("eval %d %d %d" % pred_img_array.shape) + org_img_shape = org_img_array[:, :, :].shape + pred_img_shape = pred_img_array[:, :, :].shape + if org_img_shape != pred_img_shape: + pred_img_array = pred_img_array[: org_img_shape[0], : org_img_shape[1]] + + # Fo usage of eval + # pylint: disable=eval-used + result = eval(f"{metric}(org_img_array, pred_img_array)") + return result + + +def process(path, model_path, metric): + if args.l1c: + prefix = "l1c" + if args.l2a: + prefix = "l2a" + if args.run_60: + folder = prefix + "test60/" + border = 12 + else: + folder = prefix + "test/" + border = 4 + + path_to_patches = os.path.join(path, folder) + + fileList = [os.path.basename(x) for x in sorted(glob(path_to_patches + "*SAFE"))] + + mean_eval_value = [] + metric_dict = {} + + for dset in fileList: + if args.run_60: + org_img_path = os.path.join( + path_to_patches, dset + "/no_tiling/data60_gt.npy" + ) + bic_img_path = os.path.join(path_to_patches, dset + "/no_tiling/data60.npy") + pred_img_path = os.path.join( + path_to_patches, dset + "/no_tiling/data60_predicted.npy" + ) + else: + org_img_path = os.path.join( + path_to_patches, dset + "/no_tiling/data20_gt.npy" + ) + bic_img_path = os.path.join(path_to_patches, dset + "/no_tiling/data20.npy") + pred_img_path = os.path.join( + path_to_patches, dset + "/no_tiling/data20_predicted.npy" + ) + + predict_downsampled_img(path, model_path, folder, dset, border, pred_img_path) + print(org_img_path, bic_img_path, pred_img_path) + eval_value = evaluation(org_img_path, pred_img_path, metric) + if args.bic: + eval_value_bic = evaluation(org_img_path, bic_img_path, metric, bic=True) + print(f"Bicubic: {eval_value_bic}") + metric_dict[dset] = eval_value + mean_eval_value.append(eval_value) + print(f"NN: {eval_value}") + + metric_dict["mean"] = sum(mean_eval_value) / len(mean_eval_value) + + write_final_dict(metric, metric_dict) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser( + description="Evaluates an Image Super Resolution Model" + ) + parser.add_argument("--path", type=str, help="Path to image for evaluation") + parser.add_argument("--model_path", type=str, help="Path to model weights") + parser.add_argument("--l1c", action="store_true", help="Getting L1C samples") + parser.add_argument("--l2a", action="store_true", help="Getting L2A samples") + parser.add_argument("--bic", action="store_true", help="Compare bicubic result") + parser.add_argument( + "--metric", + type=str, + default="psnr", + help="Use psnr, uiq, sam or sre as evaluation metric", + ) + parser.add_argument( + "--run_60", + action="store_true", + help="Whether to run a 60->10m network. Default 20->10m.", + ) + + args = parser.parse_args() + path = args.path + model_path = args.model_path + metric = args.metric + + process(path, model_path, metric) diff --git a/utils/imresize.py b/utils/imresize.py index 7e94bfb..b8fd494 100644 --- a/utils/imresize.py +++ b/utils/imresize.py @@ -1,5 +1,6 @@ +from math import ceil + import numpy as np -from math import ceil, floor # This code was cloned from https://github.com/fatheral/matlab_imresize # It is used to get the MATLAB bicubic upsampling. @@ -11,20 +12,25 @@ def deriveSizeFromScale(img_shape, scale): output_shape.append(int(ceil(scale[k] * img_shape[k]))) return output_shape + def deriveScaleFromSize(img_shape_in, img_shape_out): scale = [] for k in range(2): scale.append(1.0 * img_shape_out[k] / img_shape_in[k]) return scale + def cubic(x): x = np.array(x).astype(np.float64) absx = np.absolute(x) absx2 = np.multiply(absx, absx) absx3 = np.multiply(absx2, absx) - f = np.multiply(1.5*absx3 - 2.5*absx2 + 1, absx <= 1) + np.multiply(-0.5*absx3 + 2.5*absx2 - 4*absx + 2, (1 < absx) & (absx <= 2)) + f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply( + -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2, (absx > 1) & (absx <= 2) + ) return f + def contributions(in_length, out_length, scale, kernel, k_width): if scale < 1: h = lambda x: scale * kernel(scale * x) @@ -32,21 +38,24 @@ def contributions(in_length, out_length, scale, kernel, k_width): else: h = kernel kernel_width = k_width - x = np.arange(1, out_length+1).astype(np.float64) + x = np.arange(1, out_length + 1).astype(np.float64) u = x / scale + 0.5 * (1 - 1 / scale) left = np.floor(u - kernel_width / 2) P = int(ceil(kernel_width)) + 2 - ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0 + ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0 indices = ind.astype(np.int32) - weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0 + weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0 weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1)) - aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32) + aux = np.concatenate( + (np.arange(in_length), np.arange(in_length - 1, -1, step=-1)) + ).astype(np.int32) indices = aux[np.mod(indices, aux.size)] ind2store = np.nonzero(np.any(weights, axis=0)) weights = weights[:, ind2store] indices = indices[:, ind2store] return weights, indices + def imresizemex(inimg, weights, indices, dim): in_shape = inimg.shape w_shape = weights.shape @@ -59,24 +68,30 @@ def imresizemex(inimg, weights, indices, dim): w = weights[i_w, :] ind = indices[i_w, :] im_slice = inimg[ind, i_img].astype(np.float64) - outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) + outimg[i_w, i_img] = np.sum( + np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0 + ) elif dim == 1: for i_img in range(in_shape[0]): for i_w in range(w_shape[0]): w = weights[i_w, :] ind = indices[i_w, :] im_slice = inimg[i_img, ind].astype(np.float64) - outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) + outimg[i_img, i_w] = np.sum( + np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0 + ) if inimg.dtype == np.uint8: outimg = np.clip(outimg, 0, 255) return np.around(outimg).astype(np.uint8) else: return outimg + def resizeAlongDim(A, dim, weights, indices): out = imresizemex(A, weights, indices, dim) return out + def imresize(I, scalar_scale=None, output_shape=None): kernel = cubic kernel_width = 4.0 @@ -89,14 +104,16 @@ def imresize(I, scalar_scale=None, output_shape=None): scale = deriveScaleFromSize(I.shape, output_shape) output_size = list(output_shape) else: - print('Error: scalar_scale OR output_shape should be defined!') + print("Error: scalar_scale OR output_shape should be defined!") return scale_np = np.array(scale) order = np.argsort(scale_np) weights = [] indices = [] for k in range(2): - w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width) + w, ind = contributions( + I.shape[k], output_size[k], scale[k], kernel, kernel_width + ) weights.append(w) indices.append(ind) B = np.copy(I) @@ -111,8 +128,8 @@ def imresize(I, scalar_scale=None, output_shape=None): B = np.squeeze(B, axis=2) return B + def convertDouble2Byte(I): B = np.clip(I, 0.0, 1.0) - B = 255*B + B = 255 * B return np.around(B).astype(np.uint8) - diff --git a/utils/patches.py b/utils/patches.py index d18dbed..b70b5cd 100644 --- a/utils/patches.py +++ b/utils/patches.py @@ -1,75 +1,116 @@ from __future__ import division -import numpy as np -from random import randrange -from skimage.transform import resize import os import glob import json from math import ceil +from random import randrange +from typing import Tuple, List -def interp_patches(image_20, image_10_shape): - data20_interp = np.zeros((image_20.shape[0:2] + image_10_shape[2:4])).astype(np.float32) +import numpy as np +from skimage.transform import resize +import skimage.measure +from scipy.ndimage.filters import gaussian_filter + + +def interp_patches( + image_20: np.ndarray, image_10_shape: Tuple[int, int, int, int] +) -> np.ndarray: + """Upsample patches to shape of higher resolution""" + data20_interp = np.zeros((image_20.shape[0:2] + image_10_shape[2:4])).astype( + np.float32 + ) for k in range(image_20.shape[0]): for w in range(image_20.shape[1]): - data20_interp[k, w] = resize(image_20[k, w] / 30000, image_10_shape[2:4], mode='reflect') * 30000 # bilinear + data20_interp[k, w] = ( + resize(image_20[k, w] / 30000, image_10_shape[2:4], mode="reflect") + * 30000 + ) # bilinear return data20_interp -def get_test_patches(dset_10, dset_20, patchSize=128, border=4, interp=True): - - PATCH_SIZE_HR = (patchSize, patchSize) - PATCH_SIZE_LR = [p//2 for p in PATCH_SIZE_HR] - BORDER_HR = border - BORDER_LR = BORDER_HR//2 - - # Mirror the data at the borders to have the same dimensions as the input - dset_10 = np.pad(dset_10, ((BORDER_HR, BORDER_HR), (BORDER_HR, BORDER_HR), (0, 0)), mode='symmetric') - dset_20 = np.pad(dset_20, ((BORDER_LR, BORDER_LR), (BORDER_LR, BORDER_LR), (0, 0)), mode='symmetric') - - BANDS10 = dset_10.shape[2] - BANDS20 = dset_20.shape[2] - patchesAlongi = (dset_20.shape[0] - 2 * BORDER_LR) // (PATCH_SIZE_LR[0] - 2 * BORDER_LR) - patchesAlongj = (dset_20.shape[1] - 2 * BORDER_LR) // (PATCH_SIZE_LR[1] - 2 * BORDER_LR) - - nr_patches = (patchesAlongi + 1) * (patchesAlongj + 1) - - label_20 = np.zeros((nr_patches, BANDS20) + PATCH_SIZE_HR).astype(np.float32) - image_20 = np.zeros((nr_patches, BANDS20) + tuple(PATCH_SIZE_LR)).astype(np.float32) - image_10 = np.zeros((nr_patches, BANDS10) + PATCH_SIZE_HR).astype(np.float32) - - # print(label_20.shape) - # print(image_20.shape) - # print(image_10.shape) - - range_i = np.arange(0, (dset_20.shape[0] - 2 * BORDER_LR) // (PATCH_SIZE_LR[0] - 2 * BORDER_LR)) * ( - PATCH_SIZE_LR[0] - 2 * BORDER_LR) - range_j = np.arange(0, (dset_20.shape[1] - 2 * BORDER_LR) // (PATCH_SIZE_LR[1] - 2 * BORDER_LR)) * ( - PATCH_SIZE_LR[1] - 2 * BORDER_LR) - - if not (np.mod(dset_20.shape[0] - 2 * BORDER_LR, PATCH_SIZE_LR[0] - 2 * BORDER_LR) == 0): - range_i = np.append(range_i, (dset_20.shape[0] - PATCH_SIZE_LR[0])) - if not (np.mod(dset_20.shape[1] - 2 * BORDER_LR, PATCH_SIZE_LR[1] - 2 * BORDER_LR) == 0): - range_j = np.append(range_j, (dset_20.shape[1] - PATCH_SIZE_LR[1])) - - # print(range_i) - # print(range_j) - - pCount = 0 +def get_patches( + dset: np.ndarray, + patch_size: int, + border: int, + patches_along_i: int, + patches_along_j: int, +) -> np.ndarray: + n_bands = dset.shape[2] + + # array index + nr_patches = (patches_along_i + 1) * (patches_along_j + 1) + range_i = np.arange(0, patches_along_i) * (patch_size - 2 * border) + range_j = np.arange(0, patches_along_j) * (patch_size - 2 * border) + + patches = np.zeros((nr_patches, n_bands) + (patch_size, patch_size)).astype( + np.float32 + ) + + # if height and width are divisible by patch size - border * 2, or if + # range_i \and range_j are smaller than size + # add one extra patch at the end of the image + if ( + np.mod(dset.shape[0] - 2 * border, patch_size - 2 * border) != 0 + or dset.shape[0] - 2 * border / patch_size - 2 * border > patches_along_i + ): + range_i = np.append(range_i, (dset.shape[0] - patch_size)) + if ( + np.mod(dset.shape[1] - 2 * border, patch_size - 2 * border) != 0 + or dset.shape[1] - 2 * border / patch_size - 2 * border > patches_along_j + ): + range_j = np.append(range_j, (dset.shape[1] - patch_size)) + + patch_count = 0 for ii in range_i.astype(int): for jj in range_j.astype(int): upper_left_i = ii upper_left_j = jj - crop_point_lr = [upper_left_i, - upper_left_j, - upper_left_i + PATCH_SIZE_LR[0], - upper_left_j + PATCH_SIZE_LR[1]] - crop_point_hr = [p*2 for p in crop_point_lr] - image_20[pCount] = np.rollaxis(dset_20[crop_point_lr[0]:crop_point_lr[2], - crop_point_lr[1]:crop_point_lr[3]], 2) - image_10[pCount] = np.rollaxis(dset_10[crop_point_hr[0]:crop_point_hr[2], - crop_point_hr[1]:crop_point_hr[3]], 2) - pCount += 1 + # make shape (p, c, w, h) + patches[patch_count] = crop_array_to_window( + dset, + get_crop_window(upper_left_i, upper_left_j, patch_size, 1), + rollaxis=True, + ) + patch_count += 1 + + assert patch_count == nr_patches == patches.shape[0] + return patches + + +def get_test_patches( + dset_10: np.ndarray, + dset_20: np.ndarray, + patchSize: int = 128, + border: int = 4, + interp: bool = True, +) -> Tuple[np.ndarray, np.ndarray]: + """Used for inference. Creates patches of specific size in the whole image (10m and 20m)""" + + patch_size_lr = patchSize // 2 + border_lr = border // 2 + + # Mirror the data at the borders to have the same dimensions as the input + dset_10 = np.pad( + dset_10, ((border, border), (border, border), (0, 0)), mode="symmetric", + ) + dset_20 = np.pad( + dset_20, + ((border_lr, border_lr), (border_lr, border_lr), (0, 0)), + mode="symmetric", + ) + + patchesAlongi = (dset_20.shape[0] - 2 * border_lr) // ( + patch_size_lr - 2 * border_lr + ) + patchesAlongj = (dset_20.shape[1] - 2 * border_lr) // ( + patch_size_lr - 2 * border_lr + ) + + image_10 = get_patches(dset_10, patchSize, border, patchesAlongi, patchesAlongj) + image_20 = get_patches( + dset_20, patch_size_lr, border_lr, patchesAlongi, patchesAlongj + ) image_10_shape = image_10.shape @@ -80,68 +121,50 @@ def get_test_patches(dset_10, dset_20, patchSize=128, border=4, interp=True): return image_10, data20_interp -def get_test_patches60(dset_10, dset_20, dset_60, patchSize=128, border=8, interp=True): +def get_test_patches60( + dset_10: np.ndarray, + dset_20: np.ndarray, + dset_60: np.ndarray, + patchSize: int = 192, + border: int = 12, + interp: bool = True, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Used for inference. Creates patches of specific size in the whole image (10m, 20m and 60m)""" - PATCH_SIZE_10 = (patchSize, patchSize) - PATCH_SIZE_20 = [p//2 for p in PATCH_SIZE_10] - PATCH_SIZE_60 = [p//6 for p in PATCH_SIZE_10] - BORDER_10 = border - BORDER_20 = BORDER_10//2 - BORDER_60 = BORDER_10//6 + patch_size_20 = patchSize // 2 + patch_size_60 = patchSize // 6 + border_20 = border // 2 + border_60 = border // 6 # Mirror the data at the borders to have the same dimensions as the input - dset_10 = np.pad(dset_10, ((BORDER_10, BORDER_10), (BORDER_10, BORDER_10), (0, 0)), mode='symmetric') - dset_20 = np.pad(dset_20, ((BORDER_20, BORDER_20), (BORDER_20, BORDER_20), (0, 0)), mode='symmetric') - dset_60 = np.pad(dset_60, ((BORDER_60, BORDER_60), (BORDER_60, BORDER_60), (0, 0)), mode='symmetric') - - - BANDS10 = dset_10.shape[2] - BANDS20 = dset_20.shape[2] - BANDS60 = dset_60.shape[2] - patchesAlongi = (dset_60.shape[0] - 2 * BORDER_60) // (PATCH_SIZE_60[0] - 2 * BORDER_60) - patchesAlongj = (dset_60.shape[1] - 2 * BORDER_60) // (PATCH_SIZE_60[1] - 2 * BORDER_60) - - nr_patches = (patchesAlongi + 1) * (patchesAlongj + 1) - - image_10 = np.zeros((nr_patches, BANDS10) + PATCH_SIZE_10).astype(np.float32) - image_20 = np.zeros((nr_patches, BANDS20) + tuple(PATCH_SIZE_20)).astype(np.float32) - image_60 = np.zeros((nr_patches, BANDS60) + tuple(PATCH_SIZE_60)).astype(np.float32) - - # print(image_10.shape) - # print(image_20.shape) - # print(image_60.shape) - - range_i = np.arange(0, (dset_60.shape[0] - 2 * BORDER_60) // (PATCH_SIZE_60[0] - 2 * BORDER_60)) * ( - PATCH_SIZE_60[0] - 2 * BORDER_60) - range_j = np.arange(0, (dset_60.shape[1] - 2 * BORDER_60) // (PATCH_SIZE_60[1] - 2 * BORDER_60)) * ( - PATCH_SIZE_60[1] - 2 * BORDER_60) - - if not (np.mod(dset_60.shape[0] - 2 * BORDER_60, PATCH_SIZE_60[0] - 2 * BORDER_60) == 0): - range_i = np.append(range_i, (dset_60.shape[0] - PATCH_SIZE_60[0])) - if not (np.mod(dset_60.shape[1] - 2 * BORDER_60, PATCH_SIZE_60[1] - 2 * BORDER_60) == 0): - range_j = np.append(range_j, (dset_60.shape[1] - PATCH_SIZE_60[1])) - - # print(range_i) - # print(range_j) - - pCount = 0 - for ii in range_i.astype(int): - for jj in range_j.astype(int): - upper_left_i = ii - upper_left_j = jj - crop_point_60 = [upper_left_i, - upper_left_j, - upper_left_i + PATCH_SIZE_60[0], - upper_left_j + PATCH_SIZE_60[1]] - crop_point_10 = [p*6 for p in crop_point_60] - crop_point_20 = [p*3 for p in crop_point_60] - image_10[pCount] = np.rollaxis(dset_10[crop_point_10[0]:crop_point_10[2], - crop_point_10[1]:crop_point_10[3]], 2) - image_20[pCount] = np.rollaxis(dset_20[crop_point_20[0]:crop_point_20[2], - crop_point_20[1]:crop_point_20[3]], 2) - image_60[pCount] = np.rollaxis(dset_60[crop_point_60[0]:crop_point_60[2], - crop_point_60[1]:crop_point_60[3]], 2) - pCount += 1 + dset_10 = np.pad( + dset_10, ((border, border), (border, border), (0, 0)), mode="symmetric", + ) + dset_20 = np.pad( + dset_20, + ((border_20, border_20), (border_20, border_20), (0, 0)), + mode="symmetric", + ) + dset_60 = np.pad( + dset_60, + ((border_60, border_60), (border_60, border_60), (0, 0)), + mode="symmetric", + ) + + patchesAlongi = (dset_60.shape[0] - 2 * border_60) // ( + patch_size_60 - 2 * border_60 + ) + patchesAlongj = (dset_60.shape[1] - 2 * border_60) // ( + patch_size_60 - 2 * border_60 + ) + + image_10 = get_patches(dset_10, patchSize, border, patchesAlongi, patchesAlongj) + image_20 = get_patches( + dset_20, patch_size_20, border_20, patchesAlongi, patchesAlongj + ) + image_60 = get_patches( + dset_60, patch_size_60, border_60, patchesAlongi, patchesAlongj + ) image_10_shape = image_10.shape @@ -156,127 +179,241 @@ def get_test_patches60(dset_10, dset_20, dset_60, patchSize=128, border=8, inter return image_10, data20_interp, data60_interp -def save_test_patches(dset_10, dset_20, file, patchSize=128, border=4, interp=True): - image_10, data20_interp = get_test_patches(dset_10, dset_20, patchSize=patchSize, border=border, interp=interp) +def save_test_patches( + dset_10: np.ndarray, + dset_20: np.ndarray, + file: str, + patchSize: int = 128, + border: int = 4, + interp: bool = True, +): + """Save patches for inference into files (10 and 20m)""" + image_10, data20_interp = get_test_patches( + dset_10, dset_20, patchSize=patchSize, border=border, interp=interp + ) print("Saving to file {}".format(file)) - np.save(file + 'data10', image_10) - np.save(file + 'data20', data20_interp) - print('Done!') - - -def save_test_patches60(dset_10, dset_20, dset_60, file, patchSize=192, border=12, interp=True): - - image_10, data20_interp, data60_interp = get_test_patches60(dset_10, dset_20, dset_60, patchSize=patchSize, - border=border, interp=interp) + np.save(file + "data10", image_10) + np.save(file + "data20", data20_interp) + print("Done!") + + +def save_test_patches60( + dset_10: np.ndarray, + dset_20: np.ndarray, + dset_60: np.ndarray, + file: str, + patchSize: int = 192, + border: int = 12, + interp: bool = True, +): + """Save patches for inference into files (10m, 20m and 60m)""" + image_10, data20_interp, data60_interp = get_test_patches60( + dset_10, dset_20, dset_60, patchSize=patchSize, border=border, interp=interp + ) print("Saving to file {}".format(file)) - np.save(file + 'data10', image_10) - np.save(file + 'data20', data20_interp) - np.save(file + 'data60', data60_interp) - print('Done!') - + np.save(file + "data10", image_10) + np.save(file + "data20", data20_interp) + np.save(file + "data60", data60_interp) + print("Done!") + + +def get_crop_window( + upper_left_x: int, upper_left_y: int, patch_size: int, scale: int = 1 +) -> List[int]: + """From a x,y coordinate pair and patch size return a list ofpixel coordinates + defining a window in an array. Optionally pass a scale factor.""" + crop_window = [ + upper_left_x, + upper_left_y, + upper_left_x + patch_size, + upper_left_y + patch_size, + ] + crop_window = [p * scale for p in crop_window] + return crop_window + + +def crop_array_to_window( + array: np.ndarray, crop_window: List[int], rollaxis: bool = True +) -> np.ndarray: + """Return a subset of a numpy array. Rollaxis optional from channels last + to channels first and vice versa. """ + cropped_array = array[ + crop_window[0] : crop_window[2], crop_window[1] : crop_window[3] + ] + if rollaxis: + return np.rollaxis(cropped_array, 2,) + else: + return cropped_array -def save_random_patches(dset_20gt, dset_10, dset_20, file, NR_CROP=8000): - PATCH_SIZE_HR = (32, 32) - PATCH_SIZE_LR = (16, 16) +def get_random_patches( + dset_20gt: np.ndarray, + dset_10: np.ndarray, + dset_20: np.ndarray, + nr_patches: int = 8000, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Returns a set number of patches randomly select from a 10m and 20m resolution.""" + patch_size = 16 BANDS10 = dset_10.shape[2] BANDS20 = dset_20.shape[2] - label_20 = np.zeros((NR_CROP, BANDS20) + PATCH_SIZE_HR).astype(np.float32) - image_20 = np.zeros((NR_CROP, BANDS20) + PATCH_SIZE_LR).astype(np.float32) - image_10 = np.zeros((NR_CROP, BANDS10) + PATCH_SIZE_HR).astype(np.float32) - - # print(label_20.shape) - # print(image_20.shape) - # print(image_10.shape) - - i = 0 - for crop in range(0, NR_CROP): + label_20 = np.zeros( + (nr_patches, BANDS20) + (patch_size * 2, patch_size * 2) + ).astype(np.float32) + image_20 = np.zeros((nr_patches, BANDS20) + (patch_size, patch_size)).astype( + np.float32 + ) + image_10 = np.zeros( + (nr_patches, BANDS10) + (patch_size * 2, patch_size * 2) + ).astype(np.float32) + + for i in range(0, nr_patches): # while True: - upper_left_x = randrange(0, dset_20.shape[0] - PATCH_SIZE_LR[0]) - upper_left_y = randrange(0, dset_20.shape[1] - PATCH_SIZE_LR[1]) - crop_point_lr = [upper_left_x, - upper_left_y, - upper_left_x + PATCH_SIZE_LR[0], - upper_left_y + PATCH_SIZE_LR[1]] - crop_point_hr = [p*2 for p in crop_point_lr] - label_20[i] = np.rollaxis(dset_20gt[crop_point_hr[0]:crop_point_hr[2], crop_point_hr[1]:crop_point_hr[3]], 2) - image_20[i] = np.rollaxis(dset_20[crop_point_lr[0]:crop_point_lr[2], crop_point_lr[1]:crop_point_lr[3]], 2) - image_10[i] = np.rollaxis(dset_10[crop_point_hr[0]:crop_point_hr[2], crop_point_hr[1]:crop_point_hr[3]], 2) - i += 1 - np.save(file + 'data10', image_10) - image_10_shape = image_10.shape - del image_10 - np.save(file + 'data20_gt', label_20) - del label_20 - - data20_interp = interp_patches(image_20, image_10_shape) - np.save(file + 'data20', data20_interp) - - print('Done!') - - -def save_random_patches60(dset_60gt, dset_10, dset_20, dset_60, file, NR_CROP=500): - - PATCH_SIZE_10 = (96, 96) - PATCH_SIZE_20 = (48, 48) - PATCH_SIZE_60 = (16, 16) + upper_left_x = randrange(0, dset_20.shape[0] - patch_size) + upper_left_y = randrange(0, dset_20.shape[1] - patch_size) + + label_20[i] = crop_array_to_window( + dset_20gt, + get_crop_window(upper_left_x, upper_left_y, patch_size, 2), + rollaxis=True, + ) + image_20[i] = crop_array_to_window( + dset_20, + get_crop_window(upper_left_x, upper_left_y, patch_size), + rollaxis=True, + ) + image_10[i] = crop_array_to_window( + dset_10, + get_crop_window(upper_left_x, upper_left_y, patch_size, 2), + rollaxis=True, + ) + + image_20 = interp_patches(image_20, image_10.shape) + + return image_10, label_20, image_20 + + +def get_random_patches60( + dset_60gt: np.ndarray, + dset_10: np.ndarray, + dset_20: np.ndarray, + dset_60: np.ndarray, + nr_patches: int = 500, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Returns a set number of patches randomly select from a 10m, 20m and 60m resolution.""" + + patch_size = 16 BANDS10 = dset_10.shape[2] BANDS20 = dset_20.shape[2] BANDS60 = dset_60.shape[2] - label_60 = np.zeros((NR_CROP, BANDS60) + PATCH_SIZE_10).astype(np.float32) - image_10 = np.zeros((NR_CROP, BANDS10) + PATCH_SIZE_10).astype(np.float32) - image_20 = np.zeros((NR_CROP, BANDS20) + PATCH_SIZE_20).astype(np.float32) - image_60 = np.zeros((NR_CROP, BANDS60) + PATCH_SIZE_60).astype(np.float32) - - print(label_60.shape) - print(image_10.shape) - print(image_20.shape) - print(image_60.shape) - - i = 0 - for crop in range(0, NR_CROP): - # while True: - upper_left_x = randrange(0, dset_60.shape[0] - PATCH_SIZE_60[0]) - upper_left_y = randrange(0, dset_60.shape[1] - PATCH_SIZE_60[1]) - crop_point_lr = [upper_left_x, - upper_left_y, - upper_left_x + PATCH_SIZE_60[0], - upper_left_y + PATCH_SIZE_60[1]] - crop_point_hr20 = [p*3 for p in crop_point_lr] - crop_point_hr60 = [p*6 for p in crop_point_lr] - - label_60[i] = np.rollaxis(dset_60gt[crop_point_hr60[0]:crop_point_hr60[2], crop_point_hr60[1]:crop_point_hr60[3]], 2) - image_10[i] = np.rollaxis(dset_10[crop_point_hr60[0]:crop_point_hr60[2], crop_point_hr60[1]:crop_point_hr60[3]], 2) - image_20[i] = np.rollaxis(dset_20[crop_point_hr20[0]:crop_point_hr20[2], crop_point_hr20[1]:crop_point_hr20[3]], 2) - image_60[i] = np.rollaxis(dset_60[crop_point_lr[0]:crop_point_lr[2], crop_point_lr[1]:crop_point_lr[3]], 2) - i += 1 - np.save(file + 'data10', image_10) - image_10_shape = image_10.shape + label_60 = np.zeros( + (nr_patches, BANDS60) + (patch_size * 6, patch_size * 6) + ).astype(np.float32) + image_10 = np.zeros( + (nr_patches, BANDS10) + (patch_size * 6, patch_size * 6) + ).astype(np.float32) + image_20 = np.zeros( + (nr_patches, BANDS20) + (patch_size * 3, patch_size * 3) + ).astype(np.float32) + image_60 = np.zeros((nr_patches, BANDS60) + (patch_size, patch_size)).astype( + np.float32 + ) + + for i in range(0, nr_patches): + upper_left_x = randrange(0, dset_60.shape[0] - patch_size) + upper_left_y = randrange(0, dset_60.shape[1] - patch_size) + + label_60[i] = crop_array_to_window( + dset_60gt, + get_crop_window(upper_left_x, upper_left_y, patch_size, 6), + rollaxis=True, + ) + image_10[i] = crop_array_to_window( + dset_10, + get_crop_window(upper_left_x, upper_left_y, patch_size, 6), + rollaxis=True, + ) + image_20[i] = crop_array_to_window( + dset_20, + get_crop_window(upper_left_x, upper_left_y, patch_size, 3), + rollaxis=True, + ) + image_60[i] = crop_array_to_window( + dset_60, + get_crop_window(upper_left_x, upper_left_y, patch_size, 1), + rollaxis=True, + ) + + image_20 = interp_patches(image_20, image_10.shape) + image_60 = interp_patches(image_60, image_10.shape) + return image_10, label_60, image_20, image_60 + + +def save_random_patches( + dset_20gt: np.ndarray, + dset_10: np.ndarray, + dset_20: np.ndarray, + file: str, + NR_CROP: int = 8000, +): + """Save patches into file for training (10 and 20m)""" + image_10, label_20, image_20 = get_random_patches( + dset_20gt, dset_10, dset_20, NR_CROP + ) + + np.save(file + "data10", image_10) + del image_10 + np.save(file + "data20_gt", label_20) + del label_20 + np.save(file + "data20", image_20) + del image_20 + print("Done!") + + +def save_random_patches60( + dset_60gt: np.ndarray, + dset_10: np.ndarray, + dset_20: np.ndarray, + dset_60: np.ndarray, + file: str, + NR_CROP: int = 500, +): + """Save patches into file for training (10, 20m and 60m)""" + + image_10, label_60, image_20, image_60 = get_random_patches60( + dset_60gt, dset_10, dset_20, dset_60, NR_CROP + ) + np.save(file + "data10", image_10) del image_10 - np.save(file + 'data60_gt', label_60) + np.save(file + "data60_gt", label_60) del label_60 - data20_interp = interp_patches(image_20, image_10_shape) - np.save(file + 'data20', data20_interp) - del data20_interp + np.save(file + "data20", image_20) + del image_20 - data60_interp = interp_patches(image_60, image_10_shape) - np.save(file + 'data60', data60_interp) + np.save(file + "data60", image_60) + del image_60 - print('Done!') + print("Done!") -def splitTrainVal(train_path, train, label): +def splitTrainVal( + train_path: str, train: List[np.ndarray], label: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Create test validation split from val_index.npy file generated from create_validation_set.""" + # val_ind is numpy array + # pylint: disable=invalid-unary-operand-type try: - val_ind = np.load(train_path + 'val_index.npy') + val_ind = np.load(train_path + "val_index.npy") except IOError: - print("Please define the validation split indices, usually located in .../data/test/. To generate this file use" - " createRandom.py") + print( + "Please define the validation split indices, usually located in .../data/test/. To generate this file use" + " createRandom.py" + ) val_tr = [p[val_ind] for p in train] train = [p[~val_ind] for p in train] val_lb = label[val_ind] @@ -285,30 +422,44 @@ def splitTrainVal(train_path, train, label): return train, label, val_tr, val_lb -def OpenDataFiles(path, run_60, SCALE): - if run_60: - train_path = path + 'train60/' - else: - train_path = path + 'train/' - # Initialize in able to concatenate - data20_gt = data60_gt = data10 = data20 = data60 = None - # train = label = None +def OpenDataFiles( + path: str, run_60: bool, SCALE: int +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """From path with train patches, return numpy array with train and val patches.""" + train_path = path + # Create list from path - fileList = [os.path.basename(x) for x in sorted(glob.glob(train_path + '*SAFE'))] - for dset in fileList: - data10_new = np.load(train_path + dset + '/data10.npy') - data20_new = np.load(train_path + dset + '/data20.npy') - data10 = np.concatenate((data10, data10_new)) if data10 is not None else data10_new - data20 = np.concatenate((data20, data20_new)) if data20 is not None else data20_new - if run_60: - data60_gt_new = np.load(train_path + dset + '/data60_gt.npy') - data60_new = np.load(train_path + dset + '/data60.npy') - data60_gt = np.concatenate((data60_gt, data60_gt_new)) if data60_gt is not None else data60_gt_new - data60 = np.concatenate((data60, data60_new)) if data60 is not None else data60_new + fileList = [os.path.basename(x) for x in sorted(glob.glob(train_path + "*SAFE"))] + for i, dset in enumerate(fileList): + print(f"Loading {dset}...") + if i == 0: + data10 = np.load(train_path + dset + "/data10.npy") + data20 = np.load(train_path + dset + "/data20.npy") else: - data20_gt_new = np.load(train_path + dset + '/data20_gt.npy') - data20_gt = np.concatenate((data20_gt, data20_gt_new)) if data20_gt is not None else data20_gt_new + data10_new = np.load(train_path + dset + "/data10.npy") + data20_new = np.load(train_path + dset + "/data20.npy") + data10 = np.concatenate((data10, data10_new)) + data20 = np.concatenate((data20, data20_new)) + if run_60: + if i == 0: + data60_gt = np.load(train_path + dset + "/data60_gt.npy") + data60 = np.load(train_path + dset + "/data60.npy") + else: + data60_gt_new = np.load(train_path + dset + "/data60_gt.npy") + data60_new = np.load(train_path + dset + "/data60.npy") + data60_gt = np.concatenate((data60_gt, data60_gt_new)) + data60 = np.concatenate((data60, data60_new)) + + else: + if i == 0: + data20_gt = np.load(train_path + dset + "/data20_gt.npy") + else: + data20_gt_new = np.load(train_path + dset + "/data20_gt.npy") + data20_gt = np.concatenate((data20_gt, data20_gt_new)) + + print(f"Loaded!") + if SCALE: data10 /= SCALE data20 /= SCALE @@ -324,22 +475,25 @@ def OpenDataFiles(path, run_60, SCALE): return splitTrainVal(train_path, [data10, data20], data20_gt) -def OpenDataFilesTest(path, run_60, SCALE, true_scale=False): +def OpenDataFilesTest( + path: str, run_60: bool, SCALE: int, true_scale: bool = False +) -> Tuple[np.ndarray, np.ndarray]: + """From path with patches, return numpy array with patches used for inference.""" if not SCALE: SCALE = 1 - data10 = np.load(path + '/data10.npy') - data20 = np.load(path + '/data20.npy') + data10 = np.load(path + "/data10.npy") + data20 = np.load(path + "/data20.npy") data10 /= SCALE data20 /= SCALE if run_60: - data60 = np.load(path + '/data60.npy') + data60 = np.load(path + "/data60.npy") data60 /= SCALE train = [data10, data20, data60] else: train = [data10, data20] - with open(path + '/roi.json') as data_file: + with open(path + "/roi.json") as data_file: data = json.load(data_file) image_size = [(data[2] - data[0]), (data[3] - data[1])] @@ -350,39 +504,43 @@ def OpenDataFilesTest(path, run_60, SCALE, true_scale=False): return train, image_size -def downPixelAggr(img, SCALE=2): - from scipy import signal - import skimage.measure - from scipy.ndimage.filters import gaussian_filter - +def downPixelAggr(img: np.ndarray, SCALE: int = 2) -> np.ndarray: + """Down-scale array by scale factor. Apply gaussian blur and block reduce. """ if len(img.shape) == 2: img = np.expand_dims(img, axis=-1) img_blur = np.zeros(img.shape) # Filter the image with a Gaussian filter for i in range(0, img.shape[2]): - img_blur[:, :, i] = gaussian_filter(img[:, :, i], 1/SCALE) + img_blur[:, :, i] = gaussian_filter(img[:, :, i], 1 / SCALE) # New image dims - new_dims = tuple(s//SCALE for s in img.shape) - img_lr = np.zeros(new_dims[0:2]+(img.shape[-1],)) + new_dims = tuple(s // SCALE for s in img.shape) + img_lr = np.zeros(new_dims[0:2] + (img.shape[-1],)) # Iterate through all the image channels with avg pooling (pixel aggregation) for i in range(0, img.shape[2]): - img_lr[:, :, i] = skimage.measure.block_reduce(img_blur[:, :, i], (SCALE, SCALE), np.mean) + reduced = skimage.measure.block_reduce( + img_blur[:, :, i], (SCALE, SCALE), np.mean + ) + img_lr_shape = img_lr[:, :, i].shape + if reduced.shape != img_lr_shape: + reduced = reduced[: img_lr_shape[0], : img_lr_shape[1]] + img_lr[:, :, i] = reduced return np.squeeze(img_lr) -def recompose_images(a, border, size=None): +def recompose_images(a: np.ndarray, border: int, size=None) -> np.ndarray: + """ From array with patches recompose original image.""" if a.shape[0] == 1: images = a[0] else: # # This is done because we do not mirror the data at the image border # size = [s - border * 2 for s in size] - patch_size = a.shape[2]-border*2 + patch_size = a.shape[2] - border * 2 # print('Patch has dimension {}'.format(patch_size)) # print('Prediction has shape {}'.format(a.shape)) - x_tiles = int(ceil(size[1]/float(patch_size))) - y_tiles = int(ceil(size[0]/float(patch_size))) + x_tiles = int(ceil(size[1] / float(patch_size))) + y_tiles = int(ceil(size[0] / float(patch_size))) # print('Tiles per image {} {}'.format(x_tiles, y_tiles)) # Initialize image @@ -399,7 +557,14 @@ def recompose_images(a, border, size=None): xpoint = x * patch_size if xpoint > size[1] - patch_size: xpoint = size[1] - patch_size - images[:, ypoint:ypoint+patch_size, xpoint:xpoint+patch_size] = a[current_patch, :, border:a.shape[2]-border, border:a.shape[3]-border] + images[ + :, ypoint : ypoint + patch_size, xpoint : xpoint + patch_size + ] = a[ + current_patch, + :, + border : a.shape[2] - border, + border : a.shape[3] - border, + ] current_patch += 1 return images.transpose((1, 2, 0))