Skip to content
Snippets Groups Projects
Commit 9cf7a04a authored by vlorentz's avatar vlorentz
Browse files

Add method MerkleNode.iter_tree, to visit all nodes in the subtree of a node.

parent c0ce38ed
No related branches found
No related tags found
No related merge requests found
......@@ -8,7 +8,7 @@
import abc
import collections
from typing import List, Optional
from typing import Iterator, List, Optional, Set
def deep_update(left, right):
......@@ -273,6 +273,20 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
for child in self.values():
child.reset_collect()
def iter_tree(self) -> Iterator['MerkleNode']:
"""Yields all children nodes, recursively. Common nodes are
deduplicated.
"""
yield from self._iter_tree(set())
def _iter_tree(
self, seen: Set[bytes]) -> Iterator['MerkleNode']:
if self.hash not in seen:
seen.add(self.hash)
yield self
for child in self.values():
yield from child._iter_tree(seen=seen)
class MerkleLeaf(MerkleNode):
"""A leaf to a Merkle tree.
......
......@@ -184,6 +184,10 @@ class TestMerkleNode(unittest.TestCase):
collected2 = self.root.collect()
self.assertEqual(collected2, {})
def test_iter_tree(self):
nodes = list(self.root.iter_tree())
self.assertCountEqual(nodes, self.nodes.values())
def test_get(self):
for key in (b'a', b'b', b'c'):
self.assertEqual(self.root[key], self.nodes[b'root/' + key])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment