diff --git a/mypy.ini b/mypy.ini index da4d2d884e64388af62c2ee05c4693fff2b87c7e..56a74240f3c79fb545808dedf9f17d9c0f969844 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,6 +28,9 @@ ignore_missing_imports = True [mypy-lxml.*] ignore_missing_imports = True +[mypy-mercurial.*] +ignore_missing_imports = True + [mypy-pandas.*] ignore_missing_imports = True @@ -61,6 +64,9 @@ ignore_missing_imports = True [mypy-repomd.*] ignore_missing_imports = True +[mypy-subvertpy.*] +ignore_missing_imports = True + [mypy-defusedxml.*] ignore_missing_imports = True # [mypy-add_your_lib_here.*] diff --git a/pyproject.toml b/pyproject.toml index 0cff67c37605349d98711216d9c205b4cc720350..19ef992f5ffe2f416665e3a282fa1d4c390eb4e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ testing = {file = ["requirements-test.txt"]} "lister.bioconductor" = "swh.lister.bioconductor:register" "lister.bitbucket" = "swh.lister.bitbucket:register" "lister.bower" = "swh.lister.bower:register" +"lister.save-bulk" = "swh.lister.save_bulk:register" "lister.cgit" = "swh.lister.cgit:register" "lister.conda" = "swh.lister.conda:register" "lister.cpan" = "swh.lister.cpan:register" diff --git a/requirements.txt b/requirements.txt index dead79e942580b9c2c472ffcac94ee247b111074..be4c6fc62a4f652471e8d534595a81042f925fdf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,19 @@ beautifulsoup4 +breezy >= 3.3.1, < 3.3.5 # use versions with available binary wheels dateparser dulwich iso8601 launchpadlib looseversion lxml +mercurial psycopg2 pyreadr python_debian repomd requests setuptools +subvertpy tenacity >= 8.4.2 testing.postgresql toml diff --git a/swh/lister/save_bulk/__init__.py b/swh/lister/save_bulk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..da37a9915be8141a709a554db8cdbfd02559b5aa --- /dev/null +++ b/swh/lister/save_bulk/__init__.py @@ -0,0 +1,13 @@ +# Copyright (C) 2024 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + + +def register(): + from .lister import SaveBulkLister + + return { + "lister": SaveBulkLister, + "task_modules": [f"{__name__}.tasks"], + } diff --git a/swh/lister/save_bulk/lister.py b/swh/lister/save_bulk/lister.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0e2f53c1c7760f93340c4d4c1de28f7a70550a --- /dev/null +++ b/swh/lister/save_bulk/lister.py @@ -0,0 +1,416 @@ +# Copyright (C) 2024 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from dataclasses import asdict, dataclass, field +from http import HTTPStatus +import logging +import socket +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypedDict +from urllib.parse import quote, urlparse + +from breezy.builtins import cmd_info +from dulwich.porcelain import ls_remote +from mercurial import hg, ui +from requests import ConnectionError, RequestException +from subvertpy import SubversionException, client +from subvertpy.ra import Auth, get_username_provider + +from swh.lister.utils import is_tarball +from swh.scheduler.interface import SchedulerInterface +from swh.scheduler.model import ListedOrigin + +from ..pattern import CredentialsType, Lister + +logger = logging.getLogger(__name__) + + +def _log_invalid_origin_type_for_url( + origin_url: str, origin_type: str, err_msg: Optional[str] = None +): + msg = f"Origin URL {origin_url} does not target a {origin_type}." + if err_msg: + msg += f"\nError details: {err_msg}" + logger.info(msg) + + +def is_valid_tarball_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Checks if an URL targets a tarball using a set of heuritiscs. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL targets a tarball and + second member holds an optional error message if check failed + """ + exc_str = None + try: + ret, _ = is_tarball([origin_url]) + except Exception as e: + ret = False + exc_str = str(e) + if not ret: + _log_invalid_origin_type_for_url(origin_url, "tarball", exc_str) + return ret, exc_str + + +def is_valid_git_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Check if an URL targets a public git repository by attempting to list + its remote refs. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL targets a public git + repository and second member holds an error message if check failed + """ + try: + ls_remote(origin_url) + except Exception as e: + exc_str = str(e) + _log_invalid_origin_type_for_url(origin_url, "public git repository", exc_str) + return False, exc_str + else: + return True, None + + +def is_valid_svn_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Check if an URL targets a public subversion repository by attempting to get + repository information. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL targets a public subversion + repository and second member holds an error message if check failed + """ + svn_client = client.Client(auth=Auth([get_username_provider()])) + try: + svn_client.info(quote(origin_url, safe="/:!$&'()*+,=@").rstrip("/")) + except SubversionException as e: + exc_str = str(e) + _log_invalid_origin_type_for_url( + origin_url, "public subversion repository", exc_str + ) + return False, exc_str + else: + return True, None + + +def is_valid_hg_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Check if an URL targets a public mercurial repository by attempting to connect + to the remote repository. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL targets a public mercurial + repository and second member holds an error message if check failed + """ + hgui = ui.ui() + hgui.setconfig(b"ui", b"interactive", False) + try: + hg.peer(hgui, {}, origin_url.encode()) + except Exception as e: + exc_str = str(e) + _log_invalid_origin_type_for_url( + origin_url, "public mercurial repository", exc_str + ) + return False, exc_str + else: + return True, None + + +def is_valid_bzr_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Check if an URL targets a public bazaar repository by attempting to get + repository information. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL targets a public bazaar + repository and second member holds an error message if check failed + """ + try: + cmd_info().run_argv_aliases([origin_url]) + except Exception as e: + exc_str = str(e) + _log_invalid_origin_type_for_url( + origin_url, "public bazaar repository", exc_str + ) + return False, exc_str + else: + return True, None + + +def is_valid_cvs_url(origin_url: str) -> Tuple[bool, Optional[str]]: + """Check if an URL matches one of the formats expected by the CVS loader of + Software Heritage. + + Args: + origin_url: The URL to check + + Returns: + a tuple whose first member indicates if the URL matches one of the formats + expected by the CVS loader and second member holds an error message if + check failed. + """ + err_msg = None + rsync_url_format = "rsync://<hostname>[.*/]<project_name>/<module_name>" + pserver_url_format = ( + "pserver://<usernmame>@<hostname>[.*/]<project_name>/<module_name>" + ) + err_msg_prefix = ( + "The origin URL for the CVS repository is malformed, it should match" + ) + + parsed_url = urlparse(origin_url) + ret = ( + parsed_url.scheme in ("rsync", "pserver") + and len(parsed_url.path.strip("/").split("/")) >= 2 + ) + if parsed_url.scheme == "rsync": + if not ret: + err_msg = f"{err_msg_prefix} '{rsync_url_format}'" + elif parsed_url.scheme == "pserver": + ret = ret and parsed_url.username is not None + if not ret: + err_msg = f"{err_msg_prefix} '{pserver_url_format}'" + else: + err_msg = f"{err_msg_prefix} '{rsync_url_format}' or '{pserver_url_format}'" + + if not ret: + _log_invalid_origin_type_for_url(origin_url, "CVS", err_msg) + + return ret, err_msg + + +CONNECTION_ERROR = "A connection error occurred when requesting origin URL." +HTTP_ERROR = "An HTTP error occurred when requesting origin URL" +HOSTNAME_ERROR = "The hostname could not be resolved." + + +VISIT_TYPE_ERROR: Dict[str, str] = { + "tarball-directory": "The origin URL does not target a tarball.", + "git": "The origin URL does not target a public git repository.", + "svn": "The origin URL does not target a public subversion repository.", + "hg": "The origin URL does not target a public mercurial repository.", + "bzr": "The origin URL does not target a public bazaar repository.", + "cvs": "The origin URL does not target a public CVS repository.", +} + + +class SubmittedOrigin(TypedDict): + origin_url: str + visit_type: str + + +@dataclass(frozen=True) +class RejectedOrigin: + origin_url: str + visit_type: str + reason: str + exception: Optional[str] + + +@dataclass +class SaveBulkListerState: + """Stored lister state""" + + rejected_origins: List[RejectedOrigin] = field(default_factory=list) + """ + List of origins rejected by the lister. + """ + + +SaveBulkListerPage = List[SubmittedOrigin] + + +class SaveBulkLister(Lister[SaveBulkListerState, SaveBulkListerPage]): + """The save-bulk lister enables to verify a list of origins to archive provided + by an HTTP endpoint. Its purpose is to avoid polluting the scheduler database with + origins that cannot be loaded into the archive. + + Each origin is identified by an URL and a visit type. For a given visit type the + lister is checking if the origin URL can be found and if the visit type is valid. + + The HTTP endpoint must return an origins list in a paginated way through the use + of two integer query parameters: ``page`` indicates the page to fetch and `per_page` + corresponds the number of origins in a page. + The endpoint must return a JSON list in the following format: + + .. code-block:: JSON + + [ + { + "origin_url": "https://git.example.org/user/project", + "visit_type": "git" + }, + { + "origin_url": "https://example.org/downloads/project.tar.gz", + "visit_type": "tarball-directory" + } + ] + + + The supported visit types are those for VCS (``bzr``, ``cvs``, ``hg``, ``git`` + and ``svn``) plus the one for loading a tarball content into the archive + (``tarball-directory``). + + Accepted origins are inserted or upserted in the scheduler database. + + Rejected origins are stored in the lister state. + """ + + LISTER_NAME = "save-bulk" + + def __init__( + self, + url: str, + instance: str, + scheduler: SchedulerInterface, + credentials: Optional[CredentialsType] = None, + max_origins_per_page: Optional[int] = None, + max_pages: Optional[int] = None, + enable_origins: bool = True, + per_page: int = 1000, + ): + super().__init__( + scheduler=scheduler, + credentials=credentials, + url=url, + instance=instance, + max_origins_per_page=max_origins_per_page, + max_pages=max_pages, + enable_origins=enable_origins, + ) + self.rejected_origins: Set[RejectedOrigin] = set() + self.per_page = per_page + + def state_from_dict(self, d: Dict[str, Any]) -> SaveBulkListerState: + return SaveBulkListerState( + rejected_origins=[ + RejectedOrigin(**rej) for rej in d.get("rejected_origins", []) + ] + ) + + def state_to_dict(self, state: SaveBulkListerState) -> Dict[str, Any]: + return {"rejected_origins": [asdict(rej) for rej in state.rejected_origins]} + + def get_pages(self) -> Iterator[SaveBulkListerPage]: + current_page = 1 + origins = self.session.get( + self.url, params={"page": current_page, "per_page": self.per_page} + ).json() + while origins: + yield origins + current_page += 1 + origins = self.session.get( + self.url, params={"page": current_page, "per_page": self.per_page} + ).json() + + def get_origins_from_page( + self, origins: SaveBulkListerPage + ) -> Iterator[ListedOrigin]: + assert self.lister_obj.id is not None + + for origin in origins: + origin_url = origin["origin_url"] + visit_type = origin["visit_type"] + + logger.info( + "Checking origin URL %s for visit type %s.", origin_url, visit_type + ) + + rejection_details = None + rejection_exception = None + + parsed_url = urlparse(origin_url) + if rejection_details is None: + if parsed_url.scheme in ("http", "https"): + try: + response = self.session.head(origin_url, allow_redirects=True) + response.raise_for_status() + except ConnectionError as e: + logger.info( + "A connection error occurred when requesting %s.", + origin_url, + ) + rejection_details = CONNECTION_ERROR + rejection_exception = str(e) + except RequestException as e: + if e.response is not None: + status = e.response.status_code + status_str = f"{status} - {HTTPStatus(status).phrase}" + logger.info( + "An HTTP error occurred when requesting %s: %s", + origin_url, + status_str, + ) + rejection_details = f"{HTTP_ERROR}: {status_str}" + else: + logger.info( + "An HTTP error occurred when requesting %s.", + origin_url, + ) + rejection_details = f"{HTTP_ERROR}." + rejection_exception = str(e) + else: + try: + socket.getaddrinfo(parsed_url.netloc, port=None) + except OSError as e: + logger.info( + "Host name %s could not be resolved.", parsed_url.netloc + ) + rejection_details = HOSTNAME_ERROR + rejection_exception = str(e) + + if rejection_details is None: + visit_type_check_url = globals().get( + f"is_valid_{visit_type.split('-', 1)[0]}_url" + ) + if visit_type_check_url: + url_valid, rejection_exception = visit_type_check_url(origin_url) + if not url_valid: + rejection_details = VISIT_TYPE_ERROR[visit_type] + else: + rejection_details = ( + f"Visit type {visit_type} is not supported " + "for bulk on-demand archival." + ) + logger.info( + "Visit type %s for origin URL %s is not supported", + visit_type, + origin_url, + ) + + if rejection_details is None: + yield ListedOrigin( + lister_id=self.lister_obj.id, + url=origin["origin_url"], + visit_type=origin["visit_type"], + extra_loader_arguments=( + {"checksum_layout": "standard", "checksums": {}} + if origin["visit_type"] == "tarball-directory" + else {} + ), + ) + else: + self.rejected_origins.add( + RejectedOrigin( + origin_url=origin_url, + visit_type=visit_type, + reason=rejection_details, + exception=rejection_exception, + ) + ) + # update scheduler state at each rejected origin to get feedback + # using Web API before end of listing + self.state.rejected_origins = list(self.rejected_origins) + self.set_state_in_scheduler() diff --git a/swh/lister/save_bulk/tasks.py b/swh/lister/save_bulk/tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..17c9d0cf8d4a161f9bd08465efb59d3cf222a99e --- /dev/null +++ b/swh/lister/save_bulk/tasks.py @@ -0,0 +1,19 @@ +# Copyright (C) 2024 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from celery import shared_task + +from swh.lister.save_bulk.lister import SaveBulkLister + + +@shared_task(name=__name__ + ".SaveBulkListerTask") +def list_save_bulk(**kwargs): + """Task for save-bulk lister""" + return SaveBulkLister.from_configfile(**kwargs).run().dict() + + +@shared_task(name=__name__ + ".ping") +def _ping(): + return "OK" diff --git a/swh/lister/save_bulk/tests/__init__.py b/swh/lister/save_bulk/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/swh/lister/save_bulk/tests/test_lister.py b/swh/lister/save_bulk/tests/test_lister.py new file mode 100644 index 0000000000000000000000000000000000000000..30594d71510cd7094a191b017804ce8d97add576 --- /dev/null +++ b/swh/lister/save_bulk/tests/test_lister.py @@ -0,0 +1,263 @@ +# Copyright (C) 2024 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from operator import attrgetter, itemgetter +import re +import string + +import pytest +import requests + +from swh.lister.pattern import ListerStats +from swh.lister.save_bulk.lister import ( + CONNECTION_ERROR, + HOSTNAME_ERROR, + HTTP_ERROR, + VISIT_TYPE_ERROR, + RejectedOrigin, + SaveBulkLister, + SubmittedOrigin, + is_valid_cvs_url, +) + +URL = "https://example.org/origins/list/" +INSTANCE = "some-instance" + +PER_PAGE = 2 + +SUBMITTED_ORIGINS = [ + SubmittedOrigin(origin_url=origin_url, visit_type=visit_type) + for origin_url, visit_type in [ + ("https://example.org/download/tarball.tar.gz", "tarball-directory"), + ("https://git.example.org/user/project.git", "git"), + ("https://svn.example.org/project/trunk", "svn"), + ("https://hg.example.org/projects/test", "hg"), + ("https://bzr.example.org/projects/test", "bzr"), + ("rsync://cvs.example.org/cvsroot/project/module", "cvs"), + ] +] + + +@pytest.fixture(autouse=True) +def origins_list_requests_mock(requests_mock): + nb_pages = len(SUBMITTED_ORIGINS) // PER_PAGE + for i in range(nb_pages): + requests_mock.get( + f"{URL}?page={i+1}&per_page={PER_PAGE}", + json=SUBMITTED_ORIGINS[i * PER_PAGE : (i + 1) * PER_PAGE], + ) + requests_mock.get( + f"{URL}?page={nb_pages+1}&per_page={PER_PAGE}", + json=[], + ) + + +@pytest.mark.parametrize( + "valid_cvs_url", + [ + "rsync://cvs.example.org/project/module", + "pserver://anonymous@cvs.example.org/project/module", + ], +) +def test_is_valid_cvs_url_success(valid_cvs_url): + assert is_valid_cvs_url(valid_cvs_url) == (True, None) + + +@pytest.mark.parametrize( + "invalid_cvs_url", + [ + "rsync://cvs.example.org/project", + "pserver://anonymous@cvs.example.org/project", + "pserver://cvs.example.org/project/module", + "http://cvs.example.org/project/module", + ], +) +def test_is_valid_cvs_url_failure(invalid_cvs_url): + err_msg_prefix = "The origin URL for the CVS repository is malformed" + ret, err_msg = is_valid_cvs_url(invalid_cvs_url) + assert not ret and err_msg.startswith(err_msg_prefix) + + +def test_bulk_lister_valid_origins(swh_scheduler, requests_mock, mocker): + requests_mock.head(re.compile(".*"), status_code=200) + mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").return_value = [ + ("125.25.14.15", 0) + ] + for origin in SUBMITTED_ORIGINS: + visit_type = origin["visit_type"].split("-", 1)[0] + mocker.patch( + f"swh.lister.save_bulk.lister.is_valid_{visit_type}_url" + ).return_value = (True, None) + + lister_bulk = SaveBulkLister( + url=URL, + instance=INSTANCE, + scheduler=swh_scheduler, + per_page=PER_PAGE, + ) + stats = lister_bulk.run() + + expected_nb_origins = len(SUBMITTED_ORIGINS) + assert stats == ListerStats( + pages=expected_nb_origins // PER_PAGE, origins=expected_nb_origins + ) + + state = lister_bulk.get_state_from_scheduler() + + assert sorted( + [ + SubmittedOrigin(origin_url=origin.url, visit_type=origin.visit_type) + for origin in swh_scheduler.get_listed_origins( + lister_bulk.lister_obj.id + ).results + ], + key=itemgetter("visit_type"), + ) == sorted(SUBMITTED_ORIGINS, key=itemgetter("visit_type")) + assert state.rejected_origins == [] + + +def test_bulk_lister_not_found_origins(swh_scheduler, requests_mock, mocker): + requests_mock.head(re.compile(".*"), status_code=404) + mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").side_effect = ( + OSError("Hostname not found") + ) + + lister_bulk = SaveBulkLister( + url=URL, + instance=INSTANCE, + scheduler=swh_scheduler, + per_page=PER_PAGE, + ) + stats = lister_bulk.run() + + assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=0) + + state = lister_bulk.get_state_from_scheduler() + + assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list( + sorted( + [ + RejectedOrigin( + origin_url=o["origin_url"], + visit_type=o["visit_type"], + reason=( + HTTP_ERROR + ": 404 - Not Found" + if o["origin_url"].startswith("http") + else HOSTNAME_ERROR + ), + exception=( + f"404 Client Error: None for url: {o['origin_url']}" + if o["origin_url"].startswith("http") + else "Hostname not found" + ), + ) + for o in SUBMITTED_ORIGINS + ], + key=attrgetter("origin_url"), + ) + ) + + +def test_bulk_lister_connection_errors(swh_scheduler, requests_mock, mocker): + requests_mock.head( + re.compile(".*"), + exc=requests.exceptions.ConnectionError("connection error"), + ) + mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").side_effect = ( + OSError("Hostname not found") + ) + + lister_bulk = SaveBulkLister( + url=URL, + instance=INSTANCE, + scheduler=swh_scheduler, + per_page=PER_PAGE, + ) + stats = lister_bulk.run() + + assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=0) + + state = lister_bulk.get_state_from_scheduler() + + assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list( + sorted( + [ + RejectedOrigin( + origin_url=o["origin_url"], + visit_type=o["visit_type"], + reason=( + CONNECTION_ERROR + if o["origin_url"].startswith("http") + else HOSTNAME_ERROR + ), + exception=( + "connection error" + if o["origin_url"].startswith("http") + else "Hostname not found" + ), + ) + for o in SUBMITTED_ORIGINS + ], + key=attrgetter("origin_url"), + ) + ) + + +def test_bulk_lister_invalid_origins(swh_scheduler, requests_mock, mocker): + requests_mock.head(re.compile(".*"), status_code=200) + mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").return_value = [ + ("125.25.14.15", 0) + ] + + exc_msg_template = string.Template( + "error: the origin url does not target a public $visit_type repository." + ) + for origin in SUBMITTED_ORIGINS: + visit_type = origin["visit_type"].split("-", 1)[0] + visit_type_check = mocker.patch( + f"swh.lister.save_bulk.lister.is_valid_{visit_type}_url" + ) + if visit_type == "tarball": + visit_type_check.return_value = (True, None) + else: + visit_type_check.return_value = ( + False, + exc_msg_template.substitute(visit_type=visit_type), + ) + + lister_bulk = SaveBulkLister( + url=URL, + instance=INSTANCE, + scheduler=swh_scheduler, + per_page=PER_PAGE, + ) + stats = lister_bulk.run() + + assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=1) + + assert [ + SubmittedOrigin(origin_url=origin.url, visit_type=origin.visit_type) + for origin in swh_scheduler.get_listed_origins( + lister_bulk.lister_obj.id + ).results + ] == [SUBMITTED_ORIGINS[0]] + + state = lister_bulk.get_state_from_scheduler() + + assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list( + sorted( + [ + RejectedOrigin( + origin_url=o["origin_url"], + visit_type=o["visit_type"], + reason=VISIT_TYPE_ERROR[o["visit_type"]], + exception=exc_msg_template.substitute(visit_type=o["visit_type"]), + ) + for o in SUBMITTED_ORIGINS + if o["visit_type"] != "tarball-directory" + ], + key=attrgetter("origin_url"), + ) + ) diff --git a/swh/lister/save_bulk/tests/test_tasks.py b/swh/lister/save_bulk/tests/test_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..d8564e665a53dcd758c8a7414e73aafcdd26731e --- /dev/null +++ b/swh/lister/save_bulk/tests/test_tasks.py @@ -0,0 +1,38 @@ +# Copyright (C) 2024 The Software Heritage developers +# See the AUTHORS file at the top-level directory of this distribution +# License: GNU General Public License version 3, or any later version +# See top-level LICENSE file for more information + +from swh.lister.pattern import ListerStats + + +def test_save_bulk_ping(swh_scheduler_celery_app, swh_scheduler_celery_worker): + res = swh_scheduler_celery_app.send_task("swh.lister.save_bulk.tasks.ping") + assert res + res.wait() + assert res.successful() + assert res.result == "OK" + + +def test_save_bulk_lister_task( + swh_scheduler_celery_app, swh_scheduler_celery_worker, mocker +): + lister = mocker.patch("swh.lister.save_bulk.tasks.SaveBulkLister") + lister.from_configfile.return_value = lister + lister.run.return_value = ListerStats(pages=1, origins=2) + + kwargs = dict( + url="https://example.org/origins/list/", + instance="some-instance", + ) + + res = swh_scheduler_celery_app.send_task( + "swh.lister.save_bulk.tasks.SaveBulkListerTask", + kwargs=kwargs, + ) + assert res + res.wait() + assert res.successful() + + lister.from_configfile.assert_called_once_with(**kwargs) + lister.run.assert_called_once() diff --git a/swh/lister/tests/test_cli.py b/swh/lister/tests/test_cli.py index a5645a610511371984caa2e892e41235da024790..7329424ad0997ef6ad5b45a3a2717c7281fd4368 100644 --- a/swh/lister/tests/test_cli.py +++ b/swh/lister/tests/test_cli.py @@ -49,6 +49,10 @@ lister_args = { "stagit": { "url": "https://git.codemadness.org", }, + "save-bulk": { + "url": "https://example.org/origins/list/", + "instance": "example.org", + }, }