Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • anlambert/swh-model
  • lunar/swh-model
  • franckbret/swh-model
  • douardda/swh-model
  • olasd/swh-model
  • swh/devel/swh-model
  • Alphare/swh-model
  • samplet/swh-model
  • marmoute/swh-model
  • rboyer/swh-model
10 results
Show changes
Showing
with 4822 additions and 999 deletions
swh.core swh.core >= 0.3
Click Click
dulwich dulwich
Click aiohttp
dulwich click
pytest pytest >= 8.1
pytz pytz
types-click
types-python-dateutil
types-pytz
types-deprecated
# Add here external Python modules dependencies, one per line. Module names # Add here external Python modules dependencies, one per line. Module names
# should match https://pypi.python.org/pypi names. For the full spec or # should match https://pypi.python.org/pypi names. For the full spec or
# dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html # dependency lines, see https://pip.readthedocs.org/en/1.1/requirements.html
vcversioner attrs != 21.1.0 # https://github.com/python-attrs/attrs/issues/804
attrs attrs_strict >= 0.0.7
deprecated
hypothesis hypothesis
iso8601
python-dateutil python-dateutil
typing_extensions
#!/usr/bin/env python3
# Copyright (C) 2015-2018 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from setuptools import setup, find_packages
from os import path
from io import open
here = path.abspath(path.dirname(__file__))
# Get the long description from the README file
with open(path.join(here, 'README.md'), encoding='utf-8') as f:
long_description = f.read()
def parse_requirements(name=None):
if name:
reqf = 'requirements-%s.txt' % name
else:
reqf = 'requirements.txt'
requirements = []
if not path.exists(reqf):
return requirements
with open(reqf) as f:
for line in f.readlines():
line = line.strip()
if not line or line.startswith('#'):
continue
requirements.append(line)
return requirements
blake2_requirements = ['pyblake2;python_version<"3.6"']
setup(
name='swh.model',
description='Software Heritage data model',
long_description=long_description,
long_description_content_type='text/markdown',
author='Software Heritage developers',
author_email='swh-devel@inria.fr',
url='https://forge.softwareheritage.org/diffusion/DMOD/',
packages=find_packages(),
setup_requires=['vcversioner'],
install_requires=(parse_requirements() + parse_requirements('swh') +
blake2_requirements),
extras_require={
'cli': parse_requirements('cli'),
'testing': parse_requirements('test'),
},
vcversioner={},
include_package_data=True,
entry_points='''
[console_scripts]
swh-identify=swh.model.cli:identify
[swh.cli.subcommands]
identify=swh.model.cli:identify
''',
classifiers=[
"Programming Language :: Python :: 3",
"Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
"Operating System :: OS Independent",
"Development Status :: 5 - Production/Stable",
],
project_urls={
'Bug Reports': 'https://forge.softwareheritage.org/maniphest',
'Funding': 'https://www.softwareheritage.org/donate',
'Source': 'https://forge.softwareheritage.org/source/swh-model',
},
)
from pkgutil import extend_path
from typing import Iterable
__path__ = extend_path(__path__, __name__) # type: Iterable[str]
# Copyright (C) 2018-2019 The Software Heritage developers # Copyright (C) 2018-2020 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution # See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version # License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information # See top-level LICENSE file for more information
import click
import dulwich.repo
import os import os
import sys import sys
from typing import Callable, Dict, Iterable, Optional
from functools import partial # WARNING: do not import unnecessary things here to keep cli startup time under
from urllib.parse import urlparse # control
try:
import click
except ImportError:
print(
"Cannot run swh-identify; the Click package is not installed."
"Please install 'swh.model[cli]' for full functionality.",
file=sys.stderr,
)
exit(1)
from swh.model import hashutil try:
from swh.model import identifiers as pids import swh.core.cli
from swh.model.exceptions import ValidationError
from swh.model.from_disk import Content, Directory
cli_command = swh.core.cli.swh.command
except ImportError:
# stub so that swh-identify can be used when swh-core isn't installed
cli_command = click.command
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help']) from swh.model.from_disk import Directory
from swh.model.swhids import CoreSWHID
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])
# Mapping between dulwich types and Software Heritage ones. Used by snapshot ID # Mapping between dulwich types and Software Heritage ones. Used by snapshot ID
# computation. # computation.
_DULWICH_TYPES = { _DULWICH_TYPES = {
b'blob': 'content', b"blob": "content",
b'tree': 'directory', b"tree": "directory",
b'commit': 'revision', b"commit": "revision",
b'tag': 'release', b"tag": "release",
} }
class PidParamType(click.ParamType): class CoreSWHIDParamType(click.ParamType):
name = 'persistent identifier' """Click argument that accepts a core SWHID and returns them as
:class:`swh.model.swhids.CoreSWHID` instances"""
name = "SWHID"
def convert(self, value, param, ctx) -> CoreSWHID:
from swh.model.exceptions import ValidationError
def convert(self, value, param, ctx):
try: try:
pids.parse_persistent_identifier(value) return CoreSWHID.from_string(value)
return value # return as string, as we need just that
except ValidationError as e: except ValidationError as e:
self.fail('%s is not a valid PID. %s.' % (value, e), param, ctx) self.fail(f'"{value}" is not a valid core SWHID: {e}', param, ctx)
def swhid_of_file(path) -> CoreSWHID:
from swh.model.from_disk import Content
object = Content.from_file(path=path)
return object.swhid()
def swhid_of_file_content(data) -> CoreSWHID:
from swh.model.from_disk import Content
object = Content.from_bytes(mode=644, data=data)
return object.swhid()
def model_of_dir(
path: bytes,
exclude_patterns: Optional[Iterable[bytes]] = None,
update_info: Optional[Callable[[int], None]] = None,
) -> Directory:
from swh.model.from_disk import accept_all_paths, ignore_directories_patterns
path_filter = (
ignore_directories_patterns(path, exclude_patterns)
if exclude_patterns
else accept_all_paths
)
return Directory.from_disk(
path=path, path_filter=path_filter, progress_callback=update_info
)
def pid_of_file(path):
object = Content.from_file(path=path).get_data()
return pids.persistent_identifier(pids.CONTENT, object)
def swhid_of_dir(
path: bytes, exclude_patterns: Optional[Iterable[bytes]] = None
) -> CoreSWHID:
obj = model_of_dir(path, exclude_patterns)
return obj.swhid()
def pid_of_file_content(data):
object = Content.from_bytes(mode=644, data=data).get_data()
return pids.persistent_identifier(pids.CONTENT, object)
def swhid_of_origin(url):
from swh.model.model import Origin
def pid_of_dir(path): return Origin(url).swhid()
object = Directory.from_disk(path=path).get_data()
return pids.persistent_identifier(pids.DIRECTORY, object)
def pid_of_origin(url): def swhid_of_git_repo(path) -> CoreSWHID:
pid = pids.PersistentId(object_type='origin', try:
object_id=pids.origin_identifier({'url': url})) import dulwich.repo
return str(pid) except ImportError:
raise click.ClickException(
"Cannot compute snapshot identifier; the Dulwich package is not installed. "
"Please install 'swh.model[cli]' for full functionality.",
)
from swh.model import hashutil
from swh.model.model import Snapshot
def pid_of_git_repo(path):
repo = dulwich.repo.Repo(path) repo = dulwich.repo.Repo(path)
branches = {} branches: Dict[bytes, Optional[Dict]] = {}
for ref, target in repo.refs.as_dict().items(): for ref, target in repo.refs.as_dict().items():
obj = repo[target] obj = repo[target]
if obj: if obj:
branches[ref] = { branches[ref] = {
'target': hashutil.bytehex_to_hash(target), "target": hashutil.bytehex_to_hash(target),
'target_type': _DULWICH_TYPES[obj.type_name], "target_type": _DULWICH_TYPES[obj.type_name],
} }
else: else:
branches[ref] = None branches[ref] = None
for ref, target in repo.refs.get_symrefs().items(): for ref, target in repo.refs.get_symrefs().items():
branches[ref] = { branches[ref] = {
'target': target, "target": target,
'target_type': 'alias', "target_type": "alias",
} }
snapshot = {'branches': branches} snapshot = {"branches": branches}
pid = pids.PersistentId(object_type='snapshot', return Snapshot.from_dict(snapshot).swhid()
object_id=pids.snapshot_identifier(snapshot))
return str(pid)
def identify_object(obj_type, follow_symlinks, obj): def identify_object(
if obj_type == 'auto': obj_type: str, follow_symlinks: bool, exclude_patterns: Iterable[bytes], obj
if obj == '-' or os.path.isfile(obj): ) -> str:
obj_type = 'content' from urllib.parse import urlparse
if obj_type == "auto":
if obj == "-" or os.path.isfile(obj):
obj_type = "content"
elif os.path.isdir(obj): elif os.path.isdir(obj):
obj_type = 'directory' obj_type = "directory"
else: else:
try: # URL parsing try: # URL parsing
if urlparse(obj).scheme: if urlparse(obj).scheme:
obj_type = 'origin' obj_type = "origin"
else: else:
raise ValueError raise ValueError
except ValueError: except ValueError:
raise click.BadParameter('cannot detect object type for %s' % raise click.BadParameter("cannot detect object type for %s" % obj)
obj)
pid = None
if obj == '-': if obj == "-":
content = sys.stdin.buffer.read() content = sys.stdin.buffer.read()
pid = pid_of_file_content(content) swhid = str(swhid_of_file_content(content))
elif obj_type in ['content', 'directory']: elif obj_type in ["content", "directory"]:
path = obj.encode(sys.getfilesystemencoding()) path = obj.encode(sys.getfilesystemencoding())
if follow_symlinks and os.path.islink(obj): if follow_symlinks and os.path.islink(obj):
path = os.path.realpath(obj) path = os.path.realpath(obj)
if obj_type == 'content': if obj_type == "content":
pid = pid_of_file(path) swhid = str(swhid_of_file(path))
elif obj_type == 'directory': elif obj_type == "directory":
pid = pid_of_dir(path) swhid = str(swhid_of_dir(path, exclude_patterns))
elif obj_type == 'origin': elif obj_type == "origin":
pid = pid_of_origin(obj) swhid = str(swhid_of_origin(obj))
elif obj_type == 'snapshot': elif obj_type == "snapshot":
pid = pid_of_git_repo(obj) swhid = str(swhid_of_git_repo(obj))
else: # shouldn't happen, due to option validation else: # shouldn't happen, due to option validation
raise click.BadParameter('invalid object type: ' + obj_type) raise click.BadParameter("invalid object type: " + obj_type)
# note: we return original obj instead of path here, to preserve user-given # note: we return original obj instead of path here, to preserve user-given
# file name in output # file name in output
return (obj, pid) return swhid
@click.command(context_settings=CONTEXT_SETTINGS) @cli_command(context_settings=CONTEXT_SETTINGS)
@click.option('--dereference/--no-dereference', 'follow_symlinks', @click.option(
default=True, "--dereference/--no-dereference",
help='follow (or not) symlinks for OBJECTS passed as arguments ' "follow_symlinks",
+ '(default: follow)') default=True,
@click.option('--filename/--no-filename', 'show_filename', default=True, help="follow (or not) symlinks for OBJECTS passed as arguments "
help='show/hide file name (default: show)') + "(default: follow)",
@click.option('--type', '-t', 'obj_type', default='auto', )
type=click.Choice(['auto', 'content', 'directory', 'origin', @click.option(
'snapshot']), "--filename/--no-filename",
help='type of object to identify (default: auto)') "show_filename",
@click.option('--verify', '-v', metavar='PID', type=PidParamType(), default=True,
help='reference identifier to be compared with computed one') help="show/hide file name (default: show)",
@click.argument('objects', nargs=-1) )
def identify(obj_type, verify, show_filename, follow_symlinks, objects): @click.option(
"""Compute the Software Heritage persistent identifier (PID) for the given "--type",
"-t",
"obj_type",
default="auto",
type=click.Choice(["auto", "content", "directory", "origin", "snapshot"]),
help="type of object to identify (default: auto)",
)
@click.option(
"--exclude",
"-x",
"exclude_patterns",
metavar="PATTERN",
multiple=True,
help="Exclude directories using glob patterns \
(e.g., ``*.git`` to exclude all .git directories)",
)
@click.option(
"--verify",
"-v",
metavar="SWHID",
type=CoreSWHIDParamType(),
help="reference identifier to be compared with computed one",
)
@click.option(
"-r",
"--recursive",
is_flag=True,
help="compute SWHID recursively",
)
@click.argument("objects", nargs=-1, required=True)
def identify(
obj_type,
verify,
show_filename,
follow_symlinks,
objects,
exclude_patterns,
recursive,
):
"""Compute the Software Heritage persistent identifier (SWHID) for the given
source code object(s). source code object(s).
For more details about Software Heritage PIDs see: For more details about SWHIDs see:
\b
https://docs.softwareheritage.org/devel/swh-model/persistent-identifiers.html https://docs.softwareheritage.org/devel/swh-model/persistent-identifiers.html
\b Tip: you can pass "-" to identify the content of standard input.
Examples:
Examples::
\b
$ swh identify fork.c kmod.c sched/deadline.c $ swh identify fork.c kmod.c sched/deadline.c
swh:1:cnt:2e391c754ae730bd2d8520c2ab497c403220c6e3 fork.c swh:1:cnt:2e391c754ae730bd2d8520c2ab497c403220c6e3 fork.c
swh:1:cnt:0277d1216f80ae1adeed84a686ed34c9b2931fc2 kmod.c swh:1:cnt:0277d1216f80ae1adeed84a686ed34c9b2931fc2 kmod.c
swh:1:cnt:57b939c81bce5d06fa587df8915f05affbe22b82 sched/deadline.c swh:1:cnt:57b939c81bce5d06fa587df8915f05affbe22b82 sched/deadline.c
\b
$ swh identify --no-filename /usr/src/linux/kernel/ $ swh identify --no-filename /usr/src/linux/kernel/
swh:1:dir:f9f858a48d663b3809c9e2f336412717496202ab swh:1:dir:f9f858a48d663b3809c9e2f336412717496202ab
\b
$ git clone --mirror https://forge.softwareheritage.org/source/helloworld.git $ git clone --mirror https://forge.softwareheritage.org/source/helloworld.git
$ swh identify --type snapshot helloworld.git/
swh:1:snp:510aa88bdc517345d258c1fc2babcd0e1f905e93 helloworld.git
""" # NoQA # overlong lines in shell examples are fine $ swh identify --type snapshot helloworld.git/
if not objects: swh:1:snp:510aa88bdc517345d258c1fc2babcd0e1f905e93 helloworld.git
objects = ['-']
if verify and len(objects) != 1: """
raise click.BadParameter('verification requires a single object') from functools import partial
import logging
results = map(partial(identify_object, obj_type, follow_symlinks), objects) if exclude_patterns:
exclude_patterns = set(pattern.encode() for pattern in exclude_patterns)
if verify: if verify and len(objects) != 1:
pid = next(results)[1] raise click.BadParameter("verification requires a single object")
if verify == pid:
click.echo('PID match: %s' % pid) if recursive and not os.path.isdir(objects[0]):
sys.exit(0) recursive = False
else: logging.warn("recursive option disabled, input is not a directory object")
click.echo('PID mismatch: %s != %s' % (verify, pid))
sys.exit(1) if recursive:
else: if verify:
for (obj, pid) in results: raise click.BadParameter(
msg = pid "verification of recursive object identification is not supported"
if show_filename: )
msg = '%s\t%s' % (pid, os.fsdecode(obj))
if not obj_type == ("auto" or "directory"):
raise click.BadParameter(
"recursive identification is supported only for directories"
)
path = os.fsencode(objects[0])
dir_obj = model_of_dir(path, exclude_patterns)
for sub_obj in dir_obj.iter_tree():
path_name = "path" if "path" in sub_obj.data.keys() else "data"
path = os.fsdecode(sub_obj.data[path_name])
swhid = str(sub_obj.swhid())
msg = f"{swhid}\t{path}" if show_filename else f"{swhid}"
click.echo(msg) click.echo(msg)
else:
results = zip(
objects,
map(
partial(identify_object, obj_type, follow_symlinks, exclude_patterns),
objects,
),
)
if verify:
swhid = next(results)[1]
if str(verify) == swhid:
click.echo("SWHID match: %s" % swhid)
sys.exit(0)
else:
click.echo("SWHID mismatch: %s != %s" % (verify, swhid))
sys.exit(1)
else:
for obj, swhid in results:
msg = swhid
if show_filename:
msg = "%s\t%s" % (swhid, os.fsdecode(obj))
click.echo(msg)
if __name__ == '__main__': if __name__ == "__main__":
identify() identify()
# Copyright (C) 2020-2023 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from __future__ import annotations
"""Utility data structures."""
from collections.abc import Mapping
import copy
from typing import Dict, Generic, Iterable, Optional, Tuple, TypeVar, Union
KT = TypeVar("KT")
VT = TypeVar("VT")
class ImmutableDict(Mapping, Generic[KT, VT]):
"""A frozen dictionary.
This class behaves like a dictionary, but internally stores objects in a tuple,
so it is both immutable and hashable."""
_data: Dict[KT, VT]
def __init__(
self,
data: Union[Iterable[Tuple[KT, VT]], ImmutableDict[KT, VT], Dict[KT, VT]] = {},
):
if isinstance(data, dict):
self._data = data
elif isinstance(data, ImmutableDict):
self._data = data._data
else:
self._data = {k: v for k, v in data}
@property
def data(self):
return tuple(self._data.items())
def __repr__(self):
return f"ImmutableDict({dict(self.data)!r})"
def __getitem__(self, key):
return self._data[key]
def __iter__(self):
for k, v in self.data:
yield k
def __len__(self):
return len(self._data)
def items(self):
yield from self.data
def __hash__(self):
return hash(tuple(sorted(self.data)))
def copy_pop(self, popped_key) -> Tuple[Optional[VT], ImmutableDict[KT, VT]]:
"""Returns a copy of this ImmutableDict without the given key,
as well as the value associated to the key."""
new_items = copy.deepcopy(self._data)
popped_value: Optional[VT] = new_items.pop(popped_key, None)
return (popped_value, ImmutableDict(new_items))
# Copyright (C) 2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
"""Primitives for finding unknown content efficiently."""
from __future__ import annotations
from collections import namedtuple
import itertools
import logging
import random
from typing import (
Any,
Callable,
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Set,
Union,
)
from typing_extensions import Protocol, runtime_checkable
from .from_disk import model
from .model import Sha1Git
logger = logging.getLogger(__name__)
# Maximum amount when sampling from the undecided set of directory entries
SAMPLE_SIZE = 1000
# Sets of sha1 of contents, skipped contents and directories respectively
Sample: NamedTuple = namedtuple(
"Sample", ["contents", "skipped_contents", "directories"]
)
@runtime_checkable
class ArchiveDiscoveryInterface(Protocol):
"""Interface used in discovery code to abstract over ways of connecting to
the SWH archive (direct storage, web API, etc.) for all methods needed by
discovery algorithms."""
contents: List[model.Content]
skipped_contents: List[model.SkippedContent]
directories: List[model.Directory]
def __init__(
self,
contents: List[model.Content],
skipped_contents: List[model.SkippedContent],
directories: List[model.Directory],
) -> None:
self.contents = contents
self.skipped_contents = skipped_contents
self.directories = directories
def content_missing(self, contents: List[Sha1Git]) -> Iterable[Sha1Git]:
"""List content missing from the archive by sha1"""
def skipped_content_missing(
self, skipped_contents: List[Sha1Git]
) -> Iterable[Sha1Git]:
"""List skipped content missing from the archive by sha1"""
def directory_missing(self, directories: List[Sha1Git]) -> Iterable[Sha1Git]:
"""List directories missing from the archive by sha1"""
class BaseDiscoveryGraph:
"""Creates the base structures and methods needed for discovery algorithms.
Subclasses should override ``get_sample`` to affect how the discovery is made.
The `update_info_callback` is an optional argument that will get called for
each new piece of information we get. The callback arguments are `(content,
known)`.
- content: the relevant model.Content object,
- known: a boolean, True if the file is known to the archive False otherwise.
"""
def __init__(
self,
contents,
skipped_contents,
directories,
update_info_callback: Optional[Callable[[Any, bool], None]] = None,
):
self._all_contents: Mapping[
Sha1Git, Union[model.Content, model.SkippedContent]
] = {}
self._undecided_directories: Set[Sha1Git] = set()
self._children: Mapping[Sha1Git, Set[Sha1Git]] = {}
self._parents: Mapping[model.DirectoryEntry, Set[Any]] = {}
self.undecided: Set[Sha1Git] = set()
for content in itertools.chain(contents, skipped_contents):
self.undecided.add(content.sha1_git)
self._all_contents[content.sha1_git] = content
for directory in directories:
self.undecided.add(directory.id)
self._undecided_directories.add(directory.id)
self._children[directory.id] = {c.target for c in directory.entries}
for child in directory.entries:
self._parents.setdefault(child.target, set()).add(directory.id)
self.undecided |= self._undecided_directories
self.known: Set[Sha1Git] = set()
self.unknown: Set[Sha1Git] = set()
self._update_info_callback = update_info_callback
self._sha1_to_obj = {}
for content in itertools.chain(contents, skipped_contents):
self._sha1_to_obj[content.sha1_git] = content
for directory in directories:
self._sha1_to_obj[directory.id] = directory
def mark_known(self, entries: Iterable[Sha1Git]):
"""Mark ``entries`` and those they imply as known in the SWH archive"""
self._mark_entries(entries, self._children, self.known)
def mark_unknown(self, entries: Iterable[Sha1Git]):
"""Mark ``entries`` and those they imply as unknown in the SWH archive"""
self._mark_entries(entries, self._parents, self.unknown)
def _mark_entries(
self,
entries: Iterable[Sha1Git],
transitive_mapping: Mapping[Any, Any],
target_set: Set[Any],
):
"""Use Merkle graph properties to mark a directory entry as known or unknown.
If an entry is known, then all of its descendants are known. If it's
unknown, then all of its ancestors are unknown.
- ``entries``: directory entries to mark along with their ancestors/descendants
where applicable.
- ``transitive_mapping``: mapping from an entry to the next entries to mark
in the hierarchy, if any.
- ``target_set``: set where marked entries will be added.
"""
callback = self._update_info_callback
to_process = set(entries)
while to_process:
current = to_process.pop()
target_set.add(current)
new = current in self.undecided
self.undecided.discard(current)
self._undecided_directories.discard(current)
next_entries = transitive_mapping.get(current, set()) & self.undecided
to_process.update(next_entries)
if new and callback is not None:
obj = self._sha1_to_obj[current]
callback(obj, current in self.known)
def get_sample(
self,
) -> Sample:
"""Return a three-tuple of samples from the undecided sets of contents,
skipped contents and directories respectively.
These samples will be queried against the storage which will tell us
which are known."""
raise NotImplementedError()
def do_query(self, archive: ArchiveDiscoveryInterface, sample: Sample) -> None:
"""Given a three-tuple of samples, ask the archive which are known or
unknown and mark them as such."""
methods = (
archive.content_missing,
archive.skipped_content_missing,
archive.directory_missing,
)
for sample_per_type, method in zip(sample, methods):
if not sample_per_type:
continue
known = set(sample_per_type)
unknown = set(method(list(sample_per_type)))
known -= unknown
self.mark_known(known)
self.mark_unknown(unknown)
class RandomDirSamplingDiscoveryGraph(BaseDiscoveryGraph):
"""Use a random sampling using only directories.
This allows us to find a statistically good spread of entries in the graph
with a smaller population than using all types of entries. When there are
no more directories, only contents or skipped contents are undecided if any
are left: we send them directly to the storage since they should be few and
their structure flat."""
def get_sample(self) -> Sample:
if self._undecided_directories:
if len(self._undecided_directories) <= SAMPLE_SIZE:
return Sample(
contents=set(),
skipped_contents=set(),
directories=set(self._undecided_directories),
)
sample = random.sample(tuple(self._undecided_directories), SAMPLE_SIZE)
directories = {o for o in sample}
return Sample(
contents=set(), skipped_contents=set(), directories=directories
)
contents = set()
skipped_contents = set()
for sha1 in self.undecided:
obj = self._all_contents[sha1]
obj_type = obj.object_type
if obj_type == model.Content.object_type:
contents.add(sha1)
elif obj_type == model.SkippedContent.object_type:
skipped_contents.add(sha1)
else:
raise TypeError(f"Unexpected object type {obj_type}")
return Sample(
contents=contents, skipped_contents=skipped_contents, directories=set()
)
def filter_known_objects(
archive: ArchiveDiscoveryInterface,
update_info_callback: Optional[Callable[[Any, bool], None]] = None,
):
"""Filter ``archive``'s ``contents``, ``skipped_contents`` and ``directories``
to only return those that are unknown to the SWH archive using a discovery
algorithm.
The `update_info_callback` is an optional argument that will get called for
each new piece of information we get. The callback arguments are `(content,
known)`.
- content: the relevant model.Content object,
- known: a boolean, True if the file is known to the archive False otherwise.
"""
contents = archive.contents
skipped_contents = archive.skipped_contents
directories = archive.directories
contents_count = len(contents)
skipped_contents_count = len(skipped_contents)
directories_count = len(directories)
graph = RandomDirSamplingDiscoveryGraph(
contents,
skipped_contents,
directories,
update_info_callback=update_info_callback,
)
while graph.undecided:
sample = graph.get_sample()
graph.do_query(archive, sample)
contents = [c for c in contents if c.sha1_git in graph.unknown]
skipped_contents = [c for c in skipped_contents if c.sha1_git in graph.unknown]
directories = [c for c in directories if c.id in graph.unknown]
logger.debug(
"Filtered out %d contents, %d skipped contents and %d directories",
contents_count - len(contents),
skipped_contents_count - len(skipped_contents),
directories_count - len(directories),
)
return (contents, skipped_contents, directories)
...@@ -33,11 +33,12 @@ ...@@ -33,11 +33,12 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE. # POSSIBILITY OF SUCH DAMAGE.
NON_FIELD_ERRORS = '__all__' NON_FIELD_ERRORS = "__all__"
class ValidationError(Exception): class ValidationError(Exception):
"""An error while validating data.""" """An error while validating data."""
def __init__(self, message, code=None, params=None): def __init__(self, message, code=None, params=None):
""" """
The `message` argument can be a single error, a list of errors, or a The `message` argument can be a single error, a list of errors, or a
...@@ -54,16 +55,15 @@ class ValidationError(Exception): ...@@ -54,16 +55,15 @@ class ValidationError(Exception):
message = message[0] message = message[0]
if isinstance(message, ValidationError): if isinstance(message, ValidationError):
if hasattr(message, 'error_dict'): if hasattr(message, "error_dict"):
message = message.error_dict message = message.error_dict
# PY2 has a `message` property which is always there so we can't # PY2 has a `message` property which is always there so we can't
# duck-type on it. It was introduced in Python 2.5 and already # duck-type on it. It was introduced in Python 2.5 and already
# deprecated in Python 2.6. # deprecated in Python 2.6.
elif not hasattr(message, 'message'): elif not hasattr(message, "message"):
message = message.error_list message = message.error_list
else: else:
message, code, params = (message.message, message.code, message, code, params = (message.message, message.code, message.params)
message.params)
if isinstance(message, dict): if isinstance(message, dict):
self.error_dict = {} self.error_dict = {}
...@@ -78,9 +78,8 @@ class ValidationError(Exception): ...@@ -78,9 +78,8 @@ class ValidationError(Exception):
# Normalize plain strings to instances of ValidationError. # Normalize plain strings to instances of ValidationError.
if not isinstance(message, ValidationError): if not isinstance(message, ValidationError):
message = ValidationError(message) message = ValidationError(message)
if hasattr(message, 'error_dict'): if hasattr(message, "error_dict"):
self.error_list.extend(sum(message.error_dict.values(), self.error_list.extend(sum(message.error_dict.values(), []))
[]))
else: else:
self.error_list.extend(message.error_list) self.error_list.extend(message.error_list)
...@@ -94,18 +93,18 @@ class ValidationError(Exception): ...@@ -94,18 +93,18 @@ class ValidationError(Exception):
def message_dict(self): def message_dict(self):
# Trigger an AttributeError if this ValidationError # Trigger an AttributeError if this ValidationError
# doesn't have an error_dict. # doesn't have an error_dict.
getattr(self, 'error_dict') getattr(self, "error_dict")
return dict(self) return dict(self)
@property @property
def messages(self): def messages(self):
if hasattr(self, 'error_dict'): if hasattr(self, "error_dict"):
return sum(dict(self).values(), []) return sum(dict(self).values(), [])
return list(self) return list(self)
def update_error_dict(self, error_dict): def update_error_dict(self, error_dict):
if hasattr(self, 'error_dict'): if hasattr(self, "error_dict"):
for field, error_list in self.error_dict.items(): for field, error_list in self.error_dict.items():
error_dict.setdefault(field, []).extend(error_list) error_dict.setdefault(field, []).extend(error_list)
else: else:
...@@ -113,7 +112,7 @@ class ValidationError(Exception): ...@@ -113,7 +112,7 @@ class ValidationError(Exception):
return error_dict return error_dict
def __iter__(self): def __iter__(self):
if hasattr(self, 'error_dict'): if hasattr(self, "error_dict"):
for field, errors in self.error_dict.items(): for field, errors in self.error_dict.items():
yield field, list(ValidationError(errors)) yield field, list(ValidationError(errors))
else: else:
...@@ -124,9 +123,13 @@ class ValidationError(Exception): ...@@ -124,9 +123,13 @@ class ValidationError(Exception):
yield message yield message
def __str__(self): def __str__(self):
if hasattr(self, 'error_dict'): if hasattr(self, "error_dict"):
return repr(dict(self)) return repr(dict(self))
return repr(list(self)) return repr(list(self))
def __repr__(self): def __repr__(self):
return 'ValidationError(%s)' % self return "ValidationError(%s)" % self
class InvalidDirectoryPath(Exception):
pass
...@@ -6,8 +6,13 @@ ...@@ -6,8 +6,13 @@
# We do our imports here but we don't use them, so flake8 complains # We do our imports here but we don't use them, so flake8 complains
# flake8: noqa # flake8: noqa
from .simple import (validate_type, validate_int, validate_str, validate_bytes, from .compound import validate_against_schema, validate_all_keys, validate_any_key
validate_datetime, validate_enum) from .hashes import validate_sha1, validate_sha1_git, validate_sha256
from .hashes import (validate_sha1, validate_sha1_git, validate_sha256) from .simple import (
from .compound import (validate_against_schema, validate_all_keys, validate_bytes,
validate_any_key) validate_datetime,
validate_enum,
validate_int,
validate_str,
validate_type,
)
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from collections import defaultdict from collections import defaultdict
import itertools import itertools
from ..exceptions import ValidationError, NON_FIELD_ERRORS from ..exceptions import NON_FIELD_ERRORS, ValidationError
def validate_against_schema(model, schema, value): def validate_against_schema(model, schema, value):
...@@ -26,19 +26,19 @@ def validate_against_schema(model, schema, value): ...@@ -26,19 +26,19 @@ def validate_against_schema(model, schema, value):
if not isinstance(value, dict): if not isinstance(value, dict):
raise ValidationError( raise ValidationError(
'Unexpected type %(type)s for %(model)s, expected dict', "Unexpected type %(type)s for %(model)s, expected dict",
params={ params={
'model': model, "model": model,
'type': value.__class__.__name__, "type": value.__class__.__name__,
}, },
code='model-unexpected-type', code="model-unexpected-type",
) )
errors = defaultdict(list) errors = defaultdict(list)
for key, (mandatory, validators) in itertools.chain( for key, (mandatory, validators) in itertools.chain(
((k, v) for k, v in schema.items() if k != NON_FIELD_ERRORS), ((k, v) for k, v in schema.items() if k != NON_FIELD_ERRORS),
[(NON_FIELD_ERRORS, (False, schema.get(NON_FIELD_ERRORS, [])))] [(NON_FIELD_ERRORS, (False, schema.get(NON_FIELD_ERRORS, [])))],
): ):
if not validators: if not validators:
continue continue
...@@ -54,9 +54,9 @@ def validate_against_schema(model, schema, value): ...@@ -54,9 +54,9 @@ def validate_against_schema(model, schema, value):
if mandatory: if mandatory:
errors[key].append( errors[key].append(
ValidationError( ValidationError(
'Field %(field)s is mandatory', "Field %(field)s is mandatory",
params={'field': key}, params={"field": key},
code='model-field-mandatory', code="model-field-mandatory",
) )
) )
...@@ -74,19 +74,21 @@ def validate_against_schema(model, schema, value): ...@@ -74,19 +74,21 @@ def validate_against_schema(model, schema, value):
else: else:
if not valid: if not valid:
errdata = { errdata = {
'validator': validator.__name__, "validator": validator.__name__,
} }
if key == NON_FIELD_ERRORS: if key == NON_FIELD_ERRORS:
errmsg = 'Validation of model %(model)s failed in ' \ errmsg = (
'%(validator)s' "Validation of model %(model)s failed in " "%(validator)s"
errdata['model'] = model )
errcode = 'model-validation-failed' errdata["model"] = model
errcode = "model-validation-failed"
else: else:
errmsg = 'Validation of field %(field)s failed in ' \ errmsg = (
'%(validator)s' "Validation of field %(field)s failed in " "%(validator)s"
errdata['field'] = key )
errcode = 'field-validation-failed' errdata["field"] = key
errcode = "field-validation-failed"
errors[key].append( errors[key].append(
ValidationError(errmsg, params=errdata, code=errcode) ValidationError(errmsg, params=errdata, code=errcode)
...@@ -102,11 +104,11 @@ def validate_all_keys(value, keys): ...@@ -102,11 +104,11 @@ def validate_all_keys(value, keys):
"""Validate that all the given keys are present in value""" """Validate that all the given keys are present in value"""
missing_keys = set(keys) - set(value) missing_keys = set(keys) - set(value)
if missing_keys: if missing_keys:
missing_fields = ', '.join(sorted(missing_keys)) missing_fields = ", ".join(sorted(missing_keys))
raise ValidationError( raise ValidationError(
'Missing mandatory fields %(missing_fields)s', "Missing mandatory fields %(missing_fields)s",
params={'missing_fields': missing_fields}, params={"missing_fields": missing_fields},
code='missing-mandatory-field' code="missing-mandatory-field",
) )
return True return True
...@@ -116,11 +118,11 @@ def validate_any_key(value, keys): ...@@ -116,11 +118,11 @@ def validate_any_key(value, keys):
"""Validate that any of the given keys is present in value""" """Validate that any of the given keys is present in value"""
present_keys = set(keys) & set(value) present_keys = set(keys) & set(value)
if not present_keys: if not present_keys:
missing_fields = ', '.join(sorted(keys)) missing_fields = ", ".join(sorted(keys))
raise ValidationError( raise ValidationError(
'Must contain one of the alternative fields %(missing_fields)s', "Must contain one of the alternative fields %(missing_fields)s",
params={'missing_fields': missing_fields}, params={"missing_fields": missing_fields},
code='missing-alternative-field', code="missing-alternative-field",
) )
return True return True
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# See top-level LICENSE file for more information # See top-level LICENSE file for more information
import string import string
from ..exceptions import ValidationError from ..exceptions import ValidationError
...@@ -22,22 +23,22 @@ def validate_hash(value, hash_type): ...@@ -22,22 +23,22 @@ def validate_hash(value, hash_type):
""" """
hash_lengths = { hash_lengths = {
'sha1': 20, "sha1": 20,
'sha1_git': 20, "sha1_git": 20,
'sha256': 32, "sha256": 32,
} }
hex_digits = set(string.hexdigits) hex_digits = set(string.hexdigits)
if hash_type not in hash_lengths: if hash_type not in hash_lengths:
raise ValidationError( raise ValidationError(
'Unexpected hash type %(hash_type)s, expected one of' "Unexpected hash type %(hash_type)s, expected one of" " %(hash_types)s",
' %(hash_types)s',
params={ params={
'hash_type': hash_type, "hash_type": hash_type,
'hash_types': ', '.join(sorted(hash_lengths)), "hash_types": ", ".join(sorted(hash_lengths)),
}, },
code='unexpected-hash-type') code="unexpected-hash-type",
)
if isinstance(value, str): if isinstance(value, str):
errors = [] errors = []
...@@ -48,10 +49,10 @@ def validate_hash(value, hash_type): ...@@ -48,10 +49,10 @@ def validate_hash(value, hash_type):
"Unexpected characters `%(unexpected_chars)s' for hash " "Unexpected characters `%(unexpected_chars)s' for hash "
"type %(hash_type)s", "type %(hash_type)s",
params={ params={
'unexpected_chars': ', '.join(sorted(extra_chars)), "unexpected_chars": ", ".join(sorted(extra_chars)),
'hash_type': hash_type, "hash_type": hash_type,
}, },
code='unexpected-hash-contents', code="unexpected-hash-contents",
) )
) )
...@@ -60,14 +61,14 @@ def validate_hash(value, hash_type): ...@@ -60,14 +61,14 @@ def validate_hash(value, hash_type):
if length != expected_length: if length != expected_length:
errors.append( errors.append(
ValidationError( ValidationError(
'Unexpected length %(length)d for hash type ' "Unexpected length %(length)d for hash type "
'%(hash_type)s, expected %(expected_length)d', "%(hash_type)s, expected %(expected_length)d",
params={ params={
'length': length, "length": length,
'expected_length': expected_length, "expected_length": expected_length,
'hash_type': hash_type, "hash_type": hash_type,
}, },
code='unexpected-hash-length', code="unexpected-hash-length",
) )
) )
...@@ -81,37 +82,37 @@ def validate_hash(value, hash_type): ...@@ -81,37 +82,37 @@ def validate_hash(value, hash_type):
expected_length = hash_lengths[hash_type] expected_length = hash_lengths[hash_type]
if length != expected_length: if length != expected_length:
raise ValidationError( raise ValidationError(
'Unexpected length %(length)d for hash type ' "Unexpected length %(length)d for hash type "
'%(hash_type)s, expected %(expected_length)d', "%(hash_type)s, expected %(expected_length)d",
params={ params={
'length': length, "length": length,
'expected_length': expected_length, "expected_length": expected_length,
'hash_type': hash_type, "hash_type": hash_type,
}, },
code='unexpected-hash-length', code="unexpected-hash-length",
) )
return True return True
raise ValidationError( raise ValidationError(
'Unexpected type %(type)s for hash, expected str or bytes', "Unexpected type %(type)s for hash, expected str or bytes",
params={ params={
'type': value.__class__.__name__, "type": value.__class__.__name__,
}, },
code='unexpected-hash-value-type', code="unexpected-hash-value-type",
) )
def validate_sha1(sha1): def validate_sha1(sha1):
"""Validate that sha1 is a valid sha1 hash""" """Validate that sha1 is a valid sha1 hash"""
return validate_hash(sha1, 'sha1') return validate_hash(sha1, "sha1")
def validate_sha1_git(sha1_git): def validate_sha1_git(sha1_git):
"""Validate that sha1_git is a valid sha1_git hash""" """Validate that sha1_git is a valid sha1_git hash"""
return validate_hash(sha1_git, 'sha1_git') return validate_hash(sha1_git, "sha1_git")
def validate_sha256(sha256): def validate_sha256(sha256):
"""Validate that sha256 is a valid sha256 hash""" """Validate that sha256 is a valid sha256 hash"""
return validate_hash(sha256, 'sha256') return validate_hash(sha256, "sha256")
...@@ -13,16 +13,16 @@ def validate_type(value, type): ...@@ -13,16 +13,16 @@ def validate_type(value, type):
"""Validate that value is an integer""" """Validate that value is an integer"""
if not isinstance(value, type): if not isinstance(value, type):
if isinstance(type, tuple): if isinstance(type, tuple):
typestr = 'one of %s' % ', '.join(typ.__name__ for typ in type) typestr = "one of %s" % ", ".join(typ.__name__ for typ in type)
else: else:
typestr = type.__name__ typestr = type.__name__
raise ValidationError( raise ValidationError(
'Unexpected type %(type)s, expected %(expected_type)s', "Unexpected type %(type)s, expected %(expected_type)s",
params={ params={
'type': value.__class__.__name__, "type": value.__class__.__name__,
'expected_type': typestr, "expected_type": typestr,
}, },
code='unexpected-type' code="unexpected-type",
) )
return True return True
...@@ -54,10 +54,12 @@ def validate_datetime(value): ...@@ -54,10 +54,12 @@ def validate_datetime(value):
errors.append(e) errors.append(e)
if isinstance(value, datetime.datetime) and value.tzinfo is None: if isinstance(value, datetime.datetime) and value.tzinfo is None:
errors.append(ValidationError( errors.append(
'Datetimes must be timezone-aware in swh', ValidationError(
code='datetime-without-tzinfo', "Datetimes must be timezone-aware in swh",
)) code="datetime-without-tzinfo",
)
)
if errors: if errors:
raise ValidationError(errors) raise ValidationError(errors)
...@@ -69,12 +71,12 @@ def validate_enum(value, expected_values): ...@@ -69,12 +71,12 @@ def validate_enum(value, expected_values):
"""Validate that value is contained in expected_values""" """Validate that value is contained in expected_values"""
if value not in expected_values: if value not in expected_values:
raise ValidationError( raise ValidationError(
'Unexpected value %(value)s, expected one of %(expected_values)s', "Unexpected value %(value)s, expected one of %(expected_values)s",
params={ params={
'value': value, "value": value,
'expected_values': ', '.join(sorted(expected_values)), "expected_values": ", ".join(sorted(expected_values)),
}, },
code='unexpected-value', code="unexpected-value",
) )
return True return True
This diff is collapsed.
...@@ -54,15 +54,16 @@ Basic usage examples: ...@@ -54,15 +54,16 @@ Basic usage examples:
import binascii import binascii
import functools import functools
import hashlib import hashlib
import os
from io import BytesIO from io import BytesIO
from typing import Callable, Dict import os
from typing import Callable, Dict, Optional, Union
ALGORITHMS = set(['sha1', 'sha256', 'sha1_git', 'blake2s256', 'blake2b512']) ALGORITHMS = set(
["sha1", "sha256", "sha1_git", "blake2s256", "blake2b512", "md5", "sha512"]
)
"""Hashing algorithms supported by this module""" """Hashing algorithms supported by this module"""
DEFAULT_ALGORITHMS = set(['sha1', 'sha256', 'sha1_git', 'blake2s256']) DEFAULT_ALGORITHMS = set(["sha1", "sha256", "sha1_git", "blake2s256"])
"""Algorithms computed by default when calling the functions from this module. """Algorithms computed by default when calling the functions from this module.
Subset of :const:`ALGORITHMS`. Subset of :const:`ALGORITHMS`.
...@@ -71,7 +72,7 @@ Subset of :const:`ALGORITHMS`. ...@@ -71,7 +72,7 @@ Subset of :const:`ALGORITHMS`.
HASH_BLOCK_SIZE = 32768 HASH_BLOCK_SIZE = 32768
"""Block size for streaming hash computations made in this module""" """Block size for streaming hash computations made in this module"""
_blake2_hash_cache = {} # type: Dict[str, Callable] _blake2_hash_cache: Dict[str, Callable] = {}
class MultiHash: class MultiHash:
...@@ -87,12 +88,13 @@ class MultiHash: ...@@ -87,12 +88,13 @@ class MultiHash:
computed and returned. computed and returned.
""" """
def __init__(self, hash_names=DEFAULT_ALGORITHMS, length=None): def __init__(self, hash_names=DEFAULT_ALGORITHMS, length=None):
self.state = {} self.state = {}
self.track_length = False self.track_length = False
for name in hash_names: for name in hash_names:
if name == 'length': if name == "length":
self.state['length'] = 0 self.state["length"] = 0
self.track_length = True self.track_length = True
else: else:
self.state[name] = _new_hash(name, length) self.state[name] = _new_hash(name, length)
...@@ -116,7 +118,7 @@ class MultiHash: ...@@ -116,7 +118,7 @@ class MultiHash:
@classmethod @classmethod
def from_path(cls, path, hash_names=DEFAULT_ALGORITHMS): def from_path(cls, path, hash_names=DEFAULT_ALGORITHMS):
length = os.path.getsize(path) length = os.path.getsize(path)
with open(path, 'rb') as f: with open(path, "rb") as f:
ret = cls.from_file(f, hash_names=hash_names, length=length) ret = cls.from_file(f, hash_names=hash_names, length=length)
return ret return ret
...@@ -128,48 +130,45 @@ class MultiHash: ...@@ -128,48 +130,45 @@ class MultiHash:
def update(self, chunk): def update(self, chunk):
for name, h in self.state.items(): for name, h in self.state.items():
if name == 'length': if name == "length":
continue continue
h.update(chunk) h.update(chunk)
if self.track_length: if self.track_length:
self.state['length'] += len(chunk) self.state["length"] += len(chunk)
def digest(self): def digest(self):
return { return {
name: h.digest() if name != 'length' else h name: h.digest() if name != "length" else h
for name, h in self.state.items() for name, h in self.state.items()
} }
def hexdigest(self): def hexdigest(self):
return { return {
name: h.hexdigest() if name != 'length' else h name: h.hexdigest() if name != "length" else h
for name, h in self.state.items() for name, h in self.state.items()
} }
def bytehexdigest(self): def bytehexdigest(self):
return { return {
name: hash_to_bytehex(h.digest()) if name != 'length' else h name: hash_to_bytehex(h.digest()) if name != "length" else h
for name, h in self.state.items() for name, h in self.state.items()
} }
def copy(self): def copy(self):
copied_state = { copied_state = {
name: h.copy() if name != 'length' else h name: h.copy() if name != "length" else h for name, h in self.state.items()
for name, h in self.state.items()
} }
return self.from_state(copied_state, self.track_length) return self.from_state(copied_state, self.track_length)
def _new_blake2_hash(algo): def _new_blake2_hash(algo):
"""Return a function that initializes a blake2 hash. """Return a function that initializes a blake2 hash."""
"""
if algo in _blake2_hash_cache: if algo in _blake2_hash_cache:
return _blake2_hash_cache[algo]() return _blake2_hash_cache[algo]()
lalgo = algo.lower() lalgo = algo.lower()
if not lalgo.startswith('blake2'): if not lalgo.startswith("blake2"):
raise ValueError('Algorithm %s is not a blake2 hash' % algo) raise ValueError("Algorithm %s is not a blake2 hash" % algo)
blake_family = lalgo[:7] blake_family = lalgo[:7]
...@@ -178,27 +177,14 @@ def _new_blake2_hash(algo): ...@@ -178,27 +177,14 @@ def _new_blake2_hash(algo):
try: try:
digest_size, remainder = divmod(int(lalgo[7:]), 8) digest_size, remainder = divmod(int(lalgo[7:]), 8)
except ValueError: except ValueError:
raise ValueError( raise ValueError("Unknown digest size for algo %s" % algo) from None
'Unknown digest size for algo %s' % algo
) from None
if remainder: if remainder:
raise ValueError( raise ValueError(
'Digest size for algorithm %s must be a multiple of 8' % algo "Digest size for algorithm %s must be a multiple of 8" % algo
) )
if lalgo in hashlib.algorithms_available: blake2 = getattr(hashlib, blake_family)
# Handle the case where OpenSSL ships the given algorithm _blake2_hash_cache[algo] = lambda: blake2(digest_size=digest_size)
# (e.g. Python 3.5 on Debian 9 stretch)
_blake2_hash_cache[algo] = lambda: hashlib.new(lalgo)
else:
# Try using the built-in implementation for Python 3.6+
if blake_family in hashlib.algorithms_available:
blake2 = getattr(hashlib, blake_family)
else:
import pyblake2
blake2 = getattr(pyblake2, blake_family)
_blake2_hash_cache[algo] = lambda: blake2(digest_size=digest_size)
return _blake2_hash_cache[algo]() return _blake2_hash_cache[algo]()
...@@ -208,18 +194,16 @@ def _new_hashlib_hash(algo): ...@@ -208,18 +194,16 @@ def _new_hashlib_hash(algo):
Handle the swh-specific names for the blake2-related algorithms Handle the swh-specific names for the blake2-related algorithms
""" """
if algo.startswith('blake2'): if algo.startswith("blake2"):
return _new_blake2_hash(algo) return _new_blake2_hash(algo)
else: else:
return hashlib.new(algo) return hashlib.new(algo)
def _new_git_hash(base_algo, git_type, length): def git_object_header(git_type: str, length: int) -> bytes:
"""Initialize a digest object (as returned by python's hashlib) for the """Returns the header for a git object of the given type and length.
requested algorithm, and feed it with the header for a git object of the
given type and length.
The header for hashing a git object consists of: The header of a git object consists of:
- The type of the object (encoded in ASCII) - The type of the object (encoded in ASCII)
- One ASCII space (\x20) - One ASCII space (\x20)
- The length of the object (decimal encoded in ASCII) - The length of the object (decimal encoded in ASCII)
...@@ -234,15 +218,26 @@ def _new_git_hash(base_algo, git_type, length): ...@@ -234,15 +218,26 @@ def _new_git_hash(base_algo, git_type, length):
Returns: Returns:
a hashutil.hash object a hashutil.hash object
""" """
git_object_types = {
"blob",
"tree",
"commit",
"tag",
"snapshot",
"raw_extrinsic_metadata",
"extid",
}
h = _new_hashlib_hash(base_algo) if git_type not in git_object_types:
git_header = '%s %d\0' % (git_type, length) raise ValueError(
h.update(git_header.encode('ascii')) "Unexpected git object type %s, expected one of %s"
% (git_type, ", ".join(sorted(git_object_types)))
)
return h return ("%s %d\0" % (git_type, length)).encode("ascii")
def _new_hash(algo, length=None): def _new_hash(algo: str, length: Optional[int] = None):
"""Initialize a digest object (as returned by python's hashlib) for """Initialize a digest object (as returned by python's hashlib) for
the requested algorithm. See the constant ALGORITHMS for the list the requested algorithm. See the constant ALGORITHMS for the list
of supported algorithms. If a git-specific hashing algorithm is of supported algorithms. If a git-specific hashing algorithm is
...@@ -264,19 +259,22 @@ def _new_hash(algo, length=None): ...@@ -264,19 +259,22 @@ def _new_hash(algo, length=None):
""" """
if algo not in ALGORITHMS: if algo not in ALGORITHMS:
raise ValueError( raise ValueError(
'Unexpected hashing algorithm %s, expected one of %s' % "Unexpected hashing algorithm %s, expected one of %s"
(algo, ', '.join(sorted(ALGORITHMS)))) % (algo, ", ".join(sorted(ALGORITHMS)))
)
if algo.endswith('_git'): if algo.endswith("_git"):
if length is None: if length is None:
raise ValueError('Missing length for git hashing algorithm') raise ValueError("Missing length for git hashing algorithm")
base_algo = algo[:-4] base_algo = algo[:-4]
return _new_git_hash(base_algo, 'blob', length) h = _new_hashlib_hash(base_algo)
h.update(git_object_header("blob", length))
return h
return _new_hashlib_hash(algo) return _new_hashlib_hash(algo)
def hash_git_data(data, git_type, base_algo='sha1'): def hash_git_data(data, git_type, base_algo="sha1"):
"""Hash the given data as a git object of type git_type. """Hash the given data as a git object of type git_type.
Args: Args:
...@@ -289,21 +287,15 @@ def hash_git_data(data, git_type, base_algo='sha1'): ...@@ -289,21 +287,15 @@ def hash_git_data(data, git_type, base_algo='sha1'):
Raises: Raises:
ValueError if the git_type is unexpected. ValueError if the git_type is unexpected.
""" """
h = _new_hashlib_hash(base_algo)
git_object_types = {'blob', 'tree', 'commit', 'tag', 'snapshot'} h.update(git_object_header(git_type, len(data)))
if git_type not in git_object_types:
raise ValueError('Unexpected git object type %s, expected one of %s' %
(git_type, ', '.join(sorted(git_object_types))))
h = _new_git_hash(base_algo, git_type, len(data))
h.update(data) h.update(data)
return h.digest() return h.digest()
@functools.lru_cache() @functools.lru_cache()
def hash_to_hex(hash): def hash_to_hex(hash: Union[str, bytes]) -> str:
"""Converts a hash (in hex or bytes form) to its hexadecimal ascii form """Converts a hash (in hex or bytes form) to its hexadecimal ascii form
Args: Args:
...@@ -315,11 +307,11 @@ def hash_to_hex(hash): ...@@ -315,11 +307,11 @@ def hash_to_hex(hash):
""" """
if isinstance(hash, str): if isinstance(hash, str):
return hash return hash
return binascii.hexlify(hash).decode('ascii') return binascii.hexlify(hash).decode("ascii")
@functools.lru_cache() @functools.lru_cache()
def hash_to_bytehex(hash): def hash_to_bytehex(hash: bytes) -> bytes:
"""Converts a hash to its hexadecimal bytes representation """Converts a hash to its hexadecimal bytes representation
Args: Args:
...@@ -332,7 +324,7 @@ def hash_to_bytehex(hash): ...@@ -332,7 +324,7 @@ def hash_to_bytehex(hash):
@functools.lru_cache() @functools.lru_cache()
def hash_to_bytes(hash): def hash_to_bytes(hash: Union[str, bytes]) -> bytes:
"""Converts a hash (in hex or bytes form) to its raw bytes form """Converts a hash (in hex or bytes form) to its raw bytes form
Args: Args:
...@@ -348,7 +340,7 @@ def hash_to_bytes(hash): ...@@ -348,7 +340,7 @@ def hash_to_bytes(hash):
@functools.lru_cache() @functools.lru_cache()
def bytehex_to_hash(hex): def bytehex_to_hash(hex: bytes) -> bytes:
"""Converts a hexadecimal bytes representation of a hash to that hash """Converts a hexadecimal bytes representation of a hash to that hash
Args: Args:
......
This diff is collapsed.
# Copyright (C) 2017 The Software Heritage developers # Copyright (C) 2017-2022 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution # See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version # License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information # See top-level LICENSE file for more information
"""Merkle tree data structure""" """Merkle tree data structure"""
import abc from __future__ import annotations
import collections
from typing import List, Optional
def deep_update(left, right):
"""Recursively update the left mapping with deeply nested values from the right
mapping.
This function is useful to merge the results of several calls to
:func:`MerkleNode.collect`.
Arguments:
left: a mapping (modified by the update operation)
right: a mapping
Returns:
the left mapping, updated with nested values from the right mapping
Example:
>>> a = {
... 'key1': {
... 'key2': {
... 'key3': 'value1/2/3',
... },
... },
... }
>>> deep_update(a, {
... 'key1': {
... 'key2': {
... 'key4': 'value1/2/4',
... },
... },
... }) == {
... 'key1': {
... 'key2': {
... 'key3': 'value1/2/3',
... 'key4': 'value1/2/4',
... },
... },
... }
True
>>> deep_update(a, {
... 'key1': {
... 'key2': {
... 'key3': 'newvalue1/2/3',
... },
... },
... }) == {
... 'key1': {
... 'key2': {
... 'key3': 'newvalue1/2/3',
... 'key4': 'value1/2/4',
... },
... },
... }
True
""" import abc
for key, rvalue in right.items(): from typing import Any, Dict, Iterator, List, Set
if isinstance(rvalue, collections.Mapping):
new_lvalue = deep_update(left.get(key, {}), rvalue)
left[key] = new_lvalue
else:
left[key] = rvalue
return left
class MerkleNode(dict, metaclass=abc.ABCMeta): class MerkleNode(dict, metaclass=abc.ABCMeta):
...@@ -102,16 +39,18 @@ class MerkleNode(dict, metaclass=abc.ABCMeta): ...@@ -102,16 +39,18 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
The collection of updated data from the tree is implemented through the The collection of updated data from the tree is implemented through the
:func:`collect` function and associated helpers. :func:`collect` function and associated helpers.
Attributes:
data (dict): data associated to the current node
parents (list): known parents of the current node
collected (bool): whether the current node has been collected
""" """
__slots__ = ['parents', 'data', '__hash', 'collected']
type = None # type: Optional[str] # TODO: make this an enum __slots__ = ["parents", "data", "__hash", "collected"]
"""Type of the current node (used as a classifier for :func:`collect`)"""
data: Dict
"""data associated to the current node"""
parents: List
"""known parents of the current node"""
collected: bool
"""whether the current node has been collected"""
def __init__(self, data=None): def __init__(self, data=None):
super().__init__() super().__init__()
...@@ -120,6 +59,16 @@ class MerkleNode(dict, metaclass=abc.ABCMeta): ...@@ -120,6 +59,16 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
self.__hash = None self.__hash = None
self.collected = False self.collected = False
def __eq__(self, other):
return (
isinstance(other, MerkleNode)
and super().__eq__(other)
and self.data == other.data
)
def __ne__(self, other):
return not self.__eq__(other)
def invalidate_hash(self): def invalidate_hash(self):
"""Invalidate the cached hash of the current node.""" """Invalidate the cached hash of the current node."""
if not self.__hash: if not self.__hash:
...@@ -130,7 +79,7 @@ class MerkleNode(dict, metaclass=abc.ABCMeta): ...@@ -130,7 +79,7 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
for parent in self.parents: for parent in self.parents:
parent.invalidate_hash() parent.invalidate_hash()
def update_hash(self, *, force=False): def update_hash(self, *, force=False) -> Any:
"""Recursively compute the hash of the current node. """Recursively compute the hash of the current node.
Args: Args:
...@@ -150,20 +99,23 @@ class MerkleNode(dict, metaclass=abc.ABCMeta): ...@@ -150,20 +99,23 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
return self.__hash return self.__hash
@property @property
def hash(self): def hash(self) -> Any:
"""The hash of the current node, as calculated by """The hash of the current node, as calculated by
:func:`compute_hash`. :func:`compute_hash`.
""" """
return self.update_hash() return self.update_hash()
def __hash__(self):
return hash(self.hash)
@abc.abstractmethod @abc.abstractmethod
def compute_hash(self): def compute_hash(self) -> Any:
"""Compute the hash of the current node. """Compute the hash of the current node.
The hash should depend on the data of the node, as well as on hashes The hash should depend on the data of the node, as well as on hashes
of the children nodes. of the children nodes.
""" """
raise NotImplementedError('Must implement compute_hash method') raise NotImplementedError("Must implement compute_hash method")
def __setitem__(self, name, new_child): def __setitem__(self, name, new_child):
"""Add a child, invalidating the current hash""" """Add a child, invalidating the current hash"""
...@@ -212,47 +164,24 @@ class MerkleNode(dict, metaclass=abc.ABCMeta): ...@@ -212,47 +164,24 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
""" """
return self.data return self.data
def collect_node(self, **kwargs): def collect_node(self) -> Set[MerkleNode]:
"""Collect the data for the current node, for use by :func:`collect`. """Collect the current node if it has not been yet, for use by :func:`collect`."""
Arguments:
kwargs: passed as-is to :func:`get_data`.
Returns:
A :class:`dict` compatible with :func:`collect`.
"""
if not self.collected: if not self.collected:
self.collected = True self.collected = True
return {self.type: {self.hash: self.get_data(**kwargs)}} return {self}
else: else:
return {} return set()
def collect(self, **kwargs):
"""Collect the data for all nodes in the subtree rooted at `self`.
The data is deduplicated by type and by hash. def collect(self) -> Set[MerkleNode]:
"""Collect the added and modified nodes in the subtree rooted at `self`
Arguments: since the last collect operation.
kwargs: passed as-is to :func:`get_data`.
Returns: Returns:
A :class:`dict` with the following structure:: A :class:`set` of collected nodes
{
'typeA': {
node1.hash: node1.get_data(),
node2.hash: node2.get_data(),
},
'typeB': {
node3.hash: node3.get_data(),
...
},
...
}
""" """
ret = self.collect_node(**kwargs) ret = self.collect_node()
for child in self.values(): for child in self.values():
deep_update(ret, child.collect(**kwargs)) ret.update(child.collect())
return ret return ret
...@@ -266,23 +195,39 @@ class MerkleNode(dict, metaclass=abc.ABCMeta): ...@@ -266,23 +195,39 @@ class MerkleNode(dict, metaclass=abc.ABCMeta):
for child in self.values(): for child in self.values():
child.reset_collect() child.reset_collect()
def iter_tree(self, dedup=True) -> Iterator[MerkleNode]:
"""Yields all children nodes, recursively. Common nodes are deduplicated
by default (deduplication can be turned off setting the given argument
'dedup' to False).
"""
yield from self._iter_tree(seen=set(), dedup=dedup)
def _iter_tree(self, seen: Set[bytes], dedup) -> Iterator[MerkleNode]:
if self.hash not in seen:
if dedup:
seen.add(self.hash)
yield self
for child in self.values():
yield from child._iter_tree(seen=seen, dedup=dedup)
class MerkleLeaf(MerkleNode): class MerkleLeaf(MerkleNode):
"""A leaf to a Merkle tree. """A leaf to a Merkle tree.
A Merkle leaf is simply a Merkle node with children disabled. A Merkle leaf is simply a Merkle node with children disabled.
""" """
__slots__ = [] # type: List[str]
__slots__: List[str] = []
def __setitem__(self, name, child): def __setitem__(self, name, child):
raise ValueError('%s is a leaf' % self.__class__.__name__) raise ValueError("%s is a leaf" % self.__class__.__name__)
def __getitem__(self, name): def __getitem__(self, name):
raise ValueError('%s is a leaf' % self.__class__.__name__) raise ValueError("%s is a leaf" % self.__class__.__name__)
def __delitem__(self, name): def __delitem__(self, name):
raise ValueError('%s is a leaf' % self.__class__.__name__) raise ValueError("%s is a leaf" % self.__class__.__name__)
def update(self, new_children): def update(self, new_children):
"""Children update operation. Disabled for leaves.""" """Children update operation. Disabled for leaves."""
raise ValueError('%s is a leaf' % self.__class__.__name__) raise ValueError("%s is a leaf" % self.__class__.__name__)
This diff is collapsed.
This diff is collapsed.