From 8e4026a5e8c9d43832cd3b038b9175f3f4914407 Mon Sep 17 00:00:00 2001
From: "Antoine R. Dumont (@ardumont)" <ardumont@softwareheritage.org>
Date: Fri, 16 Oct 2020 19:34:02 +0200
Subject: [PATCH] swh.vault: Unify get_vault factory function with other
 factories

Related to T1410
---
 requirements-test.txt        |  1 +
 swh/vault/__init__.py        | 53 +++++++++++++++++++++-------------
 swh/vault/api/server.py      |  2 +-
 swh/vault/backend.py         | 13 +++++----
 swh/vault/tests/conftest.py  | 46 +++++++++++++++---------------
 swh/vault/tests/test_init.py | 55 ++++++++++++++++++++++++++++++++++++
 6 files changed, 121 insertions(+), 49 deletions(-)
 create mode 100644 swh/vault/tests/test_init.py

diff --git a/requirements-test.txt b/requirements-test.txt
index 078a4e3..66b4544 100644
--- a/requirements-test.txt
+++ b/requirements-test.txt
@@ -5,3 +5,4 @@ dulwich >= 0.18.7
 swh.loader.core
 swh.loader.git >= 0.0.52
 swh.storage[testing]
+pytest-mock
diff --git a/swh/vault/__init__.py b/swh/vault/__init__.py
index a39a171..db16ff9 100644
--- a/swh/vault/__init__.py
+++ b/swh/vault/__init__.py
@@ -1,21 +1,32 @@
-# Copyright (C) 2018  The Software Heritage developers
+# Copyright (C) 2018-2020  The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU Affero General Public License version 3, or any later version
 # See top-level LICENSE file for more information
+
+from __future__ import annotations
+
+import importlib
 import logging
+from typing import Dict
+import warnings
 
 logger = logging.getLogger(__name__)
 
 
-def get_vault(cls="remote", args={}):
+BACKEND_TYPES: Dict[str, str] = {
+    "remote": ".api.client.RemoteVaultClient",
+    "local": ".backend.VaultBackend",
+}
+
+
+def get_vault(cls: str = "remote", **kwargs):
     """
     Get a vault object of class `vault_class` with arguments
     `vault_args`.
 
     Args:
-        vault (dict): dictionary with keys:
-        - cls (str): vault's class, either 'remote'
-        - args (dict): dictionary with keys
+        cls: vault's class, either 'remote' or 'local'
+        kwargs: arguments to pass to the class' constructor
 
     Returns:
         an instance of VaultBackend (either local or remote)
@@ -24,18 +35,20 @@ def get_vault(cls="remote", args={}):
         ValueError if passed an unknown storage class.
 
     """
-    if cls == "remote":
-        from .api.client import RemoteVaultClient as Vault
-    elif cls == "local":
-        from swh.scheduler import get_scheduler
-        from swh.storage import get_storage
-        from swh.vault.backend import VaultBackend as Vault
-        from swh.vault.cache import VaultCache
-
-        args["cache"] = VaultCache(**args["cache"])
-        args["storage"] = get_storage(**args["storage"])
-        args["scheduler"] = get_scheduler(**args["scheduler"])
-    else:
-        raise ValueError("Unknown storage class `%s`" % cls)
-    logger.debug("Instantiating %s with %s" % (Vault, args))
-    return Vault(**args)
+    if "args" in kwargs:
+        warnings.warn(
+            'Explicit "args" key is deprecated, use keys directly instead.',
+            DeprecationWarning,
+        )
+        kwargs = kwargs["args"]
+
+    class_path = BACKEND_TYPES.get(cls)
+    if class_path is None:
+        raise ValueError(
+            f"Unknown Vault class `{cls}`. " f"Supported: {', '.join(BACKEND_TYPES)}"
+        )
+
+    (module_path, class_name) = class_path.rsplit(".", 1)
+    module = importlib.import_module(module_path, package=__package__)
+    Vault = getattr(module, class_name)
+    return Vault(**kwargs)
diff --git a/swh/vault/api/server.py b/swh/vault/api/server.py
index 6c178e0..6440fc2 100644
--- a/swh/vault/api/server.py
+++ b/swh/vault/api/server.py
@@ -214,7 +214,7 @@ def get_local_backend(cfg):
         if not args.get(key):
             raise ValueError("invalid configuration; missing %s config entry." % key)
 
-    return get_vault("local", args)
+    return get_vault("local", **args)
 
 
 def make_app_from_configfile(config_file=None, **kwargs):
diff --git a/swh/vault/backend.py b/swh/vault/backend.py
index 1974e9e..69d4690 100644
--- a/swh/vault/backend.py
+++ b/swh/vault/backend.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2017-2018  The Software Heritage developers
+# Copyright (C) 2017-2020  The Software Heritage developers
 # See the AUTHORS file at the top-level directory of this distribution
 # License: GNU General Public License version 3, or any later version
 # See top-level LICENSE file for more information
@@ -12,7 +12,10 @@ import psycopg2.pool
 from swh.core.db import BaseDb
 from swh.core.db.common import db_transaction
 from swh.model import hashutil
+from swh.scheduler import get_scheduler
 from swh.scheduler.utils import create_oneshot_task_dict
+from swh.storage import get_storage
+from swh.vault.cache import VaultCache
 from swh.vault.cookers import get_cooker_cls
 from swh.vault.exc import NotFoundExc
 
@@ -67,11 +70,11 @@ class VaultBackend:
     Backend for the Software Heritage vault.
     """
 
-    def __init__(self, db, cache, scheduler, storage=None, **config):
+    def __init__(self, db, **config):
         self.config = config
-        self.cache = cache
-        self.scheduler = scheduler
-        self.storage = storage
+        self.cache = VaultCache(**config["cache"])
+        self.scheduler = get_scheduler(**config["scheduler"])
+        self.storage = get_storage(**config["storage"])
         self.smtp_server = smtplib.SMTP()
 
         self._pool = psycopg2.pool.ThreadedConnectionPool(
diff --git a/swh/vault/tests/conftest.py b/swh/vault/tests/conftest.py
index 9090e46..163ca80 100644
--- a/swh/vault/tests/conftest.py
+++ b/swh/vault/tests/conftest.py
@@ -1,6 +1,7 @@
 import glob
 import os
 import subprocess
+from typing import Any, Dict
 
 import pkg_resources.extern.packaging.version
 import pytest
@@ -38,8 +39,28 @@ postgresql2 = factories.postgresql("postgresql_proc", "tests2")
 
 
 @pytest.fixture
-def swh_vault(request, postgresql_proc, postgresql, postgresql2, tmp_path):
+def swh_vault_config(postgresql, postgresql2, tmp_path) -> Dict[str, Any]:
+    tmp_path = str(tmp_path)
+    return {
+        "db": postgresql.dsn,
+        "storage": {
+            "cls": "local",
+            "db": postgresql2.dsn,
+            "objstorage": {
+                "cls": "pathslicing",
+                "args": {"root": tmp_path, "slicing": "0:1/1:5",},
+            },
+        },
+        "cache": {
+            "cls": "pathslicing",
+            "args": {"root": tmp_path, "slicing": "0:1/1:5", "allow_delete": True,},
+        },
+        "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",},
+    }
+
 
+@pytest.fixture
+def swh_vault(request, swh_vault_config, postgresql, postgresql2, tmp_path):
     for sql_dir, pg in ((SQL_DIR, postgresql), (STORAGE_SQL_DIR, postgresql2)):
         dump_files = os.path.join(sql_dir, "*.sql")
         all_dump_files = sorted(glob.glob(dump_files), key=sortkey)
@@ -59,28 +80,7 @@ def swh_vault(request, postgresql_proc, postgresql, postgresql2, tmp_path):
                 ]
             )
 
-    vault_config = {
-        "db": db_url("tests", postgresql_proc),
-        "storage": {
-            "cls": "local",
-            "db": db_url("tests2", postgresql_proc),
-            "objstorage": {
-                "cls": "pathslicing",
-                "args": {"root": str(tmp_path), "slicing": "0:1/1:5",},
-            },
-        },
-        "cache": {
-            "cls": "pathslicing",
-            "args": {
-                "root": str(tmp_path),
-                "slicing": "0:1/1:5",
-                "allow_delete": True,
-            },
-        },
-        "scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",},
-    }
-
-    return get_vault("local", vault_config)
+    return get_vault("local", **swh_vault_config)
 
 
 @pytest.fixture
diff --git a/swh/vault/tests/test_init.py b/swh/vault/tests/test_init.py
new file mode 100644
index 0000000..7f402d6
--- /dev/null
+++ b/swh/vault/tests/test_init.py
@@ -0,0 +1,55 @@
+# Copyright (C) 2020 The Software Heritage developers
+# See the AUTHORS file at the top-level directory of this distribution
+# License: GNU General Public License version 3, or any later version
+# See top-level LICENSE file for more information
+
+import pytest
+
+from swh.vault import get_vault
+from swh.vault.api.client import RemoteVaultClient
+from swh.vault.backend import VaultBackend
+
+SERVER_IMPLEMENTATIONS = [
+    ("remote", RemoteVaultClient, {"url": "localhost"}),
+    (
+        "local",
+        VaultBackend,
+        {
+            "db": "something",
+            "cache": {"cls": "memory", "args": {}},
+            "storage": {"cls": "remote", "url": "mock://storage-url"},
+            "scheduler": {"cls": "remote", "url": "mock://scheduler-url"},
+        },
+    ),
+]
+
+
+@pytest.fixture
+def mock_psycopg2(mocker):
+    mocker.patch("swh.vault.backend.psycopg2.pool")
+    mocker.patch("swh.vault.backend.psycopg2.extras")
+
+
+def test_init_get_vault_failure():
+    with pytest.raises(ValueError, match="Unknown Vault class"):
+        get_vault("unknown-vault-storage")
+
+
+@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS)
+def test_init_get_vault(class_name, expected_class, kwargs, mock_psycopg2):
+    concrete_vault = get_vault(class_name, **kwargs)
+    assert isinstance(concrete_vault, expected_class)
+
+
+@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS)
+def test_init_get_vault_deprecation_warning(
+    class_name, expected_class, kwargs, mock_psycopg2
+):
+    with pytest.warns(DeprecationWarning):
+        concrete_vault = get_vault(class_name, args=kwargs)
+    assert isinstance(concrete_vault, expected_class)
+
+
+def test_init_get_vault_ok(swh_vault_config):
+    concrete_vault = get_vault("local", **swh_vault_config)
+    assert isinstance(concrete_vault, VaultBackend)
-- 
GitLab