Skip to content
Snippets Groups Projects
Verified Commit 8e4026a5 authored by Antoine R. Dumont's avatar Antoine R. Dumont
Browse files

swh.vault: Unify get_vault factory function with other factories

Related to T1410
parent b5cecff6
No related branches found
No related tags found
1 merge request!65swh.vault: Unify get_vault factory function with other factories
......@@ -5,3 +5,4 @@ dulwich >= 0.18.7
swh.loader.core
swh.loader.git >= 0.0.52
swh.storage[testing]
pytest-mock
# Copyright (C) 2018 The Software Heritage developers
# Copyright (C) 2018-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU Affero General Public License version 3, or any later version
# See top-level LICENSE file for more information
from __future__ import annotations
import importlib
import logging
from typing import Dict
import warnings
logger = logging.getLogger(__name__)
def get_vault(cls="remote", args={}):
BACKEND_TYPES: Dict[str, str] = {
"remote": ".api.client.RemoteVaultClient",
"local": ".backend.VaultBackend",
}
def get_vault(cls: str = "remote", **kwargs):
"""
Get a vault object of class `vault_class` with arguments
`vault_args`.
Args:
vault (dict): dictionary with keys:
- cls (str): vault's class, either 'remote'
- args (dict): dictionary with keys
cls: vault's class, either 'remote' or 'local'
kwargs: arguments to pass to the class' constructor
Returns:
an instance of VaultBackend (either local or remote)
......@@ -24,18 +35,20 @@ def get_vault(cls="remote", args={}):
ValueError if passed an unknown storage class.
"""
if cls == "remote":
from .api.client import RemoteVaultClient as Vault
elif cls == "local":
from swh.scheduler import get_scheduler
from swh.storage import get_storage
from swh.vault.backend import VaultBackend as Vault
from swh.vault.cache import VaultCache
args["cache"] = VaultCache(**args["cache"])
args["storage"] = get_storage(**args["storage"])
args["scheduler"] = get_scheduler(**args["scheduler"])
else:
raise ValueError("Unknown storage class `%s`" % cls)
logger.debug("Instantiating %s with %s" % (Vault, args))
return Vault(**args)
if "args" in kwargs:
warnings.warn(
'Explicit "args" key is deprecated, use keys directly instead.',
DeprecationWarning,
)
kwargs = kwargs["args"]
class_path = BACKEND_TYPES.get(cls)
if class_path is None:
raise ValueError(
f"Unknown Vault class `{cls}`. " f"Supported: {', '.join(BACKEND_TYPES)}"
)
(module_path, class_name) = class_path.rsplit(".", 1)
module = importlib.import_module(module_path, package=__package__)
Vault = getattr(module, class_name)
return Vault(**kwargs)
......@@ -214,7 +214,7 @@ def get_local_backend(cfg):
if not args.get(key):
raise ValueError("invalid configuration; missing %s config entry." % key)
return get_vault("local", args)
return get_vault("local", **args)
def make_app_from_configfile(config_file=None, **kwargs):
......
# Copyright (C) 2017-2018 The Software Heritage developers
# Copyright (C) 2017-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
......@@ -12,7 +12,10 @@ import psycopg2.pool
from swh.core.db import BaseDb
from swh.core.db.common import db_transaction
from swh.model import hashutil
from swh.scheduler import get_scheduler
from swh.scheduler.utils import create_oneshot_task_dict
from swh.storage import get_storage
from swh.vault.cache import VaultCache
from swh.vault.cookers import get_cooker_cls
from swh.vault.exc import NotFoundExc
......@@ -67,11 +70,11 @@ class VaultBackend:
Backend for the Software Heritage vault.
"""
def __init__(self, db, cache, scheduler, storage=None, **config):
def __init__(self, db, **config):
self.config = config
self.cache = cache
self.scheduler = scheduler
self.storage = storage
self.cache = VaultCache(**config["cache"])
self.scheduler = get_scheduler(**config["scheduler"])
self.storage = get_storage(**config["storage"])
self.smtp_server = smtplib.SMTP()
self._pool = psycopg2.pool.ThreadedConnectionPool(
......
import glob
import os
import subprocess
from typing import Any, Dict
import pkg_resources.extern.packaging.version
import pytest
......@@ -38,8 +39,28 @@ postgresql2 = factories.postgresql("postgresql_proc", "tests2")
@pytest.fixture
def swh_vault(request, postgresql_proc, postgresql, postgresql2, tmp_path):
def swh_vault_config(postgresql, postgresql2, tmp_path) -> Dict[str, Any]:
tmp_path = str(tmp_path)
return {
"db": postgresql.dsn,
"storage": {
"cls": "local",
"db": postgresql2.dsn,
"objstorage": {
"cls": "pathslicing",
"args": {"root": tmp_path, "slicing": "0:1/1:5",},
},
},
"cache": {
"cls": "pathslicing",
"args": {"root": tmp_path, "slicing": "0:1/1:5", "allow_delete": True,},
},
"scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",},
}
@pytest.fixture
def swh_vault(request, swh_vault_config, postgresql, postgresql2, tmp_path):
for sql_dir, pg in ((SQL_DIR, postgresql), (STORAGE_SQL_DIR, postgresql2)):
dump_files = os.path.join(sql_dir, "*.sql")
all_dump_files = sorted(glob.glob(dump_files), key=sortkey)
......@@ -59,28 +80,7 @@ def swh_vault(request, postgresql_proc, postgresql, postgresql2, tmp_path):
]
)
vault_config = {
"db": db_url("tests", postgresql_proc),
"storage": {
"cls": "local",
"db": db_url("tests2", postgresql_proc),
"objstorage": {
"cls": "pathslicing",
"args": {"root": str(tmp_path), "slicing": "0:1/1:5",},
},
},
"cache": {
"cls": "pathslicing",
"args": {
"root": str(tmp_path),
"slicing": "0:1/1:5",
"allow_delete": True,
},
},
"scheduler": {"cls": "remote", "url": "http://swh-scheduler:5008",},
}
return get_vault("local", vault_config)
return get_vault("local", **swh_vault_config)
@pytest.fixture
......
# Copyright (C) 2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import pytest
from swh.vault import get_vault
from swh.vault.api.client import RemoteVaultClient
from swh.vault.backend import VaultBackend
SERVER_IMPLEMENTATIONS = [
("remote", RemoteVaultClient, {"url": "localhost"}),
(
"local",
VaultBackend,
{
"db": "something",
"cache": {"cls": "memory", "args": {}},
"storage": {"cls": "remote", "url": "mock://storage-url"},
"scheduler": {"cls": "remote", "url": "mock://scheduler-url"},
},
),
]
@pytest.fixture
def mock_psycopg2(mocker):
mocker.patch("swh.vault.backend.psycopg2.pool")
mocker.patch("swh.vault.backend.psycopg2.extras")
def test_init_get_vault_failure():
with pytest.raises(ValueError, match="Unknown Vault class"):
get_vault("unknown-vault-storage")
@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS)
def test_init_get_vault(class_name, expected_class, kwargs, mock_psycopg2):
concrete_vault = get_vault(class_name, **kwargs)
assert isinstance(concrete_vault, expected_class)
@pytest.mark.parametrize("class_name,expected_class,kwargs", SERVER_IMPLEMENTATIONS)
def test_init_get_vault_deprecation_warning(
class_name, expected_class, kwargs, mock_psycopg2
):
with pytest.warns(DeprecationWarning):
concrete_vault = get_vault(class_name, args=kwargs)
assert isinstance(concrete_vault, expected_class)
def test_init_get_vault_ok(swh_vault_config):
concrete_vault = get_vault("local", **swh_vault_config)
assert isinstance(concrete_vault, VaultBackend)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment