diff --git a/swh/model/from_disk.py b/swh/model/from_disk.py
index 8795b1fb56c742fac8577c8b48599f2eaa77df29..8bd7f5d18f3bb39895969c8b260fd06d038391f6 100644
--- a/swh/model/from_disk.py
+++ b/swh/model/from_disk.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2017-2020 The Software Heritage developers
+# Copyright (C) 2017-2022 The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
diff --git a/swh/model/merkle.py b/swh/model/merkle.py
index ab6b8ea35e393ece319fe401359e2e7b9b106433..b224840782e38d09a9ccad1786b6e99a4b6cf630 100644
--- a/swh/model/merkle.py
+++ b/swh/model/merkle.py
@@ -1,76 +1,14 @@
-# Copyright (C) 2017-2020 The Software Heritage developers
+# Copyright (C) 2017-2022 The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
 
 """Merkle tree data structure"""
 
-import abc
-from collections.abc import Mapping
-from typing import Dict, Iterator, List, Set
-
-
-def deep_update(left, right):
-    """Recursively update the left mapping with deeply nested values from the right
-    mapping.
-
-    This function is useful to merge the results of several calls to
-    :func:`MerkleNode.collect`.
-
-    Arguments:
-      left: a mapping (modified by the update operation)
-      right: a mapping
-
-    Returns:
-      the left mapping, updated with nested values from the right mapping
-
-    Example:
-        >>> a = {
-        ...     'key1': {
-        ...         'key2': {
-        ...              'key3': 'value1/2/3',
-        ...         },
-        ...     },
-        ... }
-        >>> deep_update(a, {
-        ...     'key1': {
-        ...         'key2': {
-        ...              'key4': 'value1/2/4',
-        ...         },
-        ...     },
-        ... }) == {
-        ...     'key1': {
-        ...         'key2': {
-        ...             'key3': 'value1/2/3',
-        ...             'key4': 'value1/2/4',
-        ...         },
-        ...     },
-        ... }
-        True
-        >>> deep_update(a, {
-        ...     'key1': {
-        ...         'key2': {
-        ...              'key3': 'newvalue1/2/3',
-        ...         },
-        ...     },
-        ... }) == {
-        ...     'key1': {
-        ...         'key2': {
-        ...             'key3': 'newvalue1/2/3',
-        ...             'key4': 'value1/2/4',
-        ...         },
-        ...     },
-        ... }
-        True
+from __future__ import annotations
 
-    """
-    for key, rvalue in right.items():
-        if isinstance(rvalue, Mapping):
-            new_lvalue = deep_update(left.get(key, {}), rvalue)
-            left[key] = new_lvalue
-        else:
-            left[key] = rvalue
-    return left
+import abc
+from typing import Any, Dict, Iterator, List, Set
 
 
 class MerkleNode(dict, metaclass=abc.ABCMeta):
@@ -141,7 +79,7 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
         for parent in self.parents:
             parent.invalidate_hash()
 
-    def update_hash(self, *, force=False):
+    def update_hash(self, *, force=False) -> Any:
         """Recursively compute the hash of the current node.
 
         Args:
@@ -161,14 +99,17 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
         return self.__hash
 
     @property
-    def hash(self):
+    def hash(self) -> Any:
         """The hash of the current node, as calculated by
         :func:`compute_hash`.
         """
         return self.update_hash()
 
+    def __hash__(self):
+        return hash(self.hash)
+
     @abc.abstractmethod
-    def compute_hash(self):
+    def compute_hash(self) -> Any:
         """Compute the hash of the current node.
 
         The hash should depend on the data of the node, as well as on hashes
@@ -223,47 +164,24 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
         """
         return self.data
 
-    def collect_node(self, **kwargs):
-        """Collect the data for the current node, for use by :func:`collect`.
-
-        Arguments:
-          kwargs: passed as-is to :func:`get_data`.
-
-        Returns:
-          A :class:`dict` compatible with :func:`collect`.
-        """
+    def collect_node(self) -> Set[MerkleNode]:
+        """Collect the current node if it has not been yet, for use by :func:`collect`."""
         if not self.collected:
             self.collected = True
-            return {self.object_type: {self.hash: self.get_data(**kwargs)}}
+            return {self}
         else:
-            return {}
+            return set()
 
-    def collect(self, **kwargs):
-        """Collect the data for all nodes in the subtree rooted at `self`.
-
-        The data is deduplicated by type and by hash.
-
-        Arguments:
-          kwargs: passed as-is to :func:`get_data`.
+    def collect(self) -> Set[MerkleNode]:
+        """Collect the added and modified nodes in the subtree rooted at `self`
+        since the last collect operation.
 
         Returns:
-           A :class:`dict` with the following structure::
-
-             {
-               'typeA': {
-                 node1.hash: node1.get_data(),
-                 node2.hash: node2.get_data(),
-               },
-               'typeB': {
-                 node3.hash: node3.get_data(),
-                 ...
-               },
-               ...
-             }
+           A :class:`set` of collected nodes
         """
-        ret = self.collect_node(**kwargs)
+        ret = self.collect_node()
         for child in self.values():
-            deep_update(ret, child.collect(**kwargs))
+            ret.update(child.collect())
 
         return ret
 
@@ -277,14 +195,14 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
         for child in self.values():
             child.reset_collect()
 
-    def iter_tree(self, dedup=True) -> Iterator["MerkleNode"]:
+    def iter_tree(self, dedup=True) -> Iterator[MerkleNode]:
         """Yields all children nodes, recursively. Common nodes are deduplicated
         by default (deduplication can be turned off setting the given argument
         'dedup' to False).
         """
         yield from self._iter_tree(set(), dedup)
 
-    def _iter_tree(self, seen: Set[bytes], dedup) -> Iterator["MerkleNode"]:
+    def _iter_tree(self, seen: Set[bytes], dedup) -> Iterator[MerkleNode]:
         if self.hash not in seen:
             if dedup:
                 seen.add(self.hash)
diff --git a/swh/model/tests/test_from_disk.py b/swh/model/tests/test_from_disk.py
index b7674d4307a4166a2ede388af8c8e6f52b72f190..c07fef683b2c47583aa514dc0961331f63df1ef2 100644
--- a/swh/model/tests/test_from_disk.py
+++ b/swh/model/tests/test_from_disk.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2017-2020 The Software Heritage developers
+# Copyright (C) 2017-2022 The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
@@ -715,6 +715,21 @@ class DirectoryToObjects(DataMixin, unittest.TestCase):
         empties = os.path.join(self.tmpdir_name, b"empty1", b"empty2")
         os.makedirs(empties)
 
+    def check_collect(
+        self, directory, expected_directory_count, expected_content_count
+    ):
+        objs = directory.collect()
+        contents = []
+        directories = []
+        for obj in objs:
+            if isinstance(obj, Content):
+                contents.append(obj)
+            elif isinstance(obj, Directory):
+                directories.append(obj)
+
+        self.assertEqual(len(directories), expected_directory_count)
+        self.assertEqual(len(contents), expected_content_count)
+
     def test_directory_to_objects(self):
         directory = Directory.from_disk(path=self.tmpdir_name)
 
@@ -743,13 +758,10 @@ class DirectoryToObjects(DataMixin, unittest.TestCase):
         with self.assertRaisesRegex(KeyError, "b'nonexistentdir'"):
             directory[b"nonexistentdir/file"]
 
-        objs = directory.collect()
-
-        self.assertCountEqual(["content", "directory"], objs)
-
-        self.assertEqual(len(objs["directory"]), 6)
-        self.assertEqual(
-            len(objs["content"]), len(self.contents) + len(self.symlinks) + 1
+        self.check_collect(
+            directory,
+            expected_directory_count=6,
+            expected_content_count=len(self.contents) + len(self.symlinks) + 1,
         )
 
     def test_directory_to_objects_ignore_empty(self):
@@ -775,13 +787,10 @@ class DirectoryToObjects(DataMixin, unittest.TestCase):
         with self.assertRaisesRegex(KeyError, "b'empty1'"):
             directory[b"empty1/empty2"]
 
-        objs = directory.collect()
-
-        self.assertCountEqual(["content", "directory"], objs)
-
-        self.assertEqual(len(objs["directory"]), 4)
-        self.assertEqual(
-            len(objs["content"]), len(self.contents) + len(self.symlinks) + 1
+        self.check_collect(
+            directory,
+            expected_directory_count=4,
+            expected_content_count=len(self.contents) + len(self.symlinks) + 1,
         )
 
     def test_directory_to_objects_ignore_name(self):
@@ -806,12 +815,11 @@ class DirectoryToObjects(DataMixin, unittest.TestCase):
         with self.assertRaisesRegex(KeyError, "b'symlinks'"):
             directory[b"symlinks"]
 
-        objs = directory.collect()
-
-        self.assertCountEqual(["content", "directory"], objs)
-
-        self.assertEqual(len(objs["directory"]), 5)
-        self.assertEqual(len(objs["content"]), len(self.contents) + 1)
+        self.check_collect(
+            directory,
+            expected_directory_count=5,
+            expected_content_count=len(self.contents) + 1,
+        )
 
     def test_directory_to_objects_ignore_name_case(self):
         directory = Directory.from_disk(
@@ -837,12 +845,11 @@ class DirectoryToObjects(DataMixin, unittest.TestCase):
         with self.assertRaisesRegex(KeyError, "b'symlinks'"):
             directory[b"symlinks"]
 
-        objs = directory.collect()
-
-        self.assertCountEqual(["content", "directory"], objs)
-
-        self.assertEqual(len(objs["directory"]), 5)
-        self.assertEqual(len(objs["content"]), len(self.contents) + 1)
+        self.check_collect(
+            directory,
+            expected_directory_count=5,
+            expected_content_count=len(self.contents) + 1,
+        )
 
     def test_directory_entry_order(self):
         with tempfile.TemporaryDirectory() as dirname:
diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py
index 52edb2c1d8424263199b28323f601ffa1b8c234b..a852541bf6d87677fda4f615a3551699738806ba 100644
--- a/swh/model/tests/test_merkle.py
+++ b/swh/model/tests/test_merkle.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2017-2020 The Software Heritage developers
+# Copyright (C) 2017-2022 The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
@@ -15,11 +15,10 @@ class MerkleTestNode(merkle.MerkleNode):
         super().__init__(data)
         self.compute_hash_called = 0
 
-    def compute_hash(self):
+    def compute_hash(self) -> bytes:
         self.compute_hash_called += 1
         child_data = [child + b"=" + self[child].hash for child in sorted(self)]
-
-        return b"hash(" + b", ".join([self.data["value"]] + child_data) + b")"
+        return b"hash(" + b", ".join([self.data.get("value", b"")] + child_data) + b")"
 
 
 class MerkleTestLeaf(merkle.MerkleLeaf):
@@ -31,7 +30,7 @@ class MerkleTestLeaf(merkle.MerkleLeaf):
 
     def compute_hash(self):
         self.compute_hash_called += 1
-        return b"hash(" + self.data["value"] + b")"
+        return b"hash(" + self.data.get("value", b"") + b")"
 
 
 class TestMerkleLeaf(unittest.TestCase):
@@ -62,14 +61,10 @@ class TestMerkleLeaf(unittest.TestCase):
         collected = self.instance.collect()
         self.assertEqual(
             collected,
-            {
-                self.instance.object_type: {
-                    self.instance.hash: self.instance.get_data(),
-                },
-            },
+            {self.instance},
         )
         collected2 = self.instance.collect()
-        self.assertEqual(collected2, {})
+        self.assertEqual(collected2, set())
         self.instance.reset_collect()
         collected3 = self.instance.collect()
         self.assertEqual(collected, collected3)
@@ -123,17 +118,17 @@ class TestMerkleNode(unittest.TestCase):
                     self.nodes[value3] = node3
 
     def test_equality(self):
-        node1 = merkle.MerkleNode({"foo": b"bar"})
-        node2 = merkle.MerkleNode({"foo": b"bar"})
-        node3 = merkle.MerkleNode({})
+        node1 = MerkleTestNode({"value": b"bar"})
+        node2 = MerkleTestNode({"value": b"bar"})
+        node3 = MerkleTestNode({})
 
         self.assertEqual(node1, node2)
         self.assertNotEqual(node1, node3, node1 == node3)
 
-        node1["foo"] = node3
+        node1[b"a"] = node3
         self.assertNotEqual(node1, node2)
 
-        node2["foo"] = node3
+        node2[b"a"] = node3
         self.assertEqual(node1, node2)
 
     def test_hash(self):
@@ -178,11 +173,11 @@ class TestMerkleNode(unittest.TestCase):
 
     def test_collect(self):
         collected = self.root.collect()
-        self.assertEqual(len(collected[self.root.object_type]), len(self.nodes))
+        self.assertEqual(collected, set(self.nodes.values()))
         for node in self.nodes.values():
             self.assertTrue(node.collected)
         collected2 = self.root.collect()
-        self.assertEqual(collected2, {})
+        self.assertEqual(collected2, set())
 
     def test_iter_tree_with_deduplication(self):
         nodes = list(self.root.iter_tree())
@@ -252,16 +247,16 @@ class TestMerkleNode(unittest.TestCase):
 
         # Ensure we collected root, root/b, and both new children
         collected_after_update = self.root.collect()
-        self.assertCountEqual(
-            collected_after_update[MerkleTestNode.object_type],
-            [
-                self.nodes[b"root"].hash,
-                self.nodes[b"root/b"].hash,
-                new_children[b"c"].hash,
-                new_children[b"d"].hash,
-            ],
+        self.assertEqual(
+            collected_after_update,
+            {
+                self.nodes[b"root"],
+                self.nodes[b"root/b"],
+                new_children[b"c"],
+                new_children[b"d"],
+            },
         )
 
         # test that noop updates doesn't invalidate anything
         self.root[b"a"][b"b"].update({})
-        self.assertEqual(self.root.collect(), {})
+        self.assertEqual(self.root.collect(), set())