From 8c988c674353ab8afb941e9e03a689028e66dc6c Mon Sep 17 00:00:00 2001 From: "(@ardumont)" <(@ardumont)> Date: Fri, 23 Oct 2020 12:06:46 +0200 Subject: [PATCH] rpc_client: Allow http method declaration consistently with the rpc server a7d1aa7 introduced that use for the server. Without doing this consistenly in the rpc client part, this won't work though. Related to a7d1aa7 --- swh/core/api/__init__.py | 53 ++++++++++++++------ swh/core/api/tests/test_rpc_client.py | 23 +++++++-- swh/core/api/tests/test_rpc_client_server.py | 13 +++++ 3 files changed, 70 insertions(+), 19 deletions(-) diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index 8ffc3b0a..e749994e 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -160,11 +160,12 @@ class MetaRPCClient(type): if backend_class: for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): - cls.__add_endpoint(meth_name, meth, attributes) + http_method = meth._method # POST by default + cls.__add_endpoint(http_method, meth_name, meth, attributes) return super().__new__(cls, name, bases, attributes) @staticmethod - def __add_endpoint(meth_name, meth, attributes): + def __add_endpoint(http_method: str, meth_name: str, meth, attributes): wrapped_meth = inspect.unwrap(meth) @functools.wraps(meth) # Copy signature and doc @@ -178,7 +179,11 @@ class MetaRPCClient(type): post_data.pop("db", None) # Send the request. - return self.post(meth._endpoint_path, post_data) + if http_method == "POST": + return self.post(meth._endpoint_path, post_data) + else: + data = post_data or {} + return self.get(meth._endpoint_path, data=data) if meth_name not in attributes: attributes[meth_name] = meth_ @@ -256,15 +261,11 @@ class RPCClient(metaclass=MetaRPCClient): raise self.api_exception(e) def post(self, endpoint, data, **opts): - if isinstance(data, (abc.Iterator, abc.Generator)): - data = (self._encode_data(x) for x in data) - else: - data = self._encode_data(data) chunk_size = opts.pop("chunk_size", self.chunk_size) response = self.raw_verb( "post", endpoint, - data=data, + data=self._encode_data(data), headers={ "content-type": "application/x-msgpack", "accept": "application/x-msgpack", @@ -278,15 +279,34 @@ class RPCClient(metaclass=MetaRPCClient): return self._decode_response(response) def _encode_data(self, data): - return encode_data(data, extra_encoders=self.extra_type_encoders) + if isinstance(data, (abc.Iterator, abc.Generator)): + data = ( + encode_data(x, extra_encoders=self.extra_type_encoders) for x in data + ) + else: + data = encode_data(data, extra_encoders=self.extra_type_encoders) + return data post_stream = post - def get(self, endpoint, **opts): + def get(self, endpoint: str, data={}, **opts): chunk_size = opts.pop("chunk_size", self.chunk_size) - response = self.raw_verb( - "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts - ) + if data: + response = self.raw_verb( + "get", + endpoint, + headers={ + "accept": "application/x-msgpack", + "content-type": "application/x-msgpack", + }, + data=self._encode_data(data), + **opts, + ) + else: + response = self.raw_verb( + "get", endpoint, headers={"accept": "application/x-msgpack"}, **opts + ) + if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked": self.raise_for_status(response) return response.iter_content(chunk_size) @@ -440,12 +460,13 @@ class RPCServerApp(Flask): backend_factory = backend_factory or backend_class for (meth_name, meth) in backend_class.__dict__.items(): if hasattr(meth, "_endpoint_path"): - self.__add_endpoint(meth_name, meth, backend_factory) + http_method = meth._method # default to POST + self.__add_endpoint(http_method, meth_name, meth, backend_factory) - def __add_endpoint(self, meth_name, meth, backend_factory): + def __add_endpoint(self, http_method: str, meth_name: str, meth, backend_factory): from flask import request - @self.route("/" + meth._endpoint_path, methods=["POST"]) + @self.route(f"/{meth._endpoint_path}", methods=[http_method]) @negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders) @negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders) @functools.wraps(meth) # Copy signature and doc diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py index 7dffac19..a5660d70 100644 --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -1,4 +1,4 @@ -# Copyright (C) 2018-2019 The Software Heritage developers +# Copyright (C) 2018-2020 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -27,6 +27,10 @@ def rpc_client(requests_mock): def serializer_test(self, data, db=None, cur=None): ... + @remote_api_endpoint("another_endpoint", method="GET") + def some_get_method(self, data): + ... + @remote_api_endpoint("overridden/endpoint") def overridden_method(self, data): return "foo" @@ -39,7 +43,7 @@ def rpc_client(requests_mock): def overridden_method(self, data): return "bar" - def callback(request, context): + def callback_post(request, context): assert request.headers["Content-Type"] == "application/x-msgpack" context.headers["Content-Type"] = "application/x-msgpack" if request.path == "/test_endpoint_url": @@ -55,7 +59,17 @@ def rpc_client(requests_mock): assert False return context.content - requests_mock.post(re.compile("mock://example.com/"), content=callback) + def callback_get(request, context): + context.headers["Content-Type"] = "application/x-msgpack" + + if request.path == "/another_endpoint": + context.content = b"\xc4\x0eanother-result" + else: + assert False + return context.content + + requests_mock.post(re.compile("mock://example.com/"), content=callback_post) + requests_mock.get(re.compile("mock://example.com/"), content=callback_get) return Testclient(url="mock://example.com") @@ -75,6 +89,9 @@ def test_client(rpc_client): res = rpc_client.something(data="whatever") assert res == "spam" + res = rpc_client.some_get_method(data="something") + assert res == b"another-result" + def test_client_extra_serializers(rpc_client): res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")]) diff --git a/swh/core/api/tests/test_rpc_client_server.py b/swh/core/api/tests/test_rpc_client_server.py index 81b0afa5..e6046be4 100644 --- a/swh/core/api/tests/test_rpc_client_server.py +++ b/swh/core/api/tests/test_rpc_client_server.py @@ -30,6 +30,10 @@ class RPCTest: def raise_typeerror(self): raise TypeError("Did I pass through?") + @remote_api_endpoint("stuff", method="GET") + def get_stuff(self, test_input, db=None, cur=None): + return test_input + # this class is used on the client part. We cannot inherit from RPCTest # because the automagic metaclass based code that generates the RPCClient @@ -54,6 +58,10 @@ class RPCTest2: def raise_typeerror(self): return "data" + @remote_api_endpoint("stuff", method="GET") + def get_stuff(self, test_input, db=None, cur=None): + return test_input + class RPCTestClient(RPCClient): backend_class = RPCTest2 @@ -98,6 +106,11 @@ def test_api_endpoint_kwargs(swh_rpc_client): assert res == "egg" +def test_api_endpoint_get_stuff(swh_rpc_client): + res = swh_rpc_client.get_stuff("something") + assert res == "something" + + def test_api_endpoint_args(swh_rpc_client): res = swh_rpc_client.something("whatever") assert res == "whatever" -- GitLab