From 40442b4d6bf859234c1f551e7f9d7ac7a01b337c Mon Sep 17 00:00:00 2001
From: Daniele Serafini <danseraf@softwareheritage.org>
Date: Wed, 26 Feb 2020 16:27:36 +0100
Subject: [PATCH] scanner: asynchronous operation, type annotation

---
 swh/scanner/cli.py     |  52 +++++++++-----
 swh/scanner/py.typed   |   1 +
 swh/scanner/scanner.py | 152 +++++++++++++++++++++++++----------------
 3 files changed, 131 insertions(+), 74 deletions(-)
 create mode 100644 swh/scanner/py.typed

diff --git a/swh/scanner/cli.py b/swh/scanner/cli.py
index dbb3b61..245a44a 100644
--- a/swh/scanner/cli.py
+++ b/swh/scanner/cli.py
@@ -4,35 +4,53 @@
 # See top-level LICENSE file for more information
 
 import click
+import asyncio
+import os
+from pathlib import PosixPath
+from urllib.parse import urlparse
+
+from .scanner import run
+from .exceptions import InvalidPath
+from .logger import setup_logger, log_counters
+from .model import Tree
 
 from swh.core.cli import CONTEXT_SETTINGS
-from swh.scanner.scanner import run
 
 
-@click.group(name='scanner', context_settings=CONTEXT_SETTINGS)
-@click.pass_context
-def scanner(ctx):
-    '''Software Heritage Scanner tools.'''
-    pass
+def parse_url(url):
+    if url.port == 80:
+        return 'https://' + url.hostname
+    else:
+        return url.geturl()
 
 
-@scanner.command(name='scan')
+@click.command(context_settings=CONTEXT_SETTINGS)
 @click.argument('path', required=True)
 @click.option('--host', '-h', default='localhost',
               metavar='IP', show_default=True,
               help="web api endpoint ip")
-@click.option('--port', '-p', default='5080',
+@click.option('--port', '-p', default='',
               metavar='PORT', show_default=True,
               help="web api endpoint port")
-@click.pass_context
-def scan(ctx, path, host, port):
-    result = run(path, host, port)
-    print(result)
-
-
-def main():
-    return scanner(auto_envvar_prefix='SWH_SCANNER')
+@click.option('--debug/--no-debug', default=True,
+              help="enable debug")
+@click.option('--verbose', '-v', is_flag=True, default=False,
+              help="show debug information")
+def scanner(path, host, port, debug, verbose):
+    """Software Heritage tool to scan the source code of a project"""
+    if not os.path.exists(path):
+        raise InvalidPath(path)
+
+    if debug:
+        setup_logger(bool(verbose))
+
+    url = parse_url(urlparse('https://%s:%s' % (host, port)))
+    source_tree = Tree(None, PosixPath(path))
+    loop = asyncio.get_event_loop()
+    loop.run_until_complete(run(path, url, source_tree))
+    source_tree.show()
+    log_counters()
 
 
 if __name__ == '__main__':
-    main()
+    scanner()
diff --git a/swh/scanner/py.typed b/swh/scanner/py.typed
new file mode 100644
index 0000000..1242d43
--- /dev/null
+++ b/swh/scanner/py.typed
@@ -0,0 +1 @@
+# Marker file for PEP 561.
diff --git a/swh/scanner/scanner.py b/swh/scanner/scanner.py
index affc0f3..11065e2 100644
--- a/swh/scanner/scanner.py
+++ b/swh/scanner/scanner.py
@@ -3,12 +3,17 @@
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
 
-import requests
 import os
-import json
 import itertools
+import asyncio
+import aiohttp
+from typing import List, Dict, Tuple, Generator, Iterator
 from pathlib import PosixPath
 
+from .logger import log_queries
+from .exceptions import APIError
+from .model import Tree
+
 from swh.model.cli import pid_of_file, pid_of_dir
 from swh.model.identifiers import (
         parse_persistent_identifier,
@@ -16,33 +21,64 @@ from swh.model.identifiers import (
 )
 
 
-def pids_discovery(pids, host, port):
-    """
+async def pids_discovery(
+        pids: List[str], session: aiohttp.ClientSession, url: str,
+        ) -> Dict[str, Dict[str, bool]]:
+    """API Request to get information about the persistent identifiers given in
+    input.
+
     Args:
-        pids list(str): A list of persistent identifier
+        pids: a list of persistent identifier
+
     Returns:
         A dictionary with:
-        key(str): persistent identifier
-        value(dict):
-            value['known'] = True if pid is found
-            value['known'] = False if pid is not found
+        key: persistent identifier searched
+        value:
+            value['known'] = True if the pid is found
+            value['known'] = False if the pid is not found
+
     """
-    endpoint = 'http://%s:%s/api/1/known/' % (host, port)
-    req = requests.post(endpoint, json=pids)
-    resp = req.text
-    return json.loads(resp)
+    endpoint = url + '/api/1/known/'
+    chunk_size = 1000
+    requests = []
+
+    log_queries(len(pids))
 
+    def get_chunk(pids):
+        for i in range(0, len(pids), chunk_size):
+            yield pids[i:i + chunk_size]
 
-def get_sub_paths(path):
-    """Find the persistent identifier of the paths and files under
-    a given path.
+    async def make_request(pids):
+        async with session.post(endpoint, json=pids) as resp:
+            if resp.status != 200:
+                error_message = '%s with given values %s' % (
+                    resp.text, str(pids))
+                raise APIError(error_message)
+            return await resp.json()
+
+    if len(pids) > chunk_size:
+        for pids_chunk in get_chunk(pids):
+            requests.append(asyncio.create_task(
+                make_request(pids_chunk)))
+
+        res = await asyncio.gather(*requests)
+        # concatenate list of dictionaries
+        return dict(itertools.chain.from_iterable(e.items() for e in res))
+    else:
+        return await make_request(pids)
+
+
+def get_subpaths(
+        path: PosixPath) -> Generator[Tuple[PosixPath, str], None, None]:
+    """Find the persistent identifier of the directories and files under a
+    given path.
 
     Args:
-        path(PosixPath): the entry root
+        path: the root path
 
     Yields:
-        tuple(path, pid): pairs of path and the relative persistent
-        identifier
+        pairs of: path, the relative persistent identifier
+
     """
     def pid_of(path):
         if path.is_dir():
@@ -52,57 +88,59 @@ def get_sub_paths(path):
 
     dirpath, dnames, fnames = next(os.walk(path))
     for node in itertools.chain(dnames, fnames):
-        path = PosixPath(dirpath).joinpath(node)
-        yield (path, pid_of(path))
+        sub_path = PosixPath(dirpath).joinpath(node)
+        yield (sub_path, pid_of(sub_path))
 
 
-def parse_path(path, host, port):
-    """Check if the sub paths of the given path is present in the
+async def parse_path(
+        path: PosixPath, session: aiohttp.ClientSession, url: str
+        ) -> Iterator[Tuple[str, str, bool]]:
+    """Check if the sub paths of the given path are present in the
     archive or not.
+
     Args:
-        path(PosixPath): The source path
-        host(str): ip for the api request
-        port(str): port for the api request
-    Yields:
-        a tuple with the path found, the persistent identifier
-        relative to the path and a boolean: False if not found,
-        True if found.
+        path: the source path
+        url: url for the API request
+
+    Returns:
+        a map containing tuples with: a subpath of the given path,
+        the pid of the subpath and the result of the api call
+
     """
-    pid_map = dict(get_sub_paths(path))
-    parsed_pids = pids_discovery(list(pid_map.values()), host, port)
+    parsed_paths = dict(get_subpaths(path))
+    parsed_pids = await pids_discovery(
+        list(parsed_paths.values()), session, url)
+
+    def unpack(tup):
+        subpath, pid = tup
+        return (subpath, pid, parsed_pids[pid]['known'])
+
+    return map(unpack, parsed_paths.items())
 
-    for sub_path, pid in pid_map.items():
-        yield (sub_path, pid, parsed_pids[pid]['known'])
 
+async def run(
+        root: PosixPath, url: str, source_tree: Tree) -> None:
+    """Start scanning from the given root.
+
+    It fill the source tree with the path discovered.
 
-def run(root, host, port):
-    """Scan the given root
     Args:
-        path: the path to scan
-        host(str): ip for the api request
-        port(str): port for the api request
-    Returns:
-        A set containing pairs of the path discovered and the
-        relative persistent identifier
-    """
-    def _scan(root, host, port, accum):
-        assert root not in accum
+        root: the root path to scan
+        url: url for the API request
 
-        next_paths = []
-        for path, pid, found in parse_path(root, host, port):
+    """
+    async def _scan(root, session, url, source_tree):
+        for path, pid, found in await parse_path(root, session, url):
             obj_type = parse_persistent_identifier(pid).object_type
 
-            if obj_type == CONTENT and found:
-                accum.add((str(path), pid))
+            if obj_type == CONTENT:
+                source_tree.addNode(path, pid if found else None)
             elif obj_type == DIRECTORY:
                 if found:
-                    accum.add((str(path), pid))
+                    source_tree.addNode(path, pid)
                 else:
-                    next_paths.append(path)
-
-        for new_path in next_paths:
-            accum = _scan(new_path, host, port, accum)
-
-        return accum
+                    source_tree.addNode(path)
+                    await _scan(path, session, url, source_tree)
 
-    return _scan(root, host, port, set())
+    async with aiohttp.ClientSession() as session:
+        await _scan(root, session, url, source_tree)
-- 
GitLab