From c0ce38ed4948ced52633b2be3f59e62bc7d82e50 Mon Sep 17 00:00:00 2001 From: Valentin Lorentz <vlorentz@softwareheritage.org> Date: Mon, 24 Feb 2020 16:00:14 +0100 Subject: [PATCH] Take the value of MerkleNode.data into account to compute equality. It just makes more sense that way. eg. before this change, all leafs would be equal to each other. --- swh/model/merkle.py | 7 +++++++ swh/model/tests/test_merkle.py | 22 ++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/swh/model/merkle.py b/swh/model/merkle.py index 02c6f2b2..31be3d17 100644 --- a/swh/model/merkle.py +++ b/swh/model/merkle.py @@ -120,6 +120,13 @@ class MerkleNode(dict, metaclass=abc.ABCMeta): self.__hash = None self.collected = False + def __eq__(self, other): + return isinstance(other, MerkleNode) \ + and super().__eq__(other) and self.data == other.data + + def __ne__(self, other): + return not self.__eq__(other) + def invalidate_hash(self): """Invalidate the cached hash of the current node.""" if not self.__hash: diff --git a/swh/model/tests/test_merkle.py b/swh/model/tests/test_merkle.py index 8b1180a4..dc7da63b 100644 --- a/swh/model/tests/test_merkle.py +++ b/swh/model/tests/test_merkle.py @@ -46,6 +46,14 @@ class TestMerkleLeaf(unittest.TestCase): self.data = {'value': b'value'} self.instance = MerkleTestLeaf(self.data) + def test_equality(self): + leaf1 = MerkleTestLeaf(self.data) + leaf2 = MerkleTestLeaf(self.data) + leaf3 = MerkleTestLeaf({}) + + self.assertEqual(leaf1, leaf2) + self.assertNotEqual(leaf1, leaf3) + def test_hash(self): self.assertEqual(self.instance.compute_hash_called, 0) instance_hash = self.instance.hash @@ -114,6 +122,20 @@ class TestMerkleNode(unittest.TestCase): node2[j] = node3 self.nodes[value3] = node3 + def test_equality(self): + node1 = merkle.MerkleNode({'foo': b'bar'}) + node2 = merkle.MerkleNode({'foo': b'bar'}) + node3 = merkle.MerkleNode({}) + + self.assertEqual(node1, node2) + self.assertNotEqual(node1, node3, node1 == node3) + + node1['foo'] = node3 + self.assertNotEqual(node1, node2) + + node2['foo'] = node3 + self.assertEqual(node1, node2) + def test_hash(self): for node in self.nodes.values(): self.assertEqual(node.compute_hash_called, 0) -- GitLab