Skip to content
Snippets Groups Projects
Commit 52ef52ea authored by vlorentz's avatar vlorentz
Browse files

Use attr instead of NamedTuple to generate SWHID.

As NamedTuple inherits from tuple, msgpack serializes it like a tuple,
which makes it indistinguishable from a tuple when deserializing,
which is an issue for the RPC API.
parent bea256e3
No related branches found
No related tags found
No related merge requests found
......@@ -8,8 +8,9 @@ import datetime
import hashlib
from functools import lru_cache
from typing import Any, Dict, NamedTuple, Union
from typing import Any, Dict, Union
import attr
from deprecated import deprecated
from .collections import ImmutableDict
......@@ -650,19 +651,8 @@ _object_type_map = {
}
_SWHID = NamedTuple(
"SWHID",
[
("namespace", str),
("scheme_version", int),
("object_type", str),
("object_id", str),
("metadata", ImmutableDict[str, Any]),
],
)
class SWHID(_SWHID):
@attr.s(frozen=True)
class SWHID:
"""
Named tuple holding the relevant info associated to a SoftWare Heritage
persistent IDentifier (SWHID)
......@@ -700,44 +690,38 @@ class SWHID(_SWHID):
# 'swh:1:cnt:8ff44f081d43176474b267de5451f2c2e88089d0'
"""
__slots__ = ()
def __new__(
cls,
namespace: str = SWHID_NAMESPACE,
scheme_version: int = SWHID_VERSION,
object_type: str = "",
object_id: str = "",
metadata: Union[ImmutableDict[str, Any], Dict[str, Any]] = ImmutableDict(),
):
o = _object_type_map.get(object_type)
if not o:
raise ValidationError(
"Wrong input: Supported types are %s" % (list(_object_type_map.keys()))
)
if namespace != SWHID_NAMESPACE:
namespace = attr.ib(type=str, default="swh")
scheme_version = attr.ib(type=int, default=1)
object_type = attr.ib(type=str, default="")
object_id = attr.ib(type=str, converter=hash_to_hex, default="") # type: ignore
metadata = attr.ib(
type=ImmutableDict[str, Any], converter=ImmutableDict, default=ImmutableDict()
)
@namespace.validator
def check_namespace(self, attribute, value):
if value != SWHID_NAMESPACE:
raise ValidationError(
"Wrong format: only supported namespace is '%s'" % SWHID_NAMESPACE
)
if scheme_version != SWHID_VERSION:
@scheme_version.validator
def check_scheme_version(self, attribute, value):
if value != SWHID_VERSION:
raise ValidationError(
"Wrong format: only supported version is %d" % SWHID_VERSION
)
# internal swh representation resolution
if isinstance(object_id, dict):
object_id = object_id[o["key_id"]]
validate_sha1(object_id) # can raise if invalid hash
object_id = hash_to_hex(object_id)
return super().__new__(
cls,
namespace,
scheme_version,
object_type,
object_id,
ImmutableDict(metadata),
)
@object_type.validator
def check_object_type(self, attribute, value):
if value not in _object_type_map:
raise ValidationError(
"Wrong input: Supported types are %s" % (list(_object_type_map.keys()))
)
@object_id.validator
def check_object_id(self, attribute, value):
validate_sha1(value) # can raise if invalid hash
def __str__(self) -> str:
o = _object_type_map.get(self.object_type)
......@@ -762,13 +746,12 @@ class PersistentId(SWHID):
"""
def __new__(cls, *args, **kwargs):
return super(cls, PersistentId).__new__(cls, *args, **kwargs)
pass
def swhid(
object_type: str,
object_id: str,
object_id: Union[str, Dict[str, Any]],
scheme_version: int = 1,
metadata: Union[ImmutableDict[str, Any], Dict[str, Any]] = ImmutableDict(),
) -> str:
......@@ -788,11 +771,14 @@ def swhid(
the SWHID of the object
"""
if isinstance(object_id, dict):
o = _object_type_map[object_type]
object_id = object_id[o["key_id"]]
swhid = SWHID(
scheme_version=scheme_version,
object_type=object_type,
object_id=object_id,
metadata=metadata,
metadata=metadata, # type: ignore # mypy can't properly unify types
)
return str(swhid)
......@@ -854,7 +840,13 @@ def parse_swhid(swhid: str) -> SWHID:
except Exception:
msg = "Contextual data is badly formatted, form key=val expected"
raise ValidationError(msg)
return SWHID(_ns, int(_version), _type, _id, _metadata)
return SWHID(
_ns,
int(_version),
_type,
_id,
_metadata, # type: ignore # mypy can't properly unify types
)
@deprecated("Use swh.model.identifiers.parse_swhid instead")
......@@ -864,4 +856,4 @@ def parse_persistent_identifier(persistent_id: str) -> PersistentId:
.. deprecated:: 0.3.8
Use :func:`swh.model.identifiers.parse_swhid` instead
"""
return PersistentId(**parse_swhid(persistent_id)._asdict())
return PersistentId(**attr.asdict(parse_swhid(persistent_id)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment