Skip to content

Commit 460f8d6

Browse files
committed
Merge branch 'develop'
2 parents e326bcb + c5f2628 commit 460f8d6

File tree

143 files changed

+3362
-767
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

143 files changed

+3362
-767
lines changed

.gitignore

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,6 @@ jobs.txt
3030
*autonet.egg-info
3131
*.simg
3232

33-
# Metalearning data
34-
/metalearning_data/
35-
/metalearning_comparison_results/
36-
/meta_outputs/
37-
/metamodels/
38-
3933

4034
# Datasets
4135
/datasets/

README.md

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ $ git checkout develop
2222

2323
Install pytorch:
2424
https://pytorch.org/
25-
26-
Install autonet
25+
26+
Install Auto-PyTorch:
2727

2828
```sh
29+
$ cat requirements.txt | xargs -n 1 -L 1 pip install
2930
$ python setup.py install
3031
```
3132

@@ -46,7 +47,12 @@ X_train, X_test, y_train, y_test = \
4647
sklearn.model_selection.train_test_split(X, y, random_state=1)
4748

4849
# running Auto-PyTorch
49-
autoPyTorch = AutoNetClassification(log_level='info', max_runtime=300, min_budget=30, max_budget=90)
50+
autoPyTorch = AutoNetClassification("tiny_cs", # config preset
51+
log_level='info',
52+
max_runtime=300,
53+
min_budget=30,
54+
max_budget=90)
55+
5056
autoPyTorch.fit(X_train, y_train, validation_split=0.3)
5157
y_pred = autoPyTorch.predict(X_test)
5258

@@ -57,6 +63,67 @@ More examples with datasets:
5763

5864
```sh
5965
$ cd examples/
66+
67+
```
68+
69+
## Configuration
70+
71+
How to configure Auto-PyTorch for your needs:
72+
73+
```py
74+
75+
# Print all possible configuration options.
76+
AutoNetClassification().print_help()
77+
78+
# You can use the constructor to configure Auto-PyTorch.
79+
autoPyTorch = AutoNetClassification(log_level='info', max_runtime=300, min_budget=30, max_budget=90)
80+
81+
# You can overwrite this configuration in each fit call.
82+
autoPyTorch.fit(X_train, y_train, log_level='debug', max_runtime=900, min_budget=50, max_budget=150)
83+
84+
# You can use presets to configure the config space.
85+
# Available presets: full_cs, medium_cs (default), tiny_cs.
86+
# These are defined in autoPyTorch/core/presets.
87+
# tiny_cs is recommended if you want fast results with few resources.
88+
# full_cs is recommended if you have many resources and a very high search budget.
89+
autoPyTorch = AutoNetClassification("full_cs")
90+
91+
# Enable or disable components using the Auto-PyTorch config:
92+
autoPyTorch = AutoNetClassification(networks=["resnet", "shapedresnet", "mlpnet", "shapedmlpnet"])
93+
94+
# You can take a look at the search space.
95+
# Each hyperparameter belongs to a node in Auto-PyTorch's ML Pipeline.
96+
# The names of the hyperparameters are prefixed with the name of the node: NodeName:hyperparameter_name.
97+
# If a hyperparameter belongs to a component: NodeName:component_name:hyperparameter_name.
98+
autoPyTorch.get_hyperparameter_search_space()
99+
100+
# You can configure the search space of every hyperparameter of every component:
101+
from autoPyTorch import HyperparameterSearchSpaceUpdates
102+
search_space_updates = HyperparameterSearchSpaceUpdates()
103+
104+
search_space_updates.append(node_name="NetworkSelector",
105+
hyperparameter="shapedresnet:activation",
106+
value_range=["relu", "sigmoid"])
107+
search_space_updates.append(node_name="NetworkSelector",
108+
hyperparameter="shapedresnet:blocks_per_group",
109+
value_range=[2,5],
110+
log=False)
111+
autoPyTorch = AutoNetClassification(hyperparameter_search_space_updates=search_space_updates)
112+
```
113+
114+
Enable ensemble building:
115+
116+
```py
117+
from autoPyTorch import AutoNetEnsemble
118+
autoPyTorchEnsemble = AutoNetEnsemble(AutoNetClassification, "tiny_cs", max_runtime=300, min_budget=30, max_budget=90)
119+
120+
```
121+
122+
Disable pynisher if you experience issues when using cuda:
123+
124+
```py
125+
autoPyTorch = AutoNetClassification("tiny_cs", log_level='info', max_runtime=300, min_budget=30, max_budget=90, cuda=True, use_pynisher=False)
126+
60127
```
61128

62129
## License

autoPyTorch/components/ensembles/__init__.py

Whitespace-only changes.

autoPyTorch/components/ensembles/abstract_ensemble.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def fit(self, base_models_predictions, true_targets, model_identifiers):
2727
self
2828
2929
"""
30-
pass
3130

3231
@abstractmethod
3332
def predict(self, base_models_predictions):
@@ -42,7 +41,6 @@ def predict(self, base_models_predictions):
4241
-------
4342
array : [n_data_points]
4443
"""
45-
self
4644

4745
@abstractmethod
4846
def get_models_with_weights(self, models):

autoPyTorch/components/lr_scheduler/lr_schedulers.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
This file contains the different learning rate schedulers of AutoNet.
55
"""
66

7+
from autoPyTorch.utils.config_space_hyperparameter import add_hyperparameter, get_hyperparameter
8+
79
import torch
810
import torch.optim.lr_scheduler as lr_scheduler
911

1012
import ConfigSpace as CS
1113
import ConfigSpace.hyperparameters as CSH
1214

13-
from autoPyTorch.components.lr_scheduler.lr_schedulers_config import CSConfig
14-
1515
__author__ = "Max Dippel, Michael Burkart and Matthias Urban"
1616
__version__ = "0.0.1"
1717
__license__ = "BSD"
@@ -29,8 +29,7 @@ def _get_scheduler(self, optimizer, config):
2929
raise ValueError('Override the method _get_scheduler and do not call the base class implementation')
3030

3131
@staticmethod
32-
def get_config_space(*args, **kwargs):
33-
# currently no use but might come in handy in the future
32+
def get_config_space():
3433
return CS.ConfigurationSpace()
3534

3635
class SchedulerNone(AutoNetLearningRateSchedulerBase):
@@ -44,12 +43,13 @@ def _get_scheduler(self, optimizer, config):
4443
return lr_scheduler.StepLR(optimizer=optimizer, step_size=config['step_size'], gamma=config['gamma'], last_epoch=-1)
4544

4645
@staticmethod
47-
def get_config_space(*args, **kwargs):
46+
def get_config_space(
47+
step_size=(1, 10),
48+
gamma=(0.001, 0.9)
49+
):
4850
cs = CS.ConfigurationSpace()
49-
config = CSConfig['step_lr']
50-
cs.add_hyperparameter(CSH.UniformIntegerHyperparameter('step_size', lower=config['step_size'][0], upper=config['step_size'][1]))
51-
cs.add_hyperparameter(CSH.UniformFloatHyperparameter('gamma', lower=config['gamma'][0], upper=config['gamma'][1]))
52-
cs.add_configuration_space(prefix='', delimiter='', configuration_space=AutoNetLearningRateSchedulerBase.get_config_space(*args, **kwargs))
51+
add_hyperparameter(cs, CSH.UniformIntegerHyperparameter, 'step_size', step_size)
52+
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'gamma', gamma)
5353
return cs
5454

5555
class SchedulerExponentialLR(AutoNetLearningRateSchedulerBase):
@@ -58,11 +58,11 @@ def _get_scheduler(self, optimizer, config):
5858
return lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=config['gamma'], last_epoch=-1)
5959

6060
@staticmethod
61-
def get_config_space(*args, **kwargs):
61+
def get_config_space(
62+
gamma=(0.8, 0.9999)
63+
):
6264
cs = CS.ConfigurationSpace()
63-
config = CSConfig['exponential_lr']
64-
cs.add_hyperparameter(CSH.UniformFloatHyperparameter('gamma', lower=config['gamma'][0], upper=config['gamma'][1]))
65-
cs.add_configuration_space(prefix='', delimiter='', configuration_space=AutoNetLearningRateSchedulerBase.get_config_space(*args, **kwargs))
65+
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'gamma', gamma)
6666
return cs
6767

6868
class SchedulerReduceLROnPlateau(AutoNetLearningRateSchedulerBase):
@@ -71,12 +71,13 @@ def _get_scheduler(self, optimizer, config):
7171
return lr_scheduler.ReduceLROnPlateau(optimizer=optimizer)
7272

7373
@staticmethod
74-
def get_config_space(*args, **kwargs):
74+
def get_config_space(
75+
factor=(0.05, 0.5),
76+
patience=(3, 10)
77+
):
7578
cs = CS.ConfigurationSpace()
76-
config = CSConfig['reduce_on_plateau']
77-
cs.add_hyperparameter(CSH.UniformFloatHyperparameter('factor', lower=config['factor'][0], upper=config['factor'][1]))
78-
cs.add_hyperparameter(CSH.UniformIntegerHyperparameter('patience', lower=config['patience'][0], upper=config['patience'][1]))
79-
cs.add_configuration_space(prefix='', delimiter='', configuration_space=AutoNetLearningRateSchedulerBase.get_config_space(*args, **kwargs))
79+
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'factor', factor)
80+
add_hyperparameter(cs, CSH.UniformIntegerHyperparameter, 'patience', patience)
8081
return cs
8182

8283
class SchedulerCyclicLR(AutoNetLearningRateSchedulerBase):
@@ -96,13 +97,15 @@ def l(epoch):
9697
return lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=l, last_epoch=-1)
9798

9899
@staticmethod
99-
def get_config_space(*args, **kwargs):
100+
def get_config_space(
101+
max_factor=(1.0, 2),
102+
min_factor=(0.001, 1.0),
103+
cycle_length=(3, 10)
104+
):
100105
cs = CS.ConfigurationSpace()
101-
config = CSConfig['cyclic_lr']
102-
cs.add_hyperparameter(CSH.UniformFloatHyperparameter('max_factor', lower=config['max_factor'][0], upper=config['max_factor'][1]))
103-
cs.add_hyperparameter(CSH.UniformFloatHyperparameter('min_factor', lower=config['min_factor'][0], upper=config['min_factor'][1]))
104-
cs.add_hyperparameter(CSH.UniformIntegerHyperparameter('cycle_length', lower=config['cycle_length'][0], upper=config['cycle_length'][1]))
105-
cs.add_configuration_space(prefix='', delimiter='', configuration_space=AutoNetLearningRateSchedulerBase.get_config_space(*args, **kwargs))
106+
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'max_factor', max_factor)
107+
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'min_factor', min_factor)
108+
add_hyperparameter(cs, CSH.UniformIntegerHyperparameter, 'cycle_length', cycle_length)
106109
return cs
107110

108111
class SchedulerCosineAnnealingWithRestartsLR(AutoNetLearningRateSchedulerBase):
@@ -114,12 +117,13 @@ def _get_scheduler(self, optimizer, config):
114117
return scheduler
115118

116119
@staticmethod
117-
def get_config_space(*args, **kwargs):
120+
def get_config_space(
121+
T_max=(1, 20),
122+
T_mult=(1.0, 2.0)
123+
):
118124
cs = CS.ConfigurationSpace()
119-
config = CSConfig['cosine_annealing_lr']
120-
cs.add_hyperparameter(CSH.UniformIntegerHyperparameter('T_max', lower=config['T_max'][0], upper=config['T_max'][1]))
121-
cs.add_hyperparameter(CSH.UniformFloatHyperparameter('T_mult', lower=config['T_mult'][0], upper=config['T_mult'][1]))
122-
cs.add_configuration_space(prefix='', delimiter='', configuration_space=AutoNetLearningRateSchedulerBase.get_config_space(*args, **kwargs))
125+
add_hyperparameter(cs, CSH.UniformIntegerHyperparameter, 'T_max', T_max)
126+
add_hyperparameter(cs, CSH.UniformFloatHyperparameter, 'T_mult', T_mult)
123127
return cs
124128

125129

autoPyTorch/components/metrics/balanced_accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
def balanced_accuracy(y_pred, y_true):
8-
return _balanced_accuracy(np.argmax(y_pred, axis=1), np.argmax(y_true, axis=1)) * 100
8+
return _balanced_accuracy(np.argmax(y_true, axis=1), np.argmax(y_pred, axis=1)) * 100
99

1010

1111
def _balanced_accuracy(solution, prediction):

autoPyTorch/components/networks/feature/embedding.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn as nn
1212
import numpy as np
1313

14+
from autoPyTorch.utils.config_space_hyperparameter import get_hyperparameter, add_hyperparameter
1415
from autoPyTorch.components.preprocessing.preprocessor_base import PreprocessorBase
1516

1617
__author__ = "Max Dippel, Michael Burkart and Matthias Urban"
@@ -76,16 +77,23 @@ def _create_ee_layers(self, in_features):
7677
return layers
7778

7879
@staticmethod
79-
def get_config_space(categorical_features=None):
80+
def get_config_space(
81+
categorical_features=None,
82+
min_unique_values_for_embedding=((3, 300), True),
83+
dimension_reduction=(0, 1),
84+
**kwargs
85+
):
8086
# dimension of entity embedding layer is a hyperparameter
8187
if categorical_features is None or not any(categorical_features):
8288
return CS.ConfigurationSpace()
8389
cs = CS.ConfigurationSpace()
84-
min_hp = CSH.UniformIntegerHyperparameter("min_unique_values_for_embedding", lower=3, upper=300, default_value=3, log=True)
90+
min_hp = get_hyperparameter(CSH.UniformIntegerHyperparameter, "min_unique_values_for_embedding", min_unique_values_for_embedding)
8591
cs.add_hyperparameter(min_hp)
8692
for i in range(len([x for x in categorical_features if x])):
87-
ee_dimensions = CSH.UniformFloatHyperparameter("dimension_reduction_" + str(i), lower=0, upper=1, default_value=1, log=False)
88-
cs.add_hyperparameter(ee_dimensions)
93+
ee_dimensions_hp = get_hyperparameter(CSH.UniformFloatHyperparameter, "dimension_reduction_" + str(i),
94+
kwargs.pop("dimension_reduction_" + str(i), dimension_reduction))
95+
cs.add_hyperparameter(ee_dimensions_hp)
96+
assert len(kwargs) == 0, "Invalid hyperparameter updates for learned embedding: %s" % str(kwargs)
8997
return cs
9098

9199
class NoEmbedding(nn.Module):

autoPyTorch/components/networks/feature/mlpnet.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn as nn
1212

1313
from autoPyTorch.components.networks.base_net import BaseFeatureNet
14+
from autoPyTorch.utils.config_space_hyperparameter import add_hyperparameter, get_hyperparameter
1415

1516
__author__ = "Max Dippel, Michael Burkart and Matthias Urban"
1617
__version__ = "0.0.1"
@@ -46,35 +47,38 @@ def _add_layer(self, layers, in_features, out_features, layer_id):
4647
layers.append(nn.Dropout(self.config["dropout_%d" % layer_id]))
4748

4849
@staticmethod
49-
def get_config_space(user_updates=None):
50+
def get_config_space(
51+
num_layers=((1, 15), False),
52+
num_units=((10, 1024), True),
53+
activation=('sigmoid', 'tanh', 'relu'),
54+
dropout=(0.0, 0.8),
55+
use_dropout=(True, False),
56+
**kwargs
57+
):
5058
cs = CS.ConfigurationSpace()
51-
range_num_layers=(1, 15)
52-
range_num_units=(10, 1024)
53-
possible_activations=('sigmoid', 'tanh', 'relu')
54-
range_dropout=(0.0, 0.8)
55-
56-
if user_updates is not None and 'num_layers' in user_updates:
57-
range_num_layers = user_updates['num_layers']
58-
59-
num_layers = CSH.UniformIntegerHyperparameter('num_layers', lower=range_num_layers[0], upper=range_num_layers[1])
60-
cs.add_hyperparameter(num_layers)
61-
use_dropout = cs.add_hyperparameter(CS.CategoricalHyperparameter("use_dropout", [True, False], default_value=True))
62-
63-
for i in range(1, range_num_layers[1] + 1):
64-
n_units = CSH.UniformIntegerHyperparameter("num_units_%d" % i,
65-
lower=range_num_units[0], upper=range_num_units[1], log=True)
66-
cs.add_hyperparameter(n_units)
67-
dropout = CSH.UniformFloatHyperparameter("dropout_%d" % i, lower=range_dropout[0], upper=range_dropout[1])
68-
cs.add_hyperparameter(dropout)
69-
dropout_condition_1 = CS.EqualsCondition(dropout, use_dropout, True)
70-
71-
if i > range_num_layers[0]:
72-
cs.add_condition(CS.GreaterThanCondition(n_units, num_layers, i - 1))
73-
74-
dropout_condition_2 = CS.GreaterThanCondition(dropout, num_layers, i - 1)
75-
cs.add_condition(CS.AndConjunction(dropout_condition_1, dropout_condition_2))
76-
else:
77-
cs.add_condition(dropout_condition_1)
59+
60+
num_layers_hp = get_hyperparameter(CSH.UniformIntegerHyperparameter, 'num_layers', num_layers)
61+
cs.add_hyperparameter(num_layers_hp)
62+
use_dropout_hp = add_hyperparameter(cs, CS.CategoricalHyperparameter, "use_dropout", use_dropout)
63+
64+
for i in range(1, num_layers[0][1] + 1):
65+
n_units_hp = get_hyperparameter(CSH.UniformIntegerHyperparameter, "num_units_%d" % i, kwargs.pop("num_units_%d" % i, num_units))
66+
cs.add_hyperparameter(n_units_hp)
67+
68+
if i > num_layers[0][0]:
69+
cs.add_condition(CS.GreaterThanCondition(n_units_hp, num_layers_hp, i - 1))
70+
71+
if True in use_dropout:
72+
dropout_hp = get_hyperparameter(CSH.UniformFloatHyperparameter, "dropout_%d" % i, kwargs.pop("dropout_%d" % i, dropout))
73+
cs.add_hyperparameter(dropout_hp)
74+
dropout_condition_1 = CS.EqualsCondition(dropout_hp, use_dropout_hp, True)
75+
76+
if i > num_layers[0][0]:
77+
dropout_condition_2 = CS.GreaterThanCondition(dropout_hp, num_layers_hp, i - 1)
78+
cs.add_condition(CS.AndConjunction(dropout_condition_1, dropout_condition_2))
79+
else:
80+
cs.add_condition(dropout_condition_1)
7881

79-
cs.add_hyperparameter(CSH.CategoricalHyperparameter('activation', possible_activations))
82+
add_hyperparameter(cs, CSH.CategoricalHyperparameter,'activation', activation)
83+
assert len(kwargs) == 0, "Invalid hyperparameter updates for mlpnet: %s" % str(kwargs)
8084
return(cs)

0 commit comments

Comments
 (0)