Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • anlambert/swh-model
  • lunar/swh-model
  • franckbret/swh-model
  • douardda/swh-model
  • olasd/swh-model
  • swh/devel/swh-model
  • Alphare/swh-model
  • samplet/swh-model
  • marmoute/swh-model
  • rboyer/swh-model
10 results
Show changes
Showing
with 5467 additions and 717 deletions
...@@ -33,11 +33,12 @@ ...@@ -33,11 +33,12 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE. # POSSIBILITY OF SUCH DAMAGE.
NON_FIELD_ERRORS = '__all__' NON_FIELD_ERRORS = "__all__"
class ValidationError(Exception): class ValidationError(Exception):
"""An error while validating data.""" """An error while validating data."""
def __init__(self, message, code=None, params=None): def __init__(self, message, code=None, params=None):
""" """
The `message` argument can be a single error, a list of errors, or a The `message` argument can be a single error, a list of errors, or a
...@@ -54,16 +55,15 @@ class ValidationError(Exception): ...@@ -54,16 +55,15 @@ class ValidationError(Exception):
message = message[0] message = message[0]
if isinstance(message, ValidationError): if isinstance(message, ValidationError):
if hasattr(message, 'error_dict'): if hasattr(message, "error_dict"):
message = message.error_dict message = message.error_dict
# PY2 has a `message` property which is always there so we can't # PY2 has a `message` property which is always there so we can't
# duck-type on it. It was introduced in Python 2.5 and already # duck-type on it. It was introduced in Python 2.5 and already
# deprecated in Python 2.6. # deprecated in Python 2.6.
elif not hasattr(message, 'message'): elif not hasattr(message, "message"):
message = message.error_list message = message.error_list
else: else:
message, code, params = (message.message, message.code, message, code, params = (message.message, message.code, message.params)
message.params)
if isinstance(message, dict): if isinstance(message, dict):
self.error_dict = {} self.error_dict = {}
...@@ -78,9 +78,8 @@ class ValidationError(Exception): ...@@ -78,9 +78,8 @@ class ValidationError(Exception):
# Normalize plain strings to instances of ValidationError. # Normalize plain strings to instances of ValidationError.
if not isinstance(message, ValidationError): if not isinstance(message, ValidationError):
message = ValidationError(message) message = ValidationError(message)
if hasattr(message, 'error_dict'): if hasattr(message, "error_dict"):
self.error_list.extend(sum(message.error_dict.values(), self.error_list.extend(sum(message.error_dict.values(), []))
[]))
else: else:
self.error_list.extend(message.error_list) self.error_list.extend(message.error_list)
...@@ -94,18 +93,18 @@ class ValidationError(Exception): ...@@ -94,18 +93,18 @@ class ValidationError(Exception):
def message_dict(self): def message_dict(self):
# Trigger an AttributeError if this ValidationError # Trigger an AttributeError if this ValidationError
# doesn't have an error_dict. # doesn't have an error_dict.
getattr(self, 'error_dict') getattr(self, "error_dict")
return dict(self) return dict(self)
@property @property
def messages(self): def messages(self):
if hasattr(self, 'error_dict'): if hasattr(self, "error_dict"):
return sum(dict(self).values(), []) return sum(dict(self).values(), [])
return list(self) return list(self)
def update_error_dict(self, error_dict): def update_error_dict(self, error_dict):
if hasattr(self, 'error_dict'): if hasattr(self, "error_dict"):
for field, error_list in self.error_dict.items(): for field, error_list in self.error_dict.items():
error_dict.setdefault(field, []).extend(error_list) error_dict.setdefault(field, []).extend(error_list)
else: else:
...@@ -113,7 +112,7 @@ class ValidationError(Exception): ...@@ -113,7 +112,7 @@ class ValidationError(Exception):
return error_dict return error_dict
def __iter__(self): def __iter__(self):
if hasattr(self, 'error_dict'): if hasattr(self, "error_dict"):
for field, errors in self.error_dict.items(): for field, errors in self.error_dict.items():
yield field, list(ValidationError(errors)) yield field, list(ValidationError(errors))
else: else:
...@@ -124,9 +123,13 @@ class ValidationError(Exception): ...@@ -124,9 +123,13 @@ class ValidationError(Exception):
yield message yield message
def __str__(self): def __str__(self):
if hasattr(self, 'error_dict'): if hasattr(self, "error_dict"):
return repr(dict(self)) return repr(dict(self))
return repr(list(self)) return repr(list(self))
def __repr__(self): def __repr__(self):
return 'ValidationError(%s)' % self return "ValidationError(%s)" % self
class InvalidDirectoryPath(Exception):
pass
...@@ -6,8 +6,13 @@ ...@@ -6,8 +6,13 @@
# We do our imports here but we don't use them, so flake8 complains # We do our imports here but we don't use them, so flake8 complains
# flake8: noqa # flake8: noqa
from .simple import (validate_type, validate_int, validate_str, validate_bytes, from .compound import validate_against_schema, validate_all_keys, validate_any_key
validate_datetime, validate_enum) from .hashes import validate_sha1, validate_sha1_git, validate_sha256
from .hashes import (validate_sha1, validate_sha1_git, validate_sha256) from .simple import (
from .compound import (validate_against_schema, validate_all_keys, validate_bytes,
validate_any_key) validate_datetime,
validate_enum,
validate_int,
validate_str,
validate_type,
)
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from collections import defaultdict from collections import defaultdict
import itertools import itertools
from ..exceptions import ValidationError, NON_FIELD_ERRORS from ..exceptions import NON_FIELD_ERRORS, ValidationError
def validate_against_schema(model, schema, value): def validate_against_schema(model, schema, value):
...@@ -26,19 +26,19 @@ def validate_against_schema(model, schema, value): ...@@ -26,19 +26,19 @@ def validate_against_schema(model, schema, value):
if not isinstance(value, dict): if not isinstance(value, dict):
raise ValidationError( raise ValidationError(
'Unexpected type %(type)s for %(model)s, expected dict', "Unexpected type %(type)s for %(model)s, expected dict",
params={ params={
'model': model, "model": model,
'type': value.__class__.__name__, "type": value.__class__.__name__,
}, },
code='model-unexpected-type', code="model-unexpected-type",
) )
errors = defaultdict(list) errors = defaultdict(list)
for key, (mandatory, validators) in itertools.chain( for key, (mandatory, validators) in itertools.chain(
((k, v) for k, v in schema.items() if k != NON_FIELD_ERRORS), ((k, v) for k, v in schema.items() if k != NON_FIELD_ERRORS),
[(NON_FIELD_ERRORS, (False, schema.get(NON_FIELD_ERRORS, [])))] [(NON_FIELD_ERRORS, (False, schema.get(NON_FIELD_ERRORS, [])))],
): ):
if not validators: if not validators:
continue continue
...@@ -54,9 +54,9 @@ def validate_against_schema(model, schema, value): ...@@ -54,9 +54,9 @@ def validate_against_schema(model, schema, value):
if mandatory: if mandatory:
errors[key].append( errors[key].append(
ValidationError( ValidationError(
'Field %(field)s is mandatory', "Field %(field)s is mandatory",
params={'field': key}, params={"field": key},
code='model-field-mandatory', code="model-field-mandatory",
) )
) )
...@@ -74,19 +74,21 @@ def validate_against_schema(model, schema, value): ...@@ -74,19 +74,21 @@ def validate_against_schema(model, schema, value):
else: else:
if not valid: if not valid:
errdata = { errdata = {
'validator': validator.__name__, "validator": validator.__name__,
} }
if key == NON_FIELD_ERRORS: if key == NON_FIELD_ERRORS:
errmsg = 'Validation of model %(model)s failed in ' \ errmsg = (
'%(validator)s' "Validation of model %(model)s failed in " "%(validator)s"
errdata['model'] = model )
errcode = 'model-validation-failed' errdata["model"] = model
errcode = "model-validation-failed"
else: else:
errmsg = 'Validation of field %(field)s failed in ' \ errmsg = (
'%(validator)s' "Validation of field %(field)s failed in " "%(validator)s"
errdata['field'] = key )
errcode = 'field-validation-failed' errdata["field"] = key
errcode = "field-validation-failed"
errors[key].append( errors[key].append(
ValidationError(errmsg, params=errdata, code=errcode) ValidationError(errmsg, params=errdata, code=errcode)
...@@ -102,11 +104,11 @@ def validate_all_keys(value, keys): ...@@ -102,11 +104,11 @@ def validate_all_keys(value, keys):
"""Validate that all the given keys are present in value""" """Validate that all the given keys are present in value"""
missing_keys = set(keys) - set(value) missing_keys = set(keys) - set(value)
if missing_keys: if missing_keys:
missing_fields = ', '.join(sorted(missing_keys)) missing_fields = ", ".join(sorted(missing_keys))
raise ValidationError( raise ValidationError(
'Missing mandatory fields %(missing_fields)s', "Missing mandatory fields %(missing_fields)s",
params={'missing_fields': missing_fields}, params={"missing_fields": missing_fields},
code='missing-mandatory-field' code="missing-mandatory-field",
) )
return True return True
...@@ -116,11 +118,11 @@ def validate_any_key(value, keys): ...@@ -116,11 +118,11 @@ def validate_any_key(value, keys):
"""Validate that any of the given keys is present in value""" """Validate that any of the given keys is present in value"""
present_keys = set(keys) & set(value) present_keys = set(keys) & set(value)
if not present_keys: if not present_keys:
missing_fields = ', '.join(sorted(keys)) missing_fields = ", ".join(sorted(keys))
raise ValidationError( raise ValidationError(
'Must contain one of the alternative fields %(missing_fields)s', "Must contain one of the alternative fields %(missing_fields)s",
params={'missing_fields': missing_fields}, params={"missing_fields": missing_fields},
code='missing-alternative-field', code="missing-alternative-field",
) )
return True return True
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# See top-level LICENSE file for more information # See top-level LICENSE file for more information
import string import string
from ..exceptions import ValidationError from ..exceptions import ValidationError
...@@ -22,22 +23,22 @@ def validate_hash(value, hash_type): ...@@ -22,22 +23,22 @@ def validate_hash(value, hash_type):
""" """
hash_lengths = { hash_lengths = {
'sha1': 20, "sha1": 20,
'sha1_git': 20, "sha1_git": 20,
'sha256': 32, "sha256": 32,
} }
hex_digits = set(string.hexdigits) hex_digits = set(string.hexdigits)
if hash_type not in hash_lengths: if hash_type not in hash_lengths:
raise ValidationError( raise ValidationError(
'Unexpected hash type %(hash_type)s, expected one of' "Unexpected hash type %(hash_type)s, expected one of" " %(hash_types)s",
' %(hash_types)s',
params={ params={
'hash_type': hash_type, "hash_type": hash_type,
'hash_types': ', '.join(sorted(hash_lengths)), "hash_types": ", ".join(sorted(hash_lengths)),
}, },
code='unexpected-hash-type') code="unexpected-hash-type",
)
if isinstance(value, str): if isinstance(value, str):
errors = [] errors = []
...@@ -48,10 +49,10 @@ def validate_hash(value, hash_type): ...@@ -48,10 +49,10 @@ def validate_hash(value, hash_type):
"Unexpected characters `%(unexpected_chars)s' for hash " "Unexpected characters `%(unexpected_chars)s' for hash "
"type %(hash_type)s", "type %(hash_type)s",
params={ params={
'unexpected_chars': ', '.join(sorted(extra_chars)), "unexpected_chars": ", ".join(sorted(extra_chars)),
'hash_type': hash_type, "hash_type": hash_type,
}, },
code='unexpected-hash-contents', code="unexpected-hash-contents",
) )
) )
...@@ -60,14 +61,14 @@ def validate_hash(value, hash_type): ...@@ -60,14 +61,14 @@ def validate_hash(value, hash_type):
if length != expected_length: if length != expected_length:
errors.append( errors.append(
ValidationError( ValidationError(
'Unexpected length %(length)d for hash type ' "Unexpected length %(length)d for hash type "
'%(hash_type)s, expected %(expected_length)d', "%(hash_type)s, expected %(expected_length)d",
params={ params={
'length': length, "length": length,
'expected_length': expected_length, "expected_length": expected_length,
'hash_type': hash_type, "hash_type": hash_type,
}, },
code='unexpected-hash-length', code="unexpected-hash-length",
) )
) )
...@@ -81,37 +82,37 @@ def validate_hash(value, hash_type): ...@@ -81,37 +82,37 @@ def validate_hash(value, hash_type):
expected_length = hash_lengths[hash_type] expected_length = hash_lengths[hash_type]
if length != expected_length: if length != expected_length:
raise ValidationError( raise ValidationError(
'Unexpected length %(length)d for hash type ' "Unexpected length %(length)d for hash type "
'%(hash_type)s, expected %(expected_length)d', "%(hash_type)s, expected %(expected_length)d",
params={ params={
'length': length, "length": length,
'expected_length': expected_length, "expected_length": expected_length,
'hash_type': hash_type, "hash_type": hash_type,
}, },
code='unexpected-hash-length', code="unexpected-hash-length",
) )
return True return True
raise ValidationError( raise ValidationError(
'Unexpected type %(type)s for hash, expected str or bytes', "Unexpected type %(type)s for hash, expected str or bytes",
params={ params={
'type': value.__class__.__name__, "type": value.__class__.__name__,
}, },
code='unexpected-hash-value-type', code="unexpected-hash-value-type",
) )
def validate_sha1(sha1): def validate_sha1(sha1):
"""Validate that sha1 is a valid sha1 hash""" """Validate that sha1 is a valid sha1 hash"""
return validate_hash(sha1, 'sha1') return validate_hash(sha1, "sha1")
def validate_sha1_git(sha1_git): def validate_sha1_git(sha1_git):
"""Validate that sha1_git is a valid sha1_git hash""" """Validate that sha1_git is a valid sha1_git hash"""
return validate_hash(sha1_git, 'sha1_git') return validate_hash(sha1_git, "sha1_git")
def validate_sha256(sha256): def validate_sha256(sha256):
"""Validate that sha256 is a valid sha256 hash""" """Validate that sha256 is a valid sha256 hash"""
return validate_hash(sha256, 'sha256') return validate_hash(sha256, "sha256")
...@@ -13,16 +13,16 @@ def validate_type(value, type): ...@@ -13,16 +13,16 @@ def validate_type(value, type):
"""Validate that value is an integer""" """Validate that value is an integer"""
if not isinstance(value, type): if not isinstance(value, type):
if isinstance(type, tuple): if isinstance(type, tuple):
typestr = 'one of %s' % ', '.join(typ.__name__ for typ in type) typestr = "one of %s" % ", ".join(typ.__name__ for typ in type)
else: else:
typestr = type.__name__ typestr = type.__name__
raise ValidationError( raise ValidationError(
'Unexpected type %(type)s, expected %(expected_type)s', "Unexpected type %(type)s, expected %(expected_type)s",
params={ params={
'type': value.__class__.__name__, "type": value.__class__.__name__,
'expected_type': typestr, "expected_type": typestr,
}, },
code='unexpected-type' code="unexpected-type",
) )
return True return True
...@@ -54,10 +54,12 @@ def validate_datetime(value): ...@@ -54,10 +54,12 @@ def validate_datetime(value):
errors.append(e) errors.append(e)
if isinstance(value, datetime.datetime) and value.tzinfo is None: if isinstance(value, datetime.datetime) and value.tzinfo is None:
errors.append(ValidationError( errors.append(
'Datetimes must be timezone-aware in swh', ValidationError(
code='datetime-without-tzinfo', "Datetimes must be timezone-aware in swh",
)) code="datetime-without-tzinfo",
)
)
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
...@@ -69,12 +71,12 @@ def validate_enum(value, expected_values): ...@@ -69,12 +71,12 @@ def validate_enum(value, expected_values):
"""Validate that value is contained in expected_values""" """Validate that value is contained in expected_values"""
if value not in expected_values: if value not in expected_values:
raise ValidationError( raise ValidationError(
'Unexpected value %(value)s, expected one of %(expected_values)s', "Unexpected value %(value)s, expected one of %(expected_values)s",
params={ params={
'value': value, "value": value,
'expected_values': ', '.join(sorted(expected_values)), "expected_values": ", ".join(sorted(expected_values)),
}, },
code='unexpected-value', code="unexpected-value",
) )
return True return True
This diff is collapsed.
This diff is collapsed.
# Copyright (C) 2015 The Software Heritage developers # Copyright (C) 2015-2018 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution # See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version # License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information # See top-level LICENSE file for more information
"""Module in charge of hashing function definitions. This is the base
module use to compute swh's hashes.
Only a subset of hashing algorithms is supported as defined in the
ALGORITHMS set. Any provided algorithms not in that list will result
in a ValueError explaining the error.
This module defines a MultiHash class to ease the softwareheritage
hashing algorithms computation. This allows to compute hashes from
file object, path, data using a similar interface as what the standard
hashlib module provides.
Basic usage examples:
- file object: MultiHash.from_file(
file_object, hash_names=DEFAULT_ALGORITHMS).digest()
- path (filepath): MultiHash.from_path(b'foo').hexdigest()
- data (bytes): MultiHash.from_data(b'foo').bytehexdigest()
"Complex" usage, defining a swh hashlib instance first:
- To compute length, integrate the length to the set of algorithms to
compute, for example:
.. code-block:: python
h = MultiHash(hash_names=set({'length'}).union(DEFAULT_ALGORITHMS))
with open(filepath, 'rb') as f:
h.update(f.read(HASH_BLOCK_SIZE))
hashes = h.digest() # returns a dict of {hash_algo_name: hash_in_bytes}
- Write alongside computing hashing algorithms (from a stream), example:
.. code-block:: python
h = MultiHash(length=length)
with open(filepath, 'wb') as f:
for chunk in r.iter_content(): # r a stream of sort
h.update(chunk)
f.write(chunk)
hashes = h.hexdigest() # returns a dict of {hash_algo_name: hash_in_hex}
"""
import binascii import binascii
import functools import functools
import hashlib import hashlib
from io import BytesIO from io import BytesIO
import os import os
from typing import Callable, Dict, Optional, Union
ALGORITHMS = set(
["sha1", "sha256", "sha1_git", "blake2s256", "blake2b512", "md5", "sha512"]
)
"""Hashing algorithms supported by this module"""
DEFAULT_ALGORITHMS = set(["sha1", "sha256", "sha1_git", "blake2s256"])
"""Algorithms computed by default when calling the functions from this module.
# supported hashing algorithms Subset of :const:`ALGORITHMS`.
ALGORITHMS = set(['sha1', 'sha256', 'sha1_git']) """
# should be a multiple of 64 (sha1/sha256's block size)
# FWIW coreutils' sha1sum uses 32768
HASH_BLOCK_SIZE = 32768 HASH_BLOCK_SIZE = 32768
"""Block size for streaming hash computations made in this module"""
_blake2_hash_cache: Dict[str, Callable] = {}
class MultiHash:
"""Hashutil class to support multiple hashes computation.
def _new_git_hash(base_algo, git_type, length): Args:
"""Initialize a digest object (as returned by python's hashlib) for the
requested algorithm, and feed it with the header for a git object of the hash_names (set): Set of hash algorithms (+ optionally length)
given type and length. to compute hashes (cf. DEFAULT_ALGORITHMS)
length (int): Length of the total sum of chunks to read
If the length is provided as algorithm, the length is also
computed and returned.
The header for hashing a git object consists of: """
def __init__(self, hash_names=DEFAULT_ALGORITHMS, length=None):
self.state = {}
self.track_length = False
for name in hash_names:
if name == "length":
self.state["length"] = 0
self.track_length = True
else:
self.state[name] = _new_hash(name, length)
@classmethod
def from_state(cls, state, track_length):
ret = cls([])
ret.state = state
ret.track_length = track_length
@classmethod
def from_file(cls, fobj, hash_names=DEFAULT_ALGORITHMS, length=None):
ret = cls(length=length, hash_names=hash_names)
while True:
chunk = fobj.read(HASH_BLOCK_SIZE)
if not chunk:
break
ret.update(chunk)
return ret
@classmethod
def from_path(cls, path, hash_names=DEFAULT_ALGORITHMS):
length = os.path.getsize(path)
with open(path, "rb") as f:
ret = cls.from_file(f, hash_names=hash_names, length=length)
return ret
@classmethod
def from_data(cls, data, hash_names=DEFAULT_ALGORITHMS):
length = len(data)
fobj = BytesIO(data)
return cls.from_file(fobj, hash_names=hash_names, length=length)
def update(self, chunk):
for name, h in self.state.items():
if name == "length":
continue
h.update(chunk)
if self.track_length:
self.state["length"] += len(chunk)
def digest(self):
return {
name: h.digest() if name != "length" else h
for name, h in self.state.items()
}
def hexdigest(self):
return {
name: h.hexdigest() if name != "length" else h
for name, h in self.state.items()
}
def bytehexdigest(self):
return {
name: hash_to_bytehex(h.digest()) if name != "length" else h
for name, h in self.state.items()
}
def copy(self):
copied_state = {
name: h.copy() if name != "length" else h for name, h in self.state.items()
}
return self.from_state(copied_state, self.track_length)
def _new_blake2_hash(algo):
"""Return a function that initializes a blake2 hash."""
if algo in _blake2_hash_cache:
return _blake2_hash_cache[algo]()
lalgo = algo.lower()
if not lalgo.startswith("blake2"):
raise ValueError("Algorithm %s is not a blake2 hash" % algo)
blake_family = lalgo[:7]
digest_size = None
if lalgo[7:]:
try:
digest_size, remainder = divmod(int(lalgo[7:]), 8)
except ValueError:
raise ValueError("Unknown digest size for algo %s" % algo) from None
if remainder:
raise ValueError(
"Digest size for algorithm %s must be a multiple of 8" % algo
)
blake2 = getattr(hashlib, blake_family)
_blake2_hash_cache[algo] = lambda: blake2(digest_size=digest_size)
return _blake2_hash_cache[algo]()
def _new_hashlib_hash(algo):
"""Initialize a digest object from hashlib.
Handle the swh-specific names for the blake2-related algorithms
"""
if algo.startswith("blake2"):
return _new_blake2_hash(algo)
else:
return hashlib.new(algo)
def git_object_header(git_type: str, length: int) -> bytes:
"""Returns the header for a git object of the given type and length.
The header of a git object consists of:
- The type of the object (encoded in ASCII) - The type of the object (encoded in ASCII)
- One ASCII space (\x20) - One ASCII space (\x20)
- The length of the object (decimal encoded in ASCII) - The length of the object (decimal encoded in ASCII)
- One NUL byte - One NUL byte
Args: Args:
base_algo: a hashlib-supported algorithm base_algo (str from :const:`ALGORITHMS`): a hashlib-supported algorithm
git_type: the type of the git object (supposedly one of 'blob', git_type: the type of the git object (supposedly one of 'blob',
'commit', 'tag', 'tree') 'commit', 'tag', 'tree')
length: the length of the git object you're encoding length: the length of the git object you're encoding
...@@ -37,25 +218,36 @@ def _new_git_hash(base_algo, git_type, length): ...@@ -37,25 +218,36 @@ def _new_git_hash(base_algo, git_type, length):
Returns: Returns:
a hashutil.hash object a hashutil.hash object
""" """
git_object_types = {
"blob",
"tree",
"commit",
"tag",
"snapshot",
"raw_extrinsic_metadata",
"extid",
}
h = hashlib.new(base_algo) if git_type not in git_object_types:
git_header = '%s %d\0' % (git_type, length) raise ValueError(
h.update(git_header.encode('ascii')) "Unexpected git object type %s, expected one of %s"
% (git_type, ", ".join(sorted(git_object_types)))
)
return h return ("%s %d\0" % (git_type, length)).encode("ascii")
def _new_hash(algo, length=None): def _new_hash(algo: str, length: Optional[int] = None):
"""Initialize a digest object (as returned by python's hashlib) for the """Initialize a digest object (as returned by python's hashlib) for
requested algorithm. See the constant ALGORITHMS for the list of supported the requested algorithm. See the constant ALGORITHMS for the list
algorithms. If a git-specific hashing algorithm is requested (e.g., of supported algorithms. If a git-specific hashing algorithm is
"sha1_git"), the hashing object will be pre-fed with the needed header; for requested (e.g., "sha1_git"), the hashing object will be pre-fed
this to work, length must be given. with the needed header; for this to work, length must be given.
Args: Args:
algo: a hashing algorithm (one of ALGORITHMS) algo (str): a hashing algorithm (one of ALGORITHMS)
length: the length of the hashed payload (needed for git-specific length (int): the length of the hashed payload (needed for
algorithms) git-specific algorithms)
Returns: Returns:
a hashutil.hash object a hashutil.hash object
...@@ -63,125 +255,99 @@ def _new_hash(algo, length=None): ...@@ -63,125 +255,99 @@ def _new_hash(algo, length=None):
Raises: Raises:
ValueError if algo is unknown, or length is missing for a git-specific ValueError if algo is unknown, or length is missing for a git-specific
hash. hash.
""" """
if algo not in ALGORITHMS: if algo not in ALGORITHMS:
raise ValueError('Unexpected hashing algorithm %s, ' raise ValueError(
'expected one of %s' % "Unexpected hashing algorithm %s, expected one of %s"
(algo, ', '.join(sorted(ALGORITHMS)))) % (algo, ", ".join(sorted(ALGORITHMS)))
)
h = None if algo.endswith("_git"):
if algo.endswith('_git'):
if length is None: if length is None:
raise ValueError('Missing length for git hashing algorithm') raise ValueError("Missing length for git hashing algorithm")
base_algo = algo[:-4] base_algo = algo[:-4]
h = _new_git_hash(base_algo, 'blob', length) h = _new_hashlib_hash(base_algo)
else: h.update(git_object_header("blob", length))
h = hashlib.new(algo) return h
return h return _new_hashlib_hash(algo)
def hash_file(fobj, length=None, algorithms=ALGORITHMS, chunk_cb=None): def hash_git_data(data, git_type, base_algo="sha1"):
"""Hash the contents of the given file object with the given algorithms. """Hash the given data as a git object of type git_type.
Args: Args:
fobj: a file-like object data: a bytes object
length: the length of the contents of the file-like object (for the git_type: the git object type
git-specific algorithms) base_algo: the base hashing algorithm used (default: sha1)
algorithms: the hashing algorithms used
Returns: a dict mapping each algorithm to a bytes digest. Returns: a dict mapping each algorithm to a bytes digest
Raises: Raises:
ValueError if algorithms contains an unknown hash algorithm. ValueError if the git_type is unexpected.
""" """
hashes = {algo: _new_hash(algo, length) for algo in algorithms} h = _new_hashlib_hash(base_algo)
h.update(git_object_header(git_type, len(data)))
while True: h.update(data)
chunk = fobj.read(HASH_BLOCK_SIZE)
if not chunk:
break
for hash in hashes.values():
hash.update(chunk)
if chunk_cb:
chunk_cb(chunk)
return {algo: hash.digest() for algo, hash in hashes.items()} return h.digest()
def hash_path(path, algorithms=ALGORITHMS, chunk_cb=None): @functools.lru_cache()
"""Hash the contents of the file at the given path with the given algorithms. def hash_to_hex(hash: Union[str, bytes]) -> str:
"""Converts a hash (in hex or bytes form) to its hexadecimal ascii form
Args: Args:
path: the path of the file to hash hash (str or bytes): a :class:`bytes` hash or a :class:`str` containing
algorithms: the hashing algorithms used the hexadecimal form of the hash
chunk_cb: a callback
Returns: a dict mapping each algorithm to a bytes digest.
Raises: Returns:
ValueError if algorithms contains an unknown hash algorithm. str: the hexadecimal form of the hash
OSError on file access error
""" """
length = os.path.getsize(path) if isinstance(hash, str):
with open(path, 'rb') as fobj: return hash
return hash_file(fobj, length, algorithms, chunk_cb) return binascii.hexlify(hash).decode("ascii")
def hash_data(data, algorithms=ALGORITHMS): @functools.lru_cache()
"""Hash the given binary blob with the given algorithms. def hash_to_bytehex(hash: bytes) -> bytes:
"""Converts a hash to its hexadecimal bytes representation
Args: Args:
data: a bytes object hash (bytes): a :class:`bytes` hash
algorithms: the hashing algorithms used
Returns: a dict mapping each algorithm to a bytes digest
Raises: Returns:
TypeError if data does not support the buffer interface. bytes: the hexadecimal form of the hash, as :class:`bytes`
ValueError if algorithms contains an unknown hash algorithm.
""" """
fobj = BytesIO(data) return binascii.hexlify(hash)
return hash_file(fobj, len(data), algorithms)
def hash_git_data(data, git_type, base_algo='sha1'): @functools.lru_cache()
"""Hash the given data as a git object of type git_type. def hash_to_bytes(hash: Union[str, bytes]) -> bytes:
"""Converts a hash (in hex or bytes form) to its raw bytes form
Args: Args:
data: a bytes object hash (str or bytes): a :class:`bytes` hash or a :class:`str` containing
git_type: the git object type the hexadecimal form of the hash
base_algo: the base hashing algorithm used (default: sha1)
Returns: a dict mapping each algorithm to a bytes digest Returns:
bytes: the :class:`bytes` form of the hash
Raises:
ValueError if the git_type is unexpected.
""" """
if isinstance(hash, bytes):
git_object_types = {'blob', 'tree', 'commit', 'tag'} return hash
return bytes.fromhex(hash)
if git_type not in git_object_types:
raise ValueError('Unexpected git object type %s, expected one of %s' %
(git_type, ', '.join(sorted(git_object_types))))
h = _new_git_hash(base_algo, git_type, len(data))
h.update(data)
return h.digest()
@functools.lru_cache() @functools.lru_cache()
def hash_to_hex(hash): def bytehex_to_hash(hex: bytes) -> bytes:
"""Converts a hash (in hex or bytes form) to its hexadecimal ascii form""" """Converts a hexadecimal bytes representation of a hash to that hash
if isinstance(hash, str):
return hash
return binascii.hexlify(hash).decode('ascii')
Args:
hash (bytes): a :class:`bytes` containing the hexadecimal form of the
hash encoded in ascii
@functools.lru_cache() Returns:
def hash_to_bytes(hash): bytes: the :class:`bytes` form of the hash
"""Converts a hash (in hex or bytes form) to its raw bytes form""" """
if isinstance(hash, bytes): return hash_to_bytes(hex.decode())
return hash
return bytes.fromhex(hash)
# Copyright (C) 2019-2021 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import datetime
import functools
import string
from typing import Any, Callable, List, Sequence, Set, Tuple, Union
from deprecated import deprecated
from hypothesis import assume
from hypothesis.extra.dateutil import timezones
from hypothesis.strategies import (
SearchStrategy,
binary,
booleans,
builds,
characters,
composite,
datetimes,
dictionaries,
from_regex,
integers,
just,
lists,
none,
one_of,
sampled_from,
sets,
text,
tuples,
)
from .from_disk import DentryPerms
from .model import (
BaseContent,
BaseModel,
Content,
Directory,
DirectoryEntry,
MetadataAuthority,
MetadataFetcher,
ModelObjectType,
Origin,
OriginVisit,
OriginVisitStatus,
Person,
RawExtrinsicMetadata,
Release,
ReleaseTargetType,
Revision,
RevisionType,
SkippedContent,
Snapshot,
SnapshotBranch,
SnapshotTargetType,
Timestamp,
TimestampWithTimezone,
)
from .swhids import ExtendedObjectType, ExtendedSWHID
pgsql_alphabet = characters(
blacklist_categories=["Cs"],
blacklist_characters=["\u0000"],
) # postgresql does not like these
def optional(strategy):
return one_of(none(), strategy)
def pgsql_text():
return text(alphabet=pgsql_alphabet)
def sha1_git():
return binary(min_size=20, max_size=20)
def sha1():
return binary(min_size=20, max_size=20)
def binaries_without_bytes(blacklist: Sequence[int]):
"""Like hypothesis.strategies.binary, but takes a sequence of bytes that
should not be included."""
return lists(sampled_from([i for i in range(256) if i not in blacklist])).map(bytes)
@composite
def extended_swhids(draw):
object_type = draw(sampled_from(ExtendedObjectType))
object_id = draw(sha1_git())
return ExtendedSWHID(object_type=object_type, object_id=object_id)
def aware_datetimes():
# datetimes in Software Heritage are not used for software artifacts
# (which may be much older than 2000), but only for objects like scheduler
# task runs, and origin visits, which were created by Software Heritage,
# so at least in 2015.
# We're forbidding old datetimes, because until 1956, many timezones had seconds
# in their "UTC offsets" (see
# <https://en.wikipedia.org/wiki/Time_zone#Worldwide_time_zones>), which is not
# encodable in ISO8601; and we need our datetimes to be ISO8601-encodable in the
# RPC protocol
min_value = datetime.datetime(2000, 1, 1, 0, 0, 0)
return datetimes(min_value=min_value, timezones=timezones())
@composite
def iris(draw):
protocol = draw(sampled_from(["git", "http", "https", "deb"]))
domain = draw(from_regex(r"\A([a-z]([a-z0-9é🏛️-]*)\.){1,3}([a-z0-9é])+\Z"))
return "%s://%s" % (protocol, domain)
@composite
def persons_d(draw):
fullname = draw(binary())
email = draw(optional(binary()))
name = draw(optional(binary()))
assume(not (len(fullname) == 32 and email is None and name is None))
return dict(fullname=fullname, name=name, email=email)
def persons(**kwargs):
return persons_d(**kwargs).map(Person.from_dict)
def timestamps_d(**kwargs):
defaults = dict(
seconds=integers(Timestamp.MIN_SECONDS, Timestamp.MAX_SECONDS),
microseconds=integers(Timestamp.MIN_MICROSECONDS, Timestamp.MAX_MICROSECONDS),
)
return builds(dict, **{**defaults, **kwargs})
def timestamps():
return timestamps_d().map(Timestamp.from_dict)
@composite
def timestamps_with_timezone_d(
draw,
*,
timestamp=timestamps_d(),
offset=integers(min_value=-14 * 60, max_value=14 * 60),
negative_utc=booleans(),
):
timestamp = draw(timestamp)
offset = draw(offset)
negative_utc = draw(negative_utc)
assume(not (negative_utc and offset))
return dict(timestamp=timestamp, offset=offset, negative_utc=negative_utc)
timestamps_with_timezone = timestamps_with_timezone_d().map(
TimestampWithTimezone.from_dict
)
def origins_d(*, url=iris().filter(lambda iri: len(iri.encode()) < 2048)):
return builds(dict, url=url)
def origins(**kwargs):
return origins_d(**kwargs).map(Origin.from_dict)
def origin_visits_d(**kwargs):
defaults = dict(
visit=integers(1, 1000),
origin=iris(),
date=aware_datetimes(),
type=pgsql_text(),
)
return builds(dict, **{**defaults, **kwargs})
def origin_visits(**kwargs):
return origin_visits_d(**kwargs).map(OriginVisit.from_dict)
def metadata_dicts():
return dictionaries(pgsql_text(), pgsql_text())
def origin_visit_statuses_d(**kwargs):
defaults = dict(
visit=integers(1, 1000),
origin=iris(),
type=optional(sampled_from(["git", "svn", "pypi", "debian"])),
status=sampled_from(
["created", "ongoing", "full", "partial", "not_found", "failed"]
),
date=aware_datetimes(),
snapshot=optional(sha1_git()),
metadata=optional(metadata_dicts()),
)
return builds(dict, **{**defaults, **kwargs})
def origin_visit_statuses(**kwargs):
return origin_visit_statuses_d(**kwargs).map(OriginVisitStatus.from_dict)
@composite
def releases_d(draw, **kwargs):
defaults = dict(
target_type=sampled_from([x.value for x in ReleaseTargetType]),
name=binary(),
message=optional(binary()),
synthetic=booleans(),
target=sha1_git(),
metadata=optional(revision_metadata()),
raw_manifest=optional(binary()),
)
d = draw(
one_of(
# None author/date:
builds(dict, author=none(), date=none(), **{**defaults, **kwargs}),
# non-None author/date:
builds(
dict,
date=timestamps_with_timezone_d(),
author=persons_d(),
**{**defaults, **kwargs},
),
# it is also possible for date to be None but not author, but let's not
# overwhelm hypothesis with this edge case
)
)
if d["raw_manifest"] is None:
del d["raw_manifest"]
return d
def releases(**kwargs):
return releases_d(**kwargs).map(Release.from_dict)
revision_metadata = metadata_dicts
def extra_headers():
return lists(
tuples(binary(min_size=0, max_size=50), binary(min_size=0, max_size=500))
).map(tuple)
@composite
def revisions_d(draw, **kwargs):
defaults = dict(
message=optional(binary()),
synthetic=booleans(),
parents=tuples(sha1_git()),
directory=sha1_git(),
type=sampled_from([x.value for x in RevisionType]),
metadata=optional(revision_metadata()),
extra_headers=extra_headers(),
raw_manifest=optional(binary()),
)
d = draw(
one_of(
# None author/committer/date/committer_date
builds(
dict,
author=none(),
committer=none(),
date=none(),
committer_date=none(),
**{**defaults, **kwargs},
),
# non-None author/committer/date/committer_date
builds(
dict,
author=persons_d(),
committer=persons_d(),
date=timestamps_with_timezone_d(),
committer_date=timestamps_with_timezone_d(),
**{**defaults, **kwargs},
),
# There are many other combinations, but let's not overwhelm hypothesis
# with these edge cases
)
)
# TODO: metadata['extra_headers'] can have binary keys and values
if d["raw_manifest"] is None:
del d["raw_manifest"]
return d
def revisions(**kwargs):
return revisions_d(**kwargs).map(Revision.from_dict)
def directory_entries_d(**kwargs):
defaults = dict(
name=binaries_without_bytes(b"/"),
target=sha1_git(),
)
return one_of(
builds(
dict,
type=just("file"),
perms=one_of(
integers(min_value=0o100000, max_value=0o100777), # regular file
integers(min_value=0o120000, max_value=0o120777), # symlink
),
**{**defaults, **kwargs},
),
builds(
dict,
type=just("dir"),
perms=integers(
min_value=DentryPerms.directory,
max_value=DentryPerms.directory + 0o777,
),
**{**defaults, **kwargs},
),
builds(
dict,
type=just("rev"),
perms=integers(
min_value=DentryPerms.revision,
max_value=DentryPerms.revision + 0o777,
),
**{**defaults, **kwargs},
),
)
def directory_entries(**kwargs):
return directory_entries_d(**kwargs).map(DirectoryEntry)
@composite
def directories_d(draw, raw_manifest=optional(binary())):
d = draw(builds(dict, entries=tuples(directory_entries_d())))
d["raw_manifest"] = draw(raw_manifest)
if d["raw_manifest"] is None:
del d["raw_manifest"]
return d
def directories(**kwargs):
return directories_d(**kwargs).map(Directory.from_dict)
def contents_d():
return one_of(present_contents_d(), skipped_contents_d())
def contents():
return one_of(present_contents(), skipped_contents())
def present_contents_d(**kwargs):
defaults = dict(
data=binary(max_size=4096),
ctime=optional(aware_datetimes()),
status=one_of(just("visible"), just("hidden")),
)
return builds(dict, **{**defaults, **kwargs})
def present_contents(**kwargs):
return present_contents_d().map(lambda d: Content.from_data(**d))
@composite
def skipped_contents_d(
draw, reason=pgsql_text(), status=just("absent"), ctime=optional(aware_datetimes())
):
result = BaseContent._hash_data(draw(binary(max_size=4096)))
result.pop("data")
nullify_attrs = draw(
sets(sampled_from(["sha1", "sha1_git", "sha256", "blake2s256"]))
)
for k in nullify_attrs:
result[k] = None
result["reason"] = draw(reason)
result["status"] = draw(status)
result["ctime"] = draw(ctime)
return result
def skipped_contents(**kwargs):
return skipped_contents_d().map(SkippedContent.from_dict)
def branch_names():
return binary(min_size=1)
def snapshot_targets_object_d():
return builds(
dict,
target=sha1_git(),
target_type=sampled_from(
[x.value for x in SnapshotTargetType if x.value not in ("alias",)]
),
)
branch_targets_object_d = deprecated(
version="v6.13.0", reason="use snapshot_targets_object_d"
)(snapshot_targets_object_d)
def snapshot_targets_alias_d():
return builds(
dict, target=sha1_git(), target_type=just("alias")
) # SnapshotTargetType.ALIAS.value))
branch_targets_alias_d = deprecated(
version="v6.13.0", reason="use snapshot_targets_alias_d"
)(snapshot_targets_alias_d)
def snapshot_targets_d(*, only_objects=False):
if only_objects:
return snapshot_targets_object_d()
else:
return one_of(snapshot_targets_alias_d(), snapshot_targets_object_d())
branch_targets_d = deprecated(version="v6.13.0", reason="use snapshot_targets_d")(
snapshot_targets_d
)
def snapshot_targets(*, only_objects=False):
return builds(
SnapshotBranch.from_dict, snapshot_targets_d(only_objects=only_objects)
)
@composite
def snapshots_d(draw, *, min_size=0, max_size=100, only_objects=False):
branches = draw(
dictionaries(
keys=branch_names(),
values=optional(snapshot_targets_d(only_objects=only_objects)),
min_size=min_size,
max_size=max_size,
)
)
if not only_objects:
# Make sure aliases point to actual branches
unresolved_aliases = {
branch: target["target"]
for branch, target in branches.items()
if (
target
and target["target_type"] == "alias"
and target["target"] not in branches
)
}
for alias_name, alias_target in unresolved_aliases.items():
# Override alias branch with one pointing to a real object
# if max_size constraint is reached
alias = alias_target if len(branches) < max_size else alias_name
branches[alias] = draw(snapshot_targets_d(only_objects=True))
# Ensure no cycles between aliases
while True:
try:
snapshot = Snapshot.from_dict(
{
"branches": {
name: branch or None for (name, branch) in branches.items()
}
}
)
except ValueError as e:
for source, target in e.args[1]:
branches[source] = draw(snapshot_targets_d(only_objects=True))
else:
break
return snapshot.to_dict()
def snapshots(*, min_size=0, max_size=100, only_objects=False):
return snapshots_d(
min_size=min_size, max_size=max_size, only_objects=only_objects
).map(Snapshot.from_dict)
def metadata_authorities(url=iris()):
return builds(MetadataAuthority, url=url, metadata=just(None))
def metadata_fetchers(**kwargs):
defaults = dict(
name=text(min_size=1, alphabet=string.printable),
version=text(
min_size=1,
alphabet=string.ascii_letters + string.digits + string.punctuation,
),
)
return builds(
MetadataFetcher,
metadata=just(None),
**{**defaults, **kwargs},
)
def raw_extrinsic_metadata(**kwargs):
defaults = dict(
target=extended_swhids(),
discovery_date=aware_datetimes(),
authority=metadata_authorities(),
fetcher=metadata_fetchers(),
format=text(min_size=1, alphabet=string.printable),
)
return builds(RawExtrinsicMetadata, **{**defaults, **kwargs})
def raw_extrinsic_metadata_d(**kwargs):
return raw_extrinsic_metadata(**kwargs).map(RawExtrinsicMetadata.to_dict)
def _tuplify(object_type: ModelObjectType, obj: BaseModel):
return (object_type, obj)
def objects(
# remove the Union once deprecated usage have been migrated
blacklist_types: Union[Set[ModelObjectType] | Any] = {
ModelObjectType.ORIGIN_VISIT_STATUS,
},
split_content: bool = False,
):
"""generates a random couple (type, obj)
which obj is an instance of the Model class corresponding to obj_type.
`blacklist_types` is a list of obj_type to exclude from the strategy.
If `split_content` is True, generates Content and SkippedContent under different
obj_type, resp. "content" and "skipped_content".
"""
strategies: List[
Tuple[ModelObjectType, Callable[[], SearchStrategy[BaseModel]]]
] = [
(ModelObjectType.ORIGIN, origins),
(ModelObjectType.ORIGIN_VISIT, origin_visits),
(ModelObjectType.ORIGIN_VISIT_STATUS, origin_visit_statuses),
(ModelObjectType.SNAPSHOT, snapshots),
(ModelObjectType.RELEASE, releases),
(ModelObjectType.REVISION, revisions),
(ModelObjectType.DIRECTORY, directories),
(ModelObjectType.RAW_EXTRINSIC_METADATA, raw_extrinsic_metadata),
]
if split_content:
strategies.append((ModelObjectType.CONTENT, present_contents))
strategies.append((ModelObjectType.SKIPPED_CONTENT, skipped_contents))
else:
strategies.append((ModelObjectType.CONTENT, contents))
candidates = [
obj_gen().map(functools.partial(_tuplify, obj_type))
for (obj_type, obj_gen) in strategies
if obj_type not in blacklist_types
]
return one_of(*candidates)
def object_dicts(
blacklist_types=(ModelObjectType.ORIGIN_VISIT_STATUS,), split_content=False
):
"""generates a random couple (type, dict)
which dict is suitable for <ModelForType>.from_dict() factory methods.
`blacklist_types` is a list of obj_type to exclude from the strategy.
If `split_content` is True, generates Content and SkippedContent under different
obj_type, resp. "content" and "skipped_content".
"""
strategies = [
(ModelObjectType.ORIGIN, origins_d),
(ModelObjectType.ORIGIN_VISIT, origin_visits_d),
(ModelObjectType.ORIGIN_VISIT_STATUS, origin_visit_statuses_d),
(ModelObjectType.SNAPSHOT, snapshots_d),
(ModelObjectType.RELEASE, releases_d),
(ModelObjectType.REVISION, revisions_d),
(ModelObjectType.DIRECTORY, directories_d),
(ModelObjectType.RAW_EXTRINSIC_METADATA, raw_extrinsic_metadata_d),
]
if split_content:
strategies.append((ModelObjectType.CONTENT, present_contents_d))
strategies.append((ModelObjectType.SKIPPED_CONTENT, skipped_contents_d))
else:
strategies.append((ModelObjectType.CONTENT, contents_d))
args = [
obj_gen().map(lambda x, obj_type=obj_type: (obj_type, x))
for (obj_type, obj_gen) in strategies
if obj_type not in blacklist_types
]
return one_of(*args)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Marker file for PEP 561.
This diff is collapsed.
File added
File added
This diff is collapsed.
This diff is collapsed.
...@@ -6,131 +6,120 @@ ...@@ -6,131 +6,120 @@
import datetime import datetime
import unittest import unittest
from nose.tools import istest
from swh.model.exceptions import ValidationError from swh.model.exceptions import ValidationError
from swh.model.fields import simple from swh.model.fields import simple
class ValidateSimple(unittest.TestCase): class ValidateSimple(unittest.TestCase):
def setUp(self): def setUp(self):
self.valid_str = 'I am a valid string' self.valid_str = "I am a valid string"
self.valid_bytes = b'I am a valid bytes object' self.valid_bytes = b"I am a valid bytes object"
self.enum_values = {'an enum value', 'other', 'and another'} self.enum_values = {"an enum value", "other", "and another"}
self.invalid_enum_value = 'invalid enum value' self.invalid_enum_value = "invalid enum value"
self.valid_int = 42 self.valid_int = 42
self.valid_real = 42.42 self.valid_real = 42.42
self.valid_datetime = datetime.datetime(1999, 1, 1, 12, 0, 0, self.valid_datetime = datetime.datetime(
tzinfo=datetime.timezone.utc) 1999, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc
)
self.invalid_datetime_notz = datetime.datetime(1999, 1, 1, 12, 0, 0) self.invalid_datetime_notz = datetime.datetime(1999, 1, 1, 12, 0, 0)
@istest def test_validate_int(self):
def validate_int(self):
self.assertTrue(simple.validate_int(self.valid_int)) self.assertTrue(simple.validate_int(self.valid_int))
@istest def test_validate_int_invalid_type(self):
def validate_int_invalid_type(self):
with self.assertRaises(ValidationError) as cm: with self.assertRaises(ValidationError) as cm:
simple.validate_int(self.valid_str) simple.validate_int(self.valid_str)
exc = cm.exception exc = cm.exception
self.assertIsInstance(str(exc), str) self.assertIsInstance(str(exc), str)
self.assertEqual(exc.code, 'unexpected-type') self.assertEqual(exc.code, "unexpected-type")
self.assertEqual(exc.params['expected_type'], 'Integral') self.assertEqual(exc.params["expected_type"], "Integral")
self.assertEqual(exc.params['type'], 'str') self.assertEqual(exc.params["type"], "str")
@istest def test_validate_str(self):
def validate_str(self):
self.assertTrue(simple.validate_str(self.valid_str)) self.assertTrue(simple.validate_str(self.valid_str))
@istest def test_validate_str_invalid_type(self):
def validate_str_invalid_type(self):
with self.assertRaises(ValidationError) as cm: with self.assertRaises(ValidationError) as cm:
simple.validate_str(self.valid_int) simple.validate_str(self.valid_int)
exc = cm.exception exc = cm.exception
self.assertIsInstance(str(exc), str) self.assertIsInstance(str(exc), str)
self.assertEqual(exc.code, 'unexpected-type') self.assertEqual(exc.code, "unexpected-type")
self.assertEqual(exc.params['expected_type'], 'str') self.assertEqual(exc.params["expected_type"], "str")
self.assertEqual(exc.params['type'], 'int') self.assertEqual(exc.params["type"], "int")
with self.assertRaises(ValidationError) as cm: with self.assertRaises(ValidationError) as cm:
simple.validate_str(self.valid_bytes) simple.validate_str(self.valid_bytes)
exc = cm.exception exc = cm.exception
self.assertIsInstance(str(exc), str) self.assertIsInstance(str(exc), str)
self.assertEqual(exc.code, 'unexpected-type') self.assertEqual(exc.code, "unexpected-type")
self.assertEqual(exc.params['expected_type'], 'str') self.assertEqual(exc.params["expected_type"], "str")
self.assertEqual(exc.params['type'], 'bytes') self.assertEqual(exc.params["type"], "bytes")
@istest def test_validate_bytes(self):
def validate_bytes(self):
self.assertTrue(simple.validate_bytes(self.valid_bytes)) self.assertTrue(simple.validate_bytes(self.valid_bytes))
@istest def test_validate_bytes_invalid_type(self):
def validate_bytes_invalid_type(self):
with self.assertRaises(ValidationError) as cm: with self.assertRaises(ValidationError) as cm:
simple.validate_bytes(self.valid_int) simple.validate_bytes(self.valid_int)
exc = cm.exception exc = cm.exception
self.assertIsInstance(str(exc), str) self.assertIsInstance(str(exc), str)
self.assertEqual(exc.code, 'unexpected-type') self.assertEqual(exc.code, "unexpected-type")
self.assertEqual(exc.params['expected_type'], 'bytes') self.assertEqual(exc.params["expected_type"], "bytes")
self.assertEqual(exc.params['type'], 'int') self.assertEqual(exc.params["type"], "int")
with self.assertRaises(ValidationError) as cm: with self.assertRaises(ValidationError) as cm:
simple.validate_bytes(self.valid_str) simple.validate_bytes(self.valid_str)
exc = cm.exception exc = cm.exception
self.assertIsInstance(str(exc), str) self.assertIsInstance(str(exc), str)
self.assertEqual(exc.code, 'unexpected-type') self.assertEqual(exc.code, "unexpected-type")
self.assertEqual(exc.params['expected_type'], 'bytes') self.assertEqual(exc.params["expected_type"], "bytes")
self.assertEqual(exc.params['type'], 'str') self.assertEqual(exc.params["type"], "str")
@istest def test_validate_datetime(self):
def validate_datetime(self):
self.assertTrue(simple.validate_datetime(self.valid_datetime)) self.assertTrue(simple.validate_datetime(self.valid_datetime))
self.assertTrue(simple.validate_datetime(self.valid_int)) self.assertTrue(simple.validate_datetime(self.valid_int))
self.assertTrue(simple.validate_datetime(self.valid_real)) self.assertTrue(simple.validate_datetime(self.valid_real))
@istest def test_validate_datetime_invalid_type(self):
def validate_datetime_invalid_type(self):
with self.assertRaises(ValidationError) as cm: with self.assertRaises(ValidationError) as cm:
simple.validate_datetime(self.valid_str) simple.validate_datetime(self.valid_str)
exc = cm.exception exc = cm.exception
self.assertIsInstance(str(exc), str) self.assertIsInstance(str(exc), str)
self.assertEqual(exc.code, 'unexpected-type') self.assertEqual(exc.code, "unexpected-type")
self.assertEqual(exc.params['expected_type'], 'one of datetime, Real') self.assertEqual(exc.params["expected_type"], "one of datetime, Real")
self.assertEqual(exc.params['type'], 'str') self.assertEqual(exc.params["type"], "str")
@istest def test_validate_datetime_invalide_tz(self):
def validate_datetime_invalide_tz(self):
with self.assertRaises(ValidationError) as cm: with self.assertRaises(ValidationError) as cm:
simple.validate_datetime(self.invalid_datetime_notz) simple.validate_datetime(self.invalid_datetime_notz)
exc = cm.exception exc = cm.exception
self.assertIsInstance(str(exc), str) self.assertIsInstance(str(exc), str)
self.assertEqual(exc.code, 'datetime-without-tzinfo') self.assertEqual(exc.code, "datetime-without-tzinfo")
@istest def test_validate_enum(self):
def validate_enum(self):
for value in self.enum_values: for value in self.enum_values:
self.assertTrue(simple.validate_enum(value, self.enum_values)) self.assertTrue(simple.validate_enum(value, self.enum_values))
@istest def test_validate_enum_invalid_value(self):
def validate_enum_invalid_value(self):
with self.assertRaises(ValidationError) as cm: with self.assertRaises(ValidationError) as cm:
simple.validate_enum(self.invalid_enum_value, self.enum_values) simple.validate_enum(self.invalid_enum_value, self.enum_values)
exc = cm.exception exc = cm.exception
self.assertIsInstance(str(exc), str) self.assertIsInstance(str(exc), str)
self.assertEqual(exc.code, 'unexpected-value') self.assertEqual(exc.code, "unexpected-value")
self.assertEqual(exc.params['value'], self.invalid_enum_value) self.assertEqual(exc.params["value"], self.invalid_enum_value)
self.assertEqual(exc.params['expected_values'], self.assertEqual(
', '.join(sorted(self.enum_values))) exc.params["expected_values"], ", ".join(sorted(self.enum_values))
)
This diff is collapsed.