Skip to content
Snippets Groups Projects
Commit 2986827f authored by Nicolas Dandrimont's avatar Nicolas Dandrimont
Browse files

winery: Only run database administration operations once per process

By keeping track of created databases, created tables, and connection
pools, we cut down drastically on the admin operations performed while
the code is running.
parent d3341426
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,9 @@
import abc
from contextlib import contextmanager
import logging
import os
import time
from typing import Dict, Set, Tuple
import psycopg
import psycopg.errors
......@@ -14,6 +16,15 @@ from psycopg_pool import ConnectionPool
logger = logging.getLogger(__name__)
POOLS: Dict[Tuple[int, str, str, str], ConnectionPool] = {}
"""Maps a tuple (pid, conninfo, dbname, application_name) to the matching ConnectionPool"""
DATABASES_CREATED: Set[Tuple[str, str]] = set()
"""Set of (conninfo, dbname) entries for databases that we know have been created"""
TABLES_CREATED: Set[Tuple[str, str]] = set()
"""Set of (conninfo, dbname) entries for databases for which we know tables have been created"""
class DatabaseAdmin:
def __init__(self, dsn, dbname=None, application_name=None):
......@@ -37,6 +48,9 @@ class DatabaseAdmin:
c.close()
def create_database(self):
if (self.dsn, self.dbname) in DATABASES_CREATED:
return
logger.debug("database %s: create", self.dbname)
with self.admin_cursor() as c:
c.execute(
......@@ -53,8 +67,10 @@ class DatabaseAdmin:
# someone else created the database, it is fine
pass
DATABASES_CREATED.add((self.dsn, self.dbname))
def drop_database(self):
logger.debug("database %s: drop", self.dbname)
logger.debug("database %s/%s: drop", self.dsn, self.dbname)
with self.admin_cursor() as c:
c.execute(
"SELECT pg_terminate_backend(pg_stat_activity.pid)"
......@@ -84,12 +100,16 @@ class DatabaseAdmin:
for i in range(60):
try:
c.execute(f"DROP DATABASE IF EXISTS {self.dbname}")
return
break
except psycopg.errors.ObjectInUse:
logger.warning(f"{self.dbname} database drop fails, waiting 10s")
time.sleep(10)
continue
raise Exception(f"database drop failed on {self.dbname}")
else:
raise Exception(f"database drop failed on {self.dbname}")
DATABASES_CREATED.discard((self.dsn, self.dbname))
TABLES_CREATED.discard((self.dsn, self.dbname))
def list_databases(self):
with self.admin_cursor() as c:
......@@ -105,20 +125,24 @@ class Database(abc.ABC):
self.dsn = dsn
self.dbname = dbname
self.application_name = application_name
self.pool = ConnectionPool(
conninfo=self.dsn,
kwargs={
"dbname": self.dbname,
"application_name": self.application_name,
"fallback_application_name": "SWH Winery",
"autocommit": True,
},
min_size=0,
max_size=4,
open=True,
max_idle=5,
check=ConnectionPool.check_connection,
)
pool_key = (os.getpid(), self.dsn, self.dbname, self.application_name)
if pool_key not in POOLS:
POOLS[pool_key] = ConnectionPool(
conninfo=self.dsn,
kwargs={
"dbname": self.dbname,
"application_name": self.application_name,
"fallback_application_name": "SWH Winery",
"autocommit": True,
},
min_size=0,
max_size=4,
open=True,
max_idle=5,
check=ConnectionPool.check_connection,
)
self.pool = POOLS[pool_key]
@property
@abc.abstractmethod
......@@ -132,10 +156,19 @@ class Database(abc.ABC):
"Return the list of CREATE TABLE statements for all tables in the database"
raise NotImplementedError("Database.database_tables")
def uninit(self):
pass
def create_tables(self):
if (self.dsn, self.dbname) in TABLES_CREATED:
return
logger.debug("database %s: create tables", self.dbname)
logger.debug("pool stats: %s", self.pool.get_stats())
with self.pool.connection() as db:
db.execute("SELECT pg_advisory_lock(%s)", (self.lock,))
for table in self.database_tables:
db.execute(table)
db.execute("SELECT pg_advisory_unlock(%s)", (self.lock,))
TABLES_CREATED.add((self.dsn, self.dbname))
......@@ -30,9 +30,6 @@ class RWShard(Database):
self.size = self.total_size()
self.limit = kwargs["shard_max_size"]
def uninit(self):
self.pool.close()
@property
def lock(self):
return 452343 # an arbitrary unique number
......
......@@ -68,6 +68,7 @@ class SharedBase(Database):
if self._locked_shard is not None:
self.set_shard_state(new_state=ShardState.STANDBY)
self._locked_shard = None
super().uninit()
@property
def lock(self):
......
......@@ -128,9 +128,6 @@ class IOThrottler(Database):
)
self.sync_interval = 60
def uninit(self):
self.pool.close()
@property
def lock(self):
return 9485433 # an arbitrary unique number
......@@ -139,12 +136,13 @@ class IOThrottler(Database):
def database_tables(self):
return [
f"""
CREATE TABLE IF NOT EXISTS t_{self.name}(
CREATE TABLE IF NOT EXISTS t_{name} (
id SERIAL PRIMARY KEY,
updated TIMESTAMP NOT NULL,
bytes INTEGER NOT NULL
)
""",
"""
for name in ("read", "write")
]
def download_info(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment