Skip to content
Snippets Groups Projects
Commit 292e5dcd authored by David Douard's avatar David Douard
Browse files

Kill the CustomCelery class

use functions instead of methods.

This is required to be able to use celery pytest fixtures so one can
really test celery tasks (especially when a task spawns sub tasks).

one (get_queue_lenth) of the 3 methods has been added as (monkeypatched)
method on the Celery class for the sake of bw compat, but this should really
be removed as well as soon as possible (seems only used in swh-archiver).
parent 96ad58b1
No related branches found
No related tags found
1 merge request!32Kill the CustomCelery class
......@@ -108,7 +108,7 @@ def setup_queues_and_tasks(sender, instance, **kwargs):
and obj != Task # Don't register the abstract class itself
):
class_name = '%s.%s' % (module_name, name)
instance.app.register_task_class(class_name, obj)
register_task_class(instance.app, class_name, obj)
for task_name in instance.app.tasks:
if task_name.startswith('swh.'):
......@@ -127,58 +127,59 @@ def route_for_task(name, args, kwargs, options, task=None, **kw):
return {'queue': name}
class CustomCelery(Celery):
def get_queue_stats(self, queue_name):
"""Get the statistics regarding a queue on the broker.
Arguments:
queue_name: name of the queue to check
Returns a dictionary raw from the RabbitMQ management API;
or `None` if the current configuration does not use RabbitMQ.
Interesting keys:
- consumers (number of consumers for the queue)
- messages (number of messages in queue)
- messages_unacknowledged (number of messages currently being
processed)
Documentation: https://www.rabbitmq.com/management.html#http-api
"""
conn_info = self.connection().info()
if conn_info['transport'] == 'memory':
# We're running in a test environment, without RabbitMQ.
return None
url = 'http://{hostname}:{port}/api/queues/{vhost}/{queue}'.format(
hostname=conn_info['hostname'],
port=conn_info['port'] + 10000,
vhost=urllib.parse.quote(conn_info['virtual_host'], safe=''),
queue=urllib.parse.quote(queue_name, safe=''),
)
credentials = (conn_info['userid'], conn_info['password'])
r = requests.get(url, auth=credentials)
if r.status_code == 404:
return {}
if r.status_code != 200:
raise ValueError('Got error %s when reading queue stats: %s' % (
r.status_code, r.json()))
return r.json()
def get_queue_length(self, queue_name):
"""Shortcut to get a queue's length"""
stats = self.get_queue_stats(queue_name)
if stats:
return stats.get('messages')
def register_task_class(self, name, cls):
"""Register a class-based task under the given name"""
if name in self.tasks:
return
task_instance = cls()
task_instance.name = name
self.register_task(task_instance)
def get_queue_stats(app, queue_name):
"""Get the statistics regarding a queue on the broker.
Arguments:
queue_name: name of the queue to check
Returns a dictionary raw from the RabbitMQ management API;
or `None` if the current configuration does not use RabbitMQ.
Interesting keys:
- consumers (number of consumers for the queue)
- messages (number of messages in queue)
- messages_unacknowledged (number of messages currently being
processed)
Documentation: https://www.rabbitmq.com/management.html#http-api
"""
conn_info = app.connection().info()
if conn_info['transport'] == 'memory':
# We're running in a test environment, without RabbitMQ.
return None
url = 'http://{hostname}:{port}/api/queues/{vhost}/{queue}'.format(
hostname=conn_info['hostname'],
port=conn_info['port'] + 10000,
vhost=urllib.parse.quote(conn_info['virtual_host'], safe=''),
queue=urllib.parse.quote(queue_name, safe=''),
)
credentials = (conn_info['userid'], conn_info['password'])
r = requests.get(url, auth=credentials)
if r.status_code == 404:
return {}
if r.status_code != 200:
raise ValueError('Got error %s when reading queue stats: %s' % (
r.status_code, r.json()))
return r.json()
def get_queue_length(app, queue_name):
"""Shortcut to get a queue's length"""
stats = get_queue_stats(app, queue_name)
if stats:
return stats.get('messages')
def register_task_class(app, name, cls):
"""Register a class-based task under the given name"""
if name in app.tasks:
return
task_instance = cls()
task_instance.name = name
app.register_task(task_instance)
INSTANCE_NAME = os.environ.get(CONFIG_NAME_ENVVAR)
......@@ -196,11 +197,7 @@ CELERY_QUEUES = [Queue('celery', Exchange('celery'), routing_key='celery')]
for queue in CONFIG['task_queues']:
CELERY_QUEUES.append(Queue(queue, Exchange(queue), routing_key=queue))
# Instantiate the Celery app
app = CustomCelery()
app.conf.update(
# The broker
broker_url=CONFIG['task_broker'],
CELERY_DEFAULT_CONFIG = dict(
# Timezone configuration: all in UTC
enable_utc=True,
timezone='UTC',
......@@ -254,4 +251,11 @@ app.conf.update(
worker_send_task_events=True,
# Do not send useless task_sent events
task_send_sent_event=False,
)
)
# Instantiate the Celery app
app = Celery(broker=CONFIG['task_broker'])
app.add_defaults(CELERY_DEFAULT_CONFIG)
# XXX for BW compat
Celery.get_queue_length = get_queue_length
......@@ -4,7 +4,7 @@ import datetime
from celery.result import AsyncResult
from celery.contrib.testing.worker import start_worker
import celery.contrib.testing.tasks # noqa
import celery.contrib.testing.tasks # noqa
import pytest
from swh.core.tests.db_testing import DbTestFixture, DB_DUMP_TYPES
......@@ -12,7 +12,7 @@ from swh.core.utils import numfile_sortkey as sortkey
from swh.scheduler import get_scheduler
from swh.scheduler.celery_backend.runner import run_ready_tasks
from swh.scheduler.celery_backend.config import app
from swh.scheduler.celery_backend.config import app, register_task_class
from swh.scheduler.tests.celery_testing import CeleryTestFixture
from . import SQL_DIR
......@@ -42,7 +42,7 @@ class SchedulerTestFixture(CeleryTestFixture, DbTestFixture):
}
self.scheduler.create_task_type(task_type)
if task_class:
app.register_task_class(backend_name, task_class)
register_task_class(app, backend_name, task_class)
def run_ready_tasks(self):
"""Runs the scheduler and a Celery worker, then blocks until
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment