From d230cb3df3a363ab89390062027e6b55d008433a Mon Sep 17 00:00:00 2001 From: Valentin Lorentz <vlorentz@softwareheritage.org> Date: Tue, 29 Sep 2020 19:24:17 +0200 Subject: [PATCH] SortedList: Don't inherit from UserList. A class should only inherit from UserList if the type of data it presents is the same as the data in the 'data' attribute, which isn't true here. This means, for example, that SortedList.__contains__ checked if the value is in self.data, which always returns False (unless unlucky, but then it returns True while it shouldn't). By removing this inheritance, methods that are no longer implemented no longer default to a buggy implementation. --- swh/core/collections.py | 5 ++--- swh/core/tests/test_collections.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/swh/core/collections.py b/swh/core/collections.py index 92fab407..ed7b8692 100644 --- a/swh/core/collections.py +++ b/swh/core/collections.py @@ -4,7 +4,6 @@ # See top-level LICENSE file for more information import bisect -import collections import itertools from typing import Any, Callable, Generic, Iterator, List, Optional, Tuple, TypeVar @@ -12,7 +11,7 @@ SortedListItem = TypeVar("SortedListItem") SortedListKey = TypeVar("SortedListKey") -class SortedList(collections.UserList, Generic[SortedListKey, SortedListItem]): +class SortedList(Generic[SortedListKey, SortedListItem]): data: List[Tuple[SortedListKey, SortedListItem]] # https://github.com/python/mypy/issues/708 @@ -29,7 +28,7 @@ class SortedList(collections.UserList, Generic[SortedListKey, SortedListItem]): return item assert key is not None # for mypy - super().__init__(sorted((key(x), x) for x in data or [])) + self.data = sorted((key(x), x) for x in data or []) self.key: Callable[[SortedListItem], SortedListKey] = key diff --git a/swh/core/tests/test_collections.py b/swh/core/tests/test_collections.py index 22efbc06..b2b4b21a 100644 --- a/swh/core/tests/test_collections.py +++ b/swh/core/tests/test_collections.py @@ -69,3 +69,18 @@ def test_sorted_list_iter_after__key(items): for split in items: expected = reversed(sorted(item for item in items if item < split)) assert list(list_.iter_after(-split)) == list(expected), f"split: {split}" + + +@parametrize +def test_contains(items): + list_ = SortedList() + for i in range(len(items)): + for item in items[0:i]: + assert item in list_ + for item in items[i:]: + assert item not in list_ + + list_.add(items[i]) + + for item in items: + assert item in list_ -- GitLab