From d298c584fe8d9c45951ad5ab9051415e4cc0ec70 Mon Sep 17 00:00:00 2001
From: Valentin Lorentz <vlorentz@softwareheritage.org>
Date: Mon, 11 Jul 2022 18:46:42 +0200
Subject: [PATCH] hypothesis_strategies: Add kwargs to composite strategies, to
 replace default sub-strategies

This allows callers to pass different (typically, stricter) strategies
to hypothesis instead of filtering the output; which makes example
generation faster.
---
 swh/model/hypothesis_strategies.py | 195 ++++++++++++++---------------
 swh/model/tests/test_model.py      |  16 +--
 2 files changed, 101 insertions(+), 110 deletions(-)

diff --git a/swh/model/hypothesis_strategies.py b/swh/model/hypothesis_strategies.py
index dabecf96..53c66f00 100644
--- a/swh/model/hypothesis_strategies.py
+++ b/swh/model/hypothesis_strategies.py
@@ -120,22 +120,22 @@ def persons_d(draw):
     return dict(fullname=fullname, name=name, email=email)
 
 
-def persons():
-    return persons_d().map(Person.from_dict)
+def persons(**kwargs):
+    return persons_d(**kwargs).map(Person.from_dict)
 
 
-def timestamps_d():
+def timestamps_d(**kwargs):
     max_seconds = datetime.datetime.max.replace(
         tzinfo=datetime.timezone.utc
     ).timestamp()
     min_seconds = datetime.datetime.min.replace(
         tzinfo=datetime.timezone.utc
     ).timestamp()
-    return builds(
-        dict,
+    defaults = dict(
         seconds=integers(min_seconds, max_seconds),
         microseconds=integers(0, 1000000),
     )
+    return builds(dict, **{**defaults, **kwargs})
 
 
 def timestamps():
@@ -145,6 +145,7 @@ def timestamps():
 @composite
 def timestamps_with_timezone_d(
     draw,
+    *,
     timestamp=timestamps_d(),
     offset=integers(min_value=-14 * 60, max_value=14 * 60),
     negative_utc=booleans(),
@@ -161,35 +162,34 @@ timestamps_with_timezone = timestamps_with_timezone_d().map(
 )
 
 
-def origins_d():
-    return builds(dict, url=iris())
+def origins_d(*, url=iris()):
+    return builds(dict, url=url)
 
 
-def origins():
-    return origins_d().map(Origin.from_dict)
+def origins(**kwargs):
+    return origins_d(**kwargs).map(Origin.from_dict)
 
 
-def origin_visits_d():
-    return builds(
-        dict,
+def origin_visits_d(**kwargs):
+    defaults = dict(
         visit=integers(1, 1000),
         origin=iris(),
         date=aware_datetimes(),
         type=pgsql_text(),
     )
+    return builds(dict, **{**defaults, **kwargs})
 
 
-def origin_visits():
-    return origin_visits_d().map(OriginVisit.from_dict)
+def origin_visits(**kwargs):
+    return origin_visits_d(**kwargs).map(OriginVisit.from_dict)
 
 
 def metadata_dicts():
     return dictionaries(pgsql_text(), pgsql_text())
 
 
-def origin_visit_statuses_d():
-    return builds(
-        dict,
+def origin_visit_statuses_d(**kwargs):
+    defaults = dict(
         visit=integers(1, 1000),
         origin=iris(),
         type=optional(sampled_from(["git", "svn", "pypi", "debian"])),
@@ -200,60 +200,48 @@ def origin_visit_statuses_d():
         snapshot=optional(sha1_git()),
         metadata=optional(metadata_dicts()),
     )
+    return builds(dict, **{**defaults, **kwargs})
 
 
-def origin_visit_statuses():
-    return origin_visit_statuses_d().map(OriginVisitStatus.from_dict)
+def origin_visit_statuses(**kwargs):
+    return origin_visit_statuses_d(**kwargs).map(OriginVisitStatus.from_dict)
 
 
 @composite
-def releases_d(draw):
-    target_type = sampled_from([x.value for x in ObjectType])
-    name = binary()
-    message = optional(binary())
-    synthetic = booleans()
-    target = sha1_git()
-    metadata = optional(revision_metadata())
+def releases_d(draw, **kwargs):
+    defaults = dict(
+        target_type=sampled_from([x.value for x in ObjectType]),
+        name=binary(),
+        message=optional(binary()),
+        synthetic=booleans(),
+        target=sha1_git(),
+        metadata=optional(revision_metadata()),
+        raw_manifest=optional(binary()),
+    )
 
     d = draw(
         one_of(
             # None author/date:
-            builds(
-                dict,
-                name=name,
-                message=message,
-                synthetic=synthetic,
-                author=none(),
-                date=none(),
-                target=target,
-                target_type=target_type,
-                metadata=metadata,
-            ),
+            builds(dict, author=none(), date=none(), **{**defaults, **kwargs}),
             # non-None author/date:
             builds(
                 dict,
-                name=name,
-                message=message,
-                synthetic=synthetic,
                 date=timestamps_with_timezone_d(),
                 author=persons_d(),
-                target=target,
-                target_type=target_type,
-                metadata=metadata,
+                **{**defaults, **kwargs},
             ),
             # it is also possible for date to be None but not author, but let's not
             # overwhelm hypothesis with this edge case
         )
     )
 
-    raw_manifest = draw(optional(binary()))
-    if raw_manifest:
-        d["raw_manifest"] = raw_manifest
+    if d["raw_manifest"] is None:
+        del d["raw_manifest"]
     return d
 
 
-def releases():
-    return releases_d().map(Release.from_dict)
+def releases(**kwargs):
+    return releases_d(**kwargs).map(Release.from_dict)
 
 
 revision_metadata = metadata_dicts
@@ -266,38 +254,36 @@ def extra_headers():
 
 
 @composite
-def revisions_d(draw):
+def revisions_d(draw, **kwargs):
+    defaults = dict(
+        message=optional(binary()),
+        synthetic=booleans(),
+        parents=tuples(sha1_git()),
+        directory=sha1_git(),
+        type=sampled_from([x.value for x in RevisionType]),
+        metadata=optional(revision_metadata()),
+        extra_headers=extra_headers(),
+        raw_manifest=optional(binary()),
+    )
     d = draw(
         one_of(
             # None author/committer/date/committer_date
             builds(
                 dict,
-                message=optional(binary()),
-                synthetic=booleans(),
                 author=none(),
                 committer=none(),
                 date=none(),
                 committer_date=none(),
-                parents=tuples(sha1_git()),
-                directory=sha1_git(),
-                type=sampled_from([x.value for x in RevisionType]),
-                metadata=optional(revision_metadata()),
-                extra_headers=extra_headers(),
+                **{**defaults, **kwargs},
             ),
             # non-None author/committer/date/committer_date
             builds(
                 dict,
-                message=optional(binary()),
-                synthetic=booleans(),
                 author=persons_d(),
                 committer=persons_d(),
                 date=timestamps_with_timezone_d(),
                 committer_date=timestamps_with_timezone_d(),
-                parents=tuples(sha1_git()),
-                directory=sha1_git(),
-                type=sampled_from([x.value for x in RevisionType]),
-                metadata=optional(revision_metadata()),
-                extra_headers=extra_headers(),
+                **{**defaults, **kwargs},
             ),
             # There are many other combinations, but let's not overwhelm hypothesis
             # with these edge cases
@@ -305,67 +291,67 @@ def revisions_d(draw):
     )
     # TODO: metadata['extra_headers'] can have binary keys and values
 
-    raw_manifest = draw(optional(binary()))
-    if raw_manifest:
-        d["raw_manifest"] = raw_manifest
+    if d["raw_manifest"] is None:
+        del d["raw_manifest"]
     return d
 
 
-def revisions():
-    return revisions_d().map(Revision.from_dict)
+def revisions(**kwargs):
+    return revisions_d(**kwargs).map(Revision.from_dict)
 
 
-def directory_entries_d():
+def directory_entries_d(**kwargs):
+    defaults = dict(
+        name=binaries_without_bytes(b"/"),
+        target=sha1_git(),
+    )
     return one_of(
         builds(
             dict,
-            name=binaries_without_bytes(b"/"),
-            target=sha1_git(),
             type=just("file"),
             perms=one_of(
                 integers(min_value=0o100000, max_value=0o100777),  # regular file
                 integers(min_value=0o120000, max_value=0o120777),  # symlink
             ),
+            **{**defaults, **kwargs},
         ),
         builds(
             dict,
-            name=binaries_without_bytes(b"/"),
-            target=sha1_git(),
             type=just("dir"),
             perms=integers(
                 min_value=DentryPerms.directory,
                 max_value=DentryPerms.directory + 0o777,
             ),
+            **{**defaults, **kwargs},
         ),
         builds(
             dict,
-            name=binaries_without_bytes(b"/"),
-            target=sha1_git(),
             type=just("rev"),
             perms=integers(
                 min_value=DentryPerms.revision,
                 max_value=DentryPerms.revision + 0o777,
             ),
+            **{**defaults, **kwargs},
         ),
     )
 
 
-def directory_entries():
-    return directory_entries_d().map(DirectoryEntry)
+def directory_entries(**kwargs):
+    return directory_entries_d(**kwargs).map(DirectoryEntry)
 
 
 @composite
-def directories_d(draw):
+def directories_d(draw, raw_manifest=optional(binary())):
     d = draw(builds(dict, entries=tuples(directory_entries_d())))
 
-    raw_manifest = draw(optional(binary()))
-    if raw_manifest:
-        d["raw_manifest"] = raw_manifest
+    d["raw_manifest"] = draw(raw_manifest)
+    if d["raw_manifest"] is None:
+        del d["raw_manifest"]
     return d
 
 
-def directories():
-    return directories_d().map(Directory.from_dict)
+def directories(**kwargs):
+    return directories_d(**kwargs).map(Directory.from_dict)
 
 
 def contents_d():
@@ -376,21 +362,23 @@ def contents():
     return one_of(present_contents(), skipped_contents())
 
 
-def present_contents_d():
-    return builds(
-        dict,
+def present_contents_d(**kwargs):
+    defaults = dict(
         data=binary(max_size=4096),
         ctime=optional(aware_datetimes()),
         status=one_of(just("visible"), just("hidden")),
     )
+    return builds(dict, **{**defaults, **kwargs})
 
 
-def present_contents():
+def present_contents(**kwargs):
     return present_contents_d().map(lambda d: Content.from_data(**d))
 
 
 @composite
-def skipped_contents_d(draw):
+def skipped_contents_d(
+    draw, reason=pgsql_text(), status=just("absent"), ctime=optional(aware_datetimes())
+):
     result = BaseContent._hash_data(draw(binary(max_size=4096)))
     result.pop("data")
     nullify_attrs = draw(
@@ -398,13 +386,13 @@ def skipped_contents_d(draw):
     )
     for k in nullify_attrs:
         result[k] = None
-    result["reason"] = draw(pgsql_text())
-    result["status"] = "absent"
-    result["ctime"] = draw(optional(aware_datetimes()))
+    result["reason"] = draw(reason)
+    result["status"] = draw(status)
+    result["ctime"] = draw(ctime)
     return result
 
 
-def skipped_contents():
+def skipped_contents(**kwargs):
     return skipped_contents_d().map(SkippedContent.from_dict)
 
 
@@ -492,35 +480,38 @@ def snapshots(*, min_size=0, max_size=100, only_objects=False):
     ).map(Snapshot.from_dict)
 
 
-def metadata_authorities():
-    return builds(MetadataAuthority, url=iris(), metadata=just(None))
+def metadata_authorities(url=iris()):
+    return builds(MetadataAuthority, url=url, metadata=just(None))
 
 
-def metadata_fetchers():
-    return builds(
-        MetadataFetcher,
+def metadata_fetchers(**kwargs):
+    defaults = dict(
         name=text(min_size=1, alphabet=string.printable),
         version=text(
             min_size=1,
             alphabet=string.ascii_letters + string.digits + string.punctuation,
         ),
+    )
+    return builds(
+        MetadataFetcher,
         metadata=just(None),
+        **{**defaults, **kwargs},
     )
 
 
-def raw_extrinsic_metadata():
-    return builds(
-        RawExtrinsicMetadata,
+def raw_extrinsic_metadata(**kwargs):
+    defaults = dict(
         target=extended_swhids(),
         discovery_date=aware_datetimes(),
         authority=metadata_authorities(),
         fetcher=metadata_fetchers(),
         format=text(min_size=1, alphabet=string.printable),
     )
+    return builds(RawExtrinsicMetadata, **{**defaults, **kwargs})
 
 
-def raw_extrinsic_metadata_d():
-    return raw_extrinsic_metadata().map(RawExtrinsicMetadata.to_dict)
+def raw_extrinsic_metadata_d(**kwargs):
+    return raw_extrinsic_metadata(**kwargs).map(RawExtrinsicMetadata.to_dict)
 
 
 def objects(blacklist_types=("origin_visit_status",), split_content=False):
diff --git a/swh/model/tests/test_model.py b/swh/model/tests/test_model.py
index 0172386a..4540c433 100644
--- a/swh/model/tests/test_model.py
+++ b/swh/model/tests/test_model.py
@@ -13,7 +13,7 @@ import attr
 from attrs_strict import AttributeTypeError
 import dateutil
 from hypothesis import given
-from hypothesis.strategies import binary
+from hypothesis.strategies import binary, none
 import pytest
 
 from swh.model.collections import ImmutableDict
@@ -841,7 +841,7 @@ def test_content_naive_datetime():
         )
 
 
-@given(strategies.present_contents().filter(lambda cnt: cnt.data is not None))
+@given(strategies.present_contents())
 def test_content_git_roundtrip(content):
     assert content.data is not None
     raw = swh.model.git_objects.content_git_object(content)
@@ -886,7 +886,7 @@ def test_skipped_content_naive_datetime():
 # Directory
 
 
-@given(strategies.directories().filter(lambda d: d.raw_manifest is None))
+@given(strategies.directories(raw_manifest=none()))
 def test_directory_check(directory):
     directory.check()
 
@@ -903,7 +903,7 @@ def test_directory_check(directory):
         directory2.check()
 
 
-@given(strategies.directories().filter(lambda d: d.raw_manifest is None))
+@given(strategies.directories(raw_manifest=none()))
 def test_directory_raw_manifest(directory):
     assert "raw_manifest" not in directory.to_dict()
 
@@ -1083,7 +1083,7 @@ def test_directory_from_possibly_duplicated_entries__preserve_manifest():
 # Release
 
 
-@given(strategies.releases().filter(lambda rel: rel.raw_manifest is None))
+@given(strategies.releases(raw_manifest=none()))
 def test_release_check(release):
     release.check()
 
@@ -1100,7 +1100,7 @@ def test_release_check(release):
         release2.check()
 
 
-@given(strategies.releases().filter(lambda rev: rev.raw_manifest is None))
+@given(strategies.releases(raw_manifest=none()))
 def test_release_raw_manifest(release):
     raw_manifest = b"foo"
     id_ = hashlib.new("sha1", raw_manifest).digest()
@@ -1120,7 +1120,7 @@ def test_release_raw_manifest(release):
 # Revision
 
 
-@given(strategies.revisions().filter(lambda rev: rev.raw_manifest is None))
+@given(strategies.revisions(raw_manifest=none()))
 def test_revision_check(revision):
     revision.check()
 
@@ -1137,7 +1137,7 @@ def test_revision_check(revision):
         revision2.check()
 
 
-@given(strategies.revisions().filter(lambda rev: rev.raw_manifest is None))
+@given(strategies.revisions(raw_manifest=none()))
 def test_revision_raw_manifest(revision):
 
     raw_manifest = b"foo"
-- 
GitLab