From 0e71ebfa4f87413567e82fafbc4eeb15e530c48c Mon Sep 17 00:00:00 2001
From: "Antoine R. Dumont (@ardumont)" <antoine.romain.dumont@gmail.com>
Date: Fri, 14 Sep 2018 20:04:27 +0200
Subject: [PATCH] hashutil: Add MultiHash to compute hashes of content in 1
 roundtrip

This is the first step to improve the hashutil module according to
"the" plan [1].

In this regards, the hashutil exposes the same function as before.
Internally though, they now uses a MultiHash instance.

Related D410

[1] D410#7952
---
 swh/model/hashutil.py            | 210 +++++++++++++++++++++----------
 swh/model/tests/test_hashutil.py |  44 +++++--
 2 files changed, 181 insertions(+), 73 deletions(-)

diff --git a/swh/model/hashutil.py b/swh/model/hashutil.py
index a1556038..bda5389c 100644
--- a/swh/model/hashutil.py
+++ b/swh/model/hashutil.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2015-2017  The Software Heritage developers
+# Copyright (C) 2015-2018  The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
@@ -12,6 +12,10 @@ in a ValueError explaining the error.
 
 This modules defines the following hashing functions:
 
+- hash_stream: Hash the contents of something iterable (file, stream,
+  ...) with the given algorithms (defaulting to DEFAULT_ALGORITHMS if
+  none provided).
+
 - hash_file: Hash the contents of the given file object with the given
   algorithms (defaulting to DEFAULT_ALGORITHMS if none provided).
 
@@ -45,6 +49,73 @@ HASH_BLOCK_SIZE = 32768
 
 _blake2_hash_cache = {}
 
+HASH_FORMATS = set(['bytes', 'bytehex', 'hex'])
+"""Supported output hash formats
+"""
+
+
+class MultiHash:
+    """Hashutil class to support multiple hashes computation.
+
+    Args:
+
+        hash_names (set): Set of hash algorithms (+ length) to compute
+                          hashes (cf. DEFAULT_ALGORITHMS)
+        length (int): Length of the total sum of chunks to read
+
+    If the length is provided as algorithm, the length is also
+    computed and returned.
+
+    """
+    def __init__(self, hash_names, length=None):
+        self.state = {}
+        self.track_length = False
+        for name in hash_names:
+            if name == 'length':
+                self.state['length'] = 0
+                self.track_length = True
+            else:
+                self.state[name] = _new_hash(name, length)
+
+    @classmethod
+    def from_state(cls, state, track_length):
+        ret = cls([])
+        ret.state = state
+        ret.track_length = track_length
+
+    def update(self, chunk):
+        for name, h in self.state.items():
+            if name == 'length':
+                continue
+            h.update(chunk)
+        if self.track_length:
+            self.state['length'] += len(chunk)
+
+    def digest(self):
+        return {
+            name: h.digest() if name != 'length' else h
+            for name, h in self.state.items()
+        }
+
+    def hexdigest(self):
+        return {
+            name: h.hexdigest() if name != 'length' else h
+            for name, h in self.state.items()
+        }
+
+    def bytehexdigest(self):
+        return {
+            name: hash_to_bytehex(h.digest()) if name != 'length' else h
+            for name, h in self.state.items()
+        }
+
+    def copy(self):
+        copied_state = {
+            name: h.copy() if name != 'length' else h
+            for name, h in self.state.items()
+        }
+        return self.from_state(copied_state, self.track_length)
+
 
 def _new_blake2_hash(algo):
     """Return a function that initializes a blake2 hash.
@@ -162,128 +233,135 @@ def _new_hash(algo, length=None):
     return _new_hashlib_hash(algo)
 
 
-def hash_file(fobj, length=None, algorithms=DEFAULT_ALGORITHMS,
-              chunk_cb=None, with_length=False, hexdigest=False):
-    """Hash the contents of the given file object with the given algorithms.
+def _read(fobj):
+    """Wrapper function around reading a chunk from fobj.
+
+    """
+    return fobj.read(HASH_BLOCK_SIZE)
+
+
+def hash_stream(s, readfn=_read, length=None, algorithms=DEFAULT_ALGORITHMS,
+                chunk_cb=None, hash_format='bytes'):
+    """Hash the contents of a stream
 
     Args:
-        fobj: a file-like object
-        length: the length of the contents of the file-like object (for the
-          git-specific algorithms)
-        algorithms: the hashing algorithms to be used, as an iterable over
-          strings
-        with_length (bool): Include length in the dict result
-        hexdigest (bool): False returns the hash as binary, otherwise
-                          returns as hex
+        s: stream or object we can consume by successive call using `readfn`
+        readfn (fn): Function to read chunk data from s
+        length (int): the length of the contents of the object (for the
+                      git-specific algorithms)
+        algorithms (set): the hashing algorithms to be used, as an
+                          iterable over strings
+        hash_format (str): Format required for the output of the
+                           computed hashes (cf. HASH_FORMATS)
 
     Returns: a dict mapping each algorithm to a digest (bytes by default).
 
     Raises:
-        ValueError if algorithms contains an unknown hash algorithm.
+        ValueError if:
+
+            algorithms contains an unknown hash algorithm.
+            hash_format is an unknown hash format
 
     """
-    hashes = {algo: _new_hash(algo, length) for algo in algorithms}
+    if hash_format not in HASH_FORMATS:
+        raise ValueError('Unexpected hash format %s, expected one of %s' % (
+            hash_format, HASH_FORMATS))
 
+    h = MultiHash(algorithms, length)
     while True:
-        chunk = fobj.read(HASH_BLOCK_SIZE)
+        chunk = readfn(s)
         if not chunk:
             break
-        for hash in hashes.values():
-            hash.update(chunk)
+        h.update(chunk)
         if chunk_cb:
             chunk_cb(chunk)
 
-    if hexdigest:
-        h = {algo: hash.hexdigest() for algo, hash in hashes.items()}
-    else:
-        h = {algo: hash.digest() for algo, hash in hashes.items()}
-    if with_length:
-        h['length'] = length
-    return h
+    if hash_format == 'bytes':
+        return h.digest()
+    if hash_format == 'bytehex':
+        return h.bytehexdigest()
+    return h.hexdigest()
 
 
-def hash_stream(s, length=None, algorithms=DEFAULT_ALGORITHMS,
-                chunk_cb=None, with_length=False, hexdigest=False):
-    """Hash the contents of the given stream with the given algorithms.
+def hash_file(fobj, length=None, algorithms=DEFAULT_ALGORITHMS,
+              chunk_cb=None, hash_format='bytes'):
+    """Hash the contents of the given file object with the given algorithms.
 
     Args:
-        s (stream): a stream object (e.g requests.get(stream=True))
-        length (int): the length of the contents of the stream (for the
-                      git-specific algorithms)
-        algorithms (dict): the hashing algorithms to be used, as an
-                           iterable over strings
-        with_length (bool): Include length in the dict result
-        hexdigest (bool): False returns the hash as binary, otherwise
-                          returns as hex
+        fobj: a file-like object
+        length: the length of the contents of the file-like object (for the
+          git-specific algorithms)
+        algorithms: the hashing algorithms to be used, as an iterable over
+          strings
+        hash_format (str): Format required for the output of the
+                           computed hashes (cf. HASH_FORMATS)
 
     Returns: a dict mapping each algorithm to a digest (bytes by default).
 
     Raises:
-        ValueError if algorithms contains an unknown hash algorithm.
+        ValueError if:
 
-    """
-    hashes = {algo: _new_hash(algo, length) for algo in algorithms}
+            algorithms contains an unknown hash algorithm.
+            hash_format is an unknown hash format
 
-    for chunk in s.iter_content():
-        if not chunk:
-            break
-        for hash in hashes.values():
-            hash.update(chunk)
-        if chunk_cb:
-            chunk_cb(chunk)
-
-    if hexdigest:
-        h = {algo: hash.hexdigest() for algo, hash in hashes.items()}
-    else:
-        h = {algo: hash.digest() for algo, hash in hashes.items()}
-    if with_length:
-        h['length'] = length
-    return h
+    """
+    return hash_stream(fobj, length=length, algorithms=algorithms,
+                       chunk_cb=chunk_cb, hash_format=hash_format)
 
 
 def hash_path(path, algorithms=DEFAULT_ALGORITHMS, chunk_cb=None,
-              with_length=True, hexdigest=False):
+              hash_format='bytes', track_length=True):
     """Hash the contents of the file at the given path with the given
        algorithms.
 
     Args:
-        path: the path of the file to hash
-        algorithms: the hashing algorithms used
-        chunk_cb: a callback
-        with_length (bool): Include length in the dict result
-        hexdigest (bool): False returns the hash as binary, otherwise
-                          returns as hex
+        path (str): the path of the file to hash
+        algorithms (set): the hashing algorithms used
+        chunk_cb (def): a callback
+        hash_format (str): Format required for the output of the
+                           computed hashes (cf. HASH_FORMATS)
 
     Returns: a dict mapping each algorithm to a bytes digest.
 
     Raises:
-        ValueError if algorithms contains an unknown hash algorithm.
+        ValueError if:
+
+            algorithms contains an unknown hash algorithm.
+            hash_format is an unknown hash format
+
         OSError on file access error
 
     """
+    if track_length:
+        algorithms = set(['length']).union(algorithms)
     length = os.path.getsize(path)
     with open(path, 'rb') as fobj:
         return hash_file(fobj, length, algorithms, chunk_cb=chunk_cb,
-                         with_length=with_length, hexdigest=hexdigest)
+                         hash_format=hash_format)
 
 
-def hash_data(data, algorithms=DEFAULT_ALGORITHMS, with_length=False):
+def hash_data(data, algorithms=DEFAULT_ALGORITHMS, hash_format='bytes'):
     """Hash the given binary blob with the given algorithms.
 
     Args:
         data (bytes): raw content to hash
         algorithms (list): the hashing algorithms used
-        with_length (bool): add the length key in the resulting dict
+        hash_format (str): Format required for the output of the
+                           computed hashes (cf. HASH_FORMATS)
 
     Returns: a dict mapping each algorithm to a bytes digest
 
     Raises:
         TypeError if data does not support the buffer interface.
-        ValueError if algorithms contains an unknown hash algorithm.
+        ValueError if:
+
+            algorithms contains an unknown hash algorithm.
+            hash_format is an unknown hash format
+
     """
     fobj = BytesIO(data)
     length = len(data)
-    return hash_file(fobj, length, algorithms, with_length=with_length)
+    return hash_file(fobj, length, algorithms, hash_format=hash_format)
 
 
 def hash_git_data(data, git_type, base_algo='sha1'):
diff --git a/swh/model/tests/test_hashutil.py b/swh/model/tests/test_hashutil.py
index 99bd78e1..4b0efa56 100644
--- a/swh/model/tests/test_hashutil.py
+++ b/swh/model/tests/test_hashutil.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2015-2017  The Software Heritage developers
+# Copyright (C) 2015-2018  The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
@@ -35,6 +35,11 @@ class Hashutil(unittest.TestCase):
             for type, cksum in self.hex_checksums.items()
         }
 
+        self.bytehex_checksums = {
+            type: hashutil.hash_to_bytehex(cksum)
+            for type, cksum in self.checksums.items()
+        }
+
         self.git_hex_checksums = {
             'blob': self.hex_checksums['sha1_git'],
             'tree': '5b2e883aa33d2efab98442693ea4dd5f1b8871b0',
@@ -58,7 +63,8 @@ class Hashutil(unittest.TestCase):
         expected_checksums = self.checksums.copy()
         expected_checksums['length'] = len(self.data)
 
-        checksums = hashutil.hash_data(self.data, with_length=True)
+        algos = set(['length']).union(hashutil.DEFAULT_ALGORITHMS)
+        checksums = hashutil.hash_data(self.data, algorithms=algos)
 
         self.assertEqual(checksums, expected_checksums)
         self.assertTrue('length' in checksums)
@@ -71,6 +77,16 @@ class Hashutil(unittest.TestCase):
         self.assertIn('Unexpected hashing algorithm', cm.exception.args[0])
         self.assertIn('unknown-hash', cm.exception.args[0])
 
+    @istest
+    def hash_data_unknown_hash_format(self):
+        with self.assertRaises(ValueError) as cm:
+            hashutil.hash_data(
+                self.data, hashutil.DEFAULT_ALGORITHMS,
+                hash_format='unknown-format')
+
+        self.assertIn('Unexpected hash format', cm.exception.args[0])
+        self.assertIn('unknown-format', cm.exception.args[0])
+
     @istest
     def hash_git_data(self):
         checksums = {
@@ -98,10 +114,17 @@ class Hashutil(unittest.TestCase):
     @istest
     def hash_file_hexdigest(self):
         fobj = io.BytesIO(self.data)
-        checksums = hashutil.hash_file(fobj, length=len(self.data),
-                                       hexdigest=True)
+        checksums = hashutil.hash_file(
+            fobj, length=len(self.data), hash_format='hex')
         self.assertEqual(checksums, self.hex_checksums)
 
+    @istest
+    def hash_file_bytehexdigest(self):
+        fobj = io.BytesIO(self.data)
+        checksums = hashutil.hash_file(
+            fobj, length=len(self.data), hash_format='bytehex')
+        self.assertEqual(checksums, self.bytehex_checksums)
+
     @istest
     def hash_stream(self):
         class StreamStub:
@@ -111,9 +134,16 @@ class Hashutil(unittest.TestCase):
             def iter_content(self):
                 yield from io.BytesIO(self.data)
 
-        s = StreamStub(self.data)
-        checksums = hashutil.hash_stream(s, length=len(self.data),
-                                         hexdigest=True)
+        s = StreamStub(self.data).iter_content()
+
+        def _readfn(s):
+            try:
+                return next(s)
+            except StopIteration:
+                return None
+
+        checksums = hashutil.hash_stream(
+            s, readfn=_readfn, length=len(self.data), hash_format='hex')
         self.assertEqual(checksums, self.hex_checksums)
 
     @istest
-- 
GitLab