From ae2ea7a1dd90304f41313303f0f5762b502d482f Mon Sep 17 00:00:00 2001
From: David Douard <david.douard@sdfa3.org>
Date: Wed, 10 Jul 2024 12:30:00 +0200
Subject: [PATCH] Update to swh.model 6.13 introducing ModelObjectType

---
 requirements-swh.txt                       |  2 +-
 swh/dataset/exporter.py                    |  8 +++--
 swh/dataset/journalprocessor.py            |  4 +--
 swh/dataset/test/test_edges.py             | 23 +++++++-------
 swh/dataset/test/test_journal_processor.py |  4 ++-
 swh/dataset/test/test_orc.py               | 35 ++++++++++++++--------
 6 files changed, 45 insertions(+), 31 deletions(-)

diff --git a/requirements-swh.txt b/requirements-swh.txt
index 1afada1..b92826c 100644
--- a/requirements-swh.txt
+++ b/requirements-swh.txt
@@ -1,5 +1,5 @@
 # Add here internal Software Heritage dependencies, one per line.
 swh.core[http] >= 2
 swh.journal >= 0.9
-swh.model >= 4.3
+swh.model >= 6.13
 swh.storage >= 2.3.1
diff --git a/swh/dataset/exporter.py b/swh/dataset/exporter.py
index 07ca8c6..4ae7cec 100644
--- a/swh/dataset/exporter.py
+++ b/swh/dataset/exporter.py
@@ -9,6 +9,8 @@ from types import TracebackType
 from typing import Any, Dict, Optional, Type
 import uuid
 
+from swh.model.model import ModelObjectType
+
 
 class Exporter:
     """
@@ -46,7 +48,7 @@ class Exporter:
     ) -> Optional[bool]:
         return self.exit_stack.__exit__(exc_type, exc_value, traceback)
 
-    def process_object(self, object_type: str, obj: Dict[str, Any]) -> None:
+    def process_object(self, object_type: ModelObjectType, obj: Dict[str, Any]) -> None:
         """
         Process a SWH object to export.
 
@@ -69,7 +71,7 @@ class ExporterDispatch(Exporter):
     (e.g you can override `process_origin(self, object)` to process origins.)
     """
 
-    def process_object(self, object_type: str, obj: Dict[str, Any]) -> None:
-        method_name = "process_" + object_type
+    def process_object(self, object_type: ModelObjectType, obj: Dict[str, Any]) -> None:
+        method_name = "process_" + object_type.name.lower()
         if hasattr(self, method_name):
             getattr(self, method_name)(obj)
diff --git a/swh/dataset/journalprocessor.py b/swh/dataset/journalprocessor.py
index ab62cb1..c4461fc 100644
--- a/swh/dataset/journalprocessor.py
+++ b/swh/dataset/journalprocessor.py
@@ -32,7 +32,7 @@ from swh.dataset.exporter import Exporter
 from swh.dataset.utils import LevelDBSet
 from swh.journal.client import JournalClient
 from swh.journal.serializers import kafka_to_value
-from swh.model.model import Origin
+from swh.model.model import ModelObjectType, Origin
 from swh.model.swhids import ExtendedObjectType, ExtendedSWHID
 from swh.storage.fixer import fix_objects
 
@@ -534,7 +534,7 @@ class JournalProcessorWorker:
 
         for exporter in self.exporters:
             try:
-                exporter.process_object(object_type, obj)
+                exporter.process_object(ModelObjectType(object_type), obj)
             except Exception:
                 logger.exception(
                     "Exporter %s: error while exporting the object: %s",
diff --git a/swh/dataset/test/test_edges.py b/swh/dataset/test/test_edges.py
index bb43369..f2b1ab3 100644
--- a/swh/dataset/test/test_edges.py
+++ b/swh/dataset/test/test_edges.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020  The Software Heritage developers
+# Copyright (C) 2020-2024  The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
@@ -14,6 +14,7 @@ import pytest
 from swh.dataset.exporters.edges import GraphEdgesExporter, sort_graph_nodes
 from swh.dataset.utils import ZSTFile
 from swh.model.hashutil import MultiHash, hash_to_bytes
+from swh.model.model import ModelObjectType
 
 DATE = {
     "timestamp": {"seconds": 1234567891, "microseconds": 0},
@@ -119,7 +120,7 @@ def b64e(s: str) -> str:
 def test_export_origin(exporter):
     node_writer, edge_writer = exporter(
         {
-            "origin": [
+            ModelObjectType.ORIGIN: [
                 {"url": "ori1"},
                 {"url": "ori2"},
             ]
@@ -135,7 +136,7 @@ def test_export_origin(exporter):
 def test_export_origin_visit_status(exporter):
     node_writer, edge_writer = exporter(
         {
-            "origin_visit_status": [
+            ModelObjectType.ORIGIN_VISIT_STATUS: [
                 {
                     **TEST_ORIGIN_VISIT_STATUS,
                     "origin": "ori1",
@@ -159,7 +160,7 @@ def test_export_origin_visit_status(exporter):
 def test_export_snapshot_simple(exporter):
     node_writer, edge_writer = exporter(
         {
-            "snapshot": [
+            ModelObjectType.SNAPSHOT: [
                 {
                     "id": binhash("snp1"),
                     "branches": {
@@ -235,7 +236,7 @@ def test_export_snapshot_simple(exporter):
 def test_export_snapshot_aliases(exporter):
     node_writer, edge_writer = exporter(
         {
-            "snapshot": [
+            ModelObjectType.SNAPSHOT: [
                 {
                     "id": binhash("snp1"),
                     "branches": {
@@ -296,7 +297,7 @@ def test_export_snapshot_no_pull_requests(exporter):
         },
     }
 
-    node_writer, edge_writer = exporter({"snapshot": [snp]})
+    node_writer, edge_writer = exporter({ModelObjectType.SNAPSHOT: [snp]})
     assert edge_writer.mock_calls == [
         call(
             f"swh:1:snp:{hexhash('snp1')} swh:1:rev:{hexhash('rev1')}"
@@ -321,7 +322,7 @@ def test_export_snapshot_no_pull_requests(exporter):
     ]
 
     node_writer, edge_writer = exporter(
-        {"snapshot": [snp]}, config={"remove_pull_requests": True}
+        {ModelObjectType.SNAPSHOT: [snp]}, config={"remove_pull_requests": True}
     )
     assert edge_writer.mock_calls == [
         call(
@@ -338,7 +339,7 @@ def test_export_snapshot_no_pull_requests(exporter):
 def test_export_releases(exporter):
     node_writer, edge_writer = exporter(
         {
-            "release": [
+            ModelObjectType.RELEASE: [
                 {
                     **TEST_RELEASE,
                     "id": binhash("rel1"),
@@ -383,7 +384,7 @@ def test_export_releases(exporter):
 def test_export_revision(exporter):
     node_writer, edge_writer = exporter(
         {
-            "revision": [
+            ModelObjectType.REVISION: [
                 {
                     **TEST_REVISION,
                     "id": binhash("rev1"),
@@ -414,7 +415,7 @@ def test_export_revision(exporter):
 def test_export_directory(exporter):
     node_writer, edge_writer = exporter(
         {
-            "directory": [
+            ModelObjectType.DIRECTORY: [
                 {
                     "id": binhash("dir1"),
                     "entries": [
@@ -465,7 +466,7 @@ def test_export_directory(exporter):
 def test_export_content(exporter):
     node_writer, edge_writer = exporter(
         {
-            "content": [
+            ModelObjectType.CONTENT: [
                 {**TEST_CONTENT, "sha1_git": binhash("cnt1")},
                 {**TEST_CONTENT, "sha1_git": binhash("cnt2")},
             ]
diff --git a/swh/dataset/test/test_journal_processor.py b/swh/dataset/test/test_journal_processor.py
index 80a9145..3e66cc0 100644
--- a/swh/dataset/test/test_journal_processor.py
+++ b/swh/dataset/test/test_journal_processor.py
@@ -60,7 +60,9 @@ class ListExporter(Exporter):
         self._objects = objects
         super().__init__(*args, **kwargs)
 
-    def process_object(self, object_type: str, obj: Dict[str, Any]) -> None:
+    def process_object(
+        self, object_type: model.ModelObjectType, obj: Dict[str, Any]
+    ) -> None:
         self._objects.append((object_type, obj))
 
 
diff --git a/swh/dataset/test/test_orc.py b/swh/dataset/test/test_orc.py
index c9b551f..cc7adff 100644
--- a/swh/dataset/test/test_orc.py
+++ b/swh/dataset/test/test_orc.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2022  The Software Heritage developers
+# Copyright (C) 2020-2024  The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
@@ -15,6 +15,7 @@ import pytest
 
 from swh.dataset.exporters import orc
 from swh.dataset.relational import MAIN_TABLES, RELATION_TABLES
+from swh.model.model import ModelObjectType
 from swh.model.tests.swh_model_data import TEST_OBJECTS
 from swh.objstorage.factory import get_objstorage
 
@@ -61,14 +62,14 @@ def exporter(messages, config=None, tmpdir=None):
 
 
 def test_export_origin():
-    obj_type = "origin"
+    obj_type = ModelObjectType.ORIGIN
     output = exporter({obj_type: TEST_OBJECTS[obj_type]})
     for obj in TEST_OBJECTS[obj_type]:
         assert (hashlib.sha1(obj.url.encode()).hexdigest(), obj.url) in output[obj_type]
 
 
 def test_export_origin_visit():
-    obj_type = "origin_visit"
+    obj_type = ModelObjectType.ORIGIN_VISIT
     output = exporter({obj_type: TEST_OBJECTS[obj_type]})
     for obj in TEST_OBJECTS[obj_type]:
         assert (
@@ -80,7 +81,7 @@ def test_export_origin_visit():
 
 
 def test_export_origin_visit_status():
-    obj_type = "origin_visit_status"
+    obj_type = ModelObjectType.ORIGIN_VISIT_STATUS
     output = exporter({obj_type: TEST_OBJECTS[obj_type]})
     for obj in TEST_OBJECTS[obj_type]:
         assert (
@@ -94,7 +95,7 @@ def test_export_origin_visit_status():
 
 
 def test_export_snapshot():
-    obj_type = "snapshot"
+    obj_type = ModelObjectType.SNAPSHOT
     output = exporter({obj_type: TEST_OBJECTS[obj_type]})
     for obj in TEST_OBJECTS[obj_type]:
         assert (orc.hash_to_hex_or_none(obj.id),) in output["snapshot"]
@@ -110,7 +111,7 @@ def test_export_snapshot():
 
 
 def test_export_release():
-    obj_type = "release"
+    obj_type = ModelObjectType.RELEASE
     output = exporter({obj_type: TEST_OBJECTS[obj_type]})
     for obj in TEST_OBJECTS[obj_type]:
         assert (
@@ -128,7 +129,7 @@ def test_export_release():
 
 
 def test_export_revision():
-    obj_type = "revision"
+    obj_type = ModelObjectType.REVISION
     output = exporter({obj_type: TEST_OBJECTS[obj_type]})
     for obj in TEST_OBJECTS[obj_type]:
         assert (
@@ -155,7 +156,7 @@ def test_export_revision():
 
 
 def test_export_directory():
-    obj_type = "directory"
+    obj_type = ModelObjectType.DIRECTORY
     output = exporter({obj_type: TEST_OBJECTS[obj_type]})
     for obj in TEST_OBJECTS[obj_type]:
         assert (orc.hash_to_hex_or_none(obj.id), obj.raw_manifest) in output[
@@ -172,7 +173,7 @@ def test_export_directory():
 
 
 def test_export_content():
-    obj_type = "content"
+    obj_type = ModelObjectType.CONTENT
     output = exporter({obj_type: TEST_OBJECTS[obj_type]})
     for obj in TEST_OBJECTS[obj_type]:
         assert (
@@ -187,7 +188,7 @@ def test_export_content():
 
 
 def test_export_skipped_content():
-    obj_type = "skipped_content"
+    obj_type = ModelObjectType.SKIPPED_CONTENT
     output = exporter({obj_type: TEST_OBJECTS[obj_type]})
     for obj in TEST_OBJECTS[obj_type]:
         assert (
@@ -256,7 +257,11 @@ def test_export_related_files(max_rows, obj_type, tmpdir):
     config = {"orc": {}}
     if max_rows is not None:
         config["orc"]["max_rows"] = {obj_type: max_rows}
-    exporter({obj_type: TEST_OBJECTS[obj_type]}, config=config, tmpdir=tmpdir)
+    exporter(
+        {ModelObjectType(obj_type): TEST_OBJECTS[obj_type]},
+        config=config,
+        tmpdir=tmpdir,
+    )
     # check there are as many ORC files as objects
     orcfiles = [fname for fname in (tmpdir / obj_type).listdir(f"{obj_type}-*.orc")]
     if max_rows is None:
@@ -303,7 +308,7 @@ def test_export_related_files(max_rows, obj_type, tmpdir):
     MAIN_TABLES.keys(),
 )
 def test_export_related_files_separated(obj_type, tmpdir):
-    exporter({obj_type: TEST_OBJECTS[obj_type]}, tmpdir=tmpdir)
+    exporter({ModelObjectType(obj_type): TEST_OBJECTS[obj_type]}, tmpdir=tmpdir)
     # check there are as many ORC files as objects
     orcfiles = [fname for fname in (tmpdir / obj_type).listdir(f"{obj_type}-*.orc")]
     assert len(orcfiles) == 1
@@ -340,7 +345,11 @@ def test_export_content_with_data(monkeypatch, tmpdir):
         },
     }
 
-    output = exporter({obj_type: TEST_OBJECTS[obj_type]}, config=config, tmpdir=tmpdir)
+    output = exporter(
+        {ModelObjectType(obj_type): TEST_OBJECTS[obj_type]},
+        config=config,
+        tmpdir=tmpdir,
+    )
     for obj in TEST_OBJECTS[obj_type]:
         assert (
             orc.hash_to_hex_or_none(obj.sha1),
-- 
GitLab