Skip to content
Snippets Groups Projects
Commit f44949a9 authored by Nicolas Dandrimont's avatar Nicolas Dandrimont
Browse files

Add a Merkle tree data structure

parent ac3df91a
No related branches found
No related tags found
No related merge requests found
# Copyright (C) 2017 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
"""Merkle tree data structure"""
import abc
import collections
def deep_update(left, right):
"""Recursively update the left mapping with deeply nested values from the right
mapping.
This function is useful to merge the results of several calls to
:func:`MerkleNode.collect`.
Arguments:
left: a mapping (modified by the update operation)
right: a mapping
Returns:
the left mapping, updated with nested values from the right mapping
Example:
>>> a = {
... 'key1': {
... 'key2': {
... 'key3': 'value1/2/3',
... },
... },
... }
>>> deep_update(a, {
... 'key1': {
... 'key2': {
... 'key4': 'value1/2/4',
... },
... },
... }) == {
... 'key1': {
... 'key2': {
... 'key3': 'value1/2/3',
... 'key4': 'value1/2/4',
... },
... },
... }
True
>>> deep_update(a, {
... 'key1': {
... 'key2': {
... 'key3': 'newvalue1/2/3',
... },
... },
... }) == {
... 'key1': {
... 'key2': {
... 'key3': 'newvalue1/2/3',
... 'key4': 'value1/2/4',
... },
... },
... }
True
"""
for key, rvalue in right.items():
if isinstance(rvalue, collections.Mapping):
new_lvalue = deep_update(left.get(key, {}), rvalue)
left[key] = new_lvalue
else:
left[key] = rvalue
return left
class MerkleNode(dict, metaclass=abc.ABCMeta):
"""Representation of a node in a Merkle Tree.
A (generalized) `Merkle Tree`_ is a tree in which every node is labeled
with a hash of its own data and the hash of its children.
.. _Merkle Tree: https://en.wikipedia.org/wiki/Merkle_tree
In pseudocode::
node.hash = hash(node.data
+ sum(child.hash for child in node.children))
This class efficiently implements the Merkle Tree data structure on top of
a Python :class:`dict`, minimizing hash computations and new data
collections when updating nodes.
Node data is stored in the :attr:`data` attribute, while (named) children
are stored as items of the underlying dictionary.
Addition, update and removal of objects are instrumented to automatically
invalidate the hashes of the current node as well as its registered
parents; It also resets the collection status of the objects so the updated
objects can be collected.
The collection of updated data from the tree is implemented through the
:func:`collect` function and associated helpers.
Attributes:
data (dict): data associated to the current node
parents (list): known parents of the current node
collected (bool): whether the current node has been collected
"""
__slots__ = ['parents', 'data', '__hash', 'collected']
type = None
"""Type of the current node (used as a classifier for :func:`collect`)"""
def __init__(self, data=None):
super().__init__()
self.parents = []
self.data = data
self.__hash = None
self.collected = False
def invalidate_hash(self):
"""Invalidate the cached hash of the current node."""
if not self.__hash:
return
self.__hash = None
self.collected = False
for parent in self.parents:
parent.invalidate_hash()
def update_hash(self, *, force=False):
"""Recursively compute the hash of the current node.
Args:
force (bool): invalidate the cache and force the computation for
this node and all children.
"""
if self.__hash and not force:
return self.__hash
if force:
self.invalidate_hash()
for child in self.values():
child.update_hash(force=force)
self.__hash = self.compute_hash()
return self.__hash
@property
def hash(self):
"""The hash of the current node, as calculated by
:func:`compute_hash`.
"""
return self.update_hash()
@abc.abstractmethod
def compute_hash(self):
"""Compute the hash of the current node.
The hash should depend on the data of the node, as well as on hashes
of the children nodes.
"""
raise NotImplementedError('Must implement compute_hash method')
def __setitem__(self, name, new_child):
"""Add a child, invalidating the current hash"""
self.invalidate_hash()
super().__setitem__(name, new_child)
new_child.parents.append(self)
def __delitem__(self, name):
"""Remove a child, invalidating the current hash"""
if name in self:
self.invalidate_hash()
self[name].parents.remove(self)
super().__delitem__(name)
else:
raise KeyError(name)
def update(self, new_children):
"""Add several named children from a dictionary"""
if not new_children:
return
self.invalidate_hash()
for name, new_child in new_children.items():
new_child.parents.append(self)
if name in self:
self[name].parents.remove(self)
super().update(new_children)
def get_data(self, **kwargs):
"""Retrieve and format the collected data for the current node, for use by
:func:`collect`.
Can be overridden, for instance when you want the collected data to
contain information about the child nodes.
Arguments:
kwargs: allow subclasses to alter behaviour depending on how
:func:`collect` is called.
Returns:
data formatted for :func:`collect`
"""
return self.data
def collect_node(self, **kwargs):
"""Collect the data for the current node, for use by :func:`collect`.
Arguments:
kwargs: passed as-is to :func:`get_data`.
Returns:
A :class:`dict` compatible with :func:`collect`.
"""
if not self.collected:
self.collected = True
return {self.type: {self.hash: self.get_data(**kwargs)}}
else:
return {}
def collect(self, **kwargs):
"""Collect the data for all nodes in the subtree rooted at `self`.
The data is deduplicated by type and by hash.
Arguments:
kwargs: passed as-is to :func:`get_data`.
Returns:
A :class:`dict` with the following structure::
{
'typeA': {
node1.hash: node1.get_data(),
node2.hash: node2.get_data(),
},
'typeB': {
node3.hash: node3.get_data(),
...
},
...
}
"""
ret = self.collect_node(**kwargs)
for child in self.values():
deep_update(ret, child.collect(**kwargs))
return ret
def reset_collect(self):
"""Recursively unmark collected nodes in the subtree rooted at `self`.
This lets the caller use :func:`collect` again.
"""
self.collected = False
for child in self.values():
child.reset_collect()
class MerkleLeaf(MerkleNode):
"""A leaf to a Merkle tree.
A Merkle leaf is simply a Merkle node with children disabled.
"""
__slots__ = []
def __setitem__(self, name, child):
raise ValueError('%s is a leaf' % self.__class__.__name__)
def __getitem__(self, name):
raise ValueError('%s is a leaf' % self.__class__.__name__)
def __delitem__(self, name):
raise ValueError('%s is a leaf' % self.__class__.__name__)
def update(self, new_children):
"""Children update operation. Disabled for leaves."""
raise ValueError('%s is a leaf' % self.__class__.__name__)
# Copyright (C) 2017 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import unittest
from swh.model import merkle
class TestedMerkleNode(merkle.MerkleNode):
type = 'tested_merkle_node_type'
def __init__(self, data):
super().__init__(data)
self.compute_hash_called = 0
def compute_hash(self):
self.compute_hash_called += 1
child_data = [
child + b'=' + self[child].hash
for child in sorted(self)
]
return (
b'hash('
+ b', '.join([self.data['value']] + child_data)
+ b')'
)
class TestedMerkleLeaf(merkle.MerkleLeaf):
type = 'tested_merkle_leaf_type'
def __init__(self, data):
super().__init__(data)
self.compute_hash_called = 0
def compute_hash(self):
self.compute_hash_called += 1
return b'hash(' + self.data['value'] + b')'
class TestMerkleLeaf(unittest.TestCase):
def setUp(self):
self.data = {'value': b'value'}
self.instance = TestedMerkleLeaf(self.data)
def test_hash(self):
self.assertEqual(self.instance.compute_hash_called, 0)
instance_hash = self.instance.hash
self.assertEqual(self.instance.compute_hash_called, 1)
instance_hash2 = self.instance.hash
self.assertEqual(self.instance.compute_hash_called, 1)
self.assertEqual(instance_hash, instance_hash2)
def test_data(self):
self.assertEqual(self.instance.get_data(), self.data)
def test_collect(self):
collected = self.instance.collect()
self.assertEqual(
collected, {
self.instance.type: {
self.instance.hash: self.instance.get_data(),
},
},
)
collected2 = self.instance.collect()
self.assertEqual(collected2, {})
self.instance.reset_collect()
collected3 = self.instance.collect()
self.assertEqual(collected, collected3)
def test_leaf(self):
with self.assertRaisesRegex(ValueError, 'is a leaf'):
self.instance[b'key1'] = 'Test'
with self.assertRaisesRegex(ValueError, 'is a leaf'):
del self.instance[b'key1']
with self.assertRaisesRegex(ValueError, 'is a leaf'):
self.instance[b'key1']
with self.assertRaisesRegex(ValueError, 'is a leaf'):
self.instance.update(self.data)
class TestMerkleNode(unittest.TestCase):
maxDiff = None
def setUp(self):
self.root = TestedMerkleNode({'value': b'root'})
self.nodes = {b'root': self.root}
for i in (b'a', b'b', b'c'):
value = b'root/' + i
node = TestedMerkleNode({
'value': value,
})
self.root[i] = node
self.nodes[value] = node
for j in (b'a', b'b', b'c'):
value2 = value + b'/' + j
node2 = TestedMerkleNode({
'value': value2,
})
node[j] = node2
self.nodes[value2] = node2
for k in (b'a', b'b', b'c'):
value3 = value2 + b'/' + j
node3 = TestedMerkleNode({
'value': value3,
})
node2[j] = node3
self.nodes[value3] = node3
def test_hash(self):
for node in self.nodes.values():
self.assertEqual(node.compute_hash_called, 0)
# Root hash will compute hash for all the nodes
hash = self.root.hash
for node in self.nodes.values():
self.assertEqual(node.compute_hash_called, 1)
self.assertIn(node.data['value'], hash)
# Should use the cached value
hash2 = self.root.hash
self.assertEqual(hash, hash2)
for node in self.nodes.values():
self.assertEqual(node.compute_hash_called, 1)
# Should still use the cached value
hash3 = self.root.update_hash(force=False)
self.assertEqual(hash, hash3)
for node in self.nodes.values():
self.assertEqual(node.compute_hash_called, 1)
# Force update of the cached value for a deeply nested node
self.root[b'a'][b'b'].update_hash(force=True)
for key, node in self.nodes.items():
# update_hash rehashes all children
if key.startswith(b'root/a/b'):
self.assertEqual(node.compute_hash_called, 2)
else:
self.assertEqual(node.compute_hash_called, 1)
hash4 = self.root.hash
self.assertEqual(hash, hash4)
for key, node in self.nodes.items():
# update_hash also invalidates all parents
if key in (b'root', b'root/a') or key.startswith(b'root/a/b'):
self.assertEqual(node.compute_hash_called, 2)
else:
self.assertEqual(node.compute_hash_called, 1)
def test_collect(self):
collected = self.root.collect()
self.assertEqual(len(collected[self.root.type]), len(self.nodes))
for node in self.nodes.values():
self.assertTrue(node.collected)
collected2 = self.root.collect()
self.assertEqual(collected2, {})
def test_get(self):
for key in (b'a', b'b', b'c'):
self.assertEqual(self.root[key], self.nodes[b'root/' + key])
with self.assertRaisesRegex(KeyError, "b'nonexistent'"):
self.root[b'nonexistent']
def test_del(self):
hash_root = self.root.hash
hash_a = self.nodes[b'root/a'].hash
del self.root[b'a'][b'c']
hash_root2 = self.root.hash
hash_a2 = self.nodes[b'root/a'].hash
self.assertNotEqual(hash_root, hash_root2)
self.assertNotEqual(hash_a, hash_a2)
self.assertEqual(self.nodes[b'root/a/c'].parents, [])
with self.assertRaisesRegex(KeyError, "b'nonexistent'"):
del self.root[b'nonexistent']
def test_update(self):
hash_root = self.root.hash
hash_b = self.root[b'b'].hash
new_children = {
b'c': TestedMerkleNode({'value': b'root/b/new_c'}),
b'd': TestedMerkleNode({'value': b'root/b/d'}),
}
# collect all nodes
self.root.collect()
self.root[b'b'].update(new_children)
# Ensure everyone got reparented
self.assertEqual(new_children[b'c'].parents, [self.root[b'b']])
self.assertEqual(new_children[b'd'].parents, [self.root[b'b']])
self.assertEqual(self.nodes[b'root/b/c'].parents, [])
hash_root2 = self.root.hash
self.assertNotEqual(hash_root, hash_root2)
self.assertIn(b'root/b/new_c', hash_root2)
self.assertIn(b'root/b/d', hash_root2)
hash_b2 = self.root[b'b'].hash
self.assertNotEqual(hash_b, hash_b2)
for key, node in self.nodes.items():
if key in (b'root', b'root/b'):
self.assertEqual(node.compute_hash_called, 2)
else:
self.assertEqual(node.compute_hash_called, 1)
# Ensure we collected root, root/b, and both new children
collected_after_update = self.root.collect()
self.assertCountEqual(
collected_after_update[TestedMerkleNode.type],
[self.nodes[b'root'].hash, self.nodes[b'root/b'].hash,
new_children[b'c'].hash, new_children[b'd'].hash],
)
# test that noop updates doesn't invalidate anything
self.root[b'a'][b'b'].update({})
self.assertEqual(self.root.collect(), {})
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment