diff --git a/swh/model/merkle.py b/swh/model/merkle.py index 02c6f2b29d17e5f6d9dc5336fe760bfc68d1617e..31be3d178cfe259be55592e18d1178dd9f578a0f 100644 --- a/swh/model/merkle.py +++ b/swh/model/merkle.py @@ -120,6 +120,13 @@ class MerkleNode(dict, metaclass=abc.ABCMeta): self.__hash = None self.collected = False + def __eq__(self, other): + return isinstance(other, MerkleNode) \ + and super().__eq__(other) and self.data == other.data + + def __ne__(self, other): + return not self.__eq__(other) + def invalidate_hash(self): """Invalidate the cached hash of the current node.""" if not self.__hash: diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py index 8b1180a4094c19005b19ea52d8879fac9ac405fb..dc7da63bb3af8570a42a6b3f27668141dba10afe 100644 --- a/swh/model/tests/test_merkle.py +++ b/swh/model/tests/test_merkle.py @@ -46,6 +46,14 @@ class TestMerkleLeaf(unittest.TestCase): self.data = {'value': b'value'} self.instance = MerkleTestLeaf(self.data) + def test_equality(self): + leaf1 = MerkleTestLeaf(self.data) + leaf2 = MerkleTestLeaf(self.data) + leaf3 = MerkleTestLeaf({}) + + self.assertEqual(leaf1, leaf2) + self.assertNotEqual(leaf1, leaf3) + def test_hash(self): self.assertEqual(self.instance.compute_hash_called, 0) instance_hash = self.instance.hash @@ -114,6 +122,20 @@ class TestMerkleNode(unittest.TestCase): node2[j] = node3 self.nodes[value3] = node3 + def test_equality(self): + node1 = merkle.MerkleNode({'foo': b'bar'}) + node2 = merkle.MerkleNode({'foo': b'bar'}) + node3 = merkle.MerkleNode({}) + + self.assertEqual(node1, node2) + self.assertNotEqual(node1, node3, node1 == node3) + + node1['foo'] = node3 + self.assertNotEqual(node1, node2) + + node2['foo'] = node3 + self.assertEqual(node1, node2) + def test_hash(self): for node in self.nodes.values(): self.assertEqual(node.compute_hash_called, 0)