From 9084f96c9eb1a3befd8cd7c179950ea00dc66178 Mon Sep 17 00:00:00 2001
From: "Antoine R. Dumont (@ardumont)" <antoine.romain.dumont@gmail.com>
Date: Sat, 15 Sep 2018 00:35:43 +0200
Subject: [PATCH] hashutil: Improve MultiHash class from_* to compute hashes

This allows calls like:

- MultiHash.from_file(file_object).digest()
- MultiHash.from_path(b'foo').hexdigest()
- MultiHash.from_data(b'foo').bytehexdigest()
---
 swh/model/hashutil.py            | 29 +++++++++++++
 swh/model/tests/test_hashutil.py | 72 +++++++++++++++++++++++++++++++-
 2 files changed, 100 insertions(+), 1 deletion(-)

diff --git a/swh/model/hashutil.py b/swh/model/hashutil.py
index 69586a82..339b72fa 100644
--- a/swh/model/hashutil.py
+++ b/swh/model/hashutil.py
@@ -49,6 +49,10 @@ HASH_FORMATS = set(['bytes', 'bytehex', 'hex'])
 """Supported output hash formats
 """
 
+EXTRA_LENGTH = set(['length'])
+"""Extra information to compute
+"""
+
 
 class MultiHash:
     """Hashutil class to support multiple hashes computation.
@@ -79,6 +83,31 @@ class MultiHash:
         ret.state = state
         ret.track_length = track_length
 
+    @classmethod
+    def from_file(cls, file, hash_names=DEFAULT_ALGORITHMS, length=None):
+        ret = cls(length=length, hash_names=hash_names)
+        for chunk in file:
+            ret.update(chunk)
+        return ret
+
+    @classmethod
+    def from_path(cls, path, hash_names=DEFAULT_ALGORITHMS, length=None,
+                  track_length=True):
+        if not length:
+            length = os.path.getsize(path)
+        # For compatibility reason with `hash_path`
+        if track_length:
+            hash_names = hash_names.union(EXTRA_LENGTH)
+        with open(path, 'rb') as f:
+            return cls.from_file(f, hash_names=hash_names, length=length)
+
+    @classmethod
+    def from_data(cls, data, hash_names=DEFAULT_ALGORITHMS, length=None):
+        if not length:
+            length = len(data)
+        fobj = BytesIO(data)
+        return cls.from_file(fobj, hash_names=hash_names, length=length)
+
     def update(self, chunk):
         for name, h in self.state.items():
             if name == 'length':
diff --git a/swh/model/tests/test_hashutil.py b/swh/model/tests/test_hashutil.py
index cbe16603..d288149b 100644
--- a/swh/model/tests/test_hashutil.py
+++ b/swh/model/tests/test_hashutil.py
@@ -13,9 +13,10 @@ from nose.tools import istest
 from unittest.mock import patch
 
 from swh.model import hashutil
+from swh.model.hashutil import MultiHash
 
 
-class Hashutil(unittest.TestCase):
+class BaseHashutil(unittest.TestCase):
     def setUp(self):
         # Reset function cache
         hashutil._blake2_hash_cache = {}
@@ -52,6 +53,75 @@ class Hashutil(unittest.TestCase):
             for type, cksum in self.git_hex_checksums.items()
         }
 
+
+class MultiHashTest(BaseHashutil):
+    @istest
+    def multi_hash_data(self):
+        checksums = MultiHash.from_data(self.data).digest()
+        self.assertEqual(checksums, self.checksums)
+        self.assertFalse('length' in checksums)
+
+    @istest
+    def multi_hash_data_with_length(self):
+        expected_checksums = self.checksums.copy()
+        expected_checksums['length'] = len(self.data)
+
+        algos = set(['length']).union(hashutil.DEFAULT_ALGORITHMS)
+        checksums = MultiHash.from_data(self.data, hash_names=algos).digest()
+
+        self.assertEqual(checksums, expected_checksums)
+        self.assertTrue('length' in checksums)
+
+    @istest
+    def multi_hash_data_unknown_hash(self):
+        with self.assertRaises(ValueError) as cm:
+            MultiHash.from_data(self.data, ['unknown-hash'])
+
+        self.assertIn('Unexpected hashing algorithm', cm.exception.args[0])
+        self.assertIn('unknown-hash', cm.exception.args[0])
+
+    @istest
+    def multi_hash_file(self):
+        fobj = io.BytesIO(self.data)
+
+        checksums = MultiHash.from_file(fobj, length=len(self.data)).digest()
+        self.assertEqual(checksums, self.checksums)
+
+    @istest
+    def multi_hash_file_hexdigest(self):
+        fobj = io.BytesIO(self.data)
+        length = len(self.data)
+        checksums = MultiHash.from_file(fobj, length=length).hexdigest()
+        self.assertEqual(checksums, self.hex_checksums)
+
+    @istest
+    def multi_hash_file_bytehexdigest(self):
+        fobj = io.BytesIO(self.data)
+        length = len(self.data)
+        checksums = MultiHash.from_file(fobj, length=length).bytehexdigest()
+        self.assertEqual(checksums, self.bytehex_checksums)
+
+    @istest
+    def multi_hash_file_missing_length(self):
+        fobj = io.BytesIO(self.data)
+        with self.assertRaises(ValueError) as cm:
+            MultiHash.from_file(fobj, hash_names=['sha1_git'])
+
+        self.assertIn('Missing length', cm.exception.args[0])
+
+    @istest
+    def multi_hash_path(self):
+        with tempfile.NamedTemporaryFile(delete=False) as f:
+            f.write(self.data)
+
+        hashes = MultiHash.from_path(f.name).digest()
+        os.remove(f.name)
+
+        self.checksums['length'] = len(self.data)
+        self.assertEquals(self.checksums, hashes)
+
+
+class Hashutil(BaseHashutil):
     @istest
     def hash_data(self):
         checksums = hashutil.hash_data(self.data)
-- 
GitLab