From 85ca7d7848008951f2e26c55c1c72ed9fa92cefb Mon Sep 17 00:00:00 2001
From: David Douard <david.douard@sdfa3.org>
Date: Fri, 20 Mar 2020 12:59:56 +0100
Subject: [PATCH] model: use attrs_static to enforce type validation of model
 objects

This ensures all instanciated model entities have valid types for attributes.

Related to T2308.
---
 mypy.ini           |   4 +-
 requirements.txt   |   1 +
 swh/model/model.py | 319 ++++++++++++++++++++++++++++++++-------------
 3 files changed, 229 insertions(+), 95 deletions(-)

diff --git a/mypy.ini b/mypy.ini
index 8e421de2..5467ded0 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -2,9 +2,11 @@
 namespace_packages = True
 warn_unused_ignores = True
 
-
 # 3rd party libraries without stubs (yet)
 
+[mypy-attrs_strict.*]  # a bit sad, but...
+ignore_missing_imports = True
+
 [mypy-django.*]  # false positive, only used my hypotesis' extras
 ignore_missing_imports = True
 
diff --git a/requirements.txt b/requirements.txt
index 1577daa9..a097082e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@
 # dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html
 vcversioner
 attrs
+attrs_strict
 hypothesis
 python-dateutil
 iso8601
diff --git a/swh/model/model.py b/swh/model/model.py
index b296b227..eb2ec15a 100644
--- a/swh/model/model.py
+++ b/swh/model/model.py
@@ -7,9 +7,10 @@ import datetime
 
 from abc import ABCMeta, abstractmethod
 from enum import Enum
-from typing import List, Optional, Dict, Union
+from typing import Dict, List, Optional, Union
 
 import attr
+from attrs_strict import type_validator
 import dateutil.parser
 import iso8601
 
@@ -84,9 +85,15 @@ class HashableObject(metaclass=ABCMeta):
 @attr.s(frozen=True)
 class Person(BaseModel):
     """Represents the author/committer of a revision or release."""
-    fullname = attr.ib(type=bytes)
-    name = attr.ib(type=Optional[bytes])
-    email = attr.ib(type=Optional[bytes])
+    fullname = attr.ib(
+        type=bytes,
+        validator=type_validator())
+    name = attr.ib(
+        type=Optional[bytes],
+        validator=type_validator())
+    email = attr.ib(
+        type=Optional[bytes],
+        validator=type_validator())
 
     @classmethod
     def from_fullname(cls, fullname: bytes):
@@ -131,8 +138,12 @@ class Person(BaseModel):
 @attr.s(frozen=True)
 class Timestamp(BaseModel):
     """Represents a naive timestamp from a VCS."""
-    seconds = attr.ib(type=int)
-    microseconds = attr.ib(type=int)
+    seconds = attr.ib(
+        type=int,
+        validator=type_validator())
+    microseconds = attr.ib(
+        type=int,
+        validator=type_validator())
 
     @seconds.validator
     def check_seconds(self, attribute, value):
@@ -150,9 +161,15 @@ class Timestamp(BaseModel):
 @attr.s(frozen=True)
 class TimestampWithTimezone(BaseModel):
     """Represents a TZ-aware timestamp from a VCS."""
-    timestamp = attr.ib(type=Timestamp)
-    offset = attr.ib(type=int)
-    negative_utc = attr.ib(type=bool)
+    timestamp = attr.ib(
+        type=Timestamp,
+        validator=type_validator())
+    offset = attr.ib(
+        type=int,
+        validator=type_validator())
+    negative_utc = attr.ib(
+        type=bool,
+        validator=type_validator())
 
     @offset.validator
     def check_offset(self, attribute, value):
@@ -193,25 +210,38 @@ class TimestampWithTimezone(BaseModel):
 @attr.s(frozen=True)
 class Origin(BaseModel):
     """Represents a software source: a VCS and an URL."""
-    url = attr.ib(type=str)
+    url = attr.ib(
+        type=str,
+        validator=type_validator())
 
 
 @attr.s(frozen=True)
 class OriginVisit(BaseModel):
     """Represents a visit of an origin at a given point in time, by a
     SWH loader."""
-    origin = attr.ib(type=str)
-    date = attr.ib(type=datetime.datetime)
+    origin = attr.ib(
+        type=str,
+        validator=type_validator())
+    date = attr.ib(
+        type=datetime.datetime,
+        validator=type_validator())
     status = attr.ib(
         type=str,
         validator=attr.validators.in_(['ongoing', 'full', 'partial']))
-    type = attr.ib(type=str)
-    snapshot = attr.ib(type=Optional[Sha1Git])
-    metadata = attr.ib(type=Optional[Dict[str, object]],
-                       default=None)
-
-    visit = attr.ib(type=Optional[int],
-                    default=None)
+    type = attr.ib(
+        type=str,
+        validator=type_validator())
+    snapshot = attr.ib(
+        type=Optional[Sha1Git],
+        validator=type_validator())
+    metadata = attr.ib(
+        type=Optional[Dict[str, object]],
+        validator=type_validator(),
+        default=None)
+    visit = attr.ib(
+        type=Optional[int],
+        validator=type_validator(),
+        default=None)
     """Should not be set before calling 'origin_visit_add()'."""
 
     def to_dict(self):
@@ -225,13 +255,10 @@ class OriginVisit(BaseModel):
     @classmethod
     def from_dict(cls, d):
         """Parses the date from a string, and accepts missing visit ids."""
-        d = d.copy()
-        date = d.pop('date')
-        return cls(
-            date=(date
-                  if isinstance(date, datetime.datetime)
-                  else dateutil.parser.parse(date)),
-            **d)
+        if isinstance(d['date'], str):
+            d = d.copy()
+            d['date'] = dateutil.parser.parse(d['date'])
+        return super().from_dict(d)
 
 
 @attr.s(frozen=True)
@@ -239,16 +266,26 @@ class OriginVisitUpdate(BaseModel):
     """Represents a visit update of an origin at a given point in time.
 
     """
-    origin = attr.ib(type=str)
-    visit = attr.ib(type=int)
-
-    date = attr.ib(type=datetime.datetime)
+    origin = attr.ib(
+        type=str,
+        validator=type_validator())
+    visit = attr.ib(
+        type=int,
+        validator=type_validator())
+
+    date = attr.ib(
+        type=datetime.datetime,
+        validator=type_validator())
     status = attr.ib(
         type=str,
         validator=attr.validators.in_(['ongoing', 'full', 'partial']))
-    snapshot = attr.ib(type=Optional[Sha1Git])
-    metadata = attr.ib(type=Optional[Dict[str, object]],
-                       default=None)
+    snapshot = attr.ib(
+        type=Optional[Sha1Git],
+        validator=type_validator())
+    metadata = attr.ib(
+        type=Optional[Dict[str, object]],
+        validator=type_validator(),
+        default=None)
 
 
 class TargetType(Enum):
@@ -274,8 +311,12 @@ class ObjectType(Enum):
 @attr.s(frozen=True)
 class SnapshotBranch(BaseModel):
     """Represents one of the branches of a snapshot."""
-    target = attr.ib(type=bytes)
-    target_type = attr.ib(type=TargetType)
+    target = attr.ib(
+        type=bytes,
+        validator=type_validator())
+    target_type = attr.ib(
+        type=TargetType,
+        validator=type_validator())
 
     @target.validator
     def check_target(self, attribute, value):
@@ -296,8 +337,13 @@ class SnapshotBranch(BaseModel):
 @attr.s(frozen=True)
 class Snapshot(BaseModel, HashableObject):
     """Represents the full state of an origin at a given point in time."""
-    branches = attr.ib(type=Dict[bytes, Optional[SnapshotBranch]])
-    id = attr.ib(type=Sha1Git, default=b'')
+    branches = attr.ib(
+        type=Dict[bytes, Optional[SnapshotBranch]],
+        validator=type_validator())
+    id = attr.ib(
+        type=Sha1Git,
+        validator=type_validator(),
+        default=b'')
 
     @staticmethod
     def compute_hash(object_dict):
@@ -316,18 +362,37 @@ class Snapshot(BaseModel, HashableObject):
 
 @attr.s(frozen=True)
 class Release(BaseModel, HashableObject):
-    name = attr.ib(type=bytes)
-    message = attr.ib(type=bytes)
-    target = attr.ib(type=Optional[Sha1Git])
-    target_type = attr.ib(type=ObjectType)
-    synthetic = attr.ib(type=bool)
-    author = attr.ib(type=Optional[Person],
-                     default=None)
-    date = attr.ib(type=Optional[TimestampWithTimezone],
-                   default=None)
-    metadata = attr.ib(type=Optional[Dict[str, object]],
-                       default=None)
-    id = attr.ib(type=Sha1Git, default=b'')
+    name = attr.ib(
+        type=bytes,
+        validator=type_validator())
+    message = attr.ib(
+        type=bytes,
+        validator=type_validator())
+    target = attr.ib(
+        type=Optional[Sha1Git],
+        validator=type_validator())
+    target_type = attr.ib(
+        type=ObjectType,
+        validator=type_validator())
+    synthetic = attr.ib(
+        type=bool,
+        validator=type_validator())
+    author = attr.ib(
+        type=Optional[Person],
+        validator=type_validator(),
+        default=None)
+    date = attr.ib(
+        type=Optional[TimestampWithTimezone],
+        validator=type_validator(),
+        default=None)
+    metadata = attr.ib(
+        type=Optional[Dict[str, object]],
+        validator=type_validator(),
+        default=None)
+    id = attr.ib(
+        type=Sha1Git,
+        validator=type_validator(),
+        default=b'')
 
     @staticmethod
     def compute_hash(object_dict):
@@ -367,19 +432,42 @@ class RevisionType(Enum):
 
 @attr.s(frozen=True)
 class Revision(BaseModel, HashableObject):
-    message = attr.ib(type=bytes)
-    author = attr.ib(type=Person)
-    committer = attr.ib(type=Person)
-    date = attr.ib(type=Optional[TimestampWithTimezone])
-    committer_date = attr.ib(type=Optional[TimestampWithTimezone])
-    type = attr.ib(type=RevisionType)
-    directory = attr.ib(type=Sha1Git)
-    synthetic = attr.ib(type=bool)
-    metadata = attr.ib(type=Optional[Dict[str, object]],
-                       default=None)
-    parents = attr.ib(type=List[Sha1Git],
-                      default=attr.Factory(list))
-    id = attr.ib(type=Sha1Git, default=b'')
+    message = attr.ib(
+        type=bytes,
+        validator=type_validator())
+    author = attr.ib(
+        type=Person,
+        validator=type_validator())
+    committer = attr.ib(
+        type=Person,
+        validator=type_validator())
+    date = attr.ib(
+        type=Optional[TimestampWithTimezone],
+        validator=type_validator())
+    committer_date = attr.ib(
+        type=Optional[TimestampWithTimezone],
+        validator=type_validator())
+    type = attr.ib(
+        type=RevisionType,
+        validator=type_validator())
+    directory = attr.ib(
+        type=Sha1Git,
+        validator=type_validator())
+    synthetic = attr.ib(
+        type=bool,
+        validator=type_validator())
+    metadata = attr.ib(
+        type=Optional[Dict[str, object]],
+        validator=type_validator(),
+        default=None)
+    parents = attr.ib(
+        type=List[Sha1Git],
+        validator=type_validator(),
+        default=attr.Factory(list))
+    id = attr.ib(
+        type=Sha1Git,
+        validator=type_validator(),
+        default=b'')
 
     @staticmethod
     def compute_hash(object_dict):
@@ -408,18 +496,30 @@ class Revision(BaseModel, HashableObject):
 
 @attr.s(frozen=True)
 class DirectoryEntry(BaseModel):
-    name = attr.ib(type=bytes)
-    type = attr.ib(type=str,
-                   validator=attr.validators.in_(['file', 'dir', 'rev']))
-    target = attr.ib(type=Sha1Git)
-    perms = attr.ib(type=int)
+    name = attr.ib(
+        type=bytes,
+        validator=type_validator())
+    type = attr.ib(
+        type=str,
+        validator=attr.validators.in_(['file', 'dir', 'rev']))
+    target = attr.ib(
+        type=Sha1Git,
+        validator=type_validator())
+    perms = attr.ib(
+        type=int,
+        validator=type_validator())
     """Usually one of the values of `swh.model.from_disk.DentryPerms`."""
 
 
 @attr.s(frozen=True)
 class Directory(BaseModel, HashableObject):
-    entries = attr.ib(type=List[DirectoryEntry])
-    id = attr.ib(type=Sha1Git, default=b'')
+    entries = attr.ib(
+        type=List[DirectoryEntry],
+        validator=type_validator())
+    id = attr.ib(
+        type=Sha1Git,
+        validator=type_validator(),
+        default=b'')
 
     @staticmethod
     def compute_hash(object_dict):
@@ -478,22 +578,37 @@ class BaseContent(BaseModel):
 
 @attr.s(frozen=True)
 class Content(BaseContent):
-    sha1 = attr.ib(type=bytes)
-    sha1_git = attr.ib(type=Sha1Git)
-    sha256 = attr.ib(type=bytes)
-    blake2s256 = attr.ib(type=bytes)
-
-    length = attr.ib(type=int)
+    sha1 = attr.ib(
+        type=bytes,
+        validator=type_validator())
+    sha1_git = attr.ib(
+        type=Sha1Git,
+        validator=type_validator())
+    sha256 = attr.ib(
+        type=bytes,
+        validator=type_validator())
+    blake2s256 = attr.ib(
+        type=bytes,
+        validator=type_validator())
+
+    length = attr.ib(
+        type=int,
+        validator=type_validator())
 
     status = attr.ib(
         type=str,
-        default='visible',
-        validator=attr.validators.in_(['visible', 'hidden']))
+        validator=attr.validators.in_(['visible', 'hidden']),
+        default='visible')
 
-    data = attr.ib(type=Optional[bytes], default=None)
+    data = attr.ib(
+        type=Optional[bytes],
+        validator=type_validator(),
+        default=None)
 
-    ctime = attr.ib(type=Optional[datetime.datetime],
-                    default=None)
+    ctime = attr.ib(
+        type=Optional[datetime.datetime],
+        validator=type_validator(),
+        default=None)
 
     @length.validator
     def check_length(self, attribute, value):
@@ -535,24 +650,40 @@ class Content(BaseContent):
 
 @attr.s(frozen=True)
 class SkippedContent(BaseContent):
-    sha1 = attr.ib(type=Optional[bytes])
-    sha1_git = attr.ib(type=Optional[Sha1Git])
-    sha256 = attr.ib(type=Optional[bytes])
-    blake2s256 = attr.ib(type=Optional[bytes])
-
-    length = attr.ib(type=Optional[int])
+    sha1 = attr.ib(
+        type=Optional[bytes],
+        validator=type_validator())
+    sha1_git = attr.ib(
+        type=Optional[Sha1Git],
+        validator=type_validator())
+    sha256 = attr.ib(
+        type=Optional[bytes],
+        validator=type_validator())
+    blake2s256 = attr.ib(
+        type=Optional[bytes],
+        validator=type_validator())
+
+    length = attr.ib(
+        type=Optional[int],
+        validator=type_validator())
 
     status = attr.ib(
         type=str,
         validator=attr.validators.in_(['absent']))
-    reason = attr.ib(type=Optional[str],
-                     default=None)
-
-    origin = attr.ib(type=Optional[Origin],
-                     default=None)
-
-    ctime = attr.ib(type=Optional[datetime.datetime],
-                    default=None)
+    reason = attr.ib(
+        type=Optional[str],
+        validator=type_validator(),
+        default=None)
+
+    origin = attr.ib(
+        type=Optional[Origin],
+        validator=type_validator(),
+        default=None)
+
+    ctime = attr.ib(
+        type=Optional[datetime.datetime],
+        validator=type_validator(),
+        default=None)
 
     @reason.validator
     def check_reason(self, attribute, value):
-- 
GitLab