Skip to content
Snippets Groups Projects
Commit 6dd6acec authored by vlorentz's avatar vlorentz
Browse files

model: Raise error on naive datetimes.

We may unknowingly pass naive datetimes to the storage through them,
causing the underlying DB to assign them a timezone that might not match
the actual one.

It already happens in swh.model and swh.loader.package tests.
parent d1db7b99
No related branches found
No related tags found
No related merge requests found
......@@ -267,6 +267,12 @@ class OriginVisit(BaseModel):
"""Should not be set before calling 'origin_visit_add()'."""
visit = attr.ib(type=Optional[int], validator=type_validator(), default=None)
@date.validator
def check_date(self, attribute, value):
"""Checks the date has a timezone."""
if value is not None and value.tzinfo is None:
raise ValueError("date must be a timezone-aware datetime.")
def to_dict(self):
"""Serializes the date as a string and omits the visit id if it is
`None`."""
......@@ -300,6 +306,12 @@ class OriginVisitStatus(BaseModel):
default=None,
)
@date.validator
def check_date(self, attribute, value):
"""Checks the date has a timezone."""
if value is not None and value.tzinfo is None:
raise ValueError("date must be a timezone-aware datetime.")
class TargetType(Enum):
"""The type of content pointed to by a snapshot branch. Usually a
......@@ -621,6 +633,12 @@ class Content(BaseContent):
if value < 0:
raise ValueError("Length must be positive.")
@ctime.validator
def check_ctime(self, attribute, value):
"""Checks the ctime has a timezone."""
if value is not None and value.tzinfo is None:
raise ValueError("ctime must be a timezone-aware datetime.")
def to_dict(self):
content = super().to_dict()
if content["data"] is None:
......@@ -695,6 +713,12 @@ class SkippedContent(BaseContent):
if value < -1:
raise ValueError("Length must be positive or -1.")
@ctime.validator
def check_ctime(self, attribute, value):
"""Checks the ctime has a timezone."""
if value is not None and value.tzinfo is None:
raise ValueError("ctime must be a timezone-aware datetime.")
def to_dict(self):
content = super().to_dict()
if content["origin"] is None:
......@@ -835,6 +859,12 @@ class RawExtrinsicMetadata(BaseModel):
else:
self._check_pid(self.type.value, value)
@discovery_date.validator
def check_discovery_date(self, attribute, value):
"""Checks the discovery_date has a timezone."""
if value is not None and value.tzinfo is None:
raise ValueError("discovery_date must be a timezone-aware datetime.")
@origin.validator
def check_origin(self, attribute, value):
if value is None:
......
......@@ -21,6 +21,8 @@ from swh.model.model import (
Release,
Snapshot,
Origin,
OriginVisit,
OriginVisitStatus,
Timestamp,
TimestampWithTimezone,
MissingData,
......@@ -97,7 +99,7 @@ def test_anonymization(objtype_and_obj):
assert anon_obj is None
# Origin, OriginVisit
# Origin, OriginVisit, OriginVisitStatus
@given(strategies.origins())
......@@ -115,6 +117,13 @@ def test_todict_origin_visits(origin_visit):
assert origin_visit == type(origin_visit).from_dict(obj)
def test_origin_visit_naive_datetime():
with pytest.raises(ValueError, match="must be a timezone-aware datetime"):
OriginVisit(
origin="http://foo/", date=datetime.datetime.now(), type="git",
)
@given(strategies.origin_visit_statuses())
def test_todict_origin_visit_statuses(origin_visit_status):
obj = origin_visit_status.to_dict()
......@@ -122,6 +131,17 @@ def test_todict_origin_visit_statuses(origin_visit_status):
assert origin_visit_status == type(origin_visit_status).from_dict(obj)
def test_origin_visit_status_naive_datetime():
with pytest.raises(ValueError, match="must be a timezone-aware datetime"):
OriginVisitStatus(
origin="http://foo/",
visit=42,
date=datetime.datetime.now(),
status="ongoing",
snapshot=None,
)
# Timestamp
......@@ -224,6 +244,13 @@ def test_timestampwithtimezone_from_datetime():
)
def test_timestampwithtimezone_from_naive_datetime():
date = datetime.datetime(2020, 2, 27, 14, 39, 19)
with pytest.raises(ValueError, match="datetime without timezone"):
TimestampWithTimezone.from_datetime(date)
def test_timestampwithtimezone_from_iso8601():
date = "2020-02-27 14:39:19.123456+0100"
......@@ -363,7 +390,7 @@ def test_content_from_dict(content_d):
def test_content_from_dict_str_ctime():
# test with ctime as a string
n = datetime.datetime(2020, 5, 6, 12, 34)
n = datetime.datetime(2020, 5, 6, 12, 34, tzinfo=datetime.timezone.utc)
content_d = {
"ctime": n.isoformat(),
"data": b"",
......@@ -377,6 +404,22 @@ def test_content_from_dict_str_ctime():
assert c.ctime == n
def test_content_from_dict_str_naive_ctime():
# test with ctime as a string
n = datetime.datetime(2020, 5, 6, 12, 34)
content_d = {
"ctime": n.isoformat(),
"data": b"",
"length": 0,
"sha1": b"\x00",
"sha256": b"\x00",
"sha1_git": b"\x00",
"blake2s256": b"\x00",
}
with pytest.raises(ValueError, match="must be a timezone-aware datetime."):
Content.from_dict(content_d)
@given(binary(max_size=4096))
def test_content_from_data(data):
c = Content.from_data(data)
......@@ -397,6 +440,14 @@ def test_hidden_content_from_data(data):
assert getattr(c, key) == value
def test_content_naive_datetime():
c = Content.from_data(b"foo")
with pytest.raises(ValueError, match="must be a timezone-aware datetime"):
Content(
**c.to_dict(), ctime=datetime.datetime.now(),
)
# SkippedContent
......@@ -422,6 +473,14 @@ def test_skipped_content_origin_is_str(skipped_content_d):
SkippedContent.from_dict(skipped_content_d)
def test_skipped_content_naive_datetime():
c = SkippedContent.from_data(b"foo", reason="reason")
with pytest.raises(ValueError, match="must be a timezone-aware datetime"):
SkippedContent(
**c.to_dict(), ctime=datetime.datetime.now(),
)
# Revision
......@@ -694,7 +753,7 @@ _metadata_fetcher = MetadataFetcher(name="test-fetcher", version="0.0.1",)
_content_swhid = parse_swhid("swh:1:cnt:94a9ed024d3859793618152ea559a168bbcbb5e2")
_origin_url = "https://forge.softwareheritage.org/source/swh-model.git"
_common_metadata_fields = dict(
discovery_date=datetime.datetime.now(),
discovery_date=datetime.datetime.now(tz=datetime.timezone.utc),
authority=_metadata_authority,
fetcher=_metadata_fetcher,
format="json",
......@@ -802,6 +861,15 @@ def test_metadata_invalid_id():
)
def test_metadata_naive_datetime():
with pytest.raises(ValueError, match="must be a timezone-aware datetime"):
RawExtrinsicMetadata(
type=MetadataTargetType.ORIGIN,
id=_origin_url,
**{**_common_metadata_fields, "discovery_date": datetime.datetime.now()},
)
def test_metadata_validate_context_origin():
"""Checks validation of RawExtrinsicMetadata.origin."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment