Skip to content
Snippets Groups Projects
Commit 25fee948 authored by David Douard's avatar David Douard
Browse files

Make get_sql_for_package() return pathlib.Path objects

parent 5822dab0
No related branches found
No related tags found
No related merge requests found
......@@ -5,10 +5,10 @@
from datetime import datetime, timezone
import functools
import glob
from importlib import import_module
import logging
from os import path
import pathlib
import re
import subprocess
from typing import Collection, Dict, List, Optional, Tuple, Union
......@@ -181,7 +181,7 @@ def swh_db_upgrade(
sqlfiles = [
fname
for fname in get_sql_for_package(modname, upgrade=True)
if db_version < int(path.splitext(path.basename(fname))[0]) <= to_version
if db_version < int(fname.stem) <= to_version
]
for sqlfile in sqlfiles:
......@@ -293,7 +293,9 @@ def swh_set_db_module(
return None
sqlfiles = [
fname for fname in get_sql_for_package("swh.core.db") if "dbmodule" in fname
fname
for fname in get_sql_for_package("swh.core.db")
if "dbmodule" in fname.stem
]
execute_sqlfiles(sqlfiles, db_or_conninfo)
......@@ -496,7 +498,7 @@ def import_swhmodule(modname):
return m
def get_sql_for_package(modname: str, upgrade: bool = False) -> List[str]:
def get_sql_for_package(modname: str, upgrade: bool = False) -> List[pathlib.Path]:
"""Return the (sorted) list of sql script files for the given swh module
If upgrade is True, return the list of available migration scripts,
......@@ -505,14 +507,15 @@ def get_sql_for_package(modname: str, upgrade: bool = False) -> List[str]:
m = import_swhmodule(modname)
if m is None:
raise ValueError(f"Module {modname} cannot be loaded")
sqldir = path.join(path.dirname(m.__file__), "sql")
sqldir = pathlib.Path(m.__file__).parent / "sql"
if upgrade:
sqldir += "/upgrades"
if not path.isdir(sqldir):
sqldir /= "upgrades"
if not sqldir.is_dir():
raise ValueError(
"Module {} does not provide a db schema " "(no sql/ dir)".format(modname)
"Module {} does not provide a db schema (no sql/ dir)".format(modname)
)
return sorted(glob.glob(path.join(sqldir, "*.sql")), key=sortkey)
return sorted(sqldir.glob("*.sql"), key=lambda x: sortkey(x.name))
def populate_database_for_package(
......@@ -541,8 +544,8 @@ def populate_database_for_package(
return sortkey(path.basename(key))
sqlfiles = get_sql_for_package(modname) + get_sql_for_package("swh.core.db")
sqlfiles = sorted(sqlfiles, key=globalsortkey)
sqlfiles = [fname for fname in sqlfiles if "-superuser-" not in fname]
sqlfiles = sorted(sqlfiles, key=lambda x: sortkey(x.stem))
sqlfiles = [fpath for fpath in sqlfiles if "-superuser-" not in fpath.stem]
execute_sqlfiles(sqlfiles, conninfo, flavor)
# populate the dbmodule table
......@@ -581,7 +584,7 @@ def init_admin_extensions(modname: str, conninfo: str) -> None:
"""
sqlfiles = get_sql_for_package(modname)
sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname]
sqlfiles = [fname for fname in sqlfiles if "-superuser-" in fname.stem]
execute_sqlfiles(sqlfiles, conninfo)
......@@ -621,7 +624,7 @@ def create_database_for_package(
def execute_sqlfiles(
sqlfiles: Collection[str], conninfo: str, flavor: Optional[str] = None
sqlfiles: Collection[pathlib.Path], conninfo: str, flavor: Optional[str] = None
):
"""Execute a list of SQL files on the database pointed at with ``conninfo``.
......@@ -643,9 +646,13 @@ def execute_sqlfiles(
flavor_set = False
for sqlfile in sqlfiles:
logger.debug(f"execute SQL file {sqlfile} dbname={conninfo}")
subprocess.check_call(psql_command + ["-f", sqlfile])
subprocess.check_call(psql_command + ["-f", str(sqlfile)])
if flavor is not None and not flavor_set and sqlfile.endswith("-flavor.sql"):
if (
flavor is not None
and not flavor_set
and sqlfile.name.endswith("-flavor.sql")
):
logger.debug("Setting database flavor %s", flavor)
query = f"insert into dbflavor (flavor) values ('{flavor}')"
subprocess.check_call(psql_command + ["-c", query])
......
......@@ -3,8 +3,8 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import glob
import os
import pathlib
from click.testing import CliRunner
from hypothesis import HealthCheck
......@@ -45,12 +45,10 @@ def mock_package_sql(mocker, datadir):
def get_sql_for_package_mock(modname, upgrade=False):
if modname.startswith("test."):
sqldir = modname.split(".", 1)[1]
sqldir = pathlib.Path(datadir) / modname.split(".", 1)[1]
if upgrade:
sqldir += "/upgrades"
return sorted(
glob.glob(os.path.join(datadir, sqldir, "*.sql")), key=sortkey
)
sqldir /= "upgrades"
return sorted(sqldir.glob("*.sql"), key=lambda x: sortkey(x.name))
return get_sql_for_package(modname)
mock_sql_files = mocker.patch(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment