From 7d6ead58a4300b2cefc3df939a5ae1845d6f5f71 Mon Sep 17 00:00:00 2001
From: David Douard <david.douard@sdfa3.org>
Date: Fri, 3 Apr 2020 15:03:32 +0200
Subject: [PATCH] writer: make the writer use swh.model objects instead of
 (possibly) dicts

And annotate the code as we go.
---
 swh/journal/serializers.py             |  10 +-
 swh/journal/tests/test_kafka_writer.py |  27 ++---
 swh/journal/writer/inmemory.py         |  11 +-
 swh/journal/writer/kafka.py            | 152 ++++++++++++++++---------
 4 files changed, 127 insertions(+), 73 deletions(-)

diff --git a/swh/journal/serializers.py b/swh/journal/serializers.py
index ea23898..247b6a9 100644
--- a/swh/journal/serializers.py
+++ b/swh/journal/serializers.py
@@ -3,12 +3,14 @@
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
 
+from typing import Any, Dict, Union
+
 import msgpack
 
 from swh.core.api.serializers import msgpack_dumps, msgpack_loads
 
 
-def key_to_kafka(key):
+def key_to_kafka(key: Union[bytes, Dict]) -> bytes:
     """Serialize a key, possibly a dict, in a predictable way"""
     p = msgpack.Packer(use_bin_type=True)
     if isinstance(key, dict):
@@ -17,16 +19,16 @@ def key_to_kafka(key):
         return p.pack(key)
 
 
-def kafka_to_key(kafka_key):
+def kafka_to_key(kafka_key: bytes) -> Union[bytes, Dict]:
     """Deserialize a key"""
     return msgpack.loads(kafka_key)
 
 
-def value_to_kafka(value):
+def value_to_kafka(value: Any) -> bytes:
     """Serialize some data for storage in kafka"""
     return msgpack_dumps(value)
 
 
-def kafka_to_value(kafka_value):
+def kafka_to_value(kafka_value: bytes) -> Any:
     """Deserialize some data stored in kafka"""
     return msgpack_loads(kafka_value)
diff --git a/swh/journal/tests/test_kafka_writer.py b/swh/journal/tests/test_kafka_writer.py
index 90861a2..575c040 100644
--- a/swh/journal/tests/test_kafka_writer.py
+++ b/swh/journal/tests/test_kafka_writer.py
@@ -16,12 +16,14 @@ from swh.journal.replay import object_converter_fn
 from swh.journal.serializers import (
     kafka_to_key, kafka_to_value
 )
-from swh.journal.writer.kafka import KafkaJournalWriter
+from swh.journal.writer.kafka import KafkaJournalWriter, OBJECT_TYPES
 
 from swh.model.model import Content, Origin, BaseModel
 
 from .conftest import OBJECT_TYPE_KEYS
 
+MODEL_OBJECTS = {v: k for (k, v) in OBJECT_TYPES.items()}
+
 
 def assert_written(consumer, kafka_prefix, expected_messages):
     consumed_objects = defaultdict(list)
@@ -76,25 +78,24 @@ def test_kafka_writer(
         consumer: Consumer):
     kafka_prefix += '.swh.journal.objects'
 
-    config = {
-        'brokers': ['localhost:%d' % kafka_server[1]],
-        'client_id': 'kafka_writer',
-        'prefix': kafka_prefix,
-        'producer_config': {
+    writer = KafkaJournalWriter(
+        brokers=[f'localhost:{kafka_server[1]}'],
+        client_id='kafka_writer',
+        prefix=kafka_prefix,
+        producer_config={
             'message.max.bytes': 100000000,
-        }
-    }
-
-    writer = KafkaJournalWriter(**config)
+        })
 
     expected_messages = 0
 
     for (object_type, (_, objects)) in OBJECT_TYPE_KEYS.items():
-        for (num, object_) in enumerate(objects):
+        for (num, object_d) in enumerate(objects):
             if object_type == 'origin_visit':
-                object_ = {**object_, 'visit': num}
+                object_d = {**object_d, 'visit': num}
             if object_type == 'content':
-                object_ = {**object_, 'ctime': datetime.datetime.now()}
+                object_d = {**object_d, 'ctime': datetime.datetime.now()}
+            object_ = MODEL_OBJECTS[object_type].from_dict(object_d)
+
             writer.write_addition(object_type, object_)
             expected_messages += 1
 
diff --git a/swh/journal/writer/inmemory.py b/swh/journal/writer/inmemory.py
index 6c7b84c..175f473 100644
--- a/swh/journal/writer/inmemory.py
+++ b/swh/journal/writer/inmemory.py
@@ -5,10 +5,14 @@
 
 import logging
 import copy
+
 from multiprocessing import Manager
+from typing import List
 
 from swh.model.model import BaseModel
 
+from .kafka import ModelObject
+
 logger = logging.getLogger(__name__)
 
 
@@ -18,13 +22,12 @@ class InMemoryJournalWriter:
         self.manager = Manager()
         self.objects = self.manager.list()
 
-    def write_addition(self, object_type, object_):
-        if isinstance(object_, BaseModel):
-            object_ = object_.to_dict()
+    def write_addition(self, object_type: str, object_: ModelObject) -> None:
+        assert isinstance(object_, BaseModel)
         self.objects.append((object_type, copy.deepcopy(object_)))
 
     write_update = write_addition
 
-    def write_additions(self, object_type, objects):
+    def write_additions(self, object_type: str, objects: List[ModelObject]) -> None:
         for object_ in objects:
             self.write_addition(object_type, object_)
diff --git a/swh/journal/writer/kafka.py b/swh/journal/writer/kafka.py
index 648d20a..3c2949c 100644
--- a/swh/journal/writer/kafka.py
+++ b/swh/journal/writer/kafka.py
@@ -4,54 +4,84 @@
 # See top-level LICENSE file for more information
 
 import logging
+from typing import Dict, Iterable, List, Type, Union, overload
 
 from confluent_kafka import Producer, KafkaException
 
 from swh.model.hashutil import DEFAULT_ALGORITHMS
-from swh.model.model import BaseModel
+from swh.model.model import (
+    BaseModel,
+    Content,
+    Directory,
+    Origin,
+    OriginVisit,
+    Release,
+    Revision,
+    SkippedContent,
+    Snapshot,
+)
 
 from swh.journal.serializers import key_to_kafka, value_to_kafka
 
 logger = logging.getLogger(__name__)
 
+OBJECT_TYPES: Dict[Type[BaseModel], str] = {
+    Content: "content",
+    Directory: "directory",
+    Origin: "origin",
+    OriginVisit: "origin_visit",
+    Release: "release",
+    Revision: "revision",
+    SkippedContent: "skipped_content",
+    Snapshot: "snapshot",
+}
+
+ModelObject = Union[
+    Content, Directory, Origin, OriginVisit, Release, Revision, SkippedContent, Snapshot
+]
+
 
 class KafkaJournalWriter:
     """This class is instantiated and used by swh-storage to write incoming
     new objects to Kafka before adding them to the storage backend
     (eg. postgresql) itself."""
-    def __init__(self, brokers, prefix, client_id, producer_config=None):
-        self._prefix = prefix
 
-        if isinstance(brokers, str):
-            brokers = [brokers]
+    def __init__(
+        self,
+        brokers: Iterable[str],
+        prefix: str,
+        client_id: str,
+        producer_config: Dict = Union[None, Dict],
+    ):
+        self._prefix = prefix
 
         if not producer_config:
             producer_config = {}
 
-        self.producer = Producer({
-            'bootstrap.servers': ','.join(brokers),
-            'client.id': client_id,
-            'on_delivery': self._on_delivery,
-            'error_cb': self._error_cb,
-            'logger': logger,
-            'acks': 'all',
-            **producer_config,
-        })
+        self.producer = Producer(
+            {
+                "bootstrap.servers": ",".join(brokers),
+                "client.id": client_id,
+                "on_delivery": self._on_delivery,
+                "error_cb": self._error_cb,
+                "logger": logger,
+                "acks": "all",
+                **producer_config,
+            }
+        )
 
     def _error_cb(self, error):
         if error.fatal():
             raise KafkaException(error)
-        logger.info('Received non-fatal kafka error: %s', error)
+        logger.info("Received non-fatal kafka error: %s", error)
 
     def _on_delivery(self, error, message):
         if error is not None:
             self._error_cb(error)
 
-    def send(self, topic, key, value):
+    def send(self, topic: str, key, value):
         self.producer.produce(
-            topic=topic,
-            key=key_to_kafka(key),
-            value=value_to_kafka(value),
+            topic=topic, key=key_to_kafka(key), value=value_to_kafka(value),
         )
 
         # Need to service the callbacks regularly by calling poll
@@ -60,54 +90,72 @@ class KafkaJournalWriter:
     def flush(self):
         self.producer.flush()
 
+    # these @overload'ed versions of the _get_key method aim at helping mypy figuring
+    # the correct type-ing.
+    @overload
+    def _get_key(
+        self, object_type: str, object_: Union[Revision, Release, Directory, Snapshot]
+    ) -> bytes:
+        ...
+
+    @overload
+    def _get_key(self, object_type: str, object_: Content) -> bytes:
+        ...
+
+    @overload
+    def _get_key(self, object_type: str, object_: SkippedContent) -> Dict[str, bytes]:
+        ...
+
+    @overload
+    def _get_key(self, object_type: str, object_: Origin) -> Dict[str, bytes]:
+        ...
+
+    @overload
+    def _get_key(self, object_type: str, object_: OriginVisit) -> Dict[str, str]:
+        ...
+
     def _get_key(self, object_type, object_):
-        if object_type in ('revision', 'release', 'directory', 'snapshot'):
-            return object_['id']
-        elif object_type == 'content':
-            return object_['sha1']  # TODO: use a dict of hashes
-        elif object_type == 'skipped_content':
-            return {
-                hash: object_[hash]
-                for hash in DEFAULT_ALGORITHMS
-            }
-        elif object_type == 'origin':
-            return {'url': object_['url']}
-        elif object_type == 'origin_visit':
+        if object_type in ("revision", "release", "directory", "snapshot"):
+            return object_.id
+        elif object_type == "content":
+            return object_.sha1  # TODO: use a dict of hashes
+        elif object_type == "skipped_content":
+            return {hash: getattr(object_, hash) for hash in DEFAULT_ALGORITHMS}
+        elif object_type == "origin":
+            return {"url": object_.url}
+        elif object_type == "origin_visit":
             return {
-                'origin': object_['origin'],
-                'date': str(object_['date']),
+                "origin": object_.origin,
+                "date": str(object_.date),
             }
         else:
-            raise ValueError('Unknown object type: %s.' % object_type)
-
-    def _sanitize_object(self, object_type, object_):
-        if object_type == 'origin_visit':
-            return {
-                **object_,
-                'date': str(object_['date']),
-            }
-        elif object_type == 'origin':
-            assert 'id' not in object_
-        return object_
-
-    def _write_addition(self, object_type, object_):
+            raise ValueError("Unknown object type: %s." % object_type)
+
+    def _sanitize_object(
+        self, object_type: str, object_: ModelObject
+    ) -> Dict[str, str]:
+        dict_ = object_.to_dict()
+        if object_type == "origin_visit":
+            # :(
+            dict_["date"] = str(dict_["date"])
+        return dict_
+
+    def _write_addition(self, object_type: str, object_: ModelObject) -> None:
         """Write a single object to the journal"""
-        if isinstance(object_, BaseModel):
-            object_ = object_.to_dict()
-        topic = '%s.%s' % (self._prefix, object_type)
+        topic = f"{self._prefix}.{object_type}"
         key = self._get_key(object_type, object_)
         dict_ = self._sanitize_object(object_type, object_)
-        logger.debug('topic: %s, key: %s, value: %s', topic, key, dict_)
+        logger.debug("topic: %s, key: %s, value: %s", topic, key, dict_)
         self.send(topic, key=key, value=dict_)
 
-    def write_addition(self, object_type, object_):
+    def write_addition(self, object_type: str, object_: ModelObject) -> None:
         """Write a single object to the journal"""
         self._write_addition(object_type, object_)
         self.flush()
 
     write_update = write_addition
 
-    def write_additions(self, object_type, objects):
+    def write_additions(self, object_type: str, objects: List[ModelObject]) -> None:
         """Write a set of objects to the journal"""
         for object_ in objects:
             self._write_addition(object_type, object_)
-- 
GitLab