Skip to content
Snippets Groups Projects
Commit 366680cd authored by Nicolas Dandrimont's avatar Nicolas Dandrimont
Browse files

RPC server: ensure MaskedObjectExceptions can be encoded

Our error handlers were missing an application of the extra encoders, so
SWHIDs weren't encodable as exception arguments, and the new
MaskedStatuses weren't supported by our encoders.
parent 7f8c52e9
No related branches found
No related tags found
1 merge request!1121RPC server: ensure MaskedObjectExceptions can be encoded
......@@ -6,9 +6,11 @@
"""Decoder and encoders for swh-model objects."""
from typing import Any, Callable, Dict, List, Tuple
import uuid
from swh.model import model, swhids
from swh.storage import interface
from swh.storage.proxies.masking.db import MaskedState, MaskedStatus
def _encode_model_object(obj):
......@@ -48,6 +50,17 @@ def _encode_snapshot_branch_by_name_response(
}
def _encode_masked_status(masked_status: MaskedStatus):
return {
"state": masked_status.state.name,
"request": str(masked_status.request),
}
def _decode_masked_status(d: Dict[str, Any]):
return MaskedStatus(state=MaskedState[d["state"]], request=uuid.UUID(d["request"]))
def _decode_origin_visit_with_statuses(
ovws: Dict[str, Any],
) -> interface.OriginVisitWithStatuses:
......@@ -109,6 +122,7 @@ ENCODERS: List[Tuple[type, str, Callable]] = [
"branch_by_name_response",
_encode_snapshot_branch_by_name_response,
),
(MaskedStatus, "masked_status", _encode_masked_status),
]
......@@ -124,4 +138,5 @@ DECODERS: Dict[str, Callable] = {
"origin_visit_with_statuses": _decode_origin_visit_with_statuses,
"object_reference": _decode_object_reference,
"branch_by_name_response": _decode_snapshot_branch_by_name_response,
"masked_status": _decode_masked_status,
}
......@@ -3,6 +3,7 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from functools import partial
import logging
import os
from typing import Any, Dict, Optional
......@@ -97,7 +98,9 @@ storage = None
def non_retryable_error_handler(exception):
"""Send all non-retryable errors with a 400 status code so the client can
re-raise them."""
return error_handler(exception, encode_data, status_code=400)
return error_handler(
exception, partial(encode_data, extra_type_encoders=ENCODERS), status_code=400
)
app.setup_psycopg2_errorhandlers()
......
......@@ -3,13 +3,18 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from uuid import uuid4
import psycopg2.errors
import pytest
from swh.core.api import RemoteException, TransientRemoteException
from swh.model.swhids import ExtendedSWHID
import swh.storage
from swh.storage import get_storage
import swh.storage.api.server as server
from swh.storage.exc import MaskedObjectException
from swh.storage.proxies.masking.db import MaskedState, MaskedStatus
from swh.storage.tests.storage_tests import (
TestStorageGeneratedData as _TestStorageGeneratedData,
)
......@@ -130,6 +135,23 @@ class TestStorageApi(_TestStorage):
swh_storage.revision_get(["\x01" * 20])
assert not isinstance(excinfo.value, TransientRemoteException)
def test_masked_object_exception(self, app_server, swh_storage, mocker):
"""Checks the client re-raises masking proxy exceptions"""
assert swh_storage.revision_get(["\x01" * 20]) == [None]
masked = {
ExtendedSWHID.from_string("swh:1:rev:" + ("01" * 20)): [
MaskedStatus(MaskedState.DECISION_PENDING, request=uuid4())
]
}
mocker.patch.object(
app_server.storage._cql_runner,
"revision_get",
side_effect=MaskedObjectException(masked),
)
with pytest.raises(MaskedObjectException) as e:
swh_storage.revision_get(["\x01" * 20])
assert e.value.masked == masked
class TestStorageApiGeneratedData(_TestStorageGeneratedData):
@pytest.mark.skip("Not supported by Cassandra")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment