From f47cc1b79308609d40d4d2331311ed0d9f4d5354 Mon Sep 17 00:00:00 2001
From: Pierre-Yves David <pierre-yves.david@ens-lyon.org>
Date: Tue, 14 May 2024 02:29:47 +0200
Subject: [PATCH] model: adds type annotation for iter_directory

This requires cleaning various item along the ways. Which is probably an
added benefit. Especially, mypy now consider BaseModel to hold a
object_type attribute.
---
 swh/model/from_disk.py        | 40 +++++++++++++++++++++++------------
 swh/model/model.py            | 17 +++++++++++----
 swh/model/tests/test_model.py | 12 ++++++++---
 3 files changed, 49 insertions(+), 20 deletions(-)

diff --git a/swh/model/from_disk.py b/swh/model/from_disk.py
index 26bb442b..6ce71cbc 100644
--- a/swh/model/from_disk.py
+++ b/swh/model/from_disk.py
@@ -27,6 +27,8 @@ from typing import (
     Optional,
     Pattern,
     Tuple,
+    Union,
+    cast,
 )
 import warnings
 
@@ -385,7 +387,7 @@ def ignore_directories_patterns(root_path: bytes, patterns: Iterable[bytes]):
 
 
 def iter_directory(
-    directory,
+    directory: "Directory",
 ) -> Tuple[List[model.Content], List[model.SkippedContent], List[model.Directory]]:
     """Return the directory listing from a disk-memory directory instance.
 
@@ -400,18 +402,19 @@ def iter_directory(
     skipped_contents: List[model.SkippedContent] = []
     directories: List[model.Directory] = []
 
-    for obj in directory.iter_tree():
-        obj = obj.to_model()
-        obj_type = obj.object_type
-        if obj_type in (model.Content.object_type, DiskBackedContent.object_type):
-            # FIXME: read the data from disk later (when the
-            # storage buffer is flushed).
-            obj = obj.with_data()
-            contents.append(obj)
-        elif obj_type == model.SkippedContent.object_type:
-            skipped_contents.append(obj)
-        elif obj_type == model.Directory.object_type:
-            directories.append(obj)
+    for i_obj in directory.iter_tree():
+        if isinstance(i_obj, Directory):
+            directories.append(i_obj.to_model())
+        elif isinstance(i_obj, Content):
+            obj = i_obj.to_model()
+            if isinstance(obj, model.SkippedContent):
+                skipped_contents.append(obj)
+            else:
+                # FIXME: read the data from disk later (when the
+                # storage buffer is flushed).
+                #
+                c_obj = cast(Union[model.Content, DiskBackedContent], obj)
+                contents.append(c_obj.with_data())
         else:
             raise TypeError(f"Unexpected object type from disk: {obj}")
 
@@ -518,6 +521,17 @@ class Directory(MerkleNode):
         self.__entries = None
         self.__model_object = None
 
+    # note: the overwrite could probably be done by parametrysing the
+    # MerkelNode type, but that is a much bigger rework than the series
+    # introducting this change.
+    def iter_tree(self, dedup=True) -> Iterator[Union["Directory", "Content"]]:
+        """Yields all children nodes, recursively. Common nodes are deduplicated
+        by default (deduplication can be turned off setting the given argument
+        'dedup' to False).
+        """
+        tree = super().iter_tree(dedup=dedup)
+        yield from cast(Iterator[Union["Directory", "Content"]], tree)
+
     def invalidate_hash(self):
         self.__entries = None
         self.__model_object = None
diff --git a/swh/model/model.py b/swh/model/model.py
index 06e19013..df25977e 100644
--- a/swh/model/model.py
+++ b/swh/model/model.py
@@ -15,7 +15,7 @@ All classes define a ``from_dict`` class method and a ``to_dict``
 method to convert between them and msgpack-serializable objects.
 """
 
-from abc import ABCMeta, abstractmethod
+from abc import ABC, ABCMeta, abstractmethod
 import collections
 import datetime
 from enum import Enum
@@ -341,7 +341,7 @@ ModelType = TypeVar("ModelType", bound="BaseModel")
 HashableModelType = TypeVar("HashableModelType", bound="BaseHashableModel")
 
 
-class BaseModel:
+class BaseModel(ABC):
     """Base class for SWH model classes.
 
     Provides serialization/deserialization to/from Python dictionaries,
@@ -349,6 +349,15 @@ class BaseModel:
 
     __slots__ = ()
 
+    @property
+    @abstractmethod
+    def object_type(self) -> str:
+        # Some juggling to please mypy
+        #
+        # Note: starting from Python 3.11 we can combine @property with
+        # @classmethod which is the real intend here.
+        raise NotImplementedError
+
     def to_dict(self):
         """Wrapper of `attr.asdict` that can be overridden by subclasses
         that have special handling of some of the fields."""
@@ -1380,7 +1389,7 @@ class Directory(HashableObjectWithManifest, BaseModel):
 
 
 @attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators)
-class BaseContent(BaseModel):
+class BaseContent(BaseModel, ABC):
     status = attr.ib(
         type=str, validator=attr.validators.in_(["visible", "hidden", "absent"])
     )
@@ -1417,7 +1426,7 @@ class BaseContent(BaseModel):
 
 @attr.s(frozen=True, slots=True, field_transformer=optimize_all_validators)
 class Content(BaseContent):
-    object_type: Final = "content"
+    object_type: Final[str] = "content"
 
     sha1 = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr)
     sha1_git = attr.ib(type=Sha1Git, validator=generic_type_validator, repr=hash_repr)
diff --git a/swh/model/tests/test_model.py b/swh/model/tests/test_model.py
index 5b71f455..2a12ec6e 100644
--- a/swh/model/tests/test_model.py
+++ b/swh/model/tests/test_model.py
@@ -1717,12 +1717,18 @@ def test_object_type_is_final():
     def check_final(cls):
         if cls in checked_classes:
             return
+
         checked_classes.add(cls)
-        if hasattr(cls, "object_type"):
+        obj_type = sentinel = object()
+        obj_type = getattr(cls, "object_type", sentinel)
+        if getattr(obj_type, "__isabstractmethod__", False):
+            obj_type = sentinel
+        if obj_type is sentinel:
+            assert cls.__subclasses__()
+        else:
+            assert not cls.__subclasses__()
             assert cls.object_type not in object_types
             object_types.add(cls.object_type)
-        if cls.__subclasses__():
-            assert not hasattr(cls, "object_type")
         for subcls in cls.__subclasses__():
             check_final(subcls)
 
-- 
GitLab