From 375832f51bf5cf96b5e9cbc0b6af5cf1c9dc4b84 Mon Sep 17 00:00:00 2001
From: Daniele Serafini <seraf.daniele@gmail.com>
Date: Fri, 4 Oct 2019 19:08:16 +0200
Subject: [PATCH] PID: move validation checks to PersistentId constructor

... from test_persistent_identifier.

Closes T1986
---
 swh/model/identifiers.py            | 29 ++++--------
 swh/model/tests/test_identifiers.py | 73 +++++++++++++++--------------
 2 files changed, 46 insertions(+), 56 deletions(-)

diff --git a/swh/model/identifiers.py b/swh/model/identifiers.py
index 6a000681..9bf7c6cb 100644
--- a/swh/model/identifiers.py
+++ b/swh/model/identifiers.py
@@ -695,6 +695,13 @@ class PersistentId(_PersistentId):
         if not o:
             raise ValidationError('Wrong input: Supported types are %s' % (
                 list(_object_type_map.keys())))
+        if namespace != PID_NAMESPACE:
+            raise ValidationError(
+                "Wrong format: only supported namespace is '%s'"
+                % PID_NAMESPACE)
+        if scheme_version != PID_VERSION:
+            raise ValidationError(
+                'Wrong format: only supported version is %d' % PID_VERSION)
         # internal swh representation resolution
         if isinstance(object_id, dict):
             object_id = object_id[o['key_id']]
@@ -773,22 +780,8 @@ def parse_persistent_identifier(persistent_id):
 
     # Checking for parsing errors
     _ns, _version, _type, _id = pid_data
-    if _ns != PID_NAMESPACE:
-        raise ValidationError(
-            "Wrong format: only supported namespace is '%s'" % PID_NAMESPACE)
-
-    if _version != str(PID_VERSION):
-        raise ValidationError(
-            'Wrong format: only supported version is %d' % PID_VERSION)
-
     pid_data[1] = int(pid_data[1])
 
-    expected_types = PID_TYPES
-    if _type not in expected_types:
-        raise ValidationError(
-            'Wrong format: Supported types are %s' % (
-                ', '.join(expected_types)))
-
     for otype, data in _object_type_map.items():
         if _type == data['short_name']:
             pid_data[2] = otype
@@ -798,12 +791,6 @@ def parse_persistent_identifier(persistent_id):
         raise ValidationError(
             'Wrong format: Identifier should be present')
 
-    try:
-        validate_sha1(_id)
-    except ValidationError:
-        raise ValidationError(
-           'Wrong format: Identifier should be a valid hash')
-
     persistent_id_metadata = {}
     for part in persistent_id_parts:
         try:
@@ -813,4 +800,4 @@ def parse_persistent_identifier(persistent_id):
             msg = 'Contextual data is badly formatted, form key=val expected'
             raise ValidationError(msg)
     pid_data.append(persistent_id_metadata)
-    return PersistentId._make(pid_data)
+    return PersistentId(*pid_data)
diff --git a/swh/model/tests/test_identifiers.py b/swh/model/tests/test_identifiers.py
index 83294d5a..9e6cd571 100644
--- a/swh/model/tests/test_identifiers.py
+++ b/swh/model/tests/test_identifiers.py
@@ -10,8 +10,8 @@ import unittest
 from swh.model import hashutil, identifiers
 from swh.model.exceptions import ValidationError
 from swh.model.identifiers import (CONTENT, DIRECTORY,
-                                   PID_TYPES, RELEASE,
-                                   REVISION, SNAPSHOT, PersistentId)
+                                   RELEASE, REVISION,
+                                   SNAPSHOT, PersistentId)
 
 
 class UtilityFunctionsIdentifier(unittest.TestCase):
@@ -768,8 +768,8 @@ class SnapshotIdentifier(unittest.TestCase):
                  'swh:1:snp:c7c108084bc0bf3d81436bf980b46e98bd338453',
                  None, {}),
                 (RELEASE, _release_id,
-                 'swh:2:rel:22ece559cc7cc2364edc5e5593d63ae8bd229f9f',
-                 2, {}),
+                 'swh:1:rel:22ece559cc7cc2364edc5e5593d63ae8bd229f9f',
+                 1, {}),
                 (REVISION, _revision_id,
                  'swh:1:rev:309cf2674ee7a0749978cf8265ab91a60aea0f7d',
                  None, {}),
@@ -783,8 +783,8 @@ class SnapshotIdentifier(unittest.TestCase):
                  'swh:1:snp:c7c108084bc0bf3d81436bf980b46e98bd338453',
                  None, {}),
                 (RELEASE, _release,
-                 'swh:2:rel:22ece559cc7cc2364edc5e5593d63ae8bd229f9f',
-                 2, {}),
+                 'swh:1:rel:22ece559cc7cc2364edc5e5593d63ae8bd229f9f',
+                 1, {}),
                 (REVISION, _revision,
                  'swh:1:rev:309cf2674ee7a0749978cf8265ab91a60aea0f7d',
                  None, {}),
@@ -811,12 +811,12 @@ class SnapshotIdentifier(unittest.TestCase):
         _snapshot_id = 'notahash4bc0bf3d81436bf980b46e98bd338453'
         _snapshot = {'id': _snapshot_id}
 
-        for _type, _hash, _error in [
-                (SNAPSHOT, _snapshot_id, 'Unexpected characters'),
-                (SNAPSHOT, _snapshot, 'Unexpected characters'),
-                ('foo', '', 'Wrong input: Supported types are'),
+        for _type, _hash in [
+                (SNAPSHOT, _snapshot_id),
+                (SNAPSHOT, _snapshot),
+                ('foo', ''),
         ]:
-            with self.assertRaisesRegex(ValidationError, _error):
+            with self.assertRaises(ValidationError):
                 identifiers.persistent_identifier(_type, _hash)
 
     def test_parse_persistent_identifier(self):
@@ -866,34 +866,37 @@ class SnapshotIdentifier(unittest.TestCase):
             self.assertEqual(actual_result, expected_result)
 
     def test_parse_persistent_identifier_parsing_error(self):
-        for pid, _error in [
-                ('swh:1:cnt',
-                 'Wrong format: There should be 4 mandatory values'),
-                ('swh:1:',
-                 'Wrong format: There should be 4 mandatory values'),
-                ('swh:',
-                 'Wrong format: There should be 4 mandatory values'),
-                ('swh:1:cnt:',
-                 'Wrong format: Identifier should be present'),
-                ('foo:1:cnt:abc8bc9d7a6bcf6db04f476d29314f157507d505',
-                 'Wrong format: only supported namespace is \'swh\''),
-                ('swh:2:dir:def8bc9d7a6bcf6db04f476d29314f157507d505',
-                 'Wrong format: only supported version is 1'),
-                ('swh:1:foo:fed8bc9d7a6bcf6db04f476d29314f157507d505',
-                 'Wrong format: Supported types are %s' % (
-                     ', '.join(PID_TYPES))),
+        for pid in [
+                ('swh:1:cnt'),
+                ('swh:1:'),
+                ('swh:'),
+                ('swh:1:cnt:'),
+                ('foo:1:cnt:abc8bc9d7a6bcf6db04f476d29314f157507d505'),
+                ('swh:2:dir:def8bc9d7a6bcf6db04f476d29314f157507d505'),
+                ('swh:1:foo:fed8bc9d7a6bcf6db04f476d29314f157507d505'),
                 ('swh:1:dir:0b6959356d30f1a4e9b7f6bca59b9a336464c03d;invalid;'
-                 'malformed',
-                 'Contextual data is badly formatted, form key=val expected'),
-                ('swh:1:snp:gh6959356d30f1a4e9b7f6bca59b9a336464c03d',
-                 'Wrong format: Identifier should be a valid hash'),
-                ('swh:1:snp:foo',
-                 'Wrong format: Identifier should be a valid hash')
+                 'malformed'),
+                ('swh:1:snp:gh6959356d30f1a4e9b7f6bca59b9a336464c03d'),
+                ('swh:1:snp:foo'),
         ]:
-            with self.assertRaisesRegex(
-                    ValidationError, _error):
+            with self.assertRaises(ValidationError):
                 identifiers.parse_persistent_identifier(pid)
 
+    def test_persistentid_class_validation_error(self):
+        for _ns, _version, _type, _id in [
+            ('foo', 1, CONTENT, 'abc8bc9d7a6bcf6db04f476d29314f157507d505'),
+            ('swh', 2, DIRECTORY, 'def8bc9d7a6bcf6db04f476d29314f157507d505'),
+            ('swh', 1, 'foo', 'fed8bc9d7a6bcf6db04f476d29314f157507d505'),
+            ('swh', 1, SNAPSHOT, 'gh6959356d30f1a4e9b7f6bca59b9a336464c03d'),
+        ]:
+            with self.assertRaises(ValidationError):
+                PersistentId(
+                    namespace=_ns,
+                    scheme_version=_version,
+                    object_type=_type,
+                    object_id=_id
+                )
+
 
 class OriginIdentifier(unittest.TestCase):
     def setUp(self):
-- 
GitLab