diff --git a/.gitignore b/.gitignore index 4983e60..a59e560 100755 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,107 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.DS_Store + logs -codes/__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index c983f16..e946b4b 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,32 @@ Official implementation of [Query2box: Reasoning over Knowledge Graphs in Vector [Hongyu Ren*](http://hyren.me), [Weihua Hu*](http://web.stanford.edu/~weihuahu/), [Jure Leskovec](https://cs.stanford.edu/people/jure/index.html), ICLR 2020. -## Requirements +## Installation + +To install in development mode, clone from GitHub +with the following: + +```bash +git clone https://github.com/hyren/query2box +cd query2box +pip install --editable . ``` -torch==1.2.0 -tensorboadX==1.6 + +`--editable` means that the code is symlinked into your +Python's `site-packages` so it doesn't need to be reinstalled +every time the code is changed. + +## Command Line Interface + +The `query2box` command line interface is installed automatically. It can be used +like in the following: + +```bash +$ CUDA_VISIBLE_DEVICES=0 query2box --do_train --cuda --do_valid --do_test \ + --data_path data/FB15k --model BoxTransE -n 128 -b 512 -d 400 -g 24 -a 1.0 \ + -lr 0.0001 --max_steps 300000 --cpu_num 1 --test_batch_size 16 --center_reg 0.02 \ + --geo box --task 1c.2c.3c.2i.3i.ic.ci.2u.uc --stepsforpath 300000 --offset_deepsets inductive \ + --center_deepsets eleattention --print_on_screen ``` ## Run diff --git a/example.sh b/example.sh index af5e3dd..2a4a5b5 100755 --- a/example.sh +++ b/example.sh @@ -1,18 +1,18 @@ #!/bin/bash - CUDA_VISIBLE_DEVICES=0 python3.5 -u codes/run.py --do_train --cuda --do_valid --do_test \ + CUDA_VISIBLE_DEVICES=0 query2box --do_train --cuda --do_valid --do_test \ --data_path data/FB15k --model BoxTransE -n 128 -b 512 -d 400 -g 24 -a 1.0 \ -lr 0.0001 --max_steps 300000 --cpu_num 1 --test_batch_size 16 --center_reg 0.02 \ --geo box --task 1c.2c.3c.2i.3i.ic.ci.2u.uc --stepsforpath 300000 --offset_deepsets inductive --center_deepsets eleattention \ --print_on_screen - CUDA_VISIBLE_DEVICES=1 python3.5 -u codes/run.py --do_train --cuda --do_valid --do_test \ + CUDA_VISIBLE_DEVICES=1 query2box --do_train --cuda --do_valid --do_test \ --data_path data/FB15k-237 --model BoxTransE -n 128 -b 512 -d 400 -g 24 -a 1.0 \ -lr 0.0001 --max_steps 300000 --cpu_num 1 --test_batch_size 16 --center_reg 0.02 \ --geo box --task 1c.2c.3c.2i.3i.ic.ci.2u.uc --stepsforpath 300000 --offset_deepsets inductive --center_deepsets eleattention \ --print_on_screen - CUDA_VISIBLE_DEVICES=2 python3.5 -u codes/run.py --do_train --cuda --do_valid --do_test \ + CUDA_VISIBLE_DEVICES=2 query2box --do_train --cuda --do_valid --do_test \ --data_path data/NELL --model BoxTransE -n 128 -b 512 -d 400 -g 24 -a 1.0 \ -lr 0.0001 --max_steps 300000 --cpu_num 1 --test_batch_size 16 --center_reg 0.02 \ --geo box --task 1c.2c.3c.2i.3i.ic.ci.2u.uc --stepsforpath 300000 --offset_deepsets inductive --center_deepsets eleattention \ diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..e6728ac --- /dev/null +++ b/setup.cfg @@ -0,0 +1,46 @@ +########################## +# Setup.py Configuration # +########################## +[metadata] +name = query2box +version = 1.0.0 +long_description = file: README.md +long_description_content_type = text/markdown + +# URLs associated with the project +url = https://github.com/hyren/query2box +download_url = https://github.com/hyren/query2box/releases +project_urls = + Bug Tracker = https://github.com/hyren/query2box/issues + Source Code = https://github.com/hyren/query2box + +# Author information +author = Hongyu Ren +# author_email = ... + +# License Information +license = MIT +license_file = LICENSE + +[options] +install_requires = + torch + tensorboardX + tqdm + +# Random options +zip_safe = false +include_package_data = True +python_requires = >=3.5 + +# Where is my code +packages = find: +package_dir = + = src + +[options.packages.find] +where = src + +[options.entry_points] +console_scripts = + query2box = query2box.run:main diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..a78fbfd --- /dev/null +++ b/setup.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- + +"""Setup module.""" + +import setuptools + +if __name__ == '__main__': + setuptools.setup() diff --git a/src/query2box/__main__.py b/src/query2box/__main__.py new file mode 100644 index 0000000..df412a6 --- /dev/null +++ b/src/query2box/__main__.py @@ -0,0 +1,4 @@ +from query2box.run import main + +if __name__ == '__main__': + main() diff --git a/codes/dataloader.py b/src/query2box/dataloader.py similarity index 100% rename from codes/dataloader.py rename to src/query2box/dataloader.py diff --git a/codes/model.py b/src/query2box/model.py similarity index 98% rename from codes/model.py rename to src/query2box/model.py index f44d1a5..6d4bb40 100644 --- a/codes/model.py +++ b/src/query2box/model.py @@ -10,10 +10,11 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader -from dataloader import * +from query2box.dataloader import * import random import pickle import math + def Identity(x): return x diff --git a/codes/run.py b/src/query2box/run.py similarity index 99% rename from codes/run.py rename to src/query2box/run.py index f293b5a..51c7a8f 100644 --- a/codes/run.py +++ b/src/query2box/run.py @@ -13,10 +13,11 @@ import numpy as np import torch +from tqdm import tqdm, trange from torch.utils.data import DataLoader -from model import Query2box -from dataloader import * +from query2box.model import Query2box +from query2box.dataloader import * from tensorboardX import SummaryWriter import time import pickle @@ -172,8 +173,11 @@ def log_metrics(mode, step, metrics): ''' for metric in metrics: logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric])) - -def main(args): + +def main(): + main_helper(parse_args()) + +def main_helper(args): set_global_seed(args.seed) args.test_batch_size = 1 assert args.bn in ['no', 'before', 'after'] @@ -500,7 +504,7 @@ def main(args): num_params += np.prod(param.size()) logging.info('Parameter Number: %d' % num_params) - if args.cuda: + if args.cuda and torch.cuda.is_available(): query2box = query2box.cuda() if args.do_train: @@ -825,7 +829,7 @@ def evaluate_train(): else: begin_pq_step = args.max_steps - args.stepsforpath #Training Loop - for step in range(init_step, args.max_steps): + for step in trange(init_step, args.max_steps, desc='Training', unit='step', unit_scale=True): # print ("begining training step", step) # if step == 100: # exit(-1) @@ -901,7 +905,7 @@ def evaluate_train(): training_logs = [] if args.do_valid and step % args.valid_steps == 0: - logging.info('Evaluating on Valid Dataset...') + tqdm.write('Evaluating on Valid Dataset...') evaluate_val() save_variable_list = { @@ -933,4 +937,4 @@ def evaluate_train(): if __name__ == '__main__': - main(parse_args()) \ No newline at end of file + main() \ No newline at end of file