diff --git a/swh/web/api/apiresponse.py b/swh/web/api/apiresponse.py index ff57ca6d164791427ced8fe96ad5a61f9e5b0895..c7e232074587ece9e9c8d8f471ccd1403c2e390e 100644 --- a/swh/web/api/apiresponse.py +++ b/swh/web/api/apiresponse.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017-2022 The Software Heritage developers +# Copyright (C) 2017-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -13,7 +13,7 @@ from django.shortcuts import render from django.urls import get_resolver from django.utils.cache import add_never_cache_headers from django.utils.html import escape -from rest_framework.exceptions import APIException +from rest_framework.exceptions import APIException, ValidationError from rest_framework.request import Request from rest_framework.response import Response from rest_framework.utils.encoders import JSONEncoder @@ -200,6 +200,7 @@ def error_response( doc_data: documentation data for HTML response """ + error_data: dict[str, Any] error_data = { "exception": exception.__class__.__name__, "reason": str(exception), @@ -212,6 +213,9 @@ def error_response( error_code = 500 if isinstance(exception, BadInputExc): error_code = 400 + elif isinstance(exception, ValidationError): + error_code = 400 + error_data["reason"] = exception.detail elif isinstance(exception, UnauthorizedExc): error_code = 401 elif isinstance(exception, NotFoundExc): diff --git a/swh/web/api/apiurls.py b/swh/web/api/apiurls.py index 89e33d5ce0dd7b73fbb18fbd75b5b3397b6ec20b..7ade53cc746b38700574dbb08b4574a3191d12a1 100644 --- a/swh/web/api/apiurls.py +++ b/swh/web/api/apiurls.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017-2024 The Software Heritage developers +# Copyright (C) 2017-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -7,6 +7,7 @@ import functools from typing import Dict, List, Literal, Optional from django.http.response import HttpResponseBase +from rest_framework import serializers from rest_framework.decorators import api_view from swh.web.api import throttling @@ -73,6 +74,7 @@ def api_route( checksum_args: Optional[List[str]] = None, never_cache: bool = False, api_urls: APIUrls = api_urls, + query_params_serializer: Optional[type[serializers.Serializer]] = None, ): """ Decorator to ease the registration of an API endpoint @@ -81,12 +83,14 @@ def api_route( Args: url_pattern: the url pattern used by DRF to identify the API route view_name: the name of the API view associated to the route used to - reverse the url + reverse the url methods: array of HTTP methods supported by the API route throttle_scope: Named scope for rate limiting api_version: web API version checksum_args: list of view argument names holding checksum values never_cache: define if api response must be cached + query_params_serializer: an optional DRF serializer to validate the API call + query parameters """ @@ -97,9 +101,28 @@ def api_route( @api_view(methods) @throttling.throttle_scope(throttle_scope) @functools.wraps(f) - def api_view_f(request, **kwargs): + def api_view_f(request, **kwargs) -> HttpResponseBase: + """Creates and calls a DRF view from the wrapped function. + + If `query_params_serializer` is set an extra kwarg `validated_query_params` + will be added to the view call containing the serializer's `validated_data`. + + Raises: + django.core.exceptions.ValidationError: query_params_serializer was + unable to validate the query parameters + + Returns: + An HttpResponse + """ # never_cache will be handled in apiresponse module request.never_cache = never_cache + + # a DRF serializer has been passed to validate the query parameters + if query_params_serializer: + serializer = query_params_serializer(data=request.query_params.dict()) + serializer.is_valid(raise_exception=True) + kwargs["validated_query_params"] = serializer.validated_data + response = f(request, **kwargs) doc_data = None # check if response has been forwarded by api_doc decorator @@ -107,6 +130,7 @@ def api_route( doc_data = response["doc_data"] response = response["data"] # check if HTTP response needs to be created + api_response: HttpResponseBase if not isinstance(response, HttpResponseBase): api_response = make_api_response( request, data=response, doc_data=doc_data diff --git a/swh/web/api/tests/views/test_snapshot.py b/swh/web/api/tests/views/test_snapshot.py index 49b4961b74a77cf5322bc887e9d5fe8f3f578879..b0776ddadd0b8930836e6b236fe6d0f7c05b2b7b 100644 --- a/swh/web/api/tests/views/test_snapshot.py +++ b/swh/web/api/tests/views/test_snapshot.py @@ -1,15 +1,17 @@ -# Copyright (C) 2018-2021 The Software Heritage developers +# Copyright (C) 2018-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information - +from http import HTTPStatus import random from hypothesis import given +import pytest from swh.model.hashutil import hash_to_hex -from swh.model.model import Snapshot +from swh.model.model import Snapshot, TargetType from swh.web.api.utils import enrich_snapshot +from swh.web.api.views.snapshot import SnapshotQuerySerializer from swh.web.tests.data import random_sha1 from swh.web.tests.helpers import check_api_get_responses, check_http_get_response from swh.web.tests.strategies import new_snapshot @@ -154,3 +156,60 @@ def test_api_snapshot_no_pull_request_branches_filtering( url = reverse("api-1-snapshot", url_args={"snapshot_id": snapshot["id"]}) resp = check_api_get_responses(api_client, url, status_code=200) assert any([b.startswith("refs/pull/") for b in resp.data["branches"]]) + + +@pytest.mark.parametrize( + "query_params,expected_error_key", + [ + ( + {"target_types": "content,an_invalid_target_type"}, + "target_types", + ), + ( + {"branches_count": "500/"}, + "branches_count", + ), + ], +) +def test_api_snapshot_invalid_params( + api_client, archive_data, snapshot, query_params, expected_error_key +): + url = reverse( + "api-1-snapshot", + url_args={"snapshot_id": snapshot}, + query_params=query_params, + ) + rv = check_api_get_responses(api_client, url, status_code=HTTPStatus.BAD_REQUEST) + assert rv.data["exception"] == "ValidationError" + assert expected_error_key in rv.data["reason"] + + +VALID_TARGET_TYPES = [e.value for e in TargetType] + + +@pytest.mark.parametrize( + "target_types,expected", + [ + ( + ",".join(VALID_TARGET_TYPES), + VALID_TARGET_TYPES, + ), + ( + " content, directory", + ["content", "directory"], + ), + ( + "alias", + ["alias"], + ), + ], +) +def test_api_snapshot_serializer_valid_target_types(target_types, expected): + serializer = SnapshotQuerySerializer(data={"target_types": target_types}) + assert serializer.is_valid() + assert serializer.validated_data["target_types"] == expected + + +def test_api_snapshot_serializer_invalid_target_types(): + serializer = SnapshotQuerySerializer(data={"target_types": "branch"}) + assert not serializer.is_valid() diff --git a/swh/web/api/views/snapshot.py b/swh/web/api/views/snapshot.py index d4889c228d0f00834a6d7778e4130b24720eb7c9..ec7920aa77e477b3ca52ef692479f9aaf63161b4 100644 --- a/swh/web/api/views/snapshot.py +++ b/swh/web/api/views/snapshot.py @@ -1,10 +1,13 @@ -# Copyright (C) 2018-2022 The Software Heritage developers +# Copyright (C) 2018-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Any +from rest_framework import serializers from rest_framework.request import Request +from swh.model.model import TargetType from swh.web.api.apidoc import api_doc, format_docstring from swh.web.api.apiurls import api_route from swh.web.api.utils import enrich_snapshot @@ -12,15 +15,73 @@ from swh.web.api.views.utils import api_lookup from swh.web.config import get_config from swh.web.utils import archive, reverse +snapshot_content_max_size = get_config()["snapshot_content_max_size"] + + +class TargetTypesField(serializers.Field): + """A DRF field to handle snapshot target types.""" + + def to_representation(self, value: list[str]) -> str: + """Serialize value. + + Args: + value: a list of target types + + Returns: + A comma separated list of target types + """ + + return ",".join(value) + + def to_internal_value(self, data: str) -> list[str]: + """From a comma separated string to a list. + + Handles serialization and validation of the target types requested. + + Args: + data: a comma separated string of target types + + Raises: + serializers.ValidationError: one or more target types are not valid. + + Returns: + A list of target types + """ + choices = [e.value for e in TargetType] + target_types = [t.strip() for t in data.split(",")] + if not all(e in choices for e in target_types): + raise serializers.ValidationError( + f"Valid target type values are: {', '.join(choices)}" + ) + return target_types + + +class SnapshotQuerySerializer(serializers.Serializer): + """Snapshot query parameters serializer.""" + + branches_from = serializers.CharField(default="", required=False) + branches_count = serializers.IntegerField( + default=snapshot_content_max_size, + required=False, + min_value=0, + max_value=snapshot_content_max_size, + ) + target_types = TargetTypesField(default="", required=False) + @api_route( r"/snapshot/(?P<snapshot_id>[0-9a-f]+)/", "api-1-snapshot", checksum_args=["snapshot_id"], + query_params_serializer=SnapshotQuerySerializer, ) @api_doc("/snapshot/", category="Archive") @format_docstring() -def api_snapshot(request: Request, snapshot_id: str): +def api_snapshot( + request: Request, + snapshot_id: str, + validated_query_params: dict[str, Any], +): """ .. http:get:: /api/1/snapshot/(snapshot_id)/ @@ -59,7 +120,8 @@ def api_snapshot(request: Request, snapshot_id: str): :>json string id: the unique identifier of the snapshot :statuscode 200: no error - :statuscode 400: an invalid snapshot identifier has been provided + :statuscode 400: an invalid snapshot identifier or invalid query parameters has + been provided :statuscode 404: requested snapshot cannot be found in the archive **Example:** @@ -68,13 +130,9 @@ def api_snapshot(request: Request, snapshot_id: str): :swh_web_api:`snapshot/6a3a2cf0b2b90ce7ae1cf0a221ed68035b686f5a/` """ - - snapshot_content_max_size = get_config()["snapshot_content_max_size"] - - branches_from = request.GET.get("branches_from", "") - branches_count = int(request.GET.get("branches_count", snapshot_content_max_size)) - target_types_str = request.GET.get("target_types", None) - target_types = target_types_str.split(",") if target_types_str else None + branches_from = validated_query_params["branches_from"] + branches_count = validated_query_params["branches_count"] + target_types = validated_query_params["target_types"] results = api_lookup( archive.lookup_snapshot,