From 7f12aa271c035687d44426702bc815e247f1acb5 Mon Sep 17 00:00:00 2001
From: Pierre-Yves David <pierre-yves.david@octobus.net>
Date: Mon, 23 Dec 2024 16:35:31 +0100
Subject: [PATCH] Migration to psycopg3

---
 requirements-swh.txt                         |  4 +--
 swh/indexer/storage/__init__.py              | 22 +++++++++--------
 swh/indexer/storage/api/server.py            |  3 ---
 swh/indexer/storage/db.py                    | 26 +++++++++++++++-----
 swh/indexer/tests/storage/test_api_client.py |  6 ++---
 5 files changed, 37 insertions(+), 24 deletions(-)

diff --git a/requirements-swh.txt b/requirements-swh.txt
index 69a70822..c3e4d6e2 100644
--- a/requirements-swh.txt
+++ b/requirements-swh.txt
@@ -1,5 +1,5 @@
-swh.core[db,http] >= 3.6.1
+swh.core[db,http] >= 4.0.0
 swh.model >= 6.13.0
 swh.objstorage >= 2.3.1
-swh.storage >= 2.0.0
+swh.storage >= 3.0.0
 swh.journal >= 0.1.0
diff --git a/swh/indexer/storage/__init__.py b/swh/indexer/storage/__init__.py
index 33d59b57..8e4b6268 100644
--- a/swh/indexer/storage/__init__.py
+++ b/swh/indexer/storage/__init__.py
@@ -9,8 +9,8 @@ from typing import Dict, Iterable, List, Optional, Tuple, Union
 import warnings
 
 import attr
-import psycopg2
-import psycopg2.pool
+import psycopg
+import psycopg_pool
 
 from swh.core.db.common import db_transaction
 from swh.indexer.storage.interface import IndexerStorageInterface
@@ -129,22 +129,24 @@ class IndexerStorage:
     def __init__(self, db, min_pool_conns=1, max_pool_conns=10, journal_writer=None):
         """
         Args:
-            db: either a libpq connection string, or a psycopg2 connection
+            db: either a libpq connection string, or a psycopg connection
             journal_writer: configuration passed to
                             `swh.journal.writer.get_journal_writer`
 
         """
         self.journal_writer = JournalWriter(journal_writer)
         try:
-            if isinstance(db, psycopg2.extensions.connection):
-                self._pool = None
-                self._db = Db(db)
-            else:
-                self._pool = psycopg2.pool.ThreadedConnectionPool(
-                    min_pool_conns, max_pool_conns, db
+            if isinstance(db, str):
+                self._pool = psycopg_pool.ConnectionPool(
+                    conninfo=db,
+                    min_size=min_pool_conns,
+                    max_size=max_pool_conns,
                 )
                 self._db = None
-        except psycopg2.OperationalError as e:
+            else:
+                self._pool = None
+                self._db = Db(db)
+        except psycopg.OperationalError as e:
             raise StorageDBError(e)
 
     def get_db(self):
diff --git a/swh/indexer/storage/api/server.py b/swh/indexer/storage/api/server.py
index 1f4dbcd8..dab5075e 100644
--- a/swh/indexer/storage/api/server.py
+++ b/swh/indexer/storage/api/server.py
@@ -43,9 +43,6 @@ def my_error_handler(exception):
     return error_handler(exception, encode_data)
 
 
-app.setup_psycopg2_errorhandlers()
-
-
 @app.errorhandler(IndexerStorageArgumentException)
 def argument_error_handler(exception):
     return error_handler(exception, encode_data, status_code=400)
diff --git a/swh/indexer/storage/db.py b/swh/indexer/storage/db.py
index c9885a2c..5d706a20 100644
--- a/swh/indexer/storage/db.py
+++ b/swh/indexer/storage/db.py
@@ -3,14 +3,27 @@
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
 
-from typing import Dict, Iterable, Iterator, List
+from typing import Any, Dict, Iterable, Iterator, List
+
+from psycopg import Cursor
 
 from swh.core.db import BaseDb
-from swh.core.db.db_utils import execute_values_generator, stored_procedure
+from swh.core.db.db_utils import stored_procedure
 
 from .interface import Sha1
 
 
+def execute_values_generator(
+    cur: Cursor, query: str, values: Iterable[Any]
+) -> Iterator[Any]:
+    cur.executemany(query, values, returning=True)
+    if cur.pgresult is None:
+        return
+    yield from cur.fetchall()
+    while cur.nextset():
+        yield from cur.fetchall()
+
+
 class Db(BaseDb):
     """Proxy to the SWH Indexer DB, with wrappers around stored procedures"""
 
@@ -32,17 +45,18 @@ class Db(BaseDb):
         """
         cur = self._cursor(cur)
         keys = ", ".join(hash_keys)
+        values_place_holder = ", ".join(["%s"] * len(hash_keys))
         equality = " AND ".join(("t.%s = c.%s" % (key, key)) for key in hash_keys)
         yield from execute_values_generator(
             cur,
             """
-            select %s from (values %%s) as t(%s)
+            select %s from (values (%s)) as t(%s)
             where not exists (
                 select 1 from %s c
                 where %s
             )
             """
-            % (keys, keys, table, equality),
+            % (keys, values_place_holder, keys, table, equality),
             (tuple(m[k] for k in hash_keys) for m in data),
         )
 
@@ -114,7 +128,7 @@ class Db(BaseDb):
         keys = map(self._convert_key, cols)
         query = """
             select {keys}
-            from (values %s) as t(id)
+            from (values (%s)) as t(id)
             inner join {table} c
                 on c.{id_col}=t.id
             inner join indexer_configuration i
@@ -206,7 +220,7 @@ class Db(BaseDb):
             cur,
             """
             select %s
-            from (values %%s) as t(id)
+            from (values (%%s)) as t(id)
             inner join content_fossology_license c on t.id=c.id
             inner join indexer_configuration i
                 on i.id=c.indexer_configuration_id
diff --git a/swh/indexer/tests/storage/test_api_client.py b/swh/indexer/tests/storage/test_api_client.py
index 251e5988..ef5f515f 100644
--- a/swh/indexer/tests/storage/test_api_client.py
+++ b/swh/indexer/tests/storage/test_api_client.py
@@ -3,7 +3,7 @@
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
 
-import psycopg2
+import psycopg
 import pytest
 
 from swh.core.api import RemoteException, TransientRemoteException
@@ -79,7 +79,7 @@ def test_operationalerror_exception(app_server, swh_indexer_storage, mocker):
     mocker.patch.object(
         app_server.storage,
         "content_mimetype_get",
-        side_effect=psycopg2.errors.AdminShutdown("cluster is shutting down"),
+        side_effect=psycopg.errors.AdminShutdown("cluster is shutting down"),
     )
     with pytest.raises(RemoteException) as excinfo:
         swh_indexer_storage.content_mimetype_get([b"\x01" * 20])
@@ -94,7 +94,7 @@ def test_querycancelled_exception(app_server, swh_indexer_storage, mocker):
     mocker.patch.object(
         app_server.storage,
         "content_mimetype_get",
-        side_effect=psycopg2.errors.QueryCanceled("too big!"),
+        side_effect=psycopg.errors.QueryCanceled("too big!"),
     )
     with pytest.raises(RemoteException) as excinfo:
         swh_indexer_storage.content_mimetype_get([b"\x01" * 20])
-- 
GitLab