Skip to content
Snippets Groups Projects
Commit 9b3e3aa6 authored by vlorentz's avatar vlorentz
Browse files

Add support for extra {de,en}coders.

This will allow swh-storage to pass model objects.
parent 6c5e89a4
No related branches found
No related tags found
No related merge requests found
......@@ -42,21 +42,23 @@ DECODERS = {
}
def encode_data_client(data: Any) -> bytes:
def encode_data_client(data: Any, extra_encoders=None) -> bytes:
try:
return msgpack_dumps(data)
return msgpack_dumps(data, extra_encoders=extra_encoders)
except OverflowError as e:
raise ValueError('Limits were reached. Please, check your input.\n' +
str(e))
def decode_response(response: Response) -> Any:
def decode_response(response: Response, extra_decoders=None) -> Any:
content_type = response.headers['content-type']
if content_type.startswith('application/x-msgpack'):
r = msgpack_loads(response.content)
r = msgpack_loads(response.content,
extra_decoders=extra_decoders)
elif content_type.startswith('application/json'):
r = json.loads(response.text, cls=SWHJSONDecoder)
r = json.loads(response.text, cls=SWHJSONDecoder,
extra_decoders=extra_decoders)
elif content_type.startswith('text/'):
r = response.text
else:
......@@ -90,9 +92,15 @@ class SWHJSONEncoder(json.JSONEncoder):
"""
def __init__(self, extra_encoders=None, **kwargs):
super().__init__(**kwargs)
self.encoders = ENCODERS
if extra_encoders:
self.encoders += extra_encoders
def default(self, o: Any
) -> Union[Dict[str, Union[Dict[str, int], str]], list]:
for (type_, type_name, encoder) in ENCODERS:
for (type_, type_name, encoder) in self.encoders:
if isinstance(o, type_):
return {
'swhtype': type_name,
......@@ -129,12 +137,18 @@ class SWHJSONDecoder(json.JSONDecoder):
"""
def __init__(self, extra_decoders=None, **kwargs):
super().__init__(**kwargs)
self.decoders = DECODERS
if extra_decoders:
self.decoders = {**self.decoders, **extra_decoders}
def decode_data(self, o: Any) -> Any:
if isinstance(o, dict):
if set(o.keys()) == {'d', 'swhtype'}:
if o['swhtype'] == 'bytes':
return base64.b85decode(o['d'])
decoder = DECODERS.get(o['swhtype'])
decoder = self.decoders.get(o['swhtype'])
if decoder:
return decoder(self.decode_data(o['d']))
return {key: self.decode_data(value) for key, value in o.items()}
......@@ -148,13 +162,17 @@ class SWHJSONDecoder(json.JSONDecoder):
return self.decode_data(data), index
def msgpack_dumps(data: Any) -> bytes:
def msgpack_dumps(data: Any, extra_encoders=None) -> bytes:
"""Write data as a msgpack stream"""
encoders = ENCODERS
if extra_encoders:
encoders += extra_encoders
def encode_types(obj):
if isinstance(obj, types.GeneratorType):
return list(obj)
for (type_, type_name, encoder) in ENCODERS:
for (type_, type_name, encoder) in encoders:
if isinstance(obj, type_):
return {
b'swhtype': type_name,
......@@ -165,11 +183,15 @@ def msgpack_dumps(data: Any) -> bytes:
return msgpack.packb(data, use_bin_type=True, default=encode_types)
def msgpack_loads(data: bytes) -> Any:
def msgpack_loads(data: bytes, extra_decoders=None) -> Any:
"""Read data as a msgpack stream"""
decoders = DECODERS
if extra_decoders:
decoders = {**decoders, **extra_decoders}
def decode_types(obj):
if set(obj.keys()) == {b'd', b'swhtype'}:
decoder = DECODERS.get(obj[b'swhtype'])
decoder = decoders.get(obj[b'swhtype'])
if decoder:
return decoder(obj[b'd'])
return obj
......
......@@ -21,6 +21,28 @@ from swh.core.api.serializers import (
)
class ExtraType:
def __init__(self, arg1, arg2):
self.arg1 = arg1
self.arg2 = arg2
def __repr__(self):
return f'ExtraType({self.arg1}, {self.arg2})'
def __eq__(self, other):
return (self.arg1, self.arg2) == (other.arg1, other.arg2)
extra_encoders = [
(ExtraType, 'extratype', lambda o: (o.arg1, o.arg2))
]
extra_decoders = {
'extratype': lambda o: ExtraType(*o),
}
class Serializers(unittest.TestCase):
def setUp(self):
self.tz = datetime.timezone(datetime.timedelta(minutes=118))
......@@ -67,6 +89,16 @@ class Serializers(unittest.TestCase):
data = json.dumps(self.data, cls=SWHJSONEncoder)
self.assertEqual(self.data, json.loads(data, cls=SWHJSONDecoder))
def test_round_trip_json_extra_types(self):
original_data = [ExtraType('baz', self.data), 'qux']
data = json.dumps(original_data, cls=SWHJSONEncoder,
extra_encoders=extra_encoders)
self.assertEqual(
original_data,
json.loads(
data, cls=SWHJSONDecoder, extra_decoders=extra_decoders))
def test_encode_swh_json(self):
data = json.dumps(self.data, cls=SWHJSONEncoder)
self.assertEqual(self.encoded_data, json.loads(data))
......@@ -75,6 +107,13 @@ class Serializers(unittest.TestCase):
data = msgpack_dumps(self.data)
self.assertEqual(self.data, msgpack_loads(data))
def test_round_trip_msgpack_extra_types(self):
original_data = [ExtraType('baz', self.data), 'qux']
data = msgpack_dumps(original_data, extra_encoders=extra_encoders)
self.assertEqual(
original_data, msgpack_loads(data, extra_decoders=extra_decoders))
def test_generator_json(self):
data = json.dumps(self.generator, cls=SWHJSONEncoder)
self.assertEqual(self.gen_lst, json.loads(data, cls=SWHJSONDecoder))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment