diff --git a/swh/lister/save_bulk/lister.py b/swh/lister/save_bulk/lister.py index f57766fea2d373bcd5635c93236724569dcb165c..f0844c0ef357bc40dc0d2a8d9102d7654b86af5a 100644 --- a/swh/lister/save_bulk/lister.py +++ b/swh/lister/save_bulk/lister.py @@ -3,11 +3,12 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import asdict, dataclass, field from http import HTTPStatus import logging import socket -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypedDict +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypedDict, Union from urllib.parse import quote, urlparse from breezy.builtins import cmd_info @@ -280,6 +281,7 @@ class SaveBulkLister(Lister[SaveBulkListerState, SaveBulkListerPage]): max_pages: Optional[int] = None, enable_origins: bool = True, per_page: int = 1000, + max_workers: int = 4, ): super().__init__( scheduler=scheduler, @@ -293,6 +295,7 @@ class SaveBulkLister(Lister[SaveBulkListerState, SaveBulkListerPage]): ) self.rejected_origins: Set[RejectedOrigin] = set() self.per_page = per_page + self.executor = ThreadPoolExecutor(max_workers=max_workers) def state_from_dict(self, d: Dict[str, Any]) -> SaveBulkListerState: return SaveBulkListerState( @@ -316,103 +319,112 @@ class SaveBulkLister(Lister[SaveBulkListerState, SaveBulkListerPage]): self.url, params={"page": current_page, "per_page": self.per_page} ).json() - def get_origins_from_page( - self, origins: SaveBulkListerPage - ) -> Iterator[ListedOrigin]: - assert self.lister_obj.id is not None + def check_origin( + self, origin_url: str, visit_type: str + ) -> Union[ListedOrigin, RejectedOrigin]: + logger.info("Checking origin URL %s for visit type %s.", origin_url, visit_type) - for origin in origins: - origin_url = origin["origin_url"] - visit_type = origin["visit_type"] + assert self.lister_obj.id - logger.info( - "Checking origin URL %s for visit type %s.", origin_url, visit_type - ) - - rejection_details = None - rejection_exception = None + rejection_details = None + rejection_exception = None - parsed_url = urlparse(origin_url) - if rejection_details is None: - if parsed_url.scheme in ("http", "https"): - try: - response = self.session.head(origin_url, allow_redirects=True) - response.raise_for_status() - except ConnectionError as e: + parsed_url = urlparse(origin_url) + if rejection_details is None: + if parsed_url.scheme in ("http", "https"): + try: + self.http_request(origin_url, method="HEAD", allow_redirects=True) + except ConnectionError as e: + logger.info( + "A connection error occurred when requesting %s.", + origin_url, + ) + rejection_details = CONNECTION_ERROR + rejection_exception = str(e) + except RequestException as e: + if e.response is not None: + status = e.response.status_code + status_str = f"{status} - {HTTPStatus(status).phrase}" logger.info( - "A connection error occurred when requesting %s.", + "An HTTP error occurred when requesting %s: %s", origin_url, + status_str, ) - rejection_details = CONNECTION_ERROR - rejection_exception = str(e) - except RequestException as e: - if e.response is not None: - status = e.response.status_code - status_str = f"{status} - {HTTPStatus(status).phrase}" - logger.info( - "An HTTP error occurred when requesting %s: %s", - origin_url, - status_str, - ) - rejection_details = f"{HTTP_ERROR}: {status_str}" - else: - logger.info( - "An HTTP error occurred when requesting %s.", - origin_url, - ) - rejection_details = f"{HTTP_ERROR}." - rejection_exception = str(e) - else: - try: - socket.getaddrinfo(parsed_url.netloc, port=None) - except OSError as e: + rejection_details = f"{HTTP_ERROR}: {status_str}" + else: logger.info( - "Host name %s could not be resolved.", parsed_url.netloc + "An HTTP error occurred when requesting %s.", + origin_url, ) - rejection_details = HOSTNAME_ERROR - rejection_exception = str(e) - - if rejection_details is None: - visit_type_check_url = globals().get( - f"is_valid_{visit_type.split('-', 1)[0]}_url" - ) - if visit_type_check_url: - url_valid, rejection_exception = visit_type_check_url(origin_url) - if not url_valid: - rejection_details = VISIT_TYPE_ERROR[visit_type] - else: - rejection_details = ( - f"Visit type {visit_type} is not supported " - "for bulk on-demand archival." - ) + rejection_details = f"{HTTP_ERROR}." + rejection_exception = str(e) + else: + try: + socket.getaddrinfo(parsed_url.netloc, port=None) + except OSError as e: logger.info( - "Visit type %s for origin URL %s is not supported", - visit_type, - origin_url, + "Host name %s could not be resolved.", parsed_url.netloc ) + rejection_details = HOSTNAME_ERROR + rejection_exception = str(e) - if rejection_details is None: - yield ListedOrigin( - lister_id=self.lister_obj.id, - url=origin["origin_url"], - visit_type=origin["visit_type"], - extra_loader_arguments=( - {"checksum_layout": "standard", "checksums": {}} - if origin["visit_type"] == "tarball-directory" - else {} - ), - ) + if rejection_details is None: + visit_type_check_url = globals().get( + f"is_valid_{visit_type.split('-', 1)[0]}_url" + ) + if visit_type_check_url: + url_valid, rejection_exception = visit_type_check_url(origin_url) + if not url_valid: + rejection_details = VISIT_TYPE_ERROR[visit_type] else: - self.rejected_origins.add( - RejectedOrigin( - origin_url=origin_url, - visit_type=visit_type, - reason=rejection_details, - exception=rejection_exception, - ) + rejection_details = ( + f"Visit type {visit_type} is not supported " + "for bulk on-demand archival." + ) + logger.info( + "Visit type %s for origin URL %s is not supported", + visit_type, + origin_url, ) - # update scheduler state at each rejected origin to get feedback - # using Web API before end of listing - self.state.rejected_origins = list(self.rejected_origins) - self.updated = True - self.set_state_in_scheduler() + + if rejection_details is None: + return ListedOrigin( + lister_id=self.lister_obj.id, + url=origin_url, + visit_type=visit_type, + extra_loader_arguments=( + {"checksum_layout": "standard", "checksums": {}} + if visit_type == "tarball-directory" + else {} + ), + ) + else: + return RejectedOrigin( + origin_url=origin_url, + visit_type=visit_type, + reason=rejection_details, + exception=rejection_exception, + ) + + def get_origins_from_page( + self, origins: SaveBulkListerPage + ) -> Iterator[ListedOrigin]: + assert self.lister_obj.id is not None + + for future in as_completed( + self.executor.submit( + self.check_origin, origin["origin_url"], origin["visit_type"] + ) + for origin in origins + ): + match origin := future.result(): + case ListedOrigin(): + yield origin + case RejectedOrigin(): + self.rejected_origins.add(origin) + + # update scheduler state after each processed page to get feedback + # using Web API before end of listing + self.state.rejected_origins = list(self.rejected_origins) + self.updated = True + self.set_state_in_scheduler() diff --git a/swh/lister/save_bulk/tests/test_lister.py b/swh/lister/save_bulk/tests/test_lister.py index 5b3a8e6dca39c063b31bb30a2b7058d6b29955f8..2bc7551667a31995744511dc149000a79f9499f3 100644 --- a/swh/lister/save_bulk/tests/test_lister.py +++ b/swh/lister/save_bulk/tests/test_lister.py @@ -165,9 +165,9 @@ def test_bulk_lister_not_found_origins(swh_scheduler, requests_mock, mocker): ) ) - # check scheduler state is updated at each not found origin + # check scheduler state is updated after each page # plus at the end of the listing process to set the termination date - expected_calls = [mocker.call()] * len(SUBMITTED_ORIGINS) + expected_calls = [mocker.call()] * stats.pages expected_calls.append(mocker.call(with_listing_finished_date=True)) assert set_state_in_scheduler.mock_calls == expected_calls @@ -219,9 +219,9 @@ def test_bulk_lister_connection_errors(swh_scheduler, requests_mock, mocker): ) ) - # check scheduler state is updated at each origin connection error + # check scheduler state is updated after each page # plus at the end of the listing process to set the termination date - expected_calls = [mocker.call()] * len(SUBMITTED_ORIGINS) + expected_calls = [mocker.call()] * stats.pages expected_calls.append(mocker.call(with_listing_finished_date=True)) assert set_state_in_scheduler.mock_calls == expected_calls @@ -286,9 +286,9 @@ def test_bulk_lister_invalid_origins(swh_scheduler, requests_mock, mocker): ) ) - # check scheduler state is updated at each invalid origin + # check scheduler state is updated after each page # plus at the end of the listing process to set the termination date - expected_calls = [mocker.call()] * (len(SUBMITTED_ORIGINS) - 1) + expected_calls = [mocker.call()] * stats.pages expected_calls.append(mocker.call(with_listing_finished_date=True)) assert set_state_in_scheduler.mock_calls == expected_calls