From ce0d4b725742848227412a927ac728e268f3860c Mon Sep 17 00:00:00 2001
From: Antoine Lambert <anlambert@softwareheritage.org>
Date: Wed, 6 Nov 2024 11:27:10 +0100
Subject: [PATCH] api: Add an optional requests retry feature to the RPC client

Add a new enable_requests_retry flag to the swh.core.api.RPCClient
class, default to False, allowing to retry requests sent by the
client when encountering specific errors. Default policy is to retry
when connection errors and transient remote exceptions are raised.
Subclasses can change that policy by overriding the retry_policy
method. Such failed requests will be retry at most five times with a
delay of 10 seconds between each.
---
 swh/core/api/__init__.py              | 42 ++++++++++++++++++++++++++-
 swh/core/api/tests/test_rpc_client.py | 33 ++++++++++++++++++---
 2 files changed, 70 insertions(+), 5 deletions(-)

diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py
index c528dfe..cc680b7 100644
--- a/swh/core/api/__init__.py
+++ b/swh/core/api/__init__.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2015-2023  The Software Heritage developers
+# Copyright (C) 2015-2024  The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
@@ -26,8 +26,12 @@ from deprecated import deprecated
 from flask import Flask, Request, Response, abort, request
 import requests
 import sentry_sdk
+from tenacity.before_sleep import before_sleep_log
+from tenacity.wait import wait_fixed
 from werkzeug.exceptions import HTTPException
 
+from swh.core.retry import http_retry, retry_if_exception
+
 from .negotiation import Formatter as FormatterBase
 from .negotiation import Negotiator as NegotiatorBase
 from .negotiation import negotiate as _negotiate
@@ -41,6 +45,8 @@ from .serializers import (
     msgpack_loads,
 )
 
+RETRY_WAIT_INTERVAL = 10
+
 logger = logging.getLogger(__name__)
 
 
@@ -211,6 +217,10 @@ class RPCClient(metaclass=MetaRPCClient):
       reraise_exceptions: On server errors, if any of the exception classes in
         this list has the same name as the error name, then the exception will
         be instantiated and raised instead of a generic RemoteException.
+      enable_requests_retry: If set to :const:`True`, requests sent by the client will
+        be retried when encountering specific errors. Default policy is to retry when
+        connection errors or transient remote exceptions are raised. Subclasses can
+        change that policy by overriding the :meth:`retry_policy` method.
 
     """
 
@@ -238,6 +248,11 @@ class RPCClient(metaclass=MetaRPCClient):
     extra_type_decoders: Dict[str, Callable] = {}
     """Value of `extra_decoders` passed to `json_loads` or `msgpack_loads`
     to be able to deserialize more object types."""
+    enable_requests_retry: bool = False
+    """If set to :const:`True`, requests sent by the client will be retried
+    when encountering specific errors. Default policy is to retry when connection
+    errors or transient remote exceptions are raised. Subclasses can change that
+    policy by overriding the :meth:`retry_policy` method."""
 
     def __init__(
         self,
@@ -250,12 +265,15 @@ class RPCClient(metaclass=MetaRPCClient):
         adapter_kwargs: Optional[Dict[str, Any]] = None,
         api_exception: Optional[Type[Exception]] = None,
         reraise_exceptions: Optional[List[Type[Exception]]] = None,
+        enable_requests_retry: Optional[bool] = None,
         **kwargs,
     ):
         if api_exception:
             self.api_exception = api_exception
         if reraise_exceptions:
             self.reraise_exceptions = reraise_exceptions
+        if enable_requests_retry is not None:
+            self.enable_requests_retry = enable_requests_retry
         base_url = url if url.endswith("/") else url + "/"
         self.url = base_url
 
@@ -278,6 +296,28 @@ class RPCClient(metaclass=MetaRPCClient):
 
         self.chunk_size = chunk_size
 
+        if self.enable_requests_retry:
+
+            retry = http_retry(
+                retry=self.retry_policy,
+                wait=wait_fixed(RETRY_WAIT_INTERVAL),
+                before_sleep=before_sleep_log(logger, logging.WARNING),
+            )
+            setattr(self, "_get", retry(self._get))
+            setattr(self, "_post", retry(self._post))
+
+    def retry_policy(self, retry_state):
+        return retry_if_exception(
+            retry_state,
+            lambda e: (
+                isinstance(e, TransientRemoteException)
+                or (
+                    isinstance(e, self.api_exception)
+                    and isinstance(e.args[0], requests.exceptions.ConnectionError)
+                )
+            ),
+        )
+
     def _url(self, endpoint):
         return "%s%s" % (self.url, endpoint)
 
diff --git a/swh/core/api/tests/test_rpc_client.py b/swh/core/api/tests/test_rpc_client.py
index 6e641e6..2706d21 100644
--- a/swh/core/api/tests/test_rpc_client.py
+++ b/swh/core/api/tests/test_rpc_client.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2018-2022  The Software Heritage developers
+# Copyright (C) 2018-2024  The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
@@ -9,6 +9,7 @@ import pytest
 from requests.exceptions import ConnectionError
 
 from swh.core.api import (
+    RETRY_WAIT_INTERVAL,
     APIError,
     RemoteException,
     RPCClient,
@@ -16,6 +17,7 @@ from swh.core.api import (
     remote_api_endpoint,
 )
 from swh.core.api.serializers import exception_to_dict, msgpack_dumps
+from swh.core.retry import MAX_NUMBER_ATTEMPTS
 
 from .test_serializers import ExtraType, extra_decoders, extra_encoders
 
@@ -48,6 +50,7 @@ def rpc_client_class(requests_mock):
         extra_type_encoders = extra_encoders
         extra_type_decoders = extra_decoders
         reraise_exceptions = [ReraiseException]
+        enable_requests_retry = True
 
         def overridden_method(self, data):
             return "bar"
@@ -115,10 +118,11 @@ def test_client_request_too_large(rpc_client):
     assert exc_info.value.args[1].status_code == 413
 
 
-def test_client_connexion_error(rpc_client, requests_mock):
+def test_client_connexion_error(rpc_client, requests_mock, mocker):
     """
     ConnectionError should be wrapped and raised as an APIError.
     """
+    mock_sleep = mocker.patch("time.sleep")
     error_message = "unreachable host"
     requests_mock.post(
         re.compile("mock://example.com/connection_error"),
@@ -131,6 +135,14 @@ def test_client_connexion_error(rpc_client, requests_mock):
     assert type(exc_info.value.args[0]) is ConnectionError
     assert str(exc_info.value.args[0]) == error_message
 
+    # check request retries on connection errors
+    mock_sleep.assert_has_calls(
+        [
+            mocker.call(param)
+            for param in [RETRY_WAIT_INTERVAL] * (MAX_NUMBER_ATTEMPTS - 1)
+        ]
+    )
+
 
 def _exception_response(exception, status_code):
     def callback(request, context):
@@ -144,10 +156,11 @@ def _exception_response(exception, status_code):
     return callback
 
 
-def test_client_reraise_exception(rpc_client, requests_mock):
+def test_client_reraise_exception(rpc_client, requests_mock, mocker):
     """
     Exception caught server-side and whitelisted will be raised again client-side.
     """
+    mock_sleep = mocker.patch("time.sleep")
     error_message = "something went wrong"
     endpoint = "reraise_exception"
 
@@ -163,10 +176,12 @@ def test_client_reraise_exception(rpc_client, requests_mock):
         rpc_client._post(endpoint, data={})
 
     assert str(exc_info.value) == error_message
+    # no request retry for such exception
+    mock_sleep.assert_not_called()
 
 
 @pytest.mark.parametrize("status_code", [400, 500, 502, 503])
-def test_client_raise_remote_exception(rpc_client, requests_mock, status_code):
+def test_client_raise_remote_exception(rpc_client, requests_mock, status_code, mocker):
     """
     Exception caught server-side and not whitelisted will be wrapped and raised
     as a RemoteException client-side.
@@ -181,6 +196,7 @@ def test_client_raise_remote_exception(rpc_client, requests_mock, status_code):
             status_code=status_code,
         ),
     )
+    mock_sleep = mocker.patch("time.sleep")
 
     with pytest.raises(RemoteException) as exc_info:
         rpc_client._post(endpoint, data={})
@@ -189,8 +205,17 @@ def test_client_raise_remote_exception(rpc_client, requests_mock, status_code):
     assert str(exc_info.value.args[0]["message"]) == error_message
     if status_code in (502, 503):
         assert isinstance(exc_info.value, TransientRemoteException)
+        # check request retry on transient remote exception
+        mock_sleep.assert_has_calls(
+            [
+                mocker.call(param)
+                for param in [RETRY_WAIT_INTERVAL] * (MAX_NUMBER_ATTEMPTS - 1)
+            ]
+        )
     else:
         assert not isinstance(exc_info.value, TransientRemoteException)
+        # no request retry on other remote exceptions
+        mock_sleep.assert_not_called()
 
 
 @pytest.mark.parametrize(
-- 
GitLab