Skip to content
Snippets Groups Projects
Commit af24960b authored by Antoine Lambert's avatar Antoine Lambert
Browse files

Add save-bulk lister to check origins prior their insertion in database

This new and special lister enables to verify a list of origins to archive
provided by users (for instance through the Web API).

Its purpose is to avoid polluting the scheduler database with origins that
cannot be loaded into the archive.

Each origin is identified by an URL and a visit type. For a given visit type
the lister is checking if the origin URL can be found and if the visit type
is valid.

The supported visit types are those for VCS (bzr, cvs, hg, git and svn) plus
the one for loading a tarball content into the archive.

Accepted origins are inserted or upserted in the scheduler database.

Rejected origins are stored in the lister state.

Related to #4709
parent 6618cf34
No related branches found
No related tags found
1 merge request!528Add save-bulk lister to check origins prior their insertion in database
Pipeline #10682 passed
......@@ -28,6 +28,9 @@ ignore_missing_imports = True
[mypy-lxml.*]
ignore_missing_imports = True
[mypy-mercurial.*]
ignore_missing_imports = True
[mypy-pandas.*]
ignore_missing_imports = True
......@@ -61,6 +64,9 @@ ignore_missing_imports = True
[mypy-repomd.*]
ignore_missing_imports = True
[mypy-subvertpy.*]
ignore_missing_imports = True
[mypy-defusedxml.*]
ignore_missing_imports = True
# [mypy-add_your_lib_here.*]
......
......@@ -34,6 +34,7 @@ testing = {file = ["requirements-test.txt"]}
"lister.bioconductor" = "swh.lister.bioconductor:register"
"lister.bitbucket" = "swh.lister.bitbucket:register"
"lister.bower" = "swh.lister.bower:register"
"lister.save-bulk" = "swh.lister.save_bulk:register"
"lister.cgit" = "swh.lister.cgit:register"
"lister.conda" = "swh.lister.conda:register"
"lister.cpan" = "swh.lister.cpan:register"
......
# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
def register():
from .lister import SaveBulkLister
return {
"lister": SaveBulkLister,
"task_modules": [f"{__name__}.tasks"],
}
# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from dataclasses import asdict, dataclass, field
from http import HTTPStatus
import logging
import socket
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, TypedDict
from urllib.parse import quote, urlparse
from breezy.builtins import cmd_info
from dulwich.porcelain import ls_remote
from mercurial import hg, ui
from requests import ConnectionError, RequestException
from subvertpy import SubversionException, client
from subvertpy.ra import Auth, get_username_provider
from swh.lister.utils import is_tarball
from swh.scheduler.interface import SchedulerInterface
from swh.scheduler.model import ListedOrigin
from ..pattern import CredentialsType, Lister
logger = logging.getLogger(__name__)
def _log_invalid_origin_type_for_url(
origin_url: str, origin_type: str, err_msg: Optional[str] = None
):
msg = f"Origin URL {origin_url} does not target a {origin_type}."
if err_msg:
msg += f"\nError details: {err_msg}"
logger.info(msg)
def is_valid_tarball_url(origin_url: str) -> Tuple[bool, Optional[str]]:
"""Checks if an URL targets a tarball using a set of heuritiscs.
Args:
origin_url: The URL to check
Returns:
a tuple whose first member indicates if the URL targets a tarball and
second member holds an optional error message if check failed
"""
exc_str = None
try:
ret, _ = is_tarball([origin_url])
except Exception as e:
ret = False
exc_str = str(e)
if not ret:
_log_invalid_origin_type_for_url(origin_url, "tarball", exc_str)
return ret, exc_str
def is_valid_git_url(origin_url: str) -> Tuple[bool, Optional[str]]:
"""Check if an URL targets a public git repository by attempting to list
its remote refs.
Args:
origin_url: The URL to check
Returns:
a tuple whose first member indicates if the URL targets a public git
repository and second member holds an error message if check failed
"""
try:
ls_remote(origin_url)
except Exception as e:
exc_str = str(e)
_log_invalid_origin_type_for_url(origin_url, "public git repository", exc_str)
return False, exc_str
else:
return True, None
def is_valid_svn_url(origin_url: str) -> Tuple[bool, Optional[str]]:
"""Check if an URL targets a public subversion repository by attempting to get
repository information.
Args:
origin_url: The URL to check
Returns:
a tuple whose first member indicates if the URL targets a public subversion
repository and second member holds an error message if check failed
"""
svn_client = client.Client(auth=Auth([get_username_provider()]))
try:
svn_client.info(quote(origin_url, safe="/:!$&'()*+,=@").rstrip("/"))
except SubversionException as e:
exc_str = str(e)
_log_invalid_origin_type_for_url(
origin_url, "public subversion repository", exc_str
)
return False, exc_str
else:
return True, None
def is_valid_hg_url(origin_url: str) -> Tuple[bool, Optional[str]]:
"""Check if an URL targets a public mercurial repository by attempting to connect
to the remote repository.
Args:
origin_url: The URL to check
Returns:
a tuple whose first member indicates if the URL targets a public mercurial
repository and second member holds an error message if check failed
"""
hgui = ui.ui()
hgui.setconfig(b"ui", b"interactive", False)
try:
hg.peer(hgui, {}, origin_url.encode())
except Exception as e:
exc_str = str(e)
_log_invalid_origin_type_for_url(
origin_url, "public mercurial repository", exc_str
)
return False, exc_str
else:
return True, None
def is_valid_bzr_url(origin_url: str) -> Tuple[bool, Optional[str]]:
"""Check if an URL targets a public bazaar repository by attempting to get
repository information.
Args:
origin_url: The URL to check
Returns:
a tuple whose first member indicates if the URL targets a public bazaar
repository and second member holds an error message if check failed
"""
try:
cmd_info().run_argv_aliases([origin_url])
except Exception as e:
exc_str = str(e)
_log_invalid_origin_type_for_url(
origin_url, "public bazaar repository", exc_str
)
return False, exc_str
else:
return True, None
def is_valid_cvs_url(origin_url: str) -> Tuple[bool, Optional[str]]:
"""Check if an URL matches one of the formats expected by the CVS loader of
Software Heritage.
Args:
origin_url: The URL to check
Returns:
a tuple whose first member indicates if the URL matches one of the formats
expected by the CVS loader and second member holds an error message if
check failed.
"""
err_msg = None
rsync_url_format = "rsync://<hostname>[.*/]<project_name>/<module_name>"
pserver_url_format = (
"pserver://<usernmame>@<hostname>[.*/]<project_name>/<module_name>"
)
err_msg_prefix = (
"The origin URL for the CVS repository is malformed, it should match"
)
parsed_url = urlparse(origin_url)
ret = (
parsed_url.scheme in ("rsync", "pserver")
and len(parsed_url.path.strip("/").split("/")) >= 2
)
if parsed_url.scheme == "rsync":
if not ret:
err_msg = f"{err_msg_prefix} '{rsync_url_format}'"
elif parsed_url.scheme == "pserver":
ret = ret and parsed_url.username is not None
if not ret:
err_msg = f"{err_msg_prefix} '{pserver_url_format}'"
else:
err_msg = f"{err_msg_prefix} '{rsync_url_format}' or '{pserver_url_format}'"
if not ret:
_log_invalid_origin_type_for_url(origin_url, "CVS", err_msg)
return ret, err_msg
CONNECTION_ERROR = "A connection error occurred when requesting origin URL."
HTTP_ERROR = "An HTTP error occurred when requesting origin URL"
HOSTNAME_ERROR = "The hostname could not be resolved."
VISIT_TYPE_ERROR: Dict[str, str] = {
"tarball-directory": "The origin URL does not target a tarball.",
"git": "The origin URL does not target a public git repository.",
"svn": "The origin URL does not target a public subversion repository.",
"hg": "The origin URL does not target a public mercurial repository.",
"bzr": "The origin URL does not target a public bazaar repository.",
"cvs": "The origin URL does not target a public CVS repository.",
}
class SubmittedOrigin(TypedDict):
origin_url: str
visit_type: str
@dataclass(frozen=True)
class RejectedOrigin:
origin_url: str
visit_type: str
reason: str
exception: Optional[str]
@dataclass
class SaveBulkListerState:
"""Stored lister state"""
rejected_origins: List[RejectedOrigin] = field(default_factory=list)
"""
List of origins rejected by the lister.
"""
SaveBulkListerPage = List[SubmittedOrigin]
class SaveBulkLister(Lister[SaveBulkListerState, SaveBulkListerPage]):
"""The save-bulk lister enables to verify a list of origins to archive provided
by an HTTP endpoint. Its purpose is to avoid polluting the scheduler database with
origins that cannot be loaded into the archive.
Each origin is identified by an URL and a visit type. For a given visit type the
lister is checking if the origin URL can be found and if the visit type is valid.
The HTTP endpoint must return an origins list in a paginated way through the use
of two integer query parameters: ``page`` indicates the page to fetch and `per_page`
corresponds the number of origins in a page.
The endpoint must return a JSON list in the following format:
.. code-block:: JSON
[
{
"origin_url": "https://git.example.org/user/project",
"visit_type": "git"
},
{
"origin_url": "https://example.org/downloads/project.tar.gz",
"visit_type": "tarball-directory"
}
]
The supported visit types are those for VCS (``bzr``, ``cvs``, ``hg``, ``git``
and ``svn``) plus the one for loading a tarball content into the archive
(``tarball-directory``).
Accepted origins are inserted or upserted in the scheduler database.
Rejected origins are stored in the lister state.
"""
LISTER_NAME = "save-bulk"
def __init__(
self,
url: str,
instance: str,
scheduler: SchedulerInterface,
credentials: Optional[CredentialsType] = None,
max_origins_per_page: Optional[int] = None,
max_pages: Optional[int] = None,
enable_origins: bool = True,
per_page: int = 1000,
):
super().__init__(
scheduler=scheduler,
credentials=credentials,
url=url,
instance=instance,
max_origins_per_page=max_origins_per_page,
max_pages=max_pages,
enable_origins=enable_origins,
)
self.rejected_origins: Set[RejectedOrigin] = set()
self.per_page = per_page
def state_from_dict(self, d: Dict[str, Any]) -> SaveBulkListerState:
return SaveBulkListerState(
rejected_origins=[
RejectedOrigin(**rej) for rej in d.get("rejected_origins", [])
]
)
def state_to_dict(self, state: SaveBulkListerState) -> Dict[str, Any]:
return {"rejected_origins": [asdict(rej) for rej in state.rejected_origins]}
def get_pages(self) -> Iterator[SaveBulkListerPage]:
current_page = 1
origins = self.session.get(
self.url, params={"page": current_page, "per_page": self.per_page}
).json()
while origins:
yield origins
current_page += 1
origins = self.session.get(
self.url, params={"page": current_page, "per_page": self.per_page}
).json()
def get_origins_from_page(
self, origins: SaveBulkListerPage
) -> Iterator[ListedOrigin]:
assert self.lister_obj.id is not None
for origin in origins:
origin_url = origin["origin_url"]
visit_type = origin["visit_type"]
logger.info(
"Checking origin URL %s for visit type %s.", origin_url, visit_type
)
rejection_details = None
rejection_exception = None
parsed_url = urlparse(origin_url)
if rejection_details is None:
if parsed_url.scheme in ("http", "https"):
try:
response = self.session.head(origin_url, allow_redirects=True)
response.raise_for_status()
except ConnectionError as e:
logger.info(
"A connection error occurred when requesting %s.",
origin_url,
)
rejection_details = CONNECTION_ERROR
rejection_exception = str(e)
except RequestException as e:
if e.response is not None:
status = e.response.status_code
status_str = f"{status} - {HTTPStatus(status).phrase}"
logger.info(
"An HTTP error occurred when requesting %s: %s",
origin_url,
status_str,
)
rejection_details = f"{HTTP_ERROR}: {status_str}"
else:
logger.info(
"An HTTP error occurred when requesting %s.",
origin_url,
)
rejection_details = f"{HTTP_ERROR}."
rejection_exception = str(e)
else:
try:
socket.getaddrinfo(parsed_url.netloc, port=None)
except OSError as e:
logger.info(
"Host name %s could not be resolved.", parsed_url.netloc
)
rejection_details = HOSTNAME_ERROR
rejection_exception = str(e)
if rejection_details is None:
visit_type_check_url = globals().get(
f"is_valid_{visit_type.split('-', 1)[0]}_url"
)
if visit_type_check_url:
url_valid, rejection_exception = visit_type_check_url(origin_url)
if not url_valid:
rejection_details = VISIT_TYPE_ERROR[visit_type]
else:
rejection_details = (
f"Visit type {visit_type} is not supported "
"for bulk on-demand archival."
)
logger.info(
"Visit type %s for origin URL %s is not supported",
visit_type,
origin_url,
)
if rejection_details is None:
yield ListedOrigin(
lister_id=self.lister_obj.id,
url=origin["origin_url"],
visit_type=origin["visit_type"],
extra_loader_arguments=(
{"checksum_layout": "standard", "checksums": {}}
if origin["visit_type"] == "tarball-directory"
else {}
),
)
else:
self.rejected_origins.add(
RejectedOrigin(
origin_url=origin_url,
visit_type=visit_type,
reason=rejection_details,
exception=rejection_exception,
)
)
# update scheduler state at each rejected origin to get feedback
# using Web API before end of listing
self.state.rejected_origins = list(self.rejected_origins)
self.set_state_in_scheduler()
# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from celery import shared_task
from swh.lister.save_bulk.lister import SaveBulkLister
@shared_task(name=__name__ + ".SaveBulkListerTask")
def list_save_bulk(**kwargs):
"""Task for save-bulk lister"""
return SaveBulkLister.from_configfile(**kwargs).run().dict()
@shared_task(name=__name__ + ".ping")
def _ping():
return "OK"
# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from operator import attrgetter, itemgetter
import re
import string
import pytest
import requests
from swh.lister.pattern import ListerStats
from swh.lister.save_bulk.lister import (
CONNECTION_ERROR,
HOSTNAME_ERROR,
HTTP_ERROR,
VISIT_TYPE_ERROR,
RejectedOrigin,
SaveBulkLister,
SubmittedOrigin,
is_valid_cvs_url,
)
URL = "https://example.org/origins/list/"
INSTANCE = "some-instance"
PER_PAGE = 2
SUBMITTED_ORIGINS = [
SubmittedOrigin(origin_url=origin_url, visit_type=visit_type)
for origin_url, visit_type in [
("https://example.org/download/tarball.tar.gz", "tarball-directory"),
("https://git.example.org/user/project.git", "git"),
("https://svn.example.org/project/trunk", "svn"),
("https://hg.example.org/projects/test", "hg"),
("https://bzr.example.org/projects/test", "bzr"),
("rsync://cvs.example.org/cvsroot/project/module", "cvs"),
]
]
@pytest.fixture(autouse=True)
def origins_list_requests_mock(requests_mock):
nb_pages = len(SUBMITTED_ORIGINS) // PER_PAGE
for i in range(nb_pages):
requests_mock.get(
f"{URL}?page={i+1}&per_page={PER_PAGE}",
json=SUBMITTED_ORIGINS[i * PER_PAGE : (i + 1) * PER_PAGE],
)
requests_mock.get(
f"{URL}?page={nb_pages+1}&per_page={PER_PAGE}",
json=[],
)
@pytest.mark.parametrize(
"valid_cvs_url",
[
"rsync://cvs.example.org/project/module",
"pserver://anonymous@cvs.example.org/project/module",
],
)
def test_is_valid_cvs_url_success(valid_cvs_url):
assert is_valid_cvs_url(valid_cvs_url) == (True, None)
@pytest.mark.parametrize(
"invalid_cvs_url",
[
"rsync://cvs.example.org/project",
"pserver://anonymous@cvs.example.org/project",
"pserver://cvs.example.org/project/module",
"http://cvs.example.org/project/module",
],
)
def test_is_valid_cvs_url_failure(invalid_cvs_url):
err_msg_prefix = "The origin URL for the CVS repository is malformed"
ret, err_msg = is_valid_cvs_url(invalid_cvs_url)
assert not ret and err_msg.startswith(err_msg_prefix)
def test_bulk_lister_valid_origins(swh_scheduler, requests_mock, mocker):
requests_mock.head(re.compile(".*"), status_code=200)
mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").return_value = [
("125.25.14.15", 0)
]
for origin in SUBMITTED_ORIGINS:
visit_type = origin["visit_type"].split("-", 1)[0]
mocker.patch(
f"swh.lister.save_bulk.lister.is_valid_{visit_type}_url"
).return_value = (True, None)
lister_bulk = SaveBulkLister(
url=URL,
instance=INSTANCE,
scheduler=swh_scheduler,
per_page=PER_PAGE,
)
stats = lister_bulk.run()
expected_nb_origins = len(SUBMITTED_ORIGINS)
assert stats == ListerStats(
pages=expected_nb_origins // PER_PAGE, origins=expected_nb_origins
)
state = lister_bulk.get_state_from_scheduler()
assert sorted(
[
SubmittedOrigin(origin_url=origin.url, visit_type=origin.visit_type)
for origin in swh_scheduler.get_listed_origins(
lister_bulk.lister_obj.id
).results
],
key=itemgetter("visit_type"),
) == sorted(SUBMITTED_ORIGINS, key=itemgetter("visit_type"))
assert state.rejected_origins == []
def test_bulk_lister_not_found_origins(swh_scheduler, requests_mock, mocker):
requests_mock.head(re.compile(".*"), status_code=404)
mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").side_effect = (
OSError("Hostname not found")
)
lister_bulk = SaveBulkLister(
url=URL,
instance=INSTANCE,
scheduler=swh_scheduler,
per_page=PER_PAGE,
)
stats = lister_bulk.run()
assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=0)
state = lister_bulk.get_state_from_scheduler()
assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list(
sorted(
[
RejectedOrigin(
origin_url=o["origin_url"],
visit_type=o["visit_type"],
reason=(
HTTP_ERROR + ": 404 - Not Found"
if o["origin_url"].startswith("http")
else HOSTNAME_ERROR
),
exception=(
f"404 Client Error: None for url: {o['origin_url']}"
if o["origin_url"].startswith("http")
else "Hostname not found"
),
)
for o in SUBMITTED_ORIGINS
],
key=attrgetter("origin_url"),
)
)
def test_bulk_lister_connection_errors(swh_scheduler, requests_mock, mocker):
requests_mock.head(
re.compile(".*"),
exc=requests.exceptions.ConnectionError("connection error"),
)
mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").side_effect = (
OSError("Hostname not found")
)
lister_bulk = SaveBulkLister(
url=URL,
instance=INSTANCE,
scheduler=swh_scheduler,
per_page=PER_PAGE,
)
stats = lister_bulk.run()
assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=0)
state = lister_bulk.get_state_from_scheduler()
assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list(
sorted(
[
RejectedOrigin(
origin_url=o["origin_url"],
visit_type=o["visit_type"],
reason=(
CONNECTION_ERROR
if o["origin_url"].startswith("http")
else HOSTNAME_ERROR
),
exception=(
"connection error"
if o["origin_url"].startswith("http")
else "Hostname not found"
),
)
for o in SUBMITTED_ORIGINS
],
key=attrgetter("origin_url"),
)
)
def test_bulk_lister_invalid_origins(swh_scheduler, requests_mock, mocker):
requests_mock.head(re.compile(".*"), status_code=200)
mocker.patch("swh.lister.save_bulk.lister.socket.getaddrinfo").return_value = [
("125.25.14.15", 0)
]
exc_msg_template = string.Template(
"error: the origin url does not target a public $visit_type repository."
)
for origin in SUBMITTED_ORIGINS:
visit_type = origin["visit_type"].split("-", 1)[0]
visit_type_check = mocker.patch(
f"swh.lister.save_bulk.lister.is_valid_{visit_type}_url"
)
if visit_type == "tarball":
visit_type_check.return_value = (True, None)
else:
visit_type_check.return_value = (
False,
exc_msg_template.substitute(visit_type=visit_type),
)
lister_bulk = SaveBulkLister(
url=URL,
instance=INSTANCE,
scheduler=swh_scheduler,
per_page=PER_PAGE,
)
stats = lister_bulk.run()
assert stats == ListerStats(pages=len(SUBMITTED_ORIGINS) // PER_PAGE, origins=1)
assert [
SubmittedOrigin(origin_url=origin.url, visit_type=origin.visit_type)
for origin in swh_scheduler.get_listed_origins(
lister_bulk.lister_obj.id
).results
] == [SUBMITTED_ORIGINS[0]]
state = lister_bulk.get_state_from_scheduler()
assert list(sorted(state.rejected_origins, key=attrgetter("origin_url"))) == list(
sorted(
[
RejectedOrigin(
origin_url=o["origin_url"],
visit_type=o["visit_type"],
reason=VISIT_TYPE_ERROR[o["visit_type"]],
exception=exc_msg_template.substitute(visit_type=o["visit_type"]),
)
for o in SUBMITTED_ORIGINS
if o["visit_type"] != "tarball-directory"
],
key=attrgetter("origin_url"),
)
)
# Copyright (C) 2024 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from swh.lister.pattern import ListerStats
def test_save_bulk_ping(swh_scheduler_celery_app, swh_scheduler_celery_worker):
res = swh_scheduler_celery_app.send_task("swh.lister.save_bulk.tasks.ping")
assert res
res.wait()
assert res.successful()
assert res.result == "OK"
def test_save_bulk_lister_task(
swh_scheduler_celery_app, swh_scheduler_celery_worker, mocker
):
lister = mocker.patch("swh.lister.save_bulk.tasks.SaveBulkLister")
lister.from_configfile.return_value = lister
lister.run.return_value = ListerStats(pages=1, origins=2)
kwargs = dict(
url="https://example.org/origins/list/",
instance="some-instance",
)
res = swh_scheduler_celery_app.send_task(
"swh.lister.save_bulk.tasks.SaveBulkListerTask",
kwargs=kwargs,
)
assert res
res.wait()
assert res.successful()
lister.from_configfile.assert_called_once_with(**kwargs)
lister.run.assert_called_once()
......@@ -49,6 +49,10 @@ lister_args = {
"stagit": {
"url": "https://git.codemadness.org",
},
"save-bulk": {
"url": "https://example.org/origins/list/",
"instance": "example.org",
},
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment