From 0b9df1a11c798767beacf09dfed6179ddc593419 Mon Sep 17 00:00:00 2001
From: David Douard <david.douard@sdfa3.org>
Date: Fri, 9 Dec 2022 15:06:16 +0100
Subject: [PATCH] Extract the journal writer part from the
 ProvenanceStorageJournal class

This allows to use the journal writing part independently from the
ProvenanceStorage proxy class, eg. for the backfiller mechanism.
---
 swh/provenance/storage/journal.py             | 104 +++++++++++-------
 .../tests/test_provenance_journal_writer.py   |  64 +++++++----
 2 files changed, 104 insertions(+), 64 deletions(-)

diff --git a/swh/provenance/storage/journal.py b/swh/provenance/storage/journal.py
index 9d55000..c29f236 100644
--- a/swh/provenance/storage/journal.py
+++ b/swh/provenance/storage/journal.py
@@ -44,10 +44,67 @@ class JournalMessage:
             return self.value
 
 
+class ProvenanceStorageJournalWriter:
+    def __init__(self, journal):
+        self.journal = journal
+
+    def content_add(self, cnts: Dict[Sha1Git, datetime]) -> None:
+        self.journal.write_additions(
+            "content", [JournalMessage(key, value) for (key, value) in cnts.items()]
+        )
+
+    def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> None:
+        self.journal.write_additions(
+            "directory",
+            [
+                JournalMessage(key, value.date)
+                for (key, value) in dirs.items()
+                if value.date is not None
+            ],
+        )
+
+    def origin_add(self, orgs: Dict[Sha1Git, str]) -> None:
+        self.journal.write_additions(
+            "origin", [JournalMessage(key, value) for (key, value) in orgs.items()]
+        )
+
+    def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> None:
+        self.journal.write_additions(
+            "revision",
+            [
+                JournalMessage(key, value.date)
+                for (key, value) in revs.items()
+                if value.date is not None
+            ],
+        )
+
+    def relation_add(
+        self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]]
+    ) -> None:
+        messages = []
+        for src, relations in data.items():
+            for reldata in relations:
+                key = hashlib.sha1(src + reldata.dst + (reldata.path or b"")).digest()
+                messages.append(
+                    JournalMessage(
+                        key,
+                        {
+                            "src": src,
+                            "dst": reldata.dst,
+                            "path": reldata.path,
+                            "dst_date": reldata.dst_date,
+                        },
+                        add_id=False,
+                    )
+                )
+
+        self.journal.write_additions(relation.value, messages)
+
+
 class ProvenanceStorageJournal:
     def __init__(self, storage, journal):
         self.storage = storage
-        self.journal = journal
+        self.journal_writer = ProvenanceStorageJournalWriter(journal)
 
     def __enter__(self) -> ProvenanceStorageInterface:
         self.storage.__enter__()
@@ -68,9 +125,7 @@ class ProvenanceStorageJournal:
         self.storage.close()
 
     def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool:
-        self.journal.write_additions(
-            "content", [JournalMessage(key, value) for (key, value) in cnts.items()]
-        )
+        self.journal_writer.content_add(cnts)
         return self.storage.content_add(cnts)
 
     def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]:
@@ -85,14 +140,7 @@ class ProvenanceStorageJournal:
         return self.storage.content_get(ids)
 
     def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool:
-        self.journal.write_additions(
-            "directory",
-            [
-                JournalMessage(key, value.date)
-                for (key, value) in dirs.items()
-                if value.date is not None
-            ],
-        )
+        self.journal_writer.directory_add(dirs)
         return self.storage.directory_add(dirs)
 
     def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]:
@@ -113,23 +161,14 @@ class ProvenanceStorageJournal:
         return self.storage.location_get_all()
 
     def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool:
-        self.journal.write_additions(
-            "origin", [JournalMessage(key, value) for (key, value) in orgs.items()]
-        )
+        self.journal_writer.origin_add(orgs)
         return self.storage.origin_add(orgs)
 
     def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]:
         return self.storage.origin_get(ids)
 
     def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool:
-        self.journal.write_additions(
-            "revision",
-            [
-                JournalMessage(key, value.date)
-                for (key, value) in revs.items()
-                if value.date is not None
-            ],
-        )
+        self.journal_writer.revision_add(revs)
         return self.storage.revision_add(revs)
 
     def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]:
@@ -138,24 +177,7 @@ class ProvenanceStorageJournal:
     def relation_add(
         self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]]
     ) -> bool:
-        messages = []
-        for src, relations in data.items():
-            for reldata in relations:
-                key = hashlib.sha1(src + reldata.dst + (reldata.path or b"")).digest()
-                messages.append(
-                    JournalMessage(
-                        key,
-                        {
-                            "src": src,
-                            "dst": reldata.dst,
-                            "path": reldata.path,
-                            "dst_date": reldata.dst_date,
-                        },
-                        add_id=False,
-                    )
-                )
-
-        self.journal.write_additions(relation.value, messages)
+        self.journal_writer.relation_add(relation, data)
         return self.storage.relation_add(relation, data)
 
     def relation_get(
diff --git a/swh/provenance/tests/test_provenance_journal_writer.py b/swh/provenance/tests/test_provenance_journal_writer.py
index 4a77690..8a2c1b2 100644
--- a/swh/provenance/tests/test_provenance_journal_writer.py
+++ b/swh/provenance/tests/test_provenance_journal_writer.py
@@ -3,7 +3,7 @@
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
 
-from typing import Generator
+from typing import Dict, Generator
 
 import pytest
 
@@ -19,7 +19,7 @@ from .test_provenance_storage import TestProvenanceStorage as _TestProvenanceSto
 
 @pytest.fixture()
 def provenance_storage(
-    provenance_postgresqldb: str,
+    provenance_postgresqldb: Dict[str, str],
 ) -> Generator[ProvenanceStorageInterface, None, None]:
     cfg = {
         "storage": {
@@ -38,52 +38,64 @@ def provenance_storage(
 class TestProvenanceStorageJournal(_TestProvenanceStorage):
     def test_provenance_storage_content(self, provenance_storage):
         super().test_provenance_storage_content(provenance_storage)
-        assert provenance_storage.journal
-        objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects}
+        assert provenance_storage.journal_writer.journal
+        objtypes = {
+            objtype
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
+        }
         assert objtypes == {"content"}
 
         journal_objs = {
             obj.id
-            for (objtype, obj) in provenance_storage.journal.objects
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
             if objtype == "content"
         }
         assert provenance_storage.entity_get_all(EntityType.CONTENT) == journal_objs
 
     def test_provenance_storage_directory(self, provenance_storage):
         super().test_provenance_storage_directory(provenance_storage)
-        assert provenance_storage.journal
-        objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects}
+        assert provenance_storage.journal_writer.journal
+        objtypes = {
+            objtype
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
+        }
         assert objtypes == {"directory"}
 
         journal_objs = {
             obj.id
-            for (objtype, obj) in provenance_storage.journal.objects
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
             if objtype == "directory"
         }
         assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == journal_objs
 
     def test_provenance_storage_origin(self, provenance_storage):
         super().test_provenance_storage_origin(provenance_storage)
-        assert provenance_storage.journal
-        objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects}
+        assert provenance_storage.journal_writer.journal
+        objtypes = {
+            objtype
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
+        }
         assert objtypes == {"origin"}
 
         journal_objs = {
             obj.id
-            for (objtype, obj) in provenance_storage.journal.objects
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
             if objtype == "origin"
         }
         assert provenance_storage.entity_get_all(EntityType.ORIGIN) == journal_objs
 
     def test_provenance_storage_revision(self, provenance_storage):
         super().test_provenance_storage_revision(provenance_storage)
-        assert provenance_storage.journal
-        objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects}
+        assert provenance_storage.journal_writer.journal
+        objtypes = {
+            objtype
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
+        }
         assert objtypes == {"revision", "origin"}
 
         journal_objs = {
             obj.id
-            for (objtype, obj) in provenance_storage.journal.objects
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
             if objtype == "revision"
         }
         all_revisions = provenance_storage.revision_get(
@@ -96,8 +108,11 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage):
 
     def test_provenance_storage_relation_revision_layer(self, provenance_storage):
         super().test_provenance_storage_relation_revision_layer(provenance_storage)
-        assert provenance_storage.journal
-        objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects}
+        assert provenance_storage.journal_writer.journal
+        objtypes = {
+            objtype
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
+        }
         assert objtypes == {
             "content",
             "directory",
@@ -108,7 +123,7 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage):
 
         journal_rels = {
             tuple(obj.value[k] for k in ("src", "dst", "path"))
-            for (objtype, obj) in provenance_storage.journal.objects
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
             if objtype == "content_in_revision"
         }
         prov_rels = {
@@ -122,7 +137,7 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage):
 
         journal_rels = {
             tuple(obj.value[k] for k in ("src", "dst", "path"))
-            for (objtype, obj) in provenance_storage.journal.objects
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
             if objtype == "content_in_directory"
         }
         prov_rels = {
@@ -136,7 +151,7 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage):
 
         journal_rels = {
             tuple(obj.value[k] for k in ("src", "dst", "path"))
-            for (objtype, obj) in provenance_storage.journal.objects
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
             if objtype == "directory_in_revision"
         }
         prov_rels = {
@@ -150,8 +165,11 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage):
 
     def test_provenance_storage_relation_origin_layer(self, provenance_storage):
         super().test_provenance_storage_relation_origin_layer(provenance_storage)
-        assert provenance_storage.journal
-        objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects}
+        assert provenance_storage.journal_writer.journal
+        objtypes = {
+            objtype
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
+        }
         assert objtypes == {
             "origin",
             "revision_in_origin",
@@ -160,7 +178,7 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage):
 
         journal_rels = {
             tuple(obj.value[k] for k in ("src", "dst", "path"))
-            for (objtype, obj) in provenance_storage.journal.objects
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
             if objtype == "revision_in_origin"
         }
         prov_rels = {
@@ -174,7 +192,7 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage):
 
         journal_rels = {
             tuple(obj.value[k] for k in ("src", "dst", "path"))
-            for (objtype, obj) in provenance_storage.journal.objects
+            for (objtype, obj) in provenance_storage.journal_writer.journal.objects
             if objtype == "revision_before_revision"
         }
         prov_rels = {
-- 
GitLab