diff --git a/swh/model/collections.py b/swh/model/collections.py index 2724f85c59c38ab664e696f5a644674fe2af4f44..495b43c5c3c79a17bdf8d5611869c58b9f42203c 100644 --- a/swh/model/collections.py +++ b/swh/model/collections.py @@ -13,9 +13,16 @@ VT = TypeVar("VT") class ImmutableDict(Mapping, Generic[KT, VT]): data: Tuple[Tuple[KT, VT], ...] - def __init__(self, data: Union[Iterable[Tuple[KT, VT]], Dict[KT, VT]] = {}): + def __init__( + self, + data: Union[ + Iterable[Tuple[KT, VT]], "ImmutableDict[KT, VT]", Dict[KT, VT] + ] = {}, + ): if isinstance(data, dict): self.data = tuple(item for item in data.items()) + elif isinstance(data, ImmutableDict): + self.data = data.data else: self.data = tuple(data) diff --git a/swh/model/tests/test_collections.py b/swh/model/tests/test_collections.py index c7b44cb1f8c2863a70d36f7a108268129e256dea..b042c5929ec67f5481dcb72577e9a18f763aef4c 100644 --- a/swh/model/tests/test_collections.py +++ b/swh/model/tests/test_collections.py @@ -32,6 +32,22 @@ def test_immutabledict_one_item(): assert list(d.items()) == [("foo", "bar")] +def test_immutabledict_from_iterable(): + d1 = ImmutableDict() + d2 = ImmutableDict({"foo": "bar"}) + + assert ImmutableDict([]) == d1 + assert ImmutableDict([("foo", "bar")]) == d2 + + +def test_immutabledict_from_immutabledict(): + d1 = ImmutableDict() + d2 = ImmutableDict({"foo": "bar"}) + + assert ImmutableDict(d1) == d1 + assert ImmutableDict(d2) == d2 + + def test_immutabledict_immutable(): d = ImmutableDict({"foo": "bar"})