From c0ce38ed4948ced52633b2be3f59e62bc7d82e50 Mon Sep 17 00:00:00 2001
From: Valentin Lorentz <vlorentz@softwareheritage.org>
Date: Mon, 24 Feb 2020 16:00:14 +0100
Subject: [PATCH] Take the value of MerkleNode.data into account to compute
 equality.

It just makes more sense that way.

eg. before this change, all leafs would be equal to each other.
---
 swh/model/merkle.py            |  7 +++++++
 swh/model/tests/test_merkle.py | 22 ++++++++++++++++++++++
 2 files changed, 29 insertions(+)

diff --git a/swh/model/merkle.py b/swh/model/merkle.py
index 02c6f2b2..31be3d17 100644
--- a/swh/model/merkle.py
+++ b/swh/model/merkle.py
@@ -120,6 +120,13 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
         self.__hash = None
         self.collected = False
 
+    def __eq__(self, other):
+        return isinstance(other, MerkleNode) \
+            and super().__eq__(other) and self.data == other.data
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
     def invalidate_hash(self):
         """Invalidate the cached hash of the current node."""
         if not self.__hash:
diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py
index 8b1180a4..dc7da63b 100644
--- a/swh/model/tests/test_merkle.py
+++ b/swh/model/tests/test_merkle.py
@@ -46,6 +46,14 @@ class TestMerkleLeaf(unittest.TestCase):
         self.data = {'value': b'value'}
         self.instance = MerkleTestLeaf(self.data)
 
+    def test_equality(self):
+        leaf1 = MerkleTestLeaf(self.data)
+        leaf2 = MerkleTestLeaf(self.data)
+        leaf3 = MerkleTestLeaf({})
+
+        self.assertEqual(leaf1, leaf2)
+        self.assertNotEqual(leaf1, leaf3)
+
     def test_hash(self):
         self.assertEqual(self.instance.compute_hash_called, 0)
         instance_hash = self.instance.hash
@@ -114,6 +122,20 @@ class TestMerkleNode(unittest.TestCase):
                     node2[j] = node3
                     self.nodes[value3] = node3
 
+    def test_equality(self):
+        node1 = merkle.MerkleNode({'foo': b'bar'})
+        node2 = merkle.MerkleNode({'foo': b'bar'})
+        node3 = merkle.MerkleNode({})
+
+        self.assertEqual(node1, node2)
+        self.assertNotEqual(node1, node3, node1 == node3)
+
+        node1['foo'] = node3
+        self.assertNotEqual(node1, node2)
+
+        node2['foo'] = node3
+        self.assertEqual(node1, node2)
+
     def test_hash(self):
         for node in self.nodes.values():
             self.assertEqual(node.compute_hash_called, 0)
-- 
GitLab