From 9af0c3a29458b77d85ecbfb4a50726f44114b047 Mon Sep 17 00:00:00 2001 From: Renaud Boyer <renaud.boyer@sofwareheritage.org> Date: Mon, 6 Jan 2025 18:32:41 +0100 Subject: [PATCH 1/5] Use drf serializers to sanitize query params --- swh/web/api/apiresponse.py | 6 ++- swh/web/api/tests/views/test_snapshot.py | 29 +++++++++++- swh/web/api/views/snapshot.py | 60 +++++++++++++++++++++--- 3 files changed, 86 insertions(+), 9 deletions(-) diff --git a/swh/web/api/apiresponse.py b/swh/web/api/apiresponse.py index ff57ca6d1..840f03fb5 100644 --- a/swh/web/api/apiresponse.py +++ b/swh/web/api/apiresponse.py @@ -13,7 +13,7 @@ from django.shortcuts import render from django.urls import get_resolver from django.utils.cache import add_never_cache_headers from django.utils.html import escape -from rest_framework.exceptions import APIException +from rest_framework.exceptions import APIException, ValidationError from rest_framework.request import Request from rest_framework.response import Response from rest_framework.utils.encoders import JSONEncoder @@ -200,6 +200,7 @@ def error_response( doc_data: documentation data for HTML response """ + error_data: dict[str, Any] error_data = { "exception": exception.__class__.__name__, "reason": str(exception), @@ -212,6 +213,9 @@ def error_response( error_code = 500 if isinstance(exception, BadInputExc): error_code = 400 + elif isinstance(exception, ValidationError): + error_code = 400 + error_data["reason"] = exception.detail elif isinstance(exception, UnauthorizedExc): error_code = 401 elif isinstance(exception, NotFoundExc): diff --git a/swh/web/api/tests/views/test_snapshot.py b/swh/web/api/tests/views/test_snapshot.py index 49b4961b7..d2ed6fc9c 100644 --- a/swh/web/api/tests/views/test_snapshot.py +++ b/swh/web/api/tests/views/test_snapshot.py @@ -2,10 +2,11 @@ # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information - +from http import HTTPStatus import random from hypothesis import given +import pytest from swh.model.hashutil import hash_to_hex from swh.model.model import Snapshot @@ -154,3 +155,29 @@ def test_api_snapshot_no_pull_request_branches_filtering( url = reverse("api-1-snapshot", url_args={"snapshot_id": snapshot["id"]}) resp = check_api_get_responses(api_client, url, status_code=200) assert any([b.startswith("refs/pull/") for b in resp.data["branches"]]) + + +@pytest.mark.parametrize( + "query_params,expected_key", + [ + ( + {"target_types": "content,an_invalid_target_type"}, + "target_types", + ), + ( + {"branches_count": "500/"}, + "branches_count", + ), + ], +) +def test_api_snapshot_invalid_params( + api_client, archive_data, snapshot, query_params, expected_key +): + url = reverse( + "api-1-snapshot", + url_args={"snapshot_id": snapshot}, + query_params=query_params, + ) + rv = check_api_get_responses(api_client, url, status_code=HTTPStatus.BAD_REQUEST) + assert rv.data["exception"] == "ValidationError" + assert expected_key in rv.data["reason"] diff --git a/swh/web/api/views/snapshot.py b/swh/web/api/views/snapshot.py index d4889c228..75b3e6d54 100644 --- a/swh/web/api/views/snapshot.py +++ b/swh/web/api/views/snapshot.py @@ -1,8 +1,10 @@ -# Copyright (C) 2018-2022 The Software Heritage developers +# Copyright (C) 2018-2024 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information +from typing import Optional +from rest_framework import serializers from rest_framework.request import Request from swh.web.api.apidoc import api_doc, format_docstring @@ -12,6 +14,48 @@ from swh.web.api.views.utils import api_lookup from swh.web.config import get_config from swh.web.utils import archive, reverse +VALID_TARGET_TYPES = [ + "content", + "directory", + "revision", + "release", + "snapshot", + "alias", +] + + +class SnapshotQuerySerializer(serializers.Serializer): + branches_from = serializers.CharField(default="", required=False) + branches_count = serializers.IntegerField(required=False, min_value=0) + target_types = serializers.CharField(default="", required=False) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + snapshot_content_max_size = get_config()["snapshot_content_max_size"] + self.fields["branches_count"].default = snapshot_content_max_size + self.fields["branches_count"].max_value = snapshot_content_max_size + + def validate_target_types(self, data: Optional[str]) -> list[str]: + """Parse and validate target types. + + Args: + data: an optional list of target_types separated by commas + + Raises: + serializers.ValidationError: an invalid target type was requested + + Returns: + A list of target types + """ + if not data: + return [] + requested_types = [t.strip() for t in data.split(",")] + if not all(e in VALID_TARGET_TYPES for e in requested_types): + raise serializers.ValidationError( + f"Valid target type values are: {', '.join(VALID_TARGET_TYPES)}" + ) + return requested_types + @api_route( r"/snapshot/(?P<snapshot_id>[0-9a-f]+)/", @@ -59,7 +103,8 @@ def api_snapshot(request: Request, snapshot_id: str): :>json string id: the unique identifier of the snapshot :statuscode 200: no error - :statuscode 400: an invalid snapshot identifier has been provided + :statuscode 400: an invalid snapshot identifier or invalid query parameters has + been provided :statuscode 404: requested snapshot cannot be found in the archive **Example:** @@ -69,12 +114,13 @@ def api_snapshot(request: Request, snapshot_id: str): :swh_web_api:`snapshot/6a3a2cf0b2b90ce7ae1cf0a221ed68035b686f5a/` """ - snapshot_content_max_size = get_config()["snapshot_content_max_size"] + serializer = SnapshotQuerySerializer(data=request.GET) + serializer.is_valid(raise_exception=True) + params = serializer.validated_data - branches_from = request.GET.get("branches_from", "") - branches_count = int(request.GET.get("branches_count", snapshot_content_max_size)) - target_types_str = request.GET.get("target_types", None) - target_types = target_types_str.split(",") if target_types_str else None + branches_from = params["branches_from"] + branches_count = params["branches_count"] + target_types = params["target_types"] results = api_lookup( archive.lookup_snapshot, -- GitLab From 155916bc1379678167ed144dd8e2be06a2e97570 Mon Sep 17 00:00:00 2001 From: Renaud Boyer <renaud.boyer@sofwareheritage.org> Date: Tue, 21 Jan 2025 09:32:36 +0100 Subject: [PATCH 2/5] api: a new query_params_serializer to validate query parameters --- swh/web/api/apiurls.py | 32 +++++++++++++++++++--- swh/web/api/tests/views/test_snapshot.py | 35 ++++++++++++++++++++++-- swh/web/api/views/snapshot.py | 20 +++++++------- 3 files changed, 70 insertions(+), 17 deletions(-) diff --git a/swh/web/api/apiurls.py b/swh/web/api/apiurls.py index 89e33d5ce..d3756b306 100644 --- a/swh/web/api/apiurls.py +++ b/swh/web/api/apiurls.py @@ -4,9 +4,10 @@ # See top-level LICENSE file for more information import functools -from typing import Dict, List, Literal, Optional +from typing import Dict, List, Literal, Optional, Union -from django.http.response import HttpResponseBase +from django.http.response import HttpResponse, HttpResponseBase +from rest_framework import serializers from rest_framework.decorators import api_view from swh.web.api import throttling @@ -73,6 +74,7 @@ def api_route( checksum_args: Optional[List[str]] = None, never_cache: bool = False, api_urls: APIUrls = api_urls, + query_params_serializer: Optional[type[serializers.Serializer]] = None, ): """ Decorator to ease the registration of an API endpoint @@ -81,12 +83,14 @@ def api_route( Args: url_pattern: the url pattern used by DRF to identify the API route view_name: the name of the API view associated to the route used to - reverse the url + reverse the url methods: array of HTTP methods supported by the API route throttle_scope: Named scope for rate limiting api_version: web API version checksum_args: list of view argument names holding checksum values never_cache: define if api response must be cached + query_params_serializer: an optional DRF serializer to validate the API call + query parameters """ @@ -97,9 +101,28 @@ def api_route( @api_view(methods) @throttling.throttle_scope(throttle_scope) @functools.wraps(f) - def api_view_f(request, **kwargs): + def api_view_f(request, **kwargs) -> Union[HttpResponse, HttpResponseBase]: + """Creates and calls a DRF view from the wrapped function. + + If `query_params_serializer` is set an extra kwarg `validated_query_params` + will be added to the view call containing the serializer's `validated_data`. + + Raises: + django.core.exceptions.ValidationError: query_params_serializer was + unable to validate the query parameters + + Returns: + An HttpResponse + """ # never_cache will be handled in apiresponse module request.never_cache = never_cache + + # a DRF serializer has been passed to validate the query parameters + if query_params_serializer: + serializer = query_params_serializer(data=request.GET) + serializer.is_valid(raise_exception=True) + kwargs["validated_query_params"] = serializer.validated_data + response = f(request, **kwargs) doc_data = None # check if response has been forwarded by api_doc decorator @@ -107,6 +130,7 @@ def api_route( doc_data = response["doc_data"] response = response["data"] # check if HTTP response needs to be created + api_response: Union[HttpResponse, HttpResponseBase] if not isinstance(response, HttpResponseBase): api_response = make_api_response( request, data=response, doc_data=doc_data diff --git a/swh/web/api/tests/views/test_snapshot.py b/swh/web/api/tests/views/test_snapshot.py index d2ed6fc9c..f5e46315f 100644 --- a/swh/web/api/tests/views/test_snapshot.py +++ b/swh/web/api/tests/views/test_snapshot.py @@ -11,6 +11,7 @@ import pytest from swh.model.hashutil import hash_to_hex from swh.model.model import Snapshot from swh.web.api.utils import enrich_snapshot +from swh.web.api.views.snapshot import VALID_TARGET_TYPES, SnapshotQuerySerializer from swh.web.tests.data import random_sha1 from swh.web.tests.helpers import check_api_get_responses, check_http_get_response from swh.web.tests.strategies import new_snapshot @@ -158,7 +159,7 @@ def test_api_snapshot_no_pull_request_branches_filtering( @pytest.mark.parametrize( - "query_params,expected_key", + "query_params,expected_error_key", [ ( {"target_types": "content,an_invalid_target_type"}, @@ -171,7 +172,7 @@ def test_api_snapshot_no_pull_request_branches_filtering( ], ) def test_api_snapshot_invalid_params( - api_client, archive_data, snapshot, query_params, expected_key + api_client, archive_data, snapshot, query_params, expected_error_key ): url = reverse( "api-1-snapshot", @@ -180,4 +181,32 @@ def test_api_snapshot_invalid_params( ) rv = check_api_get_responses(api_client, url, status_code=HTTPStatus.BAD_REQUEST) assert rv.data["exception"] == "ValidationError" - assert expected_key in rv.data["reason"] + assert expected_error_key in rv.data["reason"] + + +@pytest.mark.parametrize( + "target_types,expected", + [ + ( + ",".join(VALID_TARGET_TYPES), + VALID_TARGET_TYPES, + ), + ( + " content, directory", + ["content", "directory"], + ), + ( + "alias", + ["alias"], + ), + ], +) +def test_api_snapshot_serializer_valid_target_types(target_types, expected): + serializer = SnapshotQuerySerializer(data={"target_types": target_types}) + assert serializer.is_valid() + assert serializer.validated_data["target_types"] == expected + + +def test_api_snapshot_serializer_invalid_target_types(): + serializer = SnapshotQuerySerializer(data={"target_types": "branch"}) + assert not serializer.is_valid() diff --git a/swh/web/api/views/snapshot.py b/swh/web/api/views/snapshot.py index 75b3e6d54..f8da5ae99 100644 --- a/swh/web/api/views/snapshot.py +++ b/swh/web/api/views/snapshot.py @@ -2,7 +2,7 @@ # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Optional +from typing import Any, Optional from rest_framework import serializers from rest_framework.request import Request @@ -61,10 +61,15 @@ class SnapshotQuerySerializer(serializers.Serializer): r"/snapshot/(?P<snapshot_id>[0-9a-f]+)/", "api-1-snapshot", checksum_args=["snapshot_id"], + query_params_serializer=SnapshotQuerySerializer, ) @api_doc("/snapshot/", category="Archive") @format_docstring() -def api_snapshot(request: Request, snapshot_id: str): +def api_snapshot( + request: Request, + snapshot_id: str, + validated_query_params: dict[str, Any], +): """ .. http:get:: /api/1/snapshot/(snapshot_id)/ @@ -113,14 +118,9 @@ def api_snapshot(request: Request, snapshot_id: str): :swh_web_api:`snapshot/6a3a2cf0b2b90ce7ae1cf0a221ed68035b686f5a/` """ - - serializer = SnapshotQuerySerializer(data=request.GET) - serializer.is_valid(raise_exception=True) - params = serializer.validated_data - - branches_from = params["branches_from"] - branches_count = params["branches_count"] - target_types = params["target_types"] + branches_from = validated_query_params["branches_from"] + branches_count = validated_query_params["branches_count"] + target_types = validated_query_params["target_types"] results = api_lookup( archive.lookup_snapshot, -- GitLab From 2f1ec4fd95ddd7350785262facd8d2d8b794cfee Mon Sep 17 00:00:00 2001 From: Renaud Boyer <renaud.boyer@sofwareheritage.org> Date: Tue, 21 Jan 2025 15:16:57 +0100 Subject: [PATCH 3/5] Refactoring --- swh/web/api/apiurls.py | 2 +- swh/web/api/tests/views/test_snapshot.py | 7 ++- swh/web/api/views/snapshot.py | 68 ++++++++++++++---------- 3 files changed, 46 insertions(+), 31 deletions(-) diff --git a/swh/web/api/apiurls.py b/swh/web/api/apiurls.py index d3756b306..5431a59d0 100644 --- a/swh/web/api/apiurls.py +++ b/swh/web/api/apiurls.py @@ -119,7 +119,7 @@ def api_route( # a DRF serializer has been passed to validate the query parameters if query_params_serializer: - serializer = query_params_serializer(data=request.GET) + serializer = query_params_serializer(data=request.query_params.dict()) serializer.is_valid(raise_exception=True) kwargs["validated_query_params"] = serializer.validated_data diff --git a/swh/web/api/tests/views/test_snapshot.py b/swh/web/api/tests/views/test_snapshot.py index f5e46315f..0f61cb8af 100644 --- a/swh/web/api/tests/views/test_snapshot.py +++ b/swh/web/api/tests/views/test_snapshot.py @@ -9,9 +9,9 @@ from hypothesis import given import pytest from swh.model.hashutil import hash_to_hex -from swh.model.model import Snapshot +from swh.model.model import Snapshot, TargetType from swh.web.api.utils import enrich_snapshot -from swh.web.api.views.snapshot import VALID_TARGET_TYPES, SnapshotQuerySerializer +from swh.web.api.views.snapshot import SnapshotQuerySerializer from swh.web.tests.data import random_sha1 from swh.web.tests.helpers import check_api_get_responses, check_http_get_response from swh.web.tests.strategies import new_snapshot @@ -184,6 +184,9 @@ def test_api_snapshot_invalid_params( assert expected_error_key in rv.data["reason"] +VALID_TARGET_TYPES = [e.value for e in TargetType] + + @pytest.mark.parametrize( "target_types,expected", [ diff --git a/swh/web/api/views/snapshot.py b/swh/web/api/views/snapshot.py index f8da5ae99..f9dd907c2 100644 --- a/swh/web/api/views/snapshot.py +++ b/swh/web/api/views/snapshot.py @@ -2,11 +2,12 @@ # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information -from typing import Any, Optional +from typing import Any from rest_framework import serializers from rest_framework.request import Request +from swh.model.model import TargetType from swh.web.api.apidoc import api_doc, format_docstring from swh.web.api.apiurls import api_route from swh.web.api.utils import enrich_snapshot @@ -14,47 +15,58 @@ from swh.web.api.views.utils import api_lookup from swh.web.config import get_config from swh.web.utils import archive, reverse -VALID_TARGET_TYPES = [ - "content", - "directory", - "revision", - "release", - "snapshot", - "alias", -] +snapshot_content_max_size = get_config()["snapshot_content_max_size"] -class SnapshotQuerySerializer(serializers.Serializer): - branches_from = serializers.CharField(default="", required=False) - branches_count = serializers.IntegerField(required=False, min_value=0) - target_types = serializers.CharField(default="", required=False) +class TargetTypesField(serializers.Field): + """A DRF field to handle snapshot target types.""" + + def to_representation(self, value: list[str]) -> str: + """Serialize value. + + Args: + value: a list of target types + + Returns: + A comma separated list of target types + """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - snapshot_content_max_size = get_config()["snapshot_content_max_size"] - self.fields["branches_count"].default = snapshot_content_max_size - self.fields["branches_count"].max_value = snapshot_content_max_size + return ",".join(value) - def validate_target_types(self, data: Optional[str]) -> list[str]: - """Parse and validate target types. + def to_internal_value(self, data: str) -> list[str]: + """From a comma separated string to a list. + + Handles serialization and validation of the target types requested. Args: - data: an optional list of target_types separated by commas + data: a comma separated string of target types Raises: - serializers.ValidationError: an invalid target type was requested + serializers.ValidationError: one or more target types are not valid. Returns: A list of target types """ - if not data: - return [] - requested_types = [t.strip() for t in data.split(",")] - if not all(e in VALID_TARGET_TYPES for e in requested_types): + choices = [e.value for e in TargetType] + target_types = [t.strip() for t in data.split(",")] + if not all(e in choices for e in target_types): raise serializers.ValidationError( - f"Valid target type values are: {', '.join(VALID_TARGET_TYPES)}" + f"Valid target type values are: {', '.join(choices)}" ) - return requested_types + return target_types + + +class SnapshotQuerySerializer(serializers.Serializer): + """Snapshot query parameters serializer.""" + + branches_from = serializers.CharField(default="", required=False) + branches_count = serializers.IntegerField( + default=snapshot_content_max_size, + required=False, + min_value=0, + max_value=snapshot_content_max_size, + ) + target_types = TargetTypesField(default="", required=False) @api_route( -- GitLab From 9b4066670f94ef72f3806a7153c3ec594a42d31f Mon Sep 17 00:00:00 2001 From: Renaud Boyer <renaud.boyer@sofwareheritage.org> Date: Tue, 21 Jan 2025 15:19:26 +0100 Subject: [PATCH 4/5] Fix dates --- swh/web/api/apiresponse.py | 2 +- swh/web/api/apiurls.py | 2 +- swh/web/api/tests/views/test_snapshot.py | 2 +- swh/web/api/views/snapshot.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/swh/web/api/apiresponse.py b/swh/web/api/apiresponse.py index 840f03fb5..c7e232074 100644 --- a/swh/web/api/apiresponse.py +++ b/swh/web/api/apiresponse.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017-2022 The Software Heritage developers +# Copyright (C) 2017-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information diff --git a/swh/web/api/apiurls.py b/swh/web/api/apiurls.py index 5431a59d0..14b159eff 100644 --- a/swh/web/api/apiurls.py +++ b/swh/web/api/apiurls.py @@ -1,4 +1,4 @@ -# Copyright (C) 2017-2024 The Software Heritage developers +# Copyright (C) 2017-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information diff --git a/swh/web/api/tests/views/test_snapshot.py b/swh/web/api/tests/views/test_snapshot.py index 0f61cb8af..b0776ddad 100644 --- a/swh/web/api/tests/views/test_snapshot.py +++ b/swh/web/api/tests/views/test_snapshot.py @@ -1,4 +1,4 @@ -# Copyright (C) 2018-2021 The Software Heritage developers +# Copyright (C) 2018-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information diff --git a/swh/web/api/views/snapshot.py b/swh/web/api/views/snapshot.py index f9dd907c2..ec7920aa7 100644 --- a/swh/web/api/views/snapshot.py +++ b/swh/web/api/views/snapshot.py @@ -1,4 +1,4 @@ -# Copyright (C) 2018-2024 The Software Heritage developers +# Copyright (C) 2018-2025 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information -- GitLab From 910092123a794eb2d7f3dd54fd8696feb6416b9e Mon Sep 17 00:00:00 2001 From: Renaud Boyer <renaud.boyer@sofwareheritage.org> Date: Tue, 21 Jan 2025 15:25:32 +0100 Subject: [PATCH 5/5] Fix typing --- swh/web/api/apiurls.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/swh/web/api/apiurls.py b/swh/web/api/apiurls.py index 14b159eff..7ade53cc7 100644 --- a/swh/web/api/apiurls.py +++ b/swh/web/api/apiurls.py @@ -4,9 +4,9 @@ # See top-level LICENSE file for more information import functools -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional -from django.http.response import HttpResponse, HttpResponseBase +from django.http.response import HttpResponseBase from rest_framework import serializers from rest_framework.decorators import api_view @@ -101,7 +101,7 @@ def api_route( @api_view(methods) @throttling.throttle_scope(throttle_scope) @functools.wraps(f) - def api_view_f(request, **kwargs) -> Union[HttpResponse, HttpResponseBase]: + def api_view_f(request, **kwargs) -> HttpResponseBase: """Creates and calls a DRF view from the wrapped function. If `query_params_serializer` is set an extra kwarg `validated_query_params` @@ -130,7 +130,7 @@ def api_route( doc_data = response["doc_data"] response = response["data"] # check if HTTP response needs to be created - api_response: Union[HttpResponse, HttpResponseBase] + api_response: HttpResponseBase if not isinstance(response, HttpResponseBase): api_response = make_api_response( request, data=response, doc_data=doc_data -- GitLab