Skip to content
Snippets Groups Projects
Commit 4790c2a1 authored by (@ardumont)'s avatar (@ardumont) Committed by Phabricator Migration user
Browse files

rpc_client: Allow http method declaration consistently with the rpc server

a7d1aa7b introduced that use for the server. Without doing this consistenly in
the rpc client part, this won't work though.

Related to a7d1aa7b
parent 0796ac6c
No related branches found
No related tags found
1 merge request!181api: Refactor to simplify the post/get code to a minimum
......@@ -160,11 +160,12 @@ class MetaRPCClient(type):
if backend_class:
for (meth_name, meth) in backend_class.__dict__.items():
if hasattr(meth, "_endpoint_path"):
cls.__add_endpoint(meth_name, meth, attributes)
http_method = meth._method # POST by default
cls.__add_endpoint(http_method, meth_name, meth, attributes)
return super().__new__(cls, name, bases, attributes)
@staticmethod
def __add_endpoint(meth_name, meth, attributes):
def __add_endpoint(http_method: str, meth_name: str, meth, attributes):
wrapped_meth = inspect.unwrap(meth)
@functools.wraps(meth) # Copy signature and doc
......@@ -178,7 +179,11 @@ class MetaRPCClient(type):
post_data.pop("db", None)
# Send the request.
return self.post(meth._endpoint_path, post_data)
if http_method == "POST":
return self.post(meth._endpoint_path, post_data)
else:
data = post_data or {}
return self.get(meth._endpoint_path, data=data)
if meth_name not in attributes:
attributes[meth_name] = meth_
......@@ -256,15 +261,11 @@ class RPCClient(metaclass=MetaRPCClient):
raise self.api_exception(e)
def post(self, endpoint, data, **opts):
if isinstance(data, (abc.Iterator, abc.Generator)):
data = (self._encode_data(x) for x in data)
else:
data = self._encode_data(data)
chunk_size = opts.pop("chunk_size", self.chunk_size)
response = self.raw_verb(
"post",
endpoint,
data=data,
data=self._encode_data(data),
headers={
"content-type": "application/x-msgpack",
"accept": "application/x-msgpack",
......@@ -278,15 +279,34 @@ class RPCClient(metaclass=MetaRPCClient):
return self._decode_response(response)
def _encode_data(self, data):
return encode_data(data, extra_encoders=self.extra_type_encoders)
if isinstance(data, (abc.Iterator, abc.Generator)):
data = (
encode_data(x, extra_encoders=self.extra_type_encoders) for x in data
)
else:
data = encode_data(data, extra_encoders=self.extra_type_encoders)
return data
post_stream = post
def get(self, endpoint, **opts):
def get(self, endpoint: str, data={}, **opts):
chunk_size = opts.pop("chunk_size", self.chunk_size)
response = self.raw_verb(
"get", endpoint, headers={"accept": "application/x-msgpack"}, **opts
)
if data:
response = self.raw_verb(
"get",
endpoint,
headers={
"accept": "application/x-msgpack",
"content-type": "application/x-msgpack",
},
data=self._encode_data(data),
**opts,
)
else:
response = self.raw_verb(
"get", endpoint, headers={"accept": "application/x-msgpack"}, **opts
)
if opts.get("stream") or response.headers.get("transfer-encoding") == "chunked":
self.raise_for_status(response)
return response.iter_content(chunk_size)
......@@ -440,12 +460,13 @@ class RPCServerApp(Flask):
backend_factory = backend_factory or backend_class
for (meth_name, meth) in backend_class.__dict__.items():
if hasattr(meth, "_endpoint_path"):
self.__add_endpoint(meth_name, meth, backend_factory)
http_method = meth._method # default to POST
self.__add_endpoint(http_method, meth_name, meth, backend_factory)
def __add_endpoint(self, meth_name, meth, backend_factory):
def __add_endpoint(self, http_method: str, meth_name: str, meth, backend_factory):
from flask import request
@self.route("/" + meth._endpoint_path, methods=["POST"])
@self.route(f"/{meth._endpoint_path}", methods=[http_method])
@negotiate(MsgpackFormatter, extra_encoders=self.extra_type_encoders)
@negotiate(JSONFormatter, extra_encoders=self.extra_type_encoders)
@functools.wraps(meth) # Copy signature and doc
......
# Copyright (C) 2018-2019 The Software Heritage developers
# Copyright (C) 2018-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
......@@ -27,6 +27,10 @@ def rpc_client(requests_mock):
def serializer_test(self, data, db=None, cur=None):
...
@remote_api_endpoint("another_endpoint", method="GET")
def some_get_method(self, data):
...
@remote_api_endpoint("overridden/endpoint")
def overridden_method(self, data):
return "foo"
......@@ -39,7 +43,7 @@ def rpc_client(requests_mock):
def overridden_method(self, data):
return "bar"
def callback(request, context):
def callback_post(request, context):
assert request.headers["Content-Type"] == "application/x-msgpack"
context.headers["Content-Type"] = "application/x-msgpack"
if request.path == "/test_endpoint_url":
......@@ -55,7 +59,17 @@ def rpc_client(requests_mock):
assert False
return context.content
requests_mock.post(re.compile("mock://example.com/"), content=callback)
def callback_get(request, context):
context.headers["Content-Type"] = "application/x-msgpack"
if request.path == "/another_endpoint":
context.content = b"\xc4\x0eanother-result"
else:
assert False
return context.content
requests_mock.post(re.compile("mock://example.com/"), content=callback_post)
requests_mock.get(re.compile("mock://example.com/"), content=callback_get)
return Testclient(url="mock://example.com")
......@@ -75,6 +89,9 @@ def test_client(rpc_client):
res = rpc_client.something(data="whatever")
assert res == "spam"
res = rpc_client.some_get_method(data="something")
assert res == b"another-result"
def test_client_extra_serializers(rpc_client):
res = rpc_client.serializer_test(["foo", ExtraType("bar", b"baz")])
......
......@@ -30,6 +30,10 @@ class RPCTest:
def raise_typeerror(self):
raise TypeError("Did I pass through?")
@remote_api_endpoint("stuff", method="GET")
def get_stuff(self, test_input, db=None, cur=None):
return test_input
# this class is used on the client part. We cannot inherit from RPCTest
# because the automagic metaclass based code that generates the RPCClient
......@@ -54,6 +58,10 @@ class RPCTest2:
def raise_typeerror(self):
return "data"
@remote_api_endpoint("stuff", method="GET")
def get_stuff(self, test_input, db=None, cur=None):
return test_input
class RPCTestClient(RPCClient):
backend_class = RPCTest2
......@@ -98,6 +106,11 @@ def test_api_endpoint_kwargs(swh_rpc_client):
assert res == "egg"
def test_api_endpoint_get_stuff(swh_rpc_client):
res = swh_rpc_client.get_stuff("something")
assert res == "something"
def test_api_endpoint_args(swh_rpc_client):
res = swh_rpc_client.something("whatever")
assert res == "whatever"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment