From 9084f96c9eb1a3befd8cd7c179950ea00dc66178 Mon Sep 17 00:00:00 2001 From: "Antoine R. Dumont (@ardumont)" <antoine.romain.dumont@gmail.com> Date: Sat, 15 Sep 2018 00:35:43 +0200 Subject: [PATCH] hashutil: Improve MultiHash class from_* to compute hashes This allows calls like: - MultiHash.from_file(file_object).digest() - MultiHash.from_path(b'foo').hexdigest() - MultiHash.from_data(b'foo').bytehexdigest() --- swh/model/hashutil.py | 29 +++++++++++++ swh/model/tests/test_hashutil.py | 72 +++++++++++++++++++++++++++++++- 2 files changed, 100 insertions(+), 1 deletion(-) diff --git a/swh/model/hashutil.py b/swh/model/hashutil.py index 69586a82..339b72fa 100644 --- a/swh/model/hashutil.py +++ b/swh/model/hashutil.py @@ -49,6 +49,10 @@ HASH_FORMATS = set(['bytes', 'bytehex', 'hex']) """Supported output hash formats """ +EXTRA_LENGTH = set(['length']) +"""Extra information to compute +""" + class MultiHash: """Hashutil class to support multiple hashes computation. @@ -79,6 +83,31 @@ class MultiHash: ret.state = state ret.track_length = track_length + @classmethod + def from_file(cls, file, hash_names=DEFAULT_ALGORITHMS, length=None): + ret = cls(length=length, hash_names=hash_names) + for chunk in file: + ret.update(chunk) + return ret + + @classmethod + def from_path(cls, path, hash_names=DEFAULT_ALGORITHMS, length=None, + track_length=True): + if not length: + length = os.path.getsize(path) + # For compatibility reason with `hash_path` + if track_length: + hash_names = hash_names.union(EXTRA_LENGTH) + with open(path, 'rb') as f: + return cls.from_file(f, hash_names=hash_names, length=length) + + @classmethod + def from_data(cls, data, hash_names=DEFAULT_ALGORITHMS, length=None): + if not length: + length = len(data) + fobj = BytesIO(data) + return cls.from_file(fobj, hash_names=hash_names, length=length) + def update(self, chunk): for name, h in self.state.items(): if name == 'length': diff --git a/swh/model/tests/test_hashutil.py b/swh/model/tests/test_hashutil.py index cbe16603..d288149b 100644 --- a/swh/model/tests/test_hashutil.py +++ b/swh/model/tests/test_hashutil.py @@ -13,9 +13,10 @@ from nose.tools import istest from unittest.mock import patch from swh.model import hashutil +from swh.model.hashutil import MultiHash -class Hashutil(unittest.TestCase): +class BaseHashutil(unittest.TestCase): def setUp(self): # Reset function cache hashutil._blake2_hash_cache = {} @@ -52,6 +53,75 @@ class Hashutil(unittest.TestCase): for type, cksum in self.git_hex_checksums.items() } + +class MultiHashTest(BaseHashutil): + @istest + def multi_hash_data(self): + checksums = MultiHash.from_data(self.data).digest() + self.assertEqual(checksums, self.checksums) + self.assertFalse('length' in checksums) + + @istest + def multi_hash_data_with_length(self): + expected_checksums = self.checksums.copy() + expected_checksums['length'] = len(self.data) + + algos = set(['length']).union(hashutil.DEFAULT_ALGORITHMS) + checksums = MultiHash.from_data(self.data, hash_names=algos).digest() + + self.assertEqual(checksums, expected_checksums) + self.assertTrue('length' in checksums) + + @istest + def multi_hash_data_unknown_hash(self): + with self.assertRaises(ValueError) as cm: + MultiHash.from_data(self.data, ['unknown-hash']) + + self.assertIn('Unexpected hashing algorithm', cm.exception.args[0]) + self.assertIn('unknown-hash', cm.exception.args[0]) + + @istest + def multi_hash_file(self): + fobj = io.BytesIO(self.data) + + checksums = MultiHash.from_file(fobj, length=len(self.data)).digest() + self.assertEqual(checksums, self.checksums) + + @istest + def multi_hash_file_hexdigest(self): + fobj = io.BytesIO(self.data) + length = len(self.data) + checksums = MultiHash.from_file(fobj, length=length).hexdigest() + self.assertEqual(checksums, self.hex_checksums) + + @istest + def multi_hash_file_bytehexdigest(self): + fobj = io.BytesIO(self.data) + length = len(self.data) + checksums = MultiHash.from_file(fobj, length=length).bytehexdigest() + self.assertEqual(checksums, self.bytehex_checksums) + + @istest + def multi_hash_file_missing_length(self): + fobj = io.BytesIO(self.data) + with self.assertRaises(ValueError) as cm: + MultiHash.from_file(fobj, hash_names=['sha1_git']) + + self.assertIn('Missing length', cm.exception.args[0]) + + @istest + def multi_hash_path(self): + with tempfile.NamedTemporaryFile(delete=False) as f: + f.write(self.data) + + hashes = MultiHash.from_path(f.name).digest() + os.remove(f.name) + + self.checksums['length'] = len(self.data) + self.assertEquals(self.checksums, hashes) + + +class Hashutil(BaseHashutil): @istest def hash_data(self): checksums = hashutil.hash_data(self.data) -- GitLab