From 0b9df1a11c798767beacf09dfed6179ddc593419 Mon Sep 17 00:00:00 2001 From: David Douard <david.douard@sdfa3.org> Date: Fri, 9 Dec 2022 15:06:16 +0100 Subject: [PATCH] Extract the journal writer part from the ProvenanceStorageJournal class This allows to use the journal writing part independently from the ProvenanceStorage proxy class, eg. for the backfiller mechanism. --- swh/provenance/storage/journal.py | 104 +++++++++++------- .../tests/test_provenance_journal_writer.py | 64 +++++++---- 2 files changed, 104 insertions(+), 64 deletions(-) diff --git a/swh/provenance/storage/journal.py b/swh/provenance/storage/journal.py index 9d55000..c29f236 100644 --- a/swh/provenance/storage/journal.py +++ b/swh/provenance/storage/journal.py @@ -44,10 +44,67 @@ class JournalMessage: return self.value +class ProvenanceStorageJournalWriter: + def __init__(self, journal): + self.journal = journal + + def content_add(self, cnts: Dict[Sha1Git, datetime]) -> None: + self.journal.write_additions( + "content", [JournalMessage(key, value) for (key, value) in cnts.items()] + ) + + def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> None: + self.journal.write_additions( + "directory", + [ + JournalMessage(key, value.date) + for (key, value) in dirs.items() + if value.date is not None + ], + ) + + def origin_add(self, orgs: Dict[Sha1Git, str]) -> None: + self.journal.write_additions( + "origin", [JournalMessage(key, value) for (key, value) in orgs.items()] + ) + + def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> None: + self.journal.write_additions( + "revision", + [ + JournalMessage(key, value.date) + for (key, value) in revs.items() + if value.date is not None + ], + ) + + def relation_add( + self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] + ) -> None: + messages = [] + for src, relations in data.items(): + for reldata in relations: + key = hashlib.sha1(src + reldata.dst + (reldata.path or b"")).digest() + messages.append( + JournalMessage( + key, + { + "src": src, + "dst": reldata.dst, + "path": reldata.path, + "dst_date": reldata.dst_date, + }, + add_id=False, + ) + ) + + self.journal.write_additions(relation.value, messages) + + class ProvenanceStorageJournal: def __init__(self, storage, journal): self.storage = storage - self.journal = journal + self.journal_writer = ProvenanceStorageJournalWriter(journal) def __enter__(self) -> ProvenanceStorageInterface: self.storage.__enter__() @@ -68,9 +125,7 @@ class ProvenanceStorageJournal: self.storage.close() def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool: - self.journal.write_additions( - "content", [JournalMessage(key, value) for (key, value) in cnts.items()] - ) + self.journal_writer.content_add(cnts) return self.storage.content_add(cnts) def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: @@ -85,14 +140,7 @@ class ProvenanceStorageJournal: return self.storage.content_get(ids) def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool: - self.journal.write_additions( - "directory", - [ - JournalMessage(key, value.date) - for (key, value) in dirs.items() - if value.date is not None - ], - ) + self.journal_writer.directory_add(dirs) return self.storage.directory_add(dirs) def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, DirectoryData]: @@ -113,23 +161,14 @@ class ProvenanceStorageJournal: return self.storage.location_get_all() def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: - self.journal.write_additions( - "origin", [JournalMessage(key, value) for (key, value) in orgs.items()] - ) + self.journal_writer.origin_add(orgs) return self.storage.origin_add(orgs) def origin_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, str]: return self.storage.origin_get(ids) def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool: - self.journal.write_additions( - "revision", - [ - JournalMessage(key, value.date) - for (key, value) in revs.items() - if value.date is not None - ], - ) + self.journal_writer.revision_add(revs) return self.storage.revision_add(revs) def revision_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, RevisionData]: @@ -138,24 +177,7 @@ class ProvenanceStorageJournal: def relation_add( self, relation: RelationType, data: Dict[Sha1Git, Set[RelationData]] ) -> bool: - messages = [] - for src, relations in data.items(): - for reldata in relations: - key = hashlib.sha1(src + reldata.dst + (reldata.path or b"")).digest() - messages.append( - JournalMessage( - key, - { - "src": src, - "dst": reldata.dst, - "path": reldata.path, - "dst_date": reldata.dst_date, - }, - add_id=False, - ) - ) - - self.journal.write_additions(relation.value, messages) + self.journal_writer.relation_add(relation, data) return self.storage.relation_add(relation, data) def relation_get( diff --git a/swh/provenance/tests/test_provenance_journal_writer.py b/swh/provenance/tests/test_provenance_journal_writer.py index 4a77690..8a2c1b2 100644 --- a/swh/provenance/tests/test_provenance_journal_writer.py +++ b/swh/provenance/tests/test_provenance_journal_writer.py @@ -3,7 +3,7 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Generator +from typing import Dict, Generator import pytest @@ -19,7 +19,7 @@ from .test_provenance_storage import TestProvenanceStorage as _TestProvenanceSto @pytest.fixture() def provenance_storage( - provenance_postgresqldb: str, + provenance_postgresqldb: Dict[str, str], ) -> Generator[ProvenanceStorageInterface, None, None]: cfg = { "storage": { @@ -38,52 +38,64 @@ def provenance_storage( class TestProvenanceStorageJournal(_TestProvenanceStorage): def test_provenance_storage_content(self, provenance_storage): super().test_provenance_storage_content(provenance_storage) - assert provenance_storage.journal - objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects} + assert provenance_storage.journal_writer.journal + objtypes = { + objtype + for (objtype, obj) in provenance_storage.journal_writer.journal.objects + } assert objtypes == {"content"} journal_objs = { obj.id - for (objtype, obj) in provenance_storage.journal.objects + for (objtype, obj) in provenance_storage.journal_writer.journal.objects if objtype == "content" } assert provenance_storage.entity_get_all(EntityType.CONTENT) == journal_objs def test_provenance_storage_directory(self, provenance_storage): super().test_provenance_storage_directory(provenance_storage) - assert provenance_storage.journal - objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects} + assert provenance_storage.journal_writer.journal + objtypes = { + objtype + for (objtype, obj) in provenance_storage.journal_writer.journal.objects + } assert objtypes == {"directory"} journal_objs = { obj.id - for (objtype, obj) in provenance_storage.journal.objects + for (objtype, obj) in provenance_storage.journal_writer.journal.objects if objtype == "directory" } assert provenance_storage.entity_get_all(EntityType.DIRECTORY) == journal_objs def test_provenance_storage_origin(self, provenance_storage): super().test_provenance_storage_origin(provenance_storage) - assert provenance_storage.journal - objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects} + assert provenance_storage.journal_writer.journal + objtypes = { + objtype + for (objtype, obj) in provenance_storage.journal_writer.journal.objects + } assert objtypes == {"origin"} journal_objs = { obj.id - for (objtype, obj) in provenance_storage.journal.objects + for (objtype, obj) in provenance_storage.journal_writer.journal.objects if objtype == "origin" } assert provenance_storage.entity_get_all(EntityType.ORIGIN) == journal_objs def test_provenance_storage_revision(self, provenance_storage): super().test_provenance_storage_revision(provenance_storage) - assert provenance_storage.journal - objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects} + assert provenance_storage.journal_writer.journal + objtypes = { + objtype + for (objtype, obj) in provenance_storage.journal_writer.journal.objects + } assert objtypes == {"revision", "origin"} journal_objs = { obj.id - for (objtype, obj) in provenance_storage.journal.objects + for (objtype, obj) in provenance_storage.journal_writer.journal.objects if objtype == "revision" } all_revisions = provenance_storage.revision_get( @@ -96,8 +108,11 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage): def test_provenance_storage_relation_revision_layer(self, provenance_storage): super().test_provenance_storage_relation_revision_layer(provenance_storage) - assert provenance_storage.journal - objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects} + assert provenance_storage.journal_writer.journal + objtypes = { + objtype + for (objtype, obj) in provenance_storage.journal_writer.journal.objects + } assert objtypes == { "content", "directory", @@ -108,7 +123,7 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage): journal_rels = { tuple(obj.value[k] for k in ("src", "dst", "path")) - for (objtype, obj) in provenance_storage.journal.objects + for (objtype, obj) in provenance_storage.journal_writer.journal.objects if objtype == "content_in_revision" } prov_rels = { @@ -122,7 +137,7 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage): journal_rels = { tuple(obj.value[k] for k in ("src", "dst", "path")) - for (objtype, obj) in provenance_storage.journal.objects + for (objtype, obj) in provenance_storage.journal_writer.journal.objects if objtype == "content_in_directory" } prov_rels = { @@ -136,7 +151,7 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage): journal_rels = { tuple(obj.value[k] for k in ("src", "dst", "path")) - for (objtype, obj) in provenance_storage.journal.objects + for (objtype, obj) in provenance_storage.journal_writer.journal.objects if objtype == "directory_in_revision" } prov_rels = { @@ -150,8 +165,11 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage): def test_provenance_storage_relation_origin_layer(self, provenance_storage): super().test_provenance_storage_relation_origin_layer(provenance_storage) - assert provenance_storage.journal - objtypes = {objtype for (objtype, obj) in provenance_storage.journal.objects} + assert provenance_storage.journal_writer.journal + objtypes = { + objtype + for (objtype, obj) in provenance_storage.journal_writer.journal.objects + } assert objtypes == { "origin", "revision_in_origin", @@ -160,7 +178,7 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage): journal_rels = { tuple(obj.value[k] for k in ("src", "dst", "path")) - for (objtype, obj) in provenance_storage.journal.objects + for (objtype, obj) in provenance_storage.journal_writer.journal.objects if objtype == "revision_in_origin" } prov_rels = { @@ -174,7 +192,7 @@ class TestProvenanceStorageJournal(_TestProvenanceStorage): journal_rels = { tuple(obj.value[k] for k in ("src", "dst", "path")) - for (objtype, obj) in provenance_storage.journal.objects + for (objtype, obj) in provenance_storage.journal_writer.journal.objects if objtype == "revision_before_revision" } prov_rels = { -- GitLab