diff --git a/CHANGELOG.md b/CHANGELOG.md index ed053347..f663ac25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ - Store `dataset_id` in `tables_index` table of CSV database [#341](https://github.com/datagouv/hydra/pull/341) - Enable to analyse and insert parquet files [#342](https://github.com/datagouv/hydra/pull/342) [#346](https://github.com/datagouv/hydra/pull/346) - New `tag_version.sh` script to replace Bump'x [#342](https://github.com/datagouv/hydra/pull/342) +- Reuse HTTP connexions between sessions [#325](https://github.com/datagouv/hydra/pull/325) ## 2.4.1 (2025-09-03) diff --git a/udata_hydra/crawl/check_resources.py b/udata_hydra/crawl/check_resources.py index 9cb3d572..533d7568 100644 --- a/udata_hydra/crawl/check_resources.py +++ b/udata_hydra/crawl/check_resources.py @@ -16,7 +16,7 @@ ) from udata_hydra.crawl.preprocess_check_data import preprocess_check_data from udata_hydra.db.resource import Resource -from udata_hydra.utils import queue +from udata_hydra.utils import get_http_client, queue RESOURCE_RESPONSE_STATUSES = { "OK": "ok", @@ -35,23 +35,20 @@ async def check_batch_resources(to_parse: list[Record]) -> None: """Check a batch of resources""" context.monitor().set_status("Checking resources...") tasks: list = [] - async with aiohttp.ClientSession( - timeout=None, - headers={"user-agent": config.USER_AGENT_FULL}, - ) as session: - for row in to_parse: - tasks.append( - check_resource( - url=row["url"], - resource=row, - session=session, - worker_priority="default" if row["priority"] else "low", - ) + session = await get_http_client() + for row in to_parse: + tasks.append( + check_resource( + url=row["url"], + resource=row, + session=session, + worker_priority="default" if row["priority"] else "low", ) - for task in asyncio.as_completed(tasks): - result = await task - results[result] += 1 - context.monitor().refresh(results) + ) + for task in asyncio.as_completed(tasks): + result = await task + results[result] += 1 + context.monitor().refresh(results) async def check_resource( diff --git a/udata_hydra/routes/checks.py b/udata_hydra/routes/checks.py index 88981240..f1fc6da9 100644 --- a/udata_hydra/routes/checks.py +++ b/udata_hydra/routes/checks.py @@ -1,16 +1,15 @@ import json from datetime import date -import aiohttp from aiohttp import web from asyncpg import Record -from udata_hydra import config, context +from udata_hydra import context from udata_hydra.crawl.check_resources import check_resource from udata_hydra.db.check import Check from udata_hydra.db.resource import Resource from udata_hydra.schemas import CheckGroupBy, CheckSchema -from udata_hydra.utils import get_request_params +from udata_hydra.utils import get_http_client, get_request_params async def get_latest_check(request: web.Request) -> web.Response: @@ -76,18 +75,15 @@ async def create_check(request: web.Request) -> web.Response: context.monitor().set_status(f'Crawling url "{url}"...') - async with aiohttp.ClientSession( - timeout=None, - headers={"user-agent": config.USER_AGENT_FULL}, - ) as session: - status: str = await check_resource( - url=url, - resource=resource, - force_analysis=force_analysis, - session=session, - worker_priority="high", - ) - context.monitor().refresh(status) + session = await get_http_client() + status: str = await check_resource( + url=url, + resource=resource, + force_analysis=force_analysis, + session=session, + worker_priority="high", + ) + context.monitor().refresh(status) check: Record | None = await Check.get_latest(url, resource_id) if not check: diff --git a/udata_hydra/utils/__init__.py b/udata_hydra/utils/__init__.py index bbb2be64..d56eeb54 100644 --- a/udata_hydra/utils/__init__.py +++ b/udata_hydra/utils/__init__.py @@ -9,7 +9,7 @@ remove_remainders, ) from .geojson import detect_geojson_from_headers -from .http import UdataPayload, get_request_params, send +from .http import UdataPayload, get_http_client, get_request_params, send from .parquet import detect_parquet_from_headers from .queue import enqueue from .timer import Timer @@ -28,6 +28,7 @@ "detect_geojson_from_headers", "UdataPayload", "get_request_params", + "get_http_client", "send", "detect_parquet_from_headers", "enqueue", diff --git a/udata_hydra/utils/file.py b/udata_hydra/utils/file.py index 61ddf989..c49f3225 100644 --- a/udata_hydra/utils/file.py +++ b/udata_hydra/utils/file.py @@ -11,7 +11,8 @@ import magic from udata_hydra import config -from udata_hydra.utils import IOException +from udata_hydra.utils.errors import IOException +from udata_hydra.utils.http import get_http_client log = logging.getLogger("udata-hydra") @@ -59,18 +60,15 @@ async def download_resource( i = 0 too_large, download_error = False, None try: - async with aiohttp.ClientSession( - headers={"user-agent": config.USER_AGENT_FULL}, - raise_for_status=True, - ) as session: - async with session.get(url, allow_redirects=True) as response: - async for chunk in response.content.iter_chunked(chunk_size): - if max_size_allowed is None or i * chunk_size < max_size_allowed: - tmp_file.write(chunk) - else: - too_large = True - break - i += 1 + session = await get_http_client() + async with session.get(url, allow_redirects=True) as response: + async for chunk in response.content.iter_chunked(chunk_size): + if max_size_allowed is None or i * chunk_size < max_size_allowed: + tmp_file.write(chunk) + else: + too_large = True + break + i += 1 except aiohttp.ClientResponseError as e: download_error = e finally: @@ -107,13 +105,13 @@ async def download_resource( async def download_file(url: str, fd): """Download a file from URL to a file descriptor""" - async with aiohttp.ClientSession() as session: - async with session.get(url) as resp: - while True: - chunk = await resp.content.read(1024) - if not chunk: - break - fd.write(chunk) + session = await get_http_client() + async with session.get(url) as resp: + while True: + chunk = await resp.content.read(1024) + if not chunk: + break + fd.write(chunk) def remove_remainders(resource_id: str, extensions: list[str]) -> None: diff --git a/udata_hydra/utils/http.py b/udata_hydra/utils/http.py index 876a0bcd..770f030f 100644 --- a/udata_hydra/utils/http.py +++ b/udata_hydra/utils/http.py @@ -10,6 +10,8 @@ log = logging.getLogger("udata-hydra") +_http_client: aiohttp.ClientSession | None = None + class UdataPayload: HYDRA_UDATA_METADATA = { @@ -84,16 +86,58 @@ async def send(dataset_id: str, resource_id: str, document: UdataPayload) -> Non "X-API-KEY": config.UDATA_URI_API_KEY, } - async with aiohttp.ClientSession() as session: - async with session.put(uri, json=document.payload, headers=headers) as resp: - # we're raising since we should be in a worker thread - if resp.status == 404: - pass - elif resp.status == 410: - raise IOException( - "Resource has been deleted on udata", resource_id=resource_id, url=uri - ) - if resp.status == 502: - raise IOException("Udata is unreachable", resource_id=resource_id, url=uri) - else: - resp.raise_for_status() + session = await get_http_client() + async with session.put(uri, json=document.payload, headers=headers) as resp: + # we're raising since we should be in a worker thread + if resp.status == 404: + pass + elif resp.status == 410: + raise IOException( + "Resource has been deleted on udata", resource_id=resource_id, url=uri + ) + if resp.status == 502: + raise IOException("Udata is unreachable", resource_id=resource_id, url=uri) + else: + resp.raise_for_status() + + +async def get_http_client( + follow_redirects: bool = True, timeout: float | None = None +) -> aiohttp.ClientSession: + """Get a shared aiohttp ClientSession instance for performance optimization. + + Args: + follow_redirects: Whether to follow redirects + timeout: Request timeout in seconds + + Returns: + Shared aiohttp ClientSession instance + """ + global _http_client + + if _http_client is None or _http_client.closed: + # Create a new client session + timeout_obj = aiohttp.ClientTimeout(total=timeout) if timeout else None + + # Prepare headers + headers = {} + if config.USER_AGENT_FULL: + headers["User-Agent"] = config.USER_AGENT_FULL + + _http_client = aiohttp.ClientSession( + timeout=timeout_obj, + headers=headers, + ) + log.debug("Created new aiohttp ClientSession") + + return _http_client + + +async def close_http_client(): + """Close the shared aiohttp ClientSession instance.""" + global _http_client + + if _http_client and not _http_client.closed: + await _http_client.close() + _http_client = None + log.debug("Closed aiohttp ClientSession")