Skip to content
Snippets Groups Projects
Commit c674a974 authored by Nicolas Dandrimont's avatar Nicolas Dandrimont
Browse files

Implement a separate kafka communication thread for journal clients

This communication thread is in charge of pulling the messages from
kafka and handing them off to a processing thread, as well as doing
regular polling of the rdkafka client (which in turn notifies the
brokers that the consumer is still alive).

Doing this allows the kafka communication thread to pause the kafka
consumption explicitly when processing a batch of messages takes too
long. This can in turn avoid a lot of rebalance traffic on the kafka
brokers, and overall avoids a bunch of internal rdkafka timeouts.
parent 3b0eb69c
Branches poll-thread
No related tags found
No related merge requests found
......@@ -4,11 +4,14 @@
# See top-level LICENSE file for more information
from collections import defaultdict
import enum
from enum import Enum
from importlib import import_module
from itertools import cycle
import logging
import os
import queue
from threading import Thread
import time
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import warnings
......@@ -29,7 +32,6 @@ from .serializers import kafka_to_value
logger = logging.getLogger(__name__)
rdkafka_logger = logging.getLogger(__name__ + ".rdkafka")
# Only accepted offset reset policy accepted
ACCEPTED_OFFSET_RESET = ["earliest", "latest"]
......@@ -40,7 +42,7 @@ _SPAMMY_ERRORS = [
]
class EofBehavior(enum.Enum):
class EofBehavior(Enum):
"""Possible behaviors when reaching the end of the log"""
CONTINUE = "continue"
......@@ -48,6 +50,14 @@ class EofBehavior(enum.Enum):
RESTART = "restart"
class PollThreadState(Enum):
INITIALIZING = "INITIALIZING"
WAITING = "WAITING"
PROCESSING = "PROCESSING"
REBALANCING = "REBALANCING"
TERMINATING = "TERMINATING"
def get_journal_client(cls: str, **kwargs: Any):
"""Factory function to instantiate a journal client object.
......@@ -261,6 +271,11 @@ class JournalClient:
)
self.consumer = Consumer(consumer_settings)
self.poll_thread = Thread(target=self.kafka_poll_thread)
self.poll_thread_queue: queue.Queue[PollThreadState] = queue.Queue()
self.poll_thread.start()
if privileged:
privileged_prefix = f"{prefix}_privileged"
else: # do not attempt to subscribe to privileged topics
......@@ -331,13 +346,67 @@ class JournalClient:
"JournalClient; please remove it from your configuration.",
)
def kafka_poll_thread(self):
prev_state = PollThreadState.INITIALIZING
state = PollThreadState.INITIALIZING
last_state_change = time.monotonic()
paused_partitions = False
while True:
try:
new_state = self.poll_thread_queue.get(timeout=0.1)
except queue.Empty:
pass
else:
if new_state != state:
logger.debug("Poll thread now %s", new_state)
if state != PollThreadState.REBALANCING:
prev_state = state
state = new_state
last_state_change = time.monotonic()
now = time.monotonic()
if state == PollThreadState.INITIALIZING:
continue
elif state == PollThreadState.REBALANCING:
paused_partitions = True
self.poll_thread_queue.put(prev_state)
elif state == PollThreadState.WAITING:
if paused_partitions:
self.consumer.resume(self.consumer.assignment())
paused_partitions = False
elif state == PollThreadState.PROCESSING:
if not paused_partitions and now - last_state_change > 15:
self.consumer.pause(self.consumer.assignment())
paused_partitions = True
if paused_partitions:
msg = self.consumer.poll(timeout=10)
if not msg:
continue
error = msg.error()
if not error:
raise ValueError("poll thread got a non-error message?")
_error_cb(error)
elif state == PollThreadState.TERMINATING:
break
else:
logger.warning("Unknown poll_thread_state: %s", state)
def rebalance_cb(self, consumer, partitions):
consumer.pause(partitions)
self.poll_thread_queue.put(PollThreadState.REBALANCING)
def subscribe(self):
"""Subscribe to topics listed in self.subscription
This can be overridden if you need, for instance, to manually assign partitions.
"""
logger.debug(f"Subscribing to: {self.subscription}")
self.consumer.subscribe(topics=self.subscription)
self.consumer.subscribe(
topics=self.subscription,
on_assign=self.rebalance_cb,
on_revoke=self.rebalance_cb,
)
def process(self, worker_fn: Callable[[Dict[str, List[dict]]], None]):
"""Polls Kafka for a batch of messages, and calls the worker_fn
......@@ -366,6 +435,7 @@ class JournalClient:
batch_size,
)
set_status("waiting")
self.poll_thread_queue.put(PollThreadState.WAITING)
for i in cycle(reversed(range(10))):
messages = self.consumer.consume(
timeout=timeout, num_messages=batch_size
......@@ -402,9 +472,11 @@ class JournalClient:
if messages:
set_status("processing")
self.poll_thread_queue.put(PollThreadState.PROCESSING)
batch_processed, at_eof = self.handle_messages(messages, worker_fn)
set_status("idle")
self.poll_thread_queue.put(PollThreadState.WAITING)
# report the number of handled messages
self.statsd.increment("handle_message_total", value=batch_processed)
total_objects_processed += batch_processed
......@@ -460,4 +532,6 @@ class JournalClient:
return self.value_deserializer(object_type, message.value())
def close(self):
self.poll_thread_queue.put(PollThreadState.TERMINATING)
self.poll_thread.join()
self.consumer.close()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment