From 65db031263bea088c954bc80c28a1e987fbc9a76 Mon Sep 17 00:00:00 2001
From: Antoine Lambert <anlambert@softwareheritage.org>
Date: Mon, 25 Mar 2024 16:35:13 +0100
Subject: [PATCH] objstorage_checker: Add support for consuming content ids
 from journal

Add ObjectStorageCheckerJournal class to consume content ids from a kafka
topic in order to check their presence in a given object storage but also
to check their integrity by fetching their bytes and recomputing checksums.

Related to #4694.
---
 swh/scrubber/base_checker.py                  |  23 +--
 swh/scrubber/journal_checker.py               |  37 ++---
 swh/scrubber/objstorage_checker.py            | 145 +++++++++++++----
 swh/scrubber/tests/conftest.py                |  25 +++
 .../tests/objstorage_checker_tests.py         | 149 +++++++++++++++---
 swh/scrubber/tests/test_journal_kafka.py      | 113 ++++---------
 6 files changed, 329 insertions(+), 163 deletions(-)

diff --git a/swh/scrubber/base_checker.py b/swh/scrubber/base_checker.py
index 5cf7fc3..3e7c36f 100644
--- a/swh/scrubber/base_checker.py
+++ b/swh/scrubber/base_checker.py
@@ -33,6 +33,12 @@ class BaseChecker(ABC):
 
         self._config: Optional[ConfigEntry] = None
         self._statsd: Optional[Statsd] = None
+        self.statsd_constant_tags = {
+            "object_type": self.object_type.name.lower(),
+            "datastore_package": self.datastore.package,
+            "datastore_cls": self.datastore.cls,
+            "datastore_instance": self.datastore.instance,
+        }
 
     @property
     def config(self) -> ConfigEntry:
@@ -59,6 +65,11 @@ class BaseChecker(ABC):
             )
         return self._statsd
 
+    @property
+    def object_type(self) -> swhids.ObjectType:
+        """Returns the type of object being checked."""
+        return self.config.object_type
+
     @property
     def check_hashes(self) -> bool:
         return self.config.check_hashes
@@ -84,17 +95,7 @@ class BasePartitionChecker(BaseChecker):
     ):
         super().__init__(db=db, config_id=config_id)
         self.limit = limit
-        self.statsd_constant_tags = {
-            "object_type": self.object_type,
-            "nb_partitions": self.nb_partitions,
-            "datastore_package": self.datastore.package,
-            "datastore_cls": self.datastore.cls,
-        }
-
-    @property
-    def object_type(self) -> swhids.ObjectType:
-        """Returns the type of object being checked."""
-        return self.config.object_type
+        self.statsd_constant_tags["nb_partitions"] = self.nb_partitions
 
     @property
     def nb_partitions(self) -> int:
diff --git a/swh/scrubber/journal_checker.py b/swh/scrubber/journal_checker.py
index dae1aae..fc574bf 100644
--- a/swh/scrubber/journal_checker.py
+++ b/swh/scrubber/journal_checker.py
@@ -21,18 +21,19 @@ from .db import Datastore, ScrubberDb
 logger = logging.getLogger(__name__)
 
 
-def get_datastore(journal_cfg) -> Datastore:
+def get_datastore(journal_cfg: Dict[str, Any]) -> Datastore:
     if journal_cfg.get("cls") == "kafka":
+        package = "journal"
+        cls = "kafka"
+        instance = {
+            "brokers": journal_cfg["brokers"],
+            "group_id": journal_cfg["group_id"],
+            "prefix": journal_cfg["prefix"],
+        }
         datastore = Datastore(
-            package="journal",
-            cls="kafka",
-            instance=json.dumps(
-                {
-                    "brokers": journal_cfg["brokers"],
-                    "group_id": journal_cfg["group_id"],
-                    "prefix": journal_cfg["prefix"],
-                }
-            ),
+            package=package,
+            cls=cls,
+            instance=json.dumps(instance),
         )
     else:
         raise NotImplementedError(
@@ -46,13 +47,14 @@ class JournalChecker(BaseChecker):
     reports errors in a separate database."""
 
     def __init__(
-        self, db: ScrubberDb, config_id: int, journal_client_config: Dict[str, Any]
+        self,
+        db: ScrubberDb,
+        config_id: int,
+        journal_client_config: Dict[str, Any],
     ):
         super().__init__(db=db, config_id=config_id)
-        self.statsd_constant_tags = {
-            "datastore_package": self.datastore.package,
-            "datastore_cls": self.datastore.cls,
-        }
+
+        object_type = self.object_type.name.lower()
 
         if self.config.check_references:
             raise ValueError(
@@ -66,9 +68,7 @@ class JournalChecker(BaseChecker):
                 "The journal_client configuration entry should not define the "
                 "object_types field; this is handled by the scrubber configuration entry"
             )
-        self.journal_client_config["object_types"] = [
-            self.config.object_type.name.lower()
-        ]
+        self.journal_client_config["object_types"] = [object_type]
         self.journal_client = get_journal_client(
             **self.journal_client_config,
             # Remove default deserializer; so process_kafka_values() gets the message
@@ -109,6 +109,7 @@ class JournalChecker(BaseChecker):
                 else:
                     object_ = cls.from_dict(kafka_to_value(message))
                     has_duplicate_dir_entries = False
+
                 real_id = object_.compute_hash()
                 if object_.id != real_id or has_duplicate_dir_entries:
                     self.db.corrupt_object_add(object_.swhid(), self.config, message)
diff --git a/swh/scrubber/objstorage_checker.py b/swh/scrubber/objstorage_checker.py
index de08a88..97cb0f0 100644
--- a/swh/scrubber/objstorage_checker.py
+++ b/swh/scrubber/objstorage_checker.py
@@ -5,17 +5,19 @@
 
 import json
 import logging
-from typing import Iterable, Optional
+from typing import Any, Dict, Iterable, List, Optional
 
-from swh.journal.serializers import value_to_kafka
+from swh.core.statsd import Statsd
+from swh.journal.client import get_journal_client
+from swh.journal.serializers import kafka_to_value, value_to_kafka
 from swh.model.model import Content
 from swh.model.swhids import ObjectType
 from swh.objstorage.exc import ObjNotFoundError
 from swh.objstorage.interface import ObjStorageInterface, objid_from_dict
 from swh.storage.interface import StorageInterface
 
-from .base_checker import BasePartitionChecker
-from .db import Datastore, ScrubberDb
+from .base_checker import BaseChecker, BasePartitionChecker
+from .db import ConfigEntry, Datastore, ScrubberDb
 
 logger = logging.getLogger(__name__)
 
@@ -29,12 +31,40 @@ def get_objstorage_datastore(objstorage_config):
     )
 
 
-class ObjectStorageChecker(BasePartitionChecker):
-    """A checker to detect missing and corrupted contents in an object storage.
+def check_content(
+    content: Content,
+    objstorage: ObjStorageInterface,
+    db: ScrubberDb,
+    config: ConfigEntry,
+    statsd: Statsd,
+) -> None:
+    content_hashes = objid_from_dict(content.hashes())
+    try:
+        content_bytes = objstorage.get(content_hashes)
+    except ObjNotFoundError:
+        if config.check_references:
+            statsd.increment("missing_object_total")
+            db.missing_object_add(id=content.swhid(), reference_ids={}, config=config)
+    else:
+        if config.check_hashes:
+            recomputed_hashes = objid_from_dict(
+                Content.from_data(content_bytes).hashes()
+            )
+            if content_hashes != recomputed_hashes:
+                statsd.increment("hash_mismatch_total")
+                db.corrupt_object_add(
+                    id=content.swhid(),
+                    config=config,
+                    serialized_object=value_to_kafka(content.to_dict()),
+                )
+
 
-    It iterates on content objects referenced in a storage instance, check they
-    are available in a given object storage instance then retrieve their bytes
-    from it in order to recompute checksums and detect corruptions."""
+class ObjectStorageCheckerPartition(BasePartitionChecker):
+    """A partition based checker to detect missing and corrupted contents in an object storage.
+
+    It iterates on content objects referenced in a storage instance, check they are available
+    in a given object storage instance then retrieve their bytes from it in order to recompute
+    checksums and detect corruptions."""
 
     def __init__(
         self,
@@ -49,15 +79,21 @@ class ObjectStorageChecker(BasePartitionChecker):
         self.objstorage = (
             objstorage if objstorage is not None else getattr(storage, "objstorage")
         )
-        self.statsd_constant_tags["datastore_instance"] = self.datastore.instance
 
-    def check_partition(self, object_type: ObjectType, partition_id: int) -> None:
-        if object_type != ObjectType.CONTENT:
+        object_type = self.object_type.name.lower()
+
+        if object_type != "content":
+            raise ValueError(
+                "ObjectStorageCheckerPartition can only check objects of type content,"
+                f"checking objects of type {object_type} is not supported."
+            )
+
+        if self.objstorage is None:
             raise ValueError(
-                "ObjectStorageChecker can only check objects of type content,"
-                f"checking objects of type {object_type.name.lower()} is not supported."
+                "An object storage must be provided to ObjectStorageCheckerPartition."
             )
 
+    def check_partition(self, object_type: ObjectType, partition_id: int) -> None:
         page_token = None
         while True:
             page = self.storage.content_get_partition(
@@ -79,24 +115,63 @@ class ObjectStorageChecker(BasePartitionChecker):
 
     def check_contents(self, contents: Iterable[Content]) -> None:
         for content in contents:
-            content_hashes = objid_from_dict(content.hashes())
-            try:
-                content_bytes = self.objstorage.get(content_hashes)
-            except ObjNotFoundError:
-                if self.check_references:
-                    self.statsd.increment("missing_object_total")
-                    self.db.missing_object_add(
-                        id=content.swhid(), reference_ids={}, config=self.config
-                    )
-            else:
-                if self.check_hashes:
-                    recomputed_hashes = objid_from_dict(
-                        Content.from_data(content_bytes).hashes()
-                    )
-                    if content_hashes != recomputed_hashes:
-                        self.statsd.increment("hash_mismatch_total")
-                        self.db.corrupt_object_add(
-                            id=content.swhid(),
-                            config=self.config,
-                            serialized_object=value_to_kafka(content.to_dict()),
-                        )
+            check_content(
+                content,
+                self.objstorage,
+                self.db,
+                self.config,
+                self.statsd,
+            )
+
+
+class ObjectStorageCheckerJournal(BaseChecker):
+    """A journal based checker to detect missing and corrupted contents in an object storage.
+
+    It iterates on content objects referenced in a kafka topic, check they are available
+    in a given object storage instance then retrieve their bytes from it in order to
+    recompute checksums and detect corruptions."""
+
+    def __init__(
+        self,
+        db: ScrubberDb,
+        config_id: int,
+        journal_client_config: Dict[str, Any],
+        objstorage: ObjStorageInterface,
+    ):
+        super().__init__(db=db, config_id=config_id)
+        self.objstorage = objstorage
+
+        object_type = self.object_type.name.lower()
+
+        if object_type != "content":
+            raise ValueError(
+                "ObjectStorageCheckerJournal can only check objects of type content,"
+                f"checking objects of type {object_type} is not supported."
+            )
+
+        self.journal_client_config = journal_client_config.copy()
+        if "object_types" in self.journal_client_config:
+            raise ValueError(
+                "The journal_client configuration entry should not define the "
+                "object_types field; this is handled by the scrubber configuration entry"
+            )
+        self.journal_client_config["object_types"] = [object_type]
+        self.journal_client = get_journal_client(
+            **self.journal_client_config,
+            # Remove default deserializer; so process_kafka_values() gets the message
+            # verbatim so it can archive it with as few modifications a possible.
+            value_deserializer=lambda obj_type, msg: msg,
+        )
+
+    def run(self) -> None:
+        self.journal_client.process(self.process_kafka_messages)
+
+    def process_kafka_messages(self, all_messages: Dict[str, List[bytes]]):
+        for message in all_messages["content"]:
+            check_content(
+                Content.from_dict(kafka_to_value(message)),
+                self.objstorage,
+                self.db,
+                self.config,
+                self.statsd,
+            )
diff --git a/swh/scrubber/tests/conftest.py b/swh/scrubber/tests/conftest.py
index 56bf2c8..592fa10 100644
--- a/swh/scrubber/tests/conftest.py
+++ b/swh/scrubber/tests/conftest.py
@@ -12,6 +12,7 @@ from pytest_postgresql import factories
 
 from swh.core.db.db_utils import initialize_database_for_module
 from swh.journal.serializers import value_to_kafka
+from swh.journal.writer import get_journal_writer
 from swh.model.hashutil import hash_to_bytes
 from swh.model.model import Directory, DirectoryEntry
 from swh.model.swhids import ObjectType
@@ -110,3 +111,27 @@ def corrupt_object(scrubber_db, config_entry):
         first_occurrence=datetime.datetime.now(tz=datetime.timezone.utc),
         object_=value_to_kafka(CORRUPT_DIRECTORY.to_dict()),
     )
+
+
+@pytest.fixture
+def journal_client_config(
+    kafka_server: str, kafka_prefix: str, kafka_consumer_group: str
+):
+    return dict(
+        cls="kafka",
+        brokers=kafka_server,
+        group_id=kafka_consumer_group,
+        prefix=kafka_prefix,
+        on_eof="stop",
+    )
+
+
+@pytest.fixture
+def journal_writer(kafka_server: str, kafka_prefix: str):
+    return get_journal_writer(
+        cls="kafka",
+        brokers=[kafka_server],
+        client_id="kafka_writer",
+        prefix=kafka_prefix,
+        anonymize=False,
+    )
diff --git a/swh/scrubber/tests/objstorage_checker_tests.py b/swh/scrubber/tests/objstorage_checker_tests.py
index 73fc095..8ccbe6d 100644
--- a/swh/scrubber/tests/objstorage_checker_tests.py
+++ b/swh/scrubber/tests/objstorage_checker_tests.py
@@ -4,6 +4,7 @@
 # See top-level LICENSE file for more information
 
 from datetime import datetime, timedelta, timezone
+import json
 
 import attr
 import pytest
@@ -12,7 +13,8 @@ from swh.journal.serializers import kafka_to_value
 from swh.model.swhids import CoreSWHID, ObjectType
 from swh.model.tests import swh_model_data
 from swh.scrubber.objstorage_checker import (
-    ObjectStorageChecker,
+    ObjectStorageCheckerJournal,
+    ObjectStorageCheckerPartition,
     get_objstorage_datastore,
 )
 
@@ -32,35 +34,49 @@ def datastore(swh_objstorage_config):
 
 
 @pytest.fixture
-def objstorage_checker(swh_storage, swh_objstorage, scrubber_db, datastore):
+def objstorage_checker_partition(swh_storage, swh_objstorage, scrubber_db, datastore):
     nb_partitions = len(EXPECTED_PARTITIONS)
     config_id = scrubber_db.config_add(
-        "cfg_objstorage_checker", datastore, ObjectType.CONTENT, nb_partitions
+        "cfg_objstorage_checker_partition", datastore, ObjectType.CONTENT, nb_partitions
+    )
+    return ObjectStorageCheckerPartition(
+        scrubber_db, config_id, swh_storage, swh_objstorage
+    )
+
+
+@pytest.fixture
+def objstorage_checker_journal(
+    journal_client_config, swh_objstorage, scrubber_db, datastore
+):
+    config_id = scrubber_db.config_add(
+        "cfg_objstorage_checker_journal", datastore, ObjectType.CONTENT, nb_partitions=1
+    )
+    return ObjectStorageCheckerJournal(
+        scrubber_db, config_id, journal_client_config, swh_objstorage
     )
-    return ObjectStorageChecker(scrubber_db, config_id, swh_storage, swh_objstorage)
 
 
-def test_objstorage_checker_no_corruption(
-    swh_storage, swh_objstorage, objstorage_checker
+def test_objstorage_checker_partition_no_corruption(
+    swh_storage, swh_objstorage, objstorage_checker_partition
 ):
     swh_storage.content_add(swh_model_data.CONTENTS)
     swh_objstorage.add_batch({c.sha1: c.data for c in swh_model_data.CONTENTS})
 
-    objstorage_checker.run()
+    objstorage_checker_partition.run()
 
-    scrubber_db = objstorage_checker.db
+    scrubber_db = objstorage_checker_partition.db
     assert list(scrubber_db.corrupt_object_iter()) == []
 
     assert_checked_ranges(
         scrubber_db,
-        [(ObjectType.CONTENT, objstorage_checker.config_id)],
+        [(ObjectType.CONTENT, objstorage_checker_partition.config_id)],
         EXPECTED_PARTITIONS,
     )
 
 
 @pytest.mark.parametrize("missing_idx", range(0, len(swh_model_data.CONTENTS), 5))
-def test_objstorage_checker_missing_content(
-    swh_storage, swh_objstorage, objstorage_checker, missing_idx
+def test_objstorage_checker_partition_missing_content(
+    swh_storage, swh_objstorage, objstorage_checker_partition, missing_idx
 ):
     contents = list(swh_model_data.CONTENTS)
     swh_storage.content_add(contents)
@@ -69,15 +85,15 @@ def test_objstorage_checker_missing_content(
     )
 
     before_date = datetime.now(tz=timezone.utc)
-    objstorage_checker.run()
+    objstorage_checker_partition.run()
     after_date = datetime.now(tz=timezone.utc)
 
-    scrubber_db = objstorage_checker.db
+    scrubber_db = objstorage_checker_partition.db
 
     missing_objects = list(scrubber_db.missing_object_iter())
     assert len(missing_objects) == 1
     assert missing_objects[0].id == contents[missing_idx].swhid()
-    assert missing_objects[0].config.datastore == objstorage_checker.datastore
+    assert missing_objects[0].config.datastore == objstorage_checker_partition.datastore
     assert (
         before_date - timedelta(seconds=5)
         <= missing_objects[0].first_occurrence
@@ -86,7 +102,7 @@ def test_objstorage_checker_missing_content(
 
     assert_checked_ranges(
         scrubber_db,
-        [(ObjectType.CONTENT, objstorage_checker.config_id)],
+        [(ObjectType.CONTENT, objstorage_checker_partition.config_id)],
         EXPECTED_PARTITIONS,
         before_date,
         after_date,
@@ -94,8 +110,8 @@ def test_objstorage_checker_missing_content(
 
 
 @pytest.mark.parametrize("corrupt_idx", range(0, len(swh_model_data.CONTENTS), 5))
-def test_objstorage_checker_corrupt_content(
-    swh_storage, swh_objstorage, objstorage_checker, corrupt_idx
+def test_objstorage_checker_partition_corrupt_content(
+    swh_storage, swh_objstorage, objstorage_checker_partition, corrupt_idx
 ):
     contents = list(swh_model_data.CONTENTS)
     contents[corrupt_idx] = attr.evolve(contents[corrupt_idx], sha1_git=b"\x00" * 20)
@@ -103,17 +119,17 @@ def test_objstorage_checker_corrupt_content(
     swh_objstorage.add_batch({c.sha1: c.data for c in contents})
 
     before_date = datetime.now(tz=timezone.utc)
-    objstorage_checker.run()
+    objstorage_checker_partition.run()
     after_date = datetime.now(tz=timezone.utc)
 
-    scrubber_db = objstorage_checker.db
+    scrubber_db = objstorage_checker_partition.db
 
     corrupt_objects = list(scrubber_db.corrupt_object_iter())
     assert len(corrupt_objects) == 1
     assert corrupt_objects[0].id == CoreSWHID.from_string(
         "swh:1:cnt:0000000000000000000000000000000000000000"
     )
-    assert corrupt_objects[0].config.datastore == objstorage_checker.datastore
+    assert corrupt_objects[0].config.datastore == objstorage_checker_partition.datastore
     assert (
         before_date - timedelta(seconds=5)
         <= corrupt_objects[0].first_occurrence
@@ -126,8 +142,99 @@ def test_objstorage_checker_corrupt_content(
 
     assert_checked_ranges(
         scrubber_db,
-        [(ObjectType.CONTENT, objstorage_checker.config_id)],
+        [(ObjectType.CONTENT, objstorage_checker_partition.config_id)],
         EXPECTED_PARTITIONS,
         before_date,
         after_date,
     )
+
+
+def test_objstorage_checker_journal_contents_no_corruption(
+    scrubber_db,
+    journal_writer,
+    journal_client_config,
+    objstorage_checker_journal,
+):
+    journal_writer.write_additions("content", swh_model_data.CONTENTS)
+
+    gid = journal_client_config["group_id"] + "_"
+
+    object_type = "content"
+    journal_client_config["group_id"] = gid + object_type
+
+    objstorage_checker_journal.objstorage.add_batch(
+        {c.sha1: c.data for c in swh_model_data.CONTENTS}
+    )
+    objstorage_checker_journal.run()
+    objstorage_checker_journal.journal_client.close()
+
+    assert list(scrubber_db.corrupt_object_iter()) == []
+
+
+@pytest.mark.parametrize("corrupt_idx", range(0, len(swh_model_data.CONTENTS), 5))
+def test_objstorage_checker_journal_corrupt_content(
+    scrubber_db,
+    journal_writer,
+    objstorage_checker_journal,
+    swh_objstorage_config,
+    corrupt_idx,
+):
+    contents = list(swh_model_data.CONTENTS)
+    contents[corrupt_idx] = attr.evolve(contents[corrupt_idx], sha1_git=b"\x00" * 20)
+
+    journal_writer.write_additions("content", contents)
+
+    before_date = datetime.now(tz=timezone.utc)
+
+    objstorage_checker_journal.objstorage.add_batch({c.sha1: c.data for c in contents})
+    objstorage_checker_journal.run()
+    after_date = datetime.now(tz=timezone.utc)
+
+    corrupt_objects = list(scrubber_db.corrupt_object_iter())
+    assert len(corrupt_objects) == 1
+    assert corrupt_objects[0].id == CoreSWHID.from_string(
+        "swh:1:cnt:0000000000000000000000000000000000000000"
+    )
+    assert corrupt_objects[0].config.datastore.package == "objstorage"
+    assert corrupt_objects[0].config.datastore.cls == swh_objstorage_config.pop("cls")
+    assert corrupt_objects[0].config.datastore.instance == json.dumps(
+        swh_objstorage_config
+    )
+    assert (
+        before_date - timedelta(seconds=5)
+        <= corrupt_objects[0].first_occurrence
+        <= after_date + timedelta(seconds=5)
+    )
+    corrupted_content = contents[corrupt_idx].to_dict()
+    corrupted_content.pop("data")
+    assert kafka_to_value(corrupt_objects[0].object_) == corrupted_content
+
+
+@pytest.mark.parametrize("missing_idx", range(0, len(swh_model_data.CONTENTS), 5))
+def test_objstorage_checker_journal_missing_content(
+    scrubber_db,
+    journal_writer,
+    objstorage_checker_journal,
+    missing_idx,
+):
+    contents = list(swh_model_data.CONTENTS)
+
+    journal_writer.write_additions("content", contents)
+
+    before_date = datetime.now(tz=timezone.utc)
+
+    objstorage_checker_journal.objstorage.add_batch(
+        {c.sha1: c.data for i, c in enumerate(contents) if i != missing_idx}
+    )
+    objstorage_checker_journal.run()
+    after_date = datetime.now(tz=timezone.utc)
+
+    missing_objects = list(scrubber_db.missing_object_iter())
+    assert len(missing_objects) == 1
+    assert missing_objects[0].id == contents[missing_idx].swhid()
+    assert missing_objects[0].config.datastore == objstorage_checker_journal.datastore
+    assert (
+        before_date - timedelta(seconds=5)
+        <= missing_objects[0].first_occurrence
+        <= after_date + timedelta(seconds=5)
+    )
diff --git a/swh/scrubber/tests/test_journal_kafka.py b/swh/scrubber/tests/test_journal_kafka.py
index 6f5096f..206d881 100644
--- a/swh/scrubber/tests/test_journal_kafka.py
+++ b/swh/scrubber/tests/test_journal_kafka.py
@@ -10,7 +10,6 @@ import attr
 import pytest
 
 from swh.journal.serializers import kafka_to_value
-from swh.journal.writer import get_journal_writer
 from swh.model import model, swhids
 from swh.model.swhids import ObjectType
 from swh.model.tests import swh_model_data
@@ -18,55 +17,26 @@ from swh.scrubber.db import Datastore
 from swh.scrubber.journal_checker import JournalChecker, get_datastore
 
 
-def journal_client_config(
-    kafka_server: str, kafka_prefix: str, kafka_consumer_group: str
-):
-    return dict(
-        cls="kafka",
-        brokers=kafka_server,
-        group_id=kafka_consumer_group,
-        prefix=kafka_prefix,
-        on_eof="stop",
-    )
-
-
-def journal_writer(kafka_server: str, kafka_prefix: str):
-    return get_journal_writer(
-        cls="kafka",
-        brokers=[kafka_server],
-        client_id="kafka_writer",
-        prefix=kafka_prefix,
-        anonymize=False,
-    )
-
-
 @pytest.fixture
-def datastore(
-    kafka_server: str, kafka_prefix: str, kafka_consumer_group: str
-) -> Datastore:
-    journal_config = journal_client_config(
-        kafka_server, kafka_prefix, kafka_consumer_group
-    )
-    datastore = get_datastore(journal_config)
-    return datastore
+def datastore(journal_client_config) -> Datastore:
+    return get_datastore(journal_client_config)
 
 
 def test_no_corruption(
-    scrubber_db, kafka_server, kafka_prefix, kafka_consumer_group, datastore
+    scrubber_db,
+    datastore,
+    journal_writer,
+    journal_client_config,
 ):
-    writer = journal_writer(kafka_server, kafka_prefix)
-    writer.write_additions("directory", swh_model_data.DIRECTORIES)
-    writer.write_additions("revision", swh_model_data.REVISIONS)
-    writer.write_additions("release", swh_model_data.RELEASES)
-    writer.write_additions("snapshot", swh_model_data.SNAPSHOTS)
+    journal_writer.write_additions("directory", swh_model_data.DIRECTORIES)
+    journal_writer.write_additions("revision", swh_model_data.REVISIONS)
+    journal_writer.write_additions("release", swh_model_data.RELEASES)
+    journal_writer.write_additions("snapshot", swh_model_data.SNAPSHOTS)
 
-    journal_cfg = journal_client_config(
-        kafka_server, kafka_prefix, kafka_consumer_group
-    )
-    gid = journal_cfg["group_id"] + "_"
+    gid = journal_client_config["group_id"] + "_"
 
     for object_type in ("directory", "revision", "release", "snapshot"):
-        journal_cfg["group_id"] = gid + object_type
+        journal_client_config["group_id"] = gid + object_type
         config_id = scrubber_db.config_add(
             name=f"cfg_{object_type}",
             datastore=datastore,
@@ -77,7 +47,7 @@ def test_no_corruption(
         jc = JournalChecker(
             db=scrubber_db,
             config_id=config_id,
-            journal_client_config=journal_cfg,
+            journal_client_config=journal_client_config,
         )
         jc.run()
         jc.journal_client.close()
@@ -88,10 +58,9 @@ def test_no_corruption(
 @pytest.mark.parametrize("corrupt_idx", range(len(swh_model_data.SNAPSHOTS)))
 def test_corrupt_snapshot(
     scrubber_db,
-    kafka_server,
-    kafka_prefix,
-    kafka_consumer_group,
     datastore,
+    journal_writer,
+    journal_client_config,
     corrupt_idx,
 ):
     config_id = scrubber_db.config_add(
@@ -104,16 +73,13 @@ def test_corrupt_snapshot(
     snapshots = list(swh_model_data.SNAPSHOTS)
     snapshots[corrupt_idx] = attr.evolve(snapshots[corrupt_idx], id=b"\x00" * 20)
 
-    writer = journal_writer(kafka_server, kafka_prefix)
-    writer.write_additions("snapshot", snapshots)
+    journal_writer.write_additions("snapshot", snapshots)
 
     before_date = datetime.datetime.now(tz=datetime.timezone.utc)
     JournalChecker(
         db=scrubber_db,
         config_id=config_id,
-        journal_client_config=journal_client_config(
-            kafka_server, kafka_prefix, kafka_consumer_group
-        ),
+        journal_client_config=journal_client_config,
     ).run()
     after_date = datetime.datetime.now(tz=datetime.timezone.utc)
 
@@ -136,10 +102,9 @@ def test_corrupt_snapshot(
 
 def test_corrupt_snapshots(
     scrubber_db,
-    kafka_server,
-    kafka_prefix,
-    kafka_consumer_group,
     datastore,
+    journal_writer,
+    journal_client_config,
 ):
     config_id = scrubber_db.config_add(
         name="cfg_snapshot",
@@ -152,15 +117,12 @@ def test_corrupt_snapshots(
     for i in (0, 1):
         snapshots[i] = attr.evolve(snapshots[i], id=bytes([i]) * 20)
 
-    writer = journal_writer(kafka_server, kafka_prefix)
-    writer.write_additions("snapshot", snapshots)
+    journal_writer.write_additions("snapshot", snapshots)
 
     JournalChecker(
         db=scrubber_db,
         config_id=config_id,
-        journal_client_config=journal_client_config(
-            kafka_server, kafka_prefix, kafka_consumer_group
-        ),
+        journal_client_config=journal_client_config,
     ).run()
 
     corrupt_objects = list(scrubber_db.corrupt_object_iter())
@@ -176,10 +138,10 @@ def test_corrupt_snapshots(
 
 def test_duplicate_directory_entries(
     scrubber_db,
-    kafka_server,
-    kafka_prefix,
-    kafka_consumer_group,
     datastore,
+    journal_writer,
+    kafka_prefix,
+    journal_client_config,
 ):
     config_id = scrubber_db.config_add(
         name="cfg_directory",
@@ -213,23 +175,22 @@ def test_duplicate_directory_entries(
         + b"0 filename\x00"
         + b"\x02" * 20
     )
-    dupe_directory = {
+    dup_directory = {
         "id": hashlib.sha1(raw_manifest).digest(),
         "entries": corrupt_directory["entries"],
         "raw_manifest": raw_manifest,
     }
 
-    writer = journal_writer(kafka_server, kafka_prefix)
-    writer.send(f"{kafka_prefix}.directory", directory.id, directory.to_dict())
-    writer.send(f"{kafka_prefix}.directory", corrupt_directory["id"], corrupt_directory)
-    writer.send(f"{kafka_prefix}.directory", dupe_directory["id"], dupe_directory)
+    journal_writer.send(f"{kafka_prefix}.directory", directory.id, directory.to_dict())
+    journal_writer.send(
+        f"{kafka_prefix}.directory", corrupt_directory["id"], corrupt_directory
+    )
+    journal_writer.send(f"{kafka_prefix}.directory", dup_directory["id"], dup_directory)
 
     JournalChecker(
         db=scrubber_db,
         config_id=config_id,
-        journal_client_config=journal_client_config(
-            kafka_server, kafka_prefix, kafka_consumer_group
-        ),
+        journal_client_config=journal_client_config,
     ).run()
 
     corrupt_objects = list(scrubber_db.corrupt_object_iter())
@@ -238,16 +199,14 @@ def test_duplicate_directory_entries(
         swhids.CoreSWHID.from_string(swhid)
         for swhid in [
             "swh:1:dir:0000000000000000000000000000000000000000",
-            f"swh:1:dir:{dupe_directory['id'].hex()}",
+            f"swh:1:dir:{dup_directory['id'].hex()}",
         ]
     }
 
 
 def test_check_references_raises(
     scrubber_db,
-    kafka_server,
-    kafka_prefix,
-    kafka_consumer_group,
+    journal_client_config,
     datastore,
 ):
     config_id = scrubber_db.config_add(
@@ -257,12 +216,10 @@ def test_check_references_raises(
         nb_partitions=1,
         check_references=True,
     )
-    journal_config = journal_client_config(
-        kafka_server, kafka_prefix, kafka_consumer_group
-    )
+
     with pytest.raises(ValueError):
         JournalChecker(
             db=scrubber_db,
             config_id=config_id,
-            journal_client_config=journal_config,
+            journal_client_config=journal_client_config,
         )
-- 
GitLab