Skip to content
Snippets Groups Projects
Commit 68591a22 authored by vlorentz's avatar vlorentz
Browse files

Add max_matching_nodes parameter to /leaves

To match the new parameter in the gRPC API
parent 33726994
No related branches found
Tags v2.1.0
No related merge requests found
......@@ -52,7 +52,13 @@ class RemoteGraphClient(RPCClient):
return self.get("stats")
def leaves(
self, src, edges="*", direction="forward", max_edges=0, return_types="*"
self,
src,
edges="*",
direction="forward",
max_edges=0,
return_types="*",
max_matching_nodes=0,
):
return self.get_lines(
"leaves/{}".format(src),
......@@ -61,6 +67,7 @@ class RemoteGraphClient(RPCClient):
"direction": direction,
"max_edges": max_edges,
"return_types": return_types,
"max_matching_nodes": max_matching_nodes,
},
)
......@@ -137,10 +144,14 @@ class RemoteGraphClient(RPCClient):
},
)
def count_leaves(self, src, edges="*", direction="forward"):
def count_leaves(self, src, edges="*", direction="forward", max_matching_nodes=0):
return self.get(
"leaves/count/{}".format(src),
params={"edges": edges, "direction": direction},
params={
"edges": edges,
"direction": direction,
"max_matching_nodes": max_matching_nodes,
},
)
def count_neighbors(self, src, edges="*", direction="forward"):
......
# Copyright (C) 2021 The Software Heritage developers
# Copyright (C) 2021-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import functools
import inspect
import itertools
import re
import statistics
from typing import (
......@@ -150,9 +151,10 @@ class NaiveClient:
direction: str = "forward",
max_edges: int = 0,
return_types: str = "*",
max_matching_nodes: int = 0,
) -> Iterator[str]:
# TODO: max_edges
yield from filter_node_types(
leaves = filter_node_types(
return_types,
[
node
......@@ -161,6 +163,11 @@ class NaiveClient:
],
)
if max_matching_nodes > 0:
leaves = itertools.islice(leaves, max_matching_nodes)
return leaves
@check_arguments
def neighbors(
self,
......@@ -250,9 +257,19 @@ class NaiveClient:
@check_arguments
def count_leaves(
self, src: str, edges: str = "*", direction: str = "forward"
self,
src: str,
edges: str = "*",
direction: str = "forward",
max_matching_nodes: int = 0,
) -> int:
return len(list(self.leaves(src, edges, direction)))
return len(
list(
self.leaves(
src, edges, direction, max_matching_nodes=max_matching_nodes
)
)
)
@check_arguments
def count_neighbors(
......
......@@ -144,13 +144,15 @@ class GraphView(aiohttp.web.View):
else:
return s
def get_limit(self):
"""Validate HTTP query parameter `limit`, i.e., number of results"""
s = self.request.query.get("limit", "0")
def get_max_matching_nodes(self):
"""Validate HTTP query parameter `max_matching_nodes`, i.e., number of results"""
s = self.request.query.get("max_matching_nodes", "0")
try:
return int(s)
except ValueError:
raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}")
raise aiohttp.web.HTTPBadRequest(
text=f"invalid max_matching_nodes value: {s}"
)
def get_max_edges(self):
"""Validate HTTP query parameter 'max_edges', i.e.,
......@@ -249,6 +251,7 @@ class SimpleTraversalView(StreamingGraphView):
direction=self.get_direction(),
return_nodes=NodeFilter(types=self.get_return_types()),
mask=FieldMask(paths=["swhid"]),
max_matching_nodes=self.get_max_matching_nodes(),
)
if self.get_max_edges():
self.traversal_request.max_edges = self.get_max_edges()
......@@ -307,6 +310,7 @@ class CountView(GraphView):
direction=self.get_direction(),
return_nodes=NodeFilter(types=self.get_return_types()),
mask=FieldMask(paths=["swhid"]),
max_matching_nodes=self.get_max_matching_nodes(),
)
if self.get_max_edges():
self.traversal_request.max_edges = self.get_max_edges()
......
......@@ -43,6 +43,25 @@ def test_leaves(graph_client):
assert set(actual) == set(expected)
@pytest.mark.parametrize("max_matching_nodes", [0, 1, 2, 3, 4, 5, 10, 1 << 31])
def test_leaves_with_limit(graph_client, max_matching_nodes):
actual = list(
graph_client.leaves(TEST_ORIGIN_ID, max_matching_nodes=max_matching_nodes)
)
expected = [
"swh:1:cnt:0000000000000000000000000000000000000001",
"swh:1:cnt:0000000000000000000000000000000000000004",
"swh:1:cnt:0000000000000000000000000000000000000005",
"swh:1:cnt:0000000000000000000000000000000000000007",
]
if max_matching_nodes == 0:
assert set(actual) == set(expected)
else:
assert set(actual) <= set(expected)
assert len(actual) == min(4, max_matching_nodes)
def test_neighbors(graph_client):
actual = list(
graph_client.neighbors(
......@@ -326,6 +345,17 @@ def test_count(graph_client):
assert actual == 3
@pytest.mark.parametrize("max_matching_nodes", [0, 1, 2, 3, 4, 5, 10, 1 << 31])
def test_count_with_limit(graph_client, max_matching_nodes):
actual = graph_client.count_leaves(
TEST_ORIGIN_ID, max_matching_nodes=max_matching_nodes
)
if max_matching_nodes == 0:
assert actual == 4
else:
assert actual == min(4, max_matching_nodes)
def test_param_validation(graph_client):
with raises(GraphArgumentException) as exc_info: # SWHID not found
list(graph_client.leaves("swh:1:rel:00ffffffff000000000000000000000000000010"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment