Skip to content
Snippets Groups Projects
Commit b4a0fe14 authored by Antoine Pietri's avatar Antoine Pietri
Browse files

Enable iteration on the graph

parent 9657658b
Branches graph_iter
No related tags found
No related merge requests found
......@@ -159,6 +159,18 @@ class Graph:
node_id = self.pid2node[node_id]
return GraphNode(self, node_id)
def __iter__(self):
for pid, pos in self.backend.pid2node:
yield self[pid]
def iter_prefix(self, prefix):
for pid, pos in self.backend.pid2node.iter_prefix(prefix):
yield self[pid]
def iter_type(self, pid_type):
for pid, pos in self.backend.pid2node.iter_type(pid_type):
yield self[pid]
@contextlib.contextmanager
def load(graph_path):
......
......@@ -236,36 +236,50 @@ class PidToIntMap(_OnDiskMap, MutableMapping):
f.write(str_to_bytes(pid))
f.write(struct.pack(INT_BIN_FMT, int))
def _find(self, pid_str: str) -> Tuple[int, int]:
"""lookup the integer identifier of a pid and its position
def _bisect_pos(self, pid_str: str) -> int:
"""bisect the position of the given identifier. If the identifier is
not found, the position of the pid immediately after is returned.
Args:
pid_str: the pid as a string
Returns:
a pair `(pid, pos)` with pid integer identifier and its logical
record position in the map
the logical record of the bisected position in the map
"""
if not isinstance(pid_str, str):
raise TypeError('PID must be a str, not ' + type(pid_str))
raise TypeError('PID must be a str, not {}'.format(type(pid_str)))
try:
target = str_to_bytes(pid_str) # desired PID as bytes
except ValueError:
raise ValueError('invalid PID: "{}"'.format(pid_str))
min = 0
max = self.length - 1
while (min <= max):
mid = (min + max) // 2
(pid, int) = self._get_bin_record(mid)
lo = 0
hi = self.length - 1
while lo < hi:
mid = (lo + hi) // 2
(pid, _value) = self._get_bin_record(mid)
if pid < target:
min = mid + 1
elif pid > target:
max = mid - 1
else: # pid == target
return (struct.unpack(INT_BIN_FMT, int)[0], mid)
lo = mid + 1
else:
hi = mid
return lo
def _find(self, pid_str: str) -> Tuple[int, int]:
"""lookup the integer identifier of a pid and its position
Args:
pid_str: the pid as a string
Returns:
a pair `(pid, pos)` with pid integer identifier and its logical
record position in the map
"""
pos = self._bisect_pos(pid_str)
pid_found, value = self._get_record(pos)
if pid_found == pid_str:
return (value, pos)
raise KeyError(pid_str)
def __getitem__(self, pid_str: str) -> int:
......@@ -292,6 +306,21 @@ class PidToIntMap(_OnDiskMap, MutableMapping):
for pos in range(self.length):
yield self._get_record(pos)
def iter_prefix(self, prefix: str):
swh, n, t, sha = prefix.split(':')
sha = sha.ljust(40, '0')
start_pid = ':'.join([swh, n, t, sha])
start = self._bisect_pos(start_pid)
for pos in range(start, self.length):
pid, value = self._get_record(pos)
if not pid.startswith(prefix):
break
yield pid, value
def iter_type(self, pid_type: str) -> Iterator[Tuple[str, int]]:
prefix = 'swh:1:{}:'.format(pid_type)
yield from self.iter_prefix(prefix)
class IntToPidMap(_OnDiskMap, MutableMapping):
"""memory mapped map from a continuous range of 0..N (8-byte long) integers to
......
......@@ -110,3 +110,13 @@ def test_count(graph):
.count_visit_nodes(edges='rel:rev,rev:rev') == 3)
assert (graph['swh:1:rev:0000000000000000000000000000000000000009']
.count_neighbors(direction='backward') == 3)
def test_iter_type(graph):
rev_list = list(graph.iter_type('rev'))
actual = [n.pid for n in rev_list]
expected = ['swh:1:rev:0000000000000000000000000000000000000003',
'swh:1:rev:0000000000000000000000000000000000000009',
'swh:1:rev:0000000000000000000000000000000000000013',
'swh:1:rev:0000000000000000000000000000000000000018']
assert expected == actual
......@@ -12,6 +12,7 @@ from itertools import islice
from swh.graph.pid import str_to_bytes, bytes_to_str
from swh.graph.pid import PidToIntMap, IntToPidMap
from swh.model.identifiers import PID_TYPES
class TestPidSerialization(unittest.TestCase):
......@@ -137,6 +138,14 @@ class TestPidToIntMap(unittest.TestCase):
os.unlink(fname2) # tmpdir will be cleaned even if we don't reach this
def test_iter_type(self):
for t in PID_TYPES:
first_20 = list(islice(self.map.iter_type(t), 20))
k = first_20[0][1]
expected = [('swh:1:{}:{:040x}'.format(t, i), i)
for i in range(k, k + 20)]
assert first_20 == expected
class TestIntToPidMap(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment