diff --git a/Makefile b/Makefile index 9f95fde..8c2f29c 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,6 @@ # Format code format: uv run ruff format src/ tests/ - # Lint and fix issues lint: uv run ruff check src/ tests/ --fix @@ -18,6 +17,7 @@ lint-check-unsafe: # Run tests test: + export AWE_ENV=TEST uv run pytest tests/ # Install the package diff --git a/README.md b/README.md index 30083fa..6e42288 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,13 @@ -# AIWebExplorer +[![WebExplorer CI/CD](https://github.com/thinktwiceco/webexplorer/actions/workflows/ci.yml/badge.svg)](https://github.com/thinktwiceco/webexplorer/actions/workflows/ci.yml) + +[![New Version Deployed](https://github.com/thinktwiceco/webexplorer/actions/workflows/version.yml/badge.svg?branch=master)](https://github.com/thinktwiceco/webexplorer/actions/workflows/version.yml) + + +# AIWebExplorer ๐ŸŒ An agent for agents to explore the web -## Installation +## ๐Ÿ“ฆ Installation This project uses `uv` for dependency management. @@ -18,7 +23,7 @@ uv sync source .venv/bin/activate ``` -## Development +## ๐Ÿ› ๏ธ Development ```bash # Run linting @@ -31,34 +36,92 @@ uv run ruff format . uv run ruff check --select I ``` -## Environment Variables +## โš™๏ธ Environment Variables Copy `.env.example` to `.env` and adjust the values: ```env -# Environment setting (DEV, STAGING, PROD, CI) +# Environment setting (DEV, TEST, CI, PROD) AWE_ENV=DEV # Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL) AWE_LOG_LEVEL=INFO -# Optional: LLM Provider configuration -# Options: openai, togetherai, deepseek +# LLM Provider configuration (REQUIRED) +# Supported providers: openai, togetherai, deepseek AWE_LLM_PROVIDER=openai -# Optional: LLM Model to use -# Example: gpt-4, gpt-3.5-turbo, etc. +# LLM Model to use (REQUIRED) +# Examples: +# - OpenAI: gpt-4, gpt-4-turbo, gpt-3.5-turbo +# - TogetherAI: meta-llama/Llama-2-70b-chat-hf, mistralai/Mixtral-8x7B-Instruct-v0.1 +# - DeepSeek: deepseek-chat, deepseek-coder AWE_LLM_MODEL=gpt-4 + +# API Key for the selected provider (REQUIRED) +AWE_LLM_API_KEY=your-api-key-here ``` ### Configuration Options - **`AWE_ENV`**: Application environment (default: `DEV`) + - Options: `DEV`, `TEST`, `CI`, `PROD` - **`AWE_LOG_LEVEL`**: Logging verbosity level (default: `INFO`) -- **`AWE_LLM_PROVIDER`**: Optional LLM provider selection. If not set, must be specified when creating agents -- **`AWE_LLM_MODEL`**: Optional default model to use. If not set, must be specified when creating agents + - Options: `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` +- **`AWE_LLM_PROVIDER`**: LLM provider selection (**REQUIRED**) + - Supported providers: `openai`, `togetherai`, `deepseek` + - Can be overridden when creating agents programmatically +- **`AWE_LLM_MODEL`**: Model identifier to use (**REQUIRED**) + - Must be compatible with the selected provider + - Can be overridden when creating agents programmatically +- **`AWE_LLM_API_KEY`**: API key for authentication (**REQUIRED**) + - Must be valid for the selected provider + - Can be overridden when creating agents programmatically + +### Supported Providers + +#### ๐Ÿค– OpenAI +- **Provider**: `openai` +- **Models**: `gpt-4`, `gpt-4-turbo`, `gpt-3.5-turbo`, and more +- **API Key**: Get from [OpenAI Platform](https://platform.openai.com/) + +#### ๐Ÿ”— TogetherAI +- **Provider**: `togetherai` +- **Models**: Various open-source models including Llama, Mixtral, etc. +- **API Key**: Get from [Together.ai](https://together.ai/) + +#### ๐Ÿ” DeepSeek +- **Provider**: `deepseek` +- **Models**: `deepseek-chat`, `deepseek-coder` +- **API Key**: Get from [DeepSeek Platform](https://platform.deepseek.com/) + +## ๐Ÿงช Testing + +For comprehensive testing documentation, including how to run tests, use dependency injection for mocking, and write new tests, see the [Tests README](tests/README.md). + +### Running Tests + +```bash +# Run all tests +pytest + +# Run with verbose output +pytest -v + +# Run specific test file +pytest tests/test_webexplorer_integration.py +``` + +### ๐Ÿ“Š Evaluation Reports + +Performance evaluation reports are available in the [`tests/reports/`](tests/reports/) directory: + +- [**Amazon Extraction Report**](tests/reports/amazon_extraction_report.md) - Evaluation of product information extraction from Amazon +- [**Wikipedia Extraction Report**](tests/reports/wikipedia_extraction_report.md) - Evaluation of information extraction from Wikipedia + +These reports track the accuracy and performance of the WebExplorer across different types of websites and extraction tasks. -## New Features +## โœจ New Features To develop a new feature: @@ -79,7 +142,7 @@ To develop a new feature: 3. **Create a Pull Request to `develop` branch** 4. **After review and merge, delete the feature branch** -## New Versions +## ๐Ÿš€ New Versions ### Option 1: Automated Release (Recommended) diff --git a/src/aiwebexplorer/__init__.py b/src/aiwebexplorer/__init__.py index 494adf0..af1d382 100644 --- a/src/aiwebexplorer/__init__.py +++ b/src/aiwebexplorer/__init__.py @@ -5,3 +5,8 @@ except ImportError: # Fallback for development __version__ = "0.0.0+dev" + +from .agents import get_evaluate_request_agent as get_evaluate_request_agent +from .agents import get_extraction_agent as get_extraction_agent +from .agents import get_finalizer_agent as get_finalizer_agent +from .webexplorer import WebExplorer as WebExplorer diff --git a/src/aiwebexplorer/agent_factory.py b/src/aiwebexplorer/agent_factory.py index bc48d7d..1b1cf45 100644 --- a/src/aiwebexplorer/agent_factory.py +++ b/src/aiwebexplorer/agent_factory.py @@ -1,4 +1,3 @@ -import os from typing import Any, Literal, NewType, TypeVar from agno.agent import Agent @@ -6,9 +5,10 @@ from agno.models.openai import OpenAIChat from agno.models.together import Together +from aiwebexplorer.dependencies import Dependency from aiwebexplorer.interfaces import IAgent -from .config import config +from .config import get_provider_config T = TypeVar("T") @@ -27,47 +27,24 @@ SupportedModelProvider = Together | DeepSeek | OpenAIChat -def _get_api_key(provider: ModelProvider | None = None) -> tuple[str, ModelProvider]: - import dotenv - - dotenv.load_dotenv() - - # If a model provider is provided, only return the api key for that provider - if provider: - expected_key = f"{provider.upper()}_APIKEY" - value = os.environ.get(expected_key) - if value: - return value, provider - - raise ValueError(f"Expected {expected_key} to be set when requesting provider {provider}") - - # If no provider is provided, return the api key for the first provider that is set - for api_key, provider in [OPENAI_APIKEY, TOGETHERAI_APIKEY, DEEPSEEK_APIKEY]: - value = os.environ.get(api_key) - - if value: - return value, provider - - raise ValueError("No api key found for any provider") - - def _get_model( model_id: str | None = None, provider: ModelProvider | None = None, api_key: str | None = None, - model_id_map: ModelIdMap | None = None, ) -> SupportedModelProvider: - if api_key is None: - api_key, provider = _get_api_key(provider) - - if model_id is None: - if not model_id_map: - error_message = """ - You didn't provide a model id or a model id map. I'm expecting at least a model id map - in order to figure it out what model to use. - """ - raise ValueError(error_message) - model_id = model_id_map[provider] + provider_config = get_provider_config() + provider = provider or provider_config.AWE_LLM_PROVIDER + model_id = model_id or provider_config.AWE_LLM_MODEL + api_key = api_key or provider_config.AWE_LLM_API_KEY + + if not provider: + raise RuntimeError("Specify a provider either in the configuration or as env variable AWE_LLM_PROVIDER") + + if not model_id: + raise RuntimeError("Specify a model id either in the configuration or as env variable AWE_LLM_MODEL") + + if not api_key: + raise RuntimeError("Specify an api key either in the configuration or as env variable AWE_LLM_API_KEY") if provider == "togetherai": return Together(id=model_id, api_key=api_key) @@ -83,20 +60,22 @@ def get_agent( name: str, instructions: list[str], *, - model_id: str | None = config.AWE_LLM_MODEL, + model_id: str | None = None, api_key: str | None = None, - provider: ModelProvider | None = config.AWE_LLM_PROVIDER, - model_id_map: ModelIdMap | None = None, + provider: ModelProvider | None = None, **kwargs: Any, ) -> IAgent[Any]: """Get an agent with the given name, instructions, content type, model, and other kwargs. See Agent constructor for Agno agent details. """ - model = _get_model(model_id, provider, api_key, model_id_map) + model = _get_model(model_id, provider, api_key) return Agent( name=name, instructions=instructions, model=model, **kwargs, ) + + +get_agent_dependency = Dependency(get_agent) diff --git a/src/aiwebexplorer/agents.py b/src/aiwebexplorer/agents.py index 18e0b2e..8bfe11a 100644 --- a/src/aiwebexplorer/agents.py +++ b/src/aiwebexplorer/agents.py @@ -1,118 +1,119 @@ from dataclasses import dataclass +from typing import Any -from aiwebexplorer.dependencies import get_agent_dep +from aiwebexplorer.agent_factory import get_agent_dependency +from aiwebexplorer.interfaces import IAgent @dataclass class AgentsIds: ... -model_id_map = { - "togetherai": "openai/gpt-oss-20b", - "deepseek": "chat", - "openai": "gpt-4", -} +def get_evaluate_request_agent() -> IAgent[Any]: + return get_agent_dependency( + name="evaluate_request_agent", + instructions=[ + "You will evaluate the request of a user to extract information from a web page.", + "You need to make sure that the request is valid.", + ( + "In order to be valid, the request must contain a url or a popular website name that you can extract " + "the url from." + ), + "Make sure that the question is clear and complete.", + ( + "If any information is missing, return an error message explaining why you are not able to perform " + "the task." + ), + "Refine the question if necessary.", + "Respond with: ", + " - The url of the website where the search will be performed ([URL])", + " - The refined question that will be used to extract the information ([QUESTION])", + ( + " - A comma separated list of keywords that can be used to extract the information from the webpage " + "([KEYWORDS])" + ), + ( + " - A marker at the end of your response to indicate if the response is successful or not. The marker " + "should be [RESULT]: OK if the response is successful, or [RESULT]: ERROR if you are not able to " + "perform the task ([RESULT])" + ), + ( + "If, for any reason, you are not able to extract either the url or the question, return the error " + "message explaining why you are not able to perform the task ([MESSAGE])" + ), + "Example of a valid response: ", + "[QUESTION]: When was the Space shuttle first launched?", + "[URL]: https://en.wikipedia.org/wiki/Space_Shuttle", + "[KEYWORDS]: Space shuttle, launch date", + "[RESULT]: OK", + "Example of an invalid response: ", + "[MESSAGE]: The question is not clear and complete. Please refine the question.", + "[RESULT]: ERROR", + "In order to provide useful keywords, focus on the specific information that you are trying to extract.", + "If the request is about some specifications of an electronic devices", + "keywords might be 'specs', 'specifications' but also 'memory', 'cpu', 'display size' etc." + "Try to find as many keywords as possible that can be used to target important information in the webpage.", + ( + "For example, if you are trying to extract the launch date of the Space shuttle, the keywords " + "should be 'launch date', 'launch', 'launched', 'first launch'," + ), + "Do not provide obvious keywords, like 'Space shuttle'", + "If the user requires a summary of the content, do not return any keywords.", + ], + ) -evaluate_request_agent = get_agent_dep()( - name="evaluate_request_agent", - instructions=[ - "You will evaluate the request of a user to extract information from a web page.", - "You need to make sure that the request is valid.", - ( - "In order to be valid, the request must contain a url or a popular website name that you can extract " - "the url from." - ), - "Make sure that the question is clear and complete.", - "If any information is missing, return an error message explaining why you are not able to perform the task.", - "Refine the question if necessary.", - "Respond with: ", - " - The url of the website where the search will be performed ([URL])", - " - The refined question that will be used to extract the information ([QUESTION])", - ( - " - A comma separated list of keywords that can be used to extract the information from the webpage " - "([KEYWORDS])" - ), - ( - " - A marker at the end of your response to indicate if the response is successful or not. The marker " - "should be [RESULT]: OK if the response is successful, or [RESULT]: ERROR if you are not able to " - "perform the task ([RESULT])" - ), - ( - "If, for any reason, you are not able to extract either the url or the question, return the error " - "message explaining why you are not able to perform the task ([MESSAGE])" - ), - "Example of a valid response: ", - "[QUESTION]: When was the Space shuttle first launched?", - "[URL]: https://en.wikipedia.org/wiki/Space_Shuttle", - "[KEYWORDS]: Space shuttle, launch date", - "[RESULT]: OK", - "Example of an invalid response: ", - "[MESSAGE]: The question is not clear and complete. Please refine the question.", - "[RESULT]: ERROR", - "In order to provide useful keywords, focus on the specific information that you are trying to extract.", - "If the request is about some specifications of an electronic devices", - "keywords might be 'specs', 'specifications' but also 'memory', 'cpu', 'display size' etc." - "Try to find as many keywords as possible that can be used to target important information in the webpage.", - ( - "For example, if you are trying to extract the launch date of the Space shuttle, the keywords should be " - "'launch date', 'launch', 'launched', 'first launch'," - ), - "Do not provide obvious keywords, like 'Space shuttle'", - "If the user requires a summary of the content, do not return any keywords.", - ], - model_id_map=model_id_map, -) -extraction_agent = get_agent_dep()( - name="extraction_agent", - instructions=[ - "Extract ONLY the requested information from the content. Be precise and concise.", - "", - "RULES:", - "1. Return only what's asked - no explanations", - "2. Extract exactly as it appears on the page", - "3. If not found: 'Not available'", - "4. Multiple items: one per line", - "5. Add [SOURCE]: snippet of text where info was found", - "6. End with [CONFIDENCE]: HIGH/MEDIUM/LOW", - "7. If incomplete info: add [PARTIAL]", - "", - "Example:", - "Question: What is the price and model?", - "Response:", - "$299.99", - "iPhone 15", - "[SOURCE]: Price: $299.99", - "[SOURCE]: Model: iPhone 15", - "[CONFIDENCE]: HIGH", - ], - model_id_map=model_id_map, -) +def get_extraction_agent() -> IAgent[Any]: + return get_agent_dependency( + name="extraction_agent", + instructions=[ + "Extract ONLY the requested information from the content. Be precise and concise.", + "", + "RULES:", + "1. Return only what's asked - no explanations", + "2. Extract exactly as it appears on the page", + "3. If not found: 'Not available'", + "4. Multiple items: one per line", + "5. Add [SOURCE]: snippet of text where info was found", + "6. End with [CONFIDENCE]: HIGH/MEDIUM/LOW", + "7. If incomplete info: add [PARTIAL]", + "", + "Example:", + "Question: What is the price and model?", + "Response:", + "$299.99", + "iPhone 15", + "[SOURCE]: Price: $299.99", + "[SOURCE]: Model: iPhone 15", + "[CONFIDENCE]: HIGH", + ], + ) -finalizer_agent = get_agent_dep()( - name="finalizer_agent", - instructions=[ - "You are a finalization agent that synthesizes extracted information into a clear, comprehensive answer.", - "Your role is to take raw extracted information and formulate a proper response to the user's question.", - "", - "CRITICAL RULES:", - "1. Read all provided extracted information carefully", - "2. Synthesize the information into a coherent, well-structured answer", - "3. Answer the user's question directly and completely", - "4. If information is missing or unclear, acknowledge this in your response", - "5. If multiple pieces of information are provided, organize them appropriately", - "6. Do not add information that wasn't in the extracted data", - "7. Do not return tables", - "Example:", - "Question: What are the main features of this product?", - "Extracted info: 'Wireless charging, 5G connectivity, 48MP camera, 128GB storage'", - "Response: 'The main features of this product include:", - "- Wireless charging capability", - "- 5G connectivity for fast internet speeds", - "- 48MP camera for high-quality photos", - "- 128GB of storage space'", - "", - "Provide a complete, well-formatted answer that directly addresses the user's question.", - ], - model_id_map=model_id_map, -) + +def get_finalizer_agent() -> IAgent[Any]: + return get_agent_dependency( + name="finalizer_agent", + instructions=[ + "You are a finalization agent that synthesizes extracted information into a clear, comprehensive answer.", + "Your role is to take raw extracted information and formulate a proper response to the user's question.", + "", + "CRITICAL RULES:", + "1. Read all provided extracted information carefully", + "2. Synthesize the information into a coherent, well-structured answer", + "3. Answer the user's question directly and completely", + "4. If information is missing or unclear, acknowledge this in your response", + "5. If multiple pieces of information are provided, organize them appropriately", + "6. Do not add information that wasn't in the extracted data", + "7. Do not return tables", + "Example:", + "Question: What are the main features of this product?", + "Extracted info: 'Wireless charging, 5G connectivity, 48MP camera, 128GB storage'", + "Response: 'The main features of this product include:", + "- Wireless charging capability", + "- 5G connectivity for fast internet speeds", + "- 48MP camera for high-quality photos", + "- 128GB of storage space'", + "", + "Provide a complete, well-formatted answer that directly addresses the user's question.", + ], + ) diff --git a/src/aiwebexplorer/config.py b/src/aiwebexplorer/config.py index df10747..61562f1 100644 --- a/src/aiwebexplorer/config.py +++ b/src/aiwebexplorer/config.py @@ -13,16 +13,29 @@ class Environment(StrEnum): PROD = "PROD" +class ProviderConfig(BaseSettings): + """Provider configuration.""" + + AWE_LLM_PROVIDER: Literal["openai", "togetherai", "deepseek"] | None = None + AWE_LLM_MODEL: str | None = None + AWE_LLM_API_KEY: str | None = None + + model_config = SettingsConfigDict(env_file=".env", env_prefix="", extra="ignore") + + class Config(BaseSettings): """Application settings.""" AWE_ENV: Environment = Environment.DEV AWE_LOG_LEVEL: str = "INFO" # Allowed: DEBUG, INFO, WARNING, ERROR, CRITICAL - AWE_LLM_PROVIDER: Literal["openai", "togetherai", "deepseek"] | None = None - AWE_LLM_MODEL: str | None = None model_config = SettingsConfigDict(env_file=".env", env_prefix="", extra="ignore") # Export a ready-to-use settings object -config = Config() +def get_config() -> Config: + return Config() + + +def get_provider_config() -> ProviderConfig: + return ProviderConfig() diff --git a/src/aiwebexplorer/dependencies.py b/src/aiwebexplorer/dependencies.py index c9cf0c1..6d22f43 100644 --- a/src/aiwebexplorer/dependencies.py +++ b/src/aiwebexplorer/dependencies.py @@ -1,25 +1,37 @@ +import hashlib from collections.abc import Callable +from contextlib import contextmanager from typing import Any -from unittest.mock import Mock -from aiwebexplorer.agent_factory import get_agent -from aiwebexplorer.config import Environment, config -from aiwebexplorer.interfaces import IAgent +class Dependency: + registry: dict[str, Callable] = {} -def get_agent_dep() -> Callable[..., IAgent[Any]]: - """Get agent dependency based on environment configuration. + def __init__(self, dependency: Callable) -> None: + self.key = self._get_fn_hash(dependency) + if self.key in self.registry: + raise ValueError(f"Dependency with key {self.key} already exists") + self.registry[self.key] = dependency - Returns: - Callable that returns either a real agent or a mock agent based on environment. - """ - if config.AWE_ENV == Environment.CI: - # Return a function that creates a mock agent for CI environment - def mock_agent_factory(*args: Any, **kwargs: Any) -> IAgent[Any]: - mock_agent = Mock(spec=IAgent[Any]) - return mock_agent + def __call__(self, *args: Any, **kwds: Any) -> Any: + if self.key not in self.registry: + raise ValueError(f"Dependency {self.key} not found") + return self.registry[self.key](*args, **kwds) - return mock_agent_factory - else: - # Return the real get_agent function for all other environments - return get_agent + def _get_fn_hash(self, fn: Callable) -> str: + module = fn.__module__ + qualname = fn.__qualname__ + code = fn.__code__ + + return hashlib.sha256(f"{module}.{qualname}:{code.co_firstlineno}:{code.co_filename}".encode()).hexdigest() + + @contextmanager + def override(self, overrider_dependency: Callable) -> None: + self._override_key = self._get_fn_hash(overrider_dependency) + self._original_dep = self.registry[self.key] + + self.registry[self.key] = overrider_dependency + try: + yield + finally: + self.registry[self.key] = self._original_dep diff --git a/src/aiwebexplorer/webexplorer.py b/src/aiwebexplorer/webexplorer.py index 78f7a3c..567c771 100644 --- a/src/aiwebexplorer/webexplorer.py +++ b/src/aiwebexplorer/webexplorer.py @@ -11,7 +11,7 @@ from playwright.async_api import Browser, async_playwright from playwright_stealth import Stealth -from aiwebexplorer.agents import evaluate_request_agent, extraction_agent, finalizer_agent +from aiwebexplorer.agents import get_evaluate_request_agent, get_extraction_agent, get_finalizer_agent from aiwebexplorer.interfaces import AgentInterfaceError, IAgent, IResponse logger = structlog.get_logger() @@ -95,9 +95,9 @@ def __init__( self.use_lightweight_fetch = use_lightweight_fetch self.site_selectors = site_selectors or self.DEFAULT_SITE_SELECTORS.copy() - self._evaluation_agent: IAgent[str] = _evaluate_request_agent or evaluate_request_agent - self._extractor_agent: IAgent[str] = _extraction_agent or extraction_agent - self._finalizer_agent: IAgent[str] = _finalizer_agent or finalizer_agent + self._evaluation_agent: IAgent[str] = _evaluate_request_agent or get_evaluate_request_agent() + self._extractor_agent: IAgent[str] = _extraction_agent or get_extraction_agent() + self._finalizer_agent: IAgent[str] = _finalizer_agent or get_finalizer_agent() async def arun(self, prompt: str) -> IResponse[str]: """Run web exploration based on the prompt. diff --git a/tests/README.md b/tests/README.md index 9a09263..68f8c38 100644 --- a/tests/README.md +++ b/tests/README.md @@ -35,11 +35,32 @@ This directory contains comprehensive tests for the AI Web Explorer package. ### Prerequisites -Install test dependencies: +#### 1. Install test dependencies: ```bash pip install pytest pytest-asyncio pytest-mock ``` +#### 2. Configure environment variables: + +For integration tests that use real LLM providers, you need to set the required environment variables: + +```bash +# Required for real API calls (not needed for mocked unit tests) +export AWE_LLM_PROVIDER=openai +export AWE_LLM_MODEL=gpt-4 +export AWE_LLM_API_KEY=your-api-key-here +``` + +Or create a `.env` file in the WebExplorer directory: + +```env +AWE_LLM_PROVIDER=openai +AWE_LLM_MODEL=gpt-4 +AWE_LLM_API_KEY=your-api-key-here +``` + +**Note**: Most tests use dependency injection to mock agents, so API keys are only needed for integration/evaluation tests. + ### Basic Test Execution ```bash @@ -88,6 +109,124 @@ pytest --cov=aiwebexplorer --cov-report=html open htmlcov/index.html ``` +## Dependency Injection System + +This test suite uses a custom dependency injection system for mocking and testing. The system allows you to override dependencies at runtime without modifying the core code. + +### How It Works + +The `Dependency` class (from `aiwebexplorer.dependencies`) provides a context manager for overriding dependencies: + +```python +from aiwebexplorer.agent_factory import get_agent_dependency + +# Create a mock agent +def mock_agent_factory(*args, **kwargs): + return MockAgent() + +# Override the dependency in a test +with get_agent_dependency.override(mock_agent_factory): + # Code here will use the mock instead of the real agent factory + result = some_function_that_uses_agents() +``` + +### Using Dependency Injection in Tests + +#### Basic Override Pattern + +```python +import pytest +from aiwebexplorer.agent_factory import get_agent_dependency + +@pytest.mark.asyncio +async def test_with_mock_agent(): + """Test using a mock agent via dependency injection.""" + + def get_agent_mock(*args, **kwargs): + mock = AsyncMock() + mock.arun.return_value = "Mocked response" + return mock + + with get_agent_dependency.override(get_agent_mock): + # Your test code here + explorer = WebExplorer() + result = await explorer.arun("test query") + assert "Mocked" in result +``` + +#### Benefits + +1. **No monkey patching**: Clean, explicit dependency replacement +2. **Context-based**: Automatically restores original after the context +3. **Type-safe**: Maintains function signatures +4. **Thread-safe**: Each override is isolated to its context + +### Available Dependencies + +- **`get_agent_dependency`**: Main agent factory used throughout the codebase + - Override this to inject mock agents in tests + - Automatically used by `get_evaluate_request_agent()`, `get_extraction_agent()`, `get_finalizer_agent()` + +### Real-World Examples + +#### Example 1: Testing Site Detection with Mocked Agents + +From `test_site_detection.py`: + +```python +@pytest.mark.asyncio +async def test_amazon_product_extraction(): + """Test extraction from Amazon product page.""" + + def get_agent_mock(*args, **kwargs): + mock = AsyncMock() + if kwargs.get("name") == "evaluate_request_agent": + mock.arun.return_value = """ + [QUESTION]: What is the product title and price? + [URL]: https://www.amazon.com/dp/B08N5WRWNW + [KEYWORDS]: product, title, price + [RESULT]: OK + """ + elif kwargs.get("name") == "extraction_agent": + mock.arun.return_value = """ + Apple AirPods Max + $549.00 + [CONFIDENCE]: HIGH + """ + return mock + + with get_agent_dependency.override(get_agent_mock): + explorer = WebExplorer() + result = await explorer.arun( + "Get product info from https://www.amazon.com/dp/B08N5WRWNW" + ) + assert "AirPods" in result +``` + +#### Example 2: Testing Text Parsing with Keyword Highlighting + +From `test_text_parsing.py`: + +```python +@pytest.mark.asyncio +async def test_keyword_extraction(): + """Test that relevant content is extracted based on keywords.""" + + def get_agent_mock(*args, **kwargs): + mock = AsyncMock() + mock.arun.return_value = """ + [QUESTION]: What are the specs? + [URL]: https://example.com + [KEYWORDS]: specifications, memory, processor + [RESULT]: OK + """ + return mock + + with get_agent_dependency.override(get_agent_mock): + explorer = WebExplorer() + # Test implementation here +``` + ## Test Fixtures ### Available Fixtures @@ -121,9 +260,41 @@ async def test_async_functionality(): assert response.content is not None ``` -### Mocking +### Mocking with Dependency Injection + +The preferred way to mock agents is using the dependency injection system: + +```python +from aiwebexplorer.agent_factory import get_agent_dependency +from unittest.mock import AsyncMock + +@pytest.mark.asyncio +async def test_with_dependency_injection(): + """Test using dependency injection to mock agents.""" + + def get_agent_mock(*args, **kwargs): + mock_agent = AsyncMock() + mock_agent.arun.return_value = """ + [QUESTION]: What is the price? + [URL]: https://example.com + [KEYWORDS]: price, cost + [RESULT]: OK + """ + return mock_agent + + # Override the dependency + with get_agent_dependency.override(get_agent_mock): + explorer = WebExplorer() + result = await explorer.arun("test prompt") + assert "example.com" in result + +### Traditional Mocking (for internal methods) + +For internal methods that don't use dependency injection: ```python +from unittest.mock import patch + @patch.object(WebExplorer, '_explore_fast') async def test_with_mock(mock_explore): mock_explore.return_value = "mocked result" diff --git a/tests/dependencie.py b/tests/dependencie.py new file mode 100644 index 0000000..544d8df --- /dev/null +++ b/tests/dependencie.py @@ -0,0 +1,9 @@ +from typing import Any +from unittest.mock import Mock + +from aiwebexplorer.interfaces import IAgent + + +def get_agent_mock(*args: Any, **kwargs: Any) -> IAgent[Any]: + """Get a mock agent for testing.""" + return Mock(spec=IAgent[Any]) diff --git a/tests/test_webexplorer.py b/tests/eval_test_webexplorer.py similarity index 69% rename from tests/test_webexplorer.py rename to tests/eval_test_webexplorer.py index b127df6..8f2dbae 100644 --- a/tests/test_webexplorer.py +++ b/tests/eval_test_webexplorer.py @@ -1,8 +1,36 @@ +from collections.abc import Coroutine +from pathlib import Path +from typing import Any + from pydantic_evals import Case, Dataset from pydantic_evals.evaluators import Evaluator, EvaluatorContext from aiwebexplorer.webexplorer import WebExplorer +# Create reports directory +REPORTS_DIR = Path(__file__).parent / "reports" +REPORTS_DIR.mkdir(exist_ok=True) + + +def save_report_markdown(report: Any, test_name: str) -> None: + """Save evaluation report as markdown.""" + report_path = REPORTS_DIR / f"{test_name}_report.md" + + # Convert report to markdown format + with open(report_path, "w") as f: + f.write(f"# Evaluation Report: {test_name}\n\n") + f.write(str(report)) + + print(f"\nReport saved to: {report_path}") + + +def arun_wrapper(webexplorer: WebExplorer) -> Coroutine[Any, Any, str]: + async def _wrapper(prompt: str) -> str: + response = await webexplorer.arun(prompt) + return response.content + + return _wrapper + async def wikipedia_extraction(): case1 = Case( @@ -13,7 +41,7 @@ async def wikipedia_extraction(): class Evaluator1(Evaluator): def evaluate(self, ctx: EvaluatorContext) -> float: - output = ctx.output.content.lower() + output = ctx.output.lower() expected_output = ctx.expected_output.lower() if output == expected_output: return 1.0 @@ -38,8 +66,9 @@ def evaluate(self, ctx: EvaluatorContext) -> float: dataset = Dataset(cases=[case1], evaluators=[Evaluator1()]) webexplorer = WebExplorer() - report = await dataset.evaluate(webexplorer.arun) + report = await dataset.evaluate(arun_wrapper(webexplorer)) report.print(include_input=True, include_output=True) + save_report_markdown(report, "wikipedia_extraction") async def amazon_extraction(): @@ -56,7 +85,7 @@ async def amazon_extraction(): class Evaluator2(Evaluator): def evaluate(self, ctx: EvaluatorContext) -> float: - output = ctx.output.content.lower() + output = ctx.output.lower() score = 0.0 if "apple iphone 15" in output: @@ -94,15 +123,19 @@ def evaluate(self, ctx: EvaluatorContext) -> float: dataset = Dataset(cases=[case2], evaluators=[Evaluator2()]) webexplorer = WebExplorer(use_lightweight_fetch=False) - report = await dataset.evaluate(webexplorer.arun) + + report = await dataset.evaluate(arun_wrapper(webexplorer)) report.print(include_input=True, include_output=True) + save_report_markdown(report, "amazon_extraction") if __name__ == "__main__": import asyncio - def test(): - # asyncio.run(wikipedia_extraction()) - asyncio.run(amazon_extraction()) + async def test() -> None: + await asyncio.gather( + wikipedia_extraction(), + amazon_extraction(), + ) - test() + asyncio.run(test()) diff --git a/tests/reports/amazon_extraction_report.md b/tests/reports/amazon_extraction_report.md new file mode 100644 index 0000000..ab59e96 --- /dev/null +++ b/tests/reports/amazon_extraction_report.md @@ -0,0 +1,10 @@ +# Evaluation Report: amazon_extraction + + Evaluation Summary: _wrapper +โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“ +โ”ƒ Case ID โ”ƒ Scores โ”ƒ Duration โ”ƒ +โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ +โ”‚ amazon_extraction โ”‚ Evaluator2: 0.420 โ”‚ 46.6s โ”‚ +โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค +โ”‚ Averages โ”‚ Evaluator2: 0.420 โ”‚ 46.6s โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ diff --git a/tests/reports/wikipedia_extraction_report.md b/tests/reports/wikipedia_extraction_report.md new file mode 100644 index 0000000..a28b35c --- /dev/null +++ b/tests/reports/wikipedia_extraction_report.md @@ -0,0 +1,10 @@ +# Evaluation Report: wikipedia_extraction + + Evaluation Summary: _wrapper +โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ณโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”“ +โ”ƒ Case ID โ”ƒ Scores โ”ƒ Duration โ”ƒ +โ”กโ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ•‡โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”ฉ +โ”‚ simple_extraction โ”‚ Evaluator1: 0.800 โ”‚ 14.8s โ”‚ +โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค +โ”‚ Averages โ”‚ Evaluator1: 0.800 โ”‚ 14.8s โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ดโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ diff --git a/tests/test_agent_factory.py b/tests/test_agent_factory.py deleted file mode 100644 index f9325b5..0000000 --- a/tests/test_agent_factory.py +++ /dev/null @@ -1,235 +0,0 @@ -"""Tests for agent factory and model provider logic.""" - -import os -from unittest.mock import Mock, patch - -import pytest - -from aiwebexplorer.agent_factory import _get_api_key, _get_model, get_agent - - -class TestGetApiKey: - """Test API key retrieval functionality.""" - - @patch.dict(os.environ, {"OPENAI_APIKEY": "test-openai-key"}, clear=True) - def test_get_api_key_openai_available(self): - """Test getting OpenAI API key when available.""" - api_key, provider = _get_api_key() - assert api_key == "test-openai-key" - assert provider == "openai" - - @patch("dotenv.load_dotenv") - @patch.dict(os.environ, {"TOGETHERAI_APIKEY": "test-together-key"}, clear=True) - def test_get_api_key_togetherai_available(self, mock_load_dotenv): - """Test getting TogetherAI API key when available.""" - api_key, provider = _get_api_key() - assert api_key == "test-together-key" - assert provider == "togetherai" - - @patch("dotenv.load_dotenv") - @patch.dict(os.environ, {"DEEPSEEK_APIKEY": "test-deepseek-key"}, clear=True) - def test_get_api_key_deepseek_available(self, mock_load_dotenv): - """Test getting DeepSeek API key when available.""" - api_key, provider = _get_api_key() - assert api_key == "test-deepseek-key" - assert provider == "deepseek" - - @patch("dotenv.load_dotenv") - @patch.dict( - os.environ, - { - "OPENAI_APIKEY": "test-openai-key", - "TOGETHERAI_APIKEY": "test-together-key", - "DEEPSEEK_APIKEY": "test-deepseek-key", - }, - clear=True, - ) - def test_get_api_key_priority_order(self, mock_load_dotenv): - """Test API key priority order (OpenAI > TogetherAI > DeepSeek).""" - api_key, provider = _get_api_key() - assert api_key == "test-openai-key" - assert provider == "openai" - - @patch("dotenv.load_dotenv") - @patch.dict(os.environ, {"DEEPSEEK_APIKEY": "test-deepseek-key"}, clear=True) - def test_get_api_key_specific_provider(self, mock_load_dotenv): - """Test getting API key for specific provider.""" - api_key, provider = _get_api_key("deepseek") - assert api_key == "test-deepseek-key" - assert provider == "deepseek" - - @patch("dotenv.load_dotenv") - @patch.dict(os.environ, {"OPENAI_APIKEY": "test-openai-key"}, clear=True) - def test_get_api_key_specific_provider_missing(self, mock_load_dotenv): - """Test getting API key for specific provider when key is missing.""" - with pytest.raises(ValueError, match="Expected DEEPSEEK_APIKEY to be set when requesting provider deepseek"): - _get_api_key("deepseek") - - @patch("dotenv.load_dotenv") - @patch.dict(os.environ, {}, clear=True) - def test_get_api_key_no_keys_available(self, mock_load_dotenv): - """Test getting API key when no keys are available.""" - with pytest.raises(ValueError, match="No api key found for any provider"): - _get_api_key() - - -class TestGetModel: - """Test model creation functionality.""" - - @patch("aiwebexplorer.agent_factory.OpenAIChat") - def test_get_model_openai(self, mock_openai): - """Test creating OpenAI model.""" - mock_model = Mock() - mock_openai.return_value = mock_model - - result = _get_model(model_id="gpt-4", provider="openai", api_key="test-key") - - mock_openai.assert_called_once_with(id="gpt-4", api_key="test-key") - assert result == mock_model - - @patch("aiwebexplorer.agent_factory.Together") - def test_get_model_togetherai(self, mock_together): - """Test creating TogetherAI model.""" - mock_model = Mock() - mock_together.return_value = mock_model - - result = _get_model(model_id="openai/gpt-oss-20b", provider="togetherai", api_key="test-key") - - mock_together.assert_called_once_with(id="openai/gpt-oss-20b", api_key="test-key") - assert result == mock_model - - @patch("aiwebexplorer.agent_factory.DeepSeek") - def test_get_model_deepseek(self, mock_deepseek): - """Test creating DeepSeek model.""" - mock_model = Mock() - mock_deepseek.return_value = mock_model - - result = _get_model(model_id="chat", provider="deepseek", api_key="test-key") - - mock_deepseek.assert_called_once_with(id="chat", api_key="test-key") - assert result == mock_model - - def test_get_model_invalid_provider(self): - """Test creating model with invalid provider.""" - with pytest.raises(ValueError, match="Invalid provider: invalid"): - _get_model(model_id="test-model", provider="invalid", api_key="test-key") - - @patch("aiwebexplorer.agent_factory._get_api_key") - @patch("aiwebexplorer.agent_factory.OpenAIChat") - def test_get_model_auto_api_key(self, mock_openai, mock_get_api_key): - """Test model creation with automatic API key retrieval.""" - mock_get_api_key.return_value = ("test-key", "openai") - mock_model = Mock() - mock_openai.return_value = mock_model - - result = _get_model(model_id="gpt-4") - - mock_get_api_key.assert_called_once_with(None) - mock_openai.assert_called_once_with(id="gpt-4", api_key="test-key") - assert result == mock_model - - @patch("aiwebexplorer.agent_factory._get_api_key") - @patch("aiwebexplorer.agent_factory.OpenAIChat") - def test_get_model_with_model_id_map(self, mock_openai, mock_get_api_key): - """Test model creation with model ID map.""" - mock_get_api_key.return_value = ("test-key", "openai") - mock_model = Mock() - mock_openai.return_value = mock_model - - model_id_map = {"openai": "gpt-4"} - result = _get_model(model_id_map=model_id_map) - - mock_openai.assert_called_once_with(id="gpt-4", api_key="test-key") - assert result == mock_model - - @patch.dict(os.environ, {"OPENAI_APIKEY": "test-key"}, clear=True) - def test_get_model_missing_model_id_map(self): - """Test model creation without model ID map.""" - with pytest.raises(ValueError, match="You didn't provide a model id or a model id map"): - _get_model() - - -class TestGetAgent: - """Test agent creation functionality.""" - - @patch("aiwebexplorer.agent_factory._get_model") - @patch("aiwebexplorer.agent_factory.Agent") - def test_get_agent_basic(self, mock_agent_class, mock_get_model): - """Test basic agent creation.""" - mock_model = Mock() - mock_get_model.return_value = mock_model - mock_agent = Mock() - mock_agent_class.return_value = mock_agent - - instructions = ["Test instruction"] - result = get_agent("test-agent", instructions) - - mock_get_model.assert_called_once() - mock_agent_class.assert_called_once_with(name="test-agent", instructions=instructions, model=mock_model) - assert result == mock_agent - - @patch("aiwebexplorer.agent_factory._get_model") - @patch("aiwebexplorer.agent_factory.Agent") - def test_get_agent_with_model_id_map(self, mock_agent_class, mock_get_model): - """Test agent creation with model ID map.""" - mock_model = Mock() - mock_get_model.return_value = mock_model - mock_agent = Mock() - mock_agent_class.return_value = mock_agent - - instructions = ["Test instruction"] - model_id_map = {"openai": "gpt-4"} - result = get_agent("test-agent", instructions, model_id_map=model_id_map) - - mock_get_model.assert_called_once_with(None, None, None, model_id_map) - assert result == mock_agent - - @patch("aiwebexplorer.agent_factory._get_model") - @patch("aiwebexplorer.agent_factory.Agent") - def test_get_agent_with_specific_provider(self, mock_agent_class, mock_get_model): - """Test agent creation with specific provider.""" - mock_model = Mock() - mock_get_model.return_value = mock_model - mock_agent = Mock() - mock_agent_class.return_value = mock_agent - - instructions = ["Test instruction"] - result = get_agent("test-agent", instructions, provider="openai", api_key="test-key") - - mock_get_model.assert_called_once_with(None, "openai", "test-key", None) - assert result == mock_agent - - @patch("aiwebexplorer.agent_factory._get_model") - @patch("aiwebexplorer.agent_factory.Agent") - def test_get_agent_with_kwargs(self, mock_agent_class, mock_get_model): - """Test agent creation with additional kwargs.""" - mock_model = Mock() - mock_get_model.return_value = mock_model - mock_agent = Mock() - mock_agent_class.return_value = mock_agent - - instructions = ["Test instruction"] - result = get_agent("test-agent", instructions, temperature=0.7, max_tokens=100) - - mock_agent_class.assert_called_once_with( - name="test-agent", instructions=instructions, model=mock_model, temperature=0.7, max_tokens=100 - ) - assert result == mock_agent - - def test_get_agent_interface_compliance(self): - """Test that created agent implements IAgent interface.""" - with ( - patch("aiwebexplorer.agent_factory._get_model") as mock_get_model, - patch("aiwebexplorer.agent_factory.Agent") as mock_agent_class, - ): - mock_model = Mock() - mock_get_model.return_value = mock_model - mock_agent = Mock() - mock_agent_class.return_value = mock_agent - - instructions = ["Test instruction"] - result = get_agent("test-agent", instructions) - - # Verify the agent has the required interface - assert hasattr(result, "arun") - assert callable(result.arun) diff --git a/tests/test_site_detection.py b/tests/test_site_detection.py index 476fd15..bf605bc 100644 --- a/tests/test_site_detection.py +++ b/tests/test_site_detection.py @@ -2,7 +2,7 @@ import pytest -from aiwebexplorer.webexplorer import WebExplorer +from tests.dependencie import get_agent_mock class TestSiteDetection: @@ -10,7 +10,12 @@ class TestSiteDetection: def setup_method(self): """Set up test fixtures.""" - self.webexplorer = WebExplorer() + from aiwebexplorer.agent_factory import get_agent_dependency + + with get_agent_dependency.override(get_agent_mock): + from aiwebexplorer.webexplorer import WebExplorer + + self.webexplorer = WebExplorer() def test_detect_site_type_amazon(self): """Test detecting Amazon site type.""" @@ -56,19 +61,23 @@ def test_detect_site_type_subdomain(self): def test_detect_site_type_custom_selectors(self): """Test site detection with custom selectors.""" - custom_selectors = {"custom": "div.custom-content"} - webexplorer = WebExplorer(site_selectors=custom_selectors) + from aiwebexplorer.agent_factory import get_agent_dependency + from aiwebexplorer.webexplorer import WebExplorer - url = "https://custom-site.com/page" - result = webexplorer._detect_site_type(url) - assert result == "custom" # Should match "custom" in URL + with get_agent_dependency.override(get_agent_mock): + custom_selectors = {"custom": "div.custom-content"} + webexplorer = WebExplorer(site_selectors=custom_selectors) - # Test with custom site in URL - custom_selectors["test"] = "div.test-content" - webexplorer = WebExplorer(site_selectors=custom_selectors) - url = "https://test-site.com/page" - result = webexplorer._detect_site_type(url) - assert result == "test" + url = "https://custom-site.com/page" + result = webexplorer._detect_site_type(url) + assert result == "custom" # Should match "custom" in URL + + # Test with custom site in URL + custom_selectors["test"] = "div.test-content" + webexplorer = WebExplorer(site_selectors=custom_selectors) + url = "https://test-site.com/page" + result = webexplorer._detect_site_type(url) + assert result == "test" class TestHtmlProcessing: @@ -76,7 +85,12 @@ class TestHtmlProcessing: def setup_method(self): """Set up test fixtures.""" - self.webexplorer = WebExplorer() + from aiwebexplorer.agent_factory import get_agent_dependency + + with get_agent_dependency.override(get_agent_mock): + from aiwebexplorer.webexplorer import WebExplorer + + self.webexplorer = WebExplorer() @pytest.mark.asyncio async def test_extract_from_html_basic(self): @@ -119,6 +133,9 @@ async def test_extract_from_html_with_site_selector(self): @pytest.mark.asyncio async def test_extract_from_html_with_custom_selector(self): """Test HTML extraction with custom CSS selector.""" + from aiwebexplorer.agent_factory import get_agent_dependency + from aiwebexplorer.webexplorer import WebExplorer + html = """ @@ -132,11 +149,12 @@ async def test_extract_from_html_with_custom_selector(self): """ - webexplorer = WebExplorer(css_selector=".product-info") - result = await webexplorer._extract_from_html(html, None) - assert "custom product" in result - assert "custom details" in result - assert "should be excluded" not in result + with get_agent_dependency.override(get_agent_mock): + webexplorer = WebExplorer(css_selector=".product-info") + result = await webexplorer._extract_from_html(html, None) + assert "custom product" in result + assert "custom details" in result + assert "should be excluded" not in result @pytest.mark.asyncio async def test_extract_from_html_exclude_tags(self): @@ -181,7 +199,7 @@ async def test_extract_from_html_whitespace_normalization(self):

Product Title

Multiple spaces and - + newlines

@@ -252,6 +270,9 @@ async def test_extract_from_html_complex_structure(self): @pytest.mark.asyncio async def test_extract_from_html_selector_not_found(self): """Test HTML extraction when selector doesn't match.""" + from aiwebexplorer.agent_factory import get_agent_dependency + from aiwebexplorer.webexplorer import WebExplorer + html = """ @@ -262,14 +283,18 @@ async def test_extract_from_html_selector_not_found(self): """ # Try to use a selector that doesn't exist - webexplorer = WebExplorer(css_selector=".non-existent") - result = await webexplorer._extract_from_html(html, None) - # Should fall back to full page content - assert "this content should be included" in result + with get_agent_dependency.override(get_agent_mock): + webexplorer = WebExplorer(css_selector=".non-existent") + result = await webexplorer._extract_from_html(html, None) + # Should fall back to full page content + assert "this content should be included" in result @pytest.mark.asyncio async def test_extract_from_html_site_selector_priority(self): """Test that site selector takes priority over custom selector.""" + from aiwebexplorer.agent_factory import get_agent_dependency + from aiwebexplorer.webexplorer import WebExplorer + html = """ @@ -282,15 +307,19 @@ async def test_extract_from_html_site_selector_priority(self): """ - webexplorer = WebExplorer(css_selector=".custom-content") - result = await webexplorer._extract_from_html(html, "amazon") - # Should use Amazon selector, not custom selector - assert "amazon product" in result - assert "custom product" not in result + with get_agent_dependency.override(get_agent_mock): + webexplorer = WebExplorer(css_selector=".custom-content") + result = await webexplorer._extract_from_html(html, "amazon") + # Should use Amazon selector, not custom selector + assert "amazon product" in result + assert "custom product" not in result @pytest.mark.asyncio async def test_extract_from_html_additional_exclude_tags(self): """Test HTML extraction with additional exclude tags.""" + from aiwebexplorer.agent_factory import get_agent_dependency + from aiwebexplorer.webexplorer import WebExplorer + html = """ @@ -302,11 +331,12 @@ async def test_extract_from_html_additional_exclude_tags(self): """ - webexplorer = WebExplorer(exclude_tags=["video", "audio"]) - result = await webexplorer._extract_from_html(html, None) - assert "product title" in result - assert "visible content" in result - assert "video content" not in result - assert "audio content" not in result - # Canvas should still be excluded (default) - assert "canvas content" not in result + with get_agent_dependency.override(get_agent_mock): + webexplorer = WebExplorer(exclude_tags=["video", "audio"]) + result = await webexplorer._extract_from_html(html, None) + assert "product title" in result + assert "visible content" in result + assert "video content" not in result + assert "audio content" not in result + # Canvas should still be excluded (default) + assert "canvas content" not in result diff --git a/tests/test_text_parsing.py b/tests/test_text_parsing.py index 1bcb955..8ea1fc5 100644 --- a/tests/test_text_parsing.py +++ b/tests/test_text_parsing.py @@ -2,8 +2,10 @@ import pytest +from aiwebexplorer.agent_factory import get_agent_dependency from aiwebexplorer.interfaces import AgentInterfaceError from aiwebexplorer.webexplorer import WebExplorer +from tests.dependencie import get_agent_mock class TestTextExtraction: @@ -11,7 +13,8 @@ class TestTextExtraction: def setup_method(self): """Set up test fixtures.""" - self.webexplorer = WebExplorer() + with get_agent_dependency.override(get_agent_mock): + self.webexplorer = WebExplorer() def test_extract_url_success(self): """Test successful URL extraction.""" @@ -166,7 +169,10 @@ class TestTextProcessing: def setup_method(self): """Set up test fixtures.""" - self.webexplorer = WebExplorer() + with get_agent_dependency.override(get_agent_mock): + from aiwebexplorer.webexplorer import WebExplorer + + self.webexplorer = WebExplorer() def test_split_text_content_default(self): """Test text splitting with default parameters.""" diff --git a/tests/test_webexplorer_integration.py b/tests/test_webexplorer_integration.py index 5bb33d7..959e482 100644 --- a/tests/test_webexplorer_integration.py +++ b/tests/test_webexplorer_integration.py @@ -1,10 +1,12 @@ """Integration tests for WebExplorer with mocked dependencies.""" +from typing import Any from unittest.mock import AsyncMock, Mock, patch import pytest -from aiwebexplorer.interfaces import AgentInterfaceError +from aiwebexplorer.agent_factory import get_agent_dependency +from aiwebexplorer.interfaces import AgentInterfaceError, IAgent from aiwebexplorer.webexplorer import WebExplorer, WebExplorerResponse @@ -18,12 +20,20 @@ def setup_method(self): self.mock_extraction_agent = AsyncMock() self.mock_finalizer_agent = AsyncMock() - # Create WebExplorer with mocked agents - self.webexplorer = WebExplorer( - _evaluate_request_agent=self.mock_evaluation_agent, - _extraction_agent=self.mock_extraction_agent, - _finalizer_agent=self.mock_finalizer_agent, - ) + # Create a mock agent factory that returns the appropriate mock agent + def mock_agent_factory(name: str, **kwargs: Any) -> IAgent[Any]: + if name == "evaluate_request_agent": + return self.mock_evaluation_agent + elif name == "extraction_agent": + return self.mock_extraction_agent + elif name == "finalizer_agent": + return self.mock_finalizer_agent + else: + return Mock(spec=IAgent[Any]) + + # Override the agent dependency with our mock factory + with get_agent_dependency.override(mock_agent_factory): + self.webexplorer = WebExplorer() @pytest.mark.asyncio async def test_arun_happy_path(self):