diff --git a/swh/model/merkle.py b/swh/model/merkle.py index 31be3d178cfe259be55592e18d1178dd9f578a0f..9d97efdc55b1c0bf23c5abfdea8f995988197ea9 100644 --- a/swh/model/merkle.py +++ b/swh/model/merkle.py @@ -8,7 +8,7 @@ import abc import collections -from typing import List, Optional +from typing import Iterator, List, Optional, Set def deep_update(left, right): @@ -273,6 +273,20 @@ class MerkleNode(dict, metaclass=abc.ABCMeta): for child in self.values(): child.reset_collect() + def iter_tree(self) -> Iterator['MerkleNode']: + """Yields all children nodes, recursively. Common nodes are + deduplicated. + """ + yield from self._iter_tree(set()) + + def _iter_tree( + self, seen: Set[bytes]) -> Iterator['MerkleNode']: + if self.hash not in seen: + seen.add(self.hash) + yield self + for child in self.values(): + yield from child._iter_tree(seen=seen) + class MerkleLeaf(MerkleNode): """A leaf to a Merkle tree. diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py index dc7da63bb3af8570a42a6b3f27668141dba10afe..734f7c036143163a24b7e9c9be3be9103d6070fa 100644 --- a/swh/model/tests/test_merkle.py +++ b/swh/model/tests/test_merkle.py @@ -184,6 +184,10 @@ class TestMerkleNode(unittest.TestCase): collected2 = self.root.collect() self.assertEqual(collected2, {}) + def test_iter_tree(self): + nodes = list(self.root.iter_tree()) + self.assertCountEqual(nodes, self.nodes.values()) + def test_get(self): for key in (b'a', b'b', b'c'): self.assertEqual(self.root[key], self.nodes[b'root/' + key])