From 2ae14b9633f10dcee3ffeffb54d94ce2cb894cd0 Mon Sep 17 00:00:00 2001
From: Pierre-Yves David <pierre-yves.david@octobus.net>
Date: Fri, 21 Mar 2025 23:57:34 +0100
Subject: [PATCH] Migrating to psycopg3

---
 requirements-swh.txt                 |  4 +-
 requirements.txt                     |  2 +
 swh/provenance/backend/postgresql.py | 56 +++++++++++++++-------------
 3 files changed, 34 insertions(+), 28 deletions(-)

diff --git a/requirements-swh.txt b/requirements-swh.txt
index 5f8b46f..a921156 100644
--- a/requirements-swh.txt
+++ b/requirements-swh.txt
@@ -1,5 +1,5 @@
 # Add here internal Software Heritage dependencies, one per line.
-swh.core[db,http] >= 2
+swh.core[db,http] >= 4.0.0
 swh.model >= 2.6.1
-swh.storage
+swh.storage >= 3.0.0
 swh.graph >= 6.7.0
diff --git a/requirements.txt b/requirements.txt
index 54ce666..a10e2e0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,6 @@
 # Add here external Python modules dependencies, one per line. Module names
 # should match https://pypi.python.org/pypi names. For the full spec or
 # dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html
+psycopg
+psycopg_pool
 
diff --git a/swh/provenance/backend/postgresql.py b/swh/provenance/backend/postgresql.py
index dd16162..1cdfef5 100644
--- a/swh/provenance/backend/postgresql.py
+++ b/swh/provenance/backend/postgresql.py
@@ -5,10 +5,10 @@
 
 from contextlib import contextmanager
 import logging
-from typing import List, Optional
+from typing import Any, List, Optional, Union
 
-import psycopg2.extras
-import psycopg2.pool
+import psycopg
+import psycopg_pool
 
 from swh.core.db import BaseDb
 from swh.core.db.common import db_transaction
@@ -28,37 +28,41 @@ class Db(BaseDb):
 class PostgresqlProvenance:
     current_version: int = 1
 
-    def __init__(self, db, min_pool_conns=1, max_pool_conns=10):
-        try:
-            if isinstance(db, psycopg2.extensions.connection):
-                self._pool = None
-                self._db = Db(db)
+    def __init__(
+        self,
+        db: Union[str, psycopg.Connection[Any]],
+        min_pool_conns: int = 1,
+        max_pool_conns: int = 10,
+    ):
+        self._db: Optional[Db]
+        self._pool: Optional[psycopg_pool.ConnectionPool]
 
-                # See comment below
-                self._db.cursor().execute("SET TIME ZONE 'UTC'")
-            else:
-                self._pool = psycopg2.pool.ThreadedConnectionPool(
-                    min_pool_conns, max_pool_conns, db
+        try:
+            if isinstance(db, str):
+                self._pool = psycopg_pool.ConnectionPool(
+                    conninfo=db,
+                    min_size=min_pool_conns,
+                    max_size=max_pool_conns,
+                    open=False,
                 )
                 self._db = None
-        except psycopg2.OperationalError as e:
+                # Wait for the first connection to be ready, and raise the
+                # appropriate exception if connection fails
+                self._pool.open(wait=True, timeout=1)
+            else:
+                self._pool = None
+                self._db = Db(db)
+        except psycopg.OperationalError as e:
             raise ProvenanceDBError(e)
 
-    def get_db(self):
+    def get_db(self) -> Db:
         if self._db:
             return self._db
         else:
-            db = Db.from_pool(self._pool)
-
-            # Workaround for psycopg2 < 2.9.0 not handling fractional timezones,
-            # which may happen on old revision/release dates on systems configured
-            # with non-UTC timezones.
-            # https://www.psycopg.org/docs/usage.html#time-zones-handling
-            db.cursor().execute("SET TIME ZONE 'UTC'")
-
-            return db
+            assert self._pool is not None
+            return Db.from_pool(self._pool)
 
-    def put_db(self, db):
+    def put_db(self, db: Db):
         if db is not self._db:
             db.put_conn()
 
@@ -74,7 +78,7 @@ class PostgresqlProvenance:
 
     @db_transaction()
     def check_config(self, *, check_write: bool, db: Db, cur=None) -> bool:
-        dbversion = swh_db_version(db.conn.dsn)
+        dbversion = swh_db_version(db.conn)
         if dbversion != self.current_version:
             logger.warning(
                 "database dbversion (%s) != %s current_version (%s)",
-- 
GitLab