diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py index c528dfedc41268948d84a37a1bb1f445d13d94d8..cc680b75e029ee034743d3f7a15fdc299c38f66c 100644 --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2023 The Software Heritage developers +# Copyright (C) 2015-2024 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -26,8 +26,12 @@ from deprecated import deprecated from flask import Flask, Request, Response, abort, request import requests import sentry_sdk +from tenacity.before_sleep import before_sleep_log +from tenacity.wait import wait_fixed from werkzeug.exceptions import HTTPException +from swh.core.retry import http_retry, retry_if_exception + from .negotiation import Formatter as FormatterBase from .negotiation import Negotiator as NegotiatorBase from .negotiation import negotiate as _negotiate @@ -41,6 +45,8 @@ from .serializers import ( msgpack_loads, ) +RETRY_WAIT_INTERVAL = 10 + logger = logging.getLogger(__name__) @@ -211,6 +217,10 @@ class RPCClient(metaclass=MetaRPCClient): reraise_exceptions: On server errors, if any of the exception classes in this list has the same name as the error name, then the exception will be instantiated and raised instead of a generic RemoteException. + enable_requests_retry: If set to :const:`True`, requests sent by the client will + be retried when encountering specific errors. Default policy is to retry when + connection errors or transient remote exceptions are raised. Subclasses can + change that policy by overriding the :meth:`retry_policy` method. """ @@ -238,6 +248,11 @@ class RPCClient(metaclass=MetaRPCClient): extra_type_decoders: Dict[str, Callable] = {} """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads` to be able to deserialize more object types.""" + enable_requests_retry: bool = False + """If set to :const:`True`, requests sent by the client will be retried + when encountering specific errors. Default policy is to retry when connection + errors or transient remote exceptions are raised. Subclasses can change that + policy by overriding the :meth:`retry_policy` method.""" def __init__( self, @@ -250,12 +265,15 @@ class RPCClient(metaclass=MetaRPCClient): adapter_kwargs: Optional[Dict[str, Any]] = None, api_exception: Optional[Type[Exception]] = None, reraise_exceptions: Optional[List[Type[Exception]]] = None, + enable_requests_retry: Optional[bool] = None, **kwargs, ): if api_exception: self.api_exception = api_exception if reraise_exceptions: self.reraise_exceptions = reraise_exceptions + if enable_requests_retry is not None: + self.enable_requests_retry = enable_requests_retry base_url = url if url.endswith("/") else url + "/" self.url = base_url @@ -278,6 +296,28 @@ class RPCClient(metaclass=MetaRPCClient): self.chunk_size = chunk_size + if self.enable_requests_retry: + + retry = http_retry( + retry=self.retry_policy, + wait=wait_fixed(RETRY_WAIT_INTERVAL), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + setattr(self, "_get", retry(self._get)) + setattr(self, "_post", retry(self._post)) + + def retry_policy(self, retry_state): + return retry_if_exception( + retry_state, + lambda e: ( + isinstance(e, TransientRemoteException) + or ( + isinstance(e, self.api_exception) + and isinstance(e.args[0], requests.exceptions.ConnectionError) + ) + ), + ) + def _url(self, endpoint): return "%s%s" % (self.url, endpoint) diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py index 6e641e6f20b90775d0a1670ce3e3db69494976a9..2706d21a749e209a0bc578ee1896b4bd7edb9c7a 100644 --- a/swh/core/api/tests/test_rpc_client.py +++ b/swh/core/api/tests/test_rpc_client.py @@ -1,4 +1,4 @@ -# Copyright (C) 2018-2022 The Software Heritage developers +# Copyright (C) 2018-2024 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -9,6 +9,7 @@ import pytest from requests.exceptions import ConnectionError from swh.core.api import ( + RETRY_WAIT_INTERVAL, APIError, RemoteException, RPCClient, @@ -16,6 +17,7 @@ from swh.core.api import ( remote_api_endpoint, ) from swh.core.api.serializers import exception_to_dict, msgpack_dumps +from swh.core.retry import MAX_NUMBER_ATTEMPTS from .test_serializers import ExtraType, extra_decoders, extra_encoders @@ -48,6 +50,7 @@ def rpc_client_class(requests_mock): extra_type_encoders = extra_encoders extra_type_decoders = extra_decoders reraise_exceptions = [ReraiseException] + enable_requests_retry = True def overridden_method(self, data): return "bar" @@ -115,10 +118,11 @@ def test_client_request_too_large(rpc_client): assert exc_info.value.args[1].status_code == 413 -def test_client_connexion_error(rpc_client, requests_mock): +def test_client_connexion_error(rpc_client, requests_mock, mocker): """ ConnectionError should be wrapped and raised as an APIError. """ + mock_sleep = mocker.patch("time.sleep") error_message = "unreachable host" requests_mock.post( re.compile("mock://example.com/connection_error"), @@ -131,6 +135,14 @@ def test_client_connexion_error(rpc_client, requests_mock): assert type(exc_info.value.args[0]) is ConnectionError assert str(exc_info.value.args[0]) == error_message + # check request retries on connection errors + mock_sleep.assert_has_calls( + [ + mocker.call(param) + for param in [RETRY_WAIT_INTERVAL] * (MAX_NUMBER_ATTEMPTS - 1) + ] + ) + def _exception_response(exception, status_code): def callback(request, context): @@ -144,10 +156,11 @@ def _exception_response(exception, status_code): return callback -def test_client_reraise_exception(rpc_client, requests_mock): +def test_client_reraise_exception(rpc_client, requests_mock, mocker): """ Exception caught server-side and whitelisted will be raised again client-side. """ + mock_sleep = mocker.patch("time.sleep") error_message = "something went wrong" endpoint = "reraise_exception" @@ -163,10 +176,12 @@ def test_client_reraise_exception(rpc_client, requests_mock): rpc_client._post(endpoint, data={}) assert str(exc_info.value) == error_message + # no request retry for such exception + mock_sleep.assert_not_called() @pytest.mark.parametrize("status_code", [400, 500, 502, 503]) -def test_client_raise_remote_exception(rpc_client, requests_mock, status_code): +def test_client_raise_remote_exception(rpc_client, requests_mock, status_code, mocker): """ Exception caught server-side and not whitelisted will be wrapped and raised as a RemoteException client-side. @@ -181,6 +196,7 @@ def test_client_raise_remote_exception(rpc_client, requests_mock, status_code): status_code=status_code, ), ) + mock_sleep = mocker.patch("time.sleep") with pytest.raises(RemoteException) as exc_info: rpc_client._post(endpoint, data={}) @@ -189,8 +205,17 @@ def test_client_raise_remote_exception(rpc_client, requests_mock, status_code): assert str(exc_info.value.args[0]["message"]) == error_message if status_code in (502, 503): assert isinstance(exc_info.value, TransientRemoteException) + # check request retry on transient remote exception + mock_sleep.assert_has_calls( + [ + mocker.call(param) + for param in [RETRY_WAIT_INTERVAL] * (MAX_NUMBER_ATTEMPTS - 1) + ] + ) else: assert not isinstance(exc_info.value, TransientRemoteException) + # no request retry on other remote exceptions + mock_sleep.assert_not_called() @pytest.mark.parametrize(