From 64267f8f50a801302cde9aadd20dc8036a382d1f Mon Sep 17 00:00:00 2001
From: Nicolas Dandrimont <nicolas@dandrimont.eu>
Date: Mon, 5 Dec 2022 14:20:31 +0100
Subject: [PATCH] Add a flag to not enable origins listed by a lister

This cuts down one more manual step in the add forge now validation
process: we can add the relevant origins to the staging scheduler
without enabling them at all.
---
 swh/lister/pattern.py            |  9 +++++++++
 swh/lister/tests/test_pattern.py | 34 ++++++++++++++++++++++++++++++++
 2 files changed, 43 insertions(+)

diff --git a/swh/lister/pattern.py b/swh/lister/pattern.py
index c52a7469..621b643e 100644
--- a/swh/lister/pattern.py
+++ b/swh/lister/pattern.py
@@ -10,6 +10,7 @@ import logging
 from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, Set, TypeVar
 from urllib.parse import urlparse
 
+import attr
 import requests
 from tenacity.before_sleep import before_sleep_log
 
@@ -86,6 +87,7 @@ class Lister(Generic[StateType, PageType]):
         expected credentials for the given instance of that lister.
       max_pages: the maximum number of pages listed in a full listing operation
       max_origins_per_page: the maximum number of origins processed per page
+      enable_origins: whether the created origins should be enabled or not
 
     Generic types:
       - *StateType*: concrete lister type; should usually be a :class:`dataclass` for
@@ -106,6 +108,7 @@ class Lister(Generic[StateType, PageType]):
         credentials: CredentialsType = None,
         max_origins_per_page: Optional[int] = None,
         max_pages: Optional[int] = None,
+        enable_origins: bool = True,
         with_github_session: bool = False,
     ):
         if not self.LISTER_NAME:
@@ -146,6 +149,7 @@ class Lister(Generic[StateType, PageType]):
         self.recorded_origins: Set[str] = set()
         self.max_pages = max_pages
         self.max_origins_per_page = max_origins_per_page
+        self.enable_origins = enable_origins
 
     @http_retry(before_sleep=before_sleep_log(logger, logging.WARNING))
     def http_request(self, url: str, method="GET", **kwargs) -> requests.Response:
@@ -189,6 +193,11 @@ class Lister(Generic[StateType, PageType]):
                         self.max_origins_per_page,
                     )
                     origins = origins[: self.max_origins_per_page]
+                if not self.enable_origins:
+                    logger.info(
+                        "Disabling origins before sending them to the scheduler"
+                    )
+                    origins = [attr.evolve(origin, enabled=False) for origin in origins]
                 sent_origins = self.send_origins(origins)
                 self.recorded_origins.update(sent_origins)
                 full_stats.origins = len(self.recorded_origins)
diff --git a/swh/lister/tests/test_pattern.py b/swh/lister/tests/test_pattern.py
index d59ba63a..6dcd1d5a 100644
--- a/swh/lister/tests/test_pattern.py
+++ b/swh/lister/tests/test_pattern.py
@@ -282,3 +282,37 @@ def test_lister_max_origins_per_page(
 
     assert run_result.pages == 10
     assert run_result.origins == 10 * expected_origins_per_page
+
+
+@pytest.mark.parametrize(
+    "enable_origins,expected",
+    [
+        (True, True),
+        (False, False),
+        # default behavior is to enable all listed origins
+        (None, True),
+    ],
+)
+def test_lister_enable_origins(swh_scheduler, enable_origins, expected):
+    extra_kwargs = {}
+    if enable_origins is not None:
+        extra_kwargs["enable_origins"] = enable_origins
+
+    lister = ListerWithALotOfPagesWithALotOfOrigins(
+        scheduler=swh_scheduler,
+        url="https://example.org",
+        instance="example.org",
+        **extra_kwargs,
+    )
+
+    run_result = lister.run()
+    assert run_result.pages == 10
+    assert run_result.origins == 100
+
+    origins = swh_scheduler.get_listed_origins(
+        lister_id=lister.lister_obj.id, enabled=None
+    ).results
+
+    assert origins
+
+    assert all(origin.enabled == expected for origin in origins)
-- 
GitLab