Source code for cup.util.threadpool

#!/usr/bin/python
# -*- coding: utf-8 -*
# #############################################################################
#
#  Copyright (c) 2014 Baidu.com,  Inc. All Rights Reserved
#
# #############################################################################
"""
:author:
    Guannan Ma
:create_date:
    2014
:last_date:
    2014
:descrition:
    Guannan ported threadpool from twisted.python.
    Mit License applied for twisted.
    http://www.opensource.org/licenses/mit-license.php
    if any concern, plz contact mythmgn@gmail.com
"""

try:
    import Queue as queue
except ImportError:
    # pylint: disable=F0401
    import queue
import contextlib
import threading
import copy
import sys
import traceback

import cup
from cup.util import context

_CONTEXT_TRACKER = context.ContextTracker4Thread()


[docs]class ThreadPool(object): """ Threadpool class """ _THREAD_FACTORY = threading.Thread _CURRENT_THREAD = staticmethod(threading.current_thread) _WORKER_STOP_SIGN = object() def __init__(self, minthreads=5, maxthreads=20, name=None): """ 创建一个线程池。 :param minthreads: 最少多少个线程在工作。 :param maxthreads: 最多多少个线程在工作 """ assert minthreads > 0, 'minimum must be >= 0 ' assert minthreads <= maxthreads, 'minimum is greater than maximum' self._min = 5 self._max = 20 self._joined = False self._started = False self._workers = 0 self._name = None # Queue is a thread-safe queue self._jobqueue = queue.Queue(0) self._min = minthreads self._max = maxthreads self._name = name self._waiters = [] self._threads = [] self._working = []
[docs] def start(self): """ 启动线程池 """ self._joined = False self._started = True # Start some threads. self.adjust_poolsize()
[docs] def start1worker(self): """ 为线程池增加一个线程。 """ self._workers += 1 name = "PoolThread-%s-%s" % (self._name or id(self), self._workers) new_thd = self._THREAD_FACTORY(target=self._worker, name=name) self._threads.append(new_thd) new_thd.start()
[docs] def stop1worker(self): """ 为线程池减少一个线程。 """ self._jobqueue.put(self._WORKER_STOP_SIGN) self._workers -= 1
def __setstate__(self, state): """ For pickling an instance from a serilized string """ self.__dict__ = state self.__class__.__init__(self, self._min, self._max) def __getstate__(self): state = {} state['min'] = self._min state['max'] = self._max return state def _start_decent_workers(self): need_size = self._jobqueue.qsize() + len(self._working) # Create enough, but not too many while self._workers < min(self._max, need_size): self.start1worker()
[docs] def add_1job(self, func, *args, **kwargs): """ :param func: 会被线程池调度的函数 :param *args: func函数需要的参数 :param **kw: func函数需要的kwargs参数 """ self.add_1job_with_callback(None, func, *args, **kwargs)
[docs] def add_1job_with_callback(self, result_callback, func, *args, **kwargs): """ :param result_callback: func作业处理函数被线程池调用后,无论成功与否都会 执行result_callback. result_callback函数需要有两个参数 (ret_in_bool, result), 成功的话为(True, result), 失败的话 为(False, result) 如果func raise exception, result_callback会收到(False, failure) :param func: 同add_1job, 被调度的作业函数 :param *args: 同add_1job, func的参数 :param **kwargs: 同add_1job, func的kwargs参数 """ if self._joined: return # pylint: disable=W0621 context = _CONTEXT_TRACKER.current_context().contexts[-1] job = (context, func, args, kwargs, result_callback) self._jobqueue.put(job) if self._started: self._start_decent_workers()
@contextlib.contextmanager def _worker_state(self, state_list, worker_thread): state_list.append(worker_thread) try: yield finally: state_list.remove(worker_thread) def _log_err_context(self, context): cup.log.warn( 'Seems a call with context failed. See the context info' ) cup.log.warn(str(context)) def _worker(self): """ worker func to handle jobs """ current_thd = self._CURRENT_THREAD() with self._worker_state(self._waiters, current_thd): job = self._jobqueue.get() while job is not self._WORKER_STOP_SIGN: with self._worker_state(self._working, current_thd): # pylint: disable=W0621 context, function, args, kwargs, result_callback = job del job try: # pylint: disable=W0142 result = _CONTEXT_TRACKER.call_with_context( context, function, *args, **kwargs ) success = True except Exception as error: success = False cup.log.warn( 'Func failed, func:%s, error_msg: %s, traceback:%s\n' % (str(function), str(error), traceback.format_exc()) ) if result_callback is None: cup.log.warn('This func does not have callback.') _CONTEXT_TRACKER.call_with_context( context, self._log_err_context, context ) result = None else: result = str(error) del function, args, kwargs # when out of "with scope", # the self._working will remove the thread from # its self._working list if result_callback is not None: try: _CONTEXT_TRACKER.call_with_context( context, result_callback, success, result ) except Exception as e: traceback.print_exc(file=sys.stderr) cup.log.warn( 'result_callback func failed, callback func:%s,' 'err_msg:%s' % (str(result_callback), str(e)) ) _CONTEXT_TRACKER.call_with_context( context, self._log_err_context, context ) del context, result_callback, result with self._worker_state(self._waiters, current_thd): job = self._jobqueue.get() # after with statements, self._waiters will remove current_thd # remove this thread from the list self._threads.remove(current_thd)
[docs] def stop(self): """ 停止线程池, 该操作是同步操作, 会夯住一直等到线程池所有线程退出。 """ self._joined = True threads = copy.copy(self._threads) while self._workers: self._jobqueue.put(self._WORKER_STOP_SIGN) self._workers -= 1 # and let's just make sure # FIXME: threads that have died before calling stop() are not joined. for thread in threads: thread.join()
[docs] def try_stop(self, check_interval=0.1): """ 发送停止线程池命令, 并尝试查看是否stop了。 如果没停止,返回False try_stop不会夯住, 会回返。 属于nonblocking模式下 """ self._joined = True threads = copy.copy(self._threads) while self._workers: self._jobqueue.put(self._WORKER_STOP_SIGN) self._workers -= 1 for thread in threads: thread.join(check_interval) for thread in threads: if thread.isAlive: return False return True
[docs] def adjust_poolsize(self, minthreads=None, maxthreads=None): """ 调整线程池的线程最少和最多运行线程个数 """ if minthreads is None: minthreads = self._min if maxthreads is None: maxthreads = self._max assert minthreads >= 0, 'minimum is negative' assert minthreads <= maxthreads, 'minimum is greater than maximum' self._min = minthreads self._max = maxthreads if not self._started: return # Kill of some threads if we have too many. while self._workers > self._max: self.stop1worker() # Start some threads if we have too few. while self._workers < self._min: self.start1worker() # Start some threads if there is a need. self._start_decent_workers()
[docs] def get_stats(self): """ 回返当前threadpool的状态信息. 其中queue_len为当前threadpool排队的作业长度 waiters_num为当前空闲的thread num working_num为当前正在工作的thread num thread_num为当前一共可以使用的thread num:: stat = {} stat['queue_len'] = self._jobqueue.qsize() stat['waiters_num'] = len(self._waiters) stat['working_num'] = len(self._working) stat['thread_num'] = len(self._threads) """ stat = {} stat['queue_len'] = self._jobqueue.qsize() stat['waiters_num'] = len(self._waiters) stat['working_num'] = len(self._working) stat['thread_num'] = len(self._threads) return stat
[docs] def dump_stats(self): """ 打印当前threadpool的状态信息到log 和stdout 其中状态信息来自于get_stats函数 """ stat = self.get_stats() print stat cup.log.info('Threadpool stat: %s' % stat) cup.log.debug('queue: %s' % self._jobqueue.queue) cup.log.debug('waiters: %s' % self._waiters) cup.log.debug('workers: %s' % self._working) cup.log.debug('total: %s' % self._threads) return stat