From 255e0719127abec2ff24328a27508febd52ba5a9 Mon Sep 17 00:00:00 2001
From: Nicolas Dandrimont <nicolas@dandrimont.eu>
Date: Wed, 23 Sep 2020 14:04:37 +0200
Subject: [PATCH] Support using a full DSN or a single database name in `swh db
 create`

This matches the functionality advertised as examples in the documentation for
the command.
---
 swh/core/cli/db.py | 23 ++++++++++++++++++-----
 1 file changed, 18 insertions(+), 5 deletions(-)

diff --git a/swh/core/cli/db.py b/swh/core/cli/db.py
index 37a1d35c..57254584 100755
--- a/swh/core/cli/db.py
+++ b/swh/core/cli/db.py
@@ -6,7 +6,7 @@
 
 import logging
 from os import environ, path
-from typing import Collection, Optional, Tuple
+from typing import Collection, Dict, Optional, Tuple
 import warnings
 
 import click
@@ -194,6 +194,19 @@ def populate_database_for_package(
     return True, current_version, dbflavor
 
 
+def parse_dsn_or_dbname(dsn_or_dbname: str) -> Dict[str, str]:
+    """Parse a psycopg2 dsn, falling back to supporting plain database names as well"""
+    import psycopg2
+    from psycopg2.extensions import parse_dsn as _parse_dsn
+
+    try:
+        return _parse_dsn(dsn_or_dbname)
+    except psycopg2.ProgrammingError:
+        # psycopg2 failed to parse the DSN; it's probably a database name,
+        # handle it as such
+        return _parse_dsn(f"dbname={dsn_or_dbname}")
+
+
 def create_database_for_package(
     modname: str, conninfo: str, template: str = "template1"
 ):
@@ -202,18 +215,18 @@ def create_database_for_package(
 
     Args:
       modname: Name of the module of which we're loading the files
-      conninfo: connection info string for the SQL database
+      conninfo: connection info string or plain database name for the SQL database
       template: the name of the database to connect to and use as template to create
                 the new database
 
     """
     import subprocess
 
-    from psycopg2.extensions import make_dsn, parse_dsn
+    from psycopg2.extensions import make_dsn
 
-    # Use the given conninfo but with dbname replaced by the template dbname
+    # Use the given conninfo string, but with dbname replaced by the template dbname
     # for the database creation step
-    creation_dsn = parse_dsn(conninfo)
+    creation_dsn = parse_dsn_or_dbname(conninfo)
     db_name = creation_dsn["dbname"]
     creation_dsn["dbname"] = template
     logger.debug("db_create db_name=%s (from %s)", db_name, template)
-- 
GitLab