From 955c001f24aa6f7bf47564264244ac94ae0f4934 Mon Sep 17 00:00:00 2001
From: Pierre-Yves David <pierre-yves.david@octobus.net>
Date: Tue, 28 May 2024 12:44:53 +0200
Subject: [PATCH] Migrate to psycopg3

This replace the previous usage of psycopg2. Beware that MR is various
dependencies needs to be landed and released accordingly.

We upgrade imports and various function and class names

Data type wrapping seems significantly simpler in psycopg3 so we can drop code
related to it.

Adding various typing information to help catch bug and make mypy happy.

We use context manager more to make sure the connection is getting properly
closed (and not just the cursors) as psycopg3 seems more sensitive to this.

We have to jump through quite hacky hoops to get a connection
information with a (reinjected) password in it. I am a bit nervous about
it. Especially as the test does not seems to exercise this case.

We drop the `execute_values_generator` from core.db as psycopg3 has better
capability in that regards that make it unnecessary in more situation. We have
a reimplementation of it compatible with psycopg3 in swh-storage (part of the
series migrating it to psycopg3)

We are dropping call to "psycopg2.extras.register_default_jsonb" call in test,
as this seems to no longer be necessary in psycopg3.

We are dropping the call to "psycopg.extras.register_uuid" as it no
longer seems useful

We are dropping the adapter for memory view as psycopg3 already do the
right thing:
https://www.psycopg.org/psycopg3/docs/basic/adapt.html#binary-adaptation

We use proper value swapping to set the config.
---
 requirements-db.txt                   |   3 +-
 requirements-test.txt                 |   2 +-
 swh/core/api/__init__.py              |   8 +-
 swh/core/api/tests/test_rpc_server.py |  16 +--
 swh/core/db/__init__.py               |  73 ++++++------
 swh/core/db/common.py                 |  11 +-
 swh/core/db/db_utils.py               | 155 ++++++++++----------------
 swh/core/db/tests/conftest.py         |   4 +-
 swh/core/db/tests/test_cli.py         |  71 +++++++-----
 swh/core/db/tests/test_db.py          |  15 ++-
 swh/core/db/tests/test_db_utils.py    |  70 ++++++------
 swh/core/github/utils.py              |   2 +-
 swh/core/pytest_plugin.py             |  57 ++++++----
 13 files changed, 241 insertions(+), 246 deletions(-)

diff --git a/requirements-db.txt b/requirements-db.txt
index 921e04d0..a187c96f 100644
--- a/requirements-db.txt
+++ b/requirements-db.txt
@@ -1,3 +1,4 @@
 # requirements for swh.core.db
-psycopg2
+psycopg
+psycopg_pool
 typing-extensions
diff --git a/requirements-test.txt b/requirements-test.txt
index 48581180..2d277429 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -7,7 +7,7 @@ pytest-postgresql > 5
 pytz
 requests-mock
 types-deprecated
-types-psycopg2
 types-pytz
 types-pyyaml
 types-requests
+types-setuptools
diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py
index cc680b75..93e43154 100644
--- a/swh/core/api/__init__.py
+++ b/swh/core/api/__init__.py
@@ -554,11 +554,11 @@ class RPCServerApp(Flask):
         (Exception, 500),
         # These errors are noisy, and are better logged on the caller's side after
         # it retried a few times:
-        ("psycopg2.errors.OperationalError", 503),
+        ("psycopg.errors.OperationalError", 503),
         # Subclass of OperationalError; but it is unlikely to be solved after retries
         # (short of getting more cache hits) because this is usually caused by the query
         # size instead of a transient failure
-        ("psycopg2.errors.QueryCanceled", 500),
+        ("psycopg.errors.QueryCanceled", 500),
         # Often a transient error because of connectivity issue with, or restart of,
         # the Kafka brokers:
         ("swh.journal.writer.kafka.KafkaDeliveryError", 503),
@@ -631,10 +631,10 @@ class RPCServerApp(Flask):
 
         self.route("/" + meth._endpoint_path, methods=["POST"])(f)
 
-    def setup_psycopg2_errorhandlers(self) -> None:
+    def setup_psycopg_errorhandlers(self) -> None:
         """Deprecated method; error handlers are now setup in the constructor."""
         warnings.warn(
-            "setup_psycopg2_errorhandlers has no effect; error handlers are now setup "
+            "setup_psycopg_errorhandlers has no effect; error handlers are now setup "
             "by the constructor.",
             DeprecationWarning,
         )
diff --git a/swh/core/api/tests/test_rpc_server.py b/swh/core/api/tests/test_rpc_server.py
index 94cb1d72..c7eb088d 100644
--- a/swh/core/api/tests/test_rpc_server.py
+++ b/swh/core/api/tests/test_rpc_server.py
@@ -58,13 +58,13 @@ class TestStorage:
 
     @remote_api_endpoint("crashy/adminshutdown")
     def adminshutdown_crash(self, data, db=None, cur=None):
-        from psycopg2.errors import AdminShutdown
+        from psycopg.errors import AdminShutdown
 
         raise AdminShutdown("cluster is shutting down")
 
     @remote_api_endpoint("crashy/querycancelled")
     def querycancelled_crash(self, data, db=None, cur=None):
-        from psycopg2.errors import QueryCanceled
+        from psycopg.errors import QueryCanceled
 
         raise QueryCanceled("too big!")
 
@@ -181,8 +181,8 @@ def test_rpc_server_custom_exception(flask_app_client):
     assert data["args"] == ["try again later!"]
 
 
-def test_rpc_server_psycopg2_adminshutdown(flask_app_client):
-    pytest.importorskip("psycopg2")
+def test_rpc_server_psycopg_adminshutdown(flask_app_client):
+    pytest.importorskip("psycopg")
 
     res = flask_app_client.post(
         url_for("adminshutdown_crash"),
@@ -197,12 +197,12 @@ def test_rpc_server_psycopg2_adminshutdown(flask_app_client):
     assert res.mimetype == "application/x-msgpack", res.data
     data = msgpack.loads(res.data)
     assert data["type"] == "AdminShutdown"
-    assert data["module"] == "psycopg2.errors"
+    assert data["module"] == "psycopg.errors"
     assert data["args"] == ["cluster is shutting down"]
 
 
-def test_rpc_server_psycopg2_querycancelled(flask_app_client):
-    pytest.importorskip("psycopg2")
+def test_rpc_server_psycopg_querycancelled(flask_app_client):
+    pytest.importorskip("psycopg")
 
     res = flask_app_client.post(
         url_for("querycancelled_crash"),
@@ -217,7 +217,7 @@ def test_rpc_server_psycopg2_querycancelled(flask_app_client):
     assert res.mimetype == "application/x-msgpack", res.data
     data = msgpack.loads(res.data)
     assert data["type"] == "QueryCanceled"
-    assert data["module"] == "psycopg2.errors"
+    assert data["module"] == "psycopg.errors"
     assert data["args"] == ["too big!"]
 
 
diff --git a/swh/core/db/__init__.py b/swh/core/db/__init__.py
index eb1341f6..122b3639 100644
--- a/swh/core/db/__init__.py
+++ b/swh/core/db/__init__.py
@@ -13,16 +13,13 @@ import sys
 import threading
 from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Type, TypeVar
 
-import psycopg2
-import psycopg2.extras
-import psycopg2.pool
+import psycopg
+from psycopg.types.range import Range
+import psycopg_pool
 
 logger = logging.getLogger(__name__)
 
 
-psycopg2.extras.register_uuid()
-
-
 def render_array(data) -> str:
     """Render the data as a postgresql array"""
     # From https://www.postgresql.org/docs/11/arrays.html#ARRAYS-IO
@@ -73,7 +70,7 @@ def value_as_pg_text(data: Any) -> str:
         return json.dumps(data)
     elif isinstance(data, (list, tuple)):
         return render_array(data)
-    elif isinstance(data, psycopg2.extras.Range):
+    elif isinstance(data, Range):
         return "%s%s,%s%s" % (
             "[" if data.lower_inc else "(",
             "-infinity" if data.lower_inf else value_as_pg_text(data.lower),
@@ -105,12 +102,6 @@ def escape_copy_column(column: str) -> str:
     return ret
 
 
-def typecast_bytea(value, cur):
-    if value is not None:
-        data = psycopg2.BINARY(value, cur)
-        return data.tobytes()
-
-
 BaseDbType = TypeVar("BaseDbType", bound="BaseDb")
 
 
@@ -121,60 +112,50 @@ class BaseDb:
 
     """
 
-    @staticmethod
-    def adapt_conn(conn: psycopg2.extensions.connection):
-        """Makes psycopg2 use 'bytes' to decode bytea instead of
-        'memoryview', for this connection."""
-        t_bytes = psycopg2.extensions.new_type((17,), "bytea", typecast_bytea)
-        psycopg2.extensions.register_type(t_bytes, conn)
-
-        t_bytes_array = psycopg2.extensions.new_array_type((1001,), "bytea[]", t_bytes)
-        psycopg2.extensions.register_type(t_bytes_array, conn)
-
     @classmethod
     def connect(cls: Type[BaseDbType], *args, **kwargs) -> BaseDbType:
         """factory method to create a DB proxy
 
-        Accepts all arguments of psycopg2.connect; only some specific
+        Accepts all arguments of psycopg.connect; only some specific
         possibilities are reported below.
 
         Args:
             connstring: libpq2 connection string
 
         """
-        conn = psycopg2.connect(*args, **kwargs)
+        conn = psycopg.connect(*args, **kwargs)
         return cls(conn)
 
     @classmethod
     def from_pool(
-        cls: Type[BaseDbType], pool: psycopg2.pool.AbstractConnectionPool
+        cls: Type[BaseDbType], pool: psycopg_pool.ConnectionPool
     ) -> BaseDbType:
         conn = pool.getconn()
         return cls(conn, pool=pool)
 
     def __init__(
         self,
-        conn: psycopg2.extensions.connection,
-        pool: Optional[psycopg2.pool.AbstractConnectionPool] = None,
+        conn: psycopg.Connection[Any],
+        pool: Optional[psycopg_pool.ConnectionPool] = None,
     ):
         """create a DB proxy
 
         Args:
-            conn: psycopg2 connection to the SWH DB
-            pool: psycopg2 pool of connections
+            conn: psycopg connection to the SWH DB
+            pool: psycopg pool of connections
 
         """
-        self.adapt_conn(conn)
         self.conn = conn
         self.pool = pool
 
+    def close(self):
+        return self.conn.close()
+
     def put_conn(self) -> None:
         if self.pool:
             self.pool.putconn(self.conn)
 
-    def cursor(
-        self, cur_arg: Optional[psycopg2.extensions.cursor] = None
-    ) -> psycopg2.extensions.cursor:
+    def cursor(self, cur_arg: Optional[psycopg.Cursor] = None) -> psycopg.Cursor:
         """get a cursor: from cur_arg if given, or a fresh one otherwise
 
         meant to avoid boilerplate if/then/else in methods that proxy stored
@@ -186,14 +167,21 @@ class BaseDb:
         else:
             return self.conn.cursor()
 
+    def __enter__(self):
+        self.conn.__enter__()
+        return self
+
+    def __exit__(self, *args, **kwargs):
+        return self.conn.__exit__(*args, **kwargs)
+
     _cursor = cursor  # for bw compat
 
     @contextmanager
-    def transaction(self) -> Iterator[psycopg2.extensions.cursor]:
+    def transaction(self) -> Iterator[psycopg.Cursor]:
         """context manager to execute within a DB transaction
 
         Yields:
-            a psycopg2 cursor
+            a psycopg cursor
 
         """
         with self.conn.cursor() as cur:
@@ -210,7 +198,7 @@ class BaseDb:
         items: Iterable[Mapping[str, Any]],
         tblname: str,
         columns: Iterable[str],
-        cur: Optional[psycopg2.extensions.cursor] = None,
+        cur: Optional[psycopg.Cursor] = None,
         item_cb: Optional[Callable[[Any], Any]] = None,
         default_values: Optional[Mapping[str, Any]] = None,
     ) -> None:
@@ -240,9 +228,12 @@ class BaseDb:
             cursor = self.cursor(cur)
             with open(read_file, "r") as f:
                 try:
-                    cursor.copy_expert(
-                        "COPY %s (%s) FROM STDIN" % (tblname, ", ".join(columns)), f
-                    )
+                    with cursor.copy(
+                        "COPY %s (%s) FROM STDIN" % (tblname, ", ".join(columns))
+                    ) as c:
+                        while data := f.read(4096):
+                            c.write(data)
+
                 except Exception:
                     # Tell the main thread about the exception
                     exc_info = sys.exc_info()
@@ -293,5 +284,5 @@ class BaseDb:
                 # postgresql returned an error, let's raise it.
                 raise exc_info[1].with_traceback(exc_info[2])
 
-    def mktemp(self, tblname: str, cur: Optional[psycopg2.extensions.cursor] = None):
+    def mktemp(self, tblname: str, cur: Optional[psycopg.Cursor] = None):
         self.cursor(cur).execute("SELECT swh_mktemp(%s)", (tblname,))
diff --git a/swh/core/db/common.py b/swh/core/db/common.py
index 965b6686..77909555 100644
--- a/swh/core/db/common.py
+++ b/swh/core/db/common.py
@@ -28,7 +28,16 @@ def apply_options(cursor, options):
         cursor.execute("SHOW %s" % option)
         old_value = cursor.fetchall()[0][0]
         if old_value != value:
-            cursor.execute("SET LOCAL %s TO %%s" % option, (value,))
+            # We could also pre-format the option and value using:
+            #
+            #      (str(option), str(value))
+            #
+            # However using %s::text is going through the psycopg adapter
+            # system and is likely more robust.
+            cursor.execute(
+                "SELECT set_config(%s::text, %s::text, true)",
+                (option, value),
+            )
             old_options[option] = old_value
     return old_options
 
diff --git a/swh/core/db/db_utils.py b/swh/core/db/db_utils.py
index bcfae1ce..adb5d597 100644
--- a/swh/core/db/db_utils.py
+++ b/swh/core/db/db_utils.py
@@ -13,15 +13,13 @@ import pathlib
 import re
 import subprocess
 from types import ModuleType
-from typing import Collection, Dict, Iterator, List, Optional, Tuple, Union, cast
+from typing import Any, Collection, Dict, Iterator, List, Optional, Tuple, Union, cast
 
-import psycopg2
-import psycopg2.errors
-import psycopg2.extensions
-from psycopg2.extensions import connection as pgconnection
-from psycopg2.extensions import encodings as pgencodings
-from psycopg2.extensions import make_dsn
-from psycopg2.extensions import parse_dsn as _parse_dsn
+import psycopg
+from psycopg import Connection
+from psycopg.conninfo import conninfo_to_dict, make_conninfo
+import psycopg.errors
+from psycopg.types.json import Json
 
 from swh.core.config import get_swh_backend_module
 from swh.core.utils import numfile_sortkey as sortkey
@@ -54,17 +52,17 @@ def stored_procedure(stored_proc):
 
 
 def jsonize(value):
-    """Convert a value to a psycopg2 JSON object if necessary"""
+    """Convert a value to a psycopg JSON object if necessary"""
     if isinstance(value, dict):
-        return psycopg2.extras.Json(value)
+        return Json(value)
 
     return value
 
 
 @contextmanager
 def connect_to_conninfo(
-    db_or_conninfo: Union[str, pgconnection],
-) -> Iterator[pgconnection]:
+    db_or_conninfo: Union[str, Connection[Any]],
+) -> Iterator[Connection[Any]]:
     """Connect to the database passed as argument.
 
     Args:
@@ -74,7 +72,12 @@ def connect_to_conninfo(
         a connected database handle or None if the database is not initialized
 
     """
-    if isinstance(db_or_conninfo, pgconnection):
+    if isinstance(db_or_conninfo, Connection):
+        # we don't use a connext manager here, as we let the caller manage the
+        # connection life time.
+        #
+        # We don't want a Connection to be prematurely closed by simple
+        # function like `swh_db_version`
         yield db_or_conninfo
     else:
         if "=" not in db_or_conninfo and "//" not in db_or_conninfo:
@@ -82,14 +85,15 @@ def connect_to_conninfo(
             db_or_conninfo = f"dbname={db_or_conninfo}"
 
         try:
-            db = psycopg2.connect(db_or_conninfo)
-        except psycopg2.Error:
+            db = psycopg.connect(db_or_conninfo)
+        except psycopg.Error:
             logger.exception("Failed to connect to `%s`", db_or_conninfo)
         else:
-            yield db
+            with db:
+                yield db
 
 
-def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]:
+def swh_db_version(db_or_conninfo: Union[str, Connection[Any]]) -> Optional[int]:
     """Retrieve the swh version of the database.
 
     If the database is not initialized, this logs a warning and returns None.
@@ -101,7 +105,8 @@ def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]:
         Either the version of the database, or None if it couldn't be detected
     """
     try:
-        with connect_to_conninfo(db_or_conninfo) as db:
+        co = connect_to_conninfo(db_or_conninfo)
+        with co as db:
             if not db:
                 return None
             with db.cursor() as c:
@@ -111,7 +116,7 @@ def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]:
                     result = c.fetchone()
                     if result:
                         return result[0]
-                except psycopg2.errors.UndefinedTable:
+                except psycopg.errors.UndefinedTable:
                     return None
     except Exception:
         logger.exception("Could not get version from `%s`", db_or_conninfo)
@@ -119,7 +124,7 @@ def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]:
 
 
 def swh_db_versions(
-    db_or_conninfo: Union[str, pgconnection],
+    db_or_conninfo: Union[str, Connection[Any]],
 ) -> Optional[List[Tuple[int, datetime, str]]]:
     """Retrieve the swh version history of the database.
 
@@ -143,7 +148,7 @@ def swh_db_versions(
                 try:
                     c.execute(query)
                     return cast(List[Tuple[int, datetime, str]], c.fetchall())
-                except psycopg2.errors.UndefinedTable:
+                except psycopg.errors.UndefinedTable:
                     return None
     except Exception:
         logger.exception("Could not get versions from `%s`", db_or_conninfo)
@@ -232,7 +237,7 @@ def swh_db_upgrade(
     return new_version
 
 
-def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]:
+def swh_db_module(db_or_conninfo: Union[str, Connection[Any]]) -> Optional[str]:
     """Retrieve the swh module used to create the database.
 
     If the database is not initialized, this logs a warning and returns None.
@@ -254,7 +259,7 @@ def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]:
                     resp = c.fetchone()
                     if resp:
                         return resp[0]
-                except psycopg2.errors.UndefinedTable:
+                except psycopg.errors.UndefinedTable:
                     return None
     except Exception:
         logger.exception("Could not get module from `%s`", db_or_conninfo)
@@ -262,7 +267,7 @@ def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]:
 
 
 def swh_set_db_module(
-    db_or_conninfo: Union[str, pgconnection], module: str, force=False
+    db_or_conninfo: Union[str, Connection[Any]], module: str, force=False
 ) -> None:
     """Set the swh module used to create the database.
 
@@ -319,7 +324,7 @@ def swh_set_db_module(
 
 
 def swh_set_db_version(
-    db_or_conninfo: Union[str, pgconnection],
+    db_or_conninfo: Union[str, Connection[Any]],
     version: int,
     ts: Optional[datetime] = None,
     desc: str = "Work in progress",
@@ -347,7 +352,7 @@ def swh_set_db_version(
             db.commit()
 
 
-def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]:
+def swh_db_flavor(db_or_conninfo: Union[str, Connection[Any]]) -> Optional[str]:
     """Retrieve the swh flavor of the database.
 
     If the database is not initialized, or the database doesn't support
@@ -370,21 +375,22 @@ def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]:
                     result = c.fetchone()
                     assert result is not None  # to keep mypy happy
                     return result[0]
-                except psycopg2.errors.UndefinedFunction:
+                except psycopg.errors.UndefinedFunction:
                     # function not found: no flavor
                     return None
     except Exception:
+        raise
         logger.exception("Could not get flavor from `%s`", db_or_conninfo)
         return None
 
 
-# The following code has been imported from psycopg2, version 2.7.4,
-# https://github.com/psycopg/psycopg2/tree/5afb2ce803debea9533e293eef73c92ffce95bcd
+# The following code has been imported from psycopg, version 2.7.4,
+# https://github.com/psycopg/psycopg/tree/5afb2ce803debea9533e293eef73c92ffce95bcd
 # and modified by Software Heritage.
 #
 # Original file: lib/extras.py
 #
-# psycopg2 is free software: you can redistribute it and/or modify it under the
+# psycopg is free software: you can redistribute it and/or modify it under the
 # terms of the GNU Lesser General Public License as published by the Free
 # Software Foundation, either version 3 of the License, or (at your option) any
 # later version.
@@ -440,60 +446,6 @@ def _split_sql(sql):
     return pre, post
 
 
-def execute_values_generator(cur, sql, argslist, template=None, page_size=100):
-    """Execute a statement using SQL ``VALUES`` with a sequence of parameters.
-    Rows returned by the query are returned through a generator.
-    You need to consume the generator for the queries to be executed!
-
-    :param cur: the cursor to use to execute the query.
-    :param sql: the query to execute. It must contain a single ``%s``
-        placeholder, which will be replaced by a `VALUES list`__.
-        Example: ``"INSERT INTO mytable (id, f1, f2) VALUES %s"``.
-    :param argslist: sequence of sequences or dictionaries with the arguments
-        to send to the query. The type and content must be consistent with
-        *template*.
-    :param template: the snippet to merge to every item in *argslist* to
-        compose the query.
-
-        - If the *argslist* items are sequences it should contain positional
-          placeholders (e.g. ``"(%s, %s, %s)"``, or ``"(%s, %s, 42)``" if there
-          are constants value...).
-        - If the *argslist* items are mappings it should contain named
-          placeholders (e.g. ``"(%(id)s, %(f1)s, 42)"``).
-
-        If not specified, assume the arguments are sequence and use a simple
-        positional template (i.e.  ``(%s, %s, ...)``), with the number of
-        placeholders sniffed by the first element in *argslist*.
-    :param page_size: maximum number of *argslist* items to include in every
-        statement. If there are more items the function will execute more than
-        one statement.
-    :param yield_from_cur: Whether to yield results from the cursor in this
-        function directly.
-
-    .. __: https://www.postgresql.org/docs/current/static/queries-values.html
-
-    After the execution of the function the `cursor.rowcount` property will
-    **not** contain a total result.
-    """
-    # we can't just use sql % vals because vals is bytes: if sql is bytes
-    # there will be some decoding error because of stupid codec used, and Py3
-    # doesn't implement % on bytes.
-    if not isinstance(sql, bytes):
-        sql = sql.encode(pgencodings[cur.connection.encoding])
-    pre, post = _split_sql(sql)
-
-    for page in _paginate(argslist, page_size=page_size):
-        if template is None:
-            template = b"(" + b",".join([b"%s"] * len(page[0])) + b")"
-        parts = pre[:]
-        for args in page:
-            parts.append(cur.mogrify(template, args))
-            parts.append(b",")
-        parts[-1:] = post
-        cur.execute(b"".join(parts))
-        yield from cur
-
-
 def import_swhmodule(modname: str) -> Optional[ModuleType]:
     # TODO: move import_swhmodule in swh.core.config, but swh-scrubber needs to
     # be aware of that befaore it can happen...
@@ -596,12 +548,13 @@ def initialize_database_for_module(
       storage_postgresql = factories.postgresql("storage_postgresql_proc")
 
     """
-    conninfo = psycopg2.connect(**kwargs).dsn
+    with psycopg.connect(**kwargs) as con:
+        conninfo = dsn_with_password(con)
     init_admin_extensions(modname, conninfo)
     populate_database_for_package(modname, conninfo, flavor)
     try:
         swh_set_db_version(conninfo, version)
-    except psycopg2.errors.UniqueViolation:
+    except psycopg.errors.UniqueViolation:
         logger.warn(
             "Version already set by db init scripts. "
             f"This generally means the swh.{modname} package needs to be "
@@ -610,7 +563,7 @@ def initialize_database_for_module(
 
 
 def get_database_info(
-    conninfo: str,
+    conninfo: Union[str, Connection[Any]],
 ) -> Tuple[Optional[str], Optional[int], Optional[str]]:
     """Get version, flavor and module of the db"""
     dbmodule = swh_db_module(conninfo)
@@ -622,13 +575,16 @@ def get_database_info(
 
 
 def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]:
-    """Parse a psycopg2 dsn, falling back to supporting plain database names as well"""
+    """Parse a psycopg dsn, falling back to supporting plain database names as well"""
     try:
-        return _parse_dsn(dsn_or_dbname)
-    except psycopg2.ProgrammingError:
-        # psycopg2 failed to parse the DSN; it's probably a database name,
+        d = conninfo_to_dict(dsn_or_dbname)
+    except psycopg.ProgrammingError:
+        # psycopg failed to parse the DSN; it's probably a database name,
         # handle it as such
-        return _parse_dsn(f"dbname={dsn_or_dbname}")
+        d = conninfo_to_dict(f"dbname={dsn_or_dbname}")
+    # conninfo_to_dict only returns non-str values when given keyword arguments
+    conninfo = cast(Dict[str, str], d)
+    return conninfo
 
 
 def init_admin_extensions(modname: str, conninfo: str) -> None:
@@ -668,7 +624,7 @@ def create_database_for_package(
             "-v",
             "ON_ERROR_STOP=1",
             "-d",
-            make_dsn(**creation_dsn),
+            make_conninfo(**creation_dsn),
             "-c",
             f'CREATE DATABASE "{dbname}"',
         ]
@@ -676,9 +632,18 @@ def create_database_for_package(
     init_admin_extensions(modname, conninfo)
 
 
+def dsn_with_password(con: Connection[Any]):
+    """fetch the connection info with password"""
+    conn_info = con.info.dsn
+    password = con.info.password
+    if password is not None:
+        conn_info = conn_info.replace("@", f":{password}@", 1)
+    return conn_info
+
+
 def execute_sqlfiles(
     sqlfiles: Collection[pathlib.Path],
-    db_or_conninfo: Union[str, pgconnection],
+    db_or_conninfo: Union[str, Connection[Any]],
     flavor: Optional[str] = None,
 ):
     """Execute a list of SQL files on the database pointed at with ``db_or_conninfo``.
@@ -691,7 +656,7 @@ def execute_sqlfiles(
     if isinstance(db_or_conninfo, str):
         conninfo = db_or_conninfo
     else:
-        conninfo = db_or_conninfo.dsn
+        conninfo = dsn_with_password(db_or_conninfo)
 
     psql_command = [
         "psql",
diff --git a/swh/core/db/tests/conftest.py b/swh/core/db/tests/conftest.py
index 1b51223e..8319ea89 100644
--- a/swh/core/db/tests/conftest.py
+++ b/swh/core/db/tests/conftest.py
@@ -7,7 +7,7 @@ import os
 
 from click.testing import CliRunner
 from hypothesis import HealthCheck
-import psycopg2
+import psycopg
 import pytest
 from pytest_postgresql import factories
 
@@ -22,7 +22,7 @@ function_scoped_fixture_check = (
 
 
 def create_role_guest(**kwargs):
-    with psycopg2.connect(**kwargs) as conn:
+    with psycopg.connect(**kwargs) as conn:
         with conn.cursor() as cur:
             cur.execute("REVOKE CREATE ON SCHEMA public FROM PUBLIC")
             cur.execute("CREATE ROLE guest NOINHERIT LOGIN PASSWORD 'guest'")
diff --git a/swh/core/db/tests/test_cli.py b/swh/core/db/tests/test_cli.py
index 4cb8ce43..45e685a9 100644
--- a/swh/core/db/tests/test_cli.py
+++ b/swh/core/db/tests/test_cli.py
@@ -15,6 +15,12 @@ from swh.core.tests.test_cli import assert_result, assert_section_contains
 postgresql2 = factories.postgresql("postgresql_proc", dbname="tests2")
 
 
+def assert_no_pending_transaction(cursor):
+    sql = """SELECT * FROM pg_stat_activity WHERE state = 'idle in transaction'"""
+    idle = cursor.execute(sql).fetchall()
+    assert idle == []
+
+
 def test_cli_swh_help(swhmain, cli_runner):
     swhmain.add_command(swhdb)
     result = cli_runner.invoke(swhmain, ["-h"])
@@ -99,10 +105,11 @@ def test_cli_swh_db_create_and_init_db(
 
     # the origin value in the scripts uses a hash function (which implementation wise
     # uses a function from the pgcrypt extension, installed during db creation step)
-    with BaseDb.connect(conninfo).cursor() as cur:
-        cur.execute(f"select * from {table}")
-        origins = cur.fetchall()
-        assert len(origins) == 1
+    with BaseDb.connect(conninfo) as conn:
+        with conn.cursor() as cur:
+            cur.execute(f"select * from {table}")
+            origins = cur.fetchall()
+            assert len(origins) == 1
 
 
 def test_cli_swh_db_initialization_fail_without_creation_first(
@@ -154,10 +161,11 @@ def test_cli_swh_db_initialization_works_with_flags(
     assert_result(result)
     # the origin values in the scripts uses a hash function (which implementation wise
     # uses a function from the pgcrypt extension, init-admin calls installs it)
-    with BaseDb.connect(postgresql.info.dsn).cursor() as cur:
-        cur.execute("select * from origin")
-        origins = cur.fetchall()
-        assert len(origins) == 1
+    with BaseDb.connect(postgresql.info.dsn) as conn:
+        with conn.cursor() as cur:
+            cur.execute("select * from origin")
+            origins = cur.fetchall()
+            assert len(origins) == 1
 
 
 def test_cli_swh_db_initialization_with_env(
@@ -183,10 +191,11 @@ def test_cli_swh_db_initialization_with_env(
 
     # the origin values in the scripts uses a hash function (which implementation wise
     # uses a function from the pgcrypt extension, init-admin calls installs it)
-    with BaseDb.connect(postgresql.info.dsn).cursor() as cur:
-        cur.execute("select * from origin")
-        origins = cur.fetchall()
-        assert len(origins) == 1
+    with BaseDb.connect(postgresql.info.dsn) as conn:
+        with conn.cursor() as cur:
+            cur.execute("select * from origin")
+            origins = cur.fetchall()
+            assert len(origins) == 1
 
 
 def test_cli_swh_db_initialization_idempotent(
@@ -218,10 +227,11 @@ def test_cli_swh_db_initialization_idempotent(
 
     # the origin values in the scripts uses a hash function (which implementation wise
     # uses a function from the pgcrypt extension, init-admin calls installs it)
-    with BaseDb.connect(postgresql.info.dsn).cursor() as cur:
-        cur.execute("select * from origin")
-        origins = cur.fetchall()
-        assert len(origins) == 1
+    with BaseDb.connect(postgresql.info.dsn) as conn:
+        with conn.cursor() as cur:
+            cur.execute("select * from origin")
+            origins = cur.fetchall()
+            assert len(origins) == 1
 
 
 @pytest.mark.parametrize("with_module_config_key", [True, False])
@@ -252,10 +262,11 @@ def test_cli_swh_db_create_and_init_db_new_api(
 
     # the origin value in the scripts uses a hash function (which implementation wise
     # uses a function from the pgcrypt extension, installed during db creation step)
-    with BaseDb.connect(conninfo).cursor() as cur:
-        cur.execute("select * from origin")
-        origins = cur.fetchall()
-        assert len(origins) == 1
+    with BaseDb.connect(conninfo) as conn:
+        with conn.cursor() as cur:
+            cur.execute("select * from origin")
+            origins = cur.fetchall()
+            assert len(origins) == 1
 
 
 def test_cli_swh_db_init_report_sqlsh_error(
@@ -301,7 +312,8 @@ def test_cli_swh_db_upgrade_new_api(
 
     # This initializes the schema and data
     cfgfile = tmp_path / "config.yml"
-    cfgfile.write_text(yaml.dump({module_name: {"cls": "postgresql", "db": conninfo}}))
+    with open(cfgfile, "w") as f:
+        f.write(yaml.dump({module_name: {"cls": "postgresql", "db": conninfo}}))
     result = cli_runner.invoke(swhdb, ["init-admin", module_name, "--dbname", conninfo])
     assert_result(result)
     result = cli_runner.invoke(swhdb, ["-C", cfgfile, "init", module_name])
@@ -358,9 +370,11 @@ def test_cli_swh_db_upgrade_new_api(
     assert swh_db_version(conninfo) == 4
 
     cnx = BaseDb.connect(conninfo)
-    with cnx.transaction() as cur:
-        cur.execute("drop table dbmodule")
-    assert swh_db_module(conninfo) is None
+    with cnx:
+        with cnx.transaction() as cur:
+            assert_no_pending_transaction(cur)
+            cur.execute("drop table dbmodule")
+        assert swh_db_module(conninfo) is None
 
     # db migration should recreate the missing dbmodule table
     result = cli_runner.invoke(swhdb, ["-C", cfgfile, "upgrade", module_name])
@@ -415,10 +429,11 @@ def test_cli_swh_db_version(swh_db_cli, mock_get_entry_points, postgresql):
 
     actual_db_version = swh_db_version(conninfo)
 
-    with BaseDb.connect(conninfo).cursor() as cur:
-        cur.execute("select version from dbversion order by version desc limit 1")
-        expected_version = cur.fetchone()[0]
-        assert actual_db_version == expected_version
+    with BaseDb.connect(conninfo) as conn:
+        with conn.cursor() as cur:
+            cur.execute("select version from dbversion order by version desc limit 1")
+            expected_version = cur.fetchone()[0]
+            assert actual_db_version == expected_version
 
     assert_result(result)
     assert (
diff --git a/swh/core/db/tests/test_db.py b/swh/core/db/tests/test_db.py
index 44932e80..da1439a0 100644
--- a/swh/core/db/tests/test_db.py
+++ b/swh/core/db/tests/test_db.py
@@ -14,7 +14,9 @@ import uuid
 
 from hypothesis import given, settings, strategies
 from hypothesis.extra.pytz import timezones
-import psycopg2
+import psycopg
+from psycopg.types.json import Json
+from psycopg.types.range import TimestamptzRange
 import pytest
 from pytest_postgresql import factories
 
@@ -169,7 +171,7 @@ FIELDS = (
         "jsonb",
         {"str": "bar", "int": 1, "list": ["a", "b"], "nested": {"a": "b"}},
         pg_jsonb(min_size=0, max_size=5),
-        in_wrapper=psycopg2.extras.Json,
+        in_wrapper=Json,
     ),
     Field(
         "intenum",
@@ -196,7 +198,7 @@ FIELDS = (
     Field(
         "tstz_range",
         "tstzrange",
-        psycopg2.extras.DateTimeTZRange(
+        TimestamptzRange(
             lower=now(),
             upper=now() + datetime.timedelta(days=1),
             bounds="[)",
@@ -208,7 +210,7 @@ FIELDS = (
             strategies.sampled_from(["[]", "()", "[)", "(]"]),
         ).map(
             # and build the actual DateTimeTZRange object from these args
-            lambda args: psycopg2.extras.DateTimeTZRange(
+            lambda args: TimestamptzRange(
                 lower=args[0][0],
                 upper=args[0][1],
                 bounds=args[1],
@@ -247,7 +249,6 @@ def db_with_data(test_db, request):
     """Fixture to initialize a db with some data out of the "INIT_SQL above"""
     db = BaseDb.connect(test_db.info.dsn)
     with db.cursor() as cur:
-        psycopg2.extras.register_default_jsonb(cur)
         cur.execute(INIT_SQL)
     yield db
     db.conn.rollback()
@@ -257,7 +258,6 @@ def db_with_data(test_db, request):
 @pytest.mark.db
 def test_db_connect(db_with_data):
     with db_with_data.cursor() as cur:
-        psycopg2.extras.register_default_jsonb(cur)
         cur.execute(INSERT_SQL, STATIC_ROW_IN)
         cur.execute("select * from test_table;")
         output = convert_lines(cur)
@@ -267,7 +267,6 @@ def test_db_connect(db_with_data):
 
 def test_db_initialized(db_with_data):
     with db_with_data.cursor() as cur:
-        psycopg2.extras.register_default_jsonb(cur)
         cur.execute(INSERT_SQL, STATIC_ROW_IN)
         cur.execute("select * from test_table;")
         output = convert_lines(cur)
@@ -304,7 +303,7 @@ def test_db_copy_to_thread_exception(db_with_data):
     data = [(2**65, "foo", b"bar")]
 
     items = [dict(zip(COLUMNS, item)) for item in data]
-    with pytest.raises(psycopg2.errors.NumericValueOutOfRange):
+    with pytest.raises(psycopg.errors.NumericValueOutOfRange):
         db_with_data.copy_to(items, "test_table", COLUMNS)
 
 
diff --git a/swh/core/db/tests/test_db_utils.py b/swh/core/db/tests/test_db_utils.py
index 1f842d12..2348a1a8 100644
--- a/swh/core/db/tests/test_db_utils.py
+++ b/swh/core/db/tests/test_db_utils.py
@@ -6,7 +6,7 @@
 from datetime import timedelta
 from os import path
 
-from psycopg2.errors import InsufficientPrivilege
+from psycopg.errors import InsufficientPrivilege
 import pytest
 
 from swh.core.cli.db import db as swhdb
@@ -22,7 +22,7 @@ from swh.core.db.db_utils import get_database_info, get_sql_for_package, now
 from swh.core.db.db_utils import parse_dsn_or_dbname as parse_dsn
 from swh.core.tests.test_cli import assert_result
 
-from .test_cli import craft_conninfo
+from .test_cli import assert_no_pending_transaction, craft_conninfo
 
 
 def test_get_sql_for_package(mock_import_module):
@@ -146,38 +146,40 @@ def test_db_utils_swh_db_upgrade_sanity_checks(
     result = cli_runner.invoke(swhdb, ["init", module, "--dbname", conninfo])
     assert_result(result)
 
-    cnx = BaseDb.connect(conninfo)
-    with cnx.transaction() as cur:
-        cur.execute("drop table dbmodule")
-
-    # try to upgrade with a unset module
-    with pytest.raises(ValueError):
-        swh_db_upgrade(conninfo, module)
-
-    # check the dbmodule is unset
-    assert swh_db_module(conninfo) is None
-
-    # set the stored module to something else
-    swh_set_db_module(conninfo, f"{module}2")
-    assert swh_db_module(conninfo) == f"{module}2"
-
-    # try to upgrade with a different module
-    with pytest.raises(ValueError):
-        swh_db_upgrade(conninfo, module)
-
-    # revert to the proper module in the db
-    swh_set_db_module(conninfo, module, force=True)
-    assert swh_db_module(conninfo) == module
-    # trying again is a noop
-    swh_set_db_module(conninfo, module)
-    assert swh_db_module(conninfo) == module
-
-    # drop the dbversion table
-    with cnx.transaction() as cur:
-        cur.execute("drop table dbversion")
-    # an upgrade should fail due to missing stored version
-    with pytest.raises(ValueError):
-        swh_db_upgrade(conninfo, module)
+    with BaseDb.connect(conninfo) as cnx:
+        with cnx.transaction() as cur:
+            assert_no_pending_transaction(cur)
+            cur.execute("drop table dbmodule")
+
+        # try to upgrade with a unset module
+        with pytest.raises(ValueError):
+            swh_db_upgrade(conninfo, module)
+
+        # check the dbmodule is unset
+        assert swh_db_module(conninfo) is None
+
+        # set the stored module to something else
+        swh_set_db_module(conninfo, f"{module}2")
+        assert swh_db_module(conninfo) == f"{module}2"
+
+        # try to upgrade with a different module
+        with pytest.raises(ValueError):
+            swh_db_upgrade(conninfo, module)
+
+        # revert to the proper module in the db
+        swh_set_db_module(conninfo, module, force=True)
+        assert swh_db_module(conninfo) == module
+        # trying again is a noop
+        swh_set_db_module(conninfo, module)
+        assert swh_db_module(conninfo) == module
+
+        # drop the dbversion table
+        with cnx.transaction() as cur:
+            assert_no_pending_transaction(cur)
+            cur.execute("drop table dbversion")
+        # an upgrade should fail due to missing stored version
+        with pytest.raises(ValueError):
+            swh_db_upgrade(conninfo, module)
 
 
 @pytest.mark.parametrize("flavor", [None, "default", "flavorA", "flavorB"])
diff --git a/swh/core/github/utils.py b/swh/core/github/utils.py
index a2f00dd1..219524f7 100644
--- a/swh/core/github/utils.py
+++ b/swh/core/github/utils.py
@@ -62,7 +62,7 @@ def get_canonical_github_origin_url(
 
 
 class RateLimited(Exception):
-    def __init__(self, response):
+    def __init__(self, response: requests.Response) -> None:
         self.reset_time: Optional[int]
 
         # Figure out how long we need to sleep because of that rate limit
diff --git a/swh/core/pytest_plugin.py b/swh/core/pytest_plugin.py
index 0f431c1b..301849f6 100644
--- a/swh/core/pytest_plugin.py
+++ b/swh/core/pytest_plugin.py
@@ -440,33 +440,46 @@ def clean_scopes():
     scope._current_scope.set(None)
 
 
-@pytest.fixture()
-def mock_import_module(request, mocker, datadir):
-    mock = mocker.MagicMock
+# Some test don't have "db" available, so we need too work around it.
+try:
+    import swh.core.db
 
-    def import_module_mocker(name, package=None):
-        if not name.startswith("swh.test"):
-            return import_module(name, package)
+    swh.core.db.__doc__
+except ImportError:
 
-        m = request.node.get_closest_marker("init_version")
-        if m:
-            version = m.kwargs.get("version", 1)
-        else:
-            version = 3
-        if name.startswith("swh."):
-            name = name[4:]
-        modpath = name.split(".")
+    @pytest.fixture()
+    def mock_import_module(request, mocker, datadir):
+        return None
 
-        def get_datastore(*args, **kw):
-            return mock(current_version=version)
+else:
 
-        return mock(
-            __name__=name.split(".")[-1],
-            __file__=str(Path(datadir, *modpath, "__init__.py")),
-            get_datastore=get_datastore,
-        )
+    @pytest.fixture()
+    def mock_import_module(request, mocker, datadir):
+        mock = mocker.MagicMock
+
+        def import_module_mocker(name, package=None):
+            if not name.startswith("swh.test"):
+                return import_module(name, package)
+
+            m = request.node.get_closest_marker("init_version")
+            if m:
+                version = m.kwargs.get("version", 1)
+            else:
+                version = 3
+            if name.startswith("swh."):
+                name = name[4:]
+            modpath = name.split(".")
+
+            def get_datastore(*args, **kw):
+                return mock(current_version=version)
+
+            return mock(
+                __name__=name.split(".")[-1],
+                __file__=str(Path(datadir, *modpath, "__init__.py")),
+                get_datastore=get_datastore,
+            )
 
-    return mocker.patch("swh.core.db.db_utils.import_module", import_module_mocker)
+        return mocker.patch("swh.core.db.db_utils.import_module", import_module_mocker)
 
 
 @pytest.fixture()
-- 
GitLab