Source code for futureproof.task_manager

import logging
import queue
from typing import Any, Callable, List, Union, Iterable, Iterator
from enum import Enum
from functools import partial
from threading import Lock
from itertools import chain

import attr

from futureproof import executors

logger = logging.getLogger(__name__)


[docs]class ErrorPolicyEnum(Enum): """Error policy options.""" IGNORE = "ignore" LOG = "log" RAISE = "raise"
[docs]@attr.s(eq=False) class Task: """Tasks describe an execution with parameters and encapsulate the result. When submitted a Future is added to it encapsulating the result. """ fn = attr.ib() # type: Callable args = attr.ib(default=()) # type: tuple kwargs = attr.ib(default={}) # type: dict result = attr.ib(default=None) # type: Any complete = attr.ib(default=False) # type: bool
[docs]class TaskManager: """Manages how tasks are created and placed in the queue. Executors pull Tasks from the managers. :param executor: The executor that will execute the submitted tasks. :param error_policy: Error policy indicating what should the behaviour be if an exception is raised. Defaults to ``raise``, raising the exception as soon as it happens and stopping all execution. ``log`` will only log the exception but continue the execution of the remaining tasks. ``ignore`` will not do anything, note, however, that the exception will be set as the result of the task so users can re-raise them if they so choose. """ def __init__( self, executor: executors._FutureProofExecutor, error_policy: Union[ErrorPolicyEnum, str] = ErrorPolicyEnum.RAISE, ): self._tasks_in_queue = 0 self._error_policy = ( error_policy if isinstance(error_policy, ErrorPolicyEnum) else ErrorPolicyEnum(error_policy.lower()) ) self._executor = executor self._shutdown = False self._tasks = [] # type: Iterable self._submitted_task_count = 0 # type: int self._results_queue = queue.Queue() # type: queue.Queue #: List of completed Task objects self.completed_tasks = [] # type: List[Task] self._completed_tasks_lock = Lock() # type: Lock def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.run() self.join() @property def results(self) -> List: """List of results for completed tasks. Note the contents may different as more tasks are completed. """ with self._completed_tasks_lock: return [task.result for task in self.completed_tasks if task.complete]
[docs] def submit(self, fn: Callable, *args: Any, **kwargs: Any) -> Task: """Submit a task for execution. The newly created :class:`Task` will be returned. """ task = Task(fn, args, kwargs) self._tasks = chain(self._tasks, [task]) return task
[docs] def map(self, fn: Callable, iterable: Iterable) -> None: """Submit a set of tasks from a callable and a iterable of arguments `iterable` may be any iterable of primitives or an iterable of argument tuples. """ def gen(): for i in iterable: args = i if isinstance(i, tuple) else (i,) yield Task(fn, args) self._tasks = chain(self._tasks, gen())
[docs] def run(self) -> None: """Start the manager and wait until all tasks are completed before shutting down.""" for _ in self.as_completed(): pass
[docs] def as_completed(self) -> Iterator[Task]: """Start the manager and return an iterator of completed tasks. When using the task manager as a context manager as_completed must be used *inside* the context, otherwise there will be no effect as the task manager will wait until all tasks are completed. """ for task in self._tasks: if self._shutdown: break if self._tasks_in_queue == self._executor.max_workers: logger.debug("Queue full, waiting for result") yield self._wait_for_result() self._submit_task(task) while len(self.completed_tasks) < self._submitted_task_count: yield self._wait_for_result() self._executor.join()
def _submit_task(self, task: Task) -> None: """Submits a task to the executor, note this will block if the queue is full.""" logger.debug( "Tasks in queue: %d, submitting task %r", self._tasks_in_queue, task ) self._tasks_in_queue += 1 fut = self._executor.submit(task.fn, *task.args, **task.kwargs) cb = partial(self._on_complete, task=task) fut.add_done_callback(cb) self._submitted_task_count += 1 def _on_complete(self, future, task): """Called once per future to perform an operation over the result. Note this function is called by the executing threads. """ complete_task = Task(task.fn, task.args, task.kwargs) complete_task.complete = True try: complete_task.result = future.result() except Exception as exc: logger.debug("Exception on task %r", complete_task) complete_task.result = exc finally: logger.debug("Completed task %r", complete_task) self._results_queue.put(complete_task)
[docs] def join(self) -> None: """Block until all tasks are completed. Alternatively use the task manager as a context manager, upon exiting join will be called and results will be available """ while len(self.completed_tasks) < self._submitted_task_count: self._wait_for_result() self._executor.join()
def _wait_for_result(self) -> Task: """Gather result from a submitted tasks.""" completed_task = self._results_queue.get(block=True) logger.debug("Gathering result for completed task %r", completed_task) self._tasks_in_queue -= 1 self.completed_tasks.append(completed_task) if isinstance(completed_task.result, Exception): if self._error_policy == ErrorPolicyEnum.RAISE: self._raise(completed_task.result) elif self._error_policy == ErrorPolicyEnum.LOG: logger.exception( "Task %s raised an exception", completed_task, exc_info=completed_task.result, ) return completed_task def _raise(self, exception) -> None: """Performs cleanup before raising an exception.""" logger.info("Raising exception and shutting down") self._shutdown = True self._executor.join() raise exception