From efc7e725991ad68f0d233d9cc713e42bead53954 Mon Sep 17 00:00:00 2001
From: Valentin Lorentz <vlorentz@softwareheritage.org>
Date: Fri, 12 Apr 2019 15:51:15 +0200
Subject: [PATCH] Add a from_dict() method to model classes, that does the
 inverse of to_dict().

---
 requirements.txt              |   1 +
 swh/model/model.py            | 180 +++++++++++++++++++++++++++-------
 swh/model/tests/test_model.py |  14 +++
 3 files changed, 158 insertions(+), 37 deletions(-)
 create mode 100644 swh/model/tests/test_model.py

diff --git a/requirements.txt b/requirements.txt
index cd97184a..59623453 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,3 +5,4 @@ vcversioner
 Click
 attrs
 hypothesis
+python-dateutil
diff --git a/swh/model/model.py b/swh/model/model.py
index 036879de..5b3d947d 100644
--- a/swh/model/model.py
+++ b/swh/model/model.py
@@ -8,21 +8,81 @@ from enum import Enum
 from typing import List, Optional, Dict
 
 import attr
+import dateutil.parser
 
+from .identifiers import normalize_timestamp
 
 # TODO: Limit this to 20 bytes
 Sha1Git = bytes
 
 
+def contains_optional_validator(validator):
+    """Inspects an attribute's validator to find its type.
+    Inspired by `hypothesis/searchstrategy/attrs.py`."""
+    if isinstance(validator, attr.validators._OptionalValidator):
+        return True
+    elif isinstance(validator, attr.validators._AndValidator):
+        for validator in validator._validators:
+            res = contains_optional_validator(validator)
+            if res:
+                return True
+    else:
+        return False
+
+
+class BaseModel:
+    """Base class for SWH model classes.
+
+    Provides serialization/deserialization to/from Python dictionaries,
+    that are suitable for JSON/msgpack-like formats."""
+
+    def to_dict(self):
+        """Wrapper of `attr.asdict` that can be overriden by subclasses
+        that have special handling of some of the fields."""
+        return attr.asdict(self)
+
+    @classmethod
+    def from_dict(cls, d):
+        """Takes a dictionary representing a tree of SWH objects, and
+        recursively builds the corresponding objects."""
+        if not isinstance(d, dict):
+            raise TypeError(
+                '%s.from_dict expects a dict, not %r' % (cls.__name__, d))
+        for (name, attribute) in attr.fields_dict(cls).items():
+            type_ = attribute.type
+
+            # Heuristic to detect `Optional[X]` and unwrap it to `X`.
+            if contains_optional_validator(attribute.validator):
+                if name not in d:
+                    continue
+                if d[name] is None:
+                    del d[name]
+                    continue
+                else:
+                    type_ = type_.__args__[0]
+
+            # Construct an object of the expected type
+            if issubclass(type_, BaseModel):
+                d[name] = type_.from_dict(d[name])
+            elif issubclass(type_, Enum):
+                d[name] = type_(d[name])
+            else:
+                pass
+
+        return cls(**d)
+
+
 @attr.s
-class Person:
+class Person(BaseModel):
+    """Represents the author/committer of a revision or release."""
     name = attr.ib(type=bytes)
     email = attr.ib(type=bytes)
     fullname = attr.ib(type=bytes)
 
 
 @attr.s
-class Timestamp:
+class Timestamp(BaseModel):
+    """Represents a naive timestamp from a VCS."""
     seconds = attr.ib(type=int)
     microseconds = attr.ib(type=int)
 
@@ -40,48 +100,66 @@ class Timestamp:
 
 
 @attr.s
-class TimestampWithTimezone:
+class TimestampWithTimezone(BaseModel):
+    """Represents a TZ-aware timestamp from a VCS."""
     timestamp = attr.ib(type=Timestamp)
     offset = attr.ib(type=int)
     negative_utc = attr.ib(type=bool)
 
-    def to_dict(self):
-        return attr.asdict(self)
-
     @offset.validator
     def check_offset(self, attribute, value):
+        """Checks the offset is a 16-bits signed integer (in theory, it
+        should always be between -14 and +14 hours)."""
         if not (-2**15 <= value < 2**15):
             # max 14 hours offset in theory, but you never know what
             # you'll find in the wild...
             raise ValueError('offset too large: %d minutes' % value)
 
+    @classmethod
+    def from_dict(cls, d):
+        """Builds a TimestampWithTimezone from any of the formats
+        accepted by :py:`swh.model.normalize_timestamp`."""
+        return super().from_dict(normalize_timestamp(d))
+
 
 @attr.s
-class Origin:
+class Origin(BaseModel):
+    """Represents a software source: a VCS and an URL."""
     type = attr.ib(type=str)
     url = attr.ib(type=str)
 
-    def to_dict(self):
-        return attr.asdict(self)
-
 
 @attr.s
-class OriginVisit:
+class OriginVisit(BaseModel):
+    """Represents a visit of an origin at a given point in time, by a
+    SWH loader."""
     origin = attr.ib(type=Origin)
     date = attr.ib(type=datetime.datetime)
-    visit = attr.ib(type=Optional[int])
+    visit = attr.ib(type=Optional[int],
+                    validator=attr.validators.optional([]))
     """Should not be set before calling 'origin_visit_add()'."""
 
     def to_dict(self):
-        ov = attr.asdict(self)
-        ov['origin'] = self.origin.to_dict()
+        """Serializes the date as a string and omits the visit id if it is
+        `None`."""
+        ov = super().to_dict()
         ov['date'] = str(self.date)
-        if not ov['visit']:
+        if ov['visit'] is None:
             del ov['visit']
         return ov
 
+    @classmethod
+    def from_dict(cls, d):
+        """Parses the date from a string, and accepts missing visit ids."""
+        return cls(
+            origin=Origin.from_dict(d['origin']),
+            date=dateutil.parser.parse(d['date']),
+            visit=d.get('visit'))
+
 
 class TargetType(Enum):
+    """The type of content pointed to by a snapshot branch. Usually a
+    revision or an alias."""
     CONTENT = 'content'
     DIRECTORY = 'directory'
     REVISION = 'revision'
@@ -91,6 +169,7 @@ class TargetType(Enum):
 
 
 class ObjectType(Enum):
+    """The type of content pointed to by a release. Usually a revision"""
     CONTENT = 'content'
     DIRECTORY = 'directory'
     REVISION = 'revision'
@@ -99,12 +178,15 @@ class ObjectType(Enum):
 
 
 @attr.s
-class SnapshotBranch:
+class SnapshotBranch(BaseModel):
+    """Represents one of the branches of a snapshot."""
     target = attr.ib(type=bytes)
     target_type = attr.ib(type=TargetType)
 
     @target.validator
     def check_target(self, attribute, value):
+        """Checks the target type is not an alias, checks the target is a
+        valid sha1_git."""
         if self.target_type != TargetType.ALIAS:
             if len(value) != 20:
                 raise ValueError('Wrong length for bytes identifier: %d' %
@@ -117,7 +199,8 @@ class SnapshotBranch:
 
 
 @attr.s
-class Snapshot:
+class Snapshot(BaseModel):
+    """Represents the full state of an origin at a given point in time."""
     id = attr.ib(type=Sha1Git)
     branches = attr.ib(type=Dict[bytes, Optional[SnapshotBranch]])
 
@@ -130,17 +213,36 @@ class Snapshot:
             }
         }
 
+    @classmethod
+    def from_dict(cls, d):
+        d['branches'] = {
+            name: SnapshotBranch.from_dict(branch)
+            for (name, branch) in d['branches'].items()
+        }
+        return cls(**d)
+
 
 @attr.s
-class Release:
+class Release(BaseModel):
     id = attr.ib(type=Sha1Git)
     name = attr.ib(type=bytes)
     message = attr.ib(type=bytes)
-    date = attr.ib(type=Optional[TimestampWithTimezone])
-    author = attr.ib(type=Optional[Person])
-    target = attr.ib(type=Optional[Sha1Git])
+    target = attr.ib(type=Optional[Sha1Git],
+                     validator=attr.validators.optional([]))
     target_type = attr.ib(type=ObjectType)
     synthetic = attr.ib(type=bool)
+    author = attr.ib(type=Optional[Person],
+                     default=None,
+                     validator=attr.validators.optional([]))
+    date = attr.ib(type=Optional[TimestampWithTimezone],
+                   default=None,
+                   validator=attr.validators.optional([]))
+
+    @author.validator
+    def check_author(self, attribute, value):
+        """If the author is `None`, checks the date is `None` too."""
+        if self.author is None and self.date is not None:
+            raise ValueError('release date must be None if author is None.')
 
     def to_dict(self):
         rel = attr.asdict(self)
@@ -148,11 +250,6 @@ class Release:
         rel['target_type'] = rel['target_type'].value
         return rel
 
-    @author.validator
-    def check_author(self, attribute, value):
-        if self.author is None and self.date is not None:
-            raise ValueError('release date must be None if date is None.')
-
 
 class RevisionType(Enum):
     GIT = 'git'
@@ -163,18 +260,21 @@ class RevisionType(Enum):
 
 
 @attr.s
-class Revision:
+class Revision(BaseModel):
     id = attr.ib(type=Sha1Git)
     message = attr.ib(type=bytes)
     author = attr.ib(type=Person)
     committer = attr.ib(type=Person)
     date = attr.ib(type=TimestampWithTimezone)
     committer_date = attr.ib(type=TimestampWithTimezone)
-    parents = attr.ib(type=List[Sha1Git])
     type = attr.ib(type=RevisionType)
     directory = attr.ib(type=Sha1Git)
-    metadata = attr.ib(type=Optional[Dict[str, object]])
     synthetic = attr.ib(type=bool)
+    metadata = attr.ib(type=Optional[Dict[str, object]],
+                       default=None,
+                       validator=attr.validators.optional([]))
+    parents = attr.ib(type=List[Sha1Git],
+                      default=attr.Factory(list))
 
     def to_dict(self):
         rev = attr.asdict(self)
@@ -185,7 +285,7 @@ class Revision:
 
 
 @attr.s
-class DirectoryEntry:
+class DirectoryEntry(BaseModel):
     name = attr.ib(type=bytes)
     type = attr.ib(type=str,
                    validator=attr.validators.in_(['file', 'dir', 'rev']))
@@ -193,12 +293,9 @@ class DirectoryEntry:
     perms = attr.ib(type=int)
     """Usually one of the values of `swh.model.from_disk.DentryPerms`."""
 
-    def to_dict(self):
-        return attr.asdict(self)
-
 
 @attr.s
-class Directory:
+class Directory(BaseModel):
     id = attr.ib(type=Sha1Git)
     entries = attr.ib(type=List[DirectoryEntry])
 
@@ -207,20 +304,29 @@ class Directory:
         dir_['entries'] = [entry.to_dict() for entry in self.entries]
         return dir_
 
+    @classmethod
+    def from_dict(cls, d):
+        d['entries'] = list(map(DirectoryEntry.from_dict, d['entries']))
+        return super().from_dict(d)
+
 
 @attr.s
-class Content:
+class Content(BaseModel):
     sha1 = attr.ib(type=bytes)
     sha1_git = attr.ib(type=Sha1Git)
     sha256 = attr.ib(type=bytes)
     blake2s256 = attr.ib(type=bytes)
 
-    data = attr.ib(type=bytes)
     length = attr.ib(type=int)
     status = attr.ib(
         type=str,
         validator=attr.validators.in_(['visible', 'absent', 'hidden']))
-    reason = attr.ib(type=Optional[str])
+    reason = attr.ib(type=Optional[str],
+                     default=None,
+                     validator=attr.validators.optional([]))
+    data = attr.ib(type=Optional[bytes],
+                   default=None,
+                   validator=attr.validators.optional([]))
 
     @length.validator
     def check_length(self, attribute, value):
diff --git a/swh/model/tests/test_model.py b/swh/model/tests/test_model.py
new file mode 100644
index 00000000..2a5452fb
--- /dev/null
+++ b/swh/model/tests/test_model.py
@@ -0,0 +1,14 @@
+# Copyright (C) 2019 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+from hypothesis import given
+
+from swh.model.hypothesis_strategies import objects
+
+
+@given(objects())
+def test_todict_inverse_fromdict(objtype_and_obj):
+    (obj_type, obj) = objtype_and_obj
+    assert obj == type(obj).from_dict(obj.to_dict())
-- 
GitLab