diff --git a/csrank/choicefunction/baseline.py b/csrank/choicefunction/baseline.py index 71f9a0b3..a3c69439 100644 --- a/csrank/choicefunction/baseline.py +++ b/csrank/choicefunction/baseline.py @@ -18,7 +18,7 @@ def __init__(self, **kwargs): """ def fit(self, X, Y, **kwd): - pass + self._pre_fit() def _predict_scores_fixed(self, X, Y, **kwargs): return np.zeros_like(Y) + Y.mean() diff --git a/csrank/choicefunction/cmpnet_choice.py b/csrank/choicefunction/cmpnet_choice.py index c300dac6..4bde0cd7 100644 --- a/csrank/choicefunction/cmpnet_choice.py +++ b/csrank/choicefunction/cmpnet_choice.py @@ -154,6 +154,7 @@ def fit( **kwd : Keyword arguments for the fit function """ + self._pre_fit() if tune_size > 0: X_train, X_val, Y_train, Y_val = train_test_split( X, Y, test_size=tune_size, random_state=self.random_state diff --git a/csrank/choicefunction/fate_choice.py b/csrank/choicefunction/fate_choice.py index eb318d6f..2154468d 100644 --- a/csrank/choicefunction/fate_choice.py +++ b/csrank/choicefunction/fate_choice.py @@ -163,6 +163,7 @@ def fit( documentation of :func:`~csrank.core.FATENetwork.fit` for more information. """ + self._pre_fit() if tune_size > 0: X_train, X_val, Y_train, Y_val = train_test_split( X, Y, test_size=tune_size, random_state=self.random_state diff --git a/csrank/choicefunction/fatelinear_choice.py b/csrank/choicefunction/fatelinear_choice.py index 18e1110e..a36666aa 100644 --- a/csrank/choicefunction/fatelinear_choice.py +++ b/csrank/choicefunction/fatelinear_choice.py @@ -73,6 +73,7 @@ def fit( verbose=0, **kwd, ): + self._pre_fit() if tune_size > 0: X_train, X_val, Y_train, Y_val = train_test_split( X, Y, test_size=tune_size, random_state=self.random_state diff --git a/csrank/choicefunction/feta_choice.py b/csrank/choicefunction/feta_choice.py index 7757aa19..e5a2a738 100644 --- a/csrank/choicefunction/feta_choice.py +++ b/csrank/choicefunction/feta_choice.py @@ -286,6 +286,7 @@ def fit( **kwd : Keyword arguments for the fit function """ + self._pre_fit() if tune_size > 0: X_train, X_val, Y_train, Y_val = train_test_split( X, Y, test_size=tune_size, random_state=self.random_state diff --git a/csrank/choicefunction/fetalinear_choice.py b/csrank/choicefunction/fetalinear_choice.py index faf0ad29..ec24015b 100644 --- a/csrank/choicefunction/fetalinear_choice.py +++ b/csrank/choicefunction/fetalinear_choice.py @@ -71,6 +71,7 @@ def fit( verbose=0, **kwd, ): + self._pre_fit() if tune_size > 0: X_train, X_val, Y_train, Y_val = train_test_split( X, Y, test_size=tune_size, random_state=self.random_state diff --git a/csrank/choicefunction/generalized_linear_model.py b/csrank/choicefunction/generalized_linear_model.py index 2adec3e0..10ba751e 100644 --- a/csrank/choicefunction/generalized_linear_model.py +++ b/csrank/choicefunction/generalized_linear_model.py @@ -158,6 +158,10 @@ def construct_model(self, X, Y): BinaryCrossEntropyLikelihood("yl", p=self.p_, observed=self.Yt_) logger.info("Model construction completed") + def _pre_fit(self): + super()._pre_fit() + self.random_state_ = check_random_state(self.random_state) + def fit( self, X, @@ -217,7 +221,7 @@ def fit( **kwargs : Keyword arguments for the fit function """ - self.random_state_ = check_random_state(self.random_state) + self._pre_fit() if tune_size > 0: X_train, X_val, Y_train, Y_val = train_test_split( X, Y, test_size=tune_size, random_state=self.random_state_ diff --git a/csrank/choicefunction/pairwise_choice.py b/csrank/choicefunction/pairwise_choice.py index 2706dccc..36450e32 100644 --- a/csrank/choicefunction/pairwise_choice.py +++ b/csrank/choicefunction/pairwise_choice.py @@ -96,6 +96,7 @@ def fit(self, X, Y, tune_size=0.1, thin_thresholds=1, verbose=0, **kwd): Keyword arguments for the fit function """ + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape if tune_size > 0: X_train, X_val, Y_train, Y_val = train_test_split( diff --git a/csrank/choicefunction/ranknet_choice.py b/csrank/choicefunction/ranknet_choice.py index fe1ec3e5..c687fb87 100644 --- a/csrank/choicefunction/ranknet_choice.py +++ b/csrank/choicefunction/ranknet_choice.py @@ -141,6 +141,7 @@ def fit( **kwd : Keyword arguments for the fit function """ + self._pre_fit() if tune_size > 0: X_train, X_val, Y_train, Y_val = train_test_split( X, Y, test_size=tune_size, random_state=self.random_state diff --git a/csrank/core/cmpnet_core.py b/csrank/core/cmpnet_core.py index b5d2c408..40c8e74c 100644 --- a/csrank/core/cmpnet_core.py +++ b/csrank/core/cmpnet_core.py @@ -95,8 +95,6 @@ def construct_model(self): model: keras :class:`Model` Neural network to learn the CmpNet utility score """ - self._initialize_optimizer() - self._initialize_regularizer() x1x2 = concatenate([self.x1, self.x2]) x2x1 = concatenate([self.x2, self.x1]) logger.debug("Creating the model") @@ -116,6 +114,12 @@ def construct_model(self): ) return model + def _pre_fit(self): + super()._pre_fit() + self.random_state_ = check_random_state(self.random_state) + self._initialize_optimizer() + self._initialize_regularizer() + def fit( self, X, Y, epochs=10, callbacks=None, validation_split=0.1, verbose=0, **kwd ): @@ -151,8 +155,7 @@ def fit( **kwd : Keyword arguments for the fit function """ - self.random_state_ = check_random_state(self.random_state) - self._initialize_regularizer() + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape x1, x2, y_double = self._convert_instances_(X, Y) diff --git a/csrank/core/fate_linear.py b/csrank/core/fate_linear.py index 5ca3fbc4..6b2f83dc 100644 --- a/csrank/core/fate_linear.py +++ b/csrank/core/fate_linear.py @@ -86,10 +86,14 @@ def step_decay(self, epoch): self.loss_ ) + def _pre_fit(self): + super()._pre_fit() + self.random_state_ = check_random_state(self.random_state) + def fit( self, X, Y, epochs=10, callbacks=None, validation_split=0.1, verbose=0, **kwd ): - self.random_state_ = check_random_state(self.random_state) + self._pre_fit() # Global Variables Initializer n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape self._construct_model_(self.n_objects_fit_) diff --git a/csrank/core/fate_network.py b/csrank/core/fate_network.py index c8965929..1de78df1 100644 --- a/csrank/core/fate_network.py +++ b/csrank/core/fate_network.py @@ -70,9 +70,6 @@ def __init__( self.kernel_regularizer = kernel_regularizer self.batch_size = batch_size self.optimizer = optimizer - self._initialize_optimizer() - self._initialize_regularizer() - self._construct_layers() self._store_kwargs( kwargs, {"optimizer__", "kernel_regularizer__", "hidden_dense_layer__"} ) @@ -148,6 +145,12 @@ def join_input_layers(self, input_layer, *layers, n_layers, n_objects): return scores + def _pre_fit(self): + super()._pre_fit() + self._initialize_optimizer() + self._initialize_regularizer() + self._construct_layers() + class FATENetwork(FATENetworkCore): def __init__(self, n_hidden_set_layers=1, n_hidden_set_units=1, **kwargs): @@ -168,12 +171,6 @@ def __init__(self, n_hidden_set_layers=1, n_hidden_set_units=1, **kwargs): self.n_hidden_set_layers = n_hidden_set_layers self.n_hidden_set_units = n_hidden_set_units - self.set_layer = None - self._create_set_layers( - activation=self.activation, - kernel_initializer=self.kernel_initializer, - kernel_regularizer=self.kernel_regularizer_, - ) def _create_set_layers(self, **kwargs): """ @@ -186,11 +183,11 @@ def _create_set_layers(self, **kwargs): ) ) if self.n_hidden_set_layers >= 1: - self.set_layer = DeepSet( + self.set_layer_ = DeepSet( units=self.n_hidden_set_units, layers=self.n_hidden_set_layers, **kwargs ) else: - self.set_layer = None + self.set_layer_ = None @staticmethod def _bucket_frequencies(X, min_bucket_size=32): @@ -308,6 +305,7 @@ def _fit( **kwargs : Keyword arguments for the fit function """ + self._pre_fit() if optimizer is not None: self.optimizer = optimizer if isinstance(X, dict): @@ -422,7 +420,7 @@ def construct_model(self, n_features, n_objects): """ input_layer = Input(shape=(n_objects, n_features), name="input_node") - set_repr = self.set_layer(input_layer) + set_repr = self.set_layer_(input_layer) scores = self.join_input_layers( input_layer, set_repr, @@ -438,6 +436,17 @@ def construct_model(self, n_features, n_objects): ) return model + def _pre_fit(self): + super()._pre_fit() + self.random_state_ = check_random_state(self.random_state) + self._initialize_optimizer() + self._initialize_regularizer() + self._create_set_layers( + activation=self.activation, + kernel_initializer=self.kernel_initializer, + kernel_regularizer=self.kernel_regularizer_, + ) + def fit( self, X, @@ -494,10 +503,7 @@ def fit( **kwargs : Keyword arguments for the fit function """ - self.random_state_ = check_random_state(self.random_state) _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape - self._initialize_optimizer() - self._initialize_regularizer() self._fit( X=X, Y=Y, @@ -596,9 +602,9 @@ def _get_context_representation(self, X, kwargs): shape=(n_objects, self.n_object_features_fit_), name="input_node" ) if self.n_hidden_set_layers >= 1: - self.set_layer(input_layer_scorer) - fr = self.set_layer.cached_models[n_objects].predict(X, **kwargs) - del self.set_layer.cached_models[n_objects] + self.set_layer_(input_layer_scorer) + fr = self.set_layer_.cached_models[n_objects].predict(X, **kwargs) + del self.set_layer_.cached_models[n_objects] X_n = np.empty( (fr.shape[0], n_objects, fr.shape[1] + self.n_object_features_fit_), dtype="float", diff --git a/csrank/core/feta_linear.py b/csrank/core/feta_linear.py index ce90849a..1e5f38a8 100644 --- a/csrank/core/feta_linear.py +++ b/csrank/core/feta_linear.py @@ -136,6 +136,10 @@ def step_decay(self, epoch): self.loss ) + def _pre_fit(self): + super()._pre_fit() + self.random_state_ = check_random_state(self.random_state) + def fit( self, X, Y, epochs=10, callbacks=None, validation_split=0.1, verbose=0, **kwd ): @@ -155,7 +159,7 @@ def fit( predict the target variables and adjust its parameters by gradient descent `epochs` times. """ - self.random_state_ = check_random_state(self.random_state) + self._pre_fit() # Global Variables Initializer n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape self._construct_model_(self.n_objects_fit_) diff --git a/csrank/core/feta_network.py b/csrank/core/feta_network.py index ae88809b..6c66eaf2 100644 --- a/csrank/core/feta_network.py +++ b/csrank/core/feta_network.py @@ -264,6 +264,12 @@ def create_input_lambda(i): ) return model + def _pre_fit(self): + super()._pre_fit() + self._initialize_optimizer() + self._initialize_regularizer() + self.random_state_ = check_random_state(self.random_state) + def fit( self, X, Y, epochs=10, callbacks=None, validation_split=0.1, verbose=0, **kwd ): @@ -290,13 +296,11 @@ def fit( **kwd : Keyword arguments for the fit function """ + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape - self._initialize_optimizer() - self._initialize_regularizer() self._construct_layers() logger.debug("Enter fit function...") - self.random_state_ = check_random_state(self.random_state) X, Y = self.sub_sampling(X, Y) self.model_ = self.construct_model() diff --git a/csrank/core/pairwise_svm.py b/csrank/core/pairwise_svm.py index 675b410a..4e6eac94 100644 --- a/csrank/core/pairwise_svm.py +++ b/csrank/core/pairwise_svm.py @@ -54,6 +54,10 @@ def __init__( self.random_state = random_state self.fit_intercept = fit_intercept + def _pre_fit(self): + super()._pre_fit() + self.random_state_ = check_random_state(self.random_state) + def fit(self, X, Y, **kwargs): """ Fit a generic preference learning model on a provided set of queries. @@ -69,7 +73,7 @@ def fit(self, X, Y, **kwargs): Keyword arguments for the fit function """ - self.random_state_ = check_random_state(self.random_state) + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape x_train, y_single = self._convert_instances_(X, Y) if self.use_logistic_regression: diff --git a/csrank/core/ranknet_core.py b/csrank/core/ranknet_core.py index d56c9779..0d0908fe 100644 --- a/csrank/core/ranknet_core.py +++ b/csrank/core/ranknet_core.py @@ -107,6 +107,12 @@ def construct_model(self): def _convert_instances_(self, X, Y): raise NotImplementedError + def _pre_fit(self): + super()._pre_fit() + self.random_state_ = check_random_state(self.random_state) + self._initialize_optimizer() + self._initialize_regularizer() + def fit( self, X, Y, epochs=10, callbacks=None, validation_split=0.1, verbose=0, **kwd ): @@ -139,15 +145,13 @@ def fit( **kwd : Keyword arguments for the fit function """ - self.random_state_ = check_random_state(self.random_state) + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape X1, X2, Y_single = self._convert_instances_(X, Y) logger.debug("Instances created {}".format(X1.shape[0])) logger.debug("Creating the model") - self._initialize_optimizer() - self._initialize_regularizer() self._construct_layers() # Model with input as two objects and output as probability of x1>x2 diff --git a/csrank/discretechoice/baseline.py b/csrank/discretechoice/baseline.py index 88fef4c4..6f8585b6 100644 --- a/csrank/discretechoice/baseline.py +++ b/csrank/discretechoice/baseline.py @@ -18,9 +18,13 @@ def __init__(self, random_state=None, **kwargs): self.random_state = random_state - def fit(self, X, Y, **kwd): + def _pre_fit(self): + super()._pre_fit() self.random_state_ = check_random_state(self.random_state) + def fit(self, X, Y, **kwd): + self._pre_fit() + def _predict_scores_fixed(self, X, **kwargs): n_instances, n_objects, n_features = X.shape return self.random_state_.rand(n_instances, n_objects) diff --git a/csrank/discretechoice/generalized_nested_logit.py b/csrank/discretechoice/generalized_nested_logit.py index 29b08fc9..7ba6cdb4 100644 --- a/csrank/discretechoice/generalized_nested_logit.py +++ b/csrank/discretechoice/generalized_nested_logit.py @@ -328,6 +328,7 @@ def fit( **kwargs : Keyword arguments for the fit function of :meth:`pymc3.fit`or :meth:`pymc3.sample` """ + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape if self.n_nests is None: # TODO this looks like a bug to me, but it was already done this way diff --git a/csrank/discretechoice/mixed_logit_model.py b/csrank/discretechoice/mixed_logit_model.py index 1c9ecb9c..a1dc687a 100644 --- a/csrank/discretechoice/mixed_logit_model.py +++ b/csrank/discretechoice/mixed_logit_model.py @@ -214,6 +214,7 @@ def fit( **kwargs : Keyword arguments for the fit function of :meth:`pymc3.fit`or :meth:`pymc3.sample` """ + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape self.construct_model(X, Y) fit_pymc3_model(self, sampler, draws, tune, vi_params, **kwargs) diff --git a/csrank/discretechoice/model_selector.py b/csrank/discretechoice/model_selector.py index bee8a8b0..078d9719 100644 --- a/csrank/discretechoice/model_selector.py +++ b/csrank/discretechoice/model_selector.py @@ -58,6 +58,7 @@ def __init__( self.models = dict() def fit(self, X, Y): + self._pre_fit() model_args = dict() for param_key in self.parameter_keys: model_args[param_key] = self.uniform_prior diff --git a/csrank/discretechoice/multinomial_logit_model.py b/csrank/discretechoice/multinomial_logit_model.py index 99b03d74..7db65b46 100644 --- a/csrank/discretechoice/multinomial_logit_model.py +++ b/csrank/discretechoice/multinomial_logit_model.py @@ -213,6 +213,7 @@ def fit( **kwargs : Keyword arguments for the fit function of :meth:`pymc3.fit`or :meth:`pymc3.sample` """ + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape self.construct_model(X, Y) fit_pymc3_model(self, sampler, draws, tune, vi_params, **kwargs) diff --git a/csrank/discretechoice/nested_logit_model.py b/csrank/discretechoice/nested_logit_model.py index abf92b27..5b4e1d46 100644 --- a/csrank/discretechoice/nested_logit_model.py +++ b/csrank/discretechoice/nested_logit_model.py @@ -385,6 +385,7 @@ def fit( **kwargs : Keyword arguments for the fit function of :meth:`pymc3.fit`or :meth:`pymc3.sample` """ + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape if self.n_nests is None: self.n_nests = int(self.n_objects_fit_ / 2) diff --git a/csrank/discretechoice/paired_combinatorial_logit.py b/csrank/discretechoice/paired_combinatorial_logit.py index d533bbb1..54d90269 100644 --- a/csrank/discretechoice/paired_combinatorial_logit.py +++ b/csrank/discretechoice/paired_combinatorial_logit.py @@ -270,6 +270,10 @@ def construct_model(self, X, Y): ) logger.info("Model construction completed") + def _pre_fit(self): + super()._pre_fit() + self.random_state_ = check_random_state(self.random_state) + def fit( self, X, @@ -320,7 +324,7 @@ def fit( **kwargs : Keyword arguments for the fit function of :meth:`pymc3.fit`or :meth:`pymc3.sample` """ - self.random_state_ = check_random_state(self.random_state) + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape self.nests_indices = np.array( list(combinations(np.arange(self.n_objects_fit_), 2)) diff --git a/csrank/discretechoice/pairwise_discrete_choice.py b/csrank/discretechoice/pairwise_discrete_choice.py index 6ee1dec9..b50d678f 100644 --- a/csrank/discretechoice/pairwise_discrete_choice.py +++ b/csrank/discretechoice/pairwise_discrete_choice.py @@ -74,5 +74,6 @@ def _convert_instances_(self, X, Y): return x_train, y_single def fit(self, X, Y, **kwd): + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape super().fit(X, Y, **kwd) diff --git a/csrank/dyadranking/fate_dyad_ranker.py b/csrank/dyadranking/fate_dyad_ranker.py index e0356d14..37c10fcd 100644 --- a/csrank/dyadranking/fate_dyad_ranker.py +++ b/csrank/dyadranking/fate_dyad_ranker.py @@ -5,7 +5,7 @@ class FATEDyadRanker(FATENetwork, DyadRanker): def fit(self, Xo, Xc, Y, **kwargs): - pass + self._pre_fit() def predict_scores(self, Xo, Xc, **kwargs): return self.model_.predict([Xo, Xc], **kwargs) diff --git a/csrank/learner.py b/csrank/learner.py index c5054a4e..a4140258 100644 --- a/csrank/learner.py +++ b/csrank/learner.py @@ -75,6 +75,21 @@ def fit(self, X, Y, **kwargs): """ raise NotImplementedError + def _pre_fit(self): + """Perform stateful initialization before fitting. + + This function is for initialization that does not depend on the data, + but still requires some processing and therefore should not happen in + __init__. Examples include initialization of optimizers, construction + of NeuralNetwork layers (if it can be done without knowledge of the + data) etc. + + You should always call this function before fit, even if you do not + override it. If you override it, you should call the super method first + so that general initializations can be inherited. + """ + pass + @abstractmethod def _predict_scores_fixed(self, X, **kwargs): """ diff --git a/csrank/objectranking/baseline.py b/csrank/objectranking/baseline.py index 12ed77eb..acb5a6b1 100644 --- a/csrank/objectranking/baseline.py +++ b/csrank/objectranking/baseline.py @@ -18,9 +18,13 @@ def __init__(self, random_state=None, **kwargs): self.random_state = (random_state,) - def fit(self, X, Y, **kwd): + def _pre_fit(self): + super()._pre_fit() self.random_state_ = check_random_state(self.random_state) + def fit(self, X, Y, **kwd): + self._pre_fit() + def _predict_scores_fixed(self, X, **kwargs): n_instances, n_objects, n_features = X.shape return self.random_state_.rand(n_instances, n_objects) diff --git a/csrank/objectranking/expected_rank_regression.py b/csrank/objectranking/expected_rank_regression.py index 2eb09e5a..71dfde37 100644 --- a/csrank/objectranking/expected_rank_regression.py +++ b/csrank/objectranking/expected_rank_regression.py @@ -72,6 +72,10 @@ def __init__( self.fit_intercept = fit_intercept self.random_state = random_state + def _pre_fit(self): + super()._pre_fit() + self.random_state_ = check_random_state(self.random_state) + def fit(self, X, Y, **kwargs): """ Fit an ExpectedRankRegression on the provided set of queries X and preferences Y of those objects. @@ -88,7 +92,7 @@ def fit(self, X, Y, **kwargs): **kwargs Keyword arguments for the fit function """ - self.random_state_ = check_random_state(self.random_state) + self._pre_fit() logger.debug("Creating the Dataset") x_train, y_train = complete_linear_regression_dataset(X, Y) logger.debug("Finished the Dataset") diff --git a/csrank/objectranking/list_net.py b/csrank/objectranking/list_net.py index 492f44e9..c5fc19cd 100644 --- a/csrank/objectranking/list_net.py +++ b/csrank/objectranking/list_net.py @@ -139,6 +139,12 @@ def _create_topk(self, X, Y): Y_topk = Y[mask].reshape(n_inst, self.n_top) return X_topk, Y_topk + def _pre_fit(self): + super()._pre_fit() + self.random_state_ = check_random_state(self.random_state) + self._initialize_optimizer() + self._initialize_regularizer() + def fit( self, X, Y, epochs=10, callbacks=None, validation_split=0.1, verbose=0, **kwd ): @@ -172,10 +178,8 @@ def fit( **kwd Keyword arguments for the fit function """ - self.random_state_ = check_random_state(self.random_state) + self._pre_fit() _n_instances, _n_objects, self.n_object_features_fit_ = X.shape - self._initialize_optimizer() - self._initialize_regularizer() self._construct_layers() logger.debug("Creating top-k dataset") X, Y = self._create_topk(X, Y) diff --git a/csrank/objectranking/rank_svm.py b/csrank/objectranking/rank_svm.py index 1318f19f..89a56339 100644 --- a/csrank/objectranking/rank_svm.py +++ b/csrank/objectranking/rank_svm.py @@ -58,6 +58,7 @@ def __init__( logger.info("Initializing network") def fit(self, X, Y, **kwargs): + self._pre_fit() _n_instances, self.n_objects_fit_, self.n_object_features_fit_ = X.shape super().fit(X, Y, **kwargs)