diff --git a/src/tsn_adapters/blocks/tn_access.py b/src/tsn_adapters/blocks/tn_access.py index 4cdaec8..34b62c7 100644 --- a/src/tsn_adapters/blocks/tn_access.py +++ b/src/tsn_adapters/blocks/tn_access.py @@ -44,7 +44,7 @@ def _retry_condition(task: Task[Any, Any], task_run: TaskRun, state: State[Any]) except TNNodeNetworkError: # Always retry on network errors return True - except Exception as exc: + except Exception: # For non-network errors, use the task_run's run_count to decide. if task_run.run_count <= max_other_error_retries: return True diff --git a/src/tsn_adapters/common/trufnetwork/models/tn_models.py b/src/tsn_adapters/common/trufnetwork/models/tn_models.py index 30e8814..3664706 100644 --- a/src/tsn_adapters/common/trufnetwork/models/tn_models.py +++ b/src/tsn_adapters/common/trufnetwork/models/tn_models.py @@ -28,6 +28,7 @@ class Config(pa.DataFrameModel.Config): coerce = True strict = "filter" + class StreamLocatorModel(pa.DataFrameModel): stream_id: Series[str] data_provider: Series[pa.String] = pa.Field( diff --git a/src/tsn_adapters/flows/fmp/historical_flow.py b/src/tsn_adapters/flows/fmp/historical_flow.py index 2cdd5dc..8b898db 100644 --- a/src/tsn_adapters/flows/fmp/historical_flow.py +++ b/src/tsn_adapters/flows/fmp/historical_flow.py @@ -145,19 +145,19 @@ def get_earliest_data_date(tn_block: TNAccessBlock, stream_id: str) -> Optional[ raise TNQueryError(str(e)) from e -def ensure_unix_timestamp(dt: pd.Series) -> pd.Series: # type: ignore -> pd Series doesn't fit with any +def ensure_unix_timestamp(dt: pd.Series) -> pd.Series: # type: ignore -> pd Series doesn't fit with any """Convert datetime series to Unix timestamp (seconds since epoch). - + This function handles various datetime formats and ensures the output is always in seconds since epoch (Unix timestamp). It explicitly handles the conversion from nanoseconds and validates the output range. - + Args: dt: A pandas Series containing datetime data in various formats - + Returns: A pandas Series containing Unix timestamps (seconds since epoch) - + Raises: ValueError: If the resulting timestamps are outside the valid range or if the conversion results in unexpected units @@ -165,24 +165,24 @@ def ensure_unix_timestamp(dt: pd.Series) -> pd.Series: # type: ignore -> pd Seri # Convert to datetime if not already if not pd.api.types.is_datetime64_any_dtype(dt): dt = pd.to_datetime(dt, utc=True) - + # Get nanoseconds since epoch - ns_timestamps = dt.astype('int64') - + ns_timestamps = dt.astype("int64") + # Convert to seconds (integer division by 1e9 for nanoseconds) second_timestamps = ns_timestamps // 10**9 - + # Validate the range (basic sanity check) # Unix timestamps should be between 1970 and 2100 approximately min_valid_timestamp = 0 # 1970-01-01 max_valid_timestamp = 4102444800 # 2100-01-01 - + if (second_timestamps < min_valid_timestamp).any() or (second_timestamps > max_valid_timestamp).any(): raise ValueError( f"Converted timestamps outside valid range: " f"min={second_timestamps.min()}, max={second_timestamps.max()}" ) - + return second_timestamps @@ -292,7 +292,7 @@ def run_ticker_pipeline( Exceptions during processing are logged and raised. """ - def process_ticker(row: pd.Series) -> TickerResult: # type: ignore -> pd Series doesn't fit with any + def process_ticker(row: pd.Series) -> TickerResult: # type: ignore -> pd Series doesn't fit with any """Process a single ticker row.""" symbol = row["source_id"] stream_id = row["stream_id"] diff --git a/src/tsn_adapters/flows/fmp/real_time_flow.py b/src/tsn_adapters/flows/fmp/real_time_flow.py index b5b0cef..7f2f112 100644 --- a/src/tsn_adapters/flows/fmp/real_time_flow.py +++ b/src/tsn_adapters/flows/fmp/real_time_flow.py @@ -97,21 +97,22 @@ def convert_quotes_to_tn_data( return DataFrame[TnDataRowModel](result_df) + def ensure_unix_timestamp(time: int) -> int: """ Ensure the timestamp is a valid unix timestamp (seconds since epoch). - + This function validates and converts timestamps to ensure they are in seconds since epoch format. It handles cases where the input might be in milliseconds or microseconds. - + Args: time: Integer timestamp that might be in seconds, milliseconds, microseconds, or nanoseconds since epoch - + Returns: Integer timestamp in seconds since epoch - + Raises: ValueError: If the timestamp is invalid or outside the reasonable range """ @@ -121,7 +122,7 @@ def ensure_unix_timestamp(time: int) -> int: # Define valid range for seconds since epoch min_valid_timestamp = 0 # 1970-01-01 max_valid_timestamp = 4102444800 # 2100-01-01 - + # Convert to seconds if in a larger unit converted_time = time if time > max_valid_timestamp: @@ -134,13 +135,11 @@ def ensure_unix_timestamp(time: int) -> int: converted_time = time // 10**3 else: # assume seconds but with some future date converted_time = time // 10**9 # aggressive conversion to be safe - + # Validate the range after conversion if converted_time < min_valid_timestamp or converted_time > max_valid_timestamp: - raise ValueError( - f"Timestamp outside valid range (1970-2100): {converted_time}" - ) - + raise ValueError(f"Timestamp outside valid range (1970-2100): {converted_time}") + return converted_time diff --git a/src/tsn_adapters/flows/stream_deploy_flow.py b/src/tsn_adapters/flows/stream_deploy_flow.py index 038a9ed..75aac68 100644 --- a/src/tsn_adapters/flows/stream_deploy_flow.py +++ b/src/tsn_adapters/flows/stream_deploy_flow.py @@ -8,7 +8,6 @@ concurrency limiter and then waits for confirmation using tn-read concurrency via TNAccessBlock. """ -import time from typing import Literal from pandera.typing import DataFrame @@ -32,6 +31,7 @@ class DeployStreamResult(TypedDict): stream_id: str status: Literal["deployed", "skipped"] + @task(tags=["tn", "tn-write"], retries=3, retry_delay_seconds=5) def check_and_deploy_stream(stream_id: str, tna_block: TNAccessBlock, is_unix: bool = False) -> DeployStreamResult: """ @@ -109,7 +109,7 @@ def deploy_streams_flow( logger.info(f"Found {len(stream_ids)} stream descriptors.") # we will deploy in batches of 500 to avoid infinite threads creation - batches = [stream_ids[i:i+batch_size] for i in range(0, len(stream_ids), batch_size)] + batches = [stream_ids[i : i + batch_size] for i in range(0, len(stream_ids), batch_size)] aggregated_results: list[DeployStreamResult] = [] for batch in batches[start_from_batch:]: diff --git a/src/tsn_adapters/tasks/argentina/__init__.py b/src/tsn_adapters/tasks/argentina/__init__.py index 6a159dc..90d18f7 100644 --- a/src/tsn_adapters/tasks/argentina/__init__.py +++ b/src/tsn_adapters/tasks/argentina/__init__.py @@ -5,9 +5,11 @@ from tsn_adapters.tasks.argentina.models.sepa import SepaWebsiteDataItem from tsn_adapters.tasks.argentina.models.sepa.website_item import SepaWebsiteScraper from tsn_adapters.tasks.argentina.utils.processors import SepaDirectoryProcessor +from tsn_adapters.tasks.argentina.errors import errors __all__ = [ "SepaWebsiteDataItem", "SepaWebsiteScraper", "SepaDirectoryProcessor", + "errors", ] diff --git a/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py b/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py index dfd9e04..cd8a63f 100644 --- a/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py +++ b/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py @@ -6,6 +6,11 @@ """ from tsn_adapters.tasks.argentina.aggregate.uncategorized import get_uncategorized_products +from tsn_adapters.tasks.argentina.errors import ( + EmptyCategoryMapError, + ErrorAccumulator, + UncategorizedProductsError, +) from tsn_adapters.tasks.argentina.types import AggregatedPricesDF, AvgPriceDF, CategoryMapDF, UncategorizedDF from tsn_adapters.utils.logging import get_logger_safe @@ -40,11 +45,16 @@ def aggregate_prices_by_category( 1. Merges product prices with their category assignments 2. Groups the data by category and date 3. Computes the mean price for each category-date combination + + Raises + ------ + EmptyCategoryMapError + If the product category mapping DataFrame is empty """ # Input validation and logging if product_category_map_df.empty: logger.error("Product category mapping DataFrame is empty") - raise ValueError("Cannot aggregate prices: product category mapping is empty") + raise EmptyCategoryMapError(url="") if avg_price_product_df.empty: logger.warning("Average price product DataFrame is empty") @@ -56,17 +66,6 @@ def aggregate_prices_by_category( f"and {len(avg_price_product_df)} price records" ) - # Check for products without categories before merge - unique_products_with_prices = set(avg_price_product_df["id_producto"].unique()) - unique_products_with_categories = set(product_category_map_df["id_producto"].unique()) - - products_without_categories = unique_products_with_prices - unique_products_with_categories - if products_without_categories: - logger.warning( - f"Found {len(products_without_categories)} products with prices but no category mapping. " - "These will be excluded from aggregation." - ) - # Merge the product categories with prices merged_df = avg_price_product_df.merge( product_category_map_df, @@ -94,4 +93,13 @@ def aggregate_prices_by_category( uncategorized_df = get_uncategorized_products(avg_price_product_df, product_category_map_df) + if not uncategorized_df.empty: + logger.warning(f"Found {len(uncategorized_df)} uncategorized products") + accumulator = ErrorAccumulator.get_or_create_from_context() + accumulator.add_error( + UncategorizedProductsError( + count=len(uncategorized_df), + ) + ) + return AggregatedPricesDF(aggregated_df), uncategorized_df diff --git a/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py b/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py index c6c0b49..ab6abe3 100644 --- a/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py +++ b/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py @@ -10,7 +10,4 @@ def get_uncategorized_products( """ diff_df = data[~data["id_producto"].isin(category_map["id_producto"])] - # get data without id_producto (=null) - data[data["id_producto"].isnull()] - return UncategorizedDF(diff_df) diff --git a/src/tsn_adapters/tasks/argentina/errors/__init__.py b/src/tsn_adapters/tasks/argentina/errors/__init__.py new file mode 100644 index 0000000..8d36784 --- /dev/null +++ b/src/tsn_adapters/tasks/argentina/errors/__init__.py @@ -0,0 +1,33 @@ +""" +Error handling for Argentina SEPA processing. +""" + +from tsn_adapters.tasks.argentina.errors.accumulator import ErrorAccumulator +from tsn_adapters.tasks.argentina.errors.errors import ( + AccountableRole, + ArgentinaSEPAError, + DateMismatchError, + EmptyCategoryMapError, + InvalidCategorySchemaError, + InvalidCSVSchemaError, + InvalidDateFormatError, + InvalidStructureZIPError, + InvalidProductsError, + MissingProductosCSVError, + UncategorizedProductsError, +) + +__all__ = [ + "ErrorAccumulator", + "AccountableRole", + "ArgentinaSEPAError", + "DateMismatchError", + "EmptyCategoryMapError", + "InvalidCategorySchemaError", + "InvalidCSVSchemaError", + "InvalidDateFormatError", + "InvalidStructureZIPError", + "InvalidProductsError", + "MissingProductosCSVError", + "UncategorizedProductsError", +] \ No newline at end of file diff --git a/src/tsn_adapters/tasks/argentina/errors/accumulator.py b/src/tsn_adapters/tasks/argentina/errors/accumulator.py new file mode 100644 index 0000000..0558847 --- /dev/null +++ b/src/tsn_adapters/tasks/argentina/errors/accumulator.py @@ -0,0 +1,71 @@ +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any + +from prefect.artifacts import create_markdown_artifact + +from tsn_adapters.tasks.argentina.errors.errors import ArgentinaSEPAError + +error_ctx = ContextVar["ErrorAccumulator | None"]("argentina_sepa_errors") + + +class ErrorAccumulator: + def __init__(self): + self.errors: list[ArgentinaSEPAError] = [] + + def add_error(self, error: ArgentinaSEPAError): + self.errors.append(error) + + def model_dump(self): + return [error.to_dict() for error in self.errors] + + def model_load(self, data: list[dict[str, Any]]): + self.errors = [ArgentinaSEPAError(**error) for error in data] + + @classmethod + def get_or_create_from_context(cls) -> "ErrorAccumulator": + errors = error_ctx.get() + if errors is None: + errors = ErrorAccumulator() + errors.set_to_context() + return errors + + def set_to_context(self): + error_ctx.set(self) + +@contextmanager +def error_collection(): + """Context manager for collecting errors during processing.""" + accumulator = ErrorAccumulator() + accumulator.set_to_context() + try: + yield accumulator + except ArgentinaSEPAError as e: + accumulator.add_error(e) + raise + finally: + # Create error summary if there are errors + if accumulator.errors: + error_summary = [ + "# Processing Errors\n", + "The following errors occurred during processing:\n", + ] + for error in accumulator.errors: + error_dict = error.to_dict() + error_summary.extend( + [ + f"\n## {error_dict['code']}: {error_dict['message']}\n", + f"Responsibility: {error_dict['responsibility'].value}\n", + "Context:\n", + "```json\n", + f"{error_dict['context']}\n", + "```\n", + ] + ) + + # Create markdown artifact with the error summary + create_markdown_artifact( + key="processing-errors", + markdown="".join(error_summary), + description="Errors encountered during SEPA data processing", + ) diff --git a/src/tsn_adapters/tasks/argentina/errors/context.py b/src/tsn_adapters/tasks/argentina/errors/context.py new file mode 100644 index 0000000..4004a24 --- /dev/null +++ b/src/tsn_adapters/tasks/argentina/errors/context.py @@ -0,0 +1,47 @@ +from collections.abc import Generator +from contextlib import contextmanager +import contextvars +from typing import Any + +# A context variable to hold a dictionary of key/value pairs for error context. +_error_context_var = contextvars.ContextVar("error_context", default={}) + + +def set_error_context(key: str, value: Any) -> None: + """ + Set a key/value pair in the global error context. + """ + current = _error_context_var.get().copy() + current[key] = value + _error_context_var.set(current) + + +def get_error_context() -> dict[str, Any]: + """ + Retrieve the current global error context. + """ + return _error_context_var.get() + + +def clear_error_context() -> None: + """ + Clear the global error context. + """ + _error_context_var.set({}) + + +@contextmanager +def error_context(**kwargs: Any) -> Generator[None, None, None]: + """ + Context manager to temporarily update the error context. + Example: + with error_context(store_id="STORE-123", user_id="USER-456"): + ...your logic... + """ + current = get_error_context().copy() + current.update(kwargs) + token = _error_context_var.set(current) + try: + yield + finally: + _error_context_var.reset(token) diff --git a/src/tsn_adapters/tasks/argentina/errors/context_helper.py b/src/tsn_adapters/tasks/argentina/errors/context_helper.py new file mode 100644 index 0000000..99a3074 --- /dev/null +++ b/src/tsn_adapters/tasks/argentina/errors/context_helper.py @@ -0,0 +1,46 @@ +""" +Helper class to store known context attributes for error management in Argentina flows. + +This class encapsulates common context attributes (e.g., store_id, date, file_key) +which can be used during error reporting. When an error is raised, you can pass +the dictionary produced by this helper to automatically include this extra metadata. +""" + +from typing import Any, Optional + +from tsn_adapters.tasks.argentina.errors.context import get_error_context, set_error_context + + +class ContextProperty: + """Descriptor for managing context properties in a DRY and type-safe manner.""" + + def __init__(self, key: str) -> None: + self.key = key + + def __get__(self, instance: Any, owner: Any) -> Optional[str]: + return get_error_context().get(self.key) + + def __set__(self, instance: Any, value: str) -> None: + set_error_context(self.key, value) + # Optionally update the instance dictionary if needed + instance.__dict__[self.key] = value + + +class ArgentinaErrorContext: + """ + Helper class for managing known context attributes in Argentina flows. + + Attributes + ---------- + store_id : Optional[str] + The identifier for the store, if applicable. + date : Optional[str] + The relevant date (e.g., report date or flow execution date) in ISO format (YYYY-MM-DD). + file_key : Optional[str] + The file key or resource identifier involved in the error. + """ + + # DRY properties using the descriptor: + store_id: ContextProperty = ContextProperty("store_id") + date: ContextProperty = ContextProperty("date") + file_key: ContextProperty = ContextProperty("file_key") diff --git a/src/tsn_adapters/tasks/argentina/errors/errors.py b/src/tsn_adapters/tasks/argentina/errors/errors.py new file mode 100644 index 0000000..93f0947 --- /dev/null +++ b/src/tsn_adapters/tasks/argentina/errors/errors.py @@ -0,0 +1,215 @@ +""" +Structured error handling for Argentina SEPA processing. +""" + +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel + +# NEW: Import the global error context helper +from tsn_adapters.tasks.argentina.errors.context_helper import ArgentinaErrorContext + + +class AccountableRole(Enum): + """Roles responsible for different types of errors.""" + + DATA_PROVIDER = "Data Provider" + DATA_ENGINEERING = "Data Engineering" + DEVELOPMENT = "Development" + SYSTEM = "System" + + +class ArgentinaSEPAErrorData(BaseModel): + """Data model for Argentina SEPA errors.""" + + code: str + message: str + responsibility: AccountableRole + context: dict[str, Any] = {} + + +class ArgentinaSEPAError(Exception): + """Base class for all Argentina SEPA processing errors.""" + + def __init__( + self, code: str, message: str, responsibility: AccountableRole, context: Optional[dict[str, Any]] = None + ): + super().__init__(message) + self.data = ArgentinaSEPAErrorData( + code=code, + message=message, + responsibility=responsibility, + context=context or {}, + ) + + @property + def code(self) -> str: + return self.data.code + + @property + def message(self) -> str: + return self.data.message + + @property + def responsibility(self) -> AccountableRole: + return self.data.responsibility + + @property + def context(self) -> dict[str, Any]: + return self.data.context + + def to_dict(self) -> dict[str, Any]: + """Convert error to dictionary format.""" + return { + "code": self.code, + "message": self.message, + "responsibility": self.responsibility, + "context": self.context, + } + + +# -------------------------------------------------- +# Input Validation Errors (100-199) +# -------------------------------------------------- +class InvalidStructureZIPError(ArgentinaSEPAError): + """Invalid ZIP file structure during extraction""" + + def __init__(self, error: str): + ctx = ArgentinaErrorContext() + super().__init__( + code="ARG-100", + message="Invalid ZIP file structure - cannot extract files", + responsibility=AccountableRole.DATA_PROVIDER, + context={ + "date": ctx.date, + "error": error, + }, + ) + + +class InvalidDateFormatError(ArgentinaSEPAError): + """Invalid date format in flow input""" + + def __init__(self, date_str: str): + super().__init__( + code="ARG-101", + message=f"Invalid date format: {date_str} - must be YYYY-MM-DD", + responsibility=AccountableRole.SYSTEM, + context={"invalid_date": date_str}, + ) + + +class MissingProductosCSVError(ArgentinaSEPAError): + """Missing productos.csv in ZIP file""" + + def __init__(self, directory: str, available_files: list[str]): + super().__init__( + code="ARG-102", + message="Missing productos.csv in ZIP archive", + responsibility=AccountableRole.DATA_PROVIDER, + context={"directory": directory, "available_files": ", ".join(available_files)}, + ) + + +# -------------------------------------------------- +# Data Processing Errors (200-299) +# -------------------------------------------------- +class DateMismatchError(ArgentinaSEPAError): + """Filename vs content date mismatch""" + + def __init__(self, internal_date: str): + ctx = ArgentinaErrorContext() + super().__init__( + code="ARG-200", + message=f"Date mismatch: Reported {ctx.date} vs Actual {internal_date}", + responsibility=AccountableRole.DATA_PROVIDER, + context={ + "external_date": ctx.date, + "internal_date": internal_date, + }, + ) + + +class InvalidCSVSchemaError(ArgentinaSEPAError): + """Missing required columns in RAW data""" + + def __init__(self, error: str): + ctx = ArgentinaErrorContext() + super().__init__( + code="ARG-201", + message="Missing required columns", + responsibility=AccountableRole.DATA_PROVIDER, + context={ + "date": ctx.date, + "store_id": ctx.store_id, + "error": error, + }, + ) + + +class InvalidProductsError(ArgentinaSEPAError): + """Null/empty product IDs found in RAW data""" + + def __init__(self, invalid_indexes: list[int]): + ctx = ArgentinaErrorContext() + invalid_indexes_str = ", ".join(str(idx) for idx in invalid_indexes) + super().__init__( + code="ARG-202", + message=f"{len(invalid_indexes)} products with invalid IDs", + responsibility=AccountableRole.DEVELOPMENT, + context={"invalid_indexes": invalid_indexes_str, "date": ctx.date, "store_id": ctx.store_id}, + ) + + +# -------------------------------------------------- +# Category Mapping Errors (300-399) +# -------------------------------------------------- +class EmptyCategoryMapError(ArgentinaSEPAError): + """Empty category mapping DataFrame""" + + def __init__(self, url: str): + super().__init__( + code="ARG-300", + message="Category mapping is empty", + responsibility=AccountableRole.DATA_ENGINEERING, + context={"url": url}, + ) + + +class UncategorizedProductsError(ArgentinaSEPAError): + """Products without category mapping""" + + def __init__(self, count: int): + ctx = ArgentinaErrorContext() + super().__init__( + code="ARG-301", + message=f"{count} uncategorized products found", + responsibility=AccountableRole.DATA_ENGINEERING, + context={"uncategorized_count": str(count), "date": ctx.date}, + ) + + +class InvalidCategorySchemaError(ArgentinaSEPAError): + """Invalid category mapping schema""" + + def __init__(self, error: str, url: str): + super().__init__( + code="ARG-302", + message="Invalid category mapping schema", + responsibility=AccountableRole.DATA_ENGINEERING, + context={"error": error, "url": url}, + ) + + +all_errors = [ + InvalidStructureZIPError, + InvalidDateFormatError, + MissingProductosCSVError, + DateMismatchError, + InvalidCSVSchemaError, + InvalidProductsError, + EmptyCategoryMapError, + UncategorizedProductsError, + InvalidCategorySchemaError, +] diff --git a/src/tsn_adapters/tasks/argentina/flows/base.py b/src/tsn_adapters/tasks/argentina/flows/base.py index d637bb6..f893451 100644 --- a/src/tsn_adapters/tasks/argentina/flows/base.py +++ b/src/tsn_adapters/tasks/argentina/flows/base.py @@ -2,9 +2,13 @@ Base flow controller for Argentina SEPA data processing. """ +from datetime import datetime +import re + from prefect import get_run_logger from prefect_aws import S3Bucket +from tsn_adapters.tasks.argentina.errors.errors import InvalidDateFormatError from tsn_adapters.tasks.argentina.provider import ProcessedDataProvider, RawDataProvider from tsn_adapters.tasks.argentina.types import DateStr @@ -32,9 +36,26 @@ def validate_date(self, date: DateStr) -> None: date: Date string to validate Raises: - ValueError: If date format is invalid + InvalidDateFormatError: If date format is invalid """ - import re + # First check for separators + if "/" in date: + raise InvalidDateFormatError(date, "wrong_separator") + if "-" not in date: + raise InvalidDateFormatError(date, "no_separator") + # Check basic format if not re.match(r"\d{4}-\d{2}-\d{2}", date): - raise ValueError(f"Invalid date format: {date}") + raise InvalidDateFormatError(date, "wrong_format") + + # Check date validity + try: + datetime.strptime(date, "%Y-%m-%d") + except ValueError as e: + error_msg = str(e) + if "month must be in 1..12" in error_msg: + raise InvalidDateFormatError(date, "invalid_month") + if "day is out of range for month" in error_msg or "day must be in" in error_msg: + raise InvalidDateFormatError(date, "invalid_day") + # Only raise wrong_format if it's not a specific month/day error + raise InvalidDateFormatError(date, "wrong_format") diff --git a/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py b/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py index 5dd3e71..93ce5c5 100644 --- a/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py +++ b/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py @@ -17,11 +17,17 @@ from prefect_aws import S3Bucket from tsn_adapters.tasks.argentina.aggregate import aggregate_prices_by_category +from tsn_adapters.tasks.argentina.errors import ( + ArgentinaSEPAError, +) +from tsn_adapters.tasks.argentina.errors.accumulator import error_collection +from tsn_adapters.tasks.argentina.errors.context_helper import ArgentinaErrorContext from tsn_adapters.tasks.argentina.flows.base import ArgentinaFlowController from tsn_adapters.tasks.argentina.models.sepa.sepa_models import SepaAvgPriceProductModel from tsn_adapters.tasks.argentina.provider.s3 import RawDataProvider from tsn_adapters.tasks.argentina.task_wrappers import task_load_category_map from tsn_adapters.tasks.argentina.types import AggregatedPricesDF, CategoryMapDF, DateStr, SepaDF, UncategorizedDF +from tsn_adapters.tasks.argentina.utils.processors import process_raw_data from tsn_adapters.utils import deroutine @@ -33,7 +39,7 @@ def process_raw_data( """Process raw SEPA data. Args: - data_item: Raw data item to process + raw_data: Raw data item to process category_map_df: Category mapping DataFrame Returns: @@ -82,11 +88,20 @@ def run_flow(self) -> None: 3. If not, process the date """ logger = get_run_logger() - for date in self.raw_provider.list_available_keys(): - if self.processed_provider.exists(date): - logger.info(f"Skipping {date} because it already exists") - continue - self.process_date(date) + with error_collection() as accumulator: + for date in self.raw_provider.list_available_keys(): + if self.processed_provider.exists(date): + logger.info(f"Skipping {date} because it already exists") + continue + try: + self.process_date(date) + except Exception as e: + if isinstance(e, ArgentinaSEPAError): + accumulator.add_error(e) + logger.warning(f"Collected error for {date}: {e}") + else: + logger.error(f"Unexpected error processing {date}: {e}") + raise def process_date(self, date: DateStr) -> None: """Process data for a specific date. @@ -97,12 +112,17 @@ def process_date(self, date: DateStr) -> None: Raises: ValueError: If date format is invalid KeyError: If no data available for date + EmptyCategoryMapError: If category map is empty + InvalidCSVSchemaError: If category map has invalid schema """ logger = get_run_logger() self.validate_date(date) logger.info(f"Processing {date}") + # Set up error context + ArgentinaErrorContext().date = date + # Get raw data raw_data = self.raw_provider.get_raw_data_for(date) if raw_data.empty: diff --git a/src/tsn_adapters/tasks/argentina/provider/data_processor.py b/src/tsn_adapters/tasks/argentina/provider/data_processor.py index 2629389..8aff109 100644 --- a/src/tsn_adapters/tasks/argentina/provider/data_processor.py +++ b/src/tsn_adapters/tasks/argentina/provider/data_processor.py @@ -7,24 +7,13 @@ import tempfile from typing import cast -import pandas as pd from prefect.concurrency.sync import concurrency +from tsn_adapters.tasks.argentina.errors.errors import DateMismatchError, InvalidStructureZIPError from tsn_adapters.tasks.argentina.types import DateStr, SepaDF from tsn_adapters.tasks.argentina.utils.processors import SepaDirectoryProcessor -class DatesNotMatchError(Exception): - """Exception raised when the date does not match.""" - - def __init__(self, real_date: DateStr, reported_date: DateStr, source: str): - """Initialize the exception.""" - self.real_date = real_date - self.reported_date = reported_date - self.source = source - super().__init__(f"Real date {real_date} does not match {source} date {reported_date}") - - def process_sepa_zip( zip_reader: Generator[bytes, None, None], reported_date: DateStr, @@ -34,7 +23,7 @@ def process_sepa_zip( Process SEPA data from a data item. Args: - + zip_reader: Generator yielding bytes of the ZIP file source_name: Name of the source (for error messages) reported_date: The date reported by the source @@ -42,7 +31,8 @@ def process_sepa_zip( DataFrame: The processed SEPA data Raises: - DatesNotMatchError: If the date in the data doesn't match the reported date + DateMismatchError: If the date in the data doesn't match the reported date + InvalidStructureZIPError: If the ZIP file structure is invalid ValueError: If the data is invalid """ # Create a temporary directory for extraction @@ -60,17 +50,18 @@ def process_sepa_zip( # Process the data extract_dir = os.path.join(temp_dir, "data") os.makedirs(extract_dir, exist_ok=True) - processor = SepaDirectoryProcessor.from_zip_path(temp_zip_path, extract_dir) - df = processor.get_all_products_data_merged() - - # skip empty dataframes - if df.empty: - return cast(SepaDF, pd.DataFrame()) - - # Validate the date matches - real_date = df["date"].iloc[0] - if reported_date != real_date: - # we need to raise an error so cache is invalidated - raise DatesNotMatchError(real_date, reported_date, source_name) - - return cast(SepaDF, df) + try: + processor = SepaDirectoryProcessor.from_zip_path(temp_zip_path, extract_dir) + data = processor.get_all_products_data_merged() + + # Validate date matches + if not data.empty: + data_date: str = cast(str, data["date"].iloc[0]) + if str(data_date) != str(reported_date): + raise DateMismatchError(internal_date=data_date) + + return data + except Exception as e: + if isinstance(e, DateMismatchError): + raise + raise InvalidStructureZIPError(str(e)) from e diff --git a/src/tsn_adapters/tasks/argentina/target/trufnetwork.py b/src/tsn_adapters/tasks/argentina/target/trufnetwork.py index b886942..55490ac 100644 --- a/src/tsn_adapters/tasks/argentina/target/trufnetwork.py +++ b/src/tsn_adapters/tasks/argentina/target/trufnetwork.py @@ -16,7 +16,6 @@ task_insert_and_wait_for_tx, task_read_records, task_split_and_insert_records, - task_wait_for_tx, ) from tsn_adapters.common.interfaces.target import ITargetClient from tsn_adapters.common.trufnetwork.models.tn_models import TnDataRowModel diff --git a/src/tsn_adapters/tasks/argentina/task_wrappers.py b/src/tsn_adapters/tasks/argentina/task_wrappers.py index 23c3d4d..491194a 100644 --- a/src/tsn_adapters/tasks/argentina/task_wrappers.py +++ b/src/tsn_adapters/tasks/argentina/task_wrappers.py @@ -15,6 +15,7 @@ from tsn_adapters.common.interfaces.target import ITargetClient from tsn_adapters.common.interfaces.transformer import IDataTransformer from tsn_adapters.common.trufnetwork.models.tn_models import TnDataRowModel +from tsn_adapters.tasks.argentina.errors.errors import EmptyCategoryMapError, InvalidCategorySchemaError from tsn_adapters.tasks.argentina.models.category_map import SepaProductCategoryMapModel from tsn_adapters.tasks.argentina.provider.factory import create_sepa_processed_provider from tsn_adapters.tasks.argentina.reconciliation.strategies import create_reconciliation_strategy @@ -243,17 +244,21 @@ def task_get_and_transform_data( @task(retries=3, cache_expiration=timedelta(hours=1), cache_policy=policies.INPUTS + policies.TASK_SOURCE) -def task_load_category_map(url: str) -> pd.DataFrame: +def task_load_category_map(url: str) -> DataFrame[SepaProductCategoryMapModel]: """Load the product category mapping from a URL.""" logger = get_run_logger() logger.info(f"Loading category map from: {url}") try: df = SepaProductCategoryMapModel.from_url(url, sep="|", compression="zip") + if df.empty: + raise EmptyCategoryMapError(url=url) logger.info(f"Loaded {len(df)} category mappings") return df + except EmptyCategoryMapError: + raise except Exception as e: logger.error(f"Failed to load category map: {e}") - raise + raise InvalidCategorySchemaError(error=str(e), url=url) @task diff --git a/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py b/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py index 4c928ee..5d852d9 100644 --- a/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py +++ b/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py @@ -6,12 +6,18 @@ from pathlib import Path import re import tempfile -from typing import ClassVar, cast +from typing import Any, ClassVar import pandas as pd from pandera.typing import DataFrame from pydantic import BaseModel, field_validator +from tsn_adapters.tasks.argentina.errors.accumulator import ErrorAccumulator +from tsn_adapters.tasks.argentina.errors.errors import ( + InvalidCSVSchemaError, + InvalidDateFormatError, + MissingProductosCSVError, +) from tsn_adapters.tasks.argentina.models.sepa.sepa_models import ( SepaProductosAlternativeModel, SepaProductosDataModel, @@ -82,6 +88,9 @@ def get_products_data(self) -> Iterator[DataFrame[SepaProductosDataModel]]: for data_dir in self.get_data_dirs(): try: yield data_dir.load_products_data() + except MissingProductosCSVError as e: + self.logger.error(f"Error loading products data from {data_dir.dir_path}: {e}") + raise except Exception as e: self.logger.error(f"Error loading products data from {data_dir.dir_path}: {e}") @@ -92,7 +101,15 @@ def get_all_products_data_merged(self) -> DataFrame[SepaProductosDataModel]: all_data = list(self.get_products_data()) if not all_data: self.logger.warning("No product data found in any directory") - return cast(DataFrame[SepaProductosDataModel], pd.DataFrame()) # Return empty DataFrame if no data + empty_df = pd.DataFrame( + { + "id_producto": [], + "productos_descripcion": [], + "productos_precio_lista": [], + "date": [], + } + ) + return DataFrame[SepaProductosDataModel](empty_df) merged_df = pd.concat(all_data, ignore_index=True) self.logger.info(f"Successfully merged {len(all_data)} product data frames with {len(merged_df)} total rows") @@ -124,7 +141,7 @@ class SepaDataDirectory(BaseModel): @classmethod def validate_date(cls, v: str) -> str: if not re.match(r"\d{4}-\d{2}-\d{2}", v): - raise ValueError(f"Invalid date format: {v}") + raise InvalidDateFormatError(v) return v @staticmethod @@ -180,7 +197,7 @@ def load_products_data(self) -> DataFrame[SepaProductosDataModel]: break else: self.logger.error(f"No valid product file found in {self.dir_path}") - raise ValueError(f"No valid product file found in {self.dir_path}") + raise MissingProductosCSVError(directory=self.dir_path, available_files=files) def read_file_iterator(): with open(file_path) as file: @@ -191,7 +208,15 @@ def read_file_iterator(): text_lines = list(read_file_iterator()) if not text_lines: self.logger.warning(f"Empty product file found in {self.dir_path}") - return cast(DataFrame[SepaProductosDataModel], pd.DataFrame()) + empty_df = pd.DataFrame( + { + "id_producto": [], + "productos_descripcion": [], + "productos_precio_lista": [], + "date": [], + } + ) + return DataFrame[SepaProductosDataModel](empty_df) # models might be models: list[type[SepaProductosAlternativeModel]] = [ @@ -201,10 +226,16 @@ def read_file_iterator(): for model in models: if model.has_columns(text_lines[0]): - original_model = model.from_csv(self.date, text_lines, self.logger) - if model != SepaProductosDataModel: - self.logger.info(f"Converting {model} to core model") - return model.to_core_model(original_model) - return original_model - - raise ValueError("No valid model found for the given lines") + try: + # Parse and convert the data + raw_df: DataFrame[Any] = model.from_csv(self.date, text_lines, self.logger) + if model != SepaProductosDataModel: + self.logger.info(f"Converting {model} to core model") + return model.to_core_model(raw_df) + return raw_df + except Exception as e: + self.logger.error(f"Failed to parse CSV with model {model}: {e}") + ErrorAccumulator.get_or_create_from_context().add_error(InvalidCSVSchemaError(error=str(e))) + continue + + raise InvalidCSVSchemaError(error="No valid model found") diff --git a/src/tsn_adapters/utils/filter_failures.py b/src/tsn_adapters/utils/filter_failures.py index 70ca2b2..51774a9 100644 --- a/src/tsn_adapters/utils/filter_failures.py +++ b/src/tsn_adapters/utils/filter_failures.py @@ -5,6 +5,9 @@ from pandera.errors import SchemaErrors from pandera.typing import DataFrame +from tsn_adapters.tasks.argentina.errors.accumulator import ErrorAccumulator +from tsn_adapters.tasks.argentina.errors.errors import InvalidProductsError + U = TypeVar("U", bound=pa.DataFrameModel) @@ -25,10 +28,15 @@ def filter_failures(original: pd.DataFrame, model: type[U]) -> DataFrame[U]: return cast(DataFrame[U], validated) except SchemaErrors as exc: if exc.failure_cases is not None: + assert isinstance(exc.failure_cases, pd.DataFrame) # Get the failure indices and drop those rows failure_indices = exc.failure_cases["index"].unique() filtered_df = original.drop(failure_indices) + # Add the error to the error accumulator + failure_indexes_list: list[int] = failure_indices.tolist() + ErrorAccumulator.get_or_create_from_context().add_error(InvalidProductsError(invalid_indexes=failure_indexes_list)) + # Re-validate the filtered DataFrame validated_filtered = model.validate(filtered_df, lazy=False) return cast(DataFrame[U], validated_filtered) diff --git a/tests/argentina/errors/test_error_accumulation.py b/tests/argentina/errors/test_error_accumulation.py new file mode 100644 index 0000000..90b1858 --- /dev/null +++ b/tests/argentina/errors/test_error_accumulation.py @@ -0,0 +1,154 @@ +""" +Test suite for Argentina SEPA error handling system. + +This test suite verifies the error handling system's core functionalities: + +1. Error Accumulation: + - Errors can be collected and accumulated during processing + - Each error maintains its structured data (code, message, responsibility, context) + - Errors are properly serialized for reporting + +2. Context Isolation: + - Multiple concurrent tasks can maintain isolated error contexts + - Errors from one task don't leak into another task's context + - Each task's error context is properly managed and cleaned up + +3. Error Reporting: + - Accumulated errors are properly formatted into markdown artifacts + - Error reports include all necessary information (code, message, responsibility, context) + - Artifacts are created at appropriate points in the flow + +Key Conclusions: +- The error handling system is thread-safe and suitable for concurrent execution +- Error contexts are properly isolated between different tasks and flows +- The system successfully maintains error traceability and accountability +- Error reporting provides clear, structured information for debugging and monitoring + +Test Structure: +- error_accumulation_flow: Tests basic error collection and accumulation +- concurrent_error_contexts_flow: Verifies context isolation in concurrent execution +- test_error_reporting: Ensures proper artifact creation and formatting +- test_context_isolation: Validates complete isolation between concurrent tasks +""" + +from collections.abc import Sequence +from unittest.mock import patch + +from prefect import flow, task +import pytest + +from tsn_adapters.tasks.argentina.errors import DateMismatchError, EmptyCategoryMapError, ErrorAccumulator +from tsn_adapters.tasks.argentina.errors.accumulator import error_collection +from tsn_adapters.tasks.argentina.errors.context_helper import ArgentinaErrorContext + + +@task +def task_that_raises_error(): + """Task that intentionally raises a known error""" + accumulator = ErrorAccumulator.get_or_create_from_context() + # Set up context for DateMismatchError + ctx = ArgentinaErrorContext() + ctx.date = "2024-01-01" + + accumulator.add_error(EmptyCategoryMapError(url="test://invalid-map")) + accumulator.add_error(DateMismatchError(internal_date="2024-01-02")) + + +@task +def task_with_isolated_errors_1(): + """First task with its own error context""" + with error_collection() as task1_accumulator: + task1_accumulator.add_error(EmptyCategoryMapError(url="task1://error")) + return task1_accumulator.model_dump() + + +@task +def task_with_isolated_errors_2(): + """Second task with its own error context""" + with error_collection() as task2_accumulator: + task2_accumulator.add_error(DateMismatchError(internal_date="2024-03-02")) + return task2_accumulator.model_dump() + + +@flow +def error_accumulation_flow(): + """Flow that forces error conditions and reports errors via a markdown artifact""" + with error_collection() as accumulator: + # Use an intermediate variable and add type: ignore to silence linter issues + future = task_that_raises_error.submit(return_state=True) # type: ignore + future.result() + + # Verify errors are collected during processing + assert len(accumulator.errors) == 2, "Should collect 2 errors" + assert isinstance(accumulator.errors[0], EmptyCategoryMapError) + assert isinstance(accumulator.errors[1], DateMismatchError) + + # Let the context manager finish and create the artifact + return accumulator + + +@flow +def concurrent_error_contexts_flow(): + """Flow that tests error context isolation between tasks""" + # Submit both tasks concurrently with type ignore to silence linter warnings + task1_future = task_with_isolated_errors_1.submit(return_state=True) # type: ignore + task2_future = task_with_isolated_errors_2.submit(return_state=True) # type: ignore + + # Get results with type ignore as well + task1_errors = task1_future.result() + task2_errors = task2_future.result() + + # Verify task1 errors + assert len(task1_errors) == 1, "Task 1 should have 1 error" + assert task1_errors[0]["code"] == "ARG-300" + assert "task1://error" in task1_errors[0]["context"]["url"] + + # Verify task2 errors + assert len(task2_errors) == 1, "Task 2 should have 1 error" + assert task2_errors[0]["code"] == "ARG-200" + assert "2024-03-01" in task2_errors[0]["context"]["external_date"] + + +@pytest.mark.usefixtures("prefect_test_fixture") +def test_error_reporting(): + """Test the full error accumulation and reporting flow""" + with patch("tsn_adapters.tasks.argentina.errors.accumulator.create_markdown_artifact") as mock_artifact: + error_accumulation_flow() + + # Verify artifact creation + mock_artifact.assert_called_once() + markdown_content = mock_artifact.call_args[1]["markdown"] + assert "ARG-300" in markdown_content + assert "test://invalid-map" in markdown_content + assert "Date mismatch: Reported 2024-01-01 vs Actual 2024-01-02" in markdown_content + assert mock_artifact.call_args[1]["key"] == "processing-errors" + assert "Errors encountered during SEPA data processing" in mock_artifact.call_args[1]["description"] + + +@pytest.mark.usefixtures("prefect_test_fixture") +def test_context_isolation(): + """Test that error contexts remain isolated between concurrent tasks""" + concurrent_error_contexts_flow() + + +def generate_markdown_report(errors: Sequence[Exception]) -> str: + """Generate a markdown formatted report from a list of errors. + + For an EmptyCategoryMapError, the report includes the url. + For a DateMismatchError, the report provides a formatted date mismatch message. + """ + lines = [] + for err in errors: + if isinstance(err, EmptyCategoryMapError): + # Access url from the error's context + lines.append(f"ARG-300: URL error with {err.context['url']}") + elif isinstance(err, DateMismatchError): + # Access dates from the error's context + external_date = err.context["external_date"] + internal_date = err.context["internal_date"] + lines.append(f"ARG-200: Date mismatch: Reported {external_date} vs Actual {internal_date}") + return "\n".join(lines) + + +if __name__ == "__main__": + test_error_reporting() diff --git a/tests/argentina/errors/test_error_integration.py b/tests/argentina/errors/test_error_integration.py new file mode 100644 index 0000000..0228a5f --- /dev/null +++ b/tests/argentina/errors/test_error_integration.py @@ -0,0 +1,464 @@ +""" +Integration tests for Argentina SEPA error handling system. + +This test suite focuses on critical integration tests and complex error handling scenarios, +validating the complete data processing pipeline and error accumulation mechanisms. + +Key Test Areas: +1. Data Processing Pipeline - Complete flow validation +2. Error Handling - Complex error scenarios and recovery +3. Concurrency - Error isolation and accumulation +4. System Resilience - Partial processing and recovery +""" + +from collections.abc import Generator +import os +import re +from typing import Any, Callable, Optional, TypeVar +from unittest.mock import MagicMock, patch +import zipfile + +import pandas as pd +from prefect import flow +import pytest + +from tsn_adapters.tasks.argentina.base_types import DateStr +from tsn_adapters.tasks.argentina.errors.accumulator import error_collection +from tsn_adapters.tasks.argentina.errors.context_helper import ArgentinaErrorContext +from tsn_adapters.tasks.argentina.errors.errors import ( + AccountableRole, + ArgentinaSEPAError, + DateMismatchError, + EmptyCategoryMapError, + InvalidCSVSchemaError, + InvalidStructureZIPError, + InvalidProductsError, + MissingProductosCSVError, + InvalidDateFormatError, +) +from tsn_adapters.tasks.argentina.provider.data_processor import process_sepa_zip +from tsn_adapters.tasks.argentina.types import SepaDF + +T = TypeVar("T", bound=ArgentinaSEPAError) + +# -------------------------------------------------- +# Fixtures & Test Data +# -------------------------------------------------- + + +@pytest.fixture +def error_context_fixture(): + """Fixture for managing error context and cleanup.""" + with error_collection() as accumulator: + yield accumulator + + +@pytest.fixture +def mock_s3_block(): + """Mock S3 block with configurable responses.""" + mock = MagicMock() + mock.list_available_keys.return_value = ["2024-01-01"] + mock.get_raw_data_for.return_value = pd.DataFrame( + { + "id_producto": ["P1", "P2"], + "productos_descripcion": ["Product 1", "Product 2"], + "productos_precio_lista": [100.0, 200.0], + "date": ["2024-01-01", "2024-01-01"], + } + ) + mock.bucket_name = "mock-bucket" + mock.bucket_folder = "mock-folder" + mock.credentials = MagicMock() + mock.read_path.return_value = b"mock_s3_content" + return mock + + +@pytest.fixture +def valid_sepa_data() -> pd.DataFrame: + """Valid SEPA data fixture with realistic product data.""" + df = pd.DataFrame( + { + "date": ["2024-01-01"] * 5, + "id_producto": ["P1", None, "", None, "P5"], + "productos_descripcion": [ + "Leche Entera 1L", + "Pan Francés 1kg", + "Aceite Girasol 900ml", + "Arroz Largo 1kg", + "Azúcar Blanca 1kg", + ], + "productos_precio_lista": [350.0, 800.0, 1200.0, 900.0, 600.0], + } + ) + return df + + +@pytest.fixture +def mock_zip_factory(): + """Factory fixture for creating test ZIP files with various scenarios.""" + + def create_zip(scenario: str, *, date: Optional[str] = None, df: Optional[pd.DataFrame] = None) -> bytes: + if scenario == "invalid": + return b"This is not a valid ZIP file" + + # Create a temporary ZIP file in memory + from io import BytesIO + + zip_buffer = BytesIO() + + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: + if scenario == "invalid_date": + # Create a directory with invalid date format + base_dir = "sepa_1_comercio-sepa-1_2024/01/01_00-00-00" + else: + base_dir = "sepa_1_comercio-sepa-1_2024-01-01_00-00-00" + + # Create the base directory marker + zip_file.writestr(f"{base_dir}/", "") + + if scenario == "missing_productos": + # Create a ZIP with wrong file but correct directory structure + zip_file.writestr(f"{base_dir}/wrong.csv", "some content") + elif scenario == "wrong_date": + content_date = date or "2024-01-02" + content = ( + "id_producto|productos_descripcion|productos_precio_lista|date\n" + f"P1|Product 1|100.0|{content_date}\n" + ) + zip_file.writestr(f"{base_dir}/productos.csv", content) + elif scenario == "mixed_errors": + # Create an invalid CSV schema + content = ( + "wrong_column|productos_descripcion|productos_precio_lista|date\n" + "P1|Product 1|100.0|2024-01-01\n" + ) + zip_file.writestr(f"{base_dir}/productos.csv", content) + elif scenario == "partial_valid": + if df is not None: + content = df.to_csv(sep="|", index=False) + else: + content = ( + "id_producto|productos_descripcion|productos_precio_lista|date\n" + "|Product 1|100.0|2024-01-01\n" # Missing ID + "|Product 2|200.0|2024-01-01\n" # Missing ID + "|Product 3|300.0|2024-01-01\n" # Missing ID + "P4|Product 4|400.0|2024-01-01\n" # Valid + "P5|Product 5|500.0|2024-01-01\n" # Valid + ) + zip_file.writestr(f"{base_dir}/productos.csv", content) + else: + # Default valid content + content = ( + "id_producto|productos_descripcion|productos_precio_lista|date\n" + "P1|Product 1|100.0|2024-01-01\n" + ) + zip_file.writestr(f"{base_dir}/productos.csv", content) + + return zip_buffer.getvalue() + + return create_zip + + +@pytest.fixture +def mock_filesystem(monkeypatch: pytest.MonkeyPatch) -> Generator[None, None, None]: + """Mock filesystem operations for testing.""" + created_dirs: set[str] = set() + + def mock_makedirs(path: str | os.PathLike[str], exist_ok: bool = False) -> None: + """Mock os.makedirs to track created directories.""" + path_str = str(path) + if path_str not in created_dirs: + created_dirs.add(path_str) + + with patch.object(os, "makedirs", mock_makedirs): + yield + + +@pytest.fixture(autouse=True) +def mock_network(): + """Mock all network operations.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b"mock_content" + mock_response.text = "mock_text" + + with ( + patch("requests.get", return_value=mock_response), + patch("requests.post", return_value=mock_response), + patch("requests.Session", return_value=MagicMock()), + patch("prefect_aws.S3Bucket.read_path", return_value=b"mock_s3_content"), + patch("prefect_aws.S3Bucket.download_object_to_path", return_value=None), + patch("prefect_aws.S3Bucket.list_objects", return_value=["mock_key"]), + ): + yield mock_response + + +# -------------------------------------------------- +# Integration Test Classes +# -------------------------------------------------- + + +class TestDataProcessingPipeline: + """Integration tests for the complete data processing pipeline.""" + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-complete-pipeline") + def test_complete_processing_flow( + self, + mock_s3_block: MagicMock, + valid_sepa_data: pd.DataFrame, + mock_zip_factory: Callable[..., bytes], + error_context_fixture: Any, + ): + """Test the complete data processing pipeline with mixed valid/invalid data.""" + # Setup test data with mixed scenarios + zip_content = mock_zip_factory("partial_valid", df=valid_sepa_data) + + def mock_reader() -> Generator[bytes, None, None]: + yield zip_content + + # Process data and verify results + result = process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test_pipeline") + + # Verify error handling + assert len(error_context_fixture.errors) > 0 + error = error_context_fixture.errors[0] + assert isinstance(error, InvalidProductsError) + assert error.context["missing_count"] == "3" + + # Verify successful processing of valid records + assert len(result) == 2 + assert all(pd.notna(result["id_producto"].astype(str))) + + @pytest.mark.usefixtures("prefect_test_fixture") + @pytest.mark.parametrize( + "error_scenario,expected_error,validation_details", + [ + ( + "invalid_structure", + InvalidStructureZIPError, + { + "code": "ARG-100", + "message": "Invalid ZIP file structure - cannot extract files", + "responsibility": AccountableRole.DATA_PROVIDER, + "context_keys": ["date", "error"], + "context_values": {"date": "2024-01-01"}, + }, + ), + ( + "missing_productos", + MissingProductosCSVError, + { + "code": "ARG-102", + "message": "Missing productos.csv in ZIP archive", + "responsibility": AccountableRole.DATA_PROVIDER, + "context_keys": ["directory", "available_files"], + "context_pattern": r"/tmp/tmp[^/]+/data/sepa_1_comercio-sepa-1_2024-01-01_00-00-00", + }, + ), + ( + "invalid_date", + InvalidDateFormatError, + { + "code": "ARG-101", + "message": "Invalid date format: 2024/01/01 - must be YYYY-MM-DD", + "responsibility": AccountableRole.SYSTEM, + "context_keys": ["invalid_date"], + "context_values": {"invalid_date": "2024/01/01"}, + }, + ), + ], + ) + @flow(name="test-error-scenarios") + def test_error_scenarios( + self, + mock_zip_factory: Callable[..., bytes], + error_scenario: str, + expected_error: type[T], + validation_details: dict[str, Any], + error_context_fixture: Any, + ): + """Test various error scenarios in the processing pipeline.""" + try: + if error_scenario == "invalid_structure": + zip_content = mock_zip_factory("invalid") + # Set up context for InvalidStructureZIPError + ctx = ArgentinaErrorContext() + ctx.store_id = "test" + ctx.date = "2024-01-01" + elif error_scenario == "invalid_date": + zip_content = mock_zip_factory("invalid_date") + else: + zip_content = mock_zip_factory("missing_productos") + + def mock_reader() -> Generator[bytes, None, None]: + yield zip_content + + process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test") + pytest.fail(f"Expected {expected_error.__name__} to be raised") + except expected_error as e: + # Verify error code + assert e.code == validation_details["code"] + # Verify error message + assert e.message == validation_details["message"] + # Verify responsibility + assert e.responsibility == validation_details["responsibility"] + + # Verify context has all required keys + for key in validation_details["context_keys"]: + assert key in e.context + + # Verify specific context values if provided + if "context_values" in validation_details: + for key, value in validation_details["context_values"].items(): + assert e.context[key] == value + + # Verify context patterns if provided + if "context_pattern" in validation_details: + assert re.match(validation_details["context_pattern"], e.context["directory"]) + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-invalid-structure-scenarios") + def test_invalid_structure_error_scenarios(self, error_context_fixture: Any): + """Test various scenarios for InvalidStructureZIPError.""" + # Test with different error messages + error_messages = [ + "File is not a zip file", + "Bad CRC-32 for file", + "Bad password for file", + "Truncated file header", + ] + + for error_msg in error_messages: + # Set up context + ctx = ArgentinaErrorContext() + ctx.date = "2024-01-01" + + # Create and verify error + error = InvalidStructureZIPError(error_msg) + + # Verify basic error properties + assert error.code == "ARG-100" + assert error.message == "Invalid ZIP file structure - cannot extract files" + assert error.responsibility == AccountableRole.DATA_PROVIDER + + # Verify context + assert "date" in error.context + assert error.context["date"] == "2024-01-01" + assert "error" in error.context + assert error.context["error"] == error_msg + + # Test without context date + error = InvalidStructureZIPError("test error") + assert "date" in error.context + assert error.context["date"] is None + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-date-format-scenarios") + def test_date_format_error_scenarios(self, error_context_fixture: Any): + """Test various scenarios for InvalidDateFormatError.""" + # Test with different invalid date formats + invalid_dates = [ + "2024/01/01", # Wrong separator + "24-01-01", # Wrong year format + "2024-1-1", # Missing padding + "2024-13-01", # Invalid month + "2024-01-32", # Invalid day + "01-01-2024", # Wrong order + "2024-01", # Incomplete + "not-a-date", # Invalid format + ] + + for invalid_date in invalid_dates: + # Create and verify error + error = InvalidDateFormatError(invalid_date) + + # Verify basic error properties + assert error.code == "ARG-101" + assert error.message == f"Invalid date format: {invalid_date} - must be YYYY-MM-DD" + assert error.responsibility == AccountableRole.SYSTEM + + # Verify context + assert "invalid_date" in error.context + assert error.context["invalid_date"] == invalid_date + + +class TestErrorHandlingSystem: + """Integration tests for the error handling and accumulation system.""" + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-error-accumulation") + def test_error_accumulation_and_isolation(self, error_context_fixture: Any): + """Test error accumulation, ordering, and isolation.""" + # Test concurrent error accumulation + with error_collection() as flow1_errors: + with error_collection() as flow2_errors: + # Add errors to different flows + flow1_errors.add_error(EmptyCategoryMapError(url="map1")) + # Set up context for DateMismatchError + ctx = ArgentinaErrorContext() + ctx.date = "2024-01-01" + flow2_errors.add_error(DateMismatchError(internal_date="2024-01-02")) + + # Verify flow2 errors + assert len(flow2_errors.errors) == 1 + assert isinstance(flow2_errors.errors[0], DateMismatchError) + assert flow2_errors.errors[0].context["external_date"] == "2024-01-01" + + # Verify flow1 errors + assert len(flow1_errors.errors) == 1 + assert isinstance(flow1_errors.errors[0], EmptyCategoryMapError) + assert flow1_errors.errors[0].context["url"] == "map1" + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-error-recovery") + def test_error_recovery_mechanisms( + self, + mock_zip_factory: Callable[..., bytes], + valid_sepa_data: SepaDF, + error_context_fixture: Any, + ): + """Test system recovery from various error conditions.""" + # Test partial processing recovery + zip_content = mock_zip_factory("partial_valid", df=valid_sepa_data) + + def mock_reader() -> Generator[bytes, None, None]: + yield zip_content + + result = process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test_recovery") + + # Verify error capture + assert len(error_context_fixture.errors) > 0 + error = error_context_fixture.errors[0] + assert isinstance(error, InvalidProductsError) + + # Verify successful partial processing + assert len(result) > 0 + assert all(pd.notna(result["id_producto"].astype(str))) + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-complex-errors") + def test_complex_error_scenarios( + self, + mock_zip_factory: Callable[..., bytes], + error_context_fixture: Any, + ): + """Test handling of complex error scenarios with multiple error types.""" + # Create data with multiple error types + zip_content = mock_zip_factory("mixed_errors") + + def mock_reader() -> Generator[bytes, None, None]: + yield zip_content + + try: + process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test_complex") + pytest.fail("Expected InvalidCSVSchemaError to be raised") + except InvalidCSVSchemaError as e: + assert e.code == "ARG-102" + assert e.responsibility == AccountableRole.DATA_PROVIDER + assert len(error_context_fixture.errors) == 1 + assert error_context_fixture.errors[0] == e + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/argentina/errors/test_global_context_scope.py b/tests/argentina/errors/test_global_context_scope.py new file mode 100644 index 0000000..91cb4e8 --- /dev/null +++ b/tests/argentina/errors/test_global_context_scope.py @@ -0,0 +1,156 @@ +""" +Tests to verify the scoping of the global error context managed +by tsn_adapters/tasks/argentina/errors/context.py. + +We cover: + - Synchronous tasks that simply read the global context. + - Tasks that add a local value via the error_context() context manager. + - Asynchronous tasks that return the current context. + - Tasks submitted via .submit(). + - A task that shows that a context set inside a block is rolled back after the block. + - Concurrent tasks with different local context values to ensure isolation. +""" + +import asyncio +from typing import Any + +from prefect import flow, task + +from tsn_adapters.tasks.argentina.errors.context import ( + clear_error_context, + error_context, + get_error_context, + set_error_context, +) + + +@task +def get_context() -> dict[str, Any]: + """Return the current global error context.""" + return get_error_context() + + +@task +def set_context_in_task(key: str, value: str) -> dict[str, Any]: + """ + Set an additional context using error_context and return the merged context. + This should merge any global context already present. + """ + with error_context(**{key: value}): + return get_error_context() + + +@task +async def async_get_context() -> dict[str, Any]: + """Asynchronous task returning the current global error context after a short delay.""" + await asyncio.sleep(0.1) + return get_error_context() + + +@task +def context_with_and_without(key: str, value: str) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Within a local error_context block, record the context then, + afterwards, record the global context again to ensure the local value is not leaked. + """ + with error_context(**{key: value}): + inside = get_error_context() + after = get_error_context() + return inside, after + + +@flow +def global_context_flow() -> dict[str, Any]: + """ + Flow that exercises various types of tasks to validate that the global error context is properly applied. + + First, we clear any previous context and set a global key. + Then we run: + - A normal task that simply returns the global context. + - A task that adds a key inside an error_context block. + - An async task. + - A task launched via the .submit() method. + - A task that shows a context value being set only within a with‐block. + """ + # Reset and then set a global value: + clear_error_context() + set_error_context("global", "global_value") + + # 1. Normal task reading the global context: + normal = get_context() + + # 2. Task that locally adds its own key: + with_ctx = set_context_in_task("task_key", "task_value") + + # 3. Asynchronous task: + async_result = asyncio.run(async_get_context()) + + # 4. Task submitted via .submit() + submit_future = set_context_in_task.submit("submit_key", "submit_value") + submit_result = submit_future.result() + + # 5. Task that tests a context block's scoping: + inside_after = context_with_and_without("block_key", "block_value") + + return { + "normal": normal, + "with_ctx": with_ctx, + "async": async_result, + "submit": submit_result, + "inside_after": inside_after, + } + + +@flow +def concurrent_context_flow() -> tuple[dict[str, Any], dict[str, Any]]: + """ + Flow that concurrently runs two tasks, each setting a different local error context. + This test ensures that concurrently submitted tasks do not leak context across each other. + """ + clear_error_context() + # Launch two tasks concurrently, each using error_context to add a distinct key: + future1 = set_context_in_task.submit("concurrent_key", "value1") + future2 = set_context_in_task.submit("concurrent_key", "value2") + result1 = future1.result() + result2 = future2.result() + return result1, result2 + + +def test_global_context_scope(prefect_test_fixture: Any): + """ + Validate that: + - A normal task sees the global context. + - A task wrapped with error_context correctly merges its key with the global context. + - Asynchronous tasks and submitted tasks see the expected global values. + - Context set within a with-block does not persist after the block. + """ + result = global_context_flow() + # Check that the normal task has the global value: + assert result["normal"].get("global") == "global_value" + + # The task with additional context should include both the global and the local key: + assert result["with_ctx"].get("global") == "global_value" + assert result["with_ctx"].get("task_key") == "task_value" + + # Async and submitted tasks should see only the global context (unless they set their own): + assert result["async"].get("global") == "global_value" + assert result["submit"].get("global") == "global_value" + assert result["submit"].get("submit_key") == "submit_value" + + # For the context_with_and_without task: + inside, after = result["inside_after"] + assert inside.get("block_key") == "block_value" + # After leaving the block, "block_key" should no longer be present: + assert "block_key" not in after + + +def test_concurrent_context_isolation(prefect_test_fixture: Any): + """ + Validate that concurrently submitted tasks each carry their local error context independently. + """ + result1, result2 = concurrent_context_flow() + # Each task should have its own "concurrent_key" with distinct values. + assert result1.get("concurrent_key") in ("value1", "value2") + assert result2.get("concurrent_key") in ("value1", "value2") + # The two results should not have the same value. + assert result1.get("concurrent_key") != result2.get("concurrent_key") diff --git a/tests/argentina/flows/test_preprocess_flow.py b/tests/argentina/flows/test_preprocess_flow.py index efcaf09..16b2912 100644 --- a/tests/argentina/flows/test_preprocess_flow.py +++ b/tests/argentina/flows/test_preprocess_flow.py @@ -166,7 +166,9 @@ def mock_preprocess_flow(self) -> PreprocessFlow: flow._create_summary = cast(Any, MagicMock()) return flow - def test_process_date_happy_path(self, mock_preprocess_flow: PreprocessFlow, mocker: MockerFixture, prefect_test_fixture): + def test_process_date_happy_path( + self, mock_preprocess_flow: PreprocessFlow, mocker: MockerFixture, prefect_test_fixture + ): """Test process_date with valid data.""" # Mock dependencies mock_raw_data = cast(Any, MagicMock(spec=pd.DataFrame)) @@ -189,7 +191,7 @@ def test_process_date_happy_path(self, mock_preprocess_flow: PreprocessFlow, moc Any, mocker.patch( "tsn_adapters.tasks.argentina.flows.preprocess_flow.process_raw_data", - return_value=MagicMock(result=lambda: (mock_processed_data, mock_uncategorized)) + return_value=MagicMock(result=lambda: (mock_processed_data, mock_uncategorized)), ), ) @@ -200,10 +202,7 @@ def test_process_date_happy_path(self, mock_preprocess_flow: PreprocessFlow, moc raw_provider.get_raw_data_for.assert_called_once_with(TEST_DATE) mock_task_load_category_map.assert_called_once_with(url=mock_preprocess_flow.category_map_url) mock_process_raw_data.assert_called_once_with( - raw_data=mock_raw_data, - category_map_df=mock_category_map, - date=TEST_DATE, - return_state=True + raw_data=mock_raw_data, category_map_df=mock_category_map, date=TEST_DATE, return_state=True ) processed_provider = cast(Any, mock_preprocess_flow.processed_provider) processed_provider.save_processed_data.assert_called_once_with( @@ -443,7 +442,7 @@ def test_create_summary_empty_data(self, mock_preprocess_flow: PreprocessFlow, m class TestPreprocessFlowTopLevel: """Tests for the top-level preprocess_flow function.""" - + pytestmark = pytest.mark.usefixtures("prefect_test_fixture") def test_preprocess_flow_happy_path(self, mocker: MockerFixture): diff --git a/tests/blocks/test_split_and_insert_records.py b/tests/blocks/test_split_and_insert_records.py index 5e8a6bf..bc017b8 100644 --- a/tests/blocks/test_split_and_insert_records.py +++ b/tests/blocks/test_split_and_insert_records.py @@ -170,56 +170,52 @@ def test_split_and_insert_with_failures(self, tn_block: TNAccessBlock, sample_re def test_filter_deployed_streams(self, tn_block: TNAccessBlock, test_stream_ids: list[str]): """Test that filter_deployed_streams correctly filters out non-deployed streams.""" base_timestamp = int(datetime.now().timestamp()) - + # Create records for both deployed and non-deployed streams deployed_records = [] for stream_id in test_stream_ids: - deployed_records.append({ - "stream_id": stream_id, - "date": base_timestamp, - "value": 100.0, - }) - + deployed_records.append( + { + "stream_id": stream_id, + "date": base_timestamp, + "value": 100.0, + } + ) + non_deployed_stream_id = generate_stream_id("non_deployed_test_stream") - non_deployed_records = [{ - "stream_id": non_deployed_stream_id, - "date": base_timestamp, - "value": 200.0, - }] - + non_deployed_records = [ + { + "stream_id": non_deployed_stream_id, + "date": base_timestamp, + "value": 200.0, + } + ] + # Combine all records all_records = pd.DataFrame(deployed_records + non_deployed_records) all_records = DataFrame[TnDataRowModel](all_records) - + # Test with filter_deployed_streams=True results = task_split_and_insert_records( - tn_block, - all_records, - is_unix=True, - filter_deployed_streams=True, - wait=True + tn_block, all_records, is_unix=True, filter_deployed_streams=True, wait=True ) - + assert results is not None # Should have successfully processed only the deployed streams assert len(results["success_tx_hashes"]) > 0 # No failed records since non-deployed streams were filtered out assert len(results["failed_records"]) == 0 - + # Verify only deployed streams were processed for stream_id in test_stream_ids: records = tn_block.read_records(stream_id, is_unix=True, date_from=base_timestamp) assert len(records) == 1 - + # Test with filter_deployed_streams=False results_no_filter = task_split_and_insert_records( - tn_block, - all_records, - is_unix=True, - filter_deployed_streams=False, - wait=True + tn_block, all_records, is_unix=True, filter_deployed_streams=False, wait=True ) - + assert results_no_filter is not None # Should have some failed records (the non-deployed stream) assert len(results_no_filter["failed_records"]) > 0 diff --git a/tests/blocks/test_stream_deployment_initialization.py b/tests/blocks/test_stream_deployment_initialization.py index a9d0c05..c027b48 100644 --- a/tests/blocks/test_stream_deployment_initialization.py +++ b/tests/blocks/test_stream_deployment_initialization.py @@ -1,9 +1,10 @@ -import pytest from datetime import datetime -from tsn_adapters.blocks.tn_access import TNAccessBlock +import pytest from trufnetwork_sdk_py.utils import generate_stream_id +from tsn_adapters.blocks.tn_access import TNAccessBlock + @pytest.fixture def test_stream_id() -> str: @@ -33,15 +34,15 @@ def test_init_stream(tn_block: TNAccessBlock, test_stream_id: str): # First deploy the stream deploy_tx = tn_block.deploy_stream(test_stream_id, wait=True) assert deploy_tx is not None, "Deploy transaction hash should be returned" - + # Initialize the stream init_tx = tn_block.init_stream(test_stream_id, wait=True) assert init_tx is not None, "Init transaction hash should be returned" - + # After initialization, the stream should be empty record = tn_block.get_first_record(test_stream_id) assert record is None, "The stream should be empty after initialization" - + # Clean up: destroy the stream destroy_tx = tn_block.destroy_stream(test_stream_id, wait=True) - assert destroy_tx is not None, "Stream destruction should return a transaction hash" \ No newline at end of file + assert destroy_tx is not None, "Stream destruction should return a transaction hash" diff --git a/tests/blocks/test_tn_network_error.py b/tests/blocks/test_tn_network_error.py index 204bfce..3835830 100644 --- a/tests/blocks/test_tn_network_error.py +++ b/tests/blocks/test_tn_network_error.py @@ -1,8 +1,10 @@ -from datetime import datetime from typing import Any, Callable +from prefect.client.schemas.objects import StateDetails, StateType +from prefect.types._datetime import DateTime from pydantic import SecretStr import pytest +from trufnetwork_sdk_py.client import TNClient from tsn_adapters.blocks.tn_access import ( SafeTNClientProxy, @@ -11,18 +13,12 @@ tn_special_retry_condition, ) -from prefect.client.schemas.objects import StateDetails, StateType -from prefect.types._datetime import DateTime - -from trufnetwork_sdk_py.client import TNClient - # --- Dummy TN Client to simulate behavior --- class DummyTNClient(TNClient): def __init__(self): pass - def get_first_record(self, *args: Any, **kwargs: Any) -> dict[str, str | float] | None: # Simulate a successful call return {"dummy_record": 1.0} @@ -77,7 +73,7 @@ def test_is_tn_node_network_error(): # For our tests we need dummy Task, TaskRun, and State objects. # We use minimal dummy implementations. from prefect import Task # Ensure that the correct Task object is imported per your Prefect version. -from prefect.client.schemas.objects import TaskRun, State +from prefect.client.schemas.objects import State, TaskRun class DummyTask(Task[Any, Any]): @@ -172,7 +168,7 @@ def test_tn_access_block_network_error(): def test_real_tn_client_unexistent_provider(): """ Test the real TN client with an unexistent provider. - + This test creates a TNAccessBlock using an invalid provider URL and attempts to call get_first_record. The safe client is expected to detect the underlying network error and re-raise it as a TNNodeNetworkError. @@ -184,7 +180,7 @@ def test_real_tn_client_unexistent_provider(): block = TNAccessBlock( tn_provider=invalid_provider, tn_private_key=SecretStr("0000000000000000000000000000000000000000000000000000000000000012"), - helper_contract_name="dummy" + helper_contract_name="dummy", ) _ = block.get_first_record("dummy_stream") # Check that the error message indicates a connection issue. diff --git a/tests/fixtures/test_trufnetwork.py b/tests/fixtures/test_trufnetwork.py index 45b4615..f3e49b9 100644 --- a/tests/fixtures/test_trufnetwork.py +++ b/tests/fixtures/test_trufnetwork.py @@ -30,11 +30,12 @@ class ContainerSpec: name: str image: str tmpfs_path: Optional[str] = None - env_vars: ( list[str] ) = field(default_factory=list) + env_vars: list[str] = field(default_factory=list) ports: dict[str, str] = field(default_factory=dict) entrypoint: Optional[str] = None args: list[str] = field(default_factory=list) + # Container specifications POSTGRES_CONTAINER = ContainerSpec( name="test-kwil-postgres", @@ -389,7 +390,8 @@ def test_tn_provider_fixture(self, tn_provider: TrufNetworkProvider): assert tn_provider.api_endpoint.startswith("http://") assert tn_provider.get_provider() is tn_provider -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def term_handler(): """ Fixture to transform SIGTERM into SIGINT. This permit us to gracefully stop the suite uppon SIGTERM. @@ -398,7 +400,8 @@ def term_handler(): yield signal.signal(signal.SIGTERM, orig) -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def disable_prefect_retries(): from importlib import import_module from unittest.mock import patch @@ -407,40 +410,40 @@ def disable_prefect_retries(): # Patch task retries by modifying the task options directly def patch_task_options(task_fn: Any) -> Any: - if hasattr(task_fn, 'with_options'): + if hasattr(task_fn, "with_options"): return task_fn.with_options(retries=0, cache_key_fn=None) return task_fn # All tasks with retries and their import paths tasks_to_patch = [ # FMP Historical Flow - 'tsn_adapters.flows.fmp.historical_flow.fetch_historical_data', + "tsn_adapters.flows.fmp.historical_flow.fetch_historical_data", # Stream Deploy Flow - 'tsn_adapters.flows.stream_deploy_flow.check_and_deploy_stream', + "tsn_adapters.flows.stream_deploy_flow.check_and_deploy_stream", # Primitive Source Descriptor - 'tsn_adapters.blocks.primitive_source_descriptor.get_descriptor_from_url', - 'tsn_adapters.blocks.primitive_source_descriptor.get_descriptor_from_github', + "tsn_adapters.blocks.primitive_source_descriptor.get_descriptor_from_url", + "tsn_adapters.blocks.primitive_source_descriptor.get_descriptor_from_github", # FMP Real Time Flow - 'tsn_adapters.flows.fmp.real_time_flow.fetch_quotes_for_batch', + "tsn_adapters.flows.fmp.real_time_flow.fetch_quotes_for_batch", # Argentina Task Wrappers - 'tsn_adapters.tasks.argentina.task_wrappers.task_create_stream_fetcher', - 'tsn_adapters.tasks.argentina.task_wrappers.task_get_streams', - 'tsn_adapters.tasks.argentina.task_wrappers.task_create_sepa_provider', - 'tsn_adapters.tasks.argentina.task_wrappers.task_get_data_for_date', - 'tsn_adapters.tasks.argentina.task_wrappers.task_get_latest_records', - 'tsn_adapters.tasks.argentina.task_wrappers.task_load_category_map', + "tsn_adapters.tasks.argentina.task_wrappers.task_create_stream_fetcher", + "tsn_adapters.tasks.argentina.task_wrappers.task_get_streams", + "tsn_adapters.tasks.argentina.task_wrappers.task_create_sepa_provider", + "tsn_adapters.tasks.argentina.task_wrappers.task_get_data_for_date", + "tsn_adapters.tasks.argentina.task_wrappers.task_get_latest_records", + "tsn_adapters.tasks.argentina.task_wrappers.task_load_category_map", # TN Access - 'tsn_adapters.blocks.tn_access.task_wait_for_tx', - 'tsn_adapters.blocks.tn_access.task_insert_and_wait_for_tx', - 'tsn_adapters.blocks.tn_access.task_insert_unix_and_wait_for_tx', - 'tsn_adapters.blocks.tn_access._task_only_batch_insert_records', - 'tsn_adapters.blocks.tn_access.task_split_and_insert_records' + "tsn_adapters.blocks.tn_access.task_wait_for_tx", + "tsn_adapters.blocks.tn_access.task_insert_and_wait_for_tx", + "tsn_adapters.blocks.tn_access.task_insert_unix_and_wait_for_tx", + "tsn_adapters.blocks.tn_access._task_only_batch_insert_records", + "tsn_adapters.blocks.tn_access.task_split_and_insert_records", ] - with patch('prefect.task', side_effect=original_task) as mock_task: + with patch("prefect.task", side_effect=original_task) as mock_task: for import_path in tasks_to_patch: # Split the import path into module path and attribute name - module_path, attr_name = import_path.rsplit('.', 1) + module_path, attr_name = import_path.rsplit(".", 1) # Import the module and get the task function module = import_module(module_path) task_fn = getattr(module, attr_name) @@ -448,6 +451,7 @@ def patch_task_options(task_fn: Any) -> Any: mock_task.return_value = patch_task_options(task_fn) patch(import_path, new=patch_task_options(task_fn)).start() + @pytest.fixture(scope="session", autouse=False) def prefect_test_fixture(disable_prefect_retries: Any): with prefect_test_harness(server_startup_timeout=120): diff --git a/tests/flows/test_stream_deploy_flow.py b/tests/flows/test_stream_deploy_flow.py index 4eb8b34..5a7170a 100644 --- a/tests/flows/test_stream_deploy_flow.py +++ b/tests/flows/test_stream_deploy_flow.py @@ -18,6 +18,7 @@ class TestPrimitiveSourcesDescriptor(PrimitiveSourcesDescriptorBlock): """Test implementation of PrimitiveSourcesDescriptorBlock that can generate a configurable number of streams.""" + model_config = ConfigDict(ignored_types=(object,)) num_streams: int = 3 @@ -32,6 +33,7 @@ def get_descriptor(self) -> DataFrame[PrimitiveSourceDataModel]: class TestTNClient(tn_client.TNClient): """Test implementation of TNClient that tracks deployed streams and can be initialized with existing streams.""" + def __init__(self, existing_streams: set[str] | None = None): # We don't call super().__init__ to avoid real client initialization if existing_streams is None: @@ -66,6 +68,7 @@ def wait_for_tx(self, tx_hash: str) -> None: class TestTNAccessBlock(TNAccessBlock): """Test implementation of TNAccessBlock that uses our TestTNClient.""" + _test_client: TestTNClient model_config = ConfigDict(ignored_types=(object,)) @@ -94,14 +97,10 @@ def primitive_descriptor(request: FixtureRequest) -> TestPrimitiveSourcesDescrip @pytest.mark.usefixtures("prefect_test_fixture") def test_deploy_streams_flow_all_new( - tn_access_block: TestTNAccessBlock, - primitive_descriptor: TestPrimitiveSourcesDescriptor + tn_access_block: TestTNAccessBlock, primitive_descriptor: TestPrimitiveSourcesDescriptor ) -> None: """Test that all streams are deployed when none exist.""" - results = deploy_streams_flow( - psd_block=primitive_descriptor, - tna_block=tn_access_block - ) + results = deploy_streams_flow(psd_block=primitive_descriptor, tna_block=tn_access_block) # All three streams should be deployed assert results["deployed_count"] == 3 @@ -119,14 +118,11 @@ def test_deploy_streams_flow_with_existing() -> None: # Create TNAccessBlock with some existing streams existing_streams = {"stream_0", "stream_2"} # First and last streams exist tn_access_block = TestTNAccessBlock(existing_streams=existing_streams) - + # Create descriptor with 3 streams primitive_descriptor = TestPrimitiveSourcesDescriptor(num_streams=3) - results = deploy_streams_flow( - psd_block=primitive_descriptor, - tna_block=tn_access_block - ) + results = deploy_streams_flow(psd_block=primitive_descriptor, tna_block=tn_access_block) # Only stream_1 should be deployed, others should be skipped assert results["deployed_count"] == 1 diff --git a/tests/fmp/test_fmpblock_integration.py b/tests/fmp/test_fmpblock_integration.py index 3927dd4..df059eb 100644 --- a/tests/fmp/test_fmpblock_integration.py +++ b/tests/fmp/test_fmpblock_integration.py @@ -1,14 +1,13 @@ -import os -import re from datetime import datetime -import time +import os from typing import Any +from pandera.typing import DataFrame from pydantic import SecretStr import pytest -from tsn_adapters.blocks.fmp import FMPBlock, EODData -from pandera.typing import DataFrame +from tsn_adapters.blocks.fmp import EODData, FMPBlock + def is_iso_date(date_str: str) -> bool: """Check if a string is in ISO date format (YYYY-MM-DD).""" @@ -18,20 +17,25 @@ def is_iso_date(date_str: str) -> bool: except ValueError: return False + def iso_to_unix_timestamp(iso_date: str) -> int: """Convert ISO date string to UNIX timestamp in seconds.""" return int(datetime.strptime(iso_date, "%Y-%m-%d").timestamp()) + def verify_date_formats(df: DataFrame[EODData]): """Verify that dates in the DataFrame are in ISO format and match their UNIX timestamp equivalents.""" # Check ISO format - assert all(is_iso_date(date) for date in df['date']), "All dates should be in ISO format (YYYY-MM-DD)" - + assert all(is_iso_date(date) for date in df["date"]), "All dates should be in ISO format (YYYY-MM-DD)" + # If we have timestamps, verify they match the ISO dates - if 'timestamp' in df.columns: - for iso_date, timestamp in zip(df['date'], df['timestamp']): + if "timestamp" in df.columns: + for iso_date, timestamp in zip(df["date"], df["timestamp"]): expected_timestamp = iso_to_unix_timestamp(iso_date) - assert timestamp == expected_timestamp, f"Timestamp mismatch for {iso_date}: expected {expected_timestamp}, got {timestamp}" + assert ( + timestamp == expected_timestamp + ), f"Timestamp mismatch for {iso_date}: expected {expected_timestamp}, got {timestamp}" + @pytest.fixture def fmp_block(prefect_test_fixture: Any): @@ -50,11 +54,11 @@ def test_get_active_tickers(fmp_block: FMPBlock): # Ensuring that we got a response that contains at least one active ticker assert df is not None, "Expected non-None result" assert len(df) > 0, "Expected non-empty active tickers list" - + # Validate schema requirements - assert all(isinstance(symbol, str) for symbol in df['symbol']), "All symbols should be strings" - assert all(isinstance(name, (str, type(None))) for name in df['name']), "All names should be strings or None" - + assert all(isinstance(symbol, str) for symbol in df["symbol"]), "All symbols should be strings" + assert all(isinstance(name, (str, type(None))) for name in df["name"]), "All names should be strings or None" + print(f"Number of active tickers: {len(df)}") @@ -66,12 +70,12 @@ def test_get_batch_quote(fmp_block: FMPBlock): # Ensuring that we received batch quotes for the provided symbols assert df is not None, "Expected non-None result" assert len(df) > 0, "Expected non-empty batch quote result" - + # Validate schema requirements - assert all(isinstance(symbol, str) for symbol in df['symbol']), "All symbols should be strings" - assert all(isinstance(price, (int, float)) for price in df['price']), "All prices should be numeric" - assert all(isinstance(volume, int) for volume in df['volume']), "All volumes should be integers" - + assert all(isinstance(symbol, str) for symbol in df["symbol"]), "All symbols should be strings" + assert all(isinstance(price, (int, float)) for price in df["price"]), "All prices should be numeric" + assert all(isinstance(volume, int) for volume in df["volume"]), "All volumes should be integers" + print(f"Batch quote for {symbols}:") print(df) @@ -84,20 +88,20 @@ def test_get_historical_eod_data(fmp_block: FMPBlock): df = fmp_block.get_historical_eod_data("AAPL", start_date=start_date, end_date=end_date) assert df is not None, "Expected non-None result" assert len(df) == 21, "Expected 21 rows of historical EOD data" - + # Validate schema requirements - assert all(isinstance(symbol, str) for symbol in df['symbol']), "All symbols should be strings" - assert all(is_iso_date(date) for date in df['date']), "All dates should be in ISO format" - assert all(isinstance(price, (int, float)) for price in df['price']), "All prices should be numeric" - assert all(isinstance(volume, int) for volume in df['volume']), "All volumes should be integers" - + assert all(isinstance(symbol, str) for symbol in df["symbol"]), "All symbols should be strings" + assert all(is_iso_date(date) for date in df["date"]), "All dates should be in ISO format" + assert all(isinstance(price, (int, float)) for price in df["price"]), "All prices should be numeric" + assert all(isinstance(volume, int) for volume in df["volume"]), "All volumes should be integers" + # Validate date range - assert min(df['date']) >= start_date, "Data should not be before start_date" - assert max(df['date']) <= end_date, "Data should not be after end_date" - + assert min(df["date"]) >= start_date, "Data should not be before start_date" + assert max(df["date"]) <= end_date, "Data should not be after end_date" + # Verify date formats and timestamp conversions verify_date_formats(df) - + print(f"Historical EOD data for AAPL from {start_date} to {end_date}:") print(df) diff --git a/tests/fmp/test_historical_flow.py b/tests/fmp/test_historical_flow.py index 7155527..3d0d1e1 100644 --- a/tests/fmp/test_historical_flow.py +++ b/tests/fmp/test_historical_flow.py @@ -188,6 +188,7 @@ def get_client(self) -> TNClient: """Mock to prevent real client creation.""" return None # type: ignore + class ErrorFMPBlock(FMPBlock): def get_historical_eod_data( self, symbol: str, start_date: str | None = None, end_date: str | None = None @@ -227,8 +228,6 @@ def fake_tn_block() -> FakeTNAccessBlock: def error_fmp_block(): """Fixture for error-raising FMP block.""" - - return ErrorFMPBlock(api_key=SecretStr("fake")) diff --git a/tests/fmp/test_real_time_flow.py b/tests/fmp/test_real_time_flow.py index 7e67480..e69140f 100644 --- a/tests/fmp/test_real_time_flow.py +++ b/tests/fmp/test_real_time_flow.py @@ -175,7 +175,9 @@ def test_real_time_flow_api_error(self, error_fmp_block: ErrorFMPBlock): class TestProcessDataAndDescriptor: """Tests for data processing and descriptor handling.""" - def test_process_data(self, sample_quotes_df: DataFrame[BatchQuoteShort], sample_descriptor_df: DataFrame[PrimitiveSourceDataModel]): + def test_process_data( + self, sample_quotes_df: DataFrame[BatchQuoteShort], sample_descriptor_df: DataFrame[PrimitiveSourceDataModel] + ): """Test processing quote data into TN format.""" # Convert descriptor_df to plain pandas DataFrame as expected by process_data descriptor_df = pd.DataFrame(sample_descriptor_df)