From 6d8ddab9196ccde3d508cfac52546e3d3d363e52 Mon Sep 17 00:00:00 2001
From: Pierre-Yves David <pierre-yves.david@ens-lyon.org>
Date: Fri, 24 May 2024 16:39:11 +0200
Subject: [PATCH] enum-cleanup: use ModelObjectType in strategies.objects

We also typed the function, which turned out more painful than
anticipated.
---
 swh/model/hypothesis_strategies.py | 60 ++++++++++++++++++++----------
 1 file changed, 40 insertions(+), 20 deletions(-)

diff --git a/swh/model/hypothesis_strategies.py b/swh/model/hypothesis_strategies.py
index af145a14..5ae56fda 100644
--- a/swh/model/hypothesis_strategies.py
+++ b/swh/model/hypothesis_strategies.py
@@ -5,12 +5,13 @@
 
 import datetime
 import string
-from typing import Sequence
+from typing import Any, Callable, List, Sequence, Set, Tuple, Union
 
 from deprecated import deprecated
 from hypothesis import assume
 from hypothesis.extra.dateutil import timezones
 from hypothesis.strategies import (
+    SearchStrategy,
     binary,
     booleans,
     builds,
@@ -33,11 +34,13 @@ from hypothesis.strategies import (
 from .from_disk import DentryPerms
 from .model import (
     BaseContent,
+    BaseModel,
     Content,
     Directory,
     DirectoryEntry,
     MetadataAuthority,
     MetadataFetcher,
+    ModelObjectType,
     Origin,
     OriginVisit,
     OriginVisitStatus,
@@ -545,7 +548,22 @@ def raw_extrinsic_metadata_d(**kwargs):
     return raw_extrinsic_metadata(**kwargs).map(RawExtrinsicMetadata.to_dict)
 
 
-def objects(blacklist_types=("origin_visit_status",), split_content=False):
+def _tuplify(
+    object_type: ModelObjectType,
+) -> Callable[[BaseModel], Tuple[ModelObjectType, BaseModel]]:
+    def tupler(obj: BaseModel):
+        return (object_type, obj)
+
+    return tupler
+
+
+def objects(
+    # remove the Union once deprecated usage have been migrated
+    blacklist_types: Union[Set[ModelObjectType] | Any] = {
+        ModelObjectType.ORIGIN_VISIT_STATUS,
+    },
+    split_content: bool = False,
+):
     """generates a random couple (type, obj)
 
     which obj is an instance of the Model class corresponding to obj_type.
@@ -555,27 +573,29 @@ def objects(blacklist_types=("origin_visit_status",), split_content=False):
     If `split_content` is True, generates Content and SkippedContent under different
     obj_type, resp. "content" and "skipped_content".
     """
-    strategies = [
-        ("origin", origins),
-        ("origin_visit", origin_visits),
-        ("origin_visit_status", origin_visit_statuses),
-        ("snapshot", snapshots),
-        ("release", releases),
-        ("revision", revisions),
-        ("directory", directories),
-        ("raw_extrinsic_metadata", raw_extrinsic_metadata),
+    strategies: List[
+        Tuple[ModelObjectType, Callable[[], SearchStrategy[BaseModel]]]
+    ] = [
+        (ModelObjectType.ORIGIN, origins),
+        (ModelObjectType.ORIGIN_VISIT, origin_visits),
+        (ModelObjectType.ORIGIN_VISIT_STATUS, origin_visit_statuses),
+        (ModelObjectType.SNAPSHOT, snapshots),
+        (ModelObjectType.RELEASE, releases),
+        (ModelObjectType.REVISION, revisions),
+        (ModelObjectType.DIRECTORY, directories),
+        (ModelObjectType.RAW_EXTRINSIC_METADATA, raw_extrinsic_metadata),
     ]
     if split_content:
-        strategies.append(("content", present_contents))
-        strategies.append(("skipped_content", skipped_contents))
+        strategies.append((ModelObjectType.CONTENT, present_contents))
+        strategies.append((ModelObjectType.SKIPPED_CONTENT, skipped_contents))
     else:
-        strategies.append(("content", contents))
-    args = [
-        obj_gen().map(lambda x, obj_type=obj_type: (obj_type, x))
-        for (obj_type, obj_gen) in strategies
-        if obj_type not in blacklist_types
-    ]
-    return one_of(*args)
+        strategies.append((ModelObjectType.CONTENT, contents))
+
+    candidates = []
+    for obj_type, obj_gen in strategies:
+        if obj_type not in blacklist_types:
+            candidates.append(obj_gen().map(_tuplify(obj_type)))
+    return one_of(*candidates)
 
 
 def object_dicts(blacklist_types=("origin_visit_status",), split_content=False):
-- 
GitLab