Skip to content
Snippets Groups Projects
Commit 4f85d9c4 authored by David Douard's avatar David Douard
Browse files

Extract the BaseDb class from swh-storage

This class is meant to be used not only to wrap storage-like db.

Also rewrite the test_logger using pytest-postgresql fixture so
we do not need pifpaf anymore.
parent 17d4a542
No related branches found
No related tags found
No related merge requests found
pytest
pytest < 4
pytest-postgresql
requests-mock
# Copyright (C) 2015-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import binascii
import datetime
import enum
import json
import os
import threading
from contextlib import contextmanager
import psycopg2
import psycopg2.extras
psycopg2.extras.register_uuid()
class BaseDb:
"""Base class for swh.*.*Db.
cf. swh.storage.db.Db, swh.archiver.db.ArchiverDb
"""
@classmethod
def connect(cls, *args, **kwargs):
"""factory method to create a DB proxy
Accepts all arguments of psycopg2.connect; only some specific
possibilities are reported below.
Args:
connstring: libpq2 connection string
"""
conn = psycopg2.connect(*args, **kwargs)
return cls(conn)
@classmethod
def from_pool(cls, pool):
return cls(pool.getconn(), pool=pool)
def _cursor(self, cur_arg):
"""get a cursor: from cur_arg if given, or a fresh one otherwise
meant to avoid boilerplate if/then/else in methods that proxy stored
procedures
"""
if cur_arg is not None:
return cur_arg
else:
return self.conn.cursor()
def __init__(self, conn, pool=None):
"""create a DB proxy
Args:
conn: psycopg2 connection to the SWH DB
pool: psycopg2 pool of connections
"""
self.conn = conn
self.pool = pool
def __del__(self):
if self.pool:
self.pool.putconn(self.conn)
@contextmanager
def transaction(self):
"""context manager to execute within a DB transaction
Yields:
a psycopg2 cursor
"""
with self.conn.cursor() as cur:
try:
yield cur
self.conn.commit()
except Exception:
if not self.conn.closed:
self.conn.rollback()
raise
def copy_to(self, items, tblname, columns, cur=None, item_cb=None):
"""Copy items' entries to table tblname with columns information.
Args:
items (dict): dictionary of data to copy over tblname
tblname (str): Destination table's name
columns ([str]): keys to access data in items and also the
column names in the destination table.
item_cb (fn): optional function to apply to items's entry
"""
def escape(data):
if data is None:
return ''
if isinstance(data, bytes):
return '\\x%s' % binascii.hexlify(data).decode('ascii')
elif isinstance(data, str):
return '"%s"' % data.replace('"', '""')
elif isinstance(data, datetime.datetime):
# We escape twice to make sure the string generated by
# isoformat gets escaped
return escape(data.isoformat())
elif isinstance(data, dict):
return escape(json.dumps(data))
elif isinstance(data, list):
return escape("{%s}" % ','.join(escape(d) for d in data))
elif isinstance(data, psycopg2.extras.Range):
# We escape twice here too, so that we make sure
# everything gets passed to copy properly
return escape(
'%s%s,%s%s' % (
'[' if data.lower_inc else '(',
'-infinity' if data.lower_inf else escape(data.lower),
'infinity' if data.upper_inf else escape(data.upper),
']' if data.upper_inc else ')',
)
)
elif isinstance(data, enum.IntEnum):
return escape(int(data))
else:
# We don't escape here to make sure we pass literals properly
return str(data)
read_file, write_file = os.pipe()
def writer():
cursor = self._cursor(cur)
with open(read_file, 'r') as f:
cursor.copy_expert('COPY %s (%s) FROM STDIN CSV' % (
tblname, ', '.join(columns)), f)
write_thread = threading.Thread(target=writer)
write_thread.start()
try:
with open(write_file, 'w') as f:
for d in items:
if item_cb is not None:
item_cb(d)
line = [escape(d.get(k)) for k in columns]
f.write(','.join(line))
f.write('\n')
finally:
# No problem bubbling up exceptions, but we still need to make sure
# we finish copying, even though we're probably going to cancel the
# transaction.
write_thread.join()
def mktemp(self, tblname, cur=None):
self._cursor(cur).execute('SELECT swh_mktemp(%s)', (tblname,))
# Copyright (C) 2015-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import inspect
import functools
def apply_options(cursor, options):
"""Applies the given postgresql client options to the given cursor.
Returns a dictionary with the old values if they changed."""
old_options = {}
for option, value in options.items():
cursor.execute('SHOW %s' % option)
old_value = cursor.fetchall()[0][0]
if old_value != value:
cursor.execute('SET LOCAL %s TO %%s' % option, (value,))
old_options[option] = old_value
return old_options
def db_transaction(**client_options):
"""decorator to execute Backend methods within DB transactions
The decorated method must accept a `cur` and `db` keyword argument
Client options are passed as `set` options to the postgresql server
"""
def decorator(meth, __client_options=client_options):
if inspect.isgeneratorfunction(meth):
raise ValueError(
'Use db_transaction_generator for generator functions.')
@functools.wraps(meth)
def _meth(self, *args, **kwargs):
if 'cur' in kwargs and kwargs['cur']:
cur = kwargs['cur']
old_options = apply_options(cur, __client_options)
ret = meth(self, *args, **kwargs)
apply_options(cur, old_options)
return ret
else:
db = self.get_db()
with db.transaction() as cur:
apply_options(cur, __client_options)
return meth(self, *args, db=db, cur=cur, **kwargs)
return _meth
return decorator
def db_transaction_generator(**client_options):
"""decorator to execute Backend methods within DB transactions, while
returning a generator
The decorated method must accept a `cur` and `db` keyword argument
Client options are passed as `set` options to the postgresql server
"""
def decorator(meth, __client_options=client_options):
if not inspect.isgeneratorfunction(meth):
raise ValueError(
'Use db_transaction for non-generator functions.')
@functools.wraps(meth)
def _meth(self, *args, **kwargs):
if 'cur' in kwargs and kwargs['cur']:
cur = kwargs['cur']
old_options = apply_options(cur, __client_options)
yield from meth(self, *args, **kwargs)
apply_options(cur, old_options)
else:
db = self.get_db()
with db.transaction() as cur:
apply_options(cur, __client_options)
yield from meth(self, *args, db=db, cur=cur, **kwargs)
return _meth
return decorator
# Copyright (C) 2015-2019 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This code has been imported from psycopg2, version 2.7.4,
# https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd
# and modified by Software Heritage.
#
# Original file: lib/extras.py
#
# psycopg2 is free software: you can redistribute it and/or modify it under the
# terms of the GNU Lesser General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option) any
# later version.
import re
import functools
import psycopg2.extensions
def stored_procedure(stored_proc):
"""decorator to execute remote stored procedure, specified as argument
Generally, the body of the decorated function should be empty. If it is
not, the stored procedure will be executed first; the function body then.
"""
def wrap(meth):
@functools.wraps(meth)
def _meth(self, *args, **kwargs):
cur = kwargs.get('cur', None)
self._cursor(cur).execute('SELECT %s()' % stored_proc)
meth(self, *args, **kwargs)
return _meth
return wrap
def jsonize(value):
"""Convert a value to a psycopg2 JSON object if necessary"""
if isinstance(value, dict):
return psycopg2.extras.Json(value)
return value
def entry_to_bytes(entry):
"""Convert an entry coming from the database to bytes"""
if isinstance(entry, memoryview):
return entry.tobytes()
if isinstance(entry, list):
return [entry_to_bytes(value) for value in entry]
return entry
def line_to_bytes(line):
"""Convert a line coming from the database to bytes"""
if not line:
return line
if isinstance(line, dict):
return {k: entry_to_bytes(v) for k, v in line.items()}
return line.__class__(entry_to_bytes(entry) for entry in line)
def cursor_to_bytes(cursor):
"""Yield all the data from a cursor as bytes"""
yield from (line_to_bytes(line) for line in cursor)
def execute_values_to_bytes(*args, **kwargs):
for line in execute_values_generator(*args, **kwargs):
yield line_to_bytes(line)
def _paginate(seq, page_size):
"""Consume an iterable and return it in chunks.
Every chunk is at most `page_size`. Never return an empty chunk.
"""
page = []
it = iter(seq)
while 1:
try:
for i in range(page_size):
page.append(next(it))
yield page
page = []
except StopIteration:
if page:
yield page
return
def _split_sql(sql):
"""Split *sql* on a single ``%s`` placeholder.
Split on the %s, perform %% replacement and return pre, post lists of
snippets.
"""
curr = pre = []
post = []
tokens = re.split(br'(%.)', sql)
for token in tokens:
if len(token) != 2 or token[:1] != b'%':
curr.append(token)
continue
if token[1:] == b's':
if curr is pre:
curr = post
else:
raise ValueError(
"the query contains more than one '%s' placeholder")
elif token[1:] == b'%':
curr.append(b'%')
else:
raise ValueError("unsupported format character: '%s'"
% token[1:].decode('ascii', 'replace'))
if curr is pre:
raise ValueError("the query doesn't contain any '%s' placeholder")
return pre, post
def execute_values_generator(cur, sql, argslist, template=None, page_size=100):
'''Execute a statement using SQL ``VALUES`` with a sequence of parameters.
Rows returned by the query are returned through a generator.
You need to consume the generator for the queries to be executed!
:param cur: the cursor to use to execute the query.
:param sql: the query to execute. It must contain a single ``%s``
placeholder, which will be replaced by a `VALUES list`__.
Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``.
:param argslist: sequence of sequences or dictionaries with the arguments
to send to the query. The type and content must be consistent with
*template*.
:param template: the snippet to merge to every item in *argslist* to
compose the query.
- If the *argslist* items are sequences it should contain positional
placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there
are constants value...).
- If the *argslist* items are mappings it should contain named
placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``).
If not specified, assume the arguments are sequence and use a simple
positional template (i.e. ``(%s, %s, ...)``), with the number of
placeholders sniffed by the first element in *argslist*.
:param page_size: maximum number of *argslist* items to include in every
statement. If there are more items the function will execute more than
one statement.
:param yield_from_cur: Whether to yield results from the cursor in this
function directly.
.. __: https://www.postgresql.org/docs/current/static/queries-values.html
After the execution of the function the `cursor.rowcount` property will
**not** contain a total result.
'''
# we can't just use sql % vals because vals is bytes: if sql is bytes
# there will be some decoding error because of stupid codec used, and Py3
# doesn't implement % on bytes.
if not isinstance(sql, bytes):
sql = sql.encode(
psycopg2.extensions.encodings[cur.connection.encoding]
)
pre, post = _split_sql(sql)
for page in _paginate(argslist, page_size=page_size):
if template is None:
template = b'(' + b','.join([b'%s'] * len(page[0])) + b')'
parts = pre[:]
for args in page:
parts.append(cur.mogrify(template, args))
parts.append(b',')
parts[-1:] = post
cur.execute(b''.join(parts))
yield from cur
......@@ -5,41 +5,46 @@
import logging
import os
import unittest
import pytest
from swh.core.logger import PostgresHandler
from swh.core.tests.db_testing import SingleDbTestFixture
from swh.core.tests import SQL_DIR
DUMP_FILE = os.path.join(SQL_DIR, 'log-schema.sql')
@pytest.mark.db
class PgLogHandler(SingleDbTestFixture, unittest.TestCase):
TEST_DB_DUMP = os.path.join(SQL_DIR, 'log-schema.sql')
@pytest.fixture
def swh_db_logger(postgresql_proc, postgresql):
def setUp(self):
super().setUp()
self.modname = 'swh.core.tests.test_logger'
self.logger = logging.Logger(self.modname, logging.DEBUG)
self.logger.addHandler(PostgresHandler('dbname=' + self.TEST_DB_NAME))
cursor = postgresql.cursor()
with open(DUMP_FILE) as fobj:
cursor.execute(fobj.read())
postgresql.commit()
modname = 'swh.core.tests.test_logger'
logger = logging.Logger(modname, logging.DEBUG)
dsn = 'postgresql://{user}@{host}:{port}/{dbname}'.format(
host=postgresql_proc.host,
port=postgresql_proc.port,
user='postgres',
dbname='tests')
logger.addHandler(PostgresHandler(dsn))
return logger
def tearDown(self):
logging.shutdown()
super().tearDown()
def test_log(self):
self.logger.info('notice',
extra={'swh_type': 'test entry', 'swh_data': 42})
self.logger.warning('warning')
def test_log(swh_db_logger, postgresql):
logger = swh_db_logger
modname = logger.name
with self.conn.cursor() as cur:
cur.execute('SELECT level, message, data, src_module FROM log')
db_log_entries = cur.fetchall()
logger.info('notice',
extra={'swh_type': 'test entry', 'swh_data': 42})
logger.warning('warning')
self.assertIn(('info', 'notice', {'type': 'test entry', 'data': 42},
self.modname),
db_log_entries)
self.assertIn(('warning', 'warning', {}, self.modname), db_log_entries)
with postgresql.cursor() as cur:
cur.execute('SELECT level, message, data, src_module FROM log')
db_log_entries = cur.fetchall()
assert ('info', 'notice', {'type': 'test entry', 'data': 42},
modname) in db_log_entries
assert ('warning', 'warning', {}, modname) in db_log_entries
......@@ -5,9 +5,8 @@ envlist=flake8,py3
deps =
.[testing]
pytest-cov
pifpaf
commands =
pifpaf run postgresql -- pytest --cov=swh --cov-branch {posargs}
pytest --cov=swh --cov-branch {posargs}
[testenv:flake8]
skip_install = true
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment