Skip to content
Snippets Groups Projects
Commit 32423fb1 authored by vlorentz's avatar vlorentz
Browse files

Make db_transaction* remove db/cur from the signature.

Rremoving them allows testing the function's signature
matches the existing signature of a specification and type checking.

Moreover, they should not be used by users of the class, so there is no
reason to have them appear in the documentation (generated from
the signature).
parent 0d367617
No related branches found
Tags v0.0.87
1 merge request!122Make db_transaction* remove db/cur from the signature.
......@@ -7,6 +7,19 @@ import inspect
import functools
def remove_kwargs(names):
def decorator(f):
sig = inspect.signature(f)
params = sig.parameters
params = [param for param in params.values()
if param.name not in names]
sig = sig.replace(parameters=params)
f.__signature__ = sig
return f
return decorator
def apply_options(cursor, options):
"""Applies the given postgresql client options to the given cursor.
......@@ -33,6 +46,7 @@ def db_transaction(**client_options):
raise ValueError(
'Use db_transaction_generator for generator functions.')
@remove_kwargs(['cur', 'db'])
@functools.wraps(meth)
def _meth(self, *args, **kwargs):
if 'cur' in kwargs and kwargs['cur']:
......@@ -67,6 +81,7 @@ def db_transaction_generator(**client_options):
raise ValueError(
'Use db_transaction for non-generator functions.')
@remove_kwargs(['cur', 'db'])
@functools.wraps(meth)
def _meth(self, *args, **kwargs):
if 'cur' in kwargs and kwargs['cur']:
......
......@@ -3,6 +3,7 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import inspect
import os.path
import tempfile
import unittest
......@@ -151,6 +152,21 @@ def test_db_transaction__with_generator():
yield None
def test_db_transaction_signature():
"""Checks db_transaction removes the 'cur' and 'db' arguments."""
def f(self, foo, *, bar=None):
pass
expected_sig = inspect.signature(f)
@db_transaction()
def g(self, foo, *, bar=None, db=None, cur=None):
pass
actual_sig = inspect.signature(g)
assert actual_sig == expected_sig
def test_db_transaction_generator(mocker):
expected_cur = object()
......@@ -189,3 +205,18 @@ def test_db_transaction_generator__with_nongenerator():
@db_transaction_generator()
def endpoint(self, cur=None, db=None):
pass
def test_db_transaction_generator_signature():
"""Checks db_transaction removes the 'cur' and 'db' arguments."""
def f(self, foo, *, bar=None):
pass
expected_sig = inspect.signature(f)
@db_transaction_generator()
def g(self, foo, *, bar=None, db=None, cur=None):
yield None
actual_sig = inspect.signature(g)
assert actual_sig == expected_sig
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment