Skip to content
Snippets Groups Projects
Verified Commit 0e71ebfa authored by Antoine R. Dumont's avatar Antoine R. Dumont
Browse files

hashutil: Add MultiHash to compute hashes of content in 1 roundtrip

This is the first step to improve the hashutil module according to
"the" plan [1].

In this regards, the hashutil exposes the same function as before.
Internally though, they now uses a MultiHash instance.

Related D410

[1] D410#7952
parent 7f885ed5
No related branches found
No related tags found
No related merge requests found
# Copyright (C) 2015-2017 The Software Heritage developers
# Copyright (C) 2015-2018 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
......@@ -12,6 +12,10 @@ in a ValueError explaining the error.
This modules defines the following hashing functions:
- hash_stream: Hash the contents of something iterable (file, stream,
...) with the given algorithms (defaulting to DEFAULT_ALGORITHMS if
none provided).
- hash_file: Hash the contents of the given file object with the given
algorithms (defaulting to DEFAULT_ALGORITHMS if none provided).
......@@ -45,6 +49,73 @@ HASH_BLOCK_SIZE = 32768
_blake2_hash_cache = {}
HASH_FORMATS = set(['bytes', 'bytehex', 'hex'])
"""Supported output hash formats
"""
class MultiHash:
"""Hashutil class to support multiple hashes computation.
Args:
hash_names (set): Set of hash algorithms (+ length) to compute
hashes (cf. DEFAULT_ALGORITHMS)
length (int): Length of the total sum of chunks to read
If the length is provided as algorithm, the length is also
computed and returned.
"""
def __init__(self, hash_names, length=None):
self.state = {}
self.track_length = False
for name in hash_names:
if name == 'length':
self.state['length'] = 0
self.track_length = True
else:
self.state[name] = _new_hash(name, length)
@classmethod
def from_state(cls, state, track_length):
ret = cls([])
ret.state = state
ret.track_length = track_length
def update(self, chunk):
for name, h in self.state.items():
if name == 'length':
continue
h.update(chunk)
if self.track_length:
self.state['length'] += len(chunk)
def digest(self):
return {
name: h.digest() if name != 'length' else h
for name, h in self.state.items()
}
def hexdigest(self):
return {
name: h.hexdigest() if name != 'length' else h
for name, h in self.state.items()
}
def bytehexdigest(self):
return {
name: hash_to_bytehex(h.digest()) if name != 'length' else h
for name, h in self.state.items()
}
def copy(self):
copied_state = {
name: h.copy() if name != 'length' else h
for name, h in self.state.items()
}
return self.from_state(copied_state, self.track_length)
def _new_blake2_hash(algo):
"""Return a function that initializes a blake2 hash.
......@@ -162,128 +233,135 @@ def _new_hash(algo, length=None):
return _new_hashlib_hash(algo)
def hash_file(fobj, length=None, algorithms=DEFAULT_ALGORITHMS,
chunk_cb=None, with_length=False, hexdigest=False):
"""Hash the contents of the given file object with the given algorithms.
def _read(fobj):
"""Wrapper function around reading a chunk from fobj.
"""
return fobj.read(HASH_BLOCK_SIZE)
def hash_stream(s, readfn=_read, length=None, algorithms=DEFAULT_ALGORITHMS,
chunk_cb=None, hash_format='bytes'):
"""Hash the contents of a stream
Args:
fobj: a file-like object
length: the length of the contents of the file-like object (for the
git-specific algorithms)
algorithms: the hashing algorithms to be used, as an iterable over
strings
with_length (bool): Include length in the dict result
hexdigest (bool): False returns the hash as binary, otherwise
returns as hex
s: stream or object we can consume by successive call using `readfn`
readfn (fn): Function to read chunk data from s
length (int): the length of the contents of the object (for the
git-specific algorithms)
algorithms (set): the hashing algorithms to be used, as an
iterable over strings
hash_format (str): Format required for the output of the
computed hashes (cf. HASH_FORMATS)
Returns: a dict mapping each algorithm to a digest (bytes by default).
Raises:
ValueError if algorithms contains an unknown hash algorithm.
ValueError if:
algorithms contains an unknown hash algorithm.
hash_format is an unknown hash format
"""
hashes = {algo: _new_hash(algo, length) for algo in algorithms}
if hash_format not in HASH_FORMATS:
raise ValueError('Unexpected hash format %s, expected one of %s' % (
hash_format, HASH_FORMATS))
h = MultiHash(algorithms, length)
while True:
chunk = fobj.read(HASH_BLOCK_SIZE)
chunk = readfn(s)
if not chunk:
break
for hash in hashes.values():
hash.update(chunk)
h.update(chunk)
if chunk_cb:
chunk_cb(chunk)
if hexdigest:
h = {algo: hash.hexdigest() for algo, hash in hashes.items()}
else:
h = {algo: hash.digest() for algo, hash in hashes.items()}
if with_length:
h['length'] = length
return h
if hash_format == 'bytes':
return h.digest()
if hash_format == 'bytehex':
return h.bytehexdigest()
return h.hexdigest()
def hash_stream(s, length=None, algorithms=DEFAULT_ALGORITHMS,
chunk_cb=None, with_length=False, hexdigest=False):
"""Hash the contents of the given stream with the given algorithms.
def hash_file(fobj, length=None, algorithms=DEFAULT_ALGORITHMS,
chunk_cb=None, hash_format='bytes'):
"""Hash the contents of the given file object with the given algorithms.
Args:
s (stream): a stream object (e.g requests.get(stream=True))
length (int): the length of the contents of the stream (for the
git-specific algorithms)
algorithms (dict): the hashing algorithms to be used, as an
iterable over strings
with_length (bool): Include length in the dict result
hexdigest (bool): False returns the hash as binary, otherwise
returns as hex
fobj: a file-like object
length: the length of the contents of the file-like object (for the
git-specific algorithms)
algorithms: the hashing algorithms to be used, as an iterable over
strings
hash_format (str): Format required for the output of the
computed hashes (cf. HASH_FORMATS)
Returns: a dict mapping each algorithm to a digest (bytes by default).
Raises:
ValueError if algorithms contains an unknown hash algorithm.
ValueError if:
"""
hashes = {algo: _new_hash(algo, length) for algo in algorithms}
algorithms contains an unknown hash algorithm.
hash_format is an unknown hash format
for chunk in s.iter_content():
if not chunk:
break
for hash in hashes.values():
hash.update(chunk)
if chunk_cb:
chunk_cb(chunk)
if hexdigest:
h = {algo: hash.hexdigest() for algo, hash in hashes.items()}
else:
h = {algo: hash.digest() for algo, hash in hashes.items()}
if with_length:
h['length'] = length
return h
"""
return hash_stream(fobj, length=length, algorithms=algorithms,
chunk_cb=chunk_cb, hash_format=hash_format)
def hash_path(path, algorithms=DEFAULT_ALGORITHMS, chunk_cb=None,
with_length=True, hexdigest=False):
hash_format='bytes', track_length=True):
"""Hash the contents of the file at the given path with the given
algorithms.
Args:
path: the path of the file to hash
algorithms: the hashing algorithms used
chunk_cb: a callback
with_length (bool): Include length in the dict result
hexdigest (bool): False returns the hash as binary, otherwise
returns as hex
path (str): the path of the file to hash
algorithms (set): the hashing algorithms used
chunk_cb (def): a callback
hash_format (str): Format required for the output of the
computed hashes (cf. HASH_FORMATS)
Returns: a dict mapping each algorithm to a bytes digest.
Raises:
ValueError if algorithms contains an unknown hash algorithm.
ValueError if:
algorithms contains an unknown hash algorithm.
hash_format is an unknown hash format
OSError on file access error
"""
if track_length:
algorithms = set(['length']).union(algorithms)
length = os.path.getsize(path)
with open(path, 'rb') as fobj:
return hash_file(fobj, length, algorithms, chunk_cb=chunk_cb,
with_length=with_length, hexdigest=hexdigest)
hash_format=hash_format)
def hash_data(data, algorithms=DEFAULT_ALGORITHMS, with_length=False):
def hash_data(data, algorithms=DEFAULT_ALGORITHMS, hash_format='bytes'):
"""Hash the given binary blob with the given algorithms.
Args:
data (bytes): raw content to hash
algorithms (list): the hashing algorithms used
with_length (bool): add the length key in the resulting dict
hash_format (str): Format required for the output of the
computed hashes (cf. HASH_FORMATS)
Returns: a dict mapping each algorithm to a bytes digest
Raises:
TypeError if data does not support the buffer interface.
ValueError if algorithms contains an unknown hash algorithm.
ValueError if:
algorithms contains an unknown hash algorithm.
hash_format is an unknown hash format
"""
fobj = BytesIO(data)
length = len(data)
return hash_file(fobj, length, algorithms, with_length=with_length)
return hash_file(fobj, length, algorithms, hash_format=hash_format)
def hash_git_data(data, git_type, base_algo='sha1'):
......
# Copyright (C) 2015-2017 The Software Heritage developers
# Copyright (C) 2015-2018 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
......@@ -35,6 +35,11 @@ class Hashutil(unittest.TestCase):
for type, cksum in self.hex_checksums.items()
}
self.bytehex_checksums = {
type: hashutil.hash_to_bytehex(cksum)
for type, cksum in self.checksums.items()
}
self.git_hex_checksums = {
'blob': self.hex_checksums['sha1_git'],
'tree': '5b2e883aa33d2efab98442693ea4dd5f1b8871b0',
......@@ -58,7 +63,8 @@ class Hashutil(unittest.TestCase):
expected_checksums = self.checksums.copy()
expected_checksums['length'] = len(self.data)
checksums = hashutil.hash_data(self.data, with_length=True)
algos = set(['length']).union(hashutil.DEFAULT_ALGORITHMS)
checksums = hashutil.hash_data(self.data, algorithms=algos)
self.assertEqual(checksums, expected_checksums)
self.assertTrue('length' in checksums)
......@@ -71,6 +77,16 @@ class Hashutil(unittest.TestCase):
self.assertIn('Unexpected hashing algorithm', cm.exception.args[0])
self.assertIn('unknown-hash', cm.exception.args[0])
@istest
def hash_data_unknown_hash_format(self):
with self.assertRaises(ValueError) as cm:
hashutil.hash_data(
self.data, hashutil.DEFAULT_ALGORITHMS,
hash_format='unknown-format')
self.assertIn('Unexpected hash format', cm.exception.args[0])
self.assertIn('unknown-format', cm.exception.args[0])
@istest
def hash_git_data(self):
checksums = {
......@@ -98,10 +114,17 @@ class Hashutil(unittest.TestCase):
@istest
def hash_file_hexdigest(self):
fobj = io.BytesIO(self.data)
checksums = hashutil.hash_file(fobj, length=len(self.data),
hexdigest=True)
checksums = hashutil.hash_file(
fobj, length=len(self.data), hash_format='hex')
self.assertEqual(checksums, self.hex_checksums)
@istest
def hash_file_bytehexdigest(self):
fobj = io.BytesIO(self.data)
checksums = hashutil.hash_file(
fobj, length=len(self.data), hash_format='bytehex')
self.assertEqual(checksums, self.bytehex_checksums)
@istest
def hash_stream(self):
class StreamStub:
......@@ -111,9 +134,16 @@ class Hashutil(unittest.TestCase):
def iter_content(self):
yield from io.BytesIO(self.data)
s = StreamStub(self.data)
checksums = hashutil.hash_stream(s, length=len(self.data),
hexdigest=True)
s = StreamStub(self.data).iter_content()
def _readfn(s):
try:
return next(s)
except StopIteration:
return None
checksums = hashutil.hash_stream(
s, readfn=_readfn, length=len(self.data), hash_format='hex')
self.assertEqual(checksums, self.hex_checksums)
@istest
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment