diff --git a/swh/model/tests/swh_model_data.py b/swh/model/tests/swh_model_data.py index 2e0f1529620583e2bf58cac20879323e1ad66773..5c4fcdda9908f0200fb2558f4242abce2af9a909 100644 --- a/swh/model/tests/swh_model_data.py +++ b/swh/model/tests/swh_model_data.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019-2021 The Software Heritage developers +# Copyright (C) 2019-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -18,7 +18,6 @@ from swh.model.model import ( MetadataAuthority, MetadataAuthorityType, MetadataFetcher, - ModelObjectType, ObjectType, Origin, OriginVisit, @@ -449,7 +448,7 @@ EXTIDS: List[ExtID] = [ ), ] -TEST_OBJECTS: Dict[ModelObjectType, Sequence[BaseModel]] = {} +TEST_OBJECTS: Dict[str, Sequence[BaseModel]] = {} # generate this mapping with code to avoid error for objects in [ CONTENTS, @@ -467,8 +466,8 @@ for objects in [ SKIPPED_CONTENTS, ]: objects = cast(List[BaseModel], objects) - object_type = objects[0].object_type - assert all(object_type == o.object_type for o in objects) + object_type = objects[0].object_type.value + assert all(object_type == o.object_type.value for o in objects) assert object_type not in TEST_OBJECTS TEST_OBJECTS[object_type] = objects diff --git a/swh/model/tests/test_model.py b/swh/model/tests/test_model.py index a50b3369fd0c5550513d8776c5de2014472a5a64..cf91ec3da442d5cd1b117c7b37788b4eb5efaafe 100644 --- a/swh/model/tests/test_model.py +++ b/swh/model/tests/test_model.py @@ -1,4 +1,4 @@ -# Copyright (C) 2019-2020 The Software Heritage developers +# Copyright (C) 2019-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -327,8 +327,8 @@ def test_optimized_type_validator_invalid(type_, value): validator(None, attr.ib(type=type_), value) -@pytest.mark.parametrize("object_type, objects", TEST_OBJECTS.items()) -def test_swh_model_todict_fromdict(object_type, objects): +@pytest.mark.parametrize("objects", TEST_OBJECTS.values()) +def test_swh_model_todict_fromdict(objects): """checks model objects in swh_model_data are in correct shape""" assert objects for obj in objects: diff --git a/swh/model/tests/test_swh_model_data.py b/swh/model/tests/test_swh_model_data.py index 3815c487540d0d01ea4c24952e598b6bea53641c..2877827f64ec113771f902d4776256d30d11b308 100644 --- a/swh/model/tests/test_swh_model_data.py +++ b/swh/model/tests/test_swh_model_data.py @@ -1,4 +1,4 @@ -# Copyright (C) 2021 The Software Heritage developers +# Copyright (C) 2021-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -6,7 +6,14 @@ import attr import pytest -from swh.model.model import ModelObjectType +from swh.model.model import ( + Directory, + OriginVisit, + OriginVisitStatus, + Release, + Revision, + Snapshot, +) from swh.model.tests.swh_model_data import TEST_OBJECTS @@ -15,17 +22,17 @@ def test_swh_model_data(object_type, objects): """checks model objects in swh_model_data are in correct shape""" assert objects for obj in objects: - assert obj.object_type == object_type + assert obj.object_type.value == object_type attr.validate(obj) @pytest.mark.parametrize( "object_type", ( - ModelObjectType.DIRECTORY, - ModelObjectType.REVISION, - ModelObjectType.RELEASE, - ModelObjectType.SNAPSHOT, + Directory.object_type.value, + Revision.object_type.value, + Release.object_type.value, + Snapshot.object_type.value, ), ) def test_swh_model_data_hash(object_type): @@ -44,8 +51,8 @@ def test_ensure_visit_status_date_consistency(): parameters from the origin-visit {origin, visit, date}... """ - visits = TEST_OBJECTS[ModelObjectType.ORIGIN_VISIT] - visit_statuses = TEST_OBJECTS[ModelObjectType.ORIGIN_VISIT_STATUS] + visits = TEST_OBJECTS[OriginVisit.object_type.value] + visit_statuses = TEST_OBJECTS[OriginVisitStatus.object_type.value] for visit, visit_status in zip(visits, visit_statuses): assert visit.origin == visit_status.origin assert visit.visit == visit_status.visit @@ -54,7 +61,7 @@ def test_ensure_visit_status_date_consistency(): def test_ensure_visit_status_snapshot_consistency(): """ensure origin-visit-status snapshots exist in the test dataset""" - snapshots = [snp.id for snp in TEST_OBJECTS[ModelObjectType.SNAPSHOT]] - for visit_status in TEST_OBJECTS[ModelObjectType.ORIGIN_VISIT_STATUS]: + snapshots = [snp.id for snp in TEST_OBJECTS[Snapshot.object_type.value]] + for visit_status in TEST_OBJECTS[OriginVisitStatus.object_type.value]: if visit_status.snapshot: assert visit_status.snapshot in snapshots