From 1f2725069476411fb588d6cee544177f82107e5b Mon Sep 17 00:00:00 2001
From: "Antoine R. Dumont (@ardumont)" <ardumont@softwareheritage.org>
Date: Tue, 1 Aug 2023 10:59:16 +0200
Subject: [PATCH] lister.pattern: Make batch record parametric and test it

This adds a test around the batch recording behavior to ensure it's not dropped by
mistake.
---
 swh/lister/pattern.py            |  9 ++++++---
 swh/lister/tests/test_pattern.py | 29 +++++++++++++++++++++++++++++
 2 files changed, 35 insertions(+), 3 deletions(-)

diff --git a/swh/lister/pattern.py b/swh/lister/pattern.py
index 852ab174..ca133358 100644
--- a/swh/lister/pattern.py
+++ b/swh/lister/pattern.py
@@ -89,6 +89,7 @@ class Lister(Generic[StateType, PageType]):
       max_pages: the maximum number of pages listed in a full listing operation
       max_origins_per_page: the maximum number of origins processed per page
       enable_origins: whether the created origins should be enabled or not
+      record_batch_size: maximum number of records to flush to the scheduler at once.
 
     Generic types:
       - *StateType*: concrete lister type; should usually be a :class:`dataclass` for
@@ -111,6 +112,7 @@ class Lister(Generic[StateType, PageType]):
         max_pages: Optional[int] = None,
         enable_origins: bool = True,
         with_github_session: bool = False,
+        record_batch_size: int = 1000,
     ):
         if not self.LISTER_NAME:
             raise ValueError("Must set the LISTER_NAME attribute on Lister classes")
@@ -165,6 +167,7 @@ class Lister(Generic[StateType, PageType]):
         self.max_pages = max_pages
         self.max_origins_per_page = max_origins_per_page
         self.enable_origins = enable_origins
+        self.record_batch_size = record_batch_size
 
     def build_url(self, instance: str) -> str:
         """Optionally build the forge url to list. When the url is not provided in the
@@ -344,8 +347,8 @@ class Lister(Generic[StateType, PageType]):
         Returns:
           the list of origin URLs recorded in scheduler database
         """
-        recorded_origins = []
-        for origins in grouper(origins, n=1000):
+        recorded_origins: List[str] = []
+        for origins in grouper(origins, n=self.record_batch_size):
             valid_origins = []
             for origin in origins:
                 if is_valid_origin_url(origin.url):
@@ -354,7 +357,7 @@ class Lister(Generic[StateType, PageType]):
                     logger.warning("Skipping invalid origin: %s", origin.url)
 
             ret = self.scheduler.record_listed_origins(valid_origins)
-            recorded_origins += [origin.url for origin in ret]
+            recorded_origins.extend(origin.url for origin in ret)
 
         return recorded_origins
 
diff --git a/swh/lister/tests/test_pattern.py b/swh/lister/tests/test_pattern.py
index 33dc3e7a..88fd2b3e 100644
--- a/swh/lister/tests/test_pattern.py
+++ b/swh/lister/tests/test_pattern.py
@@ -3,6 +3,7 @@
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
 
+from itertools import tee
 from typing import TYPE_CHECKING, Any, Dict, Iterator, List
 
 import pytest
@@ -334,3 +335,31 @@ def test_lister_enable_origins(swh_scheduler, enable_origins, expected):
     assert origins
 
     assert all(origin.enabled == expected for origin in origins)
+
+
+@pytest.mark.parametrize("batch_size", [5, 10, 20])
+def test_lister_send_origins_with_stream_is_flushed_regularly(
+    swh_scheduler, mocker, batch_size
+):
+    """Ensure the send_origins method is flushing regularly records to the scheduler"""
+    lister = RunnableLister(
+        scheduler=swh_scheduler,
+        url="https://example.com",
+        instance="example.com",
+        record_batch_size=batch_size,
+    )
+
+    def iterate_origins(lister: pattern.Lister) -> Iterator[ListedOrigin]:
+        """Basic origin iteration to ease testing."""
+        for page in lister.get_pages():
+            for origin in lister.get_origins_from_page(page):
+                yield origin
+
+    all_origins, iterator_origins = tee(iterate_origins(lister))
+
+    spy = mocker.spy(lister, "scheduler")
+    lister.send_origins(iterator_origins)
+
+    expected_nb_origins = len(list(all_origins))
+
+    assert len(spy.method_calls) == expected_nb_origins / batch_size
-- 
GitLab