Source code for celery.worker.control

"""Worker remote control command implementations."""
import io
import tempfile
from collections import UserDict, namedtuple

from billiard.common import TERM_SIGNAME
from kombu.utils.encoding import safe_repr

from celery.exceptions import WorkerShutdown
from celery.platforms import signals as _signals
from celery.utils.functional import maybe_list
from celery.utils.log import get_logger
from celery.utils.serialization import jsonify, strtobool
from celery.utils.time import rate

from . import state as worker_state
from .request import Request

__all__ = ('Panel',)

DEFAULT_TASK_INFO_ITEMS = ('exchange', 'routing_key', 'rate_limit')
logger = get_logger(__name__)

controller_info_t = namedtuple('controller_info_t', [
    'alias', 'type', 'visible', 'default_timeout',
    'help', 'signature', 'args', 'variadic',
])


def ok(value):
    return {'ok': value}


def nok(value):
    return {'error': value}


[docs]class Panel(UserDict): """Global registry of remote control commands.""" data = {} # global dict. meta = {} # -"-
[docs] @classmethod def register(cls, *args, **kwargs): if args: return cls._register(**kwargs)(*args) return cls._register(**kwargs)
@classmethod def _register(cls, name=None, alias=None, type='control', visible=True, default_timeout=1.0, help=None, signature=None, args=None, variadic=None): def _inner(fun): control_name = name or fun.__name__ _help = help or (fun.__doc__ or '').strip().split('\n')[0] cls.data[control_name] = fun cls.meta[control_name] = controller_info_t( alias, type, visible, default_timeout, _help, signature, args, variadic) if alias: cls.data[alias] = fun return fun return _inner
def control_command(**kwargs): return Panel.register(type='control', **kwargs) def inspect_command(**kwargs): return Panel.register(type='inspect', **kwargs) # -- App @inspect_command() def report(state): """Information about Celery installation for bug reports.""" return ok(state.app.bugreport()) @inspect_command( alias='dump_conf', # XXX < backwards compatible signature='[include_defaults=False]', args=[('with_defaults', strtobool)], ) def conf(state, with_defaults=False, **kwargs): """List configuration.""" return jsonify(state.app.conf.table(with_defaults=with_defaults), keyfilter=_wanted_config_key, unknown_type_filter=safe_repr) def _wanted_config_key(key): return isinstance(key, str) and not key.startswith('__') # -- Task @inspect_command( variadic='ids', signature='[id1 [id2 [... [idN]]]]', ) def query_task(state, ids, **kwargs): """Query for task information by id.""" return { req.id: (_state_of_task(req), req.info()) for req in _find_requests_by_id(maybe_list(ids)) } def _find_requests_by_id(ids, get_request=worker_state.requests.__getitem__): for task_id in ids: try: yield get_request(task_id) except KeyError: pass def _state_of_task(request, is_active=worker_state.active_requests.__contains__, is_reserved=worker_state.reserved_requests.__contains__): if is_active(request): return 'active' elif is_reserved(request): return 'reserved' return 'ready' @control_command( variadic='task_id', signature='[id1 [id2 [... [idN]]]]', ) def revoke(state, task_id, terminate=False, signal=None, **kwargs): """Revoke task by task id (or list of ids). Keyword Arguments: terminate (bool): Also terminate the process if the task is active. signal (str): Name of signal to use for terminate (e.g., ``KILL``). """ # pylint: disable=redefined-outer-name # XXX Note that this redefines `terminate`: # Outside of this scope that is a function. # supports list argument since 3.1 task_ids, task_id = set(maybe_list(task_id) or []), None size = len(task_ids) terminated = set() worker_state.revoked.update(task_ids) if terminate: signum = _signals.signum(signal or TERM_SIGNAME) for request in _find_requests_by_id(task_ids): if request.id not in terminated: terminated.add(request.id) logger.info('Terminating %s (%s)', request.id, signum) request.terminate(state.consumer.pool, signal=signum) if len(terminated) >= size: break if not terminated: return ok('terminate: tasks unknown') return ok('terminate: {}'.format(', '.join(terminated))) idstr = ', '.join(task_ids) logger.info('Tasks flagged as revoked: %s', idstr) return ok(f'tasks {idstr} flagged as revoked') @control_command( variadic='task_id', args=[('signal', str)], signature='<signal> [id1 [id2 [... [idN]]]]' ) def terminate(state, signal, task_id, **kwargs): """Terminate task by task id (or list of ids).""" return revoke(state, task_id, terminate=True, signal=signal) @control_command( args=[('task_name', str), ('rate_limit', str)], signature='<task_name> <rate_limit (e.g., 5/s | 5/m | 5/h)>', ) def rate_limit(state, task_name, rate_limit, **kwargs): """Tell worker(s) to modify the rate limit for a task by type. See Also: :attr:`celery.task.base.Task.rate_limit`. Arguments: task_name (str): Type of task to set rate limit for. rate_limit (int, str): New rate limit. """ # pylint: disable=redefined-outer-name # XXX Note that this redefines `terminate`: # Outside of this scope that is a function. try: rate(rate_limit) except ValueError as exc: return nok(f'Invalid rate limit string: {exc!r}') try: state.app.tasks[task_name].rate_limit = rate_limit except KeyError: logger.error('Rate limit attempt for unknown task %s', task_name, exc_info=True) return nok('unknown task') state.consumer.reset_rate_limits() if not rate_limit: logger.info('Rate limits disabled for tasks of type %s', task_name) return ok('rate limit disabled successfully') logger.info('New rate limit for tasks of type %s: %s.', task_name, rate_limit) return ok('new rate limit set successfully') @control_command( args=[('task_name', str), ('soft', float), ('hard', float)], signature='<task_name> <soft_secs> [hard_secs]', ) def time_limit(state, task_name=None, hard=None, soft=None, **kwargs): """Tell worker(s) to modify the time limit for task by type. Arguments: task_name (str): Name of task to change. hard (float): Hard time limit. soft (float): Soft time limit. """ try: task = state.app.tasks[task_name] except KeyError: logger.error('Change time limit attempt for unknown task %s', task_name, exc_info=True) return nok('unknown task') task.soft_time_limit = soft task.time_limit = hard logger.info('New time limits for tasks of type %s: soft=%s hard=%s', task_name, soft, hard) return ok('time limits set successfully') # -- Events @inspect_command() def clock(state, **kwargs): """Get current logical clock value.""" return {'clock': state.app.clock.value} @control_command() def election(state, id, topic, action=None, **kwargs): """Hold election. Arguments: id (str): Unique election id. topic (str): Election topic. action (str): Action to take for elected actor. """ if state.consumer.gossip: state.consumer.gossip.election(id, topic, action) @control_command() def enable_events(state): """Tell worker(s) to send task-related events.""" dispatcher = state.consumer.event_dispatcher if dispatcher.groups and 'task' not in dispatcher.groups: dispatcher.groups.add('task') logger.info('Events of group {task} enabled by remote.') return ok('task events enabled') return ok('task events already enabled') @control_command() def disable_events(state): """Tell worker(s) to stop sending task-related events.""" dispatcher = state.consumer.event_dispatcher if 'task' in dispatcher.groups: dispatcher.groups.discard('task') logger.info('Events of group {task} disabled by remote.') return ok('task events disabled') return ok('task events already disabled') @control_command() def heartbeat(state): """Tell worker(s) to send event heartbeat immediately.""" logger.debug('Heartbeat requested by remote.') dispatcher = state.consumer.event_dispatcher dispatcher.send('worker-heartbeat', freq=5, **worker_state.SOFTWARE_INFO) # -- Worker @inspect_command(visible=False) def hello(state, from_node, revoked=None, **kwargs): """Request mingle sync-data.""" # pylint: disable=redefined-outer-name # XXX Note that this redefines `revoked`: # Outside of this scope that is a function. if from_node != state.hostname: logger.info('sync with %s', from_node) if revoked: worker_state.revoked.update(revoked) return { 'revoked': worker_state.revoked._data, 'clock': state.app.clock.forward(), } @inspect_command(default_timeout=0.2) def ping(state, **kwargs): """Ping worker(s).""" return ok('pong') @inspect_command() def stats(state, **kwargs): """Request worker statistics/information.""" return state.consumer.controller.stats() @inspect_command(alias='dump_schedule') def scheduled(state, **kwargs): """List of currently scheduled ETA/countdown tasks.""" return list(_iter_schedule_requests(state.consumer.timer)) def _iter_schedule_requests(timer): for waiting in timer.schedule.queue: try: arg0 = waiting.entry.args[0] except (IndexError, TypeError): continue else: if isinstance(arg0, Request): yield { 'eta': arg0.eta.isoformat() if arg0.eta else None, 'priority': waiting.priority, 'request': arg0.info(), } @inspect_command(alias='dump_reserved') def reserved(state, **kwargs): """List of currently reserved tasks, not including scheduled/active.""" reserved_tasks = ( state.tset(worker_state.reserved_requests) - state.tset(worker_state.active_requests) ) if not reserved_tasks: return [] return [request.info() for request in reserved_tasks] @inspect_command(alias='dump_active') def active(state, **kwargs): """List of tasks currently being executed.""" return [request.info() for request in state.tset(worker_state.active_requests)] @inspect_command(alias='dump_revoked') def revoked(state, **kwargs): """List of revoked task-ids.""" return list(worker_state.revoked) @inspect_command( alias='dump_tasks', variadic='taskinfoitems', signature='[attr1 [attr2 [... [attrN]]]]', ) def registered(state, taskinfoitems=None, builtins=False, **kwargs): """List of registered tasks. Arguments: taskinfoitems (Sequence[str]): List of task attributes to include. Defaults to ``exchange,routing_key,rate_limit``. builtins (bool): Also include built-in tasks. """ reg = state.app.tasks taskinfoitems = taskinfoitems or DEFAULT_TASK_INFO_ITEMS tasks = reg if builtins else ( task for task in reg if not task.startswith('celery.')) def _extract_info(task): fields = { field: str(getattr(task, field, None)) for field in taskinfoitems if getattr(task, field, None) is not None } if fields: info = ['='.join(f) for f in fields.items()] return '{} [{}]'.format(task.name, ' '.join(info)) return task.name return [_extract_info(reg[task]) for task in sorted(tasks)] # -- Debugging @inspect_command( default_timeout=60.0, args=[('type', str), ('num', int), ('max_depth', int)], signature='[object_type=Request] [num=200 [max_depth=10]]', ) def objgraph(state, num=200, max_depth=10, type='Request'): # pragma: no cover """Create graph of uncollected objects (memory-leak debugging). Arguments: num (int): Max number of objects to graph. max_depth (int): Traverse at most n levels deep. type (str): Name of object to graph. Default is ``"Request"``. """ try: import objgraph as _objgraph except ImportError: raise ImportError('Requires the objgraph library') logger.info('Dumping graph for type %r', type) with tempfile.NamedTemporaryFile(prefix='cobjg', suffix='.png', delete=False) as fh: objects = _objgraph.by_type(type)[:num] _objgraph.show_backrefs( objects, max_depth=max_depth, highlight=lambda v: v in objects, filename=fh.name, ) return {'filename': fh.name} @inspect_command() def memsample(state, **kwargs): """Sample current RSS memory usage.""" from celery.utils.debug import sample_mem return sample_mem() @inspect_command( args=[('samples', int)], signature='[n_samples=10]', ) def memdump(state, samples=10, **kwargs): # pragma: no cover """Dump statistics of previous memsample requests.""" from celery.utils import debug out = io.StringIO() debug.memdump(file=out) return out.getvalue() # -- Pool @control_command( args=[('n', int)], signature='[N=1]', ) def pool_grow(state, n=1, **kwargs): """Grow pool by n processes/threads.""" if state.consumer.controller.autoscaler: return nok("pool_grow is not supported with autoscale. Adjust autoscale range instead.") else: state.consumer.pool.grow(n) state.consumer._update_prefetch_count(n) return ok('pool will grow') @control_command( args=[('n', int)], signature='[N=1]', ) def pool_shrink(state, n=1, **kwargs): """Shrink pool by n processes/threads.""" if state.consumer.controller.autoscaler: return nok("pool_shrink is not supported with autoscale. Adjust autoscale range instead.") else: state.consumer.pool.shrink(n) state.consumer._update_prefetch_count(-n) return ok('pool will shrink') @control_command() def pool_restart(state, modules=None, reload=False, reloader=None, **kwargs): """Restart execution pool.""" if state.app.conf.worker_pool_restarts: state.consumer.controller.reload(modules, reload, reloader=reloader) return ok('reload started') else: raise ValueError('Pool restarts not enabled') @control_command( args=[('max', int), ('min', int)], signature='[max [min]]', ) def autoscale(state, max=None, min=None): """Modify autoscale settings.""" autoscaler = state.consumer.controller.autoscaler if autoscaler: max_, min_ = autoscaler.update(max, min) return ok(f'autoscale now max={max_} min={min_}') raise ValueError('Autoscale not enabled') @control_command() def shutdown(state, msg='Got shutdown from remote', **kwargs): """Shutdown worker(s).""" logger.warning(msg) raise WorkerShutdown(msg) # -- Queues @control_command( args=[ ('queue', str), ('exchange', str), ('exchange_type', str), ('routing_key', str), ], signature='<queue> [exchange [type [routing_key]]]', ) def add_consumer(state, queue, exchange=None, exchange_type=None, routing_key=None, **options): """Tell worker(s) to consume from task queue by name.""" state.consumer.call_soon( state.consumer.add_task_queue, queue, exchange, exchange_type or 'direct', routing_key, **options) return ok(f'add consumer {queue}') @control_command( args=[('queue', str)], signature='<queue>', ) def cancel_consumer(state, queue, **_): """Tell worker(s) to stop consuming from task queue by name.""" state.consumer.call_soon( state.consumer.cancel_task_queue, queue, ) return ok(f'no longer consuming from {queue}') @inspect_command() def active_queues(state): """List the task queues a worker is currently consuming from.""" if state.consumer.task_consumer: return [dict(queue.as_dict(recurse=True)) for queue in state.consumer.task_consumer.queues] return []
Back to Top