diff --git a/pylearn2/cross_validation/__init__.py b/pylearn2/cross_validation/__init__.py index e5dbc60113..fcfbbbe451 100644 --- a/pylearn2/cross_validation/__init__.py +++ b/pylearn2/cross_validation/__init__.py @@ -34,6 +34,9 @@ class TrainCV(object): Training model. If list, training model for each fold. algorithm : TrainingAlgorithm Training algorithm. + algorithm_monitoring_datasets : list or None + Subsets of the dataset to be monitored. + Leave as None to monitor all subsets. save_path : str or None Output filename for trained models. Also used (with modification) for individual models if save_folds is True. @@ -49,6 +52,7 @@ class TrainCV(object): TrainCVExtension objects for the parent TrainCV object. """ def __init__(self, dataset_iterator, model, algorithm=None, + algorithm_monitoring_datasets=None, save_path=None, save_freq=0, extensions=None, allow_overwrite=True, save_folds=False, cv_extensions=None): self.dataset_iterator = dataset_iterator @@ -76,7 +80,13 @@ def __init__(self, dataset_iterator, model, algorithm=None, # setup monitoring datasets this_algorithm = deepcopy(algorithm) - this_algorithm._set_monitoring_dataset(datasets) + if algorithm_monitoring_datasets is None: + monitoring_datasets = datasets + else: + monitoring_datasets = dict( + (k, v) for (k, v) in datasets.iteritems() + if k in algorithm_monitoring_datasets) + this_algorithm._set_monitoring_dataset(monitoring_datasets) # extensions this_extensions = deepcopy(extensions)