Skip to content
Snippets Groups Projects
Commit fbbd044a authored by Nicolas Dandrimont's avatar Nicolas Dandrimont
Browse files

swh.model.hashutil: Implement missing swh.core.hashutil functionality

parent e92e3c51
No related branches found
No related tags found
No related merge requests found
......@@ -3,8 +3,11 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import binascii
import functools
import hashlib
from io import BytesIO
import os
# supported hashing algorithms
ALGORITHMS = set(['sha1', 'sha256', 'sha1_git'])
......@@ -78,7 +81,7 @@ def _new_hash(algo, length=None):
return h
def hash_file(fobj, length=None, algorithms=ALGORITHMS):
def hash_file(fobj, length=None, algorithms=ALGORITHMS, chunk_cb=None):
"""Hash the contents of the given file object with the given algorithms.
Args:
......@@ -87,7 +90,7 @@ def hash_file(fobj, length=None, algorithms=ALGORITHMS):
git-specific algorithms)
algorithms: the hashing algorithms used
Returns: a dict mapping each algorithm to a hexadecimal digest
Returns: a dict mapping each algorithm to a bytes digest.
Raises:
ValueError if algorithms contains an unknown hash algorithm.
......@@ -100,8 +103,29 @@ def hash_file(fobj, length=None, algorithms=ALGORITHMS):
break
for hash in hashes.values():
hash.update(chunk)
if chunk_cb:
chunk_cb(chunk)
return {algo: hash.hexdigest() for algo, hash in hashes.items()}
return {algo: hash.digest() for algo, hash in hashes.items()}
def hash_path(path, algorithms=ALGORITHMS, chunk_cb=None):
"""Hash the contents of the file at the given path with the given algorithms.
Args:
path: the path of the file to hash
algorithms: the hashing algorithms used
chunk_cb: a callback
Returns: a dict mapping each algorithm to a bytes digest.
Raises:
ValueError if algorithms contains an unknown hash algorithm.
OSError on file access error
"""
length = os.path.getsize(path)
with open(path, 'rb') as fobj:
return hash_file(fobj, length, algorithms, chunk_cb)
def hash_data(data, algorithms=ALGORITHMS):
......@@ -111,7 +135,7 @@ def hash_data(data, algorithms=ALGORITHMS):
data: a bytes object
algorithms: the hashing algorithms used
Returns: a dict mapping each algorithm to a hexadecimal digest
Returns: a dict mapping each algorithm to a bytes digest
Raises:
TypeError if data does not support the buffer interface.
......@@ -129,7 +153,7 @@ def hash_git_data(data, git_type, base_algo='sha1'):
git_type: the git object type
base_algo: the base hashing algorithm used (default: sha1)
Returns: a dict mapping each algorithm to a hexadecimal digest
Returns: a dict mapping each algorithm to a bytes digest
Raises:
ValueError if the git_type is unexpected.
......@@ -144,4 +168,20 @@ def hash_git_data(data, git_type, base_algo='sha1'):
h = _new_git_hash(base_algo, git_type, len(data))
h.update(data)
return h.hexdigest()
return h.digest()
@functools.lru_cache()
def hash_to_hex(hash):
"""Converts a hash (in hex or bytes form) to its hexadecimal ascii form"""
if isinstance(hash, str):
return hash
return binascii.hexlify(hash).decode('ascii')
@functools.lru_cache()
def hash_to_bytes(hash):
"""Converts a hash (in hex or bytes form) to its raw bytes form"""
if isinstance(hash, bytes):
return hash
return bytes.fromhex(hash)
......@@ -159,7 +159,8 @@ def directory_identifier(directory):
identifier_to_bytes(entry['target']),
])
return hashutil.hash_git_data(b''.join(components), 'tree')
return identifier_to_str(hashutil.hash_git_data(b''.join(components),
'tree'))
def format_date(date):
......@@ -265,7 +266,8 @@ def revision_identifier(revision):
revision['message'],
])
return hashutil.hash_git_data(b''.join(components), 'commit')
return identifier_to_str(hashutil.hash_git_data(b''.join(components),
'commit'))
def target_type_to_git(target_type):
......@@ -294,4 +296,5 @@ def release_identifier(release):
components.extend([b'\n', release['message']])
return hashutil.hash_git_data(b''.join(components), 'tag')
return identifier_to_str(hashutil.hash_git_data(b''.join(components),
'tag'))
......@@ -4,6 +4,7 @@
# See top-level LICENSE file for more information
import io
import tempfile
import unittest
from nose.tools import istest
......@@ -21,17 +22,27 @@ class Hashutil(unittest.TestCase):
'4a9b50ee5b5866c0d91fab0e65907311',
}
self.git_checksums = {
self.checksums = {
type: bytes.fromhex(cksum)
for type, cksum in self.hex_checksums.items()
}
self.git_hex_checksums = {
'blob': self.hex_checksums['sha1_git'],
'tree': '5b2e883aa33d2efab98442693ea4dd5f1b8871b0',
'commit': '79e4093542e72f0fcb7cbd75cb7d270f9254aa8f',
'tag': 'd6bf62466f287b4d986c545890716ce058bddf67',
}
self.git_checksums = {
type: bytes.fromhex(cksum)
for type, cksum in self.git_hex_checksums.items()
}
@istest
def hash_data(self):
checksums = hashutil.hash_data(self.data)
self.assertEqual(checksums, self.hex_checksums)
self.assertEqual(checksums, self.checksums)
@istest
def hash_data_unknown_hash(self):
......@@ -63,7 +74,7 @@ class Hashutil(unittest.TestCase):
fobj = io.BytesIO(self.data)
checksums = hashutil.hash_file(fobj, length=len(self.data))
self.assertEqual(checksums, self.hex_checksums)
self.assertEqual(checksums, self.checksums)
@istest
def hash_file_missing_length(self):
......@@ -73,3 +84,28 @@ class Hashutil(unittest.TestCase):
hashutil.hash_file(fobj, algorithms=['sha1_git'])
self.assertIn('Missing length', cm.exception.args[0])
@istest
def hash_path(self):
with tempfile.NamedTemporaryFile(delete=False) as f:
f.write(self.data)
f.close()
hashes = hashutil.hash_path(f.name)
self.assertEquals(self.checksums, hashes)
@istest
def hash_to_hex(self):
for type in self.checksums:
hex = self.hex_checksums[type]
hash = self.checksums[type]
self.assertEquals(hashutil.hash_to_hex(hex), hex)
self.assertEquals(hashutil.hash_to_hex(hash), hex)
@istest
def hash_to_bytes(self):
for type in self.checksums:
hex = self.hex_checksums[type]
hash = self.checksums[type]
self.assertEquals(hashutil.hash_to_bytes(hex), hash)
self.assertEquals(hashutil.hash_to_bytes(hash), hash)
......@@ -3,8 +3,6 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import binascii
from .exceptions import ValidationError, NON_FIELD_ERRORS
from . import fields, hashutil
......@@ -50,9 +48,7 @@ def validate_content(content):
for hash_type, computed_hash in hashes.items():
if hash_type not in content:
continue
content_hash = content[hash_type]
if isinstance(content_hash, bytes):
content_hash = binascii.hexlify(content_hash).decode()
content_hash = hashutil.hash_to_bytes(content[hash_type])
if content_hash != computed_hash:
errors.append(ValidationError(
'hash mismatch in content for hash %(hash)s',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment