From 255e0719127abec2ff24328a27508febd52ba5a9 Mon Sep 17 00:00:00 2001 From: Nicolas Dandrimont <nicolas@dandrimont.eu> Date: Wed, 23 Sep 2020 14:04:37 +0200 Subject: [PATCH] Support using a full DSN or a single database name in `swh db create` This matches the functionality advertised as examples in the documentation for the command. --- swh/core/cli/db.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py index 37a1d35c..57254584 100755 --- a/swh/core/cli/db.py +++ b/swh/core/cli/db.py @@ -6,7 +6,7 @@ import logging from os import environ, path -from typing import Collection, Optional, Tuple +from typing import Collection, Dict, Optional, Tuple import warnings import click @@ -194,6 +194,19 @@ def populate_database_for_package( return True, current_version, dbflavor +def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]: + """Parse a psycopg2 dsn, falling back to supporting plain database names as well""" + import psycopg2 + from psycopg2.extensions import parse_dsn as _parse_dsn + + try: + return _parse_dsn(dsn_or_dbname) + except psycopg2.ProgrammingError: + # psycopg2 failed to parse the DSN; it's probably a database name, + # handle it as such + return _parse_dsn(f"dbname={dsn_or_dbname}") + + def create_database_for_package( modname: str, conninfo: str, template: str = "template1" ): @@ -202,18 +215,18 @@ def create_database_for_package( Args: modname: Name of the module of which we're loading the files - conninfo: connection info string for the SQL database + conninfo: connection info string or plain database name for the SQL database template: the name of the database to connect to and use as template to create the new database """ import subprocess - from psycopg2.extensions import make_dsn, parse_dsn + from psycopg2.extensions import make_dsn - # Use the given conninfo but with dbname replaced by the template dbname + # Use the given conninfo string, but with dbname replaced by the template dbname # for the database creation step - creation_dsn = parse_dsn(conninfo) + creation_dsn = parse_dsn_or_dbname(conninfo) db_name = creation_dsn["dbname"] creation_dsn["dbname"] = template logger.debug("db_create db_name=%s (from %s)", db_name, template) -- GitLab