This is the first of the three repositories accompanying the paper Induced Model Matching: Restricted Models Help Train Full-Featured Models (NeurIPS 2024)
@inproceedings{muneeb2024induced,
title = {Induced Model Matching: Restricted Models Help Train Full-Featured Models},
author = {Usama Muneeb and Mesrob I Ohannessian},
booktitle = {The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year = {2024},
url = {https://openreview.net/forum?id=iW0wXE0VyR}
}This repository demonstrates Induced Model Matching (IMM) in the training of a simple logistic regression classifier as a proof of concept. The full-featured dataset has three features
Repositories for other implementations of IMM: IMM in Language Modeling | IMM in learning MDPs (REINFORCE)
The train_and_test.py (if run with all the default options), will train a model without any restricted model information and print the test accuracy after 5000 (default) epochs of training. In order to replicate each of the reported curves in the paper, there are options provided in this file that can be set accordingly.
While we document the relevant options of train_and_test.py below, please do not run it directly. The plots require 300 Monte Carlo runs for each configuration, and running them sequentially is prohibitively long. Instead, we have provided a run_all.sh BASH script that will parallelize these 300 runs using multiprocessing.
To maximally utilize multiprocessing, you are encouraged to edit PARALLEL_JOBS inside run_all.sh, as per the provided instructions before calling it as:
./run_all.shAlternatively, if you simply want to generate plots from cached files, you can skip directly to the plotting section below (cached CSV files for 300 runs have also been provided in this repository).
Important
Secondary Objective Coefficient One of --lambda_ratio or --lambda_param parameters can be used to set the coefficient for the secondary objective (i.e. IMM or noising). While --lambda_param will set --lambda_ratio will set train_and_test.py will determine
Overall Objective The overall objective function minimized is
Default Coefficient If none of --lambda_ratio or --lambda_param is set, a default of
No Restricted Model Information (or
python train_and_test.pyThe --n_train parameter sets the dataset size (if not specified, defaults to 10) and an optional --dataset_seed value may also be specified for reproducibility (default is no seed).
Incorporating Restricted Model via Interpolation
Since our baseline (i.e. Xie et al) also discusses interpolation, we include it in our plots. Interpolation also trains without any restricted model information, and the restricted model information is incorporated during inference. Performing interpolation requires requires an additional --transfer_type interpolate argument. The provided
python train_and_test.py --transfer_type interpolate --lambda_ratio -1Tip
Automatic --lambda_ratio to any negative value instructs the code to automatically determine the train_and_test.py. Unless we are running train_and_test.py for determining these rules via cross validation (more details in the paper appendices), we always use --lambda_ratio -1.
Incorporating Restricted Model via Noising
While our main baseline (i.e. Xie et al) performs noising in the context of language modeling, we do something similar for logistic regression. We discuss in the paper how single sample IMM is equivalent to the noising approach of Xie et al. In logistic regression, we do not perform multiset based sampling, rather we use a soft nearest neighbor density estimate (with a hyperparameter
Just like we did for interpolation above, we provide --lambda_ratio -1. Instead of --transfer_type interpolate, we use the default value for --transfer_type which is imm. And lastly, we override the default value of --alpha (which is 1) to a very large value, which in our case is 100 to mimic single sample IMM (or "noising").
python train_and_test.py --lambda_ratio -1 --alpha 100Sampled IMM
Similar to the noising command above, except that we remove --alpha 100 to revert to the default --alpha 1.
python train_and_test.py --lambda_ratio -1Serialized IMM
For serialized IMM, we also override the --imm_algorithm parameter (the default value of which is 'sampled').
python train_and_test.py --lambda_ratio -1 --imm_algorithm 'serialized'We repeat the sampled IMM command above, but add an extra --target_noise_ratio 0.2 flag to add 20% Bernoulli noise to the restricted targets (which so far were Bayes Optimal).
Caution
The word noise here shouldn't be confused with the previous usage (which is one of the ways to transfer knowledge of restricted model to larger model). This Bernoulli noise is only meant to reduce the quality of the target model.
(Sampled) IMM using Medium Quality Restricted Model
With 20% noise, what we get is the medium model quality.
python train_and_test.py --target_noise_ratio 0.2 --lambda_ratio -1(Sampled) IMM using Low Quality Restricted Model
With 20% noise, what we get is the low quality model
python train_and_test.py --target_noise_ratio 0.5 --lambda_ratio -1The CSVs for runs are generated by the run_all.sh script included in the repository. These CSVs can be used to generate the plots in the paper.
- To generate Figure 1, run
python plot_main.py. - To generate Figure 8, run
python plot_main.py --plot_type 'quality'. - To generate Figure 10, run
python plot_main.py --plot_type 'computation'.
By default, the cached CSVs in csv_cached directory will be used. This can be overridden by adding an extra --csv_dir csv flag.
This is used to obtain Figure 4 in the paper that shows the performance of the induced model on the feature-restricted task, after being trained with varying
./schedulers/measure_restricted_perf.shTo visualize the generated CSVs
python plot_restricted_perf.pyAgain, by default, the cached CSVs in csv_cached/restricted_perf directory will be used. This can be overridden by adding an extra --csv_dir csv/restricted_perf flag.
For completeness, we will also demonstrate how we obtain the plots in the paper that we use to create the hardcoded rules that give us
./schedulers/tune_lambda.shTo visualize the generated CSVs
python plot_lambda_search.pyAgain, by default, the cached CSVs in csv_cached/lambda_search directory will be used. This can be overridden by adding an extra --csv_dir csv/lambda_search flag.