diff --git a/swh/model/model.py b/swh/model/model.py index d570a552924f454e86a066c50975bd5262bba570..aebe4ad7d8366f9bc2bb77a616c2e0f74d6945e5 100644 --- a/swh/model/model.py +++ b/swh/model/model.py @@ -541,18 +541,22 @@ class Timestamp(BaseModel): object_type: Final = "timestamp" - seconds = attr.ib(type=int, validator=generic_type_validator) - microseconds = attr.ib(type=int, validator=generic_type_validator) + seconds = attr.ib(type=int) + microseconds = attr.ib(type=int) @seconds.validator def check_seconds(self, attribute, value): """Check that seconds fit in a 64-bits signed integer.""" + if value.__class__ is not int: + raise AttributeTypeError(value, attribute) if not (-(2**63) <= value < 2**63): raise ValueError("Seconds must be a signed 64-bits integer.") @microseconds.validator def check_microseconds(self, attribute, value): """Checks that microseconds are positive and < 1000000.""" + if value.__class__ is not int: + raise AttributeTypeError(value, attribute) if not (0 <= value < 10**6): raise ValueError("Microseconds must be in [0, 1000000[.") @@ -786,7 +790,7 @@ class OriginVisit(BaseModel): object_type: Final = "origin_visit" origin = attr.ib(type=str, validator=generic_type_validator) - date = attr.ib(type=datetime.datetime, validator=generic_type_validator) + date = attr.ib(type=datetime.datetime) type = attr.ib(type=str, validator=generic_type_validator) """Should not be set before calling 'origin_visit_add()'.""" visit = attr.ib(type=Optional[int], validator=generic_type_validator, default=None) @@ -794,6 +798,8 @@ class OriginVisit(BaseModel): @date.validator def check_date(self, attribute, value): """Checks the date has a timezone.""" + if value.__class__ is not datetime.datetime: + raise AttributeTypeError(value, attribute) if value is not None and value.tzinfo is None: raise ValueError("date must be a timezone-aware datetime.") @@ -818,7 +824,7 @@ class OriginVisitStatus(BaseModel): origin = attr.ib(type=str, validator=generic_type_validator) visit = attr.ib(type=int, validator=generic_type_validator) - date = attr.ib(type=datetime.datetime, validator=generic_type_validator) + date = attr.ib(type=datetime.datetime) status = attr.ib( type=str, validator=attr.validators.in_( @@ -840,6 +846,8 @@ class OriginVisitStatus(BaseModel): @date.validator def check_date(self, attribute, value): """Checks the date has a timezone.""" + if value.__class__ is not datetime.datetime: + raise AttributeTypeError(value, attribute) if value is not None and value.tzinfo is None: raise ValueError("date must be a timezone-aware datetime.") @@ -881,13 +889,15 @@ class SnapshotBranch(BaseModel): object_type: Final = "snapshot_branch" - target = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr) + target = attr.ib(type=bytes, repr=hash_repr) target_type = attr.ib(type=TargetType, validator=generic_type_validator) @target.validator def check_target(self, attribute, value): """Checks the target type is not an alias, checks the target is a valid sha1_git.""" + if value.__class__ is not bytes: + raise AttributeTypeError(value, attribute) if self.target_type != TargetType.ALIAS and self.target is not None: if len(value) != 20: raise ValueError("Wrong length for bytes identifier: %d" % len(value)) @@ -1139,7 +1149,7 @@ _DIR_ENTRY_TYPES = ["file", "dir", "rev"] class DirectoryEntry(BaseModel): object_type: Final = "directory_entry" - name = attr.ib(type=bytes, validator=generic_type_validator) + name = attr.ib(type=bytes) type = attr.ib(type=str, validator=attr.validators.in_(_DIR_ENTRY_TYPES)) target = attr.ib(type=Sha1Git, validator=generic_type_validator, repr=hash_repr) perms = attr.ib(type=int, validator=generic_type_validator, converter=int, repr=oct) @@ -1147,6 +1157,8 @@ class DirectoryEntry(BaseModel): @name.validator def check_name(self, attribute, value): + if value.__class__ is not bytes: + raise AttributeTypeError(value, attribute) if b"/" in value: raise ValueError(f"{value!r} is not a valid directory entry name.") @@ -1317,7 +1329,7 @@ class Content(BaseContent): sha256 = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr) blake2s256 = attr.ib(type=bytes, validator=generic_type_validator, repr=hash_repr) - length = attr.ib(type=int, validator=generic_type_validator) + length = attr.ib(type=int) status = attr.ib( type=str, @@ -1329,7 +1341,6 @@ class Content(BaseContent): ctime = attr.ib( type=Optional[datetime.datetime], - validator=generic_type_validator, default=None, eq=False, ) @@ -1337,14 +1348,19 @@ class Content(BaseContent): @length.validator def check_length(self, attribute, value): """Checks the length is positive.""" + if value.__class__ is not int: + raise AttributeTypeError(value, attribute) if value < 0: raise ValueError("Length must be positive.") @ctime.validator def check_ctime(self, attribute, value): """Checks the ctime has a timezone.""" - if value is not None and value.tzinfo is None: - raise ValueError("ctime must be a timezone-aware datetime.") + if value is not None: + if value.__class__ is not datetime.datetime: + raise AttributeTypeError(value, attribute) + if value.tzinfo is None: + raise ValueError("ctime must be a timezone-aware datetime.") def to_dict(self): content = super().to_dict() @@ -1408,10 +1424,10 @@ class SkippedContent(BaseContent): type=Optional[bytes], validator=generic_type_validator, repr=hash_repr ) - length = attr.ib(type=Optional[int], validator=generic_type_validator) + length = attr.ib(type=Optional[int]) status = attr.ib(type=str, validator=attr.validators.in_(["absent"])) - reason = attr.ib(type=Optional[str], validator=generic_type_validator, default=None) + reason = attr.ib(type=Optional[str], default=None) origin = attr.ib(type=Optional[str], validator=generic_type_validator, default=None) @@ -1428,18 +1444,25 @@ class SkippedContent(BaseContent): assert self.reason == value if value is None: raise ValueError("Must provide a reason if content is absent.") + elif value.__class__ is not str: + raise AttributeTypeError(value, attribute) @length.validator def check_length(self, attribute, value): """Checks the length is positive or -1.""" - if value < -1: + if value.__class__ is not int: + raise AttributeTypeError(value, attribute) + elif value < -1: raise ValueError("Length must be positive or -1.") @ctime.validator def check_ctime(self, attribute, value): """Checks the ctime has a timezone.""" - if value is not None and value.tzinfo is None: - raise ValueError("ctime must be a timezone-aware datetime.") + if value is not None: + if value.__class__ is not datetime.datetime: + raise AttributeTypeError(value, attribute) + elif value.tzinfo is None: + raise ValueError("ctime must be a timezone-aware datetime.") def to_dict(self): content = super().to_dict() @@ -1577,20 +1600,12 @@ class RawExtrinsicMetadata(HashableObject, BaseModel): # context origin = attr.ib(type=Optional[str], default=None, validator=generic_type_validator) - visit = attr.ib(type=Optional[int], default=None, validator=generic_type_validator) - snapshot = attr.ib( - type=Optional[CoreSWHID], default=None, validator=generic_type_validator - ) - release = attr.ib( - type=Optional[CoreSWHID], default=None, validator=generic_type_validator - ) - revision = attr.ib( - type=Optional[CoreSWHID], default=None, validator=generic_type_validator - ) - path = attr.ib(type=Optional[bytes], default=None, validator=generic_type_validator) - directory = attr.ib( - type=Optional[CoreSWHID], default=None, validator=generic_type_validator - ) + visit = attr.ib(type=Optional[int], default=None) + snapshot = attr.ib(type=Optional[CoreSWHID], default=None) + release = attr.ib(type=Optional[CoreSWHID], default=None) + revision = attr.ib(type=Optional[CoreSWHID], default=None) + path = attr.ib(type=Optional[bytes], default=None) + directory = attr.ib(type=Optional[CoreSWHID], default=None) id = attr.ib( type=Sha1Git, validator=generic_type_validator, default=b"", repr=hash_repr @@ -1606,12 +1621,15 @@ class RawExtrinsicMetadata(HashableObject, BaseModel): if value is None: return - if self.target.object_type not in ( - SwhidExtendedObjectType.SNAPSHOT, - SwhidExtendedObjectType.RELEASE, - SwhidExtendedObjectType.REVISION, - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + if value.__class__ is not str: + raise AttributeTypeError(value, attribute) + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.SNAPSHOT + or obj_type is SwhidExtendedObjectType.RELEASE + or obj_type is SwhidExtendedObjectType.REVISION + or obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'origin' context for " @@ -1630,13 +1648,16 @@ class RawExtrinsicMetadata(HashableObject, BaseModel): def check_visit(self, attribute, value): if value is None: return - - if self.target.object_type not in ( - SwhidExtendedObjectType.SNAPSHOT, - SwhidExtendedObjectType.RELEASE, - SwhidExtendedObjectType.REVISION, - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + if value.__class__ is not int: + raise AttributeTypeError(value, attribute) + + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.SNAPSHOT + or obj_type is SwhidExtendedObjectType.RELEASE + or obj_type is SwhidExtendedObjectType.REVISION + or obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'visit' context for " @@ -1653,12 +1674,15 @@ class RawExtrinsicMetadata(HashableObject, BaseModel): def check_snapshot(self, attribute, value): if value is None: return - - if self.target.object_type not in ( - SwhidExtendedObjectType.RELEASE, - SwhidExtendedObjectType.REVISION, - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + if value.__class__ is not CoreSWHID: + raise AttributeTypeError(value, attribute) + + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.RELEASE + or obj_type is SwhidExtendedObjectType.REVISION + or obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'snapshot' context for " @@ -1671,11 +1695,14 @@ class RawExtrinsicMetadata(HashableObject, BaseModel): def check_release(self, attribute, value): if value is None: return - - if self.target.object_type not in ( - SwhidExtendedObjectType.REVISION, - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + if value.__class__ is not CoreSWHID: + raise AttributeTypeError(value, attribute) + + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.REVISION + or obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'release' context for " @@ -1689,9 +1716,13 @@ class RawExtrinsicMetadata(HashableObject, BaseModel): if value is None: return - if self.target.object_type not in ( - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + if value.__class__ is not CoreSWHID: + raise AttributeTypeError(value, attribute) + + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'revision' context for " @@ -1705,9 +1736,13 @@ class RawExtrinsicMetadata(HashableObject, BaseModel): if value is None: return - if self.target.object_type not in ( - SwhidExtendedObjectType.DIRECTORY, - SwhidExtendedObjectType.CONTENT, + if value.__class__ is not bytes: + raise AttributeTypeError(value, attribute) + + obj_type = self.target.object_type + if not ( + obj_type is SwhidExtendedObjectType.DIRECTORY + or obj_type is SwhidExtendedObjectType.CONTENT ): raise ValueError( f"Unexpected 'path' context for " @@ -1719,7 +1754,10 @@ class RawExtrinsicMetadata(HashableObject, BaseModel): if value is None: return - if self.target.object_type not in (SwhidExtendedObjectType.CONTENT,): + if value.__class__ is not CoreSWHID: + raise AttributeTypeError(value, attribute) + + if self.target.object_type is not SwhidExtendedObjectType.CONTENT: raise ValueError( f"Unexpected 'directory' context for " f"{self.target.object_type.name.lower()} object: {value}" @@ -1728,8 +1766,8 @@ class RawExtrinsicMetadata(HashableObject, BaseModel): self._check_swhid(SwhidObjectType.DIRECTORY, value) def _check_swhid(self, expected_object_type, swhid): - if isinstance(swhid, str): - raise ValueError(f"Expected SWHID, got a string: {swhid}") + if swhid.__class__ is not CoreSWHID: + raise ValueError(f"Expected SWHID, got a {swhid.__class__}: {swhid}") if swhid.object_type != expected_object_type: raise ValueError(