Skip to content
Snippets Groups Projects
Commit 13e7adc3 authored by Antoine Lambert's avatar Antoine Lambert
Browse files

merkle: Make MerkleNode.collect return a set of nodes instead of a dict

Previously the MerkleNode.collect method was returning a dict whose keys
are node types and values dict of {<node_hash>: <node_data>}.

In order to give more flexibility to client code for the processing of
collected nodes, prefer to simply return a set of MerkleNode.

As a consequence, MerkleNode objects need to be hashable by Python so
the __hash__ method has also been implemented.

Closes T4633
parent 9b8beef1
No related branches found
No related tags found
No related merge requests found
# Copyright (C) 2017-2020 The Software Heritage developers
# Copyright (C) 2017-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
......
# Copyright (C) 2017-2020 The Software Heritage developers
# Copyright (C) 2017-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
"""Merkle tree data structure"""
import abc
from collections.abc import Mapping
from typing import Dict, Iterator, List, Set
def deep_update(left, right):
"""Recursively update the left mapping with deeply nested values from the right
mapping.
This function is useful to merge the results of several calls to
:func:`MerkleNode.collect`.
Arguments:
left: a mapping (modified by the update operation)
right: a mapping
Returns:
the left mapping, updated with nested values from the right mapping
Example:
>>> a = {
... 'key1': {
... 'key2': {
... 'key3': 'value1/2/3',
... },
... },
... }
>>> deep_update(a, {
... 'key1': {
... 'key2': {
... 'key4': 'value1/2/4',
... },
... },
... }) == {
... 'key1': {
... 'key2': {
... 'key3': 'value1/2/3',
... 'key4': 'value1/2/4',
... },
... },
... }
True
>>> deep_update(a, {
... 'key1': {
... 'key2': {
... 'key3': 'newvalue1/2/3',
... },
... },
... }) == {
... 'key1': {
... 'key2': {
... 'key3': 'newvalue1/2/3',
... 'key4': 'value1/2/4',
... },
... },
... }
True
from __future__ import annotations
"""
for key, rvalue in right.items():
if isinstance(rvalue, Mapping):
new_lvalue = deep_update(left.get(key, {}), rvalue)
left[key] = new_lvalue
else:
left[key] = rvalue
return left
import abc
from typing import Any, Dict, Iterator, List, Set
class MerkleNode(dict, metaclass=abc.ABCMeta):
......@@ -141,7 +79,7 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
for parent in self.parents:
parent.invalidate_hash()
def update_hash(self, *, force=False):
def update_hash(self, *, force=False) -> Any:
"""Recursively compute the hash of the current node.
Args:
......@@ -161,14 +99,17 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
return self.__hash
@property
def hash(self):
def hash(self) -> Any:
"""The hash of the current node, as calculated by
:func:`compute_hash`.
"""
return self.update_hash()
def __hash__(self):
return hash(self.hash)
@abc.abstractmethod
def compute_hash(self):
def compute_hash(self) -> Any:
"""Compute the hash of the current node.
The hash should depend on the data of the node, as well as on hashes
......@@ -223,47 +164,24 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
"""
return self.data
def collect_node(self, **kwargs):
"""Collect the data for the current node, for use by :func:`collect`.
Arguments:
kwargs: passed as-is to :func:`get_data`.
Returns:
A :class:`dict` compatible with :func:`collect`.
"""
def collect_node(self) -> Set[MerkleNode]:
"""Collect the current node if it has not been yet, for use by :func:`collect`."""
if not self.collected:
self.collected = True
return {self.object_type: {self.hash: self.get_data(**kwargs)}}
return {self}
else:
return {}
return set()
def collect(self, **kwargs):
"""Collect the data for all nodes in the subtree rooted at `self`.
The data is deduplicated by type and by hash.
Arguments:
kwargs: passed as-is to :func:`get_data`.
def collect(self) -> Set[MerkleNode]:
"""Collect the added and modified nodes in the subtree rooted at `self`
since the last collect operation.
Returns:
A :class:`dict` with the following structure::
{
'typeA': {
node1.hash: node1.get_data(),
node2.hash: node2.get_data(),
},
'typeB': {
node3.hash: node3.get_data(),
...
},
...
}
A :class:`set` of collected nodes
"""
ret = self.collect_node(**kwargs)
ret = self.collect_node()
for child in self.values():
deep_update(ret, child.collect(**kwargs))
ret.update(child.collect())
return ret
......@@ -277,14 +195,14 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
for child in self.values():
child.reset_collect()
def iter_tree(self, dedup=True) -> Iterator["MerkleNode"]:
def iter_tree(self, dedup=True) -> Iterator[MerkleNode]:
"""Yields all children nodes, recursively. Common nodes are deduplicated
by default (deduplication can be turned off setting the given argument
'dedup' to False).
"""
yield from self._iter_tree(set(), dedup)
def _iter_tree(self, seen: Set[bytes], dedup) -> Iterator["MerkleNode"]:
def _iter_tree(self, seen: Set[bytes], dedup) -> Iterator[MerkleNode]:
if self.hash not in seen:
if dedup:
seen.add(self.hash)
......
# Copyright (C) 2017-2020 The Software Heritage developers
# Copyright (C) 2017-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
......@@ -715,6 +715,21 @@ class DirectoryToObjects(DataMixin, unittest.TestCase):
empties = os.path.join(self.tmpdir_name, b"empty1", b"empty2")
os.makedirs(empties)
def check_collect(
self, directory, expected_directory_count, expected_content_count
):
objs = directory.collect()
contents = []
directories = []
for obj in objs:
if isinstance(obj, Content):
contents.append(obj)
elif isinstance(obj, Directory):
directories.append(obj)
self.assertEqual(len(directories), expected_directory_count)
self.assertEqual(len(contents), expected_content_count)
def test_directory_to_objects(self):
directory = Directory.from_disk(path=self.tmpdir_name)
......@@ -743,13 +758,10 @@ class DirectoryToObjects(DataMixin, unittest.TestCase):
with self.assertRaisesRegex(KeyError, "b'nonexistentdir'"):
directory[b"nonexistentdir/file"]
objs = directory.collect()
self.assertCountEqual(["content", "directory"], objs)
self.assertEqual(len(objs["directory"]), 6)
self.assertEqual(
len(objs["content"]), len(self.contents) + len(self.symlinks) + 1
self.check_collect(
directory,
expected_directory_count=6,
expected_content_count=len(self.contents) + len(self.symlinks) + 1,
)
def test_directory_to_objects_ignore_empty(self):
......@@ -775,13 +787,10 @@ class DirectoryToObjects(DataMixin, unittest.TestCase):
with self.assertRaisesRegex(KeyError, "b'empty1'"):
directory[b"empty1/empty2"]
objs = directory.collect()
self.assertCountEqual(["content", "directory"], objs)
self.assertEqual(len(objs["directory"]), 4)
self.assertEqual(
len(objs["content"]), len(self.contents) + len(self.symlinks) + 1
self.check_collect(
directory,
expected_directory_count=4,
expected_content_count=len(self.contents) + len(self.symlinks) + 1,
)
def test_directory_to_objects_ignore_name(self):
......@@ -806,12 +815,11 @@ class DirectoryToObjects(DataMixin, unittest.TestCase):
with self.assertRaisesRegex(KeyError, "b'symlinks'"):
directory[b"symlinks"]
objs = directory.collect()
self.assertCountEqual(["content", "directory"], objs)
self.assertEqual(len(objs["directory"]), 5)
self.assertEqual(len(objs["content"]), len(self.contents) + 1)
self.check_collect(
directory,
expected_directory_count=5,
expected_content_count=len(self.contents) + 1,
)
def test_directory_to_objects_ignore_name_case(self):
directory = Directory.from_disk(
......@@ -837,12 +845,11 @@ class DirectoryToObjects(DataMixin, unittest.TestCase):
with self.assertRaisesRegex(KeyError, "b'symlinks'"):
directory[b"symlinks"]
objs = directory.collect()
self.assertCountEqual(["content", "directory"], objs)
self.assertEqual(len(objs["directory"]), 5)
self.assertEqual(len(objs["content"]), len(self.contents) + 1)
self.check_collect(
directory,
expected_directory_count=5,
expected_content_count=len(self.contents) + 1,
)
def test_directory_entry_order(self):
with tempfile.TemporaryDirectory() as dirname:
......
# Copyright (C) 2017-2020 The Software Heritage developers
# Copyright (C) 2017-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
......@@ -15,11 +15,10 @@ class MerkleTestNode(merkle.MerkleNode):
super().__init__(data)
self.compute_hash_called = 0
def compute_hash(self):
def compute_hash(self) -> bytes:
self.compute_hash_called += 1
child_data = [child + b"=" + self[child].hash for child in sorted(self)]
return b"hash(" + b", ".join([self.data["value"]] + child_data) + b")"
return b"hash(" + b", ".join([self.data.get("value", b"")] + child_data) + b")"
class MerkleTestLeaf(merkle.MerkleLeaf):
......@@ -31,7 +30,7 @@ class MerkleTestLeaf(merkle.MerkleLeaf):
def compute_hash(self):
self.compute_hash_called += 1
return b"hash(" + self.data["value"] + b")"
return b"hash(" + self.data.get("value", b"") + b")"
class TestMerkleLeaf(unittest.TestCase):
......@@ -62,14 +61,10 @@ class TestMerkleLeaf(unittest.TestCase):
collected = self.instance.collect()
self.assertEqual(
collected,
{
self.instance.object_type: {
self.instance.hash: self.instance.get_data(),
},
},
{self.instance},
)
collected2 = self.instance.collect()
self.assertEqual(collected2, {})
self.assertEqual(collected2, set())
self.instance.reset_collect()
collected3 = self.instance.collect()
self.assertEqual(collected, collected3)
......@@ -123,17 +118,17 @@ class TestMerkleNode(unittest.TestCase):
self.nodes[value3] = node3
def test_equality(self):
node1 = merkle.MerkleNode({"foo": b"bar"})
node2 = merkle.MerkleNode({"foo": b"bar"})
node3 = merkle.MerkleNode({})
node1 = MerkleTestNode({"value": b"bar"})
node2 = MerkleTestNode({"value": b"bar"})
node3 = MerkleTestNode({})
self.assertEqual(node1, node2)
self.assertNotEqual(node1, node3, node1 == node3)
node1["foo"] = node3
node1[b"a"] = node3
self.assertNotEqual(node1, node2)
node2["foo"] = node3
node2[b"a"] = node3
self.assertEqual(node1, node2)
def test_hash(self):
......@@ -178,11 +173,11 @@ class TestMerkleNode(unittest.TestCase):
def test_collect(self):
collected = self.root.collect()
self.assertEqual(len(collected[self.root.object_type]), len(self.nodes))
self.assertEqual(collected, set(self.nodes.values()))
for node in self.nodes.values():
self.assertTrue(node.collected)
collected2 = self.root.collect()
self.assertEqual(collected2, {})
self.assertEqual(collected2, set())
def test_iter_tree_with_deduplication(self):
nodes = list(self.root.iter_tree())
......@@ -252,16 +247,16 @@ class TestMerkleNode(unittest.TestCase):
# Ensure we collected root, root/b, and both new children
collected_after_update = self.root.collect()
self.assertCountEqual(
collected_after_update[MerkleTestNode.object_type],
[
self.nodes[b"root"].hash,
self.nodes[b"root/b"].hash,
new_children[b"c"].hash,
new_children[b"d"].hash,
],
self.assertEqual(
collected_after_update,
{
self.nodes[b"root"],
self.nodes[b"root/b"],
new_children[b"c"],
new_children[b"d"],
},
)
# test that noop updates doesn't invalidate anything
self.root[b"a"][b"b"].update({})
self.assertEqual(self.root.collect(), {})
self.assertEqual(self.root.collect(), set())
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment