Skip to content

Code for our method Decomposition-Aware Distibutional Optimization (DADO). Accompanies our paper, "Leveraging Discrete Function Decomposability for Scientific Design".

Notifications You must be signed in to change notification settings

james-bowden/DADO

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Decomposition-Aware Distributional Optimization (DADO)

Overview

This repository accompanies our paper, "Leveraging Discrete Function Decomposability for Scientific Design". It contains code to replicate experiments, download relevant datasets, etc. Synthetic experiments can be run using experiments/synthetic/sweep_hp.py (see file for command line arguments). Protein experiments can be run using experiments/proteins/sweep_hp.py (see file for command line arguments). For protein experiments, one will additionally fit a decomposed predictive model beforehand. Selection of this model's hyperparameters by cross validation can be done in experiments/proteins/cv_decomposed_model.py, though we provide hyperparams resulting from our sweep as defaults.

Distributional Optimization

In general, the most important hyperparameters for a distributional optimization algorithm, specifically EDAs, are the number of iterations you run it for, how many samples you draw at each iteration, and how many gradient steps you take on those samples before drawing fresh samples from your latest search distribution.

You could, in principle, run a while loop over iterations and just check the results to decide when to stop, or use a condition. Unfortunately, the first isn't a great option when working with Jax because compiling generally requires knowing up front what operations will be done, and Jax-compiled functions and loops don't support side-effects, of which e.g., writing to a file or printing to the command line are instances. Our code only supports pre-specifying the number of iterations currently. At each iteration, we take 1 gradient step by default, but if sampling your model is expensive or evaluating samples is expensive, one might reasonably want to use more gradient steps. There's been a lot of thought about this in RL that you might consult, particularly the continuum between a fully online and a partially offline policy optimization algorithm.

As for how many samples to draw, the short answer is as many as possible / as your GPU can parallelize. Beyond that, there's some tradeoff between drawing more samples and having each iteration take a longer time, and drawing fewer but getting to run more iterations. Notably, for some problems if you draw a lot of samples at each iteration, it will be really easy for a naive EDA to solve. If your problem is easy in this manner and you care primarily about solving it, then great, by all means. For more difficult problems, especially in combinatorially large design spaces, DADO should consume samples more efficiently than a naive EDA. The case we should try to avoid, since our method updates via weighted regression, is one where we have too few samples and update the search distribution too drastically, such that it collapses to some local optimum before having a chance to meaningfully move around the search space.

Hyperparameters that we sweep include: the shaping function (we use exp(Q / beta)), the learning rate for gradient updates. Particularly, we sweep a range of temperatures and learning rates because these things are hard to set non-empirically.

Other hyperparameters include: the search distribution architecture, the optimizer, and everything to do with the decomposed predictive model. In our code we provide CLI arguments to set the seed, choose how many replicates to run during the sweep phase, how many replicates to run on the chosen hyperparameters, which GPUs to use on a single server, and how many processes to run in parallel (for smaller problems, we can often fit 5 or more runs on a single 16GB GPU).

Where do I get a decomposition of my objective function?

For non-synthetic experiments, choosing a decomposition of the objective function is required. This can be represented in different ways but most fundamentally, one must define a graph in which the nodes correspond to design variables (i.e., dimensions of $x$) and an edge is present between two nodes if they directly interact to influence $f(x)$ (e.g., you expect there to be a pair component function $f_{i,j}(x_i,x_j)$ or some higher-order component function that includes both $i$ and $j$ along with other nodes too).

For proteins, we provide a method which takes a predicted 3D structure from AlphaFold3, and uses a contact map (with contact cutoff of 4.5 A), which can be interpreted as an adjacency matrix, as the functional graph. This is converted into a junction tree, and then its component functions are fit on assay-labeled data. The predicted 3D structures we used are in src/problems/real/data/af3_structures.

Such an approach will not work well for all proteins -- particularly if the decomposed model cannot fit the data accurately, which is likely if the protein has multiple conformations mediating its function. In these cases, and in cases outside of proteins, some other procedure should be used to come up with the functional graph. A domain expert might start from a fully-connected graph over design variables and eliminate edges for variables that they don't think are directly coupled with respect to the objective function. One might try to learn the decomposition graph directly from data, via either structure learning or epistasis. Note that generally, decompositions that are less connected or more tree-like can be more efficiently optimized by our method compared to more connected graphs.

There are some classes of problems for which decompositions might be chosen, i.e., problems for which the topology is a part of the design specification and not already fixed in the objective. For instance, there isn't a set way that lenses in a telescope have to be arranged a priori. So a telescope designer might have an outer loop optimizing the topology of lenses, and an inner loop which uses DADO to optimize the individual lens parameters given the topology. Here, the decomposition would be specified by the outer loop optimizer and considered fixed within the inner loop. Decomposability is given here because each lens only passes light to its direct optical neighbors, no matter the configuration -- but which other lens are neighbors depends on the topology. Similar setups are present in other hardware design problems such as circuit design. Another notable class of problems where the decomposition is chosen is in Bayesian optimization, particularly the literature which either learns a decomposition at each acquisition round, or chooses random decompositions. These decompositions must then be optimized to determine which designs are evaluated; our method could be used to effectively optimize an acquisition function on a decomposed surrogate model of the objective.

Extensions

Some interesting directions not covered substantively in our paper, but compatible with / already implemented in our code include: initializing the search model to some distribution (perhaps the support of your dataset), regularizing the search model to some prior distribution, re-using samples from previous EDA iterations via a replay buffer and/or some sort of importance sampling, learned value functions, transformer-based predictive models, transformer-based search models (factorized models are implemented but haven't been tested and tuned), ...

About

Code for our method Decomposition-Aware Distibutional Optimization (DADO). Accompanies our paper, "Leveraging Discrete Function Decomposability for Scientific Design".

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages