Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions docs/cli/classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import os
import sys
import glob
import re
from pathlib import Path
from collections import defaultdict

lib_path = Path.cwd().parent
sys.path.insert(0, str(lib_path))


def classify_file_category(path):
relative_path = Path(path).relative_to(lib_path)
filename = "/".join(relative_path.parts[1:]) or relative_path.as_posix()

if filename.startswith("linear"):
return "linear"
if filename.startswith(("torch", "nn")):
return "nn"
return "general"


def fetch_option_flags(flags):
flag_list = []

for flag in flags:
flag_list.append(
{
"name": flag["name"].replace("\\", ""),
"instruction": flag["name"].split("-")[-1],
"description": flag["description"],
}
)

return flag_list


def fetch_all_files():
main_files = [
os.path.join(lib_path, "main.py"),
os.path.join(lib_path, "linear_trainer.py"),
os.path.join(lib_path, "torch_trainer.py"),
]
lib_files = glob.glob(os.path.join(lib_path, "libmultilabel/**/*.py"), recursive=True)
file_set = set(map(os.path.abspath, main_files + lib_files))
return file_set


def find_config_usages_in_file(file_path, allowed_keys, category_set):
pattern = re.compile(r"\bconfig\.([a-zA-Z_][a-zA-Z0-9_]*)")

with open(file_path, "r", encoding="utf-8") as f:
lines = f.readlines()

if file_path.endswith("main.py"):
for idx in range(len(lines)):
if lines[idx].startswith("def main("):
lines = lines[idx:]
break
all_str = " ".join(lines)
matches = set(pattern.findall(all_str)) & allowed_keys

category = classify_file_category(file_path)
for key in matches:
category_set[category].add(key)


def move_duplicates_together(data):
duplicates = (data["general"] & data["linear"]) | (data["general"] & data["nn"]) | (data["linear"] & data["nn"])
data["general"].update(duplicates)
data["linear"] -= duplicates
data["nn"] -= duplicates


def classify(raw_flags):
category_set = {"general": set(), "linear": set(), "nn": set()}

flags = fetch_option_flags(raw_flags)
allowed_keys = set(flag["instruction"] for flag in flags)
file_set = fetch_all_files()

for file_path in file_set:
find_config_usages_in_file(file_path, allowed_keys, category_set)

move_duplicates_together(category_set)

result = defaultdict(list)
for flag in raw_flags:
instr = flag["name"].replace("\\", "").split("-")[-1]
flag_name = flag["name"].replace("--", r"\-\-")

matched = False
for category, keys in category_set.items():
if instr in keys:
result[category].append({"name": flag_name, "description": flag["description"]})
matched = True
break

if not matched:
result["general"].append({"name": flag_name, "description": flag["description"]})

return result
53 changes: 40 additions & 13 deletions docs/cli/genflags.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import os

sys.path.insert(1, os.path.join(sys.path[0], "..", ".."))

import main

from classifier import classify


class FakeParser(dict):
def __init__(self):
Expand All @@ -29,21 +32,45 @@ def add_argument(
parser.add_argument("-c", "--config", help="Path to configuration file")
main.add_all_arguments(parser)

classified = classify(parser.flags)


def width_title(key, title):
return max(map(lambda f: len(f[key]), classified[title]))

def width(key):
return max(map(lambda f: len(f[key]), parser.flags))

def print_table(title, flags, intro):
print()
print(intro)
print()

wn = width("name")
wd = width("description")
wn = width_title("name", title)
wd = width_title("description", title)

print(
"""..
Do not modify this file. This file is generated by genflags.py.\n"""
print("=" * wn, "=" * wd)
print("Name".ljust(wn), "Description".ljust(wd))
print("=" * wn, "=" * wd)
for flag in flags:
print(flag["name"].ljust(wn), flag["description"].ljust(wd))
print("=" * wn, "=" * wd)
print()


print_table(
"general",
classified["general"],
intro="**General options**:\n\
Common configurations shared across both linear and neural network trainers.",
)
print_table(
"linear",
classified["linear"],
intro="**Linear options**:\n\
Configurations specific to linear trainer.",
)
print_table(
"nn",
classified["nn"],
intro="**Neural network options**:\n\
Configurations specific to torch (neural networks) trainer.",
)
print("=" * wn, "=" * wd)
print("Name".ljust(wn), "Description".ljust(wd))
print("=" * wn, "=" * wd)
for flag in parser.flags:
print(flag["name"].ljust(wn), flag["description"].ljust(wd))
print("=" * wn, "=" * wd)
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"examples_dirs": "./examples", # path to your example scripts
"gallery_dirs": "auto_examples", # path to where to save gallery generated output
"plot_gallery": False,
"write_computation_times": False,
}

# bibtex files
Expand Down
123 changes: 68 additions & 55 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,50 @@


def add_all_arguments(parser):
# path / directory

parser.add_argument(
"--result_dir", default="./runs", help="The directory to save checkpoints and logs (default: %(default)s)"
"-h",
"--help",
action="help",
help="Quickstart: https://www.csie.ntu.edu.tw/~cjlin/libmultilabel/cli/quickstart.html",
)

parser.add_argument("--seed", type=int, help="Random seed (default: %(default)s)")

# choose model (linear / nn)
parser.add_argument("--linear", action="store_true", help="Train linear model")

# others
parser.add_argument("--cpu", action="store_true", help="Disable CUDA")
parser.add_argument("--silent", action="store_true", help="Enable silent mode")
parser.add_argument(
"--data_workers", type=int, default=4, help="Use multi-cpu core for data pre-processing (default: %(default)s)"
)
parser.add_argument(
"--embed_cache_dir",
type=str,
help="For parameter search only: path to a directory for storing embeddings for multiple runs. (default: %(default)s)",
)
parser.add_argument(
"--eval", action="store_true", help="Only run evaluation on the test set (default: %(default)s)"
)
parser.add_argument("--checkpoint_path", help="The checkpoint to warm-up with (default: %(default)s)")

# data
parser.add_argument("--data_name", default="unnamed_data", help="Dataset name (default: %(default)s)")
parser.add_argument(
"--data_name",
default="unnamed_data",
help="Dataset name for generating the output directory (default: %(default)s)",
)
parser.add_argument("--training_file", help="Path to training data (default: %(default)s)")
parser.add_argument("--val_file", help="Path to validation data (default: %(default)s)")
parser.add_argument("--test_file", help="Path to test data (default: %(default)s")
parser.add_argument("--test_file", help="Path to test data (default: %(default)s)")
parser.add_argument("--label_file", type=str, help="Path to a file holding all labels (default: %(default)s)")
parser.add_argument(
"--val_size",
type=float,
default=0.2,
help="Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set (default: %(default)s).",
help="Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set (default: %(default)s)",
)
parser.add_argument(
"--min_vocab_freq",
Expand Down Expand Up @@ -67,8 +96,24 @@ def add_all_arguments(parser):
help="Whether to add the special tokens for inputs of the transformer-based language model. (default: %(default)s)",
)

# model
parser.add_argument("--model_name", default="unnamed_model", help="Model to be used (default: %(default)s)")
parser.add_argument(
"--init_weight", default="kaiming_uniform", help="Weight initialization to be used (default: %(default)s)"
)
parser.add_argument(
"--loss_function", default="binary_cross_entropy_with_logits", help="Loss function (default: %(default)s)"
)

# pretrained vocab / embeddings
parser.add_argument("--vocab_file", type=str, help="Path to a file holding vocabuaries (default: %(default)s)")
parser.add_argument(
"--embed_file",
type=str,
help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)",
)

# train
parser.add_argument("--seed", type=int, help="Random seed (default: %(default)s)")
parser.add_argument(
"--epochs", type=int, default=10000, help="The number of epochs to train (default: %(default)s)"
)
Expand Down Expand Up @@ -109,15 +154,6 @@ def add_all_arguments(parser):
help="Whether the embeddings of each word is normalized to a unit vector (default: %(default)s)",
)

# model
parser.add_argument("--model_name", default="unnamed_model", help="Model to be used (default: %(default)s)")
parser.add_argument(
"--init_weight", default="kaiming_uniform", help="Weight initialization to be used (default: %(default)s)"
)
parser.add_argument(
"--loss_function", default="binary_cross_entropy_with_logits", help="Loss function (default: %(default)s)"
)

# eval
parser.add_argument(
"--eval_batch_size", type=int, default=256, help="Size of evaluating batches (default: %(default)s)"
Expand All @@ -138,28 +174,6 @@ def add_all_arguments(parser):
"--val_metric", default="P@1", help="The metric to select the best model for testing (default: %(default)s)"
)

# pretrained vocab / embeddings
parser.add_argument("--vocab_file", type=str, help="Path to a file holding vocabuaries (default: %(default)s)")
parser.add_argument(
"--embed_file", type=str, help="Path to a file holding pre-trained embeddings or the name of the pretrained GloVe embedding (default: %(default)s)"
)
parser.add_argument("--label_file", type=str, help="Path to a file holding all labels (default: %(default)s)")

# log
parser.add_argument(
"--save_k_predictions",
type=int,
nargs="?",
const=100,
default=0,
help="Save top k predictions on test set. k=%(const)s if not specified. (default: %(default)s)",
)
parser.add_argument(
"--predict_out_path",
default="./predictions.txt",
help="Path to the output file holding label results (default: %(default)s)",
)

# auto-test
parser.add_argument(
"--limit_train_batches",
Expand All @@ -180,24 +194,27 @@ def add_all_arguments(parser):
help="Percentage of test dataset to use for auto-testing (default: %(default)s)",
)

# others
parser.add_argument("--cpu", action="store_true", help="Disable CUDA")
parser.add_argument("--silent", action="store_true", help="Enable silent mode")
# log
parser.add_argument(
"--data_workers", type=int, default=4, help="Use multi-cpu core for data pre-processing (default: %(default)s)"
"--save_k_predictions",
type=int,
nargs="?",
const=100,
default=0,
help="Save top k predictions on test set. k=%(const)s if not specified. (default: %(default)s)",
)
parser.add_argument(
"--embed_cache_dir",
type=str,
help="For parameter search only: path to a directory for storing embeddings for multiple runs. (default: %(default)s)",
"--predict_out_path",
default="./predictions.txt",
help="Path to the output file holding label results (default: %(default)s)",
)

# path / directory
parser.add_argument(
"--eval", action="store_true", help="Only run evaluation on the test set (default: %(default)s)"
"--result_dir", default="./runs", help="The directory to save checkpoints and logs (default: %(default)s)"
)
parser.add_argument("--checkpoint_path", help="The checkpoint to warm-up with (default: %(default)s)")

# linear options
parser.add_argument("--linear", action="store_true", help="Train linear model")
parser.add_argument(
"--data_format",
type=str,
Expand All @@ -224,7 +241,10 @@ def add_all_arguments(parser):
"--tree_max_depth", type=int, default=10, help="Maximum depth of the tree (default: %(default)s)"
)
parser.add_argument(
"--tree_ensemble_models", type=int, default=1, help="Number of models in the tree ensemble (default: %(default)s)"
"--tree_ensemble_models",
type=int,
default=1,
help="Number of models in the tree ensemble (default: %(default)s)",
)
parser.add_argument(
"--beam_width",
Expand All @@ -239,13 +259,6 @@ def add_all_arguments(parser):
default=8,
help="the maximal number of labels inside a cluster (default: %(default)s)",
)
parser.add_argument(
"-h",
"--help",
action="help",
help="If you are trying to specify network config such as dropout or activation or config of the learning rate scheduler, use a yaml file instead. "
"See example configs in example_config",
)


def get_config():
Expand Down