diff --git a/python/hsfs/feature_group.py b/python/hsfs/feature_group.py index 2b5f7ff484..e0355f0c0c 100644 --- a/python/hsfs/feature_group.py +++ b/python/hsfs/feature_group.py @@ -2108,6 +2108,121 @@ def embedding_index(self) -> Optional["EmbeddingIndex"]: def embedding_index(self, embedding_index: Optional["EmbeddingIndex"]) -> None: self._embedding_index = embedding_index + def read_by_primary_key( + self, + primary_key_values: Dict[str, Union[Any, List[Any]]], + event_time_min: Optional[Union[str, datetime, date]] = None, + event_time_max: Optional[Union[str, datetime, date]] = None, + online: bool = False, + dataframe_type: str = "default", + read_options: Optional[dict] = None, + ) -> Union[ + pd.DataFrame, + np.ndarray, + List[List[Any]], + TypeVar("pyspark.sql.DataFrame"), + TypeVar("pyspark.RDD"), + pl.DataFrame, + ]: + """ + Retrieve feature data for specific primary key values and optional event time range. + + !!! example + ```python + # Single primary key lookup + result = fg.read_by_primary_key({"user_id": 123}) + + # Multiple primary key lookup + result = fg.read_by_primary_key({"user_id": 123, "session_id": "abc"}) + + # Multiple records for single primary key + result = fg.read_by_primary_key({"user_id": [123, 456, 789]}) + + # With event time filtering + result = fg.read_by_primary_key( + {"user_id": 123}, + event_time_min="2023-01-01", + event_time_max="2023-01-31" + ) + ``` + + # Arguments + primary_key_values: Dictionary mapping primary key column names to their values. + Values can be single items or lists for IN-style queries. + event_time_min: Optional minimum event time for temporal filtering. + Strings should be formatted as `%Y-%m-%d`, `%Y-%m-%d %H:%M:%S`, etc. + event_time_max: Optional maximum event time for temporal filtering. + online: If True, read from online feature store. Defaults to False. + dataframe_type: Return format - "default", "pandas", "spark", "polars", "numpy", "python". + read_options: Additional options as key/value pairs to pass to the execution engine. + + # Returns + `DataFrame`: The dataframe containing the matching records. + `pyspark.DataFrame`. A Spark DataFrame. + `pandas.DataFrame`. A Pandas DataFrame. + `polars.DataFrame`. A Polars DataFrame. + `numpy.ndarray`. A two-dimensional Numpy array. + `list`. A two-dimensional Python list. + + # Raises + `FeatureStoreException`: If primary key columns don't match feature group schema + or if event time filtering is requested but no event time column exists. + """ + if not primary_key_values: + raise FeatureStoreException("primary_key_values cannot be empty.") + + # Validate primary key columns match feature group schema + fg_primary_keys = set(self.primary_key) + provided_keys = set(primary_key_values.keys()) + + if not provided_keys.issubset(fg_primary_keys): + invalid_keys = provided_keys - fg_primary_keys + raise FeatureStoreException( + f"Invalid primary key columns: {list(invalid_keys)}. " + f"Feature group primary keys are: {self.primary_key}" + ) + + # Build filter conditions for primary keys + filters = [] + for pk_column, pk_value in primary_key_values.items(): + feature_obj = self.__getattr__(pk_column) + + if isinstance(pk_value, list): + # Use isin filter for multiple values + filter_condition = feature_obj.isin(pk_value) + else: + # Use equality filter for single value + filter_condition = feature_obj == pk_value + + filters.append(filter_condition) + + # Add event time filters if specified + if event_time_min is not None or event_time_max is not None: + if self.event_time is None: + raise FeatureStoreException( + "Event time filtering requested but feature group has no event time column." + ) + + event_time_feature = self.__getattr__(self.event_time) + + if event_time_min is not None: + filters.append(event_time_feature >= event_time_min) + + if event_time_max is not None: + filters.append(event_time_feature <= event_time_max) + + # Combine all filters with AND logic + combined_filter = filters[0] + for f in filters[1:]: + combined_filter = combined_filter & f + + # Execute query using existing filter and read + return self.filter(combined_filter).read( + online=online, + dataframe_type=dataframe_type, + read_options=read_options or {}, + ) + @property def event_time(self) -> Optional[str]: """Event time feature in the feature group."""