Skip to content
Snippets Groups Projects
Commit c0f7774d authored by vlorentz's avatar vlorentz
Browse files

Replace stop_at_eof boolean with EofBehavior, add support for restarting from beginning

parent cb058d72
No related branches found
No related tags found
No related merge requests found
# Copyright (C) 2017-2022 The Software Heritage developers
# Copyright (C) 2017-2023 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
from collections import defaultdict
import enum
from importlib import import_module
from itertools import cycle
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import warnings
from confluent_kafka import Consumer, KafkaError, KafkaException
from confluent_kafka import (
OFFSET_BEGINNING,
Consumer,
KafkaError,
KafkaException,
TopicPartition,
)
from swh.core.statsd import statsd
from swh.core.statsd import Statsd
from swh.journal import DEFAULT_PREFIX
from .serializers import kafka_to_value
......@@ -30,8 +38,13 @@ _SPAMMY_ERRORS = [
KafkaError._NO_OFFSET,
]
JOURNAL_MESSAGE_NUMBER_METRIC = "swh_journal_client_handle_message_total"
JOURNAL_STATUS_METRIC = "swh_journal_client_status"
class EofBehavior(enum.Enum):
"""Possible behaviors when reaching the end of the log"""
CONTINUE = "continue"
STOP = "stop"
RESTART = "restart"
def get_journal_client(cls: str, **kwargs: Any):
......@@ -92,35 +105,37 @@ class JournalClient:
`'swh.journal.objects'`.
Clients subscribe to events specific to each object type as listed in the
`object_types` argument (if unset, defaults to all existing kafka topic under
``object_types`` argument (if unset, defaults to all existing kafka topic under
the prefix).
Clients can be sharded by setting the `group_id` to a common
Clients can be sharded by setting the ``group_id`` to a common
value across instances. The journal will share the message
throughput across the nodes sharing the same group_id.
Messages are processed by the `worker_fn` callback passed to the `process`
method, in batches of maximum `batch_size` messages (defaults to 200).
The objects passed to the `worker_fn` callback are the result of the kafka
message converted by the `value_deserializer` function. By default (if this
argument is not given), it will produce dicts (using the `kafka_to_value`
function). This signature of the function is:
Messages are processed by the ``worker_fn`` callback passed to the `process`
method, in batches of maximum ``batch_size`` messages (defaults to 200).
`value_deserializer(object_type: str, kafka_msg: bytes) -> Any`
The objects passed to the ``worker_fn`` callback are the result of the kafka
message converted by the ``value_deserializer`` function. By default (if this
argument is not given), it will produce dicts (using the ``kafka_to_value``
function). This signature of the function is::
If the value returned by `value_deserializer` is None, it is ignored and
not passed the `worker_fn` function.
value_deserializer(object_type: str, kafka_msg: bytes) -> Any
If set, the processing stops after processing `stop_after_objects` messages
in total.
If the value returned by ``value_deserializer`` is None, it is ignored and
not passed the ``worker_fn`` function.
`stop_on_eof` stops the processing when the client has reached the end of
each partition in turn.
`auto_offset_reset` sets the behavior of the client when the consumer group
initializes: `'earliest'` (the default) processes all objects since the
inception of the topics; `''`
Arguments:
stop_after_objects: If set, the processing stops after processing
this number of messages in total.
on_eof: What to do when reaching the end of each partition (keep consuming,
stop, or restart from earliest offsets); defaults to continuing.
This can be either a :class:`EofBehavior` variant or a string containing the
name of one of the variants.
stop_on_eof: (deprecated) equivalent to passing ``on_eof=EofBehavior.STOP``
auto_offset_reset: sets the behavior of the client when the consumer group
initializes: ``'earliest'`` (the default) processes all objects since the
inception of the topics; ``''``
Any other named argument is passed directly to KafkaConsumer().
......@@ -137,7 +152,8 @@ class JournalClient:
batch_size: int = 200,
process_timeout: Optional[float] = None,
auto_offset_reset: str = "earliest",
stop_on_eof: bool = False,
stop_on_eof: Optional[bool] = None,
on_eof: Optional[Union[EofBehavior, str]] = None,
value_deserializer: Optional[Callable[[str, bytes], Any]] = None,
**kwargs,
):
......@@ -155,6 +171,29 @@ class JournalClient:
self.value_deserializer = value_deserializer
else:
self.value_deserializer = lambda _, value: kafka_to_value(value)
if stop_on_eof is not None:
if on_eof is not None:
raise TypeError(
"stop_on_eof and on_eof are mutually exclusive (the former is "
"deprecated)"
)
elif stop_on_eof:
warnings.warn(
"stop_on_eof=True should be replaced with "
"on_eof=EofBehavior.STOP ('on_eof: stop' in YAML)",
DeprecationWarning,
2,
)
on_eof = EofBehavior.STOP
else:
warnings.warn(
"stop_on_eof=False should be replaced with "
"on_eof=EofBehavior.CONTINUE ('on_eof: continue' in YAML)",
DeprecationWarning,
2,
)
on_eof = EofBehavior.CONTINUE
self.on_eof = EofBehavior(on_eof or EofBehavior.CONTINUE)
if isinstance(brokers, str):
brokers = [brokers]
......@@ -201,8 +240,7 @@ class JournalClient:
"logger": rdkafka_logger,
}
self.stop_on_eof = stop_on_eof
if self.stop_on_eof:
if self.on_eof != EofBehavior.CONTINUE:
consumer_settings["enable.partition.eof"] = True
if logger.isEnabledFor(logging.DEBUG):
......@@ -214,6 +252,8 @@ class JournalClient:
logger.debug(" %s: %s", k, v)
self.statsd = Statsd("swh_journal_client")
self.consumer = Consumer(consumer_settings)
if privileged:
privileged_prefix = f"{prefix}_privileged"
......@@ -274,21 +314,19 @@ class JournalClient:
logger.debug(f"Subscribing to: {self.subscription}")
self.consumer.subscribe(topics=self.subscription)
def process(self, worker_fn):
def process(self, worker_fn: Callable[[Dict[str, List[dict]]], None]):
"""Polls Kafka for a batch of messages, and calls the worker_fn
with these messages.
Args:
worker_fn Callable[Dict[str, List[dict]]]: Function called with
the messages as
argument.
worker_fn: Function called with the messages as argument.
"""
total_objects_processed = 0
# timeout for message poll
timeout = 1.0
with statsd.status_gauge(
JOURNAL_STATUS_METRIC, statuses=["idle", "processing", "waiting"]
with self.statsd.status_gauge(
"status", statuses=["idle", "processing", "waiting"]
) as set_status:
set_status("idle")
while True:
......@@ -313,30 +351,48 @@ class JournalClient:
# do check for an EOF condition iff we already consumed
# messages, otherwise we could detect an EOF condition
# before messages had a chance to reach us (e.g. in tests)
if total_objects_processed > 0 and self.stop_on_eof and i == 0:
at_eof = all(
(tp.topic, tp.partition) in self.eof_reached
for tp in self.consumer.assignment()
)
if at_eof:
break
if total_objects_processed > 0 and i == 0:
if self.on_eof == EofBehavior.STOP:
at_eof = all(
(tp.topic, tp.partition) in self.eof_reached
for tp in self.consumer.assignment()
)
if at_eof:
break
elif self.on_eof == EofBehavior.RESTART:
for tp in self.consumer.assignment():
if (tp.topic, tp.partition) in self.eof_reached:
self.eof_reached.remove((tp.topic, tp.partition))
self.statsd.increment("partition_restart_total")
new_tp = TopicPartition(
tp.topic,
tp.partition,
OFFSET_BEGINNING,
)
self.consumer.seek(new_tp)
elif self.on_eof == EofBehavior.CONTINUE:
pass # Nothing to do, we'll just keep consuming
else:
assert False, f"Unexpected on_eof behavior: {self.on_eof}"
if messages:
set_status("processing")
batch_processed, at_eof = self.handle_messages(messages, worker_fn)
set_status("idle")
# report the number of handled messages
statsd.increment(
JOURNAL_MESSAGE_NUMBER_METRIC, value=batch_processed
)
self.statsd.increment("handle_message_total", value=batch_processed)
total_objects_processed += batch_processed
if at_eof:
if self.on_eof == EofBehavior.STOP and at_eof:
self.statsd.increment("stop_total")
break
return total_objects_processed
def handle_messages(self, messages, worker_fn):
def handle_messages(
self, messages, worker_fn: Callable[[Dict[str, List[dict]]], None]
) -> Tuple[int, bool]:
objects: Dict[str, List[Any]] = defaultdict(list)
nb_processed = 0
......@@ -363,10 +419,15 @@ class JournalClient:
worker_fn(dict(objects))
self.consumer.commit()
at_eof = self.stop_on_eof and all(
(tp.topic, tp.partition) in self.eof_reached
for tp in self.consumer.assignment()
)
if self.on_eof in (EofBehavior.STOP, EofBehavior.RESTART):
at_eof = all(
(tp.topic, tp.partition) in self.eof_reached
for tp in self.consumer.assignment()
)
elif self.on_eof == EofBehavior.CONTINUE:
at_eof = False
else:
assert False, f"Unexpected on_eof behavior: {self.on_eof}"
return nb_processed, at_eof
......
# Copyright (C) 2019 The Software Heritage developers
# Copyright (C) 2019-2023 The Software Heritage developers
# See the AUTHORS file at the top-level directory of this distribution
# License: GNU General Public License version 3, or any later version
# See top-level LICENSE file for more information
import math
from typing import Dict, List, cast
from unittest.mock import MagicMock
from confluent_kafka import Producer
import pytest
from swh.journal.client import JournalClient
from swh.journal.client import EofBehavior, JournalClient
from swh.journal.serializers import kafka_to_value, key_to_kafka, value_to_kafka
from swh.model.model import Content, Revision
from swh.model.tests.swh_model_data import TEST_OBJECTS
......@@ -40,7 +41,10 @@ REV = {
}
def test_client(kafka_prefix: str, kafka_consumer_group: str, kafka_server: str):
@pytest.mark.parametrize("legacy_eof", [True, False])
def test_client(
kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, legacy_eof: bool
):
producer = Producer(
{
"bootstrap.servers": kafka_server,
......@@ -57,21 +61,106 @@ def test_client(kafka_prefix: str, kafka_consumer_group: str, kafka_server: str)
)
producer.flush()
client = JournalClient(
brokers=[kafka_server],
group_id=kafka_consumer_group,
prefix=kafka_prefix,
stop_on_eof=True,
)
if legacy_eof:
with pytest.deprecated_call():
client = JournalClient(
brokers=[kafka_server],
group_id=kafka_consumer_group,
prefix=kafka_prefix,
stop_on_eof=True,
)
else:
client = JournalClient(
brokers=[kafka_server],
group_id=kafka_consumer_group,
prefix=kafka_prefix,
on_eof=EofBehavior.STOP,
)
worker_fn = MagicMock()
client.process(worker_fn)
worker_fn.assert_called_once_with({"revision": [REV]})
@pytest.mark.parametrize("count", [1, 2])
@pytest.mark.parametrize(
"count,legacy_eof", [(1, True), (2, True), (1, False), (2, False)]
)
def test_client_stop_after_objects(
kafka_prefix: str, kafka_consumer_group: str, kafka_server: str, count: int
kafka_prefix: str,
kafka_consumer_group: str,
kafka_server: str,
count: int,
legacy_eof: bool,
):
producer = Producer(
{
"bootstrap.servers": kafka_server,
"client.id": "test producer",
"acks": "all",
}
)
# Fill Kafka
revisions = cast(List[Revision], TEST_OBJECTS["revision"])
for rev in revisions:
producer.produce(
topic=kafka_prefix + ".revision",
key=rev.id,
value=value_to_kafka(rev.to_dict()),
)
producer.flush()
if legacy_eof:
with pytest.deprecated_call():
client = JournalClient(
brokers=[kafka_server],
group_id=kafka_consumer_group,
prefix=kafka_prefix,
stop_on_eof=False,
stop_after_objects=count,
)
else:
client = JournalClient(
brokers=[kafka_server],
group_id=kafka_consumer_group,
prefix=kafka_prefix,
on_eof=EofBehavior.CONTINUE,
stop_after_objects=count,
)
worker_fn = MagicMock()
client.process(worker_fn)
# this code below is not pretty, but needed since we have to deal with
# dicts (so no set) which can have values that are list vs tuple, and we do
# not know for sure how many calls of the worker_fn will happen during the
# consumption of the topic...
worker_fn.assert_called()
revs = [] # list of (unique) rev dicts we got from the client
for call in worker_fn.call_args_list:
callrevs = call[0][0]["revision"]
for rev in callrevs:
assert Revision.from_dict(rev) in revisions
if rev not in revs:
revs.append(rev)
assert len(revs) == count
assert len(TEST_OBJECTS["revision"]) < 10, (
'test_client_restart_and_stop_after_objects expects TEST_OBJECTS["revision"] '
"to have less than 10 objects to test exhaustively"
)
@pytest.mark.parametrize(
"count,string_eof", [(1, True), (2, False), (10, True), (20, False)]
)
def test_client_restart_and_stop_after_objects(
kafka_prefix: str,
kafka_consumer_group: str,
kafka_server: str,
count: int,
string_eof: bool,
):
producer = Producer(
{
......@@ -95,7 +184,7 @@ def test_client_stop_after_objects(
brokers=[kafka_server],
group_id=kafka_consumer_group,
prefix=kafka_prefix,
stop_on_eof=False,
on_eof="restart" if string_eof else EofBehavior.RESTART,
stop_after_objects=count,
)
......@@ -107,14 +196,32 @@ def test_client_stop_after_objects(
# not know for sure how many calls of the worker_fn will happen during the
# consumption of the topic...
worker_fn.assert_called()
revs = [] # list of (unique) rev dicts we got from the client
revs = [] # list of (possibly duplicated) rev dicts we got from the client
unique_revs = [] # list of (unique) rev dicts we got from the client
for call in worker_fn.call_args_list:
callrevs = call[0][0]["revision"]
for rev in callrevs:
assert Revision.from_dict(rev) in revisions
if rev not in revs:
revs.append(rev)
unique_revs.append(rev)
revs.append(rev)
assert len(revs) == count
assert len(unique_revs) == min(len(revisions), count)
# Each revision should be seen approximately count/len(revisions) times
rev_ids = [r["id"].hex() for r in revs] # type: ignore
for rev in revisions:
assert (
math.floor(count / len(revisions))
<= rev_ids.count(rev.id.hex())
<= math.ceil(count / len(revisions))
)
# Check each run but the last contains all revisions
for i in range(int(count / len(revisions))):
assert set(rev_ids[i * len(revisions) : (i + 1) * len(revisions)]) == set(
rev.id.hex() for rev in revisions
), i
@pytest.mark.parametrize("batch_size", [1, 5, 100])
......
......@@ -6,7 +6,7 @@
from collections import OrderedDict
from datetime import datetime, timedelta, timezone
import itertools
from typing import Iterable
from typing import List
import pytest
......@@ -49,10 +49,10 @@ def test_pprint_key():
assert pprinted_key == key.hex()
def test_kafka_to_key():
def test_kafka_to_key() -> None:
"""Standard back and forth serialization with keys"""
# All KeyType(s)
keys: Iterable[serializers.KeyType] = [
keys: List[serializers.KeyType] = [
{
"a": "foo",
"b": "bar",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment