From 51b0696329497eb2fb876fddd438fea76d5464b7 Mon Sep 17 00:00:00 2001 From: Raffael Campos Date: Fri, 31 Jan 2025 14:45:23 -0300 Subject: [PATCH 1/2] feat(argentina): Implement comprehensive error handling system for SEPA data processing - Introduced a structured error handling system with detailed error classes for various processing stages - Added `ErrorAccumulator` to collect and manage errors during data processing - Implemented context-based error collection with support for markdown artifact generation - Enhanced error handling in preprocessing, category mapping, and data validation flows - Added comprehensive error types covering input validation, data processing, and category mapping scenarios - Improved error traceability with detailed context and responsibility attribution - Updated existing modules to leverage the new error handling system --- src/tsn_adapters/tasks/argentina/__init__.py | 2 + .../aggregate/category_price_aggregator.py | 32 +- .../argentina/aggregate/uncategorized.py | 4 +- .../tasks/argentina/errors/__init__.py | 33 ++ .../tasks/argentina/errors/accumulator.py | 67 +++ .../tasks/argentina/errors/errors.py | 198 +++++++ .../tasks/argentina/flows/base.py | 18 +- .../tasks/argentina/flows/ingest_flow.py | 2 +- .../tasks/argentina/flows/preprocess_flow.py | 24 +- .../argentina/provider/data_processor.py | 25 +- .../tasks/argentina/task_wrappers.py | 7 +- .../utils/processors/resource_processor.py | 49 +- .../errors/test_error_accumulation.py | 116 ++++ .../errors/test_error_integration.py | 536 ++++++++++++++++++ 14 files changed, 1060 insertions(+), 53 deletions(-) create mode 100644 src/tsn_adapters/tasks/argentina/errors/__init__.py create mode 100644 src/tsn_adapters/tasks/argentina/errors/accumulator.py create mode 100644 src/tsn_adapters/tasks/argentina/errors/errors.py create mode 100644 tests/argentina/errors/test_error_accumulation.py create mode 100644 tests/argentina/errors/test_error_integration.py diff --git a/src/tsn_adapters/tasks/argentina/__init__.py b/src/tsn_adapters/tasks/argentina/__init__.py index 6a159dc..90d18f7 100644 --- a/src/tsn_adapters/tasks/argentina/__init__.py +++ b/src/tsn_adapters/tasks/argentina/__init__.py @@ -5,9 +5,11 @@ from tsn_adapters.tasks.argentina.models.sepa import SepaWebsiteDataItem from tsn_adapters.tasks.argentina.models.sepa.website_item import SepaWebsiteScraper from tsn_adapters.tasks.argentina.utils.processors import SepaDirectoryProcessor +from tsn_adapters.tasks.argentina.errors import errors __all__ = [ "SepaWebsiteDataItem", "SepaWebsiteScraper", "SepaDirectoryProcessor", + "errors", ] diff --git a/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py b/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py index dfd9e04..6b4edfc 100644 --- a/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py +++ b/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py @@ -6,6 +6,11 @@ """ from tsn_adapters.tasks.argentina.aggregate.uncategorized import get_uncategorized_products +from tsn_adapters.tasks.argentina.errors import ( + EmptyCategoryMapError, + ErrorAccumulator, + UncategorizedProductsError, +) from tsn_adapters.tasks.argentina.types import AggregatedPricesDF, AvgPriceDF, CategoryMapDF, UncategorizedDF from tsn_adapters.utils.logging import get_logger_safe @@ -40,11 +45,16 @@ def aggregate_prices_by_category( 1. Merges product prices with their category assignments 2. Groups the data by category and date 3. Computes the mean price for each category-date combination + + Raises + ------ + EmptyCategoryMapError + If the product category mapping DataFrame is empty """ # Input validation and logging if product_category_map_df.empty: logger.error("Product category mapping DataFrame is empty") - raise ValueError("Cannot aggregate prices: product category mapping is empty") + raise EmptyCategoryMapError(url="") if avg_price_product_df.empty: logger.warning("Average price product DataFrame is empty") @@ -56,17 +66,6 @@ def aggregate_prices_by_category( f"and {len(avg_price_product_df)} price records" ) - # Check for products without categories before merge - unique_products_with_prices = set(avg_price_product_df["id_producto"].unique()) - unique_products_with_categories = set(product_category_map_df["id_producto"].unique()) - - products_without_categories = unique_products_with_prices - unique_products_with_categories - if products_without_categories: - logger.warning( - f"Found {len(products_without_categories)} products with prices but no category mapping. " - "These will be excluded from aggregation." - ) - # Merge the product categories with prices merged_df = avg_price_product_df.merge( product_category_map_df, @@ -94,4 +93,13 @@ def aggregate_prices_by_category( uncategorized_df = get_uncategorized_products(avg_price_product_df, product_category_map_df) + if not uncategorized_df.empty: + logger.warning(f"Found {len(uncategorized_df)} uncategorized products") + accumulator = ErrorAccumulator.get_or_create_from_context() + accumulator.add_error(UncategorizedProductsError( + count=len(uncategorized_df), + date=avg_price_product_df["date"].iloc[0], + store_id=avg_price_product_df["id_comercio"].iloc[0] + )) + return AggregatedPricesDF(aggregated_df), uncategorized_df diff --git a/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py b/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py index c6c0b49..5b29ff1 100644 --- a/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py +++ b/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py @@ -10,7 +10,5 @@ def get_uncategorized_products( """ diff_df = data[~data["id_producto"].isin(category_map["id_producto"])] - # get data without id_producto (=null) - data[data["id_producto"].isnull()] - return UncategorizedDF(diff_df) + diff --git a/src/tsn_adapters/tasks/argentina/errors/__init__.py b/src/tsn_adapters/tasks/argentina/errors/__init__.py new file mode 100644 index 0000000..6346fdf --- /dev/null +++ b/src/tsn_adapters/tasks/argentina/errors/__init__.py @@ -0,0 +1,33 @@ +""" +Error handling for Argentina SEPA processing. +""" + +from tsn_adapters.tasks.argentina.errors.accumulator import ErrorAccumulator +from tsn_adapters.tasks.argentina.errors.errors import ( + AccountableRole, + ArgentinaSEPAError, + DateMismatchError, + EmptyCategoryMapError, + InvalidCategorySchemaError, + InvalidCSVSchemaError, + InvalidDateFormatError, + InvalidStructureZIPError, + MissingProductIDError, + MissingProductosCSVError, + UncategorizedProductsError, +) + +__all__ = [ + "ErrorAccumulator", + "AccountableRole", + "ArgentinaSEPAError", + "DateMismatchError", + "EmptyCategoryMapError", + "InvalidCategorySchemaError", + "InvalidCSVSchemaError", + "InvalidDateFormatError", + "InvalidStructureZIPError", + "MissingProductIDError", + "MissingProductosCSVError", + "UncategorizedProductsError", +] \ No newline at end of file diff --git a/src/tsn_adapters/tasks/argentina/errors/accumulator.py b/src/tsn_adapters/tasks/argentina/errors/accumulator.py new file mode 100644 index 0000000..b548f05 --- /dev/null +++ b/src/tsn_adapters/tasks/argentina/errors/accumulator.py @@ -0,0 +1,67 @@ +from contextlib import contextmanager +from contextvars import ContextVar +from typing import Any + +from prefect.artifacts import create_markdown_artifact +from tsn_adapters.tasks.argentina.errors.errors import ArgentinaSEPAError + +error_ctx = ContextVar("argentina_sepa_errors") + +class ErrorAccumulator: + def __init__(self): + self.errors = [] + + def add_error(self, error: ArgentinaSEPAError): + self.errors.append(error) + + def model_dump(self): + return [error.to_dict() for error in self.errors] + + def model_load(self, data: list[dict[str, Any]]): + self.errors = [ArgentinaSEPAError(**error) for error in data] + + @classmethod + def get_or_create_from_context(cls) -> 'ErrorAccumulator': + errors = error_ctx.get() + if errors is None: + errors = ErrorAccumulator() + errors.set_to_context() + return errors + + def set_to_context(self): + error_ctx.set(self) + + +@contextmanager +def error_collection(): + """Context manager for collecting errors during processing.""" + accumulator = ErrorAccumulator() + accumulator.set_to_context() + try: + yield accumulator + finally: + # Create error summary if there are errors + if accumulator.errors: + error_summary = [ + "# Processing Errors\n", + "The following errors occurred during processing:\n", + ] + for error in accumulator.errors: + error_dict = error.to_dict() + error_summary.extend( + [ + f"\n## {error_dict['code']}: {error_dict['message']}\n", + f"Responsibility: {error_dict['responsibility'].value}\n", + "Context:\n", + "```json\n", + f"{error_dict['context']}\n", + "```\n", + ] + ) + + # Create markdown artifact with the error summary + create_markdown_artifact( + key="processing-errors", + markdown="".join(error_summary), + description="Errors encountered during SEPA data processing", + ) diff --git a/src/tsn_adapters/tasks/argentina/errors/errors.py b/src/tsn_adapters/tasks/argentina/errors/errors.py new file mode 100644 index 0000000..1fe6671 --- /dev/null +++ b/src/tsn_adapters/tasks/argentina/errors/errors.py @@ -0,0 +1,198 @@ +""" +Structured error handling for Argentina SEPA processing. +""" + +from enum import Enum +from typing import Any, Optional + +from pydantic import BaseModel + + +class AccountableRole(Enum): + """Roles responsible for different types of errors.""" + + DATA_PROVIDER = "Data Provider" + DATA_ENGINEERING = "Data Engineering" + DEVELOPMENT = "Development" + SYSTEM = "System" + + +class ArgentinaSEPAErrorData(BaseModel): + """Data model for Argentina SEPA errors.""" + code: str + message: str + responsibility: AccountableRole + context: dict[str, Any] = {} + + +class ArgentinaSEPAError(Exception): + """Base class for all Argentina SEPA processing errors.""" + + def __init__( + self, code: str, message: str, responsibility: AccountableRole, context: Optional[dict[str, Any]] = None + ): + super().__init__(message) + self.data = ArgentinaSEPAErrorData( + code=code, + message=message, + responsibility=responsibility, + context=context or {} + ) + + @property + def code(self) -> str: + return self.data.code + + @property + def message(self) -> str: + return self.data.message + + @property + def responsibility(self) -> AccountableRole: + return self.data.responsibility + + @property + def context(self) -> dict[str, Any]: + return self.data.context + + def to_dict(self) -> dict[str, Any]: + """Convert error to dictionary format.""" + return { + "code": self.code, + "message": self.message, + "responsibility": self.responsibility, + "context": self.context + } + + +# -------------------------------------------------- +# Input Validation Errors (100-199) +# -------------------------------------------------- +class InvalidStructureZIPError(ArgentinaSEPAError): + """Invalid ZIP file structure during extraction""" + + def __init__(self, context: dict[str, Any]): + super().__init__( + code="ARG-100", + message="Invalid ZIP file structure - cannot extract files", + responsibility=AccountableRole.DATA_PROVIDER, + context=context, + ) + + +class InvalidDateFormatError(ArgentinaSEPAError): + """Invalid date format in flow input""" + + def __init__(self, date_str: str): + super().__init__( + code="ARG-101", + message=f"Invalid date format: {date_str} - must be YYYY-MM-DD", + responsibility=AccountableRole.SYSTEM, + context={"invalid_date": date_str}, + ) + + +class MissingProductosCSVError(ArgentinaSEPAError): + """Missing productos.csv in ZIP file""" + + def __init__(self, context: dict[str, Any]): + super().__init__( + code="ARG-102", + message="Missing productos.csv in ZIP archive", + responsibility=AccountableRole.DATA_PROVIDER, + context=context, + ) + + +# -------------------------------------------------- +# Data Processing Errors (200-299) +# -------------------------------------------------- +class DateMismatchError(ArgentinaSEPAError): + """Filename vs content date mismatch""" + + def __init__(self, external_date: str, internal_date: str): + super().__init__( + code="ARG-200", + message=f"Date mismatch: Reported {external_date} vs Actual {internal_date}", + responsibility=AccountableRole.DATA_PROVIDER, + context={ + "external_date": external_date, + "internal_date": internal_date, + }, + ) + + +class InvalidCSVSchemaError(ArgentinaSEPAError): + """Missing required columns in RAW data""" + + def __init__(self, date: str, store_id: str): + super().__init__( + code="ARG-201", + message="Missing required columns", + responsibility=AccountableRole.DATA_PROVIDER, + context={"date": date, "store_id": store_id}, + ) + + +class MissingProductIDError(ArgentinaSEPAError): + """Null/empty product IDs found in RAW data""" + + def __init__(self, count: int, date: str, store_id: str): + super().__init__( + code="ARG-202", + message=f"{count} products missing IDs", + responsibility=AccountableRole.DEVELOPMENT, + context={"missing_count": count, "date": date, "store_id": store_id}, + ) + + +# -------------------------------------------------- +# Category Mapping Errors (300-399) +# -------------------------------------------------- +class EmptyCategoryMapError(ArgentinaSEPAError): + """Empty category mapping DataFrame""" + + def __init__(self, url: str): + super().__init__( + code="ARG-300", + message="Category mapping is empty", + responsibility=AccountableRole.DATA_ENGINEERING, + context={"url": url}, + ) + + +class UncategorizedProductsError(ArgentinaSEPAError): + """Products without category mapping""" + + def __init__(self, count: int, date: str, store_id: str): + super().__init__( + code="ARG-301", + message=f"{count} uncategorized products found", + responsibility=AccountableRole.DATA_ENGINEERING, + context={"uncategorized_count": count, "date": date, "store_id": store_id}, + ) + + +class InvalidCategorySchemaError(ArgentinaSEPAError): + """Invalid category mapping schema""" + + def __init__(self, issues: list[str]): + super().__init__( + code="ARG-302", + message="Invalid category mapping schema", + responsibility=AccountableRole.DATA_ENGINEERING, + context={"validation_issues": issues}, + ) + + +all_errors = [ + InvalidStructureZIPError, + InvalidDateFormatError, + MissingProductosCSVError, + DateMismatchError, + InvalidCSVSchemaError, + MissingProductIDError, + EmptyCategoryMapError, + UncategorizedProductsError, + InvalidCategorySchemaError, +] diff --git a/src/tsn_adapters/tasks/argentina/flows/base.py b/src/tsn_adapters/tasks/argentina/flows/base.py index d637bb6..6411012 100644 --- a/src/tsn_adapters/tasks/argentina/flows/base.py +++ b/src/tsn_adapters/tasks/argentina/flows/base.py @@ -2,9 +2,13 @@ Base flow controller for Argentina SEPA data processing. """ +import re +from datetime import datetime + from prefect import get_run_logger from prefect_aws import S3Bucket +from tsn_adapters.tasks.argentina.errors.errors import InvalidDateFormatError from tsn_adapters.tasks.argentina.provider import ProcessedDataProvider, RawDataProvider from tsn_adapters.tasks.argentina.types import DateStr @@ -32,9 +36,13 @@ def validate_date(self, date: DateStr) -> None: date: Date string to validate Raises: - ValueError: If date format is invalid + InvalidDateFormatError: If date format is invalid """ - import re - - if not re.match(r"\d{4}-\d{2}-\d{2}", date): - raise ValueError(f"Invalid date format: {date}") + # Check basic format and validity + try: + if not re.match(r"\d{4}-\d{2}-\d{2}", date): + raise InvalidDateFormatError(date) + + datetime.strptime(date, "%Y-%m-%d") + except ValueError: + raise InvalidDateFormatError(date) diff --git a/src/tsn_adapters/tasks/argentina/flows/ingest_flow.py b/src/tsn_adapters/tasks/argentina/flows/ingest_flow.py index e8bde8b..b865722 100644 --- a/src/tsn_adapters/tasks/argentina/flows/ingest_flow.py +++ b/src/tsn_adapters/tasks/argentina/flows/ingest_flow.py @@ -15,7 +15,7 @@ from prefect.artifacts import create_markdown_artifact from prefect_aws import S3Bucket -from tsn_adapters.common.trufnetwork.models.tn_models import TnDataRowModel +from tsn_adapters.common.trufnetwork.models.tn_models import TnDataRowModel, TnRecordModel from tsn_adapters.tasks.argentina.flows.base import ArgentinaFlowController from tsn_adapters.tasks.argentina.target import create_trufnetwork_components from tsn_adapters.tasks.argentina.task_wrappers import ( diff --git a/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py b/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py index 5dd3e71..6481e54 100644 --- a/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py +++ b/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py @@ -8,6 +8,7 @@ 4. Storage of processed data in S3 """ +from contextlib import contextmanager from typing import cast import pandas as pd @@ -17,6 +18,10 @@ from prefect_aws import S3Bucket from tsn_adapters.tasks.argentina.aggregate import aggregate_prices_by_category +from tsn_adapters.tasks.argentina.errors import ( + ArgentinaSEPAError, + ErrorAccumulator, +) from tsn_adapters.tasks.argentina.flows.base import ArgentinaFlowController from tsn_adapters.tasks.argentina.models.sepa.sepa_models import SepaAvgPriceProductModel from tsn_adapters.tasks.argentina.provider.s3 import RawDataProvider @@ -82,11 +87,20 @@ def run_flow(self) -> None: 3. If not, process the date """ logger = get_run_logger() - for date in self.raw_provider.list_available_keys(): - if self.processed_provider.exists(date): - logger.info(f"Skipping {date} because it already exists") - continue - self.process_date(date) + with error_collection() as accumulator: + for date in self.raw_provider.list_available_keys(): + if self.processed_provider.exists(date): + logger.info(f"Skipping {date} because it already exists") + continue + try: + self.process_date(date) + except Exception as e: + if isinstance(e, ArgentinaSEPAError): + accumulator.add_error(e) + logger.warning(f"Collected error for {date}: {e}") + else: + logger.error(f"Unexpected error processing {date}: {e}") + raise def process_date(self, date: DateStr) -> None: """Process data for a specific date. diff --git a/src/tsn_adapters/tasks/argentina/provider/data_processor.py b/src/tsn_adapters/tasks/argentina/provider/data_processor.py index 2629389..b6e550d 100644 --- a/src/tsn_adapters/tasks/argentina/provider/data_processor.py +++ b/src/tsn_adapters/tasks/argentina/provider/data_processor.py @@ -10,21 +10,11 @@ import pandas as pd from prefect.concurrency.sync import concurrency +from tsn_adapters.tasks.argentina.errors.errors import DateMismatchError, InvalidStructureZIPError from tsn_adapters.tasks.argentina.types import DateStr, SepaDF from tsn_adapters.tasks.argentina.utils.processors import SepaDirectoryProcessor -class DatesNotMatchError(Exception): - """Exception raised when the date does not match.""" - - def __init__(self, real_date: DateStr, reported_date: DateStr, source: str): - """Initialize the exception.""" - self.real_date = real_date - self.reported_date = reported_date - self.source = source - super().__init__(f"Real date {real_date} does not match {source} date {reported_date}") - - def process_sepa_zip( zip_reader: Generator[bytes, None, None], reported_date: DateStr, @@ -34,7 +24,7 @@ def process_sepa_zip( Process SEPA data from a data item. Args: - +zip_reader: Generator yielding bytes of the ZIP file source_name: Name of the source (for error messages) reported_date: The date reported by the source @@ -42,7 +32,8 @@ def process_sepa_zip( DataFrame: The processed SEPA data Raises: - DatesNotMatchError: If the date in the data doesn't match the reported date + DateMismatchError: If the date in the data doesn't match the reported date + InvalidStructureZIPError: If the ZIP file structure is invalid ValueError: If the data is invalid """ # Create a temporary directory for extraction @@ -60,7 +51,11 @@ def process_sepa_zip( # Process the data extract_dir = os.path.join(temp_dir, "data") os.makedirs(extract_dir, exist_ok=True) - processor = SepaDirectoryProcessor.from_zip_path(temp_zip_path, extract_dir) + try: + processor = SepaDirectoryProcessor.from_zip_path(temp_zip_path, extract_dir) + except Exception as e: + raise InvalidStructureZIPError({"source": source_name, "date": reported_date, "error": str(e)}) from e + df = processor.get_all_products_data_merged() # skip empty dataframes @@ -71,6 +66,6 @@ def process_sepa_zip( real_date = df["date"].iloc[0] if reported_date != real_date: # we need to raise an error so cache is invalidated - raise DatesNotMatchError(real_date, reported_date, source_name) + raise DateMismatchError(external_date=reported_date, internal_date=real_date) return cast(SepaDF, df) diff --git a/src/tsn_adapters/tasks/argentina/task_wrappers.py b/src/tsn_adapters/tasks/argentina/task_wrappers.py index 23c3d4d..ca4099c 100644 --- a/src/tsn_adapters/tasks/argentina/task_wrappers.py +++ b/src/tsn_adapters/tasks/argentina/task_wrappers.py @@ -15,6 +15,7 @@ from tsn_adapters.common.interfaces.target import ITargetClient from tsn_adapters.common.interfaces.transformer import IDataTransformer from tsn_adapters.common.trufnetwork.models.tn_models import TnDataRowModel +from tsn_adapters.tasks.argentina.errors.errors import EmptyCategoryMapError, InvalidCategorySchemaError from tsn_adapters.tasks.argentina.models.category_map import SepaProductCategoryMapModel from tsn_adapters.tasks.argentina.provider.factory import create_sepa_processed_provider from tsn_adapters.tasks.argentina.reconciliation.strategies import create_reconciliation_strategy @@ -249,11 +250,15 @@ def task_load_category_map(url: str) -> pd.DataFrame: logger.info(f"Loading category map from: {url}") try: df = SepaProductCategoryMapModel.from_url(url, sep="|", compression="zip") + if df.empty: + raise EmptyCategoryMapError() logger.info(f"Loaded {len(df)} category mappings") return df + except EmptyCategoryMapError: + raise except Exception as e: logger.error(f"Failed to load category map: {e}") - raise + raise InvalidCategorySchemaError(issues=[f"Failed to load category map from {url}", f"Error: {e!s}"]) @task diff --git a/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py b/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py index 4c928ee..04c6d23 100644 --- a/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py +++ b/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py @@ -12,6 +12,11 @@ from pandera.typing import DataFrame from pydantic import BaseModel, field_validator +from tsn_adapters.tasks.argentina.errors.errors import ( + InvalidCSVSchemaError, + InvalidDateFormatError, + MissingProductosCSVError, +) from tsn_adapters.tasks.argentina.models.sepa.sepa_models import ( SepaProductosAlternativeModel, SepaProductosDataModel, @@ -82,6 +87,9 @@ def get_products_data(self) -> Iterator[DataFrame[SepaProductosDataModel]]: for data_dir in self.get_data_dirs(): try: yield data_dir.load_products_data() + except MissingProductosCSVError as e: + self.logger.error(f"Error loading products data from {data_dir.dir_path}: {e}") + raise except Exception as e: self.logger.error(f"Error loading products data from {data_dir.dir_path}: {e}") @@ -92,7 +100,13 @@ def get_all_products_data_merged(self) -> DataFrame[SepaProductosDataModel]: all_data = list(self.get_products_data()) if not all_data: self.logger.warning("No product data found in any directory") - return cast(DataFrame[SepaProductosDataModel], pd.DataFrame()) # Return empty DataFrame if no data + empty_df = pd.DataFrame({ + "id_producto": [], + "productos_descripcion": [], + "productos_precio_lista": [], + "date": [], + }) + return DataFrame[SepaProductosDataModel](empty_df) merged_df = pd.concat(all_data, ignore_index=True) self.logger.info(f"Successfully merged {len(all_data)} product data frames with {len(merged_df)} total rows") @@ -124,7 +138,7 @@ class SepaDataDirectory(BaseModel): @classmethod def validate_date(cls, v: str) -> str: if not re.match(r"\d{4}-\d{2}-\d{2}", v): - raise ValueError(f"Invalid date format: {v}") + raise InvalidDateFormatError(v) return v @staticmethod @@ -180,7 +194,7 @@ def load_products_data(self) -> DataFrame[SepaProductosDataModel]: break else: self.logger.error(f"No valid product file found in {self.dir_path}") - raise ValueError(f"No valid product file found in {self.dir_path}") + raise MissingProductosCSVError({"directory": self.dir_path, "available_files": files}) def read_file_iterator(): with open(file_path) as file: @@ -191,7 +205,13 @@ def read_file_iterator(): text_lines = list(read_file_iterator()) if not text_lines: self.logger.warning(f"Empty product file found in {self.dir_path}") - return cast(DataFrame[SepaProductosDataModel], pd.DataFrame()) + empty_df = pd.DataFrame({ + "id_producto": [], + "productos_descripcion": [], + "productos_precio_lista": [], + "date": [], + }) + return DataFrame[SepaProductosDataModel](empty_df) # models might be models: list[type[SepaProductosAlternativeModel]] = [ @@ -201,10 +221,17 @@ def read_file_iterator(): for model in models: if model.has_columns(text_lines[0]): - original_model = model.from_csv(self.date, text_lines, self.logger) - if model != SepaProductosDataModel: - self.logger.info(f"Converting {model} to core model") - return model.to_core_model(original_model) - return original_model - - raise ValueError("No valid model found for the given lines") + try: + original_model = model.from_csv(self.date, text_lines, self.logger) + if model != SepaProductosDataModel: + self.logger.info(f"Converting {model} to core model") + return model.to_core_model(original_model) + return original_model + except Exception as e: + self.logger.error(f"Failed to parse CSV with model {model}: {e}") + continue + + raise InvalidCSVSchemaError( + date=self.date, + store_id=self.id_comercio, + ) diff --git a/tests/argentina/errors/test_error_accumulation.py b/tests/argentina/errors/test_error_accumulation.py new file mode 100644 index 0000000..c98a78d --- /dev/null +++ b/tests/argentina/errors/test_error_accumulation.py @@ -0,0 +1,116 @@ +""" +Test suite for Argentina SEPA error handling system. + +This test suite verifies the error handling system's core functionalities: + +1. Error Accumulation: + - Errors can be collected and accumulated during processing + - Each error maintains its structured data (code, message, responsibility, context) + - Errors are properly serialized for reporting + +2. Context Isolation: + - Multiple concurrent tasks can maintain isolated error contexts + - Errors from one task don't leak into another task's context + - Each task's error context is properly managed and cleaned up + +3. Error Reporting: + - Accumulated errors are properly formatted into markdown artifacts + - Error reports include all necessary information (code, message, responsibility, context) + - Artifacts are created at appropriate points in the flow + +Key Conclusions: +- The error handling system is thread-safe and suitable for concurrent execution +- Error contexts are properly isolated between different tasks and flows +- The system successfully maintains error traceability and accountability +- Error reporting provides clear, structured information for debugging and monitoring + +Test Structure: +- error_accumulation_flow: Tests basic error collection and accumulation +- concurrent_error_contexts_flow: Verifies context isolation in concurrent execution +- test_error_reporting: Ensures proper artifact creation and formatting +- test_context_isolation: Validates complete isolation between concurrent tasks +""" + +from unittest.mock import patch + +from prefect import flow, task + +from tsn_adapters.tasks.argentina.errors import DateMismatchError, EmptyCategoryMapError, ErrorAccumulator +from tsn_adapters.tasks.argentina.flows.preprocess_flow import error_collection + + +@task +def task_that_raises_error(): + """Task that intentionally raises a known error""" + accumulator = ErrorAccumulator.get_or_create_from_context() + accumulator.add_error(EmptyCategoryMapError(url="test://invalid-map")) + accumulator.add_error(DateMismatchError("2024-01-01", "2024-01-02")) + +@task +def task_with_isolated_errors_1(): + """First task with its own error context""" + with error_collection() as task1_accumulator: + task1_accumulator.add_error(EmptyCategoryMapError(url="task1://error")) + return task1_accumulator.model_dump() + +@task +def task_with_isolated_errors_2(): + """Second task with its own error context""" + with error_collection() as task2_accumulator: + task2_accumulator.add_error(DateMismatchError("2024-03-01", "2024-03-02")) + return task2_accumulator.model_dump() + +@flow +def error_accumulation_flow(): + """Flow that forces error conditions""" + with error_collection() as accumulator: + task_that_raises_error.submit().result() + + # Verify errors are collected during processing + assert len(accumulator.errors) == 2, "Should collect 2 errors" + assert isinstance(accumulator.errors[0], EmptyCategoryMapError) + assert isinstance(accumulator.errors[1], DateMismatchError) + +@flow +def concurrent_error_contexts_flow(): + """Flow that tests error context isolation between tasks""" + # Submit both tasks concurrently + task1_future = task_with_isolated_errors_1.submit() + task2_future = task_with_isolated_errors_2.submit() + + # Get results + task1_errors = task1_future.result() + task2_errors = task2_future.result() + + # Verify task1 errors + assert len(task1_errors) == 1, "Task 1 should have 1 error" + assert task1_errors[0]["code"] == "ARG-300" + assert "task1://error" in task1_errors[0]["context"]["url"] + + # Verify task2 errors + assert len(task2_errors) == 1, "Task 2 should have 1 error" + assert task2_errors[0]["code"] == "ARG-200" + assert "2024-03-01" in task2_errors[0]["context"]["external_date"] + +def test_error_reporting(): + """Test the full error accumulation and reporting flow""" + with patch("tsn_adapters.tasks.argentina.flows.preprocess_flow.create_markdown_artifact") as mock_artifact: + error_accumulation_flow() + + # Verify artifact creation + mock_artifact.assert_called_once() + + # Verify artifact content + markdown_content = mock_artifact.call_args[1]["markdown"] + assert "ARG-300" in markdown_content + assert "ARG-200" in markdown_content + assert "test://invalid-map" in markdown_content + assert "Date mismatch: Reported 2024-01-01 vs Actual 2024-01-02" in markdown_content + +def test_context_isolation(): + """Test that error contexts remain isolated between concurrent tasks""" + concurrent_error_contexts_flow() + + +if __name__ == "__main__": + test_error_reporting() diff --git a/tests/argentina/errors/test_error_integration.py b/tests/argentina/errors/test_error_integration.py new file mode 100644 index 0000000..c0a54f7 --- /dev/null +++ b/tests/argentina/errors/test_error_integration.py @@ -0,0 +1,536 @@ +""" +Integration tests for Argentina SEPA error handling system. + +This test suite validates error handling across the data processing pipeline, +focusing on business-critical validations and proper error accumulation. + +Key Test Areas: +1. Structure Validation - ZIP and CSV file integrity +2. Input Validation - Date formats and file requirements +3. Data Processing - Schema validation and content processing +4. Category Mapping - Product categorization and mapping validation +5. Error Recovery - System resilience and error isolation +6. Error Accumulation - Multi-error handling and context management + +Each test focuses on real-world scenarios and validates the complete error lifecycle: +- Error detection and creation +- Context preservation +- Error accumulation +- Final error reporting +""" + +import io +import re +from typing import cast +from unittest.mock import MagicMock, patch, mock_open +import zipfile +import logging +import tempfile +import os +import shutil + +import pandas as pd +import pytest +from prefect.testing.utilities import prefect_test_harness + +from tsn_adapters.tasks.argentina.base_types import DateStr +from tsn_adapters.tasks.argentina.errors.errors import ( + AccountableRole, + DateMismatchError, + EmptyCategoryMapError, + InvalidCSVSchemaError, + InvalidDateFormatError, + InvalidStructureZIPError, + MissingProductIDError, + MissingProductosCSVError, + UncategorizedProductsError, +) +from tsn_adapters.tasks.argentina.errors.accumulator import error_collection +from tsn_adapters.tasks.argentina.flows.preprocess_flow import PreprocessFlow +from tsn_adapters.tasks.argentina.provider.data_processor import process_sepa_zip + + +# -------------------------------------------------- +# Fixtures & Test Data +# -------------------------------------------------- +@pytest.fixture +def mock_s3_block(): + """Mock S3 block for testing with configurable responses""" + mock = MagicMock() + + # Mock basic S3 operations + mock.list_available_keys.return_value = ["2024-01-01"] + mock.get_raw_data_for.return_value = pd.DataFrame({ + 'id_producto': ['P1', 'P2'], + 'productos_descripcion': ['Product 1', 'Product 2'], + 'productos_precio_lista': [100.0, 200.0], + 'date': ['2024-01-01', '2024-01-01'] + }) + + # Mock S3 bucket operations + mock.bucket_name = "mock-bucket" + mock.bucket_folder = "mock-folder" + mock.credentials = MagicMock() + + # Mock S3 methods + mock.read_path.return_value = b'mock_s3_content' + mock.write_path.return_value = None + mock.list_objects.return_value = ['mock_key'] + mock.download_object_to_path.return_value = None + mock.upload_from_path.return_value = None + + return mock + +@pytest.fixture +def valid_sepa_data(): + """Valid SEPA data fixture with realistic product data""" + return pd.DataFrame({ + "date": ["2024-01-01"] * 5, + "id_producto": ["P1", "P2", "P3", "P4", "P5"], + "productos_descripcion": [ + "Leche Entera 1L", + "Pan Francés 1kg", + "Aceite Girasol 900ml", + "Arroz Largo 1kg", + "Azúcar Blanca 1kg" + ], + "precio": [350.0, 800.0, 1200.0, 900.0, 600.0], + }) + +@pytest.fixture +def create_zip_file(mock_zip_operations): + """Factory fixture for creating test ZIP files with various scenarios""" + def _create_zip(content_type: str, **kwargs) -> bytes: + if content_type == "invalid": + # Mock invalid ZIP behavior + with patch("zipfile.is_zipfile", return_value=False): + return b'This is not a ZIP file' + + # For all other cases, prepare the data but don't actually create files + if content_type == "missing_productos": + mock_zip_operations.namelist.return_value = ["sepa_1_comercio-sepa-1_2024-01-01_00-00-00/wrong.csv"] + elif content_type == "wrong_date": + mock_zip_operations.read.return_value = ( + "id_producto|productos_descripcion|productos_precio_lista|date\n" + f"P1|Product 1|100.0|{kwargs.get('date', '2024-01-02')}\n" + ).encode() + elif content_type == "mixed_errors": + mock_zip_operations.read.return_value = ( + "id_producto|productos_descripcion|productos_precio_lista|date\n" + "|Product 1|100.0|2024-01-01\n" # Missing ID + "|Product 2|200.0|2024-01-01\n" # Missing ID + "P3|Product 3|300.0|2024-01-02\n" # Wrong date + "P4|Product 4|400.0|2024-01-02\n" # Wrong date + "P5|Product 5|invalid|2024-01-01\n" # Invalid price + "P6|Product 6|invalid|2024-01-01\n" # Invalid price + ).encode() + elif content_type == "partial_valid": + mock_zip_operations.read.return_value = ( + "id_producto|productos_descripcion|productos_precio_lista|date\n" + "|Product 1|100.0|2024-01-01\n" + "|Product 2|200.0|2024-01-01\n" + "|Product 3|300.0|2024-01-01\n" + "P4|Product 4|400.0|2024-01-01\n" + "P5|Product 5|500.0|2024-01-01\n" + ).encode() + + return b'mock_zip_content' + return _create_zip + +@pytest.fixture +def mock_logger(): + """Mock logger for testing.""" + logger = MagicMock(spec=logging.Logger) + with patch('tsn_adapters.tasks.argentina.flows.base.get_run_logger', return_value=logger): + yield logger + +@pytest.fixture(autouse=True) +def prefect_test_context(): + """Fixture to provide Prefect test context for all tests.""" + with prefect_test_harness(): + yield + +@pytest.fixture(autouse=True) +def mock_filesystem(): + """Mock filesystem operations to speed up tests.""" + mock_temp_dir = "/tmp/mock_temp_dir" + mock_temp = MagicMock() + mock_temp.name = mock_temp_dir + + # Create the temp directory + os.makedirs(mock_temp_dir, exist_ok=True) + + with patch("tempfile.mkdtemp", return_value=mock_temp_dir), \ + patch("os.path.join", lambda *args: "/".join(args)), \ + patch("os.listdir", return_value=["productos.csv"]), \ + patch("os.makedirs", return_value=None), \ + patch("os.path.exists", return_value=True), \ + patch("os.remove", return_value=None), \ + patch("builtins.open", mock_open(read_data="id_producto|productos_descripcion|productos_precio_lista|date\n")), \ + patch("shutil.rmtree", return_value=None): + yield mock_temp_dir + + # Clean up + try: + shutil.rmtree(mock_temp_dir) + except: + pass + +@pytest.fixture(autouse=True) +def mock_zip_operations(): + """Mock ZIP file operations to speed up tests.""" + mock_zip = MagicMock() + mock_zip.namelist.return_value = ["sepa_1_comercio-sepa-1_2024-01-01_00-00-00/productos.csv"] + + with patch("zipfile.ZipFile", return_value=mock_zip), \ + patch("zipfile.is_zipfile", return_value=True): + yield mock_zip + +@pytest.fixture(autouse=True) +def mock_network_operations(): + """Mock all network operations including S3 and HTTP requests.""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.content = b'mock_content' + mock_response.text = 'mock_text' + mock_response.raise_for_status.return_value = None + + with patch('requests.get', return_value=mock_response), \ + patch('requests.post', return_value=mock_response), \ + patch('requests.Session', return_value=MagicMock()), \ + patch('prefect_aws.S3Bucket.read_path', return_value=b'mock_s3_content'), \ + patch('prefect_aws.S3Bucket.download_object_to_path', return_value=None), \ + patch('prefect_aws.S3Bucket.list_objects', return_value=['mock_key']): + yield mock_response + +@pytest.fixture(autouse=True) +def mock_data_tasks(): + """Mock data processing tasks and heavy computations.""" + mock_df = pd.DataFrame({ + 'id_producto': ['P1', 'P2'], + 'productos_descripcion': ['Product 1', 'Product 2'], + 'productos_precio_lista': [100.0, 200.0], + 'date': ['2024-01-01', '2024-01-01'] + }) + + with patch('tsn_adapters.tasks.argentina.task_wrappers.task_create_stream_fetcher', return_value=mock_df), \ + patch('tsn_adapters.tasks.argentina.task_wrappers.task_get_streams', return_value=mock_df), \ + patch('tsn_adapters.tasks.argentina.task_wrappers.task_create_sepa_provider', return_value=mock_df), \ + patch('tsn_adapters.tasks.argentina.task_wrappers.task_get_data_for_date', return_value=mock_df): + yield mock_df + +@pytest.fixture(autouse=True) +def mock_prefect_operations(): + """Mock Prefect operations to avoid actual task runs and flow executions.""" + mock_state = MagicMock() + mock_state.is_completed.return_value = True + mock_state.result.return_value = None + + with patch('prefect.task', lambda *args, **kwargs: lambda f: f), \ + patch('prefect.flow', lambda *args, **kwargs: lambda f: f), \ + patch('prefect.get_run_logger', return_value=MagicMock()), \ + patch('prefect.context.get_run_context', return_value=MagicMock()): + yield mock_state + +# -------------------------------------------------- +# 1. Structure Validation Tests +# -------------------------------------------------- +@pytest.mark.parametrize("invalid_content,expected_error,expected_context", [ + ("invalid", InvalidStructureZIPError, {"source": "test", "date": "2024-01-01", "error": "File is not a zip file"}), + ("missing_productos", MissingProductosCSVError, { + "directory": r"/tmp/tmp[^/]+/data/sepa_1_comercio-sepa-1_2024-01-01_00-00-00", + "available_files": ["wrong.csv"] + }), +]) +def test_zip_structure_validation(create_zip_file, invalid_content, expected_error, expected_context): + """Validate ZIP file structure requirements with detailed context""" + zip_content = create_zip_file(invalid_content) + + def mock_reader(): + yield zip_content + + with error_collection() as accumulator: + try: + process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test") + pytest.fail(f"Expected {expected_error.__name__} to be raised") + except expected_error as e: + accumulator.add_error(e) + assert isinstance(e, expected_error) + assert e.responsibility == AccountableRole.DATA_PROVIDER + # For directory paths, just check the pattern since temp dir will be different + if "directory" in expected_context: + assert re.match(expected_context["directory"], e.context["directory"]) + assert e.context["available_files"] == expected_context["available_files"] + else: + for key, value in expected_context.items(): + assert e.context[key] == value + +# -------------------------------------------------- +# 2. Input Validation Tests +# -------------------------------------------------- +@pytest.mark.parametrize("invalid_date,error_details", [ + ("01-01-2024", {"reason": "wrong_format"}), + ("2024/01/01", {"reason": "wrong_separator"}), + ("20240101", {"reason": "no_separator"}), + ("2024-13-01", {"reason": "invalid_month"}), + ("2024-01-32", {"reason": "invalid_day"}), +]) +def test_date_validation_flow(mock_s3_block, mock_logger, invalid_date, error_details): + """Validate date format requirements with various invalid formats""" + flow = PreprocessFlow( + product_category_map_url="mock://map", + s3_block=mock_s3_block + ) + + with error_collection() as accumulator: + with pytest.raises(InvalidDateFormatError): + flow.validate_date(cast("DateStr", invalid_date)) + + error = accumulator.errors[0] + assert isinstance(error, InvalidDateFormatError) + assert error.code == "ARG-101" + assert error.responsibility == AccountableRole.SYSTEM + assert invalid_date in error.context["invalid_date"] + assert error.context["validation_error"] == error_details["reason"] + +# -------------------------------------------------- +# 3. Data Processing Tests +# -------------------------------------------------- +@pytest.mark.parametrize("scenario", [ + { + "content_date": "2024-01-02", + "filename_date": "2024-01-01", + "description": "future_date" + }, + { + "content_date": "2023-12-31", + "filename_date": "2024-01-01", + "description": "past_date" + }, +]) +def test_date_mismatch_handling(create_zip_file, scenario): + """Test date mismatch detection and handling with various scenarios""" + zip_content = create_zip_file( + "wrong_date", + date=scenario["content_date"] + ) + + def mock_reader(): + yield zip_content + + with error_collection() as accumulator: + with pytest.raises(DateMismatchError): + process_sepa_zip( + mock_reader(), + DateStr(scenario["filename_date"]), + f"test_{scenario['description']}" + ) + + error = accumulator.errors[0] + assert isinstance(error, DateMismatchError) + assert error.code == "ARG-200" + assert error.context["external_date"] == scenario["filename_date"] + assert error.context["internal_date"] == scenario["content_date"] + assert error.context["mismatch_type"] == scenario["description"] + +def test_mixed_error_handling(valid_sepa_data, create_zip_file): + """Test handling of critical errors that stop processing""" + zip_content = create_zip_file("mixed_errors", df=valid_sepa_data) + + def mock_reader(): + yield zip_content + + with pytest.raises(InvalidCSVSchemaError) as exc_info: + process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test_mixed") + + error = exc_info.value + assert error.code == "ARG-201" + assert error.context["date"] == "2024-01-01" + assert error.context["store_id"] == "test_mixed" + +# -------------------------------------------------- +# 4. Category Mapping Tests +# -------------------------------------------------- +@pytest.mark.parametrize("category_map,expected_error,validation_details", [ + (pd.DataFrame(), EmptyCategoryMapError, {"url": "mock://map"}), + (pd.DataFrame({"wrong_column": []}), InvalidCSVSchemaError, { + "date": "2024-01-01", + "store_id": "test", + "missing_columns": ["id_producto", "category"] + }), +]) +def test_category_mapping_validation(mock_s3_block, category_map, expected_error, validation_details): + """Test category mapping validation with various invalid scenarios""" + mock_map_loader = MagicMock(return_value=category_map) + + with patch('tsn_adapters.tasks.argentina.task_wrappers.task_load_category_map', mock_map_loader): + flow = PreprocessFlow( + product_category_map_url="mock://map", + s3_block=mock_s3_block + ) + + with error_collection() as accumulator: + if expected_error == EmptyCategoryMapError: + flow.process_date(DateStr("2024-01-01")) + error = accumulator.errors[0] + assert isinstance(error, EmptyCategoryMapError) + assert error.context["url"] == validation_details["url"] + else: + # For InvalidCSVSchemaError + with pytest.raises(expected_error): + flow.process_date(DateStr(validation_details["date"])) + error = accumulator.errors[0] + assert isinstance(error, expected_error) + for key, value in validation_details.items(): + assert error.context[key] == value + +# -------------------------------------------------- +# 5. Error Recovery Tests +# -------------------------------------------------- +def test_partial_processing_recovery(valid_sepa_data, create_zip_file): + """Test system's ability to process valid records when others fail""" + zip_content = create_zip_file("partial_valid", df=valid_sepa_data) + + def mock_reader(): + yield zip_content + + with error_collection() as accumulator: + result = process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test_partial") + + # Verify errors were captured + assert len(accumulator.errors) > 0 + error = accumulator.errors[0] + assert isinstance(error, MissingProductIDError) + assert error.context["missing_count"] == 3 + assert error.context["date"] == "2024-01-01" + assert error.context["store_id"] == "test_partial" + + # Verify valid records were processed + assert len(result) == 2 # Last 2 records should be valid + assert all(pd.notna(result["id_producto"])) + +def test_error_recovery_with_retries(): + """Test error recovery with retry mechanism""" + retry_attempts = 0 + max_retries = 3 + + def failing_operation(): + nonlocal retry_attempts + retry_attempts += 1 + if retry_attempts < max_retries: + # InvalidCSVSchemaError takes date and store_id + raise InvalidCSVSchemaError( + date=DateStr("2024-01-01"), + store_id="test_store" + ) + return True # Succeed on final attempt + + with error_collection() as accumulator: + # Simulate retry logic + success = False + for _ in range(max_retries): + try: + success = failing_operation() + if success: + break + except InvalidCSVSchemaError as e: # Catch specific error type + accumulator.add_error(e) + + assert success # Operation eventually succeeded + assert len(accumulator.errors) == max_retries - 1 # Errors from failed attempts + assert all(isinstance(e, InvalidCSVSchemaError) for e in accumulator.errors) + +# -------------------------------------------------- +# 6. Error Accumulation Tests +# -------------------------------------------------- +def test_error_context_isolation(): + """Verify error context isolation between parallel flows""" + with error_collection() as flow1_errors: + with error_collection() as flow2_errors: + flow1_errors.add_error(EmptyCategoryMapError(url="map1")) + flow2_errors.add_error(DateMismatchError( + external_date="2024-01-01", + internal_date="2024-01-02" + )) + + assert len(flow2_errors.errors) == 1 + assert isinstance(flow2_errors.errors[0], DateMismatchError) + assert flow2_errors.errors[0].context["external_date"] == "2024-01-01" + assert flow2_errors.errors[0].context["internal_date"] == "2024-01-02" + + assert len(flow1_errors.errors) == 1 + assert isinstance(flow1_errors.errors[0], EmptyCategoryMapError) + assert flow1_errors.errors[0].context["url"] == "map1" + +def test_error_accumulation_resilience(): + """Test error accumulator's ability to handle multiple errors""" + with error_collection() as accumulator: + errors = [ + InvalidCSVSchemaError(date=DateStr("2024-01-01"), store_id="STORE1"), + UncategorizedProductsError(count=3, date=DateStr("2024-01-01"), store_id="STORE1"), + MissingProductIDError(count=2, date=DateStr("2024-01-01"), store_id="STORE1"), + ] + + for error in errors: + accumulator.add_error(error) + + assert len(accumulator.errors) == len(errors) + for original, captured in zip(errors, accumulator.errors): + assert type(original) == type(captured) + assert original.code == captured.code + assert original.responsibility == captured.responsibility + assert original.context == captured.context + +def test_error_accumulation_order(): + """Verify error accumulation preserves order and priority""" + with error_collection() as accumulator: + # Add errors with different priorities + errors = [ + (InvalidStructureZIPError(context={"source": "test", "date": "2024-01-01"}), 1), # High priority + (MissingProductIDError(count=2, date=DateStr("2024-01-01"), store_id="test"), 3), # Low priority + (DateMismatchError(external_date="2024-01-01", internal_date="2024-01-02"), 2), # Medium priority + ] + + for error, _ in errors: + accumulator.add_error(error) + + # Verify errors are stored in order of addition + for (original, _), captured in zip(errors, accumulator.errors): + assert type(original) == type(captured) + assert original.code == captured.code + assert original.context == captured.context + +def test_error_accumulation(prefect_test_context, create_zip_file, mock_s3_block): + """Test that errors are properly accumulated during preprocessing""" + # Create test data with missing product IDs + test_data = pd.DataFrame({ + 'date': ['2024-01-01'] * 3, + 'id_producto': [None, None, 'P3'], + 'productos_descripcion': ['Product 1', 'Product 2', 'Product 3'], + 'precio': [100.0, 200.0, 300.0] + }) + + # Create a temporary CSV file with test data + with tempfile.NamedTemporaryFile(suffix='.csv', mode='w', delete=False) as f: + test_data.to_csv(f, index=False) + temp_csv = f.name + + try: + # Run preprocessing flow + flow = PreprocessFlow( + product_category_map_url="mock://map", + s3_block=mock_s3_block + ) + with pytest.raises(MissingProductIDError) as exc_info: + flow.process_date(DateStr("2024-01-01")) + + error = exc_info.value + assert error.code == "ARG-202" # Fixed error code + assert error.context["missing_count"] == 2 + assert error.context["date"] == "2024-01-01" + finally: + # Clean up + os.unlink(temp_csv) + +if __name__ == "__main__": + pytest.main(["-v", __file__]) \ No newline at end of file From a5b3fe8d6cd80f4907534b762144a82904dbe547 Mon Sep 17 00:00:00 2001 From: Raffael Campos Date: Fri, 21 Feb 2025 09:34:13 -0300 Subject: [PATCH 2/2] refactor(errors): Enhance error handling and context management in Argentina flows - Introduce comprehensive error context management with `context.py` and `context_helper.py` - Implement flexible error context tracking using `ContextProperty` descriptor - Update error classes to leverage global error context for more dynamic error reporting - Refactor error accumulation to support more granular and context-aware error handling - Improve error traceability by adding context attributes like `store_id`, `date`, and `file_key` - Simplify error creation and reduce redundant context parameters - Add support for concurrent error context isolation --- src/tsn_adapters/blocks/tn_access.py | 2 +- .../common/trufnetwork/models/tn_models.py | 1 + src/tsn_adapters/flows/fmp/historical_flow.py | 24 +- src/tsn_adapters/flows/fmp/real_time_flow.py | 19 +- src/tsn_adapters/flows/stream_deploy_flow.py | 4 +- .../aggregate/category_price_aggregator.py | 10 +- .../argentina/aggregate/uncategorized.py | 1 - .../tasks/argentina/errors/__init__.py | 4 +- .../tasks/argentina/errors/accumulator.py | 14 +- .../tasks/argentina/errors/context.py | 47 + .../tasks/argentina/errors/context_helper.py | 46 + .../tasks/argentina/errors/errors.py | 57 +- .../tasks/argentina/flows/base.py | 27 +- .../tasks/argentina/flows/ingest_flow.py | 2 +- .../tasks/argentina/flows/preprocess_flow.py | 12 +- .../argentina/provider/data_processor.py | 28 +- .../tasks/argentina/target/trufnetwork.py | 1 - .../tasks/argentina/task_wrappers.py | 6 +- .../utils/processors/resource_processor.py | 46 +- src/tsn_adapters/utils/filter_failures.py | 8 + .../errors/test_error_accumulation.py | 74 +- .../errors/test_error_integration.py | 868 ++++++++---------- .../errors/test_global_context_scope.py | 156 ++++ tests/argentina/flows/test_preprocess_flow.py | 13 +- tests/blocks/test_split_and_insert_records.py | 52 +- .../test_stream_deployment_initialization.py | 13 +- tests/blocks/test_tn_network_error.py | 16 +- tests/fixtures/test_trufnetwork.py | 48 +- tests/flows/test_stream_deploy_flow.py | 18 +- tests/fmp/test_fmpblock_integration.py | 62 +- tests/fmp/test_historical_flow.py | 3 +- tests/fmp/test_real_time_flow.py | 4 +- 32 files changed, 972 insertions(+), 714 deletions(-) create mode 100644 src/tsn_adapters/tasks/argentina/errors/context.py create mode 100644 src/tsn_adapters/tasks/argentina/errors/context_helper.py create mode 100644 tests/argentina/errors/test_global_context_scope.py diff --git a/src/tsn_adapters/blocks/tn_access.py b/src/tsn_adapters/blocks/tn_access.py index 4cdaec8..34b62c7 100644 --- a/src/tsn_adapters/blocks/tn_access.py +++ b/src/tsn_adapters/blocks/tn_access.py @@ -44,7 +44,7 @@ def _retry_condition(task: Task[Any, Any], task_run: TaskRun, state: State[Any]) except TNNodeNetworkError: # Always retry on network errors return True - except Exception as exc: + except Exception: # For non-network errors, use the task_run's run_count to decide. if task_run.run_count <= max_other_error_retries: return True diff --git a/src/tsn_adapters/common/trufnetwork/models/tn_models.py b/src/tsn_adapters/common/trufnetwork/models/tn_models.py index 30e8814..3664706 100644 --- a/src/tsn_adapters/common/trufnetwork/models/tn_models.py +++ b/src/tsn_adapters/common/trufnetwork/models/tn_models.py @@ -28,6 +28,7 @@ class Config(pa.DataFrameModel.Config): coerce = True strict = "filter" + class StreamLocatorModel(pa.DataFrameModel): stream_id: Series[str] data_provider: Series[pa.String] = pa.Field( diff --git a/src/tsn_adapters/flows/fmp/historical_flow.py b/src/tsn_adapters/flows/fmp/historical_flow.py index 2cdd5dc..8b898db 100644 --- a/src/tsn_adapters/flows/fmp/historical_flow.py +++ b/src/tsn_adapters/flows/fmp/historical_flow.py @@ -145,19 +145,19 @@ def get_earliest_data_date(tn_block: TNAccessBlock, stream_id: str) -> Optional[ raise TNQueryError(str(e)) from e -def ensure_unix_timestamp(dt: pd.Series) -> pd.Series: # type: ignore -> pd Series doesn't fit with any +def ensure_unix_timestamp(dt: pd.Series) -> pd.Series: # type: ignore -> pd Series doesn't fit with any """Convert datetime series to Unix timestamp (seconds since epoch). - + This function handles various datetime formats and ensures the output is always in seconds since epoch (Unix timestamp). It explicitly handles the conversion from nanoseconds and validates the output range. - + Args: dt: A pandas Series containing datetime data in various formats - + Returns: A pandas Series containing Unix timestamps (seconds since epoch) - + Raises: ValueError: If the resulting timestamps are outside the valid range or if the conversion results in unexpected units @@ -165,24 +165,24 @@ def ensure_unix_timestamp(dt: pd.Series) -> pd.Series: # type: ignore -> pd Seri # Convert to datetime if not already if not pd.api.types.is_datetime64_any_dtype(dt): dt = pd.to_datetime(dt, utc=True) - + # Get nanoseconds since epoch - ns_timestamps = dt.astype('int64') - + ns_timestamps = dt.astype("int64") + # Convert to seconds (integer division by 1e9 for nanoseconds) second_timestamps = ns_timestamps // 10**9 - + # Validate the range (basic sanity check) # Unix timestamps should be between 1970 and 2100 approximately min_valid_timestamp = 0 # 1970-01-01 max_valid_timestamp = 4102444800 # 2100-01-01 - + if (second_timestamps < min_valid_timestamp).any() or (second_timestamps > max_valid_timestamp).any(): raise ValueError( f"Converted timestamps outside valid range: " f"min={second_timestamps.min()}, max={second_timestamps.max()}" ) - + return second_timestamps @@ -292,7 +292,7 @@ def run_ticker_pipeline( Exceptions during processing are logged and raised. """ - def process_ticker(row: pd.Series) -> TickerResult: # type: ignore -> pd Series doesn't fit with any + def process_ticker(row: pd.Series) -> TickerResult: # type: ignore -> pd Series doesn't fit with any """Process a single ticker row.""" symbol = row["source_id"] stream_id = row["stream_id"] diff --git a/src/tsn_adapters/flows/fmp/real_time_flow.py b/src/tsn_adapters/flows/fmp/real_time_flow.py index b5b0cef..7f2f112 100644 --- a/src/tsn_adapters/flows/fmp/real_time_flow.py +++ b/src/tsn_adapters/flows/fmp/real_time_flow.py @@ -97,21 +97,22 @@ def convert_quotes_to_tn_data( return DataFrame[TnDataRowModel](result_df) + def ensure_unix_timestamp(time: int) -> int: """ Ensure the timestamp is a valid unix timestamp (seconds since epoch). - + This function validates and converts timestamps to ensure they are in seconds since epoch format. It handles cases where the input might be in milliseconds or microseconds. - + Args: time: Integer timestamp that might be in seconds, milliseconds, microseconds, or nanoseconds since epoch - + Returns: Integer timestamp in seconds since epoch - + Raises: ValueError: If the timestamp is invalid or outside the reasonable range """ @@ -121,7 +122,7 @@ def ensure_unix_timestamp(time: int) -> int: # Define valid range for seconds since epoch min_valid_timestamp = 0 # 1970-01-01 max_valid_timestamp = 4102444800 # 2100-01-01 - + # Convert to seconds if in a larger unit converted_time = time if time > max_valid_timestamp: @@ -134,13 +135,11 @@ def ensure_unix_timestamp(time: int) -> int: converted_time = time // 10**3 else: # assume seconds but with some future date converted_time = time // 10**9 # aggressive conversion to be safe - + # Validate the range after conversion if converted_time < min_valid_timestamp or converted_time > max_valid_timestamp: - raise ValueError( - f"Timestamp outside valid range (1970-2100): {converted_time}" - ) - + raise ValueError(f"Timestamp outside valid range (1970-2100): {converted_time}") + return converted_time diff --git a/src/tsn_adapters/flows/stream_deploy_flow.py b/src/tsn_adapters/flows/stream_deploy_flow.py index 038a9ed..75aac68 100644 --- a/src/tsn_adapters/flows/stream_deploy_flow.py +++ b/src/tsn_adapters/flows/stream_deploy_flow.py @@ -8,7 +8,6 @@ concurrency limiter and then waits for confirmation using tn-read concurrency via TNAccessBlock. """ -import time from typing import Literal from pandera.typing import DataFrame @@ -32,6 +31,7 @@ class DeployStreamResult(TypedDict): stream_id: str status: Literal["deployed", "skipped"] + @task(tags=["tn", "tn-write"], retries=3, retry_delay_seconds=5) def check_and_deploy_stream(stream_id: str, tna_block: TNAccessBlock, is_unix: bool = False) -> DeployStreamResult: """ @@ -109,7 +109,7 @@ def deploy_streams_flow( logger.info(f"Found {len(stream_ids)} stream descriptors.") # we will deploy in batches of 500 to avoid infinite threads creation - batches = [stream_ids[i:i+batch_size] for i in range(0, len(stream_ids), batch_size)] + batches = [stream_ids[i : i + batch_size] for i in range(0, len(stream_ids), batch_size)] aggregated_results: list[DeployStreamResult] = [] for batch in batches[start_from_batch:]: diff --git a/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py b/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py index 6b4edfc..cd8a63f 100644 --- a/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py +++ b/src/tsn_adapters/tasks/argentina/aggregate/category_price_aggregator.py @@ -96,10 +96,10 @@ def aggregate_prices_by_category( if not uncategorized_df.empty: logger.warning(f"Found {len(uncategorized_df)} uncategorized products") accumulator = ErrorAccumulator.get_or_create_from_context() - accumulator.add_error(UncategorizedProductsError( - count=len(uncategorized_df), - date=avg_price_product_df["date"].iloc[0], - store_id=avg_price_product_df["id_comercio"].iloc[0] - )) + accumulator.add_error( + UncategorizedProductsError( + count=len(uncategorized_df), + ) + ) return AggregatedPricesDF(aggregated_df), uncategorized_df diff --git a/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py b/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py index 5b29ff1..ab6abe3 100644 --- a/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py +++ b/src/tsn_adapters/tasks/argentina/aggregate/uncategorized.py @@ -11,4 +11,3 @@ def get_uncategorized_products( diff_df = data[~data["id_producto"].isin(category_map["id_producto"])] return UncategorizedDF(diff_df) - diff --git a/src/tsn_adapters/tasks/argentina/errors/__init__.py b/src/tsn_adapters/tasks/argentina/errors/__init__.py index 6346fdf..8d36784 100644 --- a/src/tsn_adapters/tasks/argentina/errors/__init__.py +++ b/src/tsn_adapters/tasks/argentina/errors/__init__.py @@ -12,7 +12,7 @@ InvalidCSVSchemaError, InvalidDateFormatError, InvalidStructureZIPError, - MissingProductIDError, + InvalidProductsError, MissingProductosCSVError, UncategorizedProductsError, ) @@ -27,7 +27,7 @@ "InvalidCSVSchemaError", "InvalidDateFormatError", "InvalidStructureZIPError", - "MissingProductIDError", + "InvalidProductsError", "MissingProductosCSVError", "UncategorizedProductsError", ] \ No newline at end of file diff --git a/src/tsn_adapters/tasks/argentina/errors/accumulator.py b/src/tsn_adapters/tasks/argentina/errors/accumulator.py index b548f05..0558847 100644 --- a/src/tsn_adapters/tasks/argentina/errors/accumulator.py +++ b/src/tsn_adapters/tasks/argentina/errors/accumulator.py @@ -3,25 +3,27 @@ from typing import Any from prefect.artifacts import create_markdown_artifact + from tsn_adapters.tasks.argentina.errors.errors import ArgentinaSEPAError -error_ctx = ContextVar("argentina_sepa_errors") +error_ctx = ContextVar["ErrorAccumulator | None"]("argentina_sepa_errors") + class ErrorAccumulator: def __init__(self): - self.errors = [] + self.errors: list[ArgentinaSEPAError] = [] def add_error(self, error: ArgentinaSEPAError): self.errors.append(error) def model_dump(self): return [error.to_dict() for error in self.errors] - + def model_load(self, data: list[dict[str, Any]]): self.errors = [ArgentinaSEPAError(**error) for error in data] @classmethod - def get_or_create_from_context(cls) -> 'ErrorAccumulator': + def get_or_create_from_context(cls) -> "ErrorAccumulator": errors = error_ctx.get() if errors is None: errors = ErrorAccumulator() @@ -31,7 +33,6 @@ def get_or_create_from_context(cls) -> 'ErrorAccumulator': def set_to_context(self): error_ctx.set(self) - @contextmanager def error_collection(): """Context manager for collecting errors during processing.""" @@ -39,6 +40,9 @@ def error_collection(): accumulator.set_to_context() try: yield accumulator + except ArgentinaSEPAError as e: + accumulator.add_error(e) + raise finally: # Create error summary if there are errors if accumulator.errors: diff --git a/src/tsn_adapters/tasks/argentina/errors/context.py b/src/tsn_adapters/tasks/argentina/errors/context.py new file mode 100644 index 0000000..4004a24 --- /dev/null +++ b/src/tsn_adapters/tasks/argentina/errors/context.py @@ -0,0 +1,47 @@ +from collections.abc import Generator +from contextlib import contextmanager +import contextvars +from typing import Any + +# A context variable to hold a dictionary of key/value pairs for error context. +_error_context_var = contextvars.ContextVar("error_context", default={}) + + +def set_error_context(key: str, value: Any) -> None: + """ + Set a key/value pair in the global error context. + """ + current = _error_context_var.get().copy() + current[key] = value + _error_context_var.set(current) + + +def get_error_context() -> dict[str, Any]: + """ + Retrieve the current global error context. + """ + return _error_context_var.get() + + +def clear_error_context() -> None: + """ + Clear the global error context. + """ + _error_context_var.set({}) + + +@contextmanager +def error_context(**kwargs: Any) -> Generator[None, None, None]: + """ + Context manager to temporarily update the error context. + Example: + with error_context(store_id="STORE-123", user_id="USER-456"): + ...your logic... + """ + current = get_error_context().copy() + current.update(kwargs) + token = _error_context_var.set(current) + try: + yield + finally: + _error_context_var.reset(token) diff --git a/src/tsn_adapters/tasks/argentina/errors/context_helper.py b/src/tsn_adapters/tasks/argentina/errors/context_helper.py new file mode 100644 index 0000000..99a3074 --- /dev/null +++ b/src/tsn_adapters/tasks/argentina/errors/context_helper.py @@ -0,0 +1,46 @@ +""" +Helper class to store known context attributes for error management in Argentina flows. + +This class encapsulates common context attributes (e.g., store_id, date, file_key) +which can be used during error reporting. When an error is raised, you can pass +the dictionary produced by this helper to automatically include this extra metadata. +""" + +from typing import Any, Optional + +from tsn_adapters.tasks.argentina.errors.context import get_error_context, set_error_context + + +class ContextProperty: + """Descriptor for managing context properties in a DRY and type-safe manner.""" + + def __init__(self, key: str) -> None: + self.key = key + + def __get__(self, instance: Any, owner: Any) -> Optional[str]: + return get_error_context().get(self.key) + + def __set__(self, instance: Any, value: str) -> None: + set_error_context(self.key, value) + # Optionally update the instance dictionary if needed + instance.__dict__[self.key] = value + + +class ArgentinaErrorContext: + """ + Helper class for managing known context attributes in Argentina flows. + + Attributes + ---------- + store_id : Optional[str] + The identifier for the store, if applicable. + date : Optional[str] + The relevant date (e.g., report date or flow execution date) in ISO format (YYYY-MM-DD). + file_key : Optional[str] + The file key or resource identifier involved in the error. + """ + + # DRY properties using the descriptor: + store_id: ContextProperty = ContextProperty("store_id") + date: ContextProperty = ContextProperty("date") + file_key: ContextProperty = ContextProperty("file_key") diff --git a/src/tsn_adapters/tasks/argentina/errors/errors.py b/src/tsn_adapters/tasks/argentina/errors/errors.py index 1fe6671..93f0947 100644 --- a/src/tsn_adapters/tasks/argentina/errors/errors.py +++ b/src/tsn_adapters/tasks/argentina/errors/errors.py @@ -7,6 +7,9 @@ from pydantic import BaseModel +# NEW: Import the global error context helper +from tsn_adapters.tasks.argentina.errors.context_helper import ArgentinaErrorContext + class AccountableRole(Enum): """Roles responsible for different types of errors.""" @@ -19,6 +22,7 @@ class AccountableRole(Enum): class ArgentinaSEPAErrorData(BaseModel): """Data model for Argentina SEPA errors.""" + code: str message: str responsibility: AccountableRole @@ -36,7 +40,7 @@ def __init__( code=code, message=message, responsibility=responsibility, - context=context or {} + context=context or {}, ) @property @@ -61,7 +65,7 @@ def to_dict(self) -> dict[str, Any]: "code": self.code, "message": self.message, "responsibility": self.responsibility, - "context": self.context + "context": self.context, } @@ -71,12 +75,16 @@ def to_dict(self) -> dict[str, Any]: class InvalidStructureZIPError(ArgentinaSEPAError): """Invalid ZIP file structure during extraction""" - def __init__(self, context: dict[str, Any]): + def __init__(self, error: str): + ctx = ArgentinaErrorContext() super().__init__( code="ARG-100", message="Invalid ZIP file structure - cannot extract files", responsibility=AccountableRole.DATA_PROVIDER, - context=context, + context={ + "date": ctx.date, + "error": error, + }, ) @@ -95,12 +103,12 @@ def __init__(self, date_str: str): class MissingProductosCSVError(ArgentinaSEPAError): """Missing productos.csv in ZIP file""" - def __init__(self, context: dict[str, Any]): + def __init__(self, directory: str, available_files: list[str]): super().__init__( code="ARG-102", message="Missing productos.csv in ZIP archive", responsibility=AccountableRole.DATA_PROVIDER, - context=context, + context={"directory": directory, "available_files": ", ".join(available_files)}, ) @@ -110,13 +118,14 @@ def __init__(self, context: dict[str, Any]): class DateMismatchError(ArgentinaSEPAError): """Filename vs content date mismatch""" - def __init__(self, external_date: str, internal_date: str): + def __init__(self, internal_date: str): + ctx = ArgentinaErrorContext() super().__init__( code="ARG-200", - message=f"Date mismatch: Reported {external_date} vs Actual {internal_date}", + message=f"Date mismatch: Reported {ctx.date} vs Actual {internal_date}", responsibility=AccountableRole.DATA_PROVIDER, context={ - "external_date": external_date, + "external_date": ctx.date, "internal_date": internal_date, }, ) @@ -125,24 +134,31 @@ def __init__(self, external_date: str, internal_date: str): class InvalidCSVSchemaError(ArgentinaSEPAError): """Missing required columns in RAW data""" - def __init__(self, date: str, store_id: str): + def __init__(self, error: str): + ctx = ArgentinaErrorContext() super().__init__( code="ARG-201", message="Missing required columns", responsibility=AccountableRole.DATA_PROVIDER, - context={"date": date, "store_id": store_id}, + context={ + "date": ctx.date, + "store_id": ctx.store_id, + "error": error, + }, ) -class MissingProductIDError(ArgentinaSEPAError): +class InvalidProductsError(ArgentinaSEPAError): """Null/empty product IDs found in RAW data""" - def __init__(self, count: int, date: str, store_id: str): + def __init__(self, invalid_indexes: list[int]): + ctx = ArgentinaErrorContext() + invalid_indexes_str = ", ".join(str(idx) for idx in invalid_indexes) super().__init__( code="ARG-202", - message=f"{count} products missing IDs", + message=f"{len(invalid_indexes)} products with invalid IDs", responsibility=AccountableRole.DEVELOPMENT, - context={"missing_count": count, "date": date, "store_id": store_id}, + context={"invalid_indexes": invalid_indexes_str, "date": ctx.date, "store_id": ctx.store_id}, ) @@ -164,24 +180,25 @@ def __init__(self, url: str): class UncategorizedProductsError(ArgentinaSEPAError): """Products without category mapping""" - def __init__(self, count: int, date: str, store_id: str): + def __init__(self, count: int): + ctx = ArgentinaErrorContext() super().__init__( code="ARG-301", message=f"{count} uncategorized products found", responsibility=AccountableRole.DATA_ENGINEERING, - context={"uncategorized_count": count, "date": date, "store_id": store_id}, + context={"uncategorized_count": str(count), "date": ctx.date}, ) class InvalidCategorySchemaError(ArgentinaSEPAError): """Invalid category mapping schema""" - def __init__(self, issues: list[str]): + def __init__(self, error: str, url: str): super().__init__( code="ARG-302", message="Invalid category mapping schema", responsibility=AccountableRole.DATA_ENGINEERING, - context={"validation_issues": issues}, + context={"error": error, "url": url}, ) @@ -191,7 +208,7 @@ def __init__(self, issues: list[str]): MissingProductosCSVError, DateMismatchError, InvalidCSVSchemaError, - MissingProductIDError, + InvalidProductsError, EmptyCategoryMapError, UncategorizedProductsError, InvalidCategorySchemaError, diff --git a/src/tsn_adapters/tasks/argentina/flows/base.py b/src/tsn_adapters/tasks/argentina/flows/base.py index 6411012..f893451 100644 --- a/src/tsn_adapters/tasks/argentina/flows/base.py +++ b/src/tsn_adapters/tasks/argentina/flows/base.py @@ -2,8 +2,8 @@ Base flow controller for Argentina SEPA data processing. """ -import re from datetime import datetime +import re from prefect import get_run_logger from prefect_aws import S3Bucket @@ -38,11 +38,24 @@ def validate_date(self, date: DateStr) -> None: Raises: InvalidDateFormatError: If date format is invalid """ - # Check basic format and validity + # First check for separators + if "/" in date: + raise InvalidDateFormatError(date, "wrong_separator") + if "-" not in date: + raise InvalidDateFormatError(date, "no_separator") + + # Check basic format + if not re.match(r"\d{4}-\d{2}-\d{2}", date): + raise InvalidDateFormatError(date, "wrong_format") + + # Check date validity try: - if not re.match(r"\d{4}-\d{2}-\d{2}", date): - raise InvalidDateFormatError(date) - datetime.strptime(date, "%Y-%m-%d") - except ValueError: - raise InvalidDateFormatError(date) + except ValueError as e: + error_msg = str(e) + if "month must be in 1..12" in error_msg: + raise InvalidDateFormatError(date, "invalid_month") + if "day is out of range for month" in error_msg or "day must be in" in error_msg: + raise InvalidDateFormatError(date, "invalid_day") + # Only raise wrong_format if it's not a specific month/day error + raise InvalidDateFormatError(date, "wrong_format") diff --git a/src/tsn_adapters/tasks/argentina/flows/ingest_flow.py b/src/tsn_adapters/tasks/argentina/flows/ingest_flow.py index b865722..e8bde8b 100644 --- a/src/tsn_adapters/tasks/argentina/flows/ingest_flow.py +++ b/src/tsn_adapters/tasks/argentina/flows/ingest_flow.py @@ -15,7 +15,7 @@ from prefect.artifacts import create_markdown_artifact from prefect_aws import S3Bucket -from tsn_adapters.common.trufnetwork.models.tn_models import TnDataRowModel, TnRecordModel +from tsn_adapters.common.trufnetwork.models.tn_models import TnDataRowModel from tsn_adapters.tasks.argentina.flows.base import ArgentinaFlowController from tsn_adapters.tasks.argentina.target import create_trufnetwork_components from tsn_adapters.tasks.argentina.task_wrappers import ( diff --git a/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py b/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py index 6481e54..93ce5c5 100644 --- a/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py +++ b/src/tsn_adapters/tasks/argentina/flows/preprocess_flow.py @@ -8,7 +8,6 @@ 4. Storage of processed data in S3 """ -from contextlib import contextmanager from typing import cast import pandas as pd @@ -20,13 +19,15 @@ from tsn_adapters.tasks.argentina.aggregate import aggregate_prices_by_category from tsn_adapters.tasks.argentina.errors import ( ArgentinaSEPAError, - ErrorAccumulator, ) +from tsn_adapters.tasks.argentina.errors.accumulator import error_collection +from tsn_adapters.tasks.argentina.errors.context_helper import ArgentinaErrorContext from tsn_adapters.tasks.argentina.flows.base import ArgentinaFlowController from tsn_adapters.tasks.argentina.models.sepa.sepa_models import SepaAvgPriceProductModel from tsn_adapters.tasks.argentina.provider.s3 import RawDataProvider from tsn_adapters.tasks.argentina.task_wrappers import task_load_category_map from tsn_adapters.tasks.argentina.types import AggregatedPricesDF, CategoryMapDF, DateStr, SepaDF, UncategorizedDF +from tsn_adapters.tasks.argentina.utils.processors import process_raw_data from tsn_adapters.utils import deroutine @@ -38,7 +39,7 @@ def process_raw_data( """Process raw SEPA data. Args: - data_item: Raw data item to process + raw_data: Raw data item to process category_map_df: Category mapping DataFrame Returns: @@ -111,12 +112,17 @@ def process_date(self, date: DateStr) -> None: Raises: ValueError: If date format is invalid KeyError: If no data available for date + EmptyCategoryMapError: If category map is empty + InvalidCSVSchemaError: If category map has invalid schema """ logger = get_run_logger() self.validate_date(date) logger.info(f"Processing {date}") + # Set up error context + ArgentinaErrorContext().date = date + # Get raw data raw_data = self.raw_provider.get_raw_data_for(date) if raw_data.empty: diff --git a/src/tsn_adapters/tasks/argentina/provider/data_processor.py b/src/tsn_adapters/tasks/argentina/provider/data_processor.py index b6e550d..8aff109 100644 --- a/src/tsn_adapters/tasks/argentina/provider/data_processor.py +++ b/src/tsn_adapters/tasks/argentina/provider/data_processor.py @@ -7,7 +7,6 @@ import tempfile from typing import cast -import pandas as pd from prefect.concurrency.sync import concurrency from tsn_adapters.tasks.argentina.errors.errors import DateMismatchError, InvalidStructureZIPError @@ -24,7 +23,7 @@ def process_sepa_zip( Process SEPA data from a data item. Args: -zip_reader: Generator yielding bytes of the ZIP file + zip_reader: Generator yielding bytes of the ZIP file source_name: Name of the source (for error messages) reported_date: The date reported by the source @@ -53,19 +52,16 @@ def process_sepa_zip( os.makedirs(extract_dir, exist_ok=True) try: processor = SepaDirectoryProcessor.from_zip_path(temp_zip_path, extract_dir) - except Exception as e: - raise InvalidStructureZIPError({"source": source_name, "date": reported_date, "error": str(e)}) from e - - df = processor.get_all_products_data_merged() + data = processor.get_all_products_data_merged() - # skip empty dataframes - if df.empty: - return cast(SepaDF, pd.DataFrame()) + # Validate date matches + if not data.empty: + data_date: str = cast(str, data["date"].iloc[0]) + if str(data_date) != str(reported_date): + raise DateMismatchError(internal_date=data_date) - # Validate the date matches - real_date = df["date"].iloc[0] - if reported_date != real_date: - # we need to raise an error so cache is invalidated - raise DateMismatchError(external_date=reported_date, internal_date=real_date) - - return cast(SepaDF, df) + return data + except Exception as e: + if isinstance(e, DateMismatchError): + raise + raise InvalidStructureZIPError(str(e)) from e diff --git a/src/tsn_adapters/tasks/argentina/target/trufnetwork.py b/src/tsn_adapters/tasks/argentina/target/trufnetwork.py index b886942..55490ac 100644 --- a/src/tsn_adapters/tasks/argentina/target/trufnetwork.py +++ b/src/tsn_adapters/tasks/argentina/target/trufnetwork.py @@ -16,7 +16,6 @@ task_insert_and_wait_for_tx, task_read_records, task_split_and_insert_records, - task_wait_for_tx, ) from tsn_adapters.common.interfaces.target import ITargetClient from tsn_adapters.common.trufnetwork.models.tn_models import TnDataRowModel diff --git a/src/tsn_adapters/tasks/argentina/task_wrappers.py b/src/tsn_adapters/tasks/argentina/task_wrappers.py index ca4099c..491194a 100644 --- a/src/tsn_adapters/tasks/argentina/task_wrappers.py +++ b/src/tsn_adapters/tasks/argentina/task_wrappers.py @@ -244,21 +244,21 @@ def task_get_and_transform_data( @task(retries=3, cache_expiration=timedelta(hours=1), cache_policy=policies.INPUTS + policies.TASK_SOURCE) -def task_load_category_map(url: str) -> pd.DataFrame: +def task_load_category_map(url: str) -> DataFrame[SepaProductCategoryMapModel]: """Load the product category mapping from a URL.""" logger = get_run_logger() logger.info(f"Loading category map from: {url}") try: df = SepaProductCategoryMapModel.from_url(url, sep="|", compression="zip") if df.empty: - raise EmptyCategoryMapError() + raise EmptyCategoryMapError(url=url) logger.info(f"Loaded {len(df)} category mappings") return df except EmptyCategoryMapError: raise except Exception as e: logger.error(f"Failed to load category map: {e}") - raise InvalidCategorySchemaError(issues=[f"Failed to load category map from {url}", f"Error: {e!s}"]) + raise InvalidCategorySchemaError(error=str(e), url=url) @task diff --git a/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py b/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py index 04c6d23..5d852d9 100644 --- a/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py +++ b/src/tsn_adapters/tasks/argentina/utils/processors/resource_processor.py @@ -6,12 +6,13 @@ from pathlib import Path import re import tempfile -from typing import ClassVar, cast +from typing import Any, ClassVar import pandas as pd from pandera.typing import DataFrame from pydantic import BaseModel, field_validator +from tsn_adapters.tasks.argentina.errors.accumulator import ErrorAccumulator from tsn_adapters.tasks.argentina.errors.errors import ( InvalidCSVSchemaError, InvalidDateFormatError, @@ -100,12 +101,14 @@ def get_all_products_data_merged(self) -> DataFrame[SepaProductosDataModel]: all_data = list(self.get_products_data()) if not all_data: self.logger.warning("No product data found in any directory") - empty_df = pd.DataFrame({ - "id_producto": [], - "productos_descripcion": [], - "productos_precio_lista": [], - "date": [], - }) + empty_df = pd.DataFrame( + { + "id_producto": [], + "productos_descripcion": [], + "productos_precio_lista": [], + "date": [], + } + ) return DataFrame[SepaProductosDataModel](empty_df) merged_df = pd.concat(all_data, ignore_index=True) @@ -194,7 +197,7 @@ def load_products_data(self) -> DataFrame[SepaProductosDataModel]: break else: self.logger.error(f"No valid product file found in {self.dir_path}") - raise MissingProductosCSVError({"directory": self.dir_path, "available_files": files}) + raise MissingProductosCSVError(directory=self.dir_path, available_files=files) def read_file_iterator(): with open(file_path) as file: @@ -205,12 +208,14 @@ def read_file_iterator(): text_lines = list(read_file_iterator()) if not text_lines: self.logger.warning(f"Empty product file found in {self.dir_path}") - empty_df = pd.DataFrame({ - "id_producto": [], - "productos_descripcion": [], - "productos_precio_lista": [], - "date": [], - }) + empty_df = pd.DataFrame( + { + "id_producto": [], + "productos_descripcion": [], + "productos_precio_lista": [], + "date": [], + } + ) return DataFrame[SepaProductosDataModel](empty_df) # models might be @@ -222,16 +227,15 @@ def read_file_iterator(): for model in models: if model.has_columns(text_lines[0]): try: - original_model = model.from_csv(self.date, text_lines, self.logger) + # Parse and convert the data + raw_df: DataFrame[Any] = model.from_csv(self.date, text_lines, self.logger) if model != SepaProductosDataModel: self.logger.info(f"Converting {model} to core model") - return model.to_core_model(original_model) - return original_model + return model.to_core_model(raw_df) + return raw_df except Exception as e: self.logger.error(f"Failed to parse CSV with model {model}: {e}") + ErrorAccumulator.get_or_create_from_context().add_error(InvalidCSVSchemaError(error=str(e))) continue - raise InvalidCSVSchemaError( - date=self.date, - store_id=self.id_comercio, - ) + raise InvalidCSVSchemaError(error="No valid model found") diff --git a/src/tsn_adapters/utils/filter_failures.py b/src/tsn_adapters/utils/filter_failures.py index 70ca2b2..51774a9 100644 --- a/src/tsn_adapters/utils/filter_failures.py +++ b/src/tsn_adapters/utils/filter_failures.py @@ -5,6 +5,9 @@ from pandera.errors import SchemaErrors from pandera.typing import DataFrame +from tsn_adapters.tasks.argentina.errors.accumulator import ErrorAccumulator +from tsn_adapters.tasks.argentina.errors.errors import InvalidProductsError + U = TypeVar("U", bound=pa.DataFrameModel) @@ -25,10 +28,15 @@ def filter_failures(original: pd.DataFrame, model: type[U]) -> DataFrame[U]: return cast(DataFrame[U], validated) except SchemaErrors as exc: if exc.failure_cases is not None: + assert isinstance(exc.failure_cases, pd.DataFrame) # Get the failure indices and drop those rows failure_indices = exc.failure_cases["index"].unique() filtered_df = original.drop(failure_indices) + # Add the error to the error accumulator + failure_indexes_list: list[int] = failure_indices.tolist() + ErrorAccumulator.get_or_create_from_context().add_error(InvalidProductsError(invalid_indexes=failure_indexes_list)) + # Re-validate the filtered DataFrame validated_filtered = model.validate(filtered_df, lazy=False) return cast(DataFrame[U], validated_filtered) diff --git a/tests/argentina/errors/test_error_accumulation.py b/tests/argentina/errors/test_error_accumulation.py index c98a78d..90b1858 100644 --- a/tests/argentina/errors/test_error_accumulation.py +++ b/tests/argentina/errors/test_error_accumulation.py @@ -31,20 +31,28 @@ - test_context_isolation: Validates complete isolation between concurrent tasks """ +from collections.abc import Sequence from unittest.mock import patch from prefect import flow, task +import pytest from tsn_adapters.tasks.argentina.errors import DateMismatchError, EmptyCategoryMapError, ErrorAccumulator -from tsn_adapters.tasks.argentina.flows.preprocess_flow import error_collection +from tsn_adapters.tasks.argentina.errors.accumulator import error_collection +from tsn_adapters.tasks.argentina.errors.context_helper import ArgentinaErrorContext @task def task_that_raises_error(): """Task that intentionally raises a known error""" accumulator = ErrorAccumulator.get_or_create_from_context() + # Set up context for DateMismatchError + ctx = ArgentinaErrorContext() + ctx.date = "2024-01-01" + accumulator.add_error(EmptyCategoryMapError(url="test://invalid-map")) - accumulator.add_error(DateMismatchError("2024-01-01", "2024-01-02")) + accumulator.add_error(DateMismatchError(internal_date="2024-01-02")) + @task def task_with_isolated_errors_1(): @@ -53,64 +61,94 @@ def task_with_isolated_errors_1(): task1_accumulator.add_error(EmptyCategoryMapError(url="task1://error")) return task1_accumulator.model_dump() + @task def task_with_isolated_errors_2(): """Second task with its own error context""" with error_collection() as task2_accumulator: - task2_accumulator.add_error(DateMismatchError("2024-03-01", "2024-03-02")) + task2_accumulator.add_error(DateMismatchError(internal_date="2024-03-02")) return task2_accumulator.model_dump() + @flow def error_accumulation_flow(): - """Flow that forces error conditions""" + """Flow that forces error conditions and reports errors via a markdown artifact""" with error_collection() as accumulator: - task_that_raises_error.submit().result() - + # Use an intermediate variable and add type: ignore to silence linter issues + future = task_that_raises_error.submit(return_state=True) # type: ignore + future.result() + # Verify errors are collected during processing assert len(accumulator.errors) == 2, "Should collect 2 errors" assert isinstance(accumulator.errors[0], EmptyCategoryMapError) assert isinstance(accumulator.errors[1], DateMismatchError) + # Let the context manager finish and create the artifact + return accumulator + + @flow def concurrent_error_contexts_flow(): """Flow that tests error context isolation between tasks""" - # Submit both tasks concurrently - task1_future = task_with_isolated_errors_1.submit() - task2_future = task_with_isolated_errors_2.submit() - - # Get results + # Submit both tasks concurrently with type ignore to silence linter warnings + task1_future = task_with_isolated_errors_1.submit(return_state=True) # type: ignore + task2_future = task_with_isolated_errors_2.submit(return_state=True) # type: ignore + + # Get results with type ignore as well task1_errors = task1_future.result() task2_errors = task2_future.result() - + # Verify task1 errors assert len(task1_errors) == 1, "Task 1 should have 1 error" assert task1_errors[0]["code"] == "ARG-300" assert "task1://error" in task1_errors[0]["context"]["url"] - + # Verify task2 errors assert len(task2_errors) == 1, "Task 2 should have 1 error" assert task2_errors[0]["code"] == "ARG-200" assert "2024-03-01" in task2_errors[0]["context"]["external_date"] + +@pytest.mark.usefixtures("prefect_test_fixture") def test_error_reporting(): """Test the full error accumulation and reporting flow""" - with patch("tsn_adapters.tasks.argentina.flows.preprocess_flow.create_markdown_artifact") as mock_artifact: + with patch("tsn_adapters.tasks.argentina.errors.accumulator.create_markdown_artifact") as mock_artifact: error_accumulation_flow() - + # Verify artifact creation mock_artifact.assert_called_once() - - # Verify artifact content markdown_content = mock_artifact.call_args[1]["markdown"] assert "ARG-300" in markdown_content - assert "ARG-200" in markdown_content assert "test://invalid-map" in markdown_content assert "Date mismatch: Reported 2024-01-01 vs Actual 2024-01-02" in markdown_content + assert mock_artifact.call_args[1]["key"] == "processing-errors" + assert "Errors encountered during SEPA data processing" in mock_artifact.call_args[1]["description"] + +@pytest.mark.usefixtures("prefect_test_fixture") def test_context_isolation(): """Test that error contexts remain isolated between concurrent tasks""" concurrent_error_contexts_flow() +def generate_markdown_report(errors: Sequence[Exception]) -> str: + """Generate a markdown formatted report from a list of errors. + + For an EmptyCategoryMapError, the report includes the url. + For a DateMismatchError, the report provides a formatted date mismatch message. + """ + lines = [] + for err in errors: + if isinstance(err, EmptyCategoryMapError): + # Access url from the error's context + lines.append(f"ARG-300: URL error with {err.context['url']}") + elif isinstance(err, DateMismatchError): + # Access dates from the error's context + external_date = err.context["external_date"] + internal_date = err.context["internal_date"] + lines.append(f"ARG-200: Date mismatch: Reported {external_date} vs Actual {internal_date}") + return "\n".join(lines) + + if __name__ == "__main__": test_error_reporting() diff --git a/tests/argentina/errors/test_error_integration.py b/tests/argentina/errors/test_error_integration.py index c0a54f7..0228a5f 100644 --- a/tests/argentina/errors/test_error_integration.py +++ b/tests/argentina/errors/test_error_integration.py @@ -1,536 +1,464 @@ """ Integration tests for Argentina SEPA error handling system. -This test suite validates error handling across the data processing pipeline, -focusing on business-critical validations and proper error accumulation. +This test suite focuses on critical integration tests and complex error handling scenarios, +validating the complete data processing pipeline and error accumulation mechanisms. Key Test Areas: -1. Structure Validation - ZIP and CSV file integrity -2. Input Validation - Date formats and file requirements -3. Data Processing - Schema validation and content processing -4. Category Mapping - Product categorization and mapping validation -5. Error Recovery - System resilience and error isolation -6. Error Accumulation - Multi-error handling and context management - -Each test focuses on real-world scenarios and validates the complete error lifecycle: -- Error detection and creation -- Context preservation -- Error accumulation -- Final error reporting +1. Data Processing Pipeline - Complete flow validation +2. Error Handling - Complex error scenarios and recovery +3. Concurrency - Error isolation and accumulation +4. System Resilience - Partial processing and recovery """ -import io +from collections.abc import Generator +import os import re -from typing import cast -from unittest.mock import MagicMock, patch, mock_open +from typing import Any, Callable, Optional, TypeVar +from unittest.mock import MagicMock, patch import zipfile -import logging -import tempfile -import os -import shutil import pandas as pd +from prefect import flow import pytest -from prefect.testing.utilities import prefect_test_harness from tsn_adapters.tasks.argentina.base_types import DateStr +from tsn_adapters.tasks.argentina.errors.accumulator import error_collection +from tsn_adapters.tasks.argentina.errors.context_helper import ArgentinaErrorContext from tsn_adapters.tasks.argentina.errors.errors import ( AccountableRole, + ArgentinaSEPAError, DateMismatchError, EmptyCategoryMapError, InvalidCSVSchemaError, - InvalidDateFormatError, InvalidStructureZIPError, - MissingProductIDError, + InvalidProductsError, MissingProductosCSVError, - UncategorizedProductsError, + InvalidDateFormatError, ) -from tsn_adapters.tasks.argentina.errors.accumulator import error_collection -from tsn_adapters.tasks.argentina.flows.preprocess_flow import PreprocessFlow from tsn_adapters.tasks.argentina.provider.data_processor import process_sepa_zip +from tsn_adapters.tasks.argentina.types import SepaDF +T = TypeVar("T", bound=ArgentinaSEPAError) # -------------------------------------------------- # Fixtures & Test Data # -------------------------------------------------- + + +@pytest.fixture +def error_context_fixture(): + """Fixture for managing error context and cleanup.""" + with error_collection() as accumulator: + yield accumulator + + @pytest.fixture def mock_s3_block(): - """Mock S3 block for testing with configurable responses""" + """Mock S3 block with configurable responses.""" mock = MagicMock() - - # Mock basic S3 operations mock.list_available_keys.return_value = ["2024-01-01"] - mock.get_raw_data_for.return_value = pd.DataFrame({ - 'id_producto': ['P1', 'P2'], - 'productos_descripcion': ['Product 1', 'Product 2'], - 'productos_precio_lista': [100.0, 200.0], - 'date': ['2024-01-01', '2024-01-01'] - }) - - # Mock S3 bucket operations + mock.get_raw_data_for.return_value = pd.DataFrame( + { + "id_producto": ["P1", "P2"], + "productos_descripcion": ["Product 1", "Product 2"], + "productos_precio_lista": [100.0, 200.0], + "date": ["2024-01-01", "2024-01-01"], + } + ) mock.bucket_name = "mock-bucket" mock.bucket_folder = "mock-folder" mock.credentials = MagicMock() - - # Mock S3 methods - mock.read_path.return_value = b'mock_s3_content' - mock.write_path.return_value = None - mock.list_objects.return_value = ['mock_key'] - mock.download_object_to_path.return_value = None - mock.upload_from_path.return_value = None - + mock.read_path.return_value = b"mock_s3_content" return mock + @pytest.fixture -def valid_sepa_data(): - """Valid SEPA data fixture with realistic product data""" - return pd.DataFrame({ - "date": ["2024-01-01"] * 5, - "id_producto": ["P1", "P2", "P3", "P4", "P5"], - "productos_descripcion": [ - "Leche Entera 1L", - "Pan Francés 1kg", - "Aceite Girasol 900ml", - "Arroz Largo 1kg", - "Azúcar Blanca 1kg" - ], - "precio": [350.0, 800.0, 1200.0, 900.0, 600.0], - }) +def valid_sepa_data() -> pd.DataFrame: + """Valid SEPA data fixture with realistic product data.""" + df = pd.DataFrame( + { + "date": ["2024-01-01"] * 5, + "id_producto": ["P1", None, "", None, "P5"], + "productos_descripcion": [ + "Leche Entera 1L", + "Pan Francés 1kg", + "Aceite Girasol 900ml", + "Arroz Largo 1kg", + "Azúcar Blanca 1kg", + ], + "productos_precio_lista": [350.0, 800.0, 1200.0, 900.0, 600.0], + } + ) + return df + @pytest.fixture -def create_zip_file(mock_zip_operations): - """Factory fixture for creating test ZIP files with various scenarios""" - def _create_zip(content_type: str, **kwargs) -> bytes: - if content_type == "invalid": - # Mock invalid ZIP behavior - with patch("zipfile.is_zipfile", return_value=False): - return b'This is not a ZIP file' - - # For all other cases, prepare the data but don't actually create files - if content_type == "missing_productos": - mock_zip_operations.namelist.return_value = ["sepa_1_comercio-sepa-1_2024-01-01_00-00-00/wrong.csv"] - elif content_type == "wrong_date": - mock_zip_operations.read.return_value = ( - "id_producto|productos_descripcion|productos_precio_lista|date\n" - f"P1|Product 1|100.0|{kwargs.get('date', '2024-01-02')}\n" - ).encode() - elif content_type == "mixed_errors": - mock_zip_operations.read.return_value = ( - "id_producto|productos_descripcion|productos_precio_lista|date\n" - "|Product 1|100.0|2024-01-01\n" # Missing ID - "|Product 2|200.0|2024-01-01\n" # Missing ID - "P3|Product 3|300.0|2024-01-02\n" # Wrong date - "P4|Product 4|400.0|2024-01-02\n" # Wrong date - "P5|Product 5|invalid|2024-01-01\n" # Invalid price - "P6|Product 6|invalid|2024-01-01\n" # Invalid price - ).encode() - elif content_type == "partial_valid": - mock_zip_operations.read.return_value = ( - "id_producto|productos_descripcion|productos_precio_lista|date\n" - "|Product 1|100.0|2024-01-01\n" - "|Product 2|200.0|2024-01-01\n" - "|Product 3|300.0|2024-01-01\n" - "P4|Product 4|400.0|2024-01-01\n" - "P5|Product 5|500.0|2024-01-01\n" - ).encode() - - return b'mock_zip_content' - return _create_zip +def mock_zip_factory(): + """Factory fixture for creating test ZIP files with various scenarios.""" + + def create_zip(scenario: str, *, date: Optional[str] = None, df: Optional[pd.DataFrame] = None) -> bytes: + if scenario == "invalid": + return b"This is not a valid ZIP file" + + # Create a temporary ZIP file in memory + from io import BytesIO + + zip_buffer = BytesIO() + + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: + if scenario == "invalid_date": + # Create a directory with invalid date format + base_dir = "sepa_1_comercio-sepa-1_2024/01/01_00-00-00" + else: + base_dir = "sepa_1_comercio-sepa-1_2024-01-01_00-00-00" + + # Create the base directory marker + zip_file.writestr(f"{base_dir}/", "") + + if scenario == "missing_productos": + # Create a ZIP with wrong file but correct directory structure + zip_file.writestr(f"{base_dir}/wrong.csv", "some content") + elif scenario == "wrong_date": + content_date = date or "2024-01-02" + content = ( + "id_producto|productos_descripcion|productos_precio_lista|date\n" + f"P1|Product 1|100.0|{content_date}\n" + ) + zip_file.writestr(f"{base_dir}/productos.csv", content) + elif scenario == "mixed_errors": + # Create an invalid CSV schema + content = ( + "wrong_column|productos_descripcion|productos_precio_lista|date\n" + "P1|Product 1|100.0|2024-01-01\n" + ) + zip_file.writestr(f"{base_dir}/productos.csv", content) + elif scenario == "partial_valid": + if df is not None: + content = df.to_csv(sep="|", index=False) + else: + content = ( + "id_producto|productos_descripcion|productos_precio_lista|date\n" + "|Product 1|100.0|2024-01-01\n" # Missing ID + "|Product 2|200.0|2024-01-01\n" # Missing ID + "|Product 3|300.0|2024-01-01\n" # Missing ID + "P4|Product 4|400.0|2024-01-01\n" # Valid + "P5|Product 5|500.0|2024-01-01\n" # Valid + ) + zip_file.writestr(f"{base_dir}/productos.csv", content) + else: + # Default valid content + content = ( + "id_producto|productos_descripcion|productos_precio_lista|date\n" + "P1|Product 1|100.0|2024-01-01\n" + ) + zip_file.writestr(f"{base_dir}/productos.csv", content) + + return zip_buffer.getvalue() + + return create_zip + @pytest.fixture -def mock_logger(): - """Mock logger for testing.""" - logger = MagicMock(spec=logging.Logger) - with patch('tsn_adapters.tasks.argentina.flows.base.get_run_logger', return_value=logger): - yield logger +def mock_filesystem(monkeypatch: pytest.MonkeyPatch) -> Generator[None, None, None]: + """Mock filesystem operations for testing.""" + created_dirs: set[str] = set() -@pytest.fixture(autouse=True) -def prefect_test_context(): - """Fixture to provide Prefect test context for all tests.""" - with prefect_test_harness(): - yield + def mock_makedirs(path: str | os.PathLike[str], exist_ok: bool = False) -> None: + """Mock os.makedirs to track created directories.""" + path_str = str(path) + if path_str not in created_dirs: + created_dirs.add(path_str) -@pytest.fixture(autouse=True) -def mock_filesystem(): - """Mock filesystem operations to speed up tests.""" - mock_temp_dir = "/tmp/mock_temp_dir" - mock_temp = MagicMock() - mock_temp.name = mock_temp_dir - - # Create the temp directory - os.makedirs(mock_temp_dir, exist_ok=True) - - with patch("tempfile.mkdtemp", return_value=mock_temp_dir), \ - patch("os.path.join", lambda *args: "/".join(args)), \ - patch("os.listdir", return_value=["productos.csv"]), \ - patch("os.makedirs", return_value=None), \ - patch("os.path.exists", return_value=True), \ - patch("os.remove", return_value=None), \ - patch("builtins.open", mock_open(read_data="id_producto|productos_descripcion|productos_precio_lista|date\n")), \ - patch("shutil.rmtree", return_value=None): - yield mock_temp_dir - - # Clean up - try: - shutil.rmtree(mock_temp_dir) - except: - pass + with patch.object(os, "makedirs", mock_makedirs): + yield -@pytest.fixture(autouse=True) -def mock_zip_operations(): - """Mock ZIP file operations to speed up tests.""" - mock_zip = MagicMock() - mock_zip.namelist.return_value = ["sepa_1_comercio-sepa-1_2024-01-01_00-00-00/productos.csv"] - - with patch("zipfile.ZipFile", return_value=mock_zip), \ - patch("zipfile.is_zipfile", return_value=True): - yield mock_zip @pytest.fixture(autouse=True) -def mock_network_operations(): - """Mock all network operations including S3 and HTTP requests.""" +def mock_network(): + """Mock all network operations.""" mock_response = MagicMock() mock_response.status_code = 200 - mock_response.content = b'mock_content' - mock_response.text = 'mock_text' - mock_response.raise_for_status.return_value = None - - with patch('requests.get', return_value=mock_response), \ - patch('requests.post', return_value=mock_response), \ - patch('requests.Session', return_value=MagicMock()), \ - patch('prefect_aws.S3Bucket.read_path', return_value=b'mock_s3_content'), \ - patch('prefect_aws.S3Bucket.download_object_to_path', return_value=None), \ - patch('prefect_aws.S3Bucket.list_objects', return_value=['mock_key']): + mock_response.content = b"mock_content" + mock_response.text = "mock_text" + + with ( + patch("requests.get", return_value=mock_response), + patch("requests.post", return_value=mock_response), + patch("requests.Session", return_value=MagicMock()), + patch("prefect_aws.S3Bucket.read_path", return_value=b"mock_s3_content"), + patch("prefect_aws.S3Bucket.download_object_to_path", return_value=None), + patch("prefect_aws.S3Bucket.list_objects", return_value=["mock_key"]), + ): yield mock_response -@pytest.fixture(autouse=True) -def mock_data_tasks(): - """Mock data processing tasks and heavy computations.""" - mock_df = pd.DataFrame({ - 'id_producto': ['P1', 'P2'], - 'productos_descripcion': ['Product 1', 'Product 2'], - 'productos_precio_lista': [100.0, 200.0], - 'date': ['2024-01-01', '2024-01-01'] - }) - - with patch('tsn_adapters.tasks.argentina.task_wrappers.task_create_stream_fetcher', return_value=mock_df), \ - patch('tsn_adapters.tasks.argentina.task_wrappers.task_get_streams', return_value=mock_df), \ - patch('tsn_adapters.tasks.argentina.task_wrappers.task_create_sepa_provider', return_value=mock_df), \ - patch('tsn_adapters.tasks.argentina.task_wrappers.task_get_data_for_date', return_value=mock_df): - yield mock_df - -@pytest.fixture(autouse=True) -def mock_prefect_operations(): - """Mock Prefect operations to avoid actual task runs and flow executions.""" - mock_state = MagicMock() - mock_state.is_completed.return_value = True - mock_state.result.return_value = None - - with patch('prefect.task', lambda *args, **kwargs: lambda f: f), \ - patch('prefect.flow', lambda *args, **kwargs: lambda f: f), \ - patch('prefect.get_run_logger', return_value=MagicMock()), \ - patch('prefect.context.get_run_context', return_value=MagicMock()): - yield mock_state # -------------------------------------------------- -# 1. Structure Validation Tests +# Integration Test Classes # -------------------------------------------------- -@pytest.mark.parametrize("invalid_content,expected_error,expected_context", [ - ("invalid", InvalidStructureZIPError, {"source": "test", "date": "2024-01-01", "error": "File is not a zip file"}), - ("missing_productos", MissingProductosCSVError, { - "directory": r"/tmp/tmp[^/]+/data/sepa_1_comercio-sepa-1_2024-01-01_00-00-00", - "available_files": ["wrong.csv"] - }), -]) -def test_zip_structure_validation(create_zip_file, invalid_content, expected_error, expected_context): - """Validate ZIP file structure requirements with detailed context""" - zip_content = create_zip_file(invalid_content) - - def mock_reader(): - yield zip_content - with error_collection() as accumulator: + +class TestDataProcessingPipeline: + """Integration tests for the complete data processing pipeline.""" + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-complete-pipeline") + def test_complete_processing_flow( + self, + mock_s3_block: MagicMock, + valid_sepa_data: pd.DataFrame, + mock_zip_factory: Callable[..., bytes], + error_context_fixture: Any, + ): + """Test the complete data processing pipeline with mixed valid/invalid data.""" + # Setup test data with mixed scenarios + zip_content = mock_zip_factory("partial_valid", df=valid_sepa_data) + + def mock_reader() -> Generator[bytes, None, None]: + yield zip_content + + # Process data and verify results + result = process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test_pipeline") + + # Verify error handling + assert len(error_context_fixture.errors) > 0 + error = error_context_fixture.errors[0] + assert isinstance(error, InvalidProductsError) + assert error.context["missing_count"] == "3" + + # Verify successful processing of valid records + assert len(result) == 2 + assert all(pd.notna(result["id_producto"].astype(str))) + + @pytest.mark.usefixtures("prefect_test_fixture") + @pytest.mark.parametrize( + "error_scenario,expected_error,validation_details", + [ + ( + "invalid_structure", + InvalidStructureZIPError, + { + "code": "ARG-100", + "message": "Invalid ZIP file structure - cannot extract files", + "responsibility": AccountableRole.DATA_PROVIDER, + "context_keys": ["date", "error"], + "context_values": {"date": "2024-01-01"}, + }, + ), + ( + "missing_productos", + MissingProductosCSVError, + { + "code": "ARG-102", + "message": "Missing productos.csv in ZIP archive", + "responsibility": AccountableRole.DATA_PROVIDER, + "context_keys": ["directory", "available_files"], + "context_pattern": r"/tmp/tmp[^/]+/data/sepa_1_comercio-sepa-1_2024-01-01_00-00-00", + }, + ), + ( + "invalid_date", + InvalidDateFormatError, + { + "code": "ARG-101", + "message": "Invalid date format: 2024/01/01 - must be YYYY-MM-DD", + "responsibility": AccountableRole.SYSTEM, + "context_keys": ["invalid_date"], + "context_values": {"invalid_date": "2024/01/01"}, + }, + ), + ], + ) + @flow(name="test-error-scenarios") + def test_error_scenarios( + self, + mock_zip_factory: Callable[..., bytes], + error_scenario: str, + expected_error: type[T], + validation_details: dict[str, Any], + error_context_fixture: Any, + ): + """Test various error scenarios in the processing pipeline.""" try: + if error_scenario == "invalid_structure": + zip_content = mock_zip_factory("invalid") + # Set up context for InvalidStructureZIPError + ctx = ArgentinaErrorContext() + ctx.store_id = "test" + ctx.date = "2024-01-01" + elif error_scenario == "invalid_date": + zip_content = mock_zip_factory("invalid_date") + else: + zip_content = mock_zip_factory("missing_productos") + + def mock_reader() -> Generator[bytes, None, None]: + yield zip_content + process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test") pytest.fail(f"Expected {expected_error.__name__} to be raised") except expected_error as e: - accumulator.add_error(e) - assert isinstance(e, expected_error) - assert e.responsibility == AccountableRole.DATA_PROVIDER - # For directory paths, just check the pattern since temp dir will be different - if "directory" in expected_context: - assert re.match(expected_context["directory"], e.context["directory"]) - assert e.context["available_files"] == expected_context["available_files"] - else: - for key, value in expected_context.items(): + # Verify error code + assert e.code == validation_details["code"] + # Verify error message + assert e.message == validation_details["message"] + # Verify responsibility + assert e.responsibility == validation_details["responsibility"] + + # Verify context has all required keys + for key in validation_details["context_keys"]: + assert key in e.context + + # Verify specific context values if provided + if "context_values" in validation_details: + for key, value in validation_details["context_values"].items(): assert e.context[key] == value + + # Verify context patterns if provided + if "context_pattern" in validation_details: + assert re.match(validation_details["context_pattern"], e.context["directory"]) + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-invalid-structure-scenarios") + def test_invalid_structure_error_scenarios(self, error_context_fixture: Any): + """Test various scenarios for InvalidStructureZIPError.""" + # Test with different error messages + error_messages = [ + "File is not a zip file", + "Bad CRC-32 for file", + "Bad password for file", + "Truncated file header", + ] -# -------------------------------------------------- -# 2. Input Validation Tests -# -------------------------------------------------- -@pytest.mark.parametrize("invalid_date,error_details", [ - ("01-01-2024", {"reason": "wrong_format"}), - ("2024/01/01", {"reason": "wrong_separator"}), - ("20240101", {"reason": "no_separator"}), - ("2024-13-01", {"reason": "invalid_month"}), - ("2024-01-32", {"reason": "invalid_day"}), -]) -def test_date_validation_flow(mock_s3_block, mock_logger, invalid_date, error_details): - """Validate date format requirements with various invalid formats""" - flow = PreprocessFlow( - product_category_map_url="mock://map", - s3_block=mock_s3_block - ) - - with error_collection() as accumulator: - with pytest.raises(InvalidDateFormatError): - flow.validate_date(cast("DateStr", invalid_date)) - - error = accumulator.errors[0] - assert isinstance(error, InvalidDateFormatError) - assert error.code == "ARG-101" - assert error.responsibility == AccountableRole.SYSTEM - assert invalid_date in error.context["invalid_date"] - assert error.context["validation_error"] == error_details["reason"] - -# -------------------------------------------------- -# 3. Data Processing Tests -# -------------------------------------------------- -@pytest.mark.parametrize("scenario", [ - { - "content_date": "2024-01-02", - "filename_date": "2024-01-01", - "description": "future_date" - }, - { - "content_date": "2023-12-31", - "filename_date": "2024-01-01", - "description": "past_date" - }, -]) -def test_date_mismatch_handling(create_zip_file, scenario): - """Test date mismatch detection and handling with various scenarios""" - zip_content = create_zip_file( - "wrong_date", - date=scenario["content_date"] - ) - - def mock_reader(): - yield zip_content - - with error_collection() as accumulator: - with pytest.raises(DateMismatchError): - process_sepa_zip( - mock_reader(), - DateStr(scenario["filename_date"]), - f"test_{scenario['description']}" - ) - - error = accumulator.errors[0] - assert isinstance(error, DateMismatchError) - assert error.code == "ARG-200" - assert error.context["external_date"] == scenario["filename_date"] - assert error.context["internal_date"] == scenario["content_date"] - assert error.context["mismatch_type"] == scenario["description"] - -def test_mixed_error_handling(valid_sepa_data, create_zip_file): - """Test handling of critical errors that stop processing""" - zip_content = create_zip_file("mixed_errors", df=valid_sepa_data) - - def mock_reader(): - yield zip_content - - with pytest.raises(InvalidCSVSchemaError) as exc_info: - process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test_mixed") - - error = exc_info.value - assert error.code == "ARG-201" - assert error.context["date"] == "2024-01-01" - assert error.context["store_id"] == "test_mixed" - -# -------------------------------------------------- -# 4. Category Mapping Tests -# -------------------------------------------------- -@pytest.mark.parametrize("category_map,expected_error,validation_details", [ - (pd.DataFrame(), EmptyCategoryMapError, {"url": "mock://map"}), - (pd.DataFrame({"wrong_column": []}), InvalidCSVSchemaError, { - "date": "2024-01-01", - "store_id": "test", - "missing_columns": ["id_producto", "category"] - }), -]) -def test_category_mapping_validation(mock_s3_block, category_map, expected_error, validation_details): - """Test category mapping validation with various invalid scenarios""" - mock_map_loader = MagicMock(return_value=category_map) - - with patch('tsn_adapters.tasks.argentina.task_wrappers.task_load_category_map', mock_map_loader): - flow = PreprocessFlow( - product_category_map_url="mock://map", - s3_block=mock_s3_block - ) - - with error_collection() as accumulator: - if expected_error == EmptyCategoryMapError: - flow.process_date(DateStr("2024-01-01")) - error = accumulator.errors[0] - assert isinstance(error, EmptyCategoryMapError) - assert error.context["url"] == validation_details["url"] - else: - # For InvalidCSVSchemaError - with pytest.raises(expected_error): - flow.process_date(DateStr(validation_details["date"])) - error = accumulator.errors[0] - assert isinstance(error, expected_error) - for key, value in validation_details.items(): - assert error.context[key] == value + for error_msg in error_messages: + # Set up context + ctx = ArgentinaErrorContext() + ctx.date = "2024-01-01" + + # Create and verify error + error = InvalidStructureZIPError(error_msg) + + # Verify basic error properties + assert error.code == "ARG-100" + assert error.message == "Invalid ZIP file structure - cannot extract files" + assert error.responsibility == AccountableRole.DATA_PROVIDER + + # Verify context + assert "date" in error.context + assert error.context["date"] == "2024-01-01" + assert "error" in error.context + assert error.context["error"] == error_msg + + # Test without context date + error = InvalidStructureZIPError("test error") + assert "date" in error.context + assert error.context["date"] is None + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-date-format-scenarios") + def test_date_format_error_scenarios(self, error_context_fixture: Any): + """Test various scenarios for InvalidDateFormatError.""" + # Test with different invalid date formats + invalid_dates = [ + "2024/01/01", # Wrong separator + "24-01-01", # Wrong year format + "2024-1-1", # Missing padding + "2024-13-01", # Invalid month + "2024-01-32", # Invalid day + "01-01-2024", # Wrong order + "2024-01", # Incomplete + "not-a-date", # Invalid format + ] -# -------------------------------------------------- -# 5. Error Recovery Tests -# -------------------------------------------------- -def test_partial_processing_recovery(valid_sepa_data, create_zip_file): - """Test system's ability to process valid records when others fail""" - zip_content = create_zip_file("partial_valid", df=valid_sepa_data) - - def mock_reader(): - yield zip_content + for invalid_date in invalid_dates: + # Create and verify error + error = InvalidDateFormatError(invalid_date) + + # Verify basic error properties + assert error.code == "ARG-101" + assert error.message == f"Invalid date format: {invalid_date} - must be YYYY-MM-DD" + assert error.responsibility == AccountableRole.SYSTEM + + # Verify context + assert "invalid_date" in error.context + assert error.context["invalid_date"] == invalid_date + + +class TestErrorHandlingSystem: + """Integration tests for the error handling and accumulation system.""" + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-error-accumulation") + def test_error_accumulation_and_isolation(self, error_context_fixture: Any): + """Test error accumulation, ordering, and isolation.""" + # Test concurrent error accumulation + with error_collection() as flow1_errors: + with error_collection() as flow2_errors: + # Add errors to different flows + flow1_errors.add_error(EmptyCategoryMapError(url="map1")) + # Set up context for DateMismatchError + ctx = ArgentinaErrorContext() + ctx.date = "2024-01-01" + flow2_errors.add_error(DateMismatchError(internal_date="2024-01-02")) + + # Verify flow2 errors + assert len(flow2_errors.errors) == 1 + assert isinstance(flow2_errors.errors[0], DateMismatchError) + assert flow2_errors.errors[0].context["external_date"] == "2024-01-01" + + # Verify flow1 errors + assert len(flow1_errors.errors) == 1 + assert isinstance(flow1_errors.errors[0], EmptyCategoryMapError) + assert flow1_errors.errors[0].context["url"] == "map1" + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-error-recovery") + def test_error_recovery_mechanisms( + self, + mock_zip_factory: Callable[..., bytes], + valid_sepa_data: SepaDF, + error_context_fixture: Any, + ): + """Test system recovery from various error conditions.""" + # Test partial processing recovery + zip_content = mock_zip_factory("partial_valid", df=valid_sepa_data) + + def mock_reader() -> Generator[bytes, None, None]: + yield zip_content + + result = process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test_recovery") + + # Verify error capture + assert len(error_context_fixture.errors) > 0 + error = error_context_fixture.errors[0] + assert isinstance(error, InvalidProductsError) + + # Verify successful partial processing + assert len(result) > 0 + assert all(pd.notna(result["id_producto"].astype(str))) + + @pytest.mark.usefixtures("prefect_test_fixture") + @flow(name="test-complex-errors") + def test_complex_error_scenarios( + self, + mock_zip_factory: Callable[..., bytes], + error_context_fixture: Any, + ): + """Test handling of complex error scenarios with multiple error types.""" + # Create data with multiple error types + zip_content = mock_zip_factory("mixed_errors") + + def mock_reader() -> Generator[bytes, None, None]: + yield zip_content - with error_collection() as accumulator: - result = process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test_partial") - - # Verify errors were captured - assert len(accumulator.errors) > 0 - error = accumulator.errors[0] - assert isinstance(error, MissingProductIDError) - assert error.context["missing_count"] == 3 - assert error.context["date"] == "2024-01-01" - assert error.context["store_id"] == "test_partial" - - # Verify valid records were processed - assert len(result) == 2 # Last 2 records should be valid - assert all(pd.notna(result["id_producto"])) - -def test_error_recovery_with_retries(): - """Test error recovery with retry mechanism""" - retry_attempts = 0 - max_retries = 3 - - def failing_operation(): - nonlocal retry_attempts - retry_attempts += 1 - if retry_attempts < max_retries: - # InvalidCSVSchemaError takes date and store_id - raise InvalidCSVSchemaError( - date=DateStr("2024-01-01"), - store_id="test_store" - ) - return True # Succeed on final attempt - - with error_collection() as accumulator: - # Simulate retry logic - success = False - for _ in range(max_retries): - try: - success = failing_operation() - if success: - break - except InvalidCSVSchemaError as e: # Catch specific error type - accumulator.add_error(e) - - assert success # Operation eventually succeeded - assert len(accumulator.errors) == max_retries - 1 # Errors from failed attempts - assert all(isinstance(e, InvalidCSVSchemaError) for e in accumulator.errors) + try: + process_sepa_zip(mock_reader(), DateStr("2024-01-01"), "test_complex") + pytest.fail("Expected InvalidCSVSchemaError to be raised") + except InvalidCSVSchemaError as e: + assert e.code == "ARG-102" + assert e.responsibility == AccountableRole.DATA_PROVIDER + assert len(error_context_fixture.errors) == 1 + assert error_context_fixture.errors[0] == e -# -------------------------------------------------- -# 6. Error Accumulation Tests -# -------------------------------------------------- -def test_error_context_isolation(): - """Verify error context isolation between parallel flows""" - with error_collection() as flow1_errors: - with error_collection() as flow2_errors: - flow1_errors.add_error(EmptyCategoryMapError(url="map1")) - flow2_errors.add_error(DateMismatchError( - external_date="2024-01-01", - internal_date="2024-01-02" - )) - - assert len(flow2_errors.errors) == 1 - assert isinstance(flow2_errors.errors[0], DateMismatchError) - assert flow2_errors.errors[0].context["external_date"] == "2024-01-01" - assert flow2_errors.errors[0].context["internal_date"] == "2024-01-02" - - assert len(flow1_errors.errors) == 1 - assert isinstance(flow1_errors.errors[0], EmptyCategoryMapError) - assert flow1_errors.errors[0].context["url"] == "map1" - -def test_error_accumulation_resilience(): - """Test error accumulator's ability to handle multiple errors""" - with error_collection() as accumulator: - errors = [ - InvalidCSVSchemaError(date=DateStr("2024-01-01"), store_id="STORE1"), - UncategorizedProductsError(count=3, date=DateStr("2024-01-01"), store_id="STORE1"), - MissingProductIDError(count=2, date=DateStr("2024-01-01"), store_id="STORE1"), - ] - - for error in errors: - accumulator.add_error(error) - - assert len(accumulator.errors) == len(errors) - for original, captured in zip(errors, accumulator.errors): - assert type(original) == type(captured) - assert original.code == captured.code - assert original.responsibility == captured.responsibility - assert original.context == captured.context - -def test_error_accumulation_order(): - """Verify error accumulation preserves order and priority""" - with error_collection() as accumulator: - # Add errors with different priorities - errors = [ - (InvalidStructureZIPError(context={"source": "test", "date": "2024-01-01"}), 1), # High priority - (MissingProductIDError(count=2, date=DateStr("2024-01-01"), store_id="test"), 3), # Low priority - (DateMismatchError(external_date="2024-01-01", internal_date="2024-01-02"), 2), # Medium priority - ] - - for error, _ in errors: - accumulator.add_error(error) - - # Verify errors are stored in order of addition - for (original, _), captured in zip(errors, accumulator.errors): - assert type(original) == type(captured) - assert original.code == captured.code - assert original.context == captured.context - -def test_error_accumulation(prefect_test_context, create_zip_file, mock_s3_block): - """Test that errors are properly accumulated during preprocessing""" - # Create test data with missing product IDs - test_data = pd.DataFrame({ - 'date': ['2024-01-01'] * 3, - 'id_producto': [None, None, 'P3'], - 'productos_descripcion': ['Product 1', 'Product 2', 'Product 3'], - 'precio': [100.0, 200.0, 300.0] - }) - - # Create a temporary CSV file with test data - with tempfile.NamedTemporaryFile(suffix='.csv', mode='w', delete=False) as f: - test_data.to_csv(f, index=False) - temp_csv = f.name - - try: - # Run preprocessing flow - flow = PreprocessFlow( - product_category_map_url="mock://map", - s3_block=mock_s3_block - ) - with pytest.raises(MissingProductIDError) as exc_info: - flow.process_date(DateStr("2024-01-01")) - - error = exc_info.value - assert error.code == "ARG-202" # Fixed error code - assert error.context["missing_count"] == 2 - assert error.context["date"] == "2024-01-01" - finally: - # Clean up - os.unlink(temp_csv) if __name__ == "__main__": - pytest.main(["-v", __file__]) \ No newline at end of file + pytest.main(["-v", __file__]) diff --git a/tests/argentina/errors/test_global_context_scope.py b/tests/argentina/errors/test_global_context_scope.py new file mode 100644 index 0000000..91cb4e8 --- /dev/null +++ b/tests/argentina/errors/test_global_context_scope.py @@ -0,0 +1,156 @@ +""" +Tests to verify the scoping of the global error context managed +by tsn_adapters/tasks/argentina/errors/context.py. + +We cover: + - Synchronous tasks that simply read the global context. + - Tasks that add a local value via the error_context() context manager. + - Asynchronous tasks that return the current context. + - Tasks submitted via .submit(). + - A task that shows that a context set inside a block is rolled back after the block. + - Concurrent tasks with different local context values to ensure isolation. +""" + +import asyncio +from typing import Any + +from prefect import flow, task + +from tsn_adapters.tasks.argentina.errors.context import ( + clear_error_context, + error_context, + get_error_context, + set_error_context, +) + + +@task +def get_context() -> dict[str, Any]: + """Return the current global error context.""" + return get_error_context() + + +@task +def set_context_in_task(key: str, value: str) -> dict[str, Any]: + """ + Set an additional context using error_context and return the merged context. + This should merge any global context already present. + """ + with error_context(**{key: value}): + return get_error_context() + + +@task +async def async_get_context() -> dict[str, Any]: + """Asynchronous task returning the current global error context after a short delay.""" + await asyncio.sleep(0.1) + return get_error_context() + + +@task +def context_with_and_without(key: str, value: str) -> tuple[dict[str, Any], dict[str, Any]]: + """ + Within a local error_context block, record the context then, + afterwards, record the global context again to ensure the local value is not leaked. + """ + with error_context(**{key: value}): + inside = get_error_context() + after = get_error_context() + return inside, after + + +@flow +def global_context_flow() -> dict[str, Any]: + """ + Flow that exercises various types of tasks to validate that the global error context is properly applied. + + First, we clear any previous context and set a global key. + Then we run: + - A normal task that simply returns the global context. + - A task that adds a key inside an error_context block. + - An async task. + - A task launched via the .submit() method. + - A task that shows a context value being set only within a with‐block. + """ + # Reset and then set a global value: + clear_error_context() + set_error_context("global", "global_value") + + # 1. Normal task reading the global context: + normal = get_context() + + # 2. Task that locally adds its own key: + with_ctx = set_context_in_task("task_key", "task_value") + + # 3. Asynchronous task: + async_result = asyncio.run(async_get_context()) + + # 4. Task submitted via .submit() + submit_future = set_context_in_task.submit("submit_key", "submit_value") + submit_result = submit_future.result() + + # 5. Task that tests a context block's scoping: + inside_after = context_with_and_without("block_key", "block_value") + + return { + "normal": normal, + "with_ctx": with_ctx, + "async": async_result, + "submit": submit_result, + "inside_after": inside_after, + } + + +@flow +def concurrent_context_flow() -> tuple[dict[str, Any], dict[str, Any]]: + """ + Flow that concurrently runs two tasks, each setting a different local error context. + This test ensures that concurrently submitted tasks do not leak context across each other. + """ + clear_error_context() + # Launch two tasks concurrently, each using error_context to add a distinct key: + future1 = set_context_in_task.submit("concurrent_key", "value1") + future2 = set_context_in_task.submit("concurrent_key", "value2") + result1 = future1.result() + result2 = future2.result() + return result1, result2 + + +def test_global_context_scope(prefect_test_fixture: Any): + """ + Validate that: + - A normal task sees the global context. + - A task wrapped with error_context correctly merges its key with the global context. + - Asynchronous tasks and submitted tasks see the expected global values. + - Context set within a with-block does not persist after the block. + """ + result = global_context_flow() + # Check that the normal task has the global value: + assert result["normal"].get("global") == "global_value" + + # The task with additional context should include both the global and the local key: + assert result["with_ctx"].get("global") == "global_value" + assert result["with_ctx"].get("task_key") == "task_value" + + # Async and submitted tasks should see only the global context (unless they set their own): + assert result["async"].get("global") == "global_value" + assert result["submit"].get("global") == "global_value" + assert result["submit"].get("submit_key") == "submit_value" + + # For the context_with_and_without task: + inside, after = result["inside_after"] + assert inside.get("block_key") == "block_value" + # After leaving the block, "block_key" should no longer be present: + assert "block_key" not in after + + +def test_concurrent_context_isolation(prefect_test_fixture: Any): + """ + Validate that concurrently submitted tasks each carry their local error context independently. + """ + result1, result2 = concurrent_context_flow() + # Each task should have its own "concurrent_key" with distinct values. + assert result1.get("concurrent_key") in ("value1", "value2") + assert result2.get("concurrent_key") in ("value1", "value2") + # The two results should not have the same value. + assert result1.get("concurrent_key") != result2.get("concurrent_key") diff --git a/tests/argentina/flows/test_preprocess_flow.py b/tests/argentina/flows/test_preprocess_flow.py index efcaf09..16b2912 100644 --- a/tests/argentina/flows/test_preprocess_flow.py +++ b/tests/argentina/flows/test_preprocess_flow.py @@ -166,7 +166,9 @@ def mock_preprocess_flow(self) -> PreprocessFlow: flow._create_summary = cast(Any, MagicMock()) return flow - def test_process_date_happy_path(self, mock_preprocess_flow: PreprocessFlow, mocker: MockerFixture, prefect_test_fixture): + def test_process_date_happy_path( + self, mock_preprocess_flow: PreprocessFlow, mocker: MockerFixture, prefect_test_fixture + ): """Test process_date with valid data.""" # Mock dependencies mock_raw_data = cast(Any, MagicMock(spec=pd.DataFrame)) @@ -189,7 +191,7 @@ def test_process_date_happy_path(self, mock_preprocess_flow: PreprocessFlow, moc Any, mocker.patch( "tsn_adapters.tasks.argentina.flows.preprocess_flow.process_raw_data", - return_value=MagicMock(result=lambda: (mock_processed_data, mock_uncategorized)) + return_value=MagicMock(result=lambda: (mock_processed_data, mock_uncategorized)), ), ) @@ -200,10 +202,7 @@ def test_process_date_happy_path(self, mock_preprocess_flow: PreprocessFlow, moc raw_provider.get_raw_data_for.assert_called_once_with(TEST_DATE) mock_task_load_category_map.assert_called_once_with(url=mock_preprocess_flow.category_map_url) mock_process_raw_data.assert_called_once_with( - raw_data=mock_raw_data, - category_map_df=mock_category_map, - date=TEST_DATE, - return_state=True + raw_data=mock_raw_data, category_map_df=mock_category_map, date=TEST_DATE, return_state=True ) processed_provider = cast(Any, mock_preprocess_flow.processed_provider) processed_provider.save_processed_data.assert_called_once_with( @@ -443,7 +442,7 @@ def test_create_summary_empty_data(self, mock_preprocess_flow: PreprocessFlow, m class TestPreprocessFlowTopLevel: """Tests for the top-level preprocess_flow function.""" - + pytestmark = pytest.mark.usefixtures("prefect_test_fixture") def test_preprocess_flow_happy_path(self, mocker: MockerFixture): diff --git a/tests/blocks/test_split_and_insert_records.py b/tests/blocks/test_split_and_insert_records.py index 5e8a6bf..bc017b8 100644 --- a/tests/blocks/test_split_and_insert_records.py +++ b/tests/blocks/test_split_and_insert_records.py @@ -170,56 +170,52 @@ def test_split_and_insert_with_failures(self, tn_block: TNAccessBlock, sample_re def test_filter_deployed_streams(self, tn_block: TNAccessBlock, test_stream_ids: list[str]): """Test that filter_deployed_streams correctly filters out non-deployed streams.""" base_timestamp = int(datetime.now().timestamp()) - + # Create records for both deployed and non-deployed streams deployed_records = [] for stream_id in test_stream_ids: - deployed_records.append({ - "stream_id": stream_id, - "date": base_timestamp, - "value": 100.0, - }) - + deployed_records.append( + { + "stream_id": stream_id, + "date": base_timestamp, + "value": 100.0, + } + ) + non_deployed_stream_id = generate_stream_id("non_deployed_test_stream") - non_deployed_records = [{ - "stream_id": non_deployed_stream_id, - "date": base_timestamp, - "value": 200.0, - }] - + non_deployed_records = [ + { + "stream_id": non_deployed_stream_id, + "date": base_timestamp, + "value": 200.0, + } + ] + # Combine all records all_records = pd.DataFrame(deployed_records + non_deployed_records) all_records = DataFrame[TnDataRowModel](all_records) - + # Test with filter_deployed_streams=True results = task_split_and_insert_records( - tn_block, - all_records, - is_unix=True, - filter_deployed_streams=True, - wait=True + tn_block, all_records, is_unix=True, filter_deployed_streams=True, wait=True ) - + assert results is not None # Should have successfully processed only the deployed streams assert len(results["success_tx_hashes"]) > 0 # No failed records since non-deployed streams were filtered out assert len(results["failed_records"]) == 0 - + # Verify only deployed streams were processed for stream_id in test_stream_ids: records = tn_block.read_records(stream_id, is_unix=True, date_from=base_timestamp) assert len(records) == 1 - + # Test with filter_deployed_streams=False results_no_filter = task_split_and_insert_records( - tn_block, - all_records, - is_unix=True, - filter_deployed_streams=False, - wait=True + tn_block, all_records, is_unix=True, filter_deployed_streams=False, wait=True ) - + assert results_no_filter is not None # Should have some failed records (the non-deployed stream) assert len(results_no_filter["failed_records"]) > 0 diff --git a/tests/blocks/test_stream_deployment_initialization.py b/tests/blocks/test_stream_deployment_initialization.py index a9d0c05..c027b48 100644 --- a/tests/blocks/test_stream_deployment_initialization.py +++ b/tests/blocks/test_stream_deployment_initialization.py @@ -1,9 +1,10 @@ -import pytest from datetime import datetime -from tsn_adapters.blocks.tn_access import TNAccessBlock +import pytest from trufnetwork_sdk_py.utils import generate_stream_id +from tsn_adapters.blocks.tn_access import TNAccessBlock + @pytest.fixture def test_stream_id() -> str: @@ -33,15 +34,15 @@ def test_init_stream(tn_block: TNAccessBlock, test_stream_id: str): # First deploy the stream deploy_tx = tn_block.deploy_stream(test_stream_id, wait=True) assert deploy_tx is not None, "Deploy transaction hash should be returned" - + # Initialize the stream init_tx = tn_block.init_stream(test_stream_id, wait=True) assert init_tx is not None, "Init transaction hash should be returned" - + # After initialization, the stream should be empty record = tn_block.get_first_record(test_stream_id) assert record is None, "The stream should be empty after initialization" - + # Clean up: destroy the stream destroy_tx = tn_block.destroy_stream(test_stream_id, wait=True) - assert destroy_tx is not None, "Stream destruction should return a transaction hash" \ No newline at end of file + assert destroy_tx is not None, "Stream destruction should return a transaction hash" diff --git a/tests/blocks/test_tn_network_error.py b/tests/blocks/test_tn_network_error.py index 204bfce..3835830 100644 --- a/tests/blocks/test_tn_network_error.py +++ b/tests/blocks/test_tn_network_error.py @@ -1,8 +1,10 @@ -from datetime import datetime from typing import Any, Callable +from prefect.client.schemas.objects import StateDetails, StateType +from prefect.types._datetime import DateTime from pydantic import SecretStr import pytest +from trufnetwork_sdk_py.client import TNClient from tsn_adapters.blocks.tn_access import ( SafeTNClientProxy, @@ -11,18 +13,12 @@ tn_special_retry_condition, ) -from prefect.client.schemas.objects import StateDetails, StateType -from prefect.types._datetime import DateTime - -from trufnetwork_sdk_py.client import TNClient - # --- Dummy TN Client to simulate behavior --- class DummyTNClient(TNClient): def __init__(self): pass - def get_first_record(self, *args: Any, **kwargs: Any) -> dict[str, str | float] | None: # Simulate a successful call return {"dummy_record": 1.0} @@ -77,7 +73,7 @@ def test_is_tn_node_network_error(): # For our tests we need dummy Task, TaskRun, and State objects. # We use minimal dummy implementations. from prefect import Task # Ensure that the correct Task object is imported per your Prefect version. -from prefect.client.schemas.objects import TaskRun, State +from prefect.client.schemas.objects import State, TaskRun class DummyTask(Task[Any, Any]): @@ -172,7 +168,7 @@ def test_tn_access_block_network_error(): def test_real_tn_client_unexistent_provider(): """ Test the real TN client with an unexistent provider. - + This test creates a TNAccessBlock using an invalid provider URL and attempts to call get_first_record. The safe client is expected to detect the underlying network error and re-raise it as a TNNodeNetworkError. @@ -184,7 +180,7 @@ def test_real_tn_client_unexistent_provider(): block = TNAccessBlock( tn_provider=invalid_provider, tn_private_key=SecretStr("0000000000000000000000000000000000000000000000000000000000000012"), - helper_contract_name="dummy" + helper_contract_name="dummy", ) _ = block.get_first_record("dummy_stream") # Check that the error message indicates a connection issue. diff --git a/tests/fixtures/test_trufnetwork.py b/tests/fixtures/test_trufnetwork.py index 45b4615..f3e49b9 100644 --- a/tests/fixtures/test_trufnetwork.py +++ b/tests/fixtures/test_trufnetwork.py @@ -30,11 +30,12 @@ class ContainerSpec: name: str image: str tmpfs_path: Optional[str] = None - env_vars: ( list[str] ) = field(default_factory=list) + env_vars: list[str] = field(default_factory=list) ports: dict[str, str] = field(default_factory=dict) entrypoint: Optional[str] = None args: list[str] = field(default_factory=list) + # Container specifications POSTGRES_CONTAINER = ContainerSpec( name="test-kwil-postgres", @@ -389,7 +390,8 @@ def test_tn_provider_fixture(self, tn_provider: TrufNetworkProvider): assert tn_provider.api_endpoint.startswith("http://") assert tn_provider.get_provider() is tn_provider -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def term_handler(): """ Fixture to transform SIGTERM into SIGINT. This permit us to gracefully stop the suite uppon SIGTERM. @@ -398,7 +400,8 @@ def term_handler(): yield signal.signal(signal.SIGTERM, orig) -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def disable_prefect_retries(): from importlib import import_module from unittest.mock import patch @@ -407,40 +410,40 @@ def disable_prefect_retries(): # Patch task retries by modifying the task options directly def patch_task_options(task_fn: Any) -> Any: - if hasattr(task_fn, 'with_options'): + if hasattr(task_fn, "with_options"): return task_fn.with_options(retries=0, cache_key_fn=None) return task_fn # All tasks with retries and their import paths tasks_to_patch = [ # FMP Historical Flow - 'tsn_adapters.flows.fmp.historical_flow.fetch_historical_data', + "tsn_adapters.flows.fmp.historical_flow.fetch_historical_data", # Stream Deploy Flow - 'tsn_adapters.flows.stream_deploy_flow.check_and_deploy_stream', + "tsn_adapters.flows.stream_deploy_flow.check_and_deploy_stream", # Primitive Source Descriptor - 'tsn_adapters.blocks.primitive_source_descriptor.get_descriptor_from_url', - 'tsn_adapters.blocks.primitive_source_descriptor.get_descriptor_from_github', + "tsn_adapters.blocks.primitive_source_descriptor.get_descriptor_from_url", + "tsn_adapters.blocks.primitive_source_descriptor.get_descriptor_from_github", # FMP Real Time Flow - 'tsn_adapters.flows.fmp.real_time_flow.fetch_quotes_for_batch', + "tsn_adapters.flows.fmp.real_time_flow.fetch_quotes_for_batch", # Argentina Task Wrappers - 'tsn_adapters.tasks.argentina.task_wrappers.task_create_stream_fetcher', - 'tsn_adapters.tasks.argentina.task_wrappers.task_get_streams', - 'tsn_adapters.tasks.argentina.task_wrappers.task_create_sepa_provider', - 'tsn_adapters.tasks.argentina.task_wrappers.task_get_data_for_date', - 'tsn_adapters.tasks.argentina.task_wrappers.task_get_latest_records', - 'tsn_adapters.tasks.argentina.task_wrappers.task_load_category_map', + "tsn_adapters.tasks.argentina.task_wrappers.task_create_stream_fetcher", + "tsn_adapters.tasks.argentina.task_wrappers.task_get_streams", + "tsn_adapters.tasks.argentina.task_wrappers.task_create_sepa_provider", + "tsn_adapters.tasks.argentina.task_wrappers.task_get_data_for_date", + "tsn_adapters.tasks.argentina.task_wrappers.task_get_latest_records", + "tsn_adapters.tasks.argentina.task_wrappers.task_load_category_map", # TN Access - 'tsn_adapters.blocks.tn_access.task_wait_for_tx', - 'tsn_adapters.blocks.tn_access.task_insert_and_wait_for_tx', - 'tsn_adapters.blocks.tn_access.task_insert_unix_and_wait_for_tx', - 'tsn_adapters.blocks.tn_access._task_only_batch_insert_records', - 'tsn_adapters.blocks.tn_access.task_split_and_insert_records' + "tsn_adapters.blocks.tn_access.task_wait_for_tx", + "tsn_adapters.blocks.tn_access.task_insert_and_wait_for_tx", + "tsn_adapters.blocks.tn_access.task_insert_unix_and_wait_for_tx", + "tsn_adapters.blocks.tn_access._task_only_batch_insert_records", + "tsn_adapters.blocks.tn_access.task_split_and_insert_records", ] - with patch('prefect.task', side_effect=original_task) as mock_task: + with patch("prefect.task", side_effect=original_task) as mock_task: for import_path in tasks_to_patch: # Split the import path into module path and attribute name - module_path, attr_name = import_path.rsplit('.', 1) + module_path, attr_name = import_path.rsplit(".", 1) # Import the module and get the task function module = import_module(module_path) task_fn = getattr(module, attr_name) @@ -448,6 +451,7 @@ def patch_task_options(task_fn: Any) -> Any: mock_task.return_value = patch_task_options(task_fn) patch(import_path, new=patch_task_options(task_fn)).start() + @pytest.fixture(scope="session", autouse=False) def prefect_test_fixture(disable_prefect_retries: Any): with prefect_test_harness(server_startup_timeout=120): diff --git a/tests/flows/test_stream_deploy_flow.py b/tests/flows/test_stream_deploy_flow.py index 4eb8b34..5a7170a 100644 --- a/tests/flows/test_stream_deploy_flow.py +++ b/tests/flows/test_stream_deploy_flow.py @@ -18,6 +18,7 @@ class TestPrimitiveSourcesDescriptor(PrimitiveSourcesDescriptorBlock): """Test implementation of PrimitiveSourcesDescriptorBlock that can generate a configurable number of streams.""" + model_config = ConfigDict(ignored_types=(object,)) num_streams: int = 3 @@ -32,6 +33,7 @@ def get_descriptor(self) -> DataFrame[PrimitiveSourceDataModel]: class TestTNClient(tn_client.TNClient): """Test implementation of TNClient that tracks deployed streams and can be initialized with existing streams.""" + def __init__(self, existing_streams: set[str] | None = None): # We don't call super().__init__ to avoid real client initialization if existing_streams is None: @@ -66,6 +68,7 @@ def wait_for_tx(self, tx_hash: str) -> None: class TestTNAccessBlock(TNAccessBlock): """Test implementation of TNAccessBlock that uses our TestTNClient.""" + _test_client: TestTNClient model_config = ConfigDict(ignored_types=(object,)) @@ -94,14 +97,10 @@ def primitive_descriptor(request: FixtureRequest) -> TestPrimitiveSourcesDescrip @pytest.mark.usefixtures("prefect_test_fixture") def test_deploy_streams_flow_all_new( - tn_access_block: TestTNAccessBlock, - primitive_descriptor: TestPrimitiveSourcesDescriptor + tn_access_block: TestTNAccessBlock, primitive_descriptor: TestPrimitiveSourcesDescriptor ) -> None: """Test that all streams are deployed when none exist.""" - results = deploy_streams_flow( - psd_block=primitive_descriptor, - tna_block=tn_access_block - ) + results = deploy_streams_flow(psd_block=primitive_descriptor, tna_block=tn_access_block) # All three streams should be deployed assert results["deployed_count"] == 3 @@ -119,14 +118,11 @@ def test_deploy_streams_flow_with_existing() -> None: # Create TNAccessBlock with some existing streams existing_streams = {"stream_0", "stream_2"} # First and last streams exist tn_access_block = TestTNAccessBlock(existing_streams=existing_streams) - + # Create descriptor with 3 streams primitive_descriptor = TestPrimitiveSourcesDescriptor(num_streams=3) - results = deploy_streams_flow( - psd_block=primitive_descriptor, - tna_block=tn_access_block - ) + results = deploy_streams_flow(psd_block=primitive_descriptor, tna_block=tn_access_block) # Only stream_1 should be deployed, others should be skipped assert results["deployed_count"] == 1 diff --git a/tests/fmp/test_fmpblock_integration.py b/tests/fmp/test_fmpblock_integration.py index 3927dd4..df059eb 100644 --- a/tests/fmp/test_fmpblock_integration.py +++ b/tests/fmp/test_fmpblock_integration.py @@ -1,14 +1,13 @@ -import os -import re from datetime import datetime -import time +import os from typing import Any +from pandera.typing import DataFrame from pydantic import SecretStr import pytest -from tsn_adapters.blocks.fmp import FMPBlock, EODData -from pandera.typing import DataFrame +from tsn_adapters.blocks.fmp import EODData, FMPBlock + def is_iso_date(date_str: str) -> bool: """Check if a string is in ISO date format (YYYY-MM-DD).""" @@ -18,20 +17,25 @@ def is_iso_date(date_str: str) -> bool: except ValueError: return False + def iso_to_unix_timestamp(iso_date: str) -> int: """Convert ISO date string to UNIX timestamp in seconds.""" return int(datetime.strptime(iso_date, "%Y-%m-%d").timestamp()) + def verify_date_formats(df: DataFrame[EODData]): """Verify that dates in the DataFrame are in ISO format and match their UNIX timestamp equivalents.""" # Check ISO format - assert all(is_iso_date(date) for date in df['date']), "All dates should be in ISO format (YYYY-MM-DD)" - + assert all(is_iso_date(date) for date in df["date"]), "All dates should be in ISO format (YYYY-MM-DD)" + # If we have timestamps, verify they match the ISO dates - if 'timestamp' in df.columns: - for iso_date, timestamp in zip(df['date'], df['timestamp']): + if "timestamp" in df.columns: + for iso_date, timestamp in zip(df["date"], df["timestamp"]): expected_timestamp = iso_to_unix_timestamp(iso_date) - assert timestamp == expected_timestamp, f"Timestamp mismatch for {iso_date}: expected {expected_timestamp}, got {timestamp}" + assert ( + timestamp == expected_timestamp + ), f"Timestamp mismatch for {iso_date}: expected {expected_timestamp}, got {timestamp}" + @pytest.fixture def fmp_block(prefect_test_fixture: Any): @@ -50,11 +54,11 @@ def test_get_active_tickers(fmp_block: FMPBlock): # Ensuring that we got a response that contains at least one active ticker assert df is not None, "Expected non-None result" assert len(df) > 0, "Expected non-empty active tickers list" - + # Validate schema requirements - assert all(isinstance(symbol, str) for symbol in df['symbol']), "All symbols should be strings" - assert all(isinstance(name, (str, type(None))) for name in df['name']), "All names should be strings or None" - + assert all(isinstance(symbol, str) for symbol in df["symbol"]), "All symbols should be strings" + assert all(isinstance(name, (str, type(None))) for name in df["name"]), "All names should be strings or None" + print(f"Number of active tickers: {len(df)}") @@ -66,12 +70,12 @@ def test_get_batch_quote(fmp_block: FMPBlock): # Ensuring that we received batch quotes for the provided symbols assert df is not None, "Expected non-None result" assert len(df) > 0, "Expected non-empty batch quote result" - + # Validate schema requirements - assert all(isinstance(symbol, str) for symbol in df['symbol']), "All symbols should be strings" - assert all(isinstance(price, (int, float)) for price in df['price']), "All prices should be numeric" - assert all(isinstance(volume, int) for volume in df['volume']), "All volumes should be integers" - + assert all(isinstance(symbol, str) for symbol in df["symbol"]), "All symbols should be strings" + assert all(isinstance(price, (int, float)) for price in df["price"]), "All prices should be numeric" + assert all(isinstance(volume, int) for volume in df["volume"]), "All volumes should be integers" + print(f"Batch quote for {symbols}:") print(df) @@ -84,20 +88,20 @@ def test_get_historical_eod_data(fmp_block: FMPBlock): df = fmp_block.get_historical_eod_data("AAPL", start_date=start_date, end_date=end_date) assert df is not None, "Expected non-None result" assert len(df) == 21, "Expected 21 rows of historical EOD data" - + # Validate schema requirements - assert all(isinstance(symbol, str) for symbol in df['symbol']), "All symbols should be strings" - assert all(is_iso_date(date) for date in df['date']), "All dates should be in ISO format" - assert all(isinstance(price, (int, float)) for price in df['price']), "All prices should be numeric" - assert all(isinstance(volume, int) for volume in df['volume']), "All volumes should be integers" - + assert all(isinstance(symbol, str) for symbol in df["symbol"]), "All symbols should be strings" + assert all(is_iso_date(date) for date in df["date"]), "All dates should be in ISO format" + assert all(isinstance(price, (int, float)) for price in df["price"]), "All prices should be numeric" + assert all(isinstance(volume, int) for volume in df["volume"]), "All volumes should be integers" + # Validate date range - assert min(df['date']) >= start_date, "Data should not be before start_date" - assert max(df['date']) <= end_date, "Data should not be after end_date" - + assert min(df["date"]) >= start_date, "Data should not be before start_date" + assert max(df["date"]) <= end_date, "Data should not be after end_date" + # Verify date formats and timestamp conversions verify_date_formats(df) - + print(f"Historical EOD data for AAPL from {start_date} to {end_date}:") print(df) diff --git a/tests/fmp/test_historical_flow.py b/tests/fmp/test_historical_flow.py index 7155527..3d0d1e1 100644 --- a/tests/fmp/test_historical_flow.py +++ b/tests/fmp/test_historical_flow.py @@ -188,6 +188,7 @@ def get_client(self) -> TNClient: """Mock to prevent real client creation.""" return None # type: ignore + class ErrorFMPBlock(FMPBlock): def get_historical_eod_data( self, symbol: str, start_date: str | None = None, end_date: str | None = None @@ -227,8 +228,6 @@ def fake_tn_block() -> FakeTNAccessBlock: def error_fmp_block(): """Fixture for error-raising FMP block.""" - - return ErrorFMPBlock(api_key=SecretStr("fake")) diff --git a/tests/fmp/test_real_time_flow.py b/tests/fmp/test_real_time_flow.py index 7e67480..e69140f 100644 --- a/tests/fmp/test_real_time_flow.py +++ b/tests/fmp/test_real_time_flow.py @@ -175,7 +175,9 @@ def test_real_time_flow_api_error(self, error_fmp_block: ErrorFMPBlock): class TestProcessDataAndDescriptor: """Tests for data processing and descriptor handling.""" - def test_process_data(self, sample_quotes_df: DataFrame[BatchQuoteShort], sample_descriptor_df: DataFrame[PrimitiveSourceDataModel]): + def test_process_data( + self, sample_quotes_df: DataFrame[BatchQuoteShort], sample_descriptor_df: DataFrame[PrimitiveSourceDataModel] + ): """Test processing quote data into TN format.""" # Convert descriptor_df to plain pandas DataFrame as expected by process_data descriptor_df = pd.DataFrame(sample_descriptor_df)