Skip to content
Snippets Groups Projects
Commit 933ea14a authored by David Douard's avatar David Douard
Browse files

Convert dates to tz-aware utc timestamps before inserting them in postgresql

this is required to prevent errors when dealing with weird/pathological
tz (aka unsupported by psycopg, like delta over 24h etc.).
parent 468ae961
No related branches found
No related tags found
No related merge requests found
......@@ -5,13 +5,15 @@
from __future__ import annotations
from datetime import datetime
from datetime import datetime, timezone
from typing import Iterator, List, Optional
from swh.model.model import Origin, Sha1Git
from .archive import ArchiveInterface
UTC = timezone.utc
class OriginEntry:
......@@ -57,8 +59,8 @@ class RevisionEntry:
root: Optional[Sha1Git] = None,
) -> None:
self.id = id
self.date = date
assert self.date is None or self.date.tzinfo is not None
assert date is None or date.tzinfo is not None
self.date = date.astimezone(UTC) if date is not None else None
self.root = root
def __str__(self) -> str:
......
......@@ -6,7 +6,7 @@
from __future__ import annotations
from contextlib import contextmanager
from datetime import datetime
from datetime import datetime, timezone
from functools import wraps
from hashlib import sha1
import itertools
......@@ -30,6 +30,8 @@ from swh.provenance.storage.interface import (
RevisionData,
)
UTC = timezone.utc
LOGGER = logging.getLogger(__name__)
STORAGE_DURATION_METRIC = "swh_provenance_storage_postgresql_duration_seconds"
......@@ -111,16 +113,18 @@ class ProvenanceStoragePostgreSql:
def content_add(self, cnts: Dict[Sha1Git, datetime]) -> bool:
if cnts:
# Upsert in consistent order to avoid deadlocks
rows = sorted(cnts.items())
sql = """
INSERT INTO content(sha1, date) VALUES %s
ON CONFLICT (sha1) DO
UPDATE SET date=LEAST(EXCLUDED.date,content.date)
"""
rows = [
(sha1git, date.astimezone(UTC)) for (sha1git, date) in sorted(cnts.items())
]
page_size = self.page_size or len(rows)
with self.transaction() as cursor:
psycopg2.extras.execute_values(
cursor, sql, argslist=rows, page_size=page_size
cursor, sql, argslist=rows, page_size=page_size,
)
return True
......@@ -162,8 +166,11 @@ class ProvenanceStoragePostgreSql:
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"})
@handle_raise_on_commit
def directory_add(self, dirs: Dict[Sha1Git, DirectoryData]) -> bool:
# Upsert in consistent order to avoid deadlocks
data = sorted((sha1, rev.date, rev.flat) for sha1, rev in dirs.items())
# sorted: Upsert in consistent order to avoid deadlocks
data = [
(sha1, rev.date.astimezone(UTC) if rev.date else None, rev.flat)
for sha1, rev in sorted(dirs.items())
]
if data:
sql = """
INSERT INTO directory(sha1, date, flat) VALUES %s
......@@ -291,8 +298,11 @@ class ProvenanceStoragePostgreSql:
@handle_raise_on_commit
def revision_add(self, revs: Dict[Sha1Git, RevisionData]) -> bool:
if revs:
# Upsert in consistent order to avoid deadlocks
data = sorted((sha1, rev.date, rev.origin) for sha1, rev in revs.items())
# sorted: Upsert in consistent order to avoid deadlocks
data = [
(sha1, rev.date.astimezone(UTC) if rev.date else None, rev.origin)
for sha1, rev in sorted(revs.items())
]
sql = """
INSERT INTO revision(sha1, date, origin)
(SELECT V.rev AS sha1, V.date::timestamptz AS date, O.id AS origin
......
......@@ -3,7 +3,7 @@
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
import hashlib
import inspect
import os
......@@ -31,6 +31,8 @@ from swh.provenance.storage.interface import (
from .utils import fill_storage, load_repo_data, ts2dt
UTC = timezone.utc
class TestProvenanceStorage:
def test_provenance_storage_content(
......@@ -42,13 +44,60 @@ class TestProvenanceStorage:
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
# Add all content present in the current repo to the storage, just assigning their
# creation dates. Then check that the returned results when querying are the same.
cnt_dates = {cnt["sha1_git"]: cnt["ctime"] for cnt in data["content"]}
expected_dates = {
cnt["sha1_git"]: cnt["ctime"].astimezone(UTC) for cnt in data["content"]
}
assert provenance_storage.content_add(cnt_dates)
assert provenance_storage.content_get(set(cnt_dates.keys())) == expected_dates
assert provenance_storage.entity_get_all(EntityType.CONTENT) == set(
cnt_dates.keys()
)
def test_provenance_storage_content_invalid_dates(
self,
provenance_storage: ProvenanceStorageInterface,
) -> None:
"""Tests content methods for every `ProvenanceStorageInterface` implementation."""
# Read data/README.md for more details on how these datasets are generated.
data = load_repo_data("cmdbts2")
# Add all content present in the current repo to the storage, just assigning their
# creation dates. Then check that the returned results when querying are the same.
cnt_dates = {
cnt["sha1_git"]: cnt["ctime"].replace(
tzinfo=timezone(-timedelta(hours=23, minutes=59, seconds=59))
)
for cnt in data["content"]
}
expected_dates = {
sha1_git: date.astimezone(UTC) for sha1_git, date in cnt_dates.items()
}
assert provenance_storage.content_add(cnt_dates)
assert provenance_storage.content_get(set(cnt_dates.keys())) == expected_dates
assert provenance_storage.entity_get_all(EntityType.CONTENT) == set(
cnt_dates.keys()
)
# Add all content present in the current repo to the storage, just assigning their
# creation dates. Then check that the returned results when querying are the same.
cnt_dates = {
cnt["sha1_git"]: cnt["ctime"] for idx, cnt in enumerate(data["content"])
cnt["sha1_git"]: cnt["ctime"].replace(
tzinfo=timezone(timedelta(hours=23, minutes=59, seconds=59))
)
for cnt in data["content"]
}
expected_dates = {
sha1_git: date.astimezone(UTC) for sha1_git, date in cnt_dates.items()
}
assert provenance_storage.content_add(cnt_dates)
assert provenance_storage.content_get(set(cnt_dates.keys())) == cnt_dates
assert provenance_storage.content_get(set(cnt_dates.keys())) == expected_dates
assert provenance_storage.entity_get_all(EntityType.CONTENT) == set(
cnt_dates.keys()
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment