|
15 | 15 | from autoPyTorch.pipeline.nodes.cross_validation import CrossValidation |
16 | 16 | from autoPyTorch.pipeline.nodes.metric_selector import MetricSelector |
17 | 17 | from autoPyTorch.pipeline.nodes.optimization_algorithm import OptimizationAlgorithm |
| 18 | +from autoPyTorch.pipeline.nodes.create_dataset_info import CreateDatasetInfo |
18 | 19 |
|
19 | 20 |
|
20 | 21 | from autoPyTorch.utils.config.config_file_parser import ConfigFileParser |
@@ -68,13 +69,28 @@ def get_current_autonet_config(self): |
68 | 69 | return self.autonet_config |
69 | 70 | return self.pipeline.get_pipeline_config(**self.base_config) |
70 | 71 |
|
71 | | - def get_hyperparameter_search_space(self, dataset_info=None): |
72 | | - """Return the hyperparameter search space of AutoNet |
| 72 | + def get_hyperparameter_search_space(self, X_train=None, Y_train=None, X_valid=None, Y_valid=None, **autonet_config): |
| 73 | + """Return hyperparameter search space of Auto-PyTorch. Does depend on the dataset! |
| 74 | + |
| 75 | + Keyword Arguments: |
| 76 | + X_train {array} -- Training data. |
| 77 | + Y_train {array} -- Targets of training data. |
| 78 | + X_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None}) |
| 79 | + Y_valid {array} -- Validation data. Will be ignored if cv_splits > 1. (default: {None}) |
73 | 80 | |
74 | 81 | Returns: |
75 | | - ConfigurationSpace -- The ConfigurationSpace that should be optimized |
| 82 | + ConfigurationSpace -- The configuration space that should be optimized. |
76 | 83 | """ |
77 | 84 |
|
| 85 | + dataset_info = None |
| 86 | + if X_train is not None and Y_train is not None: |
| 87 | + dataset_info_node = self.pipeline[CreateDatasetInfo.get_name()] |
| 88 | + dataset_info = dataset_info_node.fit(pipeline_config=dict(self.base_config, **autonet_config), |
| 89 | + X_train=X_train, |
| 90 | + Y_train=Y_train, |
| 91 | + X_valid=X_valid, |
| 92 | + Y_valid=Y_valid)["dataset_info"] |
| 93 | + |
78 | 94 | return self.pipeline.get_hyperparameter_search_space(dataset_info=dataset_info, **self.get_current_autonet_config()) |
79 | 95 |
|
80 | 96 | @classmethod |
|
0 commit comments