Skip to content
Snippets Groups Projects
Verified Commit 1f272506 authored by Antoine R. Dumont's avatar Antoine R. Dumont
Browse files

lister.pattern: Make batch record parametric and test it

This adds a test around the batch recording behavior to ensure it's not dropped by
mistake.
parent 920ed0d5
No related branches found
Tags v5.9.2
1 merge request!493lister.pattern: Make batch record parametric and test it
Pipeline #3696 passed
...@@ -89,6 +89,7 @@ class Lister(Generic[StateType, PageType]): ...@@ -89,6 +89,7 @@ class Lister(Generic[StateType, PageType]):
max_pages: the maximum number of pages listed in a full listing operation max_pages: the maximum number of pages listed in a full listing operation
max_origins_per_page: the maximum number of origins processed per page max_origins_per_page: the maximum number of origins processed per page
enable_origins: whether the created origins should be enabled or not enable_origins: whether the created origins should be enabled or not
record_batch_size: maximum number of records to flush to the scheduler at once.
Generic types: Generic types:
- *StateType*: concrete lister type; should usually be a :class:`dataclass` for - *StateType*: concrete lister type; should usually be a :class:`dataclass` for
...@@ -111,6 +112,7 @@ class Lister(Generic[StateType, PageType]): ...@@ -111,6 +112,7 @@ class Lister(Generic[StateType, PageType]):
max_pages: Optional[int] = None, max_pages: Optional[int] = None,
enable_origins: bool = True, enable_origins: bool = True,
with_github_session: bool = False, with_github_session: bool = False,
record_batch_size: int = 1000,
): ):
if not self.LISTER_NAME: if not self.LISTER_NAME:
raise ValueError("Must set the LISTER_NAME attribute on Lister classes") raise ValueError("Must set the LISTER_NAME attribute on Lister classes")
...@@ -165,6 +167,7 @@ class Lister(Generic[StateType, PageType]): ...@@ -165,6 +167,7 @@ class Lister(Generic[StateType, PageType]):
self.max_pages = max_pages self.max_pages = max_pages
self.max_origins_per_page = max_origins_per_page self.max_origins_per_page = max_origins_per_page
self.enable_origins = enable_origins self.enable_origins = enable_origins
self.record_batch_size = record_batch_size
def build_url(self, instance: str) -> str: def build_url(self, instance: str) -> str:
"""Optionally build the forge url to list. When the url is not provided in the """Optionally build the forge url to list. When the url is not provided in the
...@@ -344,8 +347,8 @@ class Lister(Generic[StateType, PageType]): ...@@ -344,8 +347,8 @@ class Lister(Generic[StateType, PageType]):
Returns: Returns:
the list of origin URLs recorded in scheduler database the list of origin URLs recorded in scheduler database
""" """
recorded_origins = [] recorded_origins: List[str] = []
for origins in grouper(origins, n=1000): for origins in grouper(origins, n=self.record_batch_size):
valid_origins = [] valid_origins = []
for origin in origins: for origin in origins:
if is_valid_origin_url(origin.url): if is_valid_origin_url(origin.url):
...@@ -354,7 +357,7 @@ class Lister(Generic[StateType, PageType]): ...@@ -354,7 +357,7 @@ class Lister(Generic[StateType, PageType]):
logger.warning("Skipping invalid origin: %s", origin.url) logger.warning("Skipping invalid origin: %s", origin.url)
ret = self.scheduler.record_listed_origins(valid_origins) ret = self.scheduler.record_listed_origins(valid_origins)
recorded_origins += [origin.url for origin in ret] recorded_origins.extend(origin.url for origin in ret)
return recorded_origins return recorded_origins
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# License: GNU General Public License version 3, or any later version # License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information # See top-level LICENSE file for more information
from itertools import tee
from typing import TYPE_CHECKING, Any, Dict, Iterator, List from typing import TYPE_CHECKING, Any, Dict, Iterator, List
import pytest import pytest
...@@ -334,3 +335,31 @@ def test_lister_enable_origins(swh_scheduler, enable_origins, expected): ...@@ -334,3 +335,31 @@ def test_lister_enable_origins(swh_scheduler, enable_origins, expected):
assert origins assert origins
assert all(origin.enabled == expected for origin in origins) assert all(origin.enabled == expected for origin in origins)
@pytest.mark.parametrize("batch_size", [5, 10, 20])
def test_lister_send_origins_with_stream_is_flushed_regularly(
swh_scheduler, mocker, batch_size
):
"""Ensure the send_origins method is flushing regularly records to the scheduler"""
lister = RunnableLister(
scheduler=swh_scheduler,
url="https://example.com",
instance="example.com",
record_batch_size=batch_size,
)
def iterate_origins(lister: pattern.Lister) -> Iterator[ListedOrigin]:
"""Basic origin iteration to ease testing."""
for page in lister.get_pages():
for origin in lister.get_origins_from_page(page):
yield origin
all_origins, iterator_origins = tee(iterate_origins(lister))
spy = mocker.spy(lister, "scheduler")
lister.send_origins(iterator_origins)
expected_nb_origins = len(list(all_origins))
assert len(spy.method_calls) == expected_nb_origins / batch_size
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment