diff --git a/pysurvival/models/semi_parametric.py b/pysurvival/models/semi_parametric.py index 80c21df..7e70a36 100644 --- a/pysurvival/models/semi_parametric.py +++ b/pysurvival/models/semi_parametric.py @@ -602,6 +602,8 @@ def fit(self, X, T, E, init_method = 'glorot_uniform', # Scaling data if self.auto_scaler: X_original = self.scaler.fit_transform( X ) + else: + X_original = X # Sorting X, T, E in descending order according to T order = np.argsort(-T)