diff --git a/swh/storage/proxies/masking/db.py b/swh/storage/proxies/masking/db.py index 3718b9c0ac739351898b3d3d7e04e07834d548ca..9c28698dcbc75f65329d11c939e83a8b8658aa21 100644 --- a/swh/storage/proxies/masking/db.py +++ b/swh/storage/proxies/masking/db.py @@ -6,7 +6,8 @@ import datetime import enum -from typing import Dict, List, Optional, Tuple +import itertools +from typing import Dict, Iterator, List, Optional, Tuple from uuid import UUID import attr @@ -19,6 +20,8 @@ from swh.model.swhids import ExtendedObjectType, ExtendedSWHID from swh.storage.exc import StorageArgumentException METRIC_QUERY_TOTAL = "swh_storage_masking_queried_total" +METRIC_LIST_REQUESTS_TOTAL = "swh_storage_masking_list_requests_total" +METRIC_LISTED_TOTAL = "swh_storage_masking_listed_total" METRIC_MASKED_TOTAL = "swh_storage_masking_masked_total" @@ -358,8 +361,9 @@ class MaskingQuery(MaskingDb): ) -> Dict[ExtendedSWHID, List[MaskedStatus]]: """Checks which objects in the list are masked. - Returns: For each masked object, a list of :class:`MaskedStatus` objects - where the State is not :const:`MaskedState.VISIBLE`. + Returns: + For each masked object, a list of :class:`MaskedStatus` objects + where the State is not :const:`MaskedState.VISIBLE`. """ cur = self.cursor() @@ -401,3 +405,46 @@ class MaskingQuery(MaskingDb): if ret: statsd.increment(METRIC_MASKED_TOTAL, len(ret)) return ret + + def iter_masked_swhids(self) -> Iterator[Tuple[ExtendedSWHID, List[MaskedStatus]]]: + """Returns the complete list of masked SWHIDs. + + SWHIDs are guaranteed to be unique in the iterator. + + Yields: + For each masked object, its SWHID and a list of :class:`MaskedStatus` + objects where the State is not :const:`MaskedState.VISIBLE`. + """ + + cur = self.cursor() + + statsd.increment(METRIC_LIST_REQUESTS_TOTAL, 1) + + cur.execute( + """ + SELECT object_id, object_type, request, state + FROM masked_object + WHERE state != 'visible' + ORDER BY object_id, object_type + """ + ) + + count = 0 + + for (object_id, object_type), statuses in itertools.groupby( + cur, key=lambda t: (t[0], t[1]) + ): + count += 1 + swhid = ExtendedSWHID( + object_id=object_id, + object_type=ExtendedObjectType[object_type.upper()], + ) + yield ( + swhid, + [ + MaskedStatus(request=request_id, state=MaskedState[state.upper()]) + for (_, _, request_id, state) in statuses + ], + ) + + statsd.increment(METRIC_LISTED_TOTAL, count) diff --git a/swh/storage/tests/masking/test_db.py b/swh/storage/tests/masking/test_db.py index 5cab2f4aeaadeebc7f0daaf4aaf02abcd9681add..09f95ceba2ee1ad667caa0df95919cf46c0f76c2 100644 --- a/swh/storage/tests/masking/test_db.py +++ b/swh/storage/tests/masking/test_db.py @@ -226,6 +226,7 @@ def test_swhid_lifecycle(masking_admin: MaskingAdmin, masking_query: MaskingQuer } assert masking_query.swhids_are_masked(all_swhids) == expected + assert dict(masking_query.iter_masked_swhids()) == expected restricted = masked_swhids[0:2] @@ -239,6 +240,7 @@ def test_swhid_lifecycle(masking_admin: MaskingAdmin, masking_query: MaskingQuer ] assert masking_query.swhids_are_masked(all_swhids) == expected + assert dict(masking_query.iter_masked_swhids()) == expected visible = masked_swhids[2:4] @@ -250,6 +252,7 @@ def test_swhid_lifecycle(masking_admin: MaskingAdmin, masking_query: MaskingQuer del expected[swhid] assert masking_query.swhids_are_masked(all_swhids) == expected + assert dict(masking_query.iter_masked_swhids()) == expected def test_query_metrics( @@ -278,18 +281,33 @@ def test_query_metrics( StorageData.origin2.swhid(), ] + # Query with no masked SWHIDs + assert masking_query.swhids_are_masked(all_swhids) == {} increment.assert_called_once_with( "swh_storage_masking_queried_total", len(all_swhids) ) increment.reset_mock() + assert dict(masking_query.iter_masked_swhids()) == {} + increment.assert_has_calls( + [ + call("swh_storage_masking_list_requests_total", 1), + call("swh_storage_masking_listed_total", 0), + ] + ) + increment.reset_mock() + + # Mask some SWHIDs + masking_admin.set_object_state( request_id=request.id, new_state=MaskedState.DECISION_PENDING, swhids=masked_swhids, ) + # Query again + assert len(masking_query.swhids_are_masked(all_swhids)) == len(masked_swhids) increment.assert_has_calls( [ @@ -297,3 +315,12 @@ def test_query_metrics( call("swh_storage_masking_masked_total", len(masked_swhids)), ] ) + increment.reset_mock() + + assert set(dict(masking_query.iter_masked_swhids())) == set(masked_swhids) + increment.assert_has_calls( + [ + call("swh_storage_masking_list_requests_total", 1), + call("swh_storage_masking_listed_total", len(masked_swhids)), + ] + )