From f44949a9ed39ac9b7ff4cd96745b35c6ce700db2 Mon Sep 17 00:00:00 2001
From: Nicolas Dandrimont <nicolas@dandrimont.eu>
Date: Fri, 22 Sep 2017 17:10:25 +0200
Subject: [PATCH] Add a Merkle tree data structure

---
 swh/model/merkle.py            | 286 +++++++++++++++++++++++++++++++++
 swh/model/tests/test_merkle.py | 229 ++++++++++++++++++++++++++
 2 files changed, 515 insertions(+)
 create mode 100644 swh/model/merkle.py
 create mode 100644 swh/model/tests/test_merkle.py

diff --git a/swh/model/merkle.py b/swh/model/merkle.py
new file mode 100644
index 00000000..c75cc2c2
--- /dev/null
+++ b/swh/model/merkle.py
@@ -0,0 +1,286 @@
+# Copyright (C) 2017 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+"""Merkle tree data structure"""
+
+import abc
+import collections
+
+
+def deep_update(left, right):
+    """Recursively update the left mapping with deeply nested values from the right
+    mapping.
+
+    This function is useful to merge the results of several calls to
+    :func:`MerkleNode.collect`.
+
+    Arguments:
+      left: a mapping (modified by the update operation)
+      right: a mapping
+
+    Returns:
+      the left mapping, updated with nested values from the right mapping
+
+    Example:
+        >>> a = {
+        ...     'key1': {
+        ...         'key2': {
+        ...              'key3': 'value1/2/3',
+        ...         },
+        ...     },
+        ... }
+        >>> deep_update(a, {
+        ...     'key1': {
+        ...         'key2': {
+        ...              'key4': 'value1/2/4',
+        ...         },
+        ...     },
+        ... }) == {
+        ...     'key1': {
+        ...         'key2': {
+        ...             'key3': 'value1/2/3',
+        ...             'key4': 'value1/2/4',
+        ...         },
+        ...     },
+        ... }
+        True
+        >>> deep_update(a, {
+        ...     'key1': {
+        ...         'key2': {
+        ...              'key3': 'newvalue1/2/3',
+        ...         },
+        ...     },
+        ... }) == {
+        ...     'key1': {
+        ...         'key2': {
+        ...             'key3': 'newvalue1/2/3',
+        ...             'key4': 'value1/2/4',
+        ...         },
+        ...     },
+        ... }
+        True
+
+    """
+    for key, rvalue in right.items():
+        if isinstance(rvalue, collections.Mapping):
+            new_lvalue = deep_update(left.get(key, {}), rvalue)
+            left[key] = new_lvalue
+        else:
+            left[key] = rvalue
+    return left
+
+
+class MerkleNode(dict, metaclass=abc.ABCMeta):
+    """Representation of a node in a Merkle Tree.
+
+    A (generalized) `Merkle Tree`_ is a tree in which every node is labeled
+    with a hash of its own data and the hash of its children.
+
+    .. _Merkle Tree: https://en.wikipedia.org/wiki/Merkle_tree
+
+    In pseudocode::
+
+      node.hash = hash(node.data
+                       + sum(child.hash for child in node.children))
+
+    This class efficiently implements the Merkle Tree data structure on top of
+    a Python :class:`dict`, minimizing hash computations and new data
+    collections when updating nodes.
+
+    Node data is stored in the :attr:`data` attribute, while (named) children
+    are stored as items of the underlying dictionary.
+
+    Addition, update and removal of objects are instrumented to automatically
+    invalidate the hashes of the current node as well as its registered
+    parents; It also resets the collection status of the objects so the updated
+    objects can be collected.
+
+    The collection of updated data from the tree is implemented through the
+    :func:`collect` function and associated helpers.
+
+    Attributes:
+      data (dict): data associated to the current node
+      parents (list): known parents of the current node
+      collected (bool): whether the current node has been collected
+
+    """
+    __slots__ = ['parents', 'data', '__hash', 'collected']
+
+    type = None
+    """Type of the current node (used as a classifier for :func:`collect`)"""
+
+    def __init__(self, data=None):
+        super().__init__()
+        self.parents = []
+        self.data = data
+        self.__hash = None
+        self.collected = False
+
+    def invalidate_hash(self):
+        """Invalidate the cached hash of the current node."""
+        if not self.__hash:
+            return
+
+        self.__hash = None
+        self.collected = False
+        for parent in self.parents:
+            parent.invalidate_hash()
+
+    def update_hash(self, *, force=False):
+        """Recursively compute the hash of the current node.
+
+        Args:
+          force (bool): invalidate the cache and force the computation for
+            this node and all children.
+        """
+        if self.__hash and not force:
+            return self.__hash
+
+        if force:
+            self.invalidate_hash()
+
+        for child in self.values():
+            child.update_hash(force=force)
+
+        self.__hash = self.compute_hash()
+        return self.__hash
+
+    @property
+    def hash(self):
+        """The hash of the current node, as calculated by
+        :func:`compute_hash`.
+        """
+        return self.update_hash()
+
+    @abc.abstractmethod
+    def compute_hash(self):
+        """Compute the hash of the current node.
+
+        The hash should depend on the data of the node, as well as on hashes
+        of the children nodes.
+        """
+        raise NotImplementedError('Must implement compute_hash method')
+
+    def __setitem__(self, name, new_child):
+        """Add a child, invalidating the current hash"""
+        self.invalidate_hash()
+
+        super().__setitem__(name, new_child)
+
+        new_child.parents.append(self)
+
+    def __delitem__(self, name):
+        """Remove a child, invalidating the current hash"""
+        if name in self:
+            self.invalidate_hash()
+            self[name].parents.remove(self)
+            super().__delitem__(name)
+        else:
+            raise KeyError(name)
+
+    def update(self, new_children):
+        """Add several named children from a dictionary"""
+        if not new_children:
+            return
+
+        self.invalidate_hash()
+
+        for name, new_child in new_children.items():
+            new_child.parents.append(self)
+            if name in self:
+                self[name].parents.remove(self)
+
+        super().update(new_children)
+
+    def get_data(self, **kwargs):
+        """Retrieve and format the collected data for the current node, for use by
+        :func:`collect`.
+
+        Can be overridden, for instance when you want the collected data to
+        contain information about the child nodes.
+
+        Arguments:
+          kwargs: allow subclasses to alter behaviour depending on how
+            :func:`collect` is called.
+
+        Returns:
+          data formatted for :func:`collect`
+        """
+        return self.data
+
+    def collect_node(self, **kwargs):
+        """Collect the data for the current node, for use by :func:`collect`.
+
+        Arguments:
+          kwargs: passed as-is to :func:`get_data`.
+
+        Returns:
+          A :class:`dict` compatible with :func:`collect`.
+        """
+        if not self.collected:
+            self.collected = True
+            return {self.type: {self.hash: self.get_data(**kwargs)}}
+        else:
+            return {}
+
+    def collect(self, **kwargs):
+        """Collect the data for all nodes in the subtree rooted at `self`.
+
+        The data is deduplicated by type and by hash.
+
+        Arguments:
+          kwargs: passed as-is to :func:`get_data`.
+
+        Returns:
+           A :class:`dict` with the following structure::
+
+             {
+               'typeA': {
+                 node1.hash: node1.get_data(),
+                 node2.hash: node2.get_data(),
+               },
+               'typeB': {
+                 node3.hash: node3.get_data(),
+                 ...
+               },
+               ...
+             }
+        """
+        ret = self.collect_node(**kwargs)
+        for child in self.values():
+            deep_update(ret, child.collect(**kwargs))
+
+        return ret
+
+    def reset_collect(self):
+        """Recursively unmark collected nodes in the subtree rooted at `self`.
+
+        This lets the caller use :func:`collect` again.
+        """
+        self.collected = False
+
+        for child in self.values():
+            child.reset_collect()
+
+
+class MerkleLeaf(MerkleNode):
+    """A leaf to a Merkle tree.
+
+    A Merkle leaf is simply a Merkle node with children disabled.
+    """
+    __slots__ = []
+
+    def __setitem__(self, name, child):
+        raise ValueError('%s is a leaf' % self.__class__.__name__)
+
+    def __getitem__(self, name):
+        raise ValueError('%s is a leaf' % self.__class__.__name__)
+
+    def __delitem__(self, name):
+        raise ValueError('%s is a leaf' % self.__class__.__name__)
+
+    def update(self, new_children):
+        """Children update operation. Disabled for leaves."""
+        raise ValueError('%s is a leaf' % self.__class__.__name__)
diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py
new file mode 100644
index 00000000..9f438928
--- /dev/null
+++ b/swh/model/tests/test_merkle.py
@@ -0,0 +1,229 @@
+# Copyright (C) 2017 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+import unittest
+
+from swh.model import merkle
+
+
+class TestedMerkleNode(merkle.MerkleNode):
+    type = 'tested_merkle_node_type'
+
+    def __init__(self, data):
+        super().__init__(data)
+        self.compute_hash_called = 0
+
+    def compute_hash(self):
+        self.compute_hash_called += 1
+        child_data = [
+            child + b'=' + self[child].hash
+            for child in sorted(self)
+        ]
+
+        return (
+            b'hash('
+            + b', '.join([self.data['value']] + child_data)
+            + b')'
+        )
+
+
+class TestedMerkleLeaf(merkle.MerkleLeaf):
+    type = 'tested_merkle_leaf_type'
+
+    def __init__(self, data):
+        super().__init__(data)
+        self.compute_hash_called = 0
+
+    def compute_hash(self):
+        self.compute_hash_called += 1
+        return b'hash(' + self.data['value'] + b')'
+
+
+class TestMerkleLeaf(unittest.TestCase):
+    def setUp(self):
+        self.data = {'value': b'value'}
+        self.instance = TestedMerkleLeaf(self.data)
+
+    def test_hash(self):
+        self.assertEqual(self.instance.compute_hash_called, 0)
+        instance_hash = self.instance.hash
+        self.assertEqual(self.instance.compute_hash_called, 1)
+        instance_hash2 = self.instance.hash
+        self.assertEqual(self.instance.compute_hash_called, 1)
+        self.assertEqual(instance_hash, instance_hash2)
+
+    def test_data(self):
+        self.assertEqual(self.instance.get_data(), self.data)
+
+    def test_collect(self):
+        collected = self.instance.collect()
+        self.assertEqual(
+            collected, {
+                self.instance.type: {
+                    self.instance.hash: self.instance.get_data(),
+                },
+            },
+        )
+        collected2 = self.instance.collect()
+        self.assertEqual(collected2, {})
+        self.instance.reset_collect()
+        collected3 = self.instance.collect()
+        self.assertEqual(collected, collected3)
+
+    def test_leaf(self):
+        with self.assertRaisesRegex(ValueError, 'is a leaf'):
+            self.instance[b'key1'] = 'Test'
+
+        with self.assertRaisesRegex(ValueError, 'is a leaf'):
+            del self.instance[b'key1']
+
+        with self.assertRaisesRegex(ValueError, 'is a leaf'):
+            self.instance[b'key1']
+
+        with self.assertRaisesRegex(ValueError, 'is a leaf'):
+            self.instance.update(self.data)
+
+
+class TestMerkleNode(unittest.TestCase):
+    maxDiff = None
+
+    def setUp(self):
+        self.root = TestedMerkleNode({'value': b'root'})
+        self.nodes = {b'root': self.root}
+        for i in (b'a', b'b', b'c'):
+            value = b'root/' + i
+            node = TestedMerkleNode({
+                'value': value,
+            })
+            self.root[i] = node
+            self.nodes[value] = node
+            for j in (b'a', b'b', b'c'):
+                value2 = value + b'/' + j
+                node2 = TestedMerkleNode({
+                    'value': value2,
+                })
+                node[j] = node2
+                self.nodes[value2] = node2
+                for k in (b'a', b'b', b'c'):
+                    value3 = value2 + b'/' + j
+                    node3 = TestedMerkleNode({
+                        'value': value3,
+                    })
+                    node2[j] = node3
+                    self.nodes[value3] = node3
+
+    def test_hash(self):
+        for node in self.nodes.values():
+            self.assertEqual(node.compute_hash_called, 0)
+
+        # Root hash will compute hash for all the nodes
+        hash = self.root.hash
+        for node in self.nodes.values():
+            self.assertEqual(node.compute_hash_called, 1)
+            self.assertIn(node.data['value'], hash)
+
+        # Should use the cached value
+        hash2 = self.root.hash
+        self.assertEqual(hash, hash2)
+        for node in self.nodes.values():
+            self.assertEqual(node.compute_hash_called, 1)
+
+        # Should still use the cached value
+        hash3 = self.root.update_hash(force=False)
+        self.assertEqual(hash, hash3)
+        for node in self.nodes.values():
+            self.assertEqual(node.compute_hash_called, 1)
+
+        # Force update of the cached value for a deeply nested node
+        self.root[b'a'][b'b'].update_hash(force=True)
+        for key, node in self.nodes.items():
+            # update_hash rehashes all children
+            if key.startswith(b'root/a/b'):
+                self.assertEqual(node.compute_hash_called, 2)
+            else:
+                self.assertEqual(node.compute_hash_called, 1)
+
+        hash4 = self.root.hash
+        self.assertEqual(hash, hash4)
+        for key, node in self.nodes.items():
+            # update_hash also invalidates all parents
+            if key in (b'root', b'root/a') or key.startswith(b'root/a/b'):
+                self.assertEqual(node.compute_hash_called, 2)
+            else:
+                self.assertEqual(node.compute_hash_called, 1)
+
+    def test_collect(self):
+        collected = self.root.collect()
+        self.assertEqual(len(collected[self.root.type]), len(self.nodes))
+        for node in self.nodes.values():
+            self.assertTrue(node.collected)
+        collected2 = self.root.collect()
+        self.assertEqual(collected2, {})
+
+    def test_get(self):
+        for key in (b'a', b'b', b'c'):
+            self.assertEqual(self.root[key], self.nodes[b'root/' + key])
+
+        with self.assertRaisesRegex(KeyError, "b'nonexistent'"):
+            self.root[b'nonexistent']
+
+    def test_del(self):
+        hash_root = self.root.hash
+        hash_a = self.nodes[b'root/a'].hash
+        del self.root[b'a'][b'c']
+        hash_root2 = self.root.hash
+        hash_a2 = self.nodes[b'root/a'].hash
+
+        self.assertNotEqual(hash_root, hash_root2)
+        self.assertNotEqual(hash_a, hash_a2)
+
+        self.assertEqual(self.nodes[b'root/a/c'].parents, [])
+
+        with self.assertRaisesRegex(KeyError, "b'nonexistent'"):
+            del self.root[b'nonexistent']
+
+    def test_update(self):
+        hash_root = self.root.hash
+        hash_b = self.root[b'b'].hash
+        new_children = {
+            b'c': TestedMerkleNode({'value': b'root/b/new_c'}),
+            b'd': TestedMerkleNode({'value': b'root/b/d'}),
+        }
+
+        # collect all nodes
+        self.root.collect()
+
+        self.root[b'b'].update(new_children)
+
+        # Ensure everyone got reparented
+        self.assertEqual(new_children[b'c'].parents, [self.root[b'b']])
+        self.assertEqual(new_children[b'd'].parents, [self.root[b'b']])
+        self.assertEqual(self.nodes[b'root/b/c'].parents, [])
+
+        hash_root2 = self.root.hash
+        self.assertNotEqual(hash_root, hash_root2)
+        self.assertIn(b'root/b/new_c', hash_root2)
+        self.assertIn(b'root/b/d', hash_root2)
+
+        hash_b2 = self.root[b'b'].hash
+        self.assertNotEqual(hash_b, hash_b2)
+
+        for key, node in self.nodes.items():
+            if key in (b'root', b'root/b'):
+                self.assertEqual(node.compute_hash_called, 2)
+            else:
+                self.assertEqual(node.compute_hash_called, 1)
+
+        # Ensure we collected root, root/b, and both new children
+        collected_after_update = self.root.collect()
+        self.assertCountEqual(
+            collected_after_update[TestedMerkleNode.type],
+            [self.nodes[b'root'].hash, self.nodes[b'root/b'].hash,
+             new_children[b'c'].hash, new_children[b'd'].hash],
+        )
+
+        # test that noop updates doesn't invalidate anything
+        self.root[b'a'][b'b'].update({})
+        self.assertEqual(self.root.collect(), {})
-- 
GitLab