From e632abed41c5cd71507c4222e1e85987556b232c Mon Sep 17 00:00:00 2001
From: David Douard <david.douard@sdfa3.org>
Date: Wed, 13 May 2020 17:03:01 +0200
Subject: [PATCH] Tag model entities with their "object_type"

this aims at preventing constant usage of isinstance() based dispatch
code when writing generic code handling model entities.

For example, the "object_type" argument of JournalWriter.write_addition() has
become superflous now we only pass model entities, etc.

This idea comes olasd's reading of mypy doc:

  https://mypy.readthedocs.io/en/latest/literal_types.html#tagged-unions

This comes with a refactoring of from_dict.DiskBackedContent to make
it *not* inherit from model.Content: object_type being Final, it cannot
be overloaded.
---
 requirements.txt              |  5 +++--
 swh/model/from_disk.py        | 33 +++++++++++++++++++++++++++++++--
 swh/model/model.py            | 31 ++++++++++++++++++++++++++++++-
 swh/model/tests/test_model.py | 22 ++++++++++++++++++++++
 4 files changed, 86 insertions(+), 5 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index c2d38df1..b70187b2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,9 +1,10 @@
 # Add here external Python modules dependencies, one per line. Module names
 # should match https://pypi.python.org/pypi names. For the full spec or
 # dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html
-vcversioner
 attrs
 attrs_strict >= 0.0.7
 hypothesis
-python-dateutil
 iso8601
+python-dateutil
+typing_extensions
+vcversioner
diff --git a/swh/model/from_disk.py b/swh/model/from_disk.py
index 5176dc9e..da559121 100644
--- a/swh/model/from_disk.py
+++ b/swh/model/from_disk.py
@@ -3,12 +3,15 @@
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
 
+import datetime
 import enum
 import os
 import stat
 
 import attr
+from attrs_strict import type_validator
 from typing import List, Optional, Iterable, Any
+from typing_extensions import Final
 
 from .hashutil import MultiHash
 from .merkle import MerkleLeaf, MerkleNode
@@ -22,11 +25,37 @@ from . import model
 
 
 @attr.s
-class DiskBackedContent(model.Content):
-    """Subclass of Content, which allows lazy-loading data from the disk."""
+class DiskBackedContent(model.BaseContent):
+    """Content-like class, which allows lazy-loading data from the disk."""
+
+    object_type: Final = "content_file"
+
+    sha1 = attr.ib(type=bytes, validator=type_validator())
+    sha1_git = attr.ib(type=model.Sha1Git, validator=type_validator())
+    sha256 = attr.ib(type=bytes, validator=type_validator())
+    blake2s256 = attr.ib(type=bytes, validator=type_validator())
+
+    length = attr.ib(type=int, validator=type_validator())
+
+    status = attr.ib(
+        type=str,
+        validator=attr.validators.in_(["visible", "hidden"]),
+        default="visible",
+    )
+
+    ctime = attr.ib(
+        type=Optional[datetime.datetime],
+        validator=type_validator(),
+        default=None,
+        eq=False,
+    )
 
     path = attr.ib(type=Optional[bytes], default=None)
 
+    @classmethod
+    def from_dict(cls, d):
+        return cls(**d)
+
     def __attrs_post_init__(self):
         if self.path is None:
             raise TypeError("path must not be None.")
diff --git a/swh/model/model.py b/swh/model/model.py
index 220733f4..bffe711d 100644
--- a/swh/model/model.py
+++ b/swh/model/model.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2018-2019 The Software Heritage developers
+# Copyright (C) 2018-2020 The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
@@ -9,6 +9,7 @@ from abc import ABCMeta, abstractmethod
 from enum import Enum
 from hashlib import sha256
 from typing import Dict, Optional, Tuple, TypeVar, Union
+from typing_extensions import Final
 
 import attr
 from attrs_strict import type_validator
@@ -101,6 +102,8 @@ class HashableObject(metaclass=ABCMeta):
 class Person(BaseModel):
     """Represents the author/committer of a revision or release."""
 
+    object_type: Final = "person"
+
     fullname = attr.ib(type=bytes, validator=type_validator())
     name = attr.ib(type=Optional[bytes], validator=type_validator())
     email = attr.ib(type=Optional[bytes], validator=type_validator())
@@ -153,6 +156,8 @@ class Person(BaseModel):
 class Timestamp(BaseModel):
     """Represents a naive timestamp from a VCS."""
 
+    object_type: Final = "timestamp"
+
     seconds = attr.ib(type=int, validator=type_validator())
     microseconds = attr.ib(type=int, validator=type_validator())
 
@@ -173,6 +178,8 @@ class Timestamp(BaseModel):
 class TimestampWithTimezone(BaseModel):
     """Represents a TZ-aware timestamp from a VCS."""
 
+    object_type: Final = "timestamp_with_timezone"
+
     timestamp = attr.ib(type=Timestamp, validator=type_validator())
     offset = attr.ib(type=int, validator=type_validator())
     negative_utc = attr.ib(type=bool, validator=type_validator())
@@ -223,6 +230,8 @@ class TimestampWithTimezone(BaseModel):
 class Origin(BaseModel):
     """Represents a software source: a VCS and an URL."""
 
+    object_type: Final = "origin"
+
     url = attr.ib(type=str, validator=type_validator())
 
 
@@ -231,6 +240,8 @@ class OriginVisit(BaseModel):
     """Represents a visit of an origin at a given point in time, by a
     SWH loader."""
 
+    object_type: Final = "origin_visit"
+
     origin = attr.ib(type=str, validator=type_validator())
     date = attr.ib(type=datetime.datetime, validator=type_validator())
     type = attr.ib(type=str, validator=type_validator())
@@ -258,6 +269,8 @@ class OriginVisitStatus(BaseModel):
 
     """
 
+    object_type: Final = "origin_visit_status"
+
     origin = attr.ib(type=str, validator=type_validator())
     visit = attr.ib(type=int, validator=type_validator())
 
@@ -298,6 +311,8 @@ class ObjectType(Enum):
 class SnapshotBranch(BaseModel):
     """Represents one of the branches of a snapshot."""
 
+    object_type: Final = "snapshot_branch"
+
     target = attr.ib(type=bytes, validator=type_validator())
     target_type = attr.ib(type=TargetType, validator=type_validator())
 
@@ -318,6 +333,8 @@ class SnapshotBranch(BaseModel):
 class Snapshot(BaseModel, HashableObject):
     """Represents the full state of an origin at a given point in time."""
 
+    object_type: Final = "snapshot"
+
     branches = attr.ib(
         type=Dict[bytes, Optional[SnapshotBranch]], validator=type_validator()
     )
@@ -341,6 +358,8 @@ class Snapshot(BaseModel, HashableObject):
 
 @attr.s(frozen=True)
 class Release(BaseModel, HashableObject):
+    object_type: Final = "release"
+
     name = attr.ib(type=bytes, validator=type_validator())
     message = attr.ib(type=Optional[bytes], validator=type_validator())
     target = attr.ib(type=Optional[Sha1Git], validator=type_validator())
@@ -399,6 +418,8 @@ class RevisionType(Enum):
 
 @attr.s(frozen=True)
 class Revision(BaseModel, HashableObject):
+    object_type: Final = "revision"
+
     message = attr.ib(type=Optional[bytes], validator=type_validator())
     author = attr.ib(type=Person, validator=type_validator())
     committer = attr.ib(type=Person, validator=type_validator())
@@ -453,6 +474,8 @@ class Revision(BaseModel, HashableObject):
 
 @attr.s(frozen=True)
 class DirectoryEntry(BaseModel):
+    object_type: Final = "directory_entry"
+
     name = attr.ib(type=bytes, validator=type_validator())
     type = attr.ib(type=str, validator=attr.validators.in_(["file", "dir", "rev"]))
     target = attr.ib(type=Sha1Git, validator=type_validator())
@@ -462,6 +485,8 @@ class DirectoryEntry(BaseModel):
 
 @attr.s(frozen=True)
 class Directory(BaseModel, HashableObject):
+    object_type: Final = "directory"
+
     entries = attr.ib(type=Tuple[DirectoryEntry, ...], validator=type_validator())
     id = attr.ib(type=Sha1Git, validator=type_validator(), default=b"")
 
@@ -518,6 +543,8 @@ class BaseContent(BaseModel):
 
 @attr.s(frozen=True)
 class Content(BaseContent):
+    object_type: Final = "content"
+
     sha1 = attr.ib(type=bytes, validator=type_validator())
     sha1_git = attr.ib(type=Sha1Git, validator=type_validator())
     sha256 = attr.ib(type=bytes, validator=type_validator())
@@ -584,6 +611,8 @@ class Content(BaseContent):
 
 @attr.s(frozen=True)
 class SkippedContent(BaseContent):
+    object_type: Final = "skipped_content"
+
     sha1 = attr.ib(type=Optional[bytes], validator=type_validator())
     sha1_git = attr.ib(type=Optional[Sha1Git], validator=type_validator())
     sha256 = attr.ib(type=Optional[bytes], validator=type_validator())
diff --git a/swh/model/tests/test_model.py b/swh/model/tests/test_model.py
index e126ca57..df949244 100644
--- a/swh/model/tests/test_model.py
+++ b/swh/model/tests/test_model.py
@@ -13,6 +13,7 @@ from hypothesis.strategies import binary
 import pytest
 
 from swh.model.model import (
+    BaseModel,
     Content,
     SkippedContent,
     Directory,
@@ -468,3 +469,24 @@ def test_snapshot_model_id_computation():
     snp_id = hash_to_bytes(snapshot_identifier(snp_dict))
     snp_model = Snapshot.from_dict(snp_dict)
     assert snp_model.id == snp_id
+
+
+@given(strategies.objects(split_content=True))
+def test_object_type(objtype_and_obj):
+    obj_type, obj = objtype_and_obj
+    assert obj_type == obj.object_type
+
+
+def test_object_type_is_final():
+    object_types = set()
+
+    def check_final(cls):
+        if hasattr(cls, "object_type"):
+            assert cls.object_type not in object_types
+            object_types.add(cls.object_type)
+        if cls.__subclasses__():
+            assert not hasattr(cls, "object_type")
+        for subcls in cls.__subclasses__():
+            check_final(subcls)
+
+    check_final(BaseModel)
-- 
GitLab