Skip to content
Snippets Groups Projects
Commit ce1e4523 authored by Nicolas Dandrimont's avatar Nicolas Dandrimont
Browse files

Add type annotations for swh.core.db.BaseDb

parent 7124063e
No related branches found
No related tags found
No related merge requests found
......@@ -10,12 +10,13 @@ import logging
import os
import sys
import threading
from typing import Any, Callable, Iterable, Mapping, Optional
from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Type, TypeVar
from contextlib import contextmanager
import psycopg2
import psycopg2.extras
import psycopg2.pool
logger = logging.getLogger(__name__)
......@@ -112,6 +113,9 @@ def typecast_bytea(value, cur):
return data.tobytes()
BaseDbType = TypeVar("BaseDbType", bound="BaseDb")
class BaseDb:
"""Base class for swh.*.*Db.
......@@ -119,8 +123,8 @@ class BaseDb:
"""
@classmethod
def adapt_conn(cls, conn):
@staticmethod
def adapt_conn(conn: psycopg2.extensions.connection):
"""Makes psycopg2 use 'bytes' to decode bytea instead of
'memoryview', for this connection."""
t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea)
......@@ -130,7 +134,7 @@ class BaseDb:
psycopg2.extensions.register_type(t_bytes_array, conn)
@classmethod
def connect(cls, *args, **kwargs):
def connect(cls: Type[BaseDbType], *args, **kwargs) -> BaseDbType:
"""factory method to create a DB proxy
Accepts all arguments of psycopg2.connect; only some specific
......@@ -144,11 +148,17 @@ class BaseDb:
return cls(conn)
@classmethod
def from_pool(cls, pool):
def from_pool(
cls: Type[BaseDbType], pool: psycopg2.pool.AbstractConnectionPool
) -> BaseDbType:
conn = pool.getconn()
return cls(conn, pool=pool)
def __init__(self, conn, pool=None):
def __init__(
self,
conn: psycopg2.extensions.connection,
pool: Optional[psycopg2.pool.AbstractConnectionPool] = None,
):
"""create a DB proxy
Args:
......@@ -160,11 +170,13 @@ class BaseDb:
self.conn = conn
self.pool = pool
def put_conn(self):
def put_conn(self) -> None:
if self.pool:
self.pool.putconn(self.conn)
def cursor(self, cur_arg=None):
def cursor(
self, cur_arg: Optional[psycopg2.extensions.cursor] = None
) -> psycopg2.extensions.cursor:
"""get a cursor: from cur_arg if given, or a fresh one otherwise
meant to avoid boilerplate if/then/else in methods that proxy stored
......@@ -179,7 +191,7 @@ class BaseDb:
_cursor = cursor # for bw compat
@contextmanager
def transaction(self):
def transaction(self) -> Iterator[psycopg2.extensions.cursor]:
"""context manager to execute within a DB transaction
Yields:
......@@ -283,5 +295,5 @@ class BaseDb:
# postgresql returned an error, let's raise it.
raise exc_info[1].with_traceback(exc_info[2])
def mktemp(self, tblname, cur=None):
def mktemp(self, tblname: str, cur: Optional[psycopg2.extensions.cursor] = None):
self.cursor(cur).execute("SELECT swh_mktemp(%s)", (tblname,))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment