"""Integration testing utilities."""
import socket
import sys
from collections import defaultdict
from functools import partial
from itertools import count
from kombu.utils.functional import retry_over_time
from celery import states
from celery.exceptions import TimeoutError
from celery.result import ResultSet
from celery.utils.text import truncate
from celery.utils.time import humanize_seconds as _humanize_seconds
E_STILL_WAITING = 'Still waiting for {0}. Trying again {when}: {exc!r}'
humanize_seconds = partial(_humanize_seconds, microseconds=True)
[docs]class Sentinel(Exception):
"""Signifies the end of something."""
[docs]class ManagerMixin:
"""Mixin that adds :class:`Manager` capabilities."""
def _init_manager(self,
block_timeout=30 * 60.0, no_join=False,
stdout=None, stderr=None):
# type: (float, bool, TextIO, TextIO) -> None
self.stdout = sys.stdout if stdout is None else stdout
self.stderr = sys.stderr if stderr is None else stderr
self.connerrors = self.app.connection().recoverable_connection_errors
self.block_timeout = block_timeout
self.no_join = no_join
[docs] def missing_results(self, r):
# type: (Sequence[AsyncResult]) -> Sequence[str]
return [res.id for res in r if res.id not in res.backend._cache]
[docs] def wait_for(
self,
fun, # type: Callable
catch, # type: Sequence[Any]
desc="thing", # type: str
args=(), # type: Tuple
kwargs=None, # type: Dict
errback=None, # type: Callable
max_retries=10, # type: int
interval_start=0.1, # type: float
interval_step=0.5, # type: float
interval_max=5.0, # type: float
emit_warning=False, # type: bool
**options # type: Any
):
# type: (...) -> Any
"""Wait for event to happen.
The `catch` argument specifies the exception that means the event
has not happened yet.
"""
kwargs = {} if not kwargs else kwargs
def on_error(exc, intervals, retries):
interval = next(intervals)
if emit_warning:
self.warn(E_STILL_WAITING.format(
desc, when=humanize_seconds(interval, 'in', ' '), exc=exc,
))
if errback:
errback(exc, interval, retries)
return interval
return self.retry_over_time(
fun, catch,
args=args, kwargs=kwargs,
errback=on_error, max_retries=max_retries,
interval_start=interval_start, interval_step=interval_step,
**options
)
[docs] def ensure_not_for_a_while(self, fun, catch,
desc='thing', max_retries=20,
interval_start=0.1, interval_step=0.02,
interval_max=1.0, emit_warning=False,
**options):
"""Make sure something does not happen (at least for a while)."""
try:
return self.wait_for(
fun, catch, desc=desc, max_retries=max_retries,
interval_start=interval_start, interval_step=interval_step,
interval_max=interval_max, emit_warning=emit_warning,
)
except catch:
pass
else:
raise AssertionError(f'Should not have happened: {desc}')
[docs] def retry_over_time(self, *args, **kwargs):
return retry_over_time(*args, **kwargs)
[docs] def join(self, r, propagate=False, max_retries=10, **kwargs):
if self.no_join:
return
if not isinstance(r, ResultSet):
r = self.app.ResultSet([r])
received = []
def on_result(task_id, value):
received.append(task_id)
for i in range(max_retries) if max_retries else count(0):
received[:] = []
try:
return r.get(callback=on_result, propagate=propagate, **kwargs)
except (socket.timeout, TimeoutError) as exc:
waiting_for = self.missing_results(r)
self.remark(
'Still waiting for {}/{}: [{}]: {!r}'.format(
len(r) - len(received), len(r),
truncate(', '.join(waiting_for)), exc), '!',
)
except self.connerrors as exc:
self.remark(f'join: connection lost: {exc!r}', '!')
raise AssertionError('Test failed: Missing task results')
[docs] def inspect(self, timeout=3.0):
return self.app.control.inspect(timeout=timeout)
[docs] def query_tasks(self, ids, timeout=0.5):
tasks = self.inspect(timeout).query_task(*ids) or {}
yield from tasks.items()
[docs] def query_task_states(self, ids, timeout=0.5):
states = defaultdict(set)
for hostname, reply in self.query_tasks(ids, timeout=timeout):
for task_id, (state, _) in reply.items():
states[state].add(task_id)
return states
[docs] def assert_accepted(self, ids, interval=0.5,
desc='waiting for tasks to be accepted', **policy):
return self.assert_task_worker_state(
self.is_accepted, ids, interval=interval, desc=desc, **policy
)
[docs] def assert_received(self, ids, interval=0.5,
desc='waiting for tasks to be received', **policy):
return self.assert_task_worker_state(
self.is_accepted, ids, interval=interval, desc=desc, **policy
)
[docs] def assert_result_tasks_in_progress_or_completed(
self,
async_results,
interval=0.5,
desc='waiting for tasks to be started or completed',
**policy
):
return self.assert_task_state_from_result(
self.is_result_task_in_progress,
async_results,
interval=interval, desc=desc, **policy
)
[docs] def assert_task_state_from_result(self, fun, results,
interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, results, timeout=interval),
(Sentinel,), **policy
)
[docs] @staticmethod
def is_result_task_in_progress(results, **kwargs):
possible_states = (states.STARTED, states.SUCCESS)
return all(result.state in possible_states for result in results)
[docs] def assert_task_worker_state(self, fun, ids, interval=0.5, **policy):
return self.wait_for(
partial(self.true_or_raise, fun, ids, timeout=interval),
(Sentinel,), **policy
)
[docs] def is_received(self, ids, **kwargs):
return self._ids_matches_state(
['reserved', 'active', 'ready'], ids, **kwargs)
[docs] def is_accepted(self, ids, **kwargs):
return self._ids_matches_state(['active', 'ready'], ids, **kwargs)
def _ids_matches_state(self, expected_states, ids, timeout=0.5):
states = self.query_task_states(ids, timeout=timeout)
return all(
any(t in s for s in [states[k] for k in expected_states])
for t in ids
)
[docs] def true_or_raise(self, fun, *args, **kwargs):
res = fun(*args, **kwargs)
if not res:
raise Sentinel()
return res
[docs]class Manager(ManagerMixin):
"""Test helpers for task integration tests."""
def __init__(self, app, **kwargs):
self.app = app
self._init_manager(**kwargs)