Skip to content
Snippets Groups Projects
Commit 6d7f36f0 authored by Antoine Pietri's avatar Antoine Pietri
Browse files

Create the aiohttp server

parent 998d3a4d
No related branches found
No related tags found
No related merge requests found
package org.softwareheritage.graph;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import py4j.GatewayServer;
import org.softwareheritage.graph.Graph;
import org.softwareheritage.graph.algo.NodeIdsConsumer;
import org.softwareheritage.graph.algo.Traversal;
public class Entry {
Graph graph;
public void load_graph(String graphBasename) throws IOException {
System.err.println("Loading graph " + graphBasename + " ...");
this.graph = new Graph(graphBasename);
System.err.println("Graph loaded.");
}
public void visit(long srcNodeId, String direction, String edgesFmt,
String clientFIFO) {
Traversal t = new Traversal(this.graph, direction, edgesFmt);
try {
FileOutputStream file = new FileOutputStream(clientFIFO);
DataOutputStream data = new DataOutputStream(file);
t.visitNodesVisitor(srcNodeId, (nodeId) -> {
try {
data.writeLong(nodeId);
} catch (IOException e) {
throw new RuntimeException("cannot write response to client: " + e);
}});
data.close();
} catch (IOException e) {
System.err.println("cannot write response to client: " + e);
}
}
}
package org.softwareheritage.graph;
import java.io.DataOutputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import py4j.GatewayServer;
import org.softwareheritage.graph.Graph;
import org.softwareheritage.graph.algo.NodeIdsConsumer;
import org.softwareheritage.graph.algo.Traversal;
public class Py4JEntryPoint {
static final int GATEWAY_SERVER_PORT = 25333;
Graph graph;
public Py4JEntryPoint(String graphBasename) throws IOException {
System.out.println("loading graph " + graphBasename + " ...");
this.graph = new Graph(graphBasename);
System.out.println("graph loaded.");
}
public void visit(long srcNodeId, String direction, String edgesFmt,
String clientFIFO) {
Traversal t = new Traversal(this.graph, direction, edgesFmt);
try {
FileOutputStream file = new FileOutputStream(clientFIFO);
DataOutputStream data = new DataOutputStream(file);
t.visitNodesVisitor(srcNodeId, (nodeId) -> {
try {
data.writeLong(nodeId);
} catch (IOException e) {
throw new RuntimeException("cannot write response to client: " + e);
}});
data.close();
} catch (IOException e) {
System.err.println("cannot write response to client: " + e);
}
}
public static void main(String[] args) {
if (args.length != 1) {
System.out.println("Usage: Py4JEntryPoint GRAPH_BASENAME");
System.exit(1);
}
GatewayServer server = null;
try {
server = new GatewayServer(new Py4JEntryPoint(args[0]), GATEWAY_SERVER_PORT);
} catch (IOException e) {
System.out.println("Could not load graph: " + e);
System.exit(2);
}
server.start();
System.out.println("swh-graph: Py4J gateway server started");
}
}
#!/usr/bin/python3
import os
import struct
import sys
import tempfile
from concurrent.futures import ThreadPoolExecutor
from py4j.java_gateway import JavaGateway, GatewayParameters
GATEWAY_SERVER_PORT = 25333
BUF_SIZE = 64*1024
BIN_FMT = '>q' # 64 bit integer, big endian
def print_node_ids(fname):
with open(fname, 'rb') as f:
data = f.read(BUF_SIZE)
while(data):
for data in struct.iter_unpack(BIN_FMT, data):
print(data[0])
data = f.read(BUF_SIZE)
if __name__ == '__main__':
try:
node_id = int(sys.argv[1])
except IndexError:
print('Usage: py4jcli NODE_ID')
sys.exit(1)
gw_params = GatewayParameters(port=GATEWAY_SERVER_PORT)
gateway = JavaGateway(gateway_parameters=gw_params)
with tempfile.TemporaryDirectory() as tmpdirname:
cli_fifo = os.path.join(tmpdirname, 'swh-graph.fifo')
os.mkfifo(cli_fifo)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(print_node_ids, cli_fifo)
gateway.entry_point.visit(node_id, 'forward', '*', cli_fifo)
_result = future.result()
gateway.shutdown()
# Copyright (C) 2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import argparse
import aiohttp.web
from swh.core.api.asynchronous import RPCServerApp
from swh.graph.server.backend import Backend
async def index(request):
return aiohttp.web.Response(body="SWH Graph API server")
async def visit(request):
node_id = int(request.match_info['id'])
response = aiohttp.web.StreamResponse(status=200)
await response.prepare(request)
async for node_id in request.app['backend'].visit(node_id):
await response.write('{}\n'.format(node_id).encode())
await response.write_eof()
return response
def make_app(backend, **kwargs):
app = RPCServerApp(**kwargs)
app.router.add_route('GET', '/', index)
# Endpoints used by the web API
app.router.add_route('GET', '/visit/{id}', visit)
app['backend'] = backend
return app
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--host', default='0.0.0.0')
parser.add_argument('--port', type=int, default=5009)
parser.add_argument('--graph', required=True)
args = parser.parse_args()
backend = Backend(graph_path=args.graph)
app = make_app(backend=backend)
with backend:
aiohttp.web.run_app(app, host=args.host, port=args.port)
if __name__ == '__main__':
main()
import asyncio
import os
import struct
import sys
import tempfile
from py4j.java_gateway import JavaGateway
GATEWAY_SERVER_PORT = 25335
BUF_SIZE = 64*1024
BIN_FMT = '>q' # 64 bit integer, big endian
async def read_node_ids(fname):
loop = asyncio.get_event_loop()
with open(fname, 'rb') as f:
while True:
data = await loop.run_in_executor(None, f.read, BUF_SIZE)
if not data:
break
for data in struct.iter_unpack(BIN_FMT, data):
yield data[0]
class Backend:
def __init__(self, graph_path):
self.gateway = None
self.entry = None
self.graph_path = graph_path
def __enter__(self):
self.gateway = JavaGateway.launch_gateway(
port=GATEWAY_SERVER_PORT,
classpath='java/server/target/swh-graph-0.0.2-jar-with-dependencies.jar',
die_on_exit=True,
redirect_stdout=sys.stdout,
redirect_stderr=sys.stderr,
)
self.entry = self.gateway.jvm.org.softwareheritage.graph.Entry()
self.entry.load_graph(self.graph_path)
# "/home/seirl/swh-graph/sample/big/compressed/swh-graph")
def __exit__(self):
self.gateway.shutdown()
async def visit(self, node_id):
loop = asyncio.get_event_loop()
with tempfile.TemporaryDirectory() as tmpdirname:
cli_fifo = os.path.join(tmpdirname, 'swh-graph.fifo')
os.mkfifo(cli_fifo)
def _visit():
return self.entry.visit(node_id, 'forward', '*', cli_fifo)
java_call = loop.run_in_executor(None, _visit)
async for node_id in read_node_ids(cli_fifo):
yield node_id
await java_call
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment