Skip to content
Snippets Groups Projects
Commit e199549f authored by vlorentz's avatar vlorentz
Browse files

masking: Add method MaskingQuery.iter_masked_swhids

This will be used by swh-dataset to list all SWHIDs to mask before an export,
instead of querying the database over and over while exporting.
parent 8a1b4346
Branches iter_masked_swhids
Tags v2.3.1
No related merge requests found
......@@ -6,7 +6,8 @@
import datetime
import enum
from typing import Dict, List, Optional, Tuple
import itertools
from typing import Dict, Iterator, List, Optional, Tuple
from uuid import UUID
import attr
......@@ -19,6 +20,8 @@ from swh.model.swhids import ExtendedObjectType, ExtendedSWHID
from swh.storage.exc import StorageArgumentException
METRIC_QUERY_TOTAL = "swh_storage_masking_queried_total"
METRIC_LIST_REQUESTS_TOTAL = "swh_storage_masking_list_requests_total"
METRIC_LISTED_TOTAL = "swh_storage_masking_listed_total"
METRIC_MASKED_TOTAL = "swh_storage_masking_masked_total"
......@@ -358,8 +361,9 @@ class MaskingQuery(MaskingDb):
) -> Dict[ExtendedSWHID, List[MaskedStatus]]:
"""Checks which objects in the list are masked.
Returns: For each masked object, a list of :class:`MaskedStatus` objects
where the State is not :const:`MaskedState.VISIBLE`.
Returns:
For each masked object, a list of :class:`MaskedStatus` objects
where the State is not :const:`MaskedState.VISIBLE`.
"""
cur = self.cursor()
......@@ -401,3 +405,46 @@ class MaskingQuery(MaskingDb):
if ret:
statsd.increment(METRIC_MASKED_TOTAL, len(ret))
return ret
def iter_masked_swhids(self) -> Iterator[Tuple[ExtendedSWHID, List[MaskedStatus]]]:
"""Returns the complete list of masked SWHIDs.
SWHIDs are guaranteed to be unique in the iterator.
Yields:
For each masked object, its SWHID and a list of :class:`MaskedStatus`
objects where the State is not :const:`MaskedState.VISIBLE`.
"""
cur = self.cursor()
statsd.increment(METRIC_LIST_REQUESTS_TOTAL, 1)
cur.execute(
"""
SELECT object_id, object_type, request, state
FROM masked_object
WHERE state != 'visible'
ORDER BY object_id, object_type
"""
)
count = 0
for (object_id, object_type), statuses in itertools.groupby(
cur, key=lambda t: (t[0], t[1])
):
count += 1
swhid = ExtendedSWHID(
object_id=object_id,
object_type=ExtendedObjectType[object_type.upper()],
)
yield (
swhid,
[
MaskedStatus(request=request_id, state=MaskedState[state.upper()])
for (_, _, request_id, state) in statuses
],
)
statsd.increment(METRIC_LISTED_TOTAL, count)
......@@ -226,6 +226,7 @@ def test_swhid_lifecycle(masking_admin: MaskingAdmin, masking_query: MaskingQuer
}
assert masking_query.swhids_are_masked(all_swhids) == expected
assert dict(masking_query.iter_masked_swhids()) == expected
restricted = masked_swhids[0:2]
......@@ -239,6 +240,7 @@ def test_swhid_lifecycle(masking_admin: MaskingAdmin, masking_query: MaskingQuer
]
assert masking_query.swhids_are_masked(all_swhids) == expected
assert dict(masking_query.iter_masked_swhids()) == expected
visible = masked_swhids[2:4]
......@@ -250,6 +252,7 @@ def test_swhid_lifecycle(masking_admin: MaskingAdmin, masking_query: MaskingQuer
del expected[swhid]
assert masking_query.swhids_are_masked(all_swhids) == expected
assert dict(masking_query.iter_masked_swhids()) == expected
def test_query_metrics(
......@@ -278,18 +281,33 @@ def test_query_metrics(
StorageData.origin2.swhid(),
]
# Query with no masked SWHIDs
assert masking_query.swhids_are_masked(all_swhids) == {}
increment.assert_called_once_with(
"swh_storage_masking_queried_total", len(all_swhids)
)
increment.reset_mock()
assert dict(masking_query.iter_masked_swhids()) == {}
increment.assert_has_calls(
[
call("swh_storage_masking_list_requests_total", 1),
call("swh_storage_masking_listed_total", 0),
]
)
increment.reset_mock()
# Mask some SWHIDs
masking_admin.set_object_state(
request_id=request.id,
new_state=MaskedState.DECISION_PENDING,
swhids=masked_swhids,
)
# Query again
assert len(masking_query.swhids_are_masked(all_swhids)) == len(masked_swhids)
increment.assert_has_calls(
[
......@@ -297,3 +315,12 @@ def test_query_metrics(
call("swh_storage_masking_masked_total", len(masked_swhids)),
]
)
increment.reset_mock()
assert set(dict(masking_query.iter_masked_swhids())) == set(masked_swhids)
increment.assert_has_calls(
[
call("swh_storage_masking_list_requests_total", 1),
call("swh_storage_masking_listed_total", len(masked_swhids)),
]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment