From af24960bc22d098eadec1740238168a6f20babe3 Mon Sep 17 00:00:00 2001 From: Antoine Lambert <anlambert@softwareheritage.org> Date: Wed, 22 May 2024 17:42:56 +0200 Subject: [PATCH] Add save-bulk lister to check origins prior their insertion in database This new and special lister enables to verify a list of origins to archive provided by users (for instance through the Web API). Its purpose is to avoid polluting the scheduler database with origins that cannot be loaded into the archive. Each origin is identified by an URL and a visit type. For a given visit type the lister is checking if the origin URL can be found and if the visit type is valid. The supported visit types are those for VCS (bzr, cvs, hg, git and svn) plus the one for loading a tarball content into the archive. Accepted origins are inserted or upserted in the scheduler database. Rejected origins are stored in the lister state. Related to #4709 --- mypy.ini | 6 + pyproject.toml | 1 + requirements.txt | 3 + swh/lister/save_bulk/__init__.py | 13 + swh/lister/save_bulk/lister.py | 416 ++++++++++++++++++++++ swh/lister/save_bulk/tasks.py | 19 + swh/lister/save_bulk/tests/__init__.py | 0 swh/lister/save_bulk/tests/test_lister.py | 263 ++++++++++++++ swh/lister/save_bulk/tests/test_tasks.py | 38 ++ swh/lister/tests/test_cli.py | 4 + 10 files changed, 763 insertions(+) create mode 100644 swh/lister/save_bulk/__init__.py create mode 100644 swh/lister/save_bulk/lister.py create mode 100644 swh/lister/save_bulk/tasks.py create mode 100644 swh/lister/save_bulk/tests/__init__.py create mode 100644 swh/lister/save_bulk/tests/test_lister.py create mode 100644 swh/lister/save_bulk/tests/test_tasks.py diff --git a/mypy.ini b/mypy.ini index da4d2d88..56a74240 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,6 +28,9 @@ ignore_missing_imports = True [mypy-lxml.*] ignore_missing_imports = True +[mypy-mercurial.*] +ignore_missing_imports = True + [mypy-pandas.*] ignore_missing_imports = True @@ -61,6 +64,9 @@ ignore_missing_imports = True [mypy-repomd.*] ignore_missing_imports = True +[mypy-subvertpy.*] +ignore_missing_imports = True + [mypy-defusedxml.*] ignore_missing_imports = True # [mypy-add_your_lib_here.*] diff --git a/pyproject.toml b/pyproject.toml index 0cff67c3..19ef992f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ testing = {file = ["requirements-test.txt"]} "lister.bioconductor" = "swh.lister.bioconductor:register" "lister.bitbucket" = "swh.lister.bitbucket:register" "lister.bower" = "swh.lister.bower:register" +"lister.save-bulk" = "swh.lister.save_bulk:register" "lister.cgit" = "swh.lister.cgit:register" "lister.conda" = "swh.lister.conda:register" "lister.cpan" = "swh.lister.cpan:register" diff --git a/requirements.txt b/requirements.txt index dead79e9..be4c6fc6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,19 @@ beautifulsoup4 +breezy >= 3.3.1, < 3.3.5 # use versions with available binary wheels dateparser dulwich iso8601 launchpadlib looseversion lxml +mercurial psycopg2 pyreadr python_debian repomd requests setuptools +subvertpy tenacity >= 8.4.2 testing.postgresql toml diff --git a/swh/lister/save_bulk/__init__.py b/swh/lister/save_bulk/__init__.py new file mode 100644 index 00000000..da37a991 --- /dev/null +++ b/swh/lister/save_bulk/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2024 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +def register(): + from .lister import SaveBulkLister + + return { + "lister": SaveBulkLister, + "task_modules": [f"{__name__}.tasks"], + } diff --git a/swh/lister/save_bulk/lister.py b/swh/lister/save_bulk/lister.py new file mode 100644 index 00000000..4b0e2f53 --- /dev/null +++ b/swh/lister/save_bulk/lister.py @@ -0,0 +1,416 @@ +# Copyright (C) 2024 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from dataclasses import asdict, dataclass, field +from http import HTTPStatus +import logging +import socket +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypedDict +from urllib.parse import quote, urlparse + +from breezy.builtins import cmd_info +from dulwich.porcelain import ls_remote +from mercurial import hg, ui +from requests import ConnectionError, RequestException +from subvertpy import SubversionException, client +from subvertpy.ra import Auth, get_username_provider + +from swh.lister.utils import is_tarball +from swh.scheduler.interface import SchedulerInterface +from swh.scheduler.model import ListedOrigin + +from ..pattern import CredentialsType, Lister + +logger = logging.getLogger(__name__) + + +def _log_invalid_origin_type_for_url( + origin_url: str, origin_type: str, err_msg: Optional[str] = None +): + msg = f"Origin URL {origin_url} does not target a {origin_type}." + if err_msg: + msg += f"\nError details: {err_msg}" + logger.info(msg) + + +def is_valid_tarball_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Checks if an URL targets a tarball using a set of heuritiscs. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL targets a tarball and + second member holds an optional error message if check failed + """ + exc_str = None + try: + ret, _ = is_tarball([origin_url]) + except Exception as e: + ret = False + exc_str = str(e) + if not ret: + _log_invalid_origin_type_for_url(origin_url, "tarball", exc_str) + return ret, exc_str + + +def is_valid_git_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Check if an URL targets a public git repository by attempting to list + its remote refs. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL targets a public git + repository and second member holds an error message if check failed + """ + try: + ls_remote(origin_url) + except Exception as e: + exc_str = str(e) + _log_invalid_origin_type_for_url(origin_url, "public git repository", exc_str) + return False, exc_str + else: + return True, None + + +def is_valid_svn_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Check if an URL targets a public subversion repository by attempting to get + repository information. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL targets a public subversion + repository and second member holds an error message if check failed + """ + svn_client = client.Client(auth=Auth([get_username_provider()])) + try: + svn_client.info(quote(origin_url, safe="/:!$&'()*+,=@").rstrip("/")) + except SubversionException as e: + exc_str = str(e) + _log_invalid_origin_type_for_url( + origin_url, "public subversion repository", exc_str + ) + return False, exc_str + else: + return True, None + + +def is_valid_hg_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Check if an URL targets a public mercurial repository by attempting to connect + to the remote repository. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL targets a public mercurial + repository and second member holds an error message if check failed + """ + hgui = ui.ui() + hgui.setconfig(b"ui", b"interactive", False) + try: + hg.peer(hgui, {}, origin_url.encode()) + except Exception as e: + exc_str = str(e) + _log_invalid_origin_type_for_url( + origin_url, "public mercurial repository", exc_str + ) + return False, exc_str + else: + return True, None + + +def is_valid_bzr_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Check if an URL targets a public bazaar repository by attempting to get + repository information. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL targets a public bazaar + repository and second member holds an error message if check failed + """ + try: + cmd_info().run_argv_aliases([origin_url]) + except Exception as e: + exc_str = str(e) + _log_invalid_origin_type_for_url( + origin_url, "public bazaar repository", exc_str + ) + return False, exc_str + else: + return True, None + + +def is_valid_cvs_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Check if an URL matches one of the formats expected by the CVS loader of + Software Heritage. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL matches one of the formats + expected by the CVS loader and second member holds an error message if + check failed. + """ + err_msg = None + rsync_url_format = "rsync://<hostname>[.*/]<project_name>/<module_name>" + pserver_url_format = ( + "pserver://<usernmame>@<hostname>[.*/]<project_name>/<module_name>" + ) + err_msg_prefix = ( + "The origin URL for the CVS repository is malformed, it should match" + ) + + parsed_url = urlparse(origin_url) + ret = ( + parsed_url.scheme in ("rsync", "pserver") + and len(parsed_url.path.strip("/").split("/")) >= 2 + ) + if parsed_url.scheme == "rsync": + if not ret: + err_msg = f"{err_msg_prefix} '{rsync_url_format}'" + elif parsed_url.scheme == "pserver": + ret = ret and parsed_url.username is not None + if not ret: + err_msg = f"{err_msg_prefix} '{pserver_url_format}'" + else: + err_msg = f"{err_msg_prefix} '{rsync_url_format}' or '{pserver_url_format}'" + + if not ret: + _log_invalid_origin_type_for_url(origin_url, "CVS", err_msg) + + return ret, err_msg + + +CONNECTION_ERROR = "A connection error occurred when requesting origin URL." +HTTP_ERROR = "An HTTP error occurred when requesting origin URL" +HOSTNAME_ERROR = "The hostname could not be resolved." + + +VISIT_TYPE_ERROR: Dict[str, str] = { + "tarball-directory": "The origin URL does not target a tarball.", + "git": "The origin URL does not target a public git repository.", + "svn": "The origin URL does not target a public subversion repository.", + "hg": "The origin URL does not target a public mercurial repository.", + "bzr": "The origin URL does not target a public bazaar repository.", + "cvs": "The origin URL does not target a public CVS repository.", +} + + +class SubmittedOrigin(TypedDict): + origin_url: str + visit_type: str + + +@dataclass(frozen=True) +class RejectedOrigin: + origin_url: str + visit_type: str + reason: str + exception: Optional[str] + + +@dataclass +class SaveBulkListerState: + """Stored lister state""" + + rejected_origins: List[RejectedOrigin] = field(default_factory=list) + """ + List of origins rejected by the lister. + """ + + +SaveBulkListerPage = List[SubmittedOrigin] + + +class SaveBulkLister(Lister[SaveBulkListerState, SaveBulkListerPage]): + """The save-bulk lister enables to verify a list of origins to archive provided + by an HTTP endpoint. Its purpose is to avoid polluting the scheduler database with + origins that cannot be loaded into the archive. + + Each origin is identified by an URL and a visit type. For a given visit type the + lister is checking if the origin URL can be found and if the visit type is valid. + + The HTTP endpoint must return an origins list in a paginated way through the use + of two integer query parameters: ``page`` indicates the page to fetch and `per_page` + corresponds the number of origins in a page. + The endpoint must return a JSON list in the following format: + + .. code-block:: JSON + + [ + { + "origin_url": "https://git.example.org/user/project", + "visit_type": "git" + }, + { + "origin_url": "https://example.org/downloads/project.tar.gz", + "visit_type": "tarball-directory" + } + ] + + + The supported visit types are those for VCS (``bzr``, ``cvs``, ``hg``, ``git`` + and ``svn``) plus the one for loading a tarball content into the archive + (``tarball-directory``). + + Accepted origins are inserted or upserted in the scheduler database. + + Rejected origins are stored in the lister state. + """ + + LISTER_NAME = "save-bulk" + + def __init__( + self, + url: str, + instance: str, + scheduler: SchedulerInterface, + credentials: Optional[CredentialsType] = None, + max_origins_per_page: Optional[int] = None, + max_pages: Optional[int] = None, + enable_origins: bool = True, + per_page: int = 1000, + ): + super().__init__( + scheduler=scheduler, + credentials=credentials, + url=url, + instance=instance, + max_origins_per_page=max_origins_per_page, + max_pages=max_pages, + enable_origins=enable_origins, + ) + self.rejected_origins: Set[RejectedOrigin] = set() + self.per_page = per_page + + def state_from_dict(self, d: Dict[str, Any]) -> SaveBulkListerState: + return SaveBulkListerState( + rejected_origins=[ + RejectedOrigin(**rej) for rej in d.get("rejected_origins", []) + ] + ) + + def state_to_dict(self, state: SaveBulkListerState) -> Dict[str, Any]: + return {"rejected_origins": [asdict(rej) for rej in state.rejected_origins]} + + def get_pages(self) -> Iterator[SaveBulkListerPage]: + current_page = 1 + origins = self.session.get( + self.url, params={"page": current_page, "per_page": self.per_page} + ).json() + while origins: + yield origins + current_page += 1 + origins = self.session.get( + self.url, params={"page": current_page, "per_page": self.per_page} + ).json() + + def get_origins_from_page( + self, origins: SaveBulkListerPage + ) -> Iterator[ListedOrigin]: + assert self.lister_obj.id is not None + + for origin in origins: + origin_url = origin["origin_url"] + visit_type = origin["visit_type"] + + logger.info( + "Checking origin URL %s for visit type %s.", origin_url, visit_type + ) + + rejection_details = None + rejection_exception = None + + parsed_url = urlparse(origin_url) + if rejection_details is None: + if parsed_url.scheme in ("http", "https"): + try: + response = self.session.head(origin_url, allow_redirects=True) + response.raise_for_status() + except ConnectionError as e: + logger.info( + "A connection error occurred when requesting %s.", + origin_url, + ) + rejection_details = CONNECTION_ERROR + rejection_exception = str(e) + except RequestException as e: + if e.response is not None: + status = e.response.status_code + status_str = f"{status} - {HTTPStatus(status).phrase}" + logger.info( + "An HTTP error occurred when requesting %s: %s", + origin_url, + status_str, + ) + rejection_details = f"{HTTP_ERROR}: {status_str}" + else: + logger.info( + "An HTTP error occurred when requesting %s.", + origin_url, + ) + rejection_details = f"{HTTP_ERROR}." + rejection_exception = str(e) + else: + try: + socket.getaddrinfo(parsed_url.netloc, port=None) + except OSError as e: + logger.info( + "Host name %s could not be resolved.", parsed_url.netloc + ) + rejection_details = HOSTNAME_ERROR + rejection_exception = str(e) + + if rejection_details is None: + visit_type_check_url = globals().get( + f"is_valid_{visit_type.split('-', 1)[0]}_url" + ) + if visit_type_check_url: + url_valid, rejection_exception = visit_type_check_url(origin_url) + if not url_valid: + rejection_details = VISIT_TYPE_ERROR[visit_type] + else: + rejection_details = ( + f"Visit type {visit_type} is not supported " + "for bulk on-demand archival." + ) + logger.info( + "Visit type %s for origin URL %s is not supported", + visit_type, + origin_url, + ) + + if rejection_details is None: + yield ListedOrigin( + lister_id=self.lister_obj.id, + url=origin["origin_url"], + visit_type=origin["visit_type"], + extra_loader_arguments=( + {"checksum_layout": "standard", "checksums": {}} + if origin["visit_type"] == "tarball-directory" + else {} + ), + ) + else: + self.rejected_origins.add( + RejectedOrigin( + origin_url=origin_url, + visit_type=visit_type, + reason=rejection_details, + exception=rejection_exception, + ) + ) + # update scheduler state at each rejected origin to get feedback + # using Web API before end of listing + self.state.rejected_origins = list(self.rejected_origins) + self.set_state_in_scheduler() diff --git a/swh/lister/save_bulk/tasks.py b/swh/lister/save_bulk/tasks.py new file mode 100644 index 00000000..17c9d0cf --- /dev/null +++ b/swh/lister/save_bulk/tasks.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from celery import shared_task + +from swh.lister.save_bulk.lister import SaveBulkLister + + +@shared_task(name=__name__ + ".SaveBulkListerTask") +def list_save_bulk(**kwargs): + """Task for save-bulk lister""" + return SaveBulkLister.from_configfile(**kwargs).run().dict() + + +@shared_task(name=__name__ + ".ping") +def _ping(): + return "OK" diff --git a/swh/lister/save_bulk/tests/__init__.py b/swh/lister/save_bulk/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/swh/lister/save_bulk/tests/test_lister.py b/swh/lister/save_bulk/tests/test_lister.py new file mode 100644 index 00000000..30594d71 --- /dev/null +++ b/swh/lister/save_bulk/tests/test_lister.py @@ -0,0 +1,263 @@ +# Copyright (C) 2024 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from operator import attrgetter, itemgetter +import re +import string + +import pytest +import requests + +from swh.lister.pattern import ListerStats +from swh.lister.save_bulk.lister import ( + CONNECTION_ERROR, + HOSTNAME_ERROR, + HTTP_ERROR, + VISIT_TYPE_ERROR, + RejectedOrigin, + SaveBulkLister, + SubmittedOrigin, + is_valid_cvs_url, +) + +URL = "https://example.org/origins/list/" +INSTANCE = "some-instance" + +PER_PAGE = 2 + +SUBMITTED_ORIGINS = [ + SubmittedOrigin(origin_url=origin_url, visit_type=visit_type) + for origin_url, visit_type in [ + ("https://example.org/download/tarball.tar.gz", "tarball-directory"), + ("https://git.example.org/user/project.git", "git"), + ("https://svn.example.org/project/trunk", "svn"), + ("https://hg.example.org/projects/test", "hg"), + ("https://bzr.example.org/projects/test", "bzr"), + ("rsync://cvs.example.org/cvsroot/project/module", "cvs"), + ] +] + + +@pytest.fixture(autouse=True) +def origins_list_requests_mock(requests_mock): + nb_pages = len(SUBMITTED_ORIGINS) // PER_PAGE + for i in range(nb_pages): + requests_mock.get( + f"{URL}?page={i+1}&per_page={PER_PAGE}", + json=SUBMITTED_ORIGINS[i * PER_PAGE : (i + 1) * PER_PAGE], + ) + requests_mock.get( + f"{URL}?page={nb_pages+1}&per_page={PER_PAGE}", + json=[], + ) + + +@pytest.mark.parametrize( + "valid_cvs_url", + [ + "rsync://cvs.example.org/project/module", + "pserver://anonymous@cvs.example.org/project/module", + ], +) +def test_is_valid_cvs_url_success(valid_cvs_url): + assert is_valid_cvs_url(valid_cvs_url) == (True, None) + + +@pytest.mark.parametrize( + "invalid_cvs_url", + [ + "rsync://cvs.example.org/project", + "pserver://anonymous@cvs.example.org/project", + "pserver://cvs.example.org/project/module", + "http://cvs.example.org/project/module", + ], +) +def test_is_valid_cvs_url_failure(invalid_cvs_url): + err_msg_prefix = "The origin URL for the CVS repository is malformed" + ret, err_msg = is_valid_cvs_url(invalid_cvs_url) + assert not ret and err_msg.startswith(err_msg_prefix) + + +def test_bulk_lister_valid_origins(swh_scheduler, requests_mock, mocker): + requests_mock.head(re.compile(".*"), status_code=200) + mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").return_value = [ + ("125.25.14.15", 0) + ] + for origin in SUBMITTED_ORIGINS: + visit_type = origin["visit_type"].split("-", 1)[0] + mocker.patch( + f"swh.lister.save_bulk.lister.is_valid_{visit_type}_url" + ).return_value = (True, None) + + lister_bulk = SaveBulkLister( + url=URL, + instance=INSTANCE, + scheduler=swh_scheduler, + per_page=PER_PAGE, + ) + stats = lister_bulk.run() + + expected_nb_origins = len(SUBMITTED_ORIGINS) + assert stats == ListerStats( + pages=expected_nb_origins // PER_PAGE, origins=expected_nb_origins + ) + + state = lister_bulk.get_state_from_scheduler() + + assert sorted( + [ + SubmittedOrigin(origin_url=origin.url, visit_type=origin.visit_type) + for origin in swh_scheduler.get_listed_origins( + lister_bulk.lister_obj.id + ).results + ], + key=itemgetter("visit_type"), + ) == sorted(SUBMITTED_ORIGINS, key=itemgetter("visit_type")) + assert state.rejected_origins == [] + + +def test_bulk_lister_not_found_origins(swh_scheduler, requests_mock, mocker): + requests_mock.head(re.compile(".*"), status_code=404) + mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").side_effect = ( + OSError("Hostname not found") + ) + + lister_bulk = SaveBulkLister( + url=URL, + instance=INSTANCE, + scheduler=swh_scheduler, + per_page=PER_PAGE, + ) + stats = lister_bulk.run() + + assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=0) + + state = lister_bulk.get_state_from_scheduler() + + assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list( + sorted( + [ + RejectedOrigin( + origin_url=o["origin_url"], + visit_type=o["visit_type"], + reason=( + HTTP_ERROR + ": 404 - Not Found" + if o["origin_url"].startswith("http") + else HOSTNAME_ERROR + ), + exception=( + f"404 Client Error: None for url: {o['origin_url']}" + if o["origin_url"].startswith("http") + else "Hostname not found" + ), + ) + for o in SUBMITTED_ORIGINS + ], + key=attrgetter("origin_url"), + ) + ) + + +def test_bulk_lister_connection_errors(swh_scheduler, requests_mock, mocker): + requests_mock.head( + re.compile(".*"), + exc=requests.exceptions.ConnectionError("connection error"), + ) + mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").side_effect = ( + OSError("Hostname not found") + ) + + lister_bulk = SaveBulkLister( + url=URL, + instance=INSTANCE, + scheduler=swh_scheduler, + per_page=PER_PAGE, + ) + stats = lister_bulk.run() + + assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=0) + + state = lister_bulk.get_state_from_scheduler() + + assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list( + sorted( + [ + RejectedOrigin( + origin_url=o["origin_url"], + visit_type=o["visit_type"], + reason=( + CONNECTION_ERROR + if o["origin_url"].startswith("http") + else HOSTNAME_ERROR + ), + exception=( + "connection error" + if o["origin_url"].startswith("http") + else "Hostname not found" + ), + ) + for o in SUBMITTED_ORIGINS + ], + key=attrgetter("origin_url"), + ) + ) + + +def test_bulk_lister_invalid_origins(swh_scheduler, requests_mock, mocker): + requests_mock.head(re.compile(".*"), status_code=200) + mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").return_value = [ + ("125.25.14.15", 0) + ] + + exc_msg_template = string.Template( + "error: the origin url does not target a public $visit_type repository." + ) + for origin in SUBMITTED_ORIGINS: + visit_type = origin["visit_type"].split("-", 1)[0] + visit_type_check = mocker.patch( + f"swh.lister.save_bulk.lister.is_valid_{visit_type}_url" + ) + if visit_type == "tarball": + visit_type_check.return_value = (True, None) + else: + visit_type_check.return_value = ( + False, + exc_msg_template.substitute(visit_type=visit_type), + ) + + lister_bulk = SaveBulkLister( + url=URL, + instance=INSTANCE, + scheduler=swh_scheduler, + per_page=PER_PAGE, + ) + stats = lister_bulk.run() + + assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=1) + + assert [ + SubmittedOrigin(origin_url=origin.url, visit_type=origin.visit_type) + for origin in swh_scheduler.get_listed_origins( + lister_bulk.lister_obj.id + ).results + ] == [SUBMITTED_ORIGINS[0]] + + state = lister_bulk.get_state_from_scheduler() + + assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list( + sorted( + [ + RejectedOrigin( + origin_url=o["origin_url"], + visit_type=o["visit_type"], + reason=VISIT_TYPE_ERROR[o["visit_type"]], + exception=exc_msg_template.substitute(visit_type=o["visit_type"]), + ) + for o in SUBMITTED_ORIGINS + if o["visit_type"] != "tarball-directory" + ], + key=attrgetter("origin_url"), + ) + ) diff --git a/swh/lister/save_bulk/tests/test_tasks.py b/swh/lister/save_bulk/tests/test_tasks.py new file mode 100644 index 00000000..d8564e66 --- /dev/null +++ b/swh/lister/save_bulk/tests/test_tasks.py @@ -0,0 +1,38 @@ +# Copyright (C) 2024 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from swh.lister.pattern import ListerStats + + +def test_save_bulk_ping(swh_scheduler_celery_app, swh_scheduler_celery_worker): + res = swh_scheduler_celery_app.send_task("swh.lister.save_bulk.tasks.ping") + assert res + res.wait() + assert res.successful() + assert res.result == "OK" + + +def test_save_bulk_lister_task( + swh_scheduler_celery_app, swh_scheduler_celery_worker, mocker +): + lister = mocker.patch("swh.lister.save_bulk.tasks.SaveBulkLister") + lister.from_configfile.return_value = lister + lister.run.return_value = ListerStats(pages=1, origins=2) + + kwargs = dict( + url="https://example.org/origins/list/", + instance="some-instance", + ) + + res = swh_scheduler_celery_app.send_task( + "swh.lister.save_bulk.tasks.SaveBulkListerTask", + kwargs=kwargs, + ) + assert res + res.wait() + assert res.successful() + + lister.from_configfile.assert_called_once_with(**kwargs) + lister.run.assert_called_once() diff --git a/swh/lister/tests/test_cli.py b/swh/lister/tests/test_cli.py index a5645a61..7329424a 100644 --- a/swh/lister/tests/test_cli.py +++ b/swh/lister/tests/test_cli.py @@ -49,6 +49,10 @@ lister_args = { "stagit": { "url": "https://git.codemadness.org", }, + "save-bulk": { + "url": "https://example.org/origins/list/", + "instance": "example.org", + }, } -- GitLab