Skip to content
Snippets Groups Projects
Commit 68591a22 authored by vlorentz's avatar vlorentz
Browse files

Add max_matching_nodes parameter to /leaves

To match the new parameter in the gRPC API
parent 33726994
No related branches found
No related tags found
No related merge requests found
...@@ -52,7 +52,13 @@ class RemoteGraphClient(RPCClient): ...@@ -52,7 +52,13 @@ class RemoteGraphClient(RPCClient):
return self.get("stats") return self.get("stats")
def leaves( def leaves(
self, src, edges="*", direction="forward", max_edges=0, return_types="*" self,
src,
edges="*",
direction="forward",
max_edges=0,
return_types="*",
max_matching_nodes=0,
): ):
return self.get_lines( return self.get_lines(
"leaves/{}".format(src), "leaves/{}".format(src),
...@@ -61,6 +67,7 @@ class RemoteGraphClient(RPCClient): ...@@ -61,6 +67,7 @@ class RemoteGraphClient(RPCClient):
"direction": direction, "direction": direction,
"max_edges": max_edges, "max_edges": max_edges,
"return_types": return_types, "return_types": return_types,
"max_matching_nodes": max_matching_nodes,
}, },
) )
...@@ -137,10 +144,14 @@ class RemoteGraphClient(RPCClient): ...@@ -137,10 +144,14 @@ class RemoteGraphClient(RPCClient):
}, },
) )
def count_leaves(self, src, edges="*", direction="forward"): def count_leaves(self, src, edges="*", direction="forward", max_matching_nodes=0):
return self.get( return self.get(
"leaves/count/{}".format(src), "leaves/count/{}".format(src),
params={"edges": edges, "direction": direction}, params={
"edges": edges,
"direction": direction,
"max_matching_nodes": max_matching_nodes,
},
) )
def count_neighbors(self, src, edges="*", direction="forward"): def count_neighbors(self, src, edges="*", direction="forward"):
......
# Copyright (C) 2021 The Software Heritage developers # Copyright (C) 2021-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution # See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version # License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information # See top-level LICENSE file for more information
import functools import functools
import inspect import inspect
import itertools
import re import re
import statistics import statistics
from typing import ( from typing import (
...@@ -150,9 +151,10 @@ class NaiveClient: ...@@ -150,9 +151,10 @@ class NaiveClient:
direction: str = "forward", direction: str = "forward",
max_edges: int = 0, max_edges: int = 0,
return_types: str = "*", return_types: str = "*",
max_matching_nodes: int = 0,
) -> Iterator[str]: ) -> Iterator[str]:
# TODO: max_edges # TODO: max_edges
yield from filter_node_types( leaves = filter_node_types(
return_types, return_types,
[ [
node node
...@@ -161,6 +163,11 @@ class NaiveClient: ...@@ -161,6 +163,11 @@ class NaiveClient:
], ],
) )
if max_matching_nodes > 0:
leaves = itertools.islice(leaves, max_matching_nodes)
return leaves
@check_arguments @check_arguments
def neighbors( def neighbors(
self, self,
...@@ -250,9 +257,19 @@ class NaiveClient: ...@@ -250,9 +257,19 @@ class NaiveClient:
@check_arguments @check_arguments
def count_leaves( def count_leaves(
self, src: str, edges: str = "*", direction: str = "forward" self,
src: str,
edges: str = "*",
direction: str = "forward",
max_matching_nodes: int = 0,
) -> int: ) -> int:
return len(list(self.leaves(src, edges, direction))) return len(
list(
self.leaves(
src, edges, direction, max_matching_nodes=max_matching_nodes
)
)
)
@check_arguments @check_arguments
def count_neighbors( def count_neighbors(
......
...@@ -144,13 +144,15 @@ class GraphView(aiohttp.web.View): ...@@ -144,13 +144,15 @@ class GraphView(aiohttp.web.View):
else: else:
return s return s
def get_limit(self): def get_max_matching_nodes(self):
"""Validate HTTP query parameter `limit`, i.e., number of results""" """Validate HTTP query parameter `max_matching_nodes`, i.e., number of results"""
s = self.request.query.get("limit", "0") s = self.request.query.get("max_matching_nodes", "0")
try: try:
return int(s) return int(s)
except ValueError: except ValueError:
raise aiohttp.web.HTTPBadRequest(text=f"invalid limit value: {s}") raise aiohttp.web.HTTPBadRequest(
text=f"invalid max_matching_nodes value: {s}"
)
def get_max_edges(self): def get_max_edges(self):
"""Validate HTTP query parameter 'max_edges', i.e., """Validate HTTP query parameter 'max_edges', i.e.,
...@@ -249,6 +251,7 @@ class SimpleTraversalView(StreamingGraphView): ...@@ -249,6 +251,7 @@ class SimpleTraversalView(StreamingGraphView):
direction=self.get_direction(), direction=self.get_direction(),
return_nodes=NodeFilter(types=self.get_return_types()), return_nodes=NodeFilter(types=self.get_return_types()),
mask=FieldMask(paths=["swhid"]), mask=FieldMask(paths=["swhid"]),
max_matching_nodes=self.get_max_matching_nodes(),
) )
if self.get_max_edges(): if self.get_max_edges():
self.traversal_request.max_edges = self.get_max_edges() self.traversal_request.max_edges = self.get_max_edges()
...@@ -307,6 +310,7 @@ class CountView(GraphView): ...@@ -307,6 +310,7 @@ class CountView(GraphView):
direction=self.get_direction(), direction=self.get_direction(),
return_nodes=NodeFilter(types=self.get_return_types()), return_nodes=NodeFilter(types=self.get_return_types()),
mask=FieldMask(paths=["swhid"]), mask=FieldMask(paths=["swhid"]),
max_matching_nodes=self.get_max_matching_nodes(),
) )
if self.get_max_edges(): if self.get_max_edges():
self.traversal_request.max_edges = self.get_max_edges() self.traversal_request.max_edges = self.get_max_edges()
......
...@@ -43,6 +43,25 @@ def test_leaves(graph_client): ...@@ -43,6 +43,25 @@ def test_leaves(graph_client):
assert set(actual) == set(expected) assert set(actual) == set(expected)
@pytest.mark.parametrize("max_matching_nodes", [0, 1, 2, 3, 4, 5, 10, 1 << 31])
def test_leaves_with_limit(graph_client, max_matching_nodes):
actual = list(
graph_client.leaves(TEST_ORIGIN_ID, max_matching_nodes=max_matching_nodes)
)
expected = [
"swh:1:cnt:0000000000000000000000000000000000000001",
"swh:1:cnt:0000000000000000000000000000000000000004",
"swh:1:cnt:0000000000000000000000000000000000000005",
"swh:1:cnt:0000000000000000000000000000000000000007",
]
if max_matching_nodes == 0:
assert set(actual) == set(expected)
else:
assert set(actual) <= set(expected)
assert len(actual) == min(4, max_matching_nodes)
def test_neighbors(graph_client): def test_neighbors(graph_client):
actual = list( actual = list(
graph_client.neighbors( graph_client.neighbors(
...@@ -326,6 +345,17 @@ def test_count(graph_client): ...@@ -326,6 +345,17 @@ def test_count(graph_client):
assert actual == 3 assert actual == 3
@pytest.mark.parametrize("max_matching_nodes", [0, 1, 2, 3, 4, 5, 10, 1 << 31])
def test_count_with_limit(graph_client, max_matching_nodes):
actual = graph_client.count_leaves(
TEST_ORIGIN_ID, max_matching_nodes=max_matching_nodes
)
if max_matching_nodes == 0:
assert actual == 4
else:
assert actual == min(4, max_matching_nodes)
def test_param_validation(graph_client): def test_param_validation(graph_client):
with raises(GraphArgumentException) as exc_info: # SWHID not found with raises(GraphArgumentException) as exc_info: # SWHID not found
list(graph_client.leaves("swh:1:rel:00ffffffff000000000000000000000000000010")) list(graph_client.leaves("swh:1:rel:00ffffffff000000000000000000000000000010"))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment