From b815737054acdbebf0d71036af3b713e866beed6 Mon Sep 17 00:00:00 2001
From: Nicolas Dandrimont <nicolas@dandrimont.eu>
Date: Mon, 5 Dec 2022 14:20:19 +0100
Subject: [PATCH] Add built-in page and origin count limit to listers

This will allow more automation of the staging add forge now process:
for known-good listers, we can limit the number of origins being
processed and reduce the amount of manual steps taken for each instance.
---
 swh/lister/pattern.py            | 22 ++++++++++-
 swh/lister/tests/test_pattern.py | 67 ++++++++++++++++++++++++++++++++
 2 files changed, 88 insertions(+), 1 deletion(-)

diff --git a/swh/lister/pattern.py b/swh/lister/pattern.py
index 8a1b497a..c52a7469 100644
--- a/swh/lister/pattern.py
+++ b/swh/lister/pattern.py
@@ -84,6 +84,8 @@ class Lister(Generic[StateType, PageType]):
         identifies the :attr:`LISTER_NAME`, the second level the lister
         :attr:`instance`. The final level is a list of dicts containing the
         expected credentials for the given instance of that lister.
+      max_pages: the maximum number of pages listed in a full listing operation
+      max_origins_per_page: the maximum number of origins processed per page
 
     Generic types:
       - *StateType*: concrete lister type; should usually be a :class:`dataclass` for
@@ -102,6 +104,8 @@ class Lister(Generic[StateType, PageType]):
         url: str,
         instance: Optional[str] = None,
         credentials: CredentialsType = None,
+        max_origins_per_page: Optional[int] = None,
+        max_pages: Optional[int] = None,
         with_github_session: bool = False,
     ):
         if not self.LISTER_NAME:
@@ -140,6 +144,8 @@ class Lister(Generic[StateType, PageType]):
         )
 
         self.recorded_origins: Set[str] = set()
+        self.max_pages = max_pages
+        self.max_origins_per_page = max_origins_per_page
 
     @http_retry(before_sleep=before_sleep_log(logger, logging.WARNING))
     def http_request(self, url: str, method="GET", **kwargs) -> requests.Response:
@@ -172,11 +178,25 @@ class Lister(Generic[StateType, PageType]):
         try:
             for page in self.get_pages():
                 full_stats.pages += 1
-                origins = self.get_origins_from_page(page)
+                origins = list(self.get_origins_from_page(page))
+                if (
+                    self.max_origins_per_page
+                    and len(origins) > self.max_origins_per_page
+                ):
+                    logger.info(
+                        "Max origins per page set, truncated %s page results down to %s",
+                        len(origins),
+                        self.max_origins_per_page,
+                    )
+                    origins = origins[: self.max_origins_per_page]
                 sent_origins = self.send_origins(origins)
                 self.recorded_origins.update(sent_origins)
                 full_stats.origins = len(self.recorded_origins)
                 self.commit_page(page)
+
+                if self.max_pages and full_stats.pages >= self.max_pages:
+                    logger.info("Reached page limit of %s, terminating", self.max_pages)
+                    break
         finally:
             self.finalize()
             if self.updated:
diff --git a/swh/lister/tests/test_pattern.py b/swh/lister/tests/test_pattern.py
index 554a8d1b..d59ba63a 100644
--- a/swh/lister/tests/test_pattern.py
+++ b/swh/lister/tests/test_pattern.py
@@ -215,3 +215,70 @@ def test_listed_origins_count(swh_scheduler):
 
     assert run_result.pages == 2
     assert run_result.origins == 1
+
+
+class ListerWithALotOfPagesWithALotOfOrigins(RunnableStatelessLister):
+    def get_pages(self) -> Iterator[PageType]:
+        for page in range(10):
+            yield [
+                {"url": f"https://example.org/page{page}/origin{origin}"}
+                for origin in range(10)
+            ]
+
+
+@pytest.mark.parametrize(
+    "max_pages,expected_pages",
+    [
+        (2, 2),
+        (10, 10),
+        (100, 10),
+        # The default returns all 10 pages
+        (None, 10),
+    ],
+)
+def test_lister_max_pages(swh_scheduler, max_pages, expected_pages):
+    extra_kwargs = {}
+    if max_pages is not None:
+        extra_kwargs["max_pages"] = max_pages
+
+    lister = ListerWithALotOfPagesWithALotOfOrigins(
+        scheduler=swh_scheduler,
+        url="https://example.org",
+        instance="example.org",
+        **extra_kwargs,
+    )
+
+    run_result = lister.run()
+
+    assert run_result.pages == expected_pages
+    assert run_result.origins == expected_pages * 10
+
+
+@pytest.mark.parametrize(
+    "max_origins_per_page,expected_origins_per_page",
+    [
+        (2, 2),
+        (10, 10),
+        (100, 10),
+        # The default returns all 10 origins per page
+        (None, 10),
+    ],
+)
+def test_lister_max_origins_per_page(
+    swh_scheduler, max_origins_per_page, expected_origins_per_page
+):
+    extra_kwargs = {}
+    if max_origins_per_page is not None:
+        extra_kwargs["max_origins_per_page"] = max_origins_per_page
+
+    lister = ListerWithALotOfPagesWithALotOfOrigins(
+        scheduler=swh_scheduler,
+        url="https://example.org",
+        instance="example.org",
+        **extra_kwargs,
+    )
+
+    run_result = lister.run()
+
+    assert run_result.pages == 10
+    assert run_result.origins == 10 * expected_origins_per_page
-- 
GitLab