Skip to content
Snippets Groups Projects
Commit c6b2b577 authored by Nicolas Dandrimont's avatar Nicolas Dandrimont
Browse files

dumb loader: add support for extra requests kwargs

This is useful to override the default settings of the requests Session,
e.g. certificate verification of connect/read timeouts.
parent f51d542f
No related branches found
No related tags found
1 merge request!177Add support for passing extra arguments to urllib3 and requests (including timeouts and certificate verification)
......@@ -6,11 +6,12 @@
from __future__ import annotations
from collections import defaultdict
import copy
import logging
import stat
import struct
from tempfile import SpooledTemporaryFile
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Set, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Set, cast
import urllib.parse
from dulwich.errors import NotGitRepository
......@@ -28,17 +29,26 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
HEADERS = {"User-Agent": "Software Heritage dumb Git loader"}
def requests_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Inject User-Agent header in the requests kwargs"""
ret = copy.deepcopy(kwargs)
ret.setdefault("headers", {}).update(
{"User-Agent": "Software Heritage dumb Git loader"}
)
ret.setdefault("timeout", (120, 60))
return ret
@http_retry(
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def check_protocol(repo_url: str) -> bool:
def check_protocol(repo_url: str, requests_extra_kwargs: Dict[str, Any] = {}) -> bool:
"""Checks if a git repository can be cloned using the dumb protocol.
Args:
repo_url: Base URL of a git repository
requests_extra_kwargs: extra keyword arguments to be passed to requests,
e.g. `timeout`, `verify`.
Returns:
Whether the dumb protocol is supported.
......@@ -50,7 +60,7 @@ def check_protocol(repo_url: str) -> bool:
repo_url.rstrip("/") + "/", "info/refs?service=git-upload-pack/"
)
logger.debug("Fetching %s", url)
response = requests.get(url, headers=HEADERS)
response = requests.get(url, **requests_kwargs(requests_extra_kwargs))
response.raise_for_status()
content_type = response.headers.get("Content-Type")
return (
......@@ -73,10 +83,18 @@ class GitObjectsFetcher:
Args:
repo_url: Base URL of a git repository
base_repo: State of repository archived by Software Heritage
requests_extra_kwargs: extra keyword arguments to be passed to requests,
e.g. `timeout`, `verify`.
"""
def __init__(self, repo_url: str, base_repo: RepoRepresentation):
def __init__(
self,
repo_url: str,
base_repo: RepoRepresentation,
requests_extra_kwargs: Dict[str, Any] = {},
):
self._session = requests.Session()
self.requests_extra_kwargs = requests_extra_kwargs
self.repo_url = repo_url
self.base_repo = base_repo
self.objects: Dict[bytes, Set[bytes]] = defaultdict(set)
......@@ -130,7 +148,7 @@ class GitObjectsFetcher:
def _http_get(self, path: str) -> SpooledTemporaryFile:
url = urllib.parse.urljoin(self.repo_url.rstrip("/") + "/", path)
logger.debug("Fetching %s", url)
response = self._session.get(url, headers=HEADERS)
response = self._session.get(url, **requests_kwargs(self.requests_extra_kwargs))
response.raise_for_status()
buffer = SpooledTemporaryFile(max_size=100 * 1024 * 1024)
for chunk in response.iter_content(chunk_size=10 * 1024 * 1024):
......
......@@ -180,6 +180,7 @@ class GitLoader(BaseGitLoader):
pack_size_bytes: int = 4 * 1024 * 1024 * 1024,
temp_file_cutoff: int = 100 * 1024 * 1024,
urllib3_extra_kwargs: Dict[str, Any] = {},
requests_extra_kwargs: Dict[str, Any] = {},
**kwargs: Any,
):
"""Initialize the bulk updater.
......@@ -206,6 +207,7 @@ class GitLoader(BaseGitLoader):
self.ext_refs: Dict[bytes, Optional[Tuple[int, bytes]]] = {}
self.repo_pack_size_bytes = 0
self.urllib3_extra_kwargs = urllib3_extra_kwargs
self.requests_extra_kwargs = requests_extra_kwargs
def fetch_pack_from_origin(
self,
......@@ -371,7 +373,7 @@ class GitLoader(BaseGitLoader):
# by the fetch_pack operation when encountering a repository with
# dumb transfer protocol so we check if the repository supports it
# here to continue the loading if it is the case
self.dumb = dumb.check_protocol(self.origin.url)
self.dumb = dumb.check_protocol(self.origin.url, self.requests_extra_kwargs)
if not self.dumb:
raise
......@@ -379,7 +381,11 @@ class GitLoader(BaseGitLoader):
"Protocol used for communication: %s", "dumb" if self.dumb else "smart"
)
if self.dumb:
self.dumb_fetcher = dumb.GitObjectsFetcher(self.origin.url, base_repo)
self.dumb_fetcher = dumb.GitObjectsFetcher(
self.origin.url,
base_repo,
requests_extra_kwargs=self.requests_extra_kwargs,
)
self.dumb_fetcher.fetch_object_ids()
self.remote_refs = utils.filter_refs(self.dumb_fetcher.refs)
self.symbolic_refs = utils.filter_refs(self.dumb_fetcher.head)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment