diff --git a/swh/lister/pattern.py b/swh/lister/pattern.py index 852ab174e43069eb1a4450ffd6221fddff3f7265..ca1333582531842f697db7b3d1822accc61ea845 100644 --- a/swh/lister/pattern.py +++ b/swh/lister/pattern.py @@ -89,6 +89,7 @@ class Lister(Generic[StateType, PageType]): max_pages: the maximum number of pages listed in a full listing operation max_origins_per_page: the maximum number of origins processed per page enable_origins: whether the created origins should be enabled or not + record_batch_size: maximum number of records to flush to the scheduler at once. Generic types: - *StateType*: concrete lister type; should usually be a :class:`dataclass` for @@ -111,6 +112,7 @@ class Lister(Generic[StateType, PageType]): max_pages: Optional[int] = None, enable_origins: bool = True, with_github_session: bool = False, + record_batch_size: int = 1000, ): if not self.LISTER_NAME: raise ValueError("Must set the LISTER_NAME attribute on Lister classes") @@ -165,6 +167,7 @@ class Lister(Generic[StateType, PageType]): self.max_pages = max_pages self.max_origins_per_page = max_origins_per_page self.enable_origins = enable_origins + self.record_batch_size = record_batch_size def build_url(self, instance: str) -> str: """Optionally build the forge url to list. When the url is not provided in the @@ -344,8 +347,8 @@ class Lister(Generic[StateType, PageType]): Returns: the list of origin URLs recorded in scheduler database """ - recorded_origins = [] - for origins in grouper(origins, n=1000): + recorded_origins: List[str] = [] + for origins in grouper(origins, n=self.record_batch_size): valid_origins = [] for origin in origins: if is_valid_origin_url(origin.url): @@ -354,7 +357,7 @@ class Lister(Generic[StateType, PageType]): logger.warning("Skipping invalid origin: %s", origin.url) ret = self.scheduler.record_listed_origins(valid_origins) - recorded_origins += [origin.url for origin in ret] + recorded_origins.extend(origin.url for origin in ret) return recorded_origins diff --git a/swh/lister/tests/test_pattern.py b/swh/lister/tests/test_pattern.py index 33dc3e7a3e6a2623b283c5acad091ce20f31b16e..88fd2b3e3d8a8476e360b1f3b8e52388941220cf 100644 --- a/swh/lister/tests/test_pattern.py +++ b/swh/lister/tests/test_pattern.py @@ -3,6 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +from itertools import tee from typing import TYPE_CHECKING, Any, Dict, Iterator, List import pytest @@ -334,3 +335,31 @@ def test_lister_enable_origins(swh_scheduler, enable_origins, expected): assert origins assert all(origin.enabled == expected for origin in origins) + + +@pytest.mark.parametrize("batch_size", [5, 10, 20]) +def test_lister_send_origins_with_stream_is_flushed_regularly( + swh_scheduler, mocker, batch_size +): + """Ensure the send_origins method is flushing regularly records to the scheduler""" + lister = RunnableLister( + scheduler=swh_scheduler, + url="https://example.com", + instance="example.com", + record_batch_size=batch_size, + ) + + def iterate_origins(lister: pattern.Lister) -> Iterator[ListedOrigin]: + """Basic origin iteration to ease testing.""" + for page in lister.get_pages(): + for origin in lister.get_origins_from_page(page): + yield origin + + all_origins, iterator_origins = tee(iterate_origins(lister)) + + spy = mocker.spy(lister, "scheduler") + lister.send_origins(iterator_origins) + + expected_nb_origins = len(list(all_origins)) + + assert len(spy.method_calls) == expected_nb_origins / batch_size