diff --git a/swh/model/merkle.py b/swh/model/merkle.py index 098c8723086b74ef397e5055123ef99910cbdf21..8934ad18e88266116ff85e7a2fb3ccf9c6452f0d 100644 --- a/swh/model/merkle.py +++ b/swh/model/merkle.py @@ -277,18 +277,20 @@ class MerkleNode(dict, metaclass=abc.ABCMeta): for child in self.values(): child.reset_collect() - def iter_tree(self) -> Iterator["MerkleNode"]: - """Yields all children nodes, recursively. Common nodes are - deduplicated. + def iter_tree(self, dedup=True) -> Iterator["MerkleNode"]: + """Yields all children nodes, recursively. Common nodes are deduplicated + by default (deduplication can be turned off setting the given argument + 'dedup' to False). """ - yield from self._iter_tree(set()) + yield from self._iter_tree(set(), dedup) - def _iter_tree(self, seen: Set[bytes]) -> Iterator["MerkleNode"]: + def _iter_tree(self, seen: Set[bytes], dedup) -> Iterator["MerkleNode"]: if self.hash not in seen: - seen.add(self.hash) + if dedup: + seen.add(self.hash) yield self for child in self.values(): - yield from child._iter_tree(seen=seen) + yield from child._iter_tree(seen=seen, dedup=dedup) class MerkleLeaf(MerkleNode): diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py index 65992f45d0c9c69b155cfbb1dbf88cd589b278a7..32de872592dbee4df54a82da4302cb87a31e2611 100644 --- a/swh/model/tests/test_merkle.py +++ b/swh/model/tests/test_merkle.py @@ -172,10 +172,18 @@ class TestMerkleNode(unittest.TestCase): collected2 = self.root.collect() self.assertEqual(collected2, {}) - def test_iter_tree(self): + def test_iter_tree_with_deduplication(self): nodes = list(self.root.iter_tree()) self.assertCountEqual(nodes, self.nodes.values()) + def test_iter_tree_without_deduplication(self): + # duplicate existing hash in merkle tree + self.root[b"d"] = MerkleTestNode({"value": b"root/c/c/c"}) + nodes_dedup = list(self.root.iter_tree()) + nodes = list(self.root.iter_tree(dedup=False)) + assert nodes != nodes_dedup + assert len(nodes) == len(nodes_dedup) + 1 + def test_get(self): for key in (b"a", b"b", b"c"): self.assertEqual(self.root[key], self.nodes[b"root/" + key])