Skip to content
Snippets Groups Projects
Verified Commit f2422d65 authored by Antoine R. Dumont's avatar Antoine R. Dumont
Browse files

identifiers: Validate that inputs are correct

Related T1104
parent dfb128e9
No related branches found
No related tags found
No related merge requests found
...@@ -609,6 +609,12 @@ def persistent_identifier(type, object, version=1): ...@@ -609,6 +609,12 @@ def persistent_identifier(type, object, version=1):
identifier identifier
version (int): persistent identifier version (default to 1) version (int): persistent identifier version (default to 1)
Raises:
ValidationError (class) in case of:
invalid type
invalid hash object
Returns: Returns:
Persistent identifier as string. Persistent identifier as string.
...@@ -635,11 +641,16 @@ def persistent_identifier(type, object, version=1): ...@@ -635,11 +641,16 @@ def persistent_identifier(type, object, version=1):
'key_id': 'sha1_git' 'key_id': 'sha1_git'
}, },
} }
o = _map[type] o = _map.get(type)
if not o:
raise ValidationError('Wrong input: Supported types are %s' % (
list(_map.keys())))
if isinstance(object, dict): # internal swh representation resolution if isinstance(object, dict): # internal swh representation resolution
_hash = object[o['key_id']] _hash = object[o['key_id']]
else: # client passed direct identifier (bytes/str) else: # client passed direct identifier (bytes/str)
_hash = object _hash = object
validate_sha1(_hash) # can raise if invalid hash
_hash = hash_to_hex(_hash) _hash = hash_to_hex(_hash)
return 'swh:%s:%s:%s' % (version, o['short_name'], _hash) return 'swh:%s:%s:%s' % (version, o['short_name'], _hash)
......
...@@ -11,6 +11,7 @@ from nose.tools import istest ...@@ -11,6 +11,7 @@ from nose.tools import istest
from swh.model import hashutil, identifiers from swh.model import hashutil, identifiers
from swh.model.exceptions import ValidationError
from swh.model.identifiers import SNAPSHOT, RELEASE, REVISION, DIRECTORY from swh.model.identifiers import SNAPSHOT, RELEASE, REVISION, DIRECTORY
from swh.model.identifiers import CONTENT from swh.model.identifiers import CONTENT
...@@ -816,6 +817,18 @@ class SnapshotIdentifier(unittest.TestCase): ...@@ -816,6 +817,18 @@ class SnapshotIdentifier(unittest.TestCase):
self.assertEquals(actual_value, expected_persistent_id) self.assertEquals(actual_value, expected_persistent_id)
def test_persistent_identifier_wrong_input(self):
_snapshot_id = 'notahash4bc0bf3d81436bf980b46e98bd338453'
_snapshot = {'id': _snapshot_id}
for _type, _hash, _error in [
(SNAPSHOT, _snapshot_id, 'Unexpected characters'),
(SNAPSHOT, _snapshot, 'Unexpected characters'),
('foo', '', 'Wrong input: Supported types are'),
]:
with self.assertRaisesRegex(ValidationError, _error):
identifiers.persistent_identifier(_type, _hash)
def test_parse_persistent_identifier(self): def test_parse_persistent_identifier(self):
for pid, _type, _version, _hash in [ for pid, _type, _version, _hash in [
('swh:1:cnt:94a9ed024d3859793618152ea559a168bbcbb5e2', 'cnt', ('swh:1:cnt:94a9ed024d3859793618152ea559a168bbcbb5e2', 'cnt',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment