Skip to content
Snippets Groups Projects
Verified Commit 1f272506 authored by Antoine R. Dumont's avatar Antoine R. Dumont
Browse files

lister.pattern: Make batch record parametric and test it

This adds a test around the batch recording behavior to ensure it's not dropped by
mistake.
parent 920ed0d5
No related branches found
Tags v5.9.2
1 merge request!493lister.pattern: Make batch record parametric and test it
Pipeline #3696 passed
......@@ -89,6 +89,7 @@ class Lister(Generic[StateType, PageType]):
max_pages: the maximum number of pages listed in a full listing operation
max_origins_per_page: the maximum number of origins processed per page
enable_origins: whether the created origins should be enabled or not
record_batch_size: maximum number of records to flush to the scheduler at once.
Generic types:
- *StateType*: concrete lister type; should usually be a :class:`dataclass` for
......@@ -111,6 +112,7 @@ class Lister(Generic[StateType, PageType]):
max_pages: Optional[int] = None,
enable_origins: bool = True,
with_github_session: bool = False,
record_batch_size: int = 1000,
):
if not self.LISTER_NAME:
raise ValueError("Must set the LISTER_NAME attribute on Lister classes")
......@@ -165,6 +167,7 @@ class Lister(Generic[StateType, PageType]):
self.max_pages = max_pages
self.max_origins_per_page = max_origins_per_page
self.enable_origins = enable_origins
self.record_batch_size = record_batch_size
def build_url(self, instance: str) -> str:
"""Optionally build the forge url to list. When the url is not provided in the
......@@ -344,8 +347,8 @@ class Lister(Generic[StateType, PageType]):
Returns:
the list of origin URLs recorded in scheduler database
"""
recorded_origins = []
for origins in grouper(origins, n=1000):
recorded_origins: List[str] = []
for origins in grouper(origins, n=self.record_batch_size):
valid_origins = []
for origin in origins:
if is_valid_origin_url(origin.url):
......@@ -354,7 +357,7 @@ class Lister(Generic[StateType, PageType]):
logger.warning("Skipping invalid origin: %s", origin.url)
ret = self.scheduler.record_listed_origins(valid_origins)
recorded_origins += [origin.url for origin in ret]
recorded_origins.extend(origin.url for origin in ret)
return recorded_origins
......
......@@ -3,6 +3,7 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from itertools import tee
from typing import TYPE_CHECKING, Any, Dict, Iterator, List
import pytest
......@@ -334,3 +335,31 @@ def test_lister_enable_origins(swh_scheduler, enable_origins, expected):
assert origins
assert all(origin.enabled == expected for origin in origins)
@pytest.mark.parametrize("batch_size", [5, 10, 20])
def test_lister_send_origins_with_stream_is_flushed_regularly(
swh_scheduler, mocker, batch_size
):
"""Ensure the send_origins method is flushing regularly records to the scheduler"""
lister = RunnableLister(
scheduler=swh_scheduler,
url="https://example.com",
instance="example.com",
record_batch_size=batch_size,
)
def iterate_origins(lister: pattern.Lister) -> Iterator[ListedOrigin]:
"""Basic origin iteration to ease testing."""
for page in lister.get_pages():
for origin in lister.get_origins_from_page(page):
yield origin
all_origins, iterator_origins = tee(iterate_origins(lister))
spy = mocker.spy(lister, "scheduler")
lister.send_origins(iterator_origins)
expected_nb_origins = len(list(all_origins))
assert len(spy.method_calls) == expected_nb_origins / batch_size
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment