Skip to content
Snippets Groups Projects
Commit 1e924e84 authored by vlorentz's avatar vlorentz
Browse files

cli: stop using the deprecated SWHID class

parent 8e011996
No related branches found
No related tags found
No related merge requests found
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
import os import os
import sys import sys
from typing import List from typing import Dict, List, Optional
# WARNING: do not import unnecessary things here to keep cli startup time under # WARNING: do not import unnecessary things here to keep cli startup time under
# control # control
import click import click
from swh.core.cli import swh as swh_cli_group from swh.core.cli import swh as swh_cli_group
from swh.model.identifiers import SWHID from swh.model.identifiers import CoreSWHID, ObjectType
CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"])
...@@ -26,45 +26,48 @@ _DULWICH_TYPES = { ...@@ -26,45 +26,48 @@ _DULWICH_TYPES = {
} }
class SWHIDParamType(click.ParamType): class CoreSWHIDParamType(click.ParamType):
"""Click argument that accepts SWHID and return them as """Click argument that accepts a core SWHID and returns them as
:class:`swh.model.identifiers.SWHID` instances """ :class:`swh.model.identifiers.CoreSWHID` instances """
name = "SWHID" name = "SWHID"
def convert(self, value, param, ctx) -> SWHID: def convert(self, value, param, ctx) -> CoreSWHID:
from swh.model.exceptions import ValidationError from swh.model.exceptions import ValidationError
from swh.model.identifiers import parse_swhid
try: try:
return parse_swhid(value) return CoreSWHID.from_string(value)
except ValidationError as e: except ValidationError as e:
self.fail(f'"{value}" is not a valid SWHID: {e}', param, ctx) self.fail(f'"{value}" is not a valid core SWHID: {e}', param, ctx)
def swhid_of_file(path): def swhid_of_file(path) -> CoreSWHID:
from swh.model.from_disk import Content from swh.model.from_disk import Content
from swh.model.identifiers import CONTENT, swhid from swh.model.hashutil import hash_to_bytes
object = Content.from_file(path=path).get_data() object = Content.from_file(path=path).get_data()
return swhid(CONTENT, object) return CoreSWHID(
object_type=ObjectType.CONTENT, object_id=hash_to_bytes(object["sha1_git"])
)
def swhid_of_file_content(data): def swhid_of_file_content(data) -> CoreSWHID:
from swh.model.from_disk import Content from swh.model.from_disk import Content
from swh.model.identifiers import CONTENT, swhid from swh.model.hashutil import hash_to_bytes
object = Content.from_bytes(mode=644, data=data).get_data() object = Content.from_bytes(mode=644, data=data).get_data()
return swhid(CONTENT, object) return CoreSWHID(
object_type=ObjectType.CONTENT, object_id=hash_to_bytes(object["sha1_git"])
)
def swhid_of_dir(path: bytes, exclude_patterns: List[bytes] = None) -> str: def swhid_of_dir(path: bytes, exclude_patterns: List[bytes] = None) -> CoreSWHID:
from swh.model.from_disk import ( from swh.model.from_disk import (
Directory, Directory,
accept_all_directories, accept_all_directories,
ignore_directories_patterns, ignore_directories_patterns,
) )
from swh.model.identifiers import DIRECTORY, swhid from swh.model.hashutil import hash_to_bytes
dir_filter = ( dir_filter = (
ignore_directories_patterns(path, exclude_patterns) ignore_directories_patterns(path, exclude_patterns)
...@@ -73,24 +76,34 @@ def swhid_of_dir(path: bytes, exclude_patterns: List[bytes] = None) -> str: ...@@ -73,24 +76,34 @@ def swhid_of_dir(path: bytes, exclude_patterns: List[bytes] = None) -> str:
) )
object = Directory.from_disk(path=path, dir_filter=dir_filter).get_data() object = Directory.from_disk(path=path, dir_filter=dir_filter).get_data()
return swhid(DIRECTORY, object) return CoreSWHID(
object_type=ObjectType.DIRECTORY, object_id=hash_to_bytes(object["id"])
)
def swhid_of_origin(url): def swhid_of_origin(url):
from swh.model.identifiers import SWHID, origin_identifier from swh.model.hashutil import hash_to_bytes
from swh.model.identifiers import (
ExtendedObjectType,
ExtendedSWHID,
origin_identifier,
)
return str(SWHID(object_type="origin", object_id=origin_identifier({"url": url}))) return ExtendedSWHID(
object_type=ExtendedObjectType.ORIGIN,
object_id=hash_to_bytes(origin_identifier({"url": url})),
)
def swhid_of_git_repo(path): def swhid_of_git_repo(path) -> CoreSWHID:
import dulwich.repo import dulwich.repo
from swh.model import hashutil from swh.model import hashutil
from swh.model.identifiers import SWHID, snapshot_identifier from swh.model.identifiers import snapshot_identifier
repo = dulwich.repo.Repo(path) repo = dulwich.repo.Repo(path)
branches = {} branches: Dict[bytes, Optional[Dict]] = {}
for ref, target in repo.refs.as_dict().items(): for ref, target in repo.refs.as_dict().items():
obj = repo[target] obj = repo[target]
if obj: if obj:
...@@ -109,10 +122,13 @@ def swhid_of_git_repo(path): ...@@ -109,10 +122,13 @@ def swhid_of_git_repo(path):
snapshot = {"branches": branches} snapshot = {"branches": branches}
return str(SWHID(object_type="snapshot", object_id=snapshot_identifier(snapshot))) return CoreSWHID(
object_type=ObjectType.SNAPSHOT,
object_id=hashutil.hash_to_bytes(snapshot_identifier(snapshot)),
)
def identify_object(obj_type, follow_symlinks, exclude_patterns, obj): def identify_object(obj_type, follow_symlinks, exclude_patterns, obj) -> str:
from urllib.parse import urlparse from urllib.parse import urlparse
if obj_type == "auto": if obj_type == "auto":
...@@ -129,31 +145,29 @@ def identify_object(obj_type, follow_symlinks, exclude_patterns, obj): ...@@ -129,31 +145,29 @@ def identify_object(obj_type, follow_symlinks, exclude_patterns, obj):
except ValueError: except ValueError:
raise click.BadParameter("cannot detect object type for %s" % obj) raise click.BadParameter("cannot detect object type for %s" % obj)
swhid = None
if obj == "-": if obj == "-":
content = sys.stdin.buffer.read() content = sys.stdin.buffer.read()
swhid = swhid_of_file_content(content) swhid = str(swhid_of_file_content(content))
elif obj_type in ["content", "directory"]: elif obj_type in ["content", "directory"]:
path = obj.encode(sys.getfilesystemencoding()) path = obj.encode(sys.getfilesystemencoding())
if follow_symlinks and os.path.islink(obj): if follow_symlinks and os.path.islink(obj):
path = os.path.realpath(obj) path = os.path.realpath(obj)
if obj_type == "content": if obj_type == "content":
swhid = swhid_of_file(path) swhid = str(swhid_of_file(path))
elif obj_type == "directory": elif obj_type == "directory":
swhid = swhid_of_dir( swhid = str(
path, [pattern.encode() for pattern in exclude_patterns] swhid_of_dir(path, [pattern.encode() for pattern in exclude_patterns])
) )
elif obj_type == "origin": elif obj_type == "origin":
swhid = swhid_of_origin(obj) swhid = str(swhid_of_origin(obj))
elif obj_type == "snapshot": elif obj_type == "snapshot":
swhid = swhid_of_git_repo(obj) swhid = str(swhid_of_git_repo(obj))
else: # shouldn't happen, due to option validation else: # shouldn't happen, due to option validation
raise click.BadParameter("invalid object type: " + obj_type) raise click.BadParameter("invalid object type: " + obj_type)
# note: we return original obj instead of path here, to preserve user-given # note: we return original obj instead of path here, to preserve user-given
# file name in output # file name in output
return (obj, swhid) return swhid
@swh_cli_group.command(context_settings=CONTEXT_SETTINGS) @swh_cli_group.command(context_settings=CONTEXT_SETTINGS)
...@@ -191,7 +205,7 @@ def identify_object(obj_type, follow_symlinks, exclude_patterns, obj): ...@@ -191,7 +205,7 @@ def identify_object(obj_type, follow_symlinks, exclude_patterns, obj):
"--verify", "--verify",
"-v", "-v",
metavar="SWHID", metavar="SWHID",
type=SWHIDParamType(), type=CoreSWHIDParamType(),
help="reference identifier to be compared with computed one", help="reference identifier to be compared with computed one",
) )
@click.argument("objects", nargs=-1, required=True) @click.argument("objects", nargs=-1, required=True)
...@@ -232,8 +246,12 @@ def identify( ...@@ -232,8 +246,12 @@ def identify(
if verify and len(objects) != 1: if verify and len(objects) != 1:
raise click.BadParameter("verification requires a single object") raise click.BadParameter("verification requires a single object")
results = map( results = zip(
partial(identify_object, obj_type, follow_symlinks, exclude_patterns), objects, objects,
map(
partial(identify_object, obj_type, follow_symlinks, exclude_patterns),
objects,
),
) )
if verify: if verify:
......
...@@ -23,7 +23,7 @@ class TestIdentify(DataMixin, unittest.TestCase): ...@@ -23,7 +23,7 @@ class TestIdentify(DataMixin, unittest.TestCase):
self.runner = CliRunner() self.runner = CliRunner()
def assertSWHID(self, result, swhid): def assertSWHID(self, result, swhid):
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0, result.output)
self.assertEqual(result.output.split()[0], swhid) self.assertEqual(result.output.split()[0], swhid)
def test_no_args(self): def test_no_args(self):
...@@ -127,7 +127,7 @@ class TestIdentify(DataMixin, unittest.TestCase): ...@@ -127,7 +127,7 @@ class TestIdentify(DataMixin, unittest.TestCase):
def test_auto_origin(self): def test_auto_origin(self):
"""automatic object type detection: origin""" """automatic object type detection: origin"""
result = self.runner.invoke(cli.identify, ["https://github.com/torvalds/linux"]) result = self.runner.invoke(cli.identify, ["https://github.com/torvalds/linux"])
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0, result.output)
self.assertRegex(result.output, r"^swh:\d+:ori:") self.assertRegex(result.output, r"^swh:\d+:ori:")
def test_verify_content(self): def test_verify_content(self):
...@@ -139,7 +139,7 @@ class TestIdentify(DataMixin, unittest.TestCase): ...@@ -139,7 +139,7 @@ class TestIdentify(DataMixin, unittest.TestCase):
# match # match
path = os.path.join(self.tmpdir_name, filename) path = os.path.join(self.tmpdir_name, filename)
result = self.runner.invoke(cli.identify, ["--verify", expected_id, path]) result = self.runner.invoke(cli.identify, ["--verify", expected_id, path])
self.assertEqual(result.exit_code, 0) self.assertEqual(result.exit_code, 0, result.output)
# mismatch # mismatch
with open(path, "a") as f: with open(path, "a") as f:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment