Skip to content
Snippets Groups Projects
Commit f9619fb1 authored by Antoine Lambert's avatar Antoine Lambert
Browse files

api/serializers: Add Exception type encoder and decoder

An exception can be constructed with another exception argument so handle
that special case by adding a dedicated encoder and decoder to avoid remote
exception serialization error.
parent e2bd9c31
No related branches found
No related tags found
1 merge request!187api/serializers: Add Exception type encoder
......@@ -316,13 +316,17 @@ class RPCClient(metaclass=MetaRPCClient):
data = self._decode_response(response, check_status=False)
if isinstance(data, dict):
for exc_type in self.reraise_exceptions:
if exc_type.__name__ == data["exception"]["type"]:
exception = exc_type(*data["exception"]["args"])
if exc_type.__name__ == data["type"]:
exception = exc_type(*data["args"])
break
else:
# old dict encoded exception schema
# TODO: Remove that code once all servers are using new schema
if "exception" in data:
exception = RemoteException(
payload=data["exception"], response=response
)
else:
exception = RemoteException(payload=data, response=response)
else:
exception = pickle.loads(data)
......@@ -330,10 +334,14 @@ class RPCClient(metaclass=MetaRPCClient):
data = self._decode_response(response, check_status=False)
if "exception_pickled" in data:
exception = pickle.loads(data["exception_pickled"])
else:
# old dict encoded exception schema
# TODO: Remove that code once all servers are using new schema
elif "exception" in data:
exception = RemoteException(
payload=data["exception"], response=response
)
else:
exception = RemoteException(payload=data, response=response)
except (TypeError, pickle.UnpicklingError):
raise RemoteException(payload=data, response=response)
......
......@@ -40,6 +40,23 @@ def _decode_paged_result(obj: Dict[str, Any]) -> PagedResult:
return PagedResult(results=obj["results"], next_page_token=obj["next_page_token"],)
def exception_to_dict(exception: Exception) -> Dict[str, Any]:
tb = traceback.format_exception(None, exception, exception.__traceback__)
exc_type = type(exception)
return {
"type": exc_type.__name__,
"module": exc_type.__module__,
"args": exception.args,
"message": str(exception),
"traceback": tb,
}
def dict_to_exception(exc_dict: Dict[str, Any]) -> Exception:
temp = __import__(exc_dict["module"], fromlist=[exc_dict["type"]])
return getattr(temp, exc_dict["type"])(*exc_dict["args"])
ENCODERS = [
(arrow.Arrow, "arrow", arrow.Arrow.isoformat),
(datetime.datetime, "datetime", encode_datetime),
......@@ -56,6 +73,7 @@ ENCODERS = [
(PagedResult, "paged_result", _encode_paged_result),
# Only for JSON:
(bytes, "bytes", lambda o: base64.b85encode(o).decode("ascii")),
(Exception, "exception", exception_to_dict),
]
DECODERS = {
......@@ -66,6 +84,7 @@ DECODERS = {
"paged_result": _decode_paged_result,
# Only for JSON:
"bytes": base64.b85decode,
"exception": dict_to_exception,
}
......@@ -279,15 +298,3 @@ def msgpack_loads(data: bytes, extra_decoders=None) -> Any:
return msgpack.unpackb(
data, encoding="utf-8", object_hook=decode_types, ext_hook=ext_hook
)
def exception_to_dict(exception):
tb = traceback.format_exception(None, exception, exception.__traceback__)
return {
"exception": {
"type": type(exception).__name__,
"args": exception.args,
"message": str(exception),
"traceback": tb,
}
}
......@@ -116,7 +116,7 @@ async def test_get_server_exception(cli) -> None:
assert resp.status == 500
data = await resp.read()
data = msgpack.unpackb(data, raw=False)
assert data["exception"]["type"] == "TestServerException"
assert data["type"] == "TestServerException"
async def test_get_client_error(cli) -> None:
......@@ -124,7 +124,7 @@ async def test_get_client_error(cli) -> None:
assert resp.status == 400
data = await resp.read()
data = msgpack.unpackb(data, raw=False)
assert data["exception"]["type"] == "TestClientError"
assert data["type"] == "TestClientError"
async def test_get_simple_nego(cli) -> None:
......
......@@ -30,6 +30,10 @@ class RPCTest:
def raise_typeerror(self):
raise TypeError("Did I pass through?")
@remote_api_endpoint("raise_exception_exc_arg")
def raise_exception_exc_arg(self):
raise Exception(Exception("error"))
# this class is used on the client part. We cannot inherit from RPCTest
# because the automagic metaclass based code that generates the RPCClient
......@@ -115,3 +119,12 @@ def test_api_typeerror(swh_rpc_client):
str(exc_info.value)
== "<RemoteException 500 TypeError: ['Did I pass through?']>"
)
def test_api_raise_exception_exc_arg(swh_rpc_client):
with pytest.raises(RemoteException) as exc_info:
swh_rpc_client.post("raise_exception_exc_arg", data={})
assert exc_info.value.args[0]["type"] == "Exception"
assert type(exc_info.value.args[0]["args"][0]) == Exception
assert str(exc_info.value.args[0]["args"][0]) == "error"
......@@ -12,6 +12,7 @@ import arrow
from arrow import Arrow
import pytest
import requests
from requests.exceptions import ConnectionError
from swh.core.api.classes import PagedResult
from swh.core.api.serializers import (
......@@ -148,6 +149,17 @@ def test_serializers_round_trip_json_extra_types():
assert actual_data == expected_original_data
def test_exception_serializer_round_trip_json():
error_message = "unreachable host"
json_data = json.dumps(
{"exception": ConnectionError(error_message)}, cls=SWHJSONEncoder
)
actual_data = json.loads(json_data, cls=SWHJSONDecoder)
assert "exception" in actual_data
assert type(actual_data["exception"]) == ConnectionError
assert str(actual_data["exception"]) == error_message
def test_serializers_encode_swh_json():
json_str = json.dumps(DATA, cls=SWHJSONEncoder)
actual_data = json.loads(json_str)
......@@ -172,6 +184,15 @@ def test_serializers_round_trip_msgpack_extra_types():
assert actual_data == original_data
def test_exception_serializer_round_trip_msgpack():
error_message = "unreachable host"
data = msgpack_dumps({"exception": ConnectionError(error_message)})
actual_data = msgpack_loads(data)
assert "exception" in actual_data
assert type(actual_data["exception"]) == ConnectionError
assert str(actual_data["exception"]) == error_message
def test_serializers_generator_json():
data = json.dumps((i for i in range(5)), cls=SWHJSONEncoder)
assert json.loads(data, cls=SWHJSONDecoder) == [i for i in range(5)]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment