Skip to content
Snippets Groups Projects
Verified Commit 29701977 authored by Antoine R. Dumont's avatar Antoine R. Dumont
Browse files

db_utils: Make connect_to_conninfo use through contextmanager

This allows to reduce the boilerplate regarding initial connection failures.
parent 95709cff
No related branches found
No related tags found
1 merge request!269db_utils: Make connect_to_conninfo use through contextmanager
......@@ -3,6 +3,7 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from contextlib import contextmanager
from datetime import datetime, timezone
import functools
from importlib import import_module
......@@ -11,7 +12,7 @@ from os import path
import pathlib
import re
import subprocess
from typing import Collection, Dict, List, Optional, Tuple, Union, cast
from typing import Collection, Dict, Iterator, List, Optional, Tuple, Union, cast
import psycopg2
import psycopg2.errors
......@@ -58,28 +59,32 @@ def jsonize(value):
return value
def connect_to_conninfo(db_or_conninfo: Union[str, pgconnection]) -> pgconnection:
"""Connect to the database passed in argument
@contextmanager
def connect_to_conninfo(
db_or_conninfo: Union[str, pgconnection]
) -> Iterator[pgconnection]:
"""Connect to the database passed as argument.
Args:
db_or_conninfo: A database connection, or a database connection info string
Returns:
a connected database handle
a connected database handle or None if the database is not initialized
Raises:
psycopg2.Error if the database doesn't exist
"""
if isinstance(db_or_conninfo, pgconnection):
return db_or_conninfo
if "=" not in db_or_conninfo and "//" not in db_or_conninfo:
# Database name
db_or_conninfo = f"dbname={db_or_conninfo}"
db = psycopg2.connect(db_or_conninfo)
yield db_or_conninfo
else:
if "=" not in db_or_conninfo and "//" not in db_or_conninfo:
# Database name
db_or_conninfo = f"dbname={db_or_conninfo}"
return db
try:
db = psycopg2.connect(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
else:
yield db
def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]:
......@@ -94,22 +99,18 @@ def swh_db_version(db_or_conninfo: Union[str, pgconnection]) -> Optional[int]:
Either the version of the database, or None if it couldn't be detected
"""
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
try:
with db.cursor() as c:
query = "select version from dbversion order by dbversion desc limit 1"
try:
c.execute(query)
result = c.fetchone()
if result:
return result[0]
except psycopg2.errors.UndefinedTable:
with connect_to_conninfo(db_or_conninfo) as db:
if not db:
return None
with db.cursor() as c:
query = "select version from dbversion order by dbversion desc limit 1"
try:
c.execute(query)
result = c.fetchone()
if result:
return result[0]
except psycopg2.errors.UndefinedTable:
return None
except Exception:
logger.exception("Could not get version from `%s`", db_or_conninfo)
return None
......@@ -129,23 +130,19 @@ def swh_db_versions(
Either the version of the database, or None if it couldn't be detected
"""
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
try:
with db.cursor() as c:
query = (
"select version, release, description "
"from dbversion order by dbversion desc"
)
try:
c.execute(query)
return cast(List[Tuple[int, datetime, str]], c.fetchall())
except psycopg2.errors.UndefinedTable:
with connect_to_conninfo(db_or_conninfo) as db:
if not db:
return None
with db.cursor() as c:
query = (
"select version, release, description "
"from dbversion order by dbversion desc"
)
try:
c.execute(query)
return cast(List[Tuple[int, datetime, str]], c.fetchall())
except psycopg2.errors.UndefinedTable:
return None
except Exception:
logger.exception("Could not get versions from `%s`", db_or_conninfo)
return None
......@@ -238,22 +235,18 @@ def swh_db_module(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]:
Either the module of the database, or None if it couldn't be detected
"""
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
try:
with db.cursor() as c:
query = "select dbmodule from dbmodule limit 1"
try:
c.execute(query)
resp = c.fetchone()
if resp:
return resp[0]
except psycopg2.errors.UndefinedTable:
with connect_to_conninfo(db_or_conninfo) as db:
if not db:
return None
with db.cursor() as c:
query = "select dbmodule from dbmodule limit 1"
try:
c.execute(query)
resp = c.fetchone()
if resp:
return resp[0]
except psycopg2.errors.UndefinedTable:
return None
except Exception:
logger.exception("Could not get module from `%s`", db_or_conninfo)
return None
......@@ -289,27 +282,25 @@ def swh_set_db_module(
)
# force is True
update = True
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
sqlfiles = [
fname
for fname in get_sql_for_package("swh.core.db")
if "dbmodule" in fname.stem
]
execute_sqlfiles(sqlfiles, db_or_conninfo)
with connect_to_conninfo(db_or_conninfo) as db:
if not db:
return None
with db.cursor() as c:
if update:
query = "update dbmodule set dbmodule = %s"
else:
query = "insert into dbmodule(dbmodule) values (%s)"
c.execute(query, (module,))
db.commit()
sqlfiles = [
fname
for fname in get_sql_for_package("swh.core.db")
if "dbmodule" in fname.stem
]
execute_sqlfiles(sqlfiles, db_or_conninfo)
with db.cursor() as c:
if update:
query = "update dbmodule set dbmodule = %s"
else:
query = "insert into dbmodule(dbmodule) values (%s)"
c.execute(query, (module,))
db.commit()
def swh_set_db_version(
......@@ -326,20 +317,19 @@ def swh_set_db_version(
db_or_conninfo: A database connection, or a database connection info string
version: the version to add
"""
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
if ts is None:
ts = now()
with db.cursor() as c:
query = (
"insert into dbversion(version, release, description) values (%s, %s, %s)"
)
c.execute(query, (version, ts, desc))
db.commit()
with connect_to_conninfo(db_or_conninfo) as db:
if not db:
return None
with db.cursor() as c:
query = (
"insert into dbversion(version, release, description) "
"values (%s, %s, %s)"
)
c.execute(query, (version, ts, desc))
db.commit()
def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]:
......@@ -355,23 +345,19 @@ def swh_db_flavor(db_or_conninfo: Union[str, pgconnection]) -> Optional[str]:
The flavor of the database, or None if it could not be detected.
"""
try:
db = connect_to_conninfo(db_or_conninfo)
except psycopg2.Error:
logger.exception("Failed to connect to `%s`", db_or_conninfo)
# Database not initialized
return None
try:
with db.cursor() as c:
query = "select swh_get_dbflavor()"
try:
c.execute(query)
result = c.fetchone()
assert result is not None # to keep mypy happy
return result[0]
except psycopg2.errors.UndefinedFunction:
# function not found: no flavor
with connect_to_conninfo(db_or_conninfo) as db:
if not db:
return None
with db.cursor() as c:
query = "select swh_get_dbflavor()"
try:
c.execute(query)
result = c.fetchone()
assert result is not None # to keep mypy happy
return result[0]
except psycopg2.errors.UndefinedFunction:
# function not found: no flavor
return None
except Exception:
logger.exception("Could not get flavor from `%s`", db_or_conninfo)
return None
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment