diff --git a/README.md b/README.md index c41b0b0..e97f4ae 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,21 @@ # HoVpred +## Installation -usage: -```python +The environment requires **Linux with CUDA 11.0**. + +```bash +conda env create -f environment.yml +conda activate hovpred +``` + +## Data + +The training data originates from NIST and DIPPR databases, which are not freely available. The CSV files in `data/` contain only the molecular identifiers (SMILES) and temperatures, with the enthalpy of vaporization values redacted. As a result, **only the prediction workflow can be reproduced** using the provided pre-trained model weights; model training and cross-validation require access to the original databases. + +## Usage + +``` python main.py [-h] [-predict] [-watsoneq] [-K_fold] [-maxatoms MAXATOMS] [-lr LR] [-epoch EPOCH] [-batchsize BATCHSIZE] [-layers LAYERS] [-heads HEADS] [-residcon] [-explicitH] [-dropout DROPOUT] @@ -10,8 +23,9 @@ python main.py [-h] [-predict] [-watsoneq] [-K_fold] [-maxatoms MAXATOMS] [-loss LOSS] [-sw_thr SW_THR] [-sw_decay SW_DECAY] ``` -optional arguments: -```python +Optional arguments: + +``` -h, --help show this help message and exit -predict If specified, prediction is carried out (default=False) @@ -35,14 +49,17 @@ optional arguments: kl_div_normal ``` +### Prediction + +Prepare a `molecules_to_predict.csv` file with two columns: `smiles` and `temperature`, then run: -example commands: -```python -# HoV prediction with uncertainty quantification - 'molecules_to_predict.csv' file is needed as an input which contains two columns: 'smiles' and 'temperature' +```bash python main.py -predict -modelname best_211007 -loss kl_div_normal +``` -# model training -python main.py -modelname test_model -loss kl_div_normal -# non-default hyperparameters can also be tested by adding more arguments +### Training (requires NIST/DIPPR data) +```bash +python main.py -modelname test_model -loss kl_div_normal ``` +non-default hyperparameters can also be tested by adding more arguments \ No newline at end of file diff --git a/data/README.md b/data/README.md index c61a3e3..0eea2c7 100644 --- a/data/README.md +++ b/data/README.md @@ -1,5 +1,6 @@ # HoV databases -- 'Data.csv': The dataframe with predefined training/validation/test set splits (to do data splits consistent with those in literature and compare the accuracy with literature) -- 'Data_for_kfold.csv' for 10-fold cross-validation: The 'main.py' performs random data split into 10 training/validation folds + one held-out test set -- Values from NIST-WTT were redacted +- `Data.csv`: The dataframe with predefined training/validation/test set splits (to do data splits consistent with those in literature and compare the accuracy with literature) +- `Data_for_kfold.csv` for 10-fold cross-validation: `main.py` performs random data split into 10 training/validation folds + one held-out test set + +**Note:** The enthalpy of vaporization values originating from NIST and DIPPR have been redacted because these databases are not freely redistributable. The CSV files retain SMILES and temperature columns so that the file structure is preserved, but training and cross-validation cannot be run without populating the missing values from the original sources. diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..be7bbf3 --- /dev/null +++ b/environment.yml @@ -0,0 +1,228 @@ +name: hovpred +channels: + - dglteam + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=4.5=1_gnu + - anyio=3.6.1=py38h578d9bd_0 + - argon2-cffi=21.3.0=pyhd8ed1ab_0 + - argon2-cffi-bindings=21.2.0=py38h7f8727e_0 + - asttokens=2.0.5=pyhd8ed1ab_0 + - attrs=21.4.0=pyhd8ed1ab_0 + - babel=2.10.1=pyhd8ed1ab_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - backports=1.0=py_2 + - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 + - beautifulsoup4=4.11.1=pyha770c72_0 + - blas=1.0=mkl + - bleach=5.0.0=pyhd8ed1ab_0 + - boost=1.74.0=py38hc10631b_3 + - boost-cpp=1.74.0=h312852a_4 + - brotlipy=0.7.0=py38h497a2fe_1001 + - bzip2=1.0.8=h7f98852_4 + - ca-certificates=2022.5.18.1=ha878542_0 + - cairo=1.16.0=h6cf1ce9_1008 + - certifi=2022.5.18.1=py38h578d9bd_0 + - cffi=1.15.0=py38hd667e15_1 + - cryptography=37.0.1=py38h9ce1e76_0 + - cycler=0.10.0=py_2 + - decorator=5.1.1=pyhd8ed1ab_0 + - defusedxml=0.7.1=pyhd8ed1ab_0 + - dgl-cuda11.0=0.7.0=py38_0 + - entrypoints=0.4=pyhd8ed1ab_0 + - executing=0.8.3=pyhd8ed1ab_0 + - flit-core=3.7.1=pyhd8ed1ab_0 + - fontconfig=2.13.1=hba837de_1005 + - freetype=2.10.4=h0708190_1 + - gettext=0.19.8.1=h0b5b191_1005 + - icu=68.1=h58526e2_0 + - importlib-metadata=4.11.4=py38h578d9bd_0 + - importlib_metadata=4.11.4=hd8ed1ab_0 + - importlib_resources=5.7.1=pyhd8ed1ab_1 + - intel-openmp=2021.3.0=h06a4308_3350 + - ipykernel=5.5.5=py38hd0cf306_0 + - ipython=8.4.0=py38h578d9bd_0 + - ipython_genutils=0.2.0=py_1 + - jbig=2.1=h7f98852_2003 + - jedi=0.18.1=py38h578d9bd_1 + - jinja2=3.1.2=pyhd8ed1ab_1 + - joblib=1.0.1=pyhd8ed1ab_0 + - jpeg=9d=h36c2ea0_0 + - json5=0.9.5=pyh9f0ad1d_0 + - jsonschema=4.6.0=pyhd8ed1ab_0 + - jupyter_client=7.0.6=pyhd8ed1ab_0 + - jupyter_core=4.10.0=py38h578d9bd_0 + - jupyter_server=1.17.1=pyhd8ed1ab_0 + - jupyterlab=3.4.3=pyhd8ed1ab_0 + - jupyterlab_pygments=0.2.2=pyhd8ed1ab_0 + - jupyterlab_server=2.14.0=pyhd8ed1ab_0 + - kiwisolver=1.3.1=py38h1fd1430_1 + - lcms2=2.12=hddcbb42_0 + - ld_impl_linux-64=2.35.1=h7274673_9 + - lerc=2.2.1=h9c3ff4c_0 + - libblas=3.9.0=11_linux64_mkl + - libcblas=3.9.0=11_linux64_mkl + - libdeflate=1.7=h7f98852_5 + - libffi=3.3=he6710b0_2 + - libgcc-ng=9.3.0=h5101ec6_17 + - libgfortran-ng=7.5.0=ha8ba4b0_17 + - libgfortran4=7.5.0=ha8ba4b0_17 + - libglib=2.68.3=h3e27bee_0 + - libgomp=9.3.0=h5101ec6_17 + - libiconv=1.16=h516909a_0 + - liblapack=3.9.0=11_linux64_mkl + - libpng=1.6.37=h21135ba_2 + - libsodium=1.0.18=h36c2ea0_1 + - libstdcxx-ng=9.3.0=hd4cf53a_17 + - libtiff=4.3.0=hf544144_1 + - libuuid=2.32.1=h7f98852_1000 + - libwebp-base=1.2.0=h7f98852_2 + - libxcb=1.13=h7f98852_1003 + - libxml2=2.9.12=h72842e0_0 + - lz4-c=1.9.3=h9c3ff4c_1 + - markupsafe=2.0.1=py38h497a2fe_0 + - matplotlib-base=3.3.4=py38h0efea84_0 + - matplotlib-inline=0.1.3=pyhd8ed1ab_0 + - mistune=0.8.4=py38h497a2fe_1004 + - mkl=2021.3.0=h06a4308_520 + - mkl-service=2.4.0=py38h7f8727e_0 + - mkl_fft=1.3.0=py38h42c9631_2 + - mkl_random=1.2.2=py38h51133e4_0 + - nbclassic=0.3.7=pyhd8ed1ab_0 + - nbclient=0.6.4=pyhd8ed1ab_1 + - nbconvert=6.5.0=pyhd8ed1ab_0 + - nbconvert-core=6.5.0=pyhd8ed1ab_0 + - nbconvert-pandoc=6.5.0=pyhd8ed1ab_0 + - nbformat=5.4.0=pyhd8ed1ab_0 + - ncurses=6.2=he6710b0_1 + - nest-asyncio=1.5.5=pyhd8ed1ab_0 + - networkx=2.6.2=pyhd3eb1b0_0 + - notebook=6.4.12=pyha770c72_0 + - notebook-shim=0.1.0=pyhd8ed1ab_0 + - olefile=0.46=pyh9f0ad1d_1 + - openjpeg=2.4.0=hb52868f_1 + - openssl=1.1.1k=h7f98852_0 + - packaging=21.3=pyhd8ed1ab_0 + - pandas=1.3.1=py38h1abd341_0 + - pandoc=2.18=ha770c72_0 + - pandocfilters=1.5.0=pyhd8ed1ab_0 + - parso=0.8.3=pyhd8ed1ab_0 + - pcre=8.45=h9c3ff4c_0 + - pexpect=4.8.0=pyh9f0ad1d_2 + - pickleshare=0.7.5=py_1003 + - pillow=8.3.1=py38h8e6f84c_0 + - pip=21.0.1=py38h06a4308_0 + - pixman=0.40.0=h36c2ea0_0 + - prometheus_client=0.14.1=pyhd8ed1ab_0 + - prompt-toolkit=3.0.29=pyha770c72_0 + - pthread-stubs=0.4=h36c2ea0_1001 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pure_eval=0.2.2=pyhd8ed1ab_0 + - pycairo=1.20.1=py38hf61ee4a_0 + - pycparser=2.21=pyhd8ed1ab_0 + - pygments=2.12.0=pyhd8ed1ab_0 + - pyopenssl=22.0.0=pyhd8ed1ab_0 + - pyparsing=2.4.7=pyh9f0ad1d_0 + - pyrsistent=0.18.0=py38heee7806_0 + - pysocks=1.7.1=py38h578d9bd_5 + - python=3.8.11=h12debd9_0_cpython + - python-dateutil=2.8.2=pyhd8ed1ab_0 + - python-fastjsonschema=2.15.3=pyhd8ed1ab_0 + - python_abi=3.8=2_cp38 + - pytz=2021.1=pyhd8ed1ab_0 + - pyzmq=19.0.2=py38ha71036d_2 + - rdkit=2021.03.4=py38hf8acc3d_0 + - readline=8.1=h27cfd23_0 + - reportlab=3.5.68=py38hadf75a6_0 + - scikit-learn=0.24.2=py38hdc147b9_0 + - scipy=1.6.2=py38had2a1c9_1 + - send2trash=1.8.0=pyhd8ed1ab_0 + - setuptools=52.0.0=py38h06a4308_0 + - sniffio=1.2.0=py38h578d9bd_3 + - soupsieve=2.3.1=pyhd8ed1ab_0 + - sqlalchemy=1.3.23=py38h497a2fe_0 + - sqlite=3.36.0=hc218d9a_0 + - stack_data=0.2.0=pyhd8ed1ab_0 + - terminado=0.15.0=py38h578d9bd_0 + - threadpoolctl=2.2.0=pyh8a188c0_0 + - tinycss2=1.1.1=pyhd8ed1ab_0 + - tk=8.6.10=hbc83047_0 + - tornado=6.1=py38h497a2fe_1 + - traitlets=5.2.2.post1=pyhd8ed1ab_0 + - wcwidth=0.2.5=pyh9f0ad1d_2 + - webencodings=0.5.1=py_1 + - websocket-client=1.3.2=pyhd8ed1ab_0 + - wheel=0.37.0=pyhd3eb1b0_0 + - xorg-kbproto=1.0.7=h7f98852_1002 + - xorg-libice=1.0.10=h7f98852_0 + - xorg-libsm=1.2.3=hd9c2040_1000 + - xorg-libx11=1.7.2=h7f98852_0 + - xorg-libxau=1.0.9=h7f98852_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xorg-libxext=1.3.4=h7f98852_1 + - xorg-libxrender=0.9.10=h7f98852_1003 + - xorg-renderproto=0.11.1=h7f98852_1002 + - xorg-xextproto=7.3.0=h7f98852_1002 + - xorg-xproto=7.0.31=h7f98852_1007 + - xz=5.2.5=h7b6447c_0 + - zeromq=4.3.4=h9c3ff4c_0 + - zipp=3.8.0=pyhd8ed1ab_0 + - zlib=1.2.11=h7b6447c_3 + - zstd=1.5.0=ha95c52a_0 + - pip: + - absl-py==0.15.0 + - aqme==1.3.0 + - ase==3.22.1 + - astunparse==1.6.3 + - cachetools==4.2.2 + - cclib==1.7.2 + - charset-normalizer==2.0.4 + - cloudpickle==2.1.0 + - flatbuffers==1.12 + - gast==0.3.3 + - google-auth==1.34.0 + - google-auth-oauthlib==0.4.5 + - google-pasta==0.2.0 + - grpcio==1.32.0 + - h5py==2.10.0 + - idna==3.2 + - keras==2.9.0 + - keras-preprocessing==1.1.2 + - libclang==14.0.6 + - llvmlite==0.38.1 + - markdown==3.3.4 + - nfp==0.3.0 + - numba==0.55.2 + - numpy==1.19.5 + - oauthlib==3.1.1 + - opt-einsum==3.3.0 + - periodictable==1.6.1 + - progress==1.6 + - protobuf==3.17.3 + - pyasn1==0.4.8 + - pyasn1-modules==0.2.8 + - pyyaml==6.0 + - requests==2.26.0 + - requests-oauthlib==1.3.0 + - rsa==4.7.2 + - seaborn==0.11.2 + - shap==0.40.0 + - six==1.15.0 + - slicer==0.0.7 + - tensorboard==2.9.1 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.0 + - tensorflow==2.4.0 + - tensorflow-addons==0.14.0 + - tensorflow-estimator==2.4.0 + - tensorflow-gpu==2.4.0 + - tensorflow-io-gcs-filesystem==0.26.0 + - termcolor==1.1.0 + - tqdm==4.62.3 + - typeguard==2.13.3 + - typing-extensions==3.7.4.3 + - urllib3==1.26.6 + - werkzeug==2.0.1 + - wrapt==1.12.1 \ No newline at end of file