Skip to content

Commit 7b92f55

Browse files
committed
pass dataset instead of dataset info to autonet.get_hyperparameter_search_space()
1 parent 472618e commit 7b92f55

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ autoPyTorch = AutoNetClassification(networks=["resnet", "shapedresnet", "mlpnet"
9595
# Each hyperparameter belongs to a node in Auto-PyTorch's ML Pipeline.
9696
# The names of the hyperparameters are prefixed with the name of the node: NodeName:hyperparameter_name.
9797
# If a hyperparameter belongs to a component: NodeName:component_name:hyperparameter_name.
98-
autoPyTorch.get_hyperparameter_search_space()
98+
# Call with the same arguments as fit.
99+
autoPyTorch.get_hyperparameter_search_space(X_train, y_train, validation_split=0.3)
99100

100101
# You can configure the search space of every hyperparameter of every component:
101102
from autoPyTorch import HyperparameterSearchSpaceUpdates

autoPyTorch/components/training/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def final_eval(self, opt_metric_name, logs, train_loader, valid_loader, minimize
8888
for i, metric in enumerate(self.metrics):
8989
if valid_metric_results:
9090
final_log['val_' + metric.__name__] = valid_metric_results[i]
91-
if self.eval_additional_logs_on_snapshot:
91+
if self.eval_additional_logs_on_snapshot and not refit:
9292
for additional_log in self.log_functions:
9393
final_log[additional_log.__name__] = additional_log(self.model, None)
9494
return final_log

autoPyTorch/core/api.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from autoPyTorch.pipeline.nodes.cross_validation import CrossValidation
1616
from autoPyTorch.pipeline.nodes.metric_selector import MetricSelector
1717
from autoPyTorch.pipeline.nodes.optimization_algorithm import OptimizationAlgorithm
18+
from autoPyTorch.pipeline.nodes.create_dataset_info import CreateDatasetInfo
1819

1920

2021
from autoPyTorch.utils.config.config_file_parser import ConfigFileParser
@@ -68,13 +69,28 @@ def get_current_autonet_config(self):
6869
return self.autonet_config
6970
return self.pipeline.get_pipeline_config(**self.base_config)
7071

71-
def get_hyperparameter_search_space(self, dataset_info=None):
72-
"""Return the hyperparameter search space of AutoNet
72+
def get_hyperparameter_search_space(self, X_train=None, Y_train=None, X_valid=None, Y_valid=None, **autonet_config):
73+
"""Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset!
74+
75+
Keyword Arguments:
76+
X_train {array} -- Training data.
77+
Y_train {array} -- Targets of training data.
78+
X_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
79+
Y_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None})
7380
7481
Returns:
75-
ConfigurationSpace -- The ConfigurationSpace that should be optimized
82+
ConfigurationSpace -- The configuration space that should be optimized.
7683
"""
7784

85+
dataset_info = None
86+
if X_train is not None and Y_train is not None:
87+
dataset_info_node = self.pipeline[CreateDatasetInfo.get_name()]
88+
dataset_info = dataset_info_node.fit(pipeline_config=dict(self.base_config, **autonet_config),
89+
X_train=X_train,
90+
Y_train=Y_train,
91+
X_valid=X_valid,
92+
Y_valid=Y_valid)["dataset_info"]
93+
7894
return self.pipeline.get_hyperparameter_search_space(dataset_info=dataset_info, **self.get_current_autonet_config())
7995

8096
@classmethod

0 commit comments

Comments
 (0)