diff --git a/python/hsfs/core/feature_view_engine.py b/python/hsfs/core/feature_view_engine.py index d4b7259c6b..2db7c6c762 100644 --- a/python/hsfs/core/feature_view_engine.py +++ b/python/hsfs/core/feature_view_engine.py @@ -335,8 +335,32 @@ def get_batch_query( dataframe_features = engine.get_instance().parse_schema_feature_group( spine.dataframe ) + + all_left_feature_group_features = query._left_feature_group.features + selected_features = query.features + # Using list comprehension + selected_left_feature_group_features = [ + feature + for feature in all_left_feature_group_features + if feature in selected_features + ] + + # remove label features as they are not necessary to check + labels = [feat for feat in feature_view_obj.features if feat.label] + left_feature_group_features = [ + feature + for feature in selected_left_feature_group_features + if not any( + ( + feature.name == label.name + and feature._feature_group_id == label._feature_group.id + ) + for label in labels + ) + ] + spine._feature_group_engine._verify_schema_compatibility( - query._left_feature_group.features, dataframe_features + left_feature_group_features, dataframe_features ) query._left_feature_group = spine elif isinstance(query._left_feature_group, feature_group.SpineGroup): @@ -478,7 +502,8 @@ def get_training_data( # forcing dataframe type to default here since dataframe operations are required for training data split. dataframe_type="default" if dataframe_type.lower() in ["numpy", "python"] - else dataframe_type, # forcing dataframe type to default here since dataframe operations are required for training data split. + else dataframe_type, + # forcing dataframe type to default here since dataframe operations are required for training data split. ) else: self._check_feature_group_accessibility(feature_view_obj) diff --git a/python/hsfs/feature_group.py b/python/hsfs/feature_group.py index f65b12c7a0..8985cb5968 100644 --- a/python/hsfs/feature_group.py +++ b/python/hsfs/feature_group.py @@ -456,8 +456,13 @@ def select_features( # Returns `Query`. A query object with all features of the feature group. """ - select_features = self.primary_key + self.foreign_key + [self.event_time] - if not isinstance(self, ExternalFeatureGroup): + select_features = self.primary_key + [self.event_time] + + if not isinstance(self, SpineGroup): + select_features = select_features + self.foreign_key + if not isinstance(self, ExternalFeatureGroup) and not isinstance( + self, SpineGroup + ): select_features = select_features + self.partition_key query = self.select_except(select_features)