upload_queue_manager.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import asyncio
  2. from queue import Queue
  3. from typing import Dict, Callable, Any
  4. from utils.get_logger import setup_logger
  5. import time
  6. logger = setup_logger(__name__)
  7. class UploadQueueManager:
  8. """
  9. 文档上传队列管理器
  10. 功能:
  11. 1. 控制最多10个并发任务
  12. 2. 前一个任务完成后,下一个任务自动开始
  13. 3. 提供队列状态监控
  14. """
  15. def __init__(self, max_concurrent: int = 10):
  16. """
  17. 初始化队列管理器
  18. 参数:
  19. max_concurrent: 最大并发数,默认10
  20. """
  21. self.max_concurrent = max_concurrent
  22. self.task_queue = asyncio.Queue() # 任务队列
  23. self.semaphore = asyncio.Semaphore(max_concurrent) # 信号量控制并发
  24. self.active_tasks: Dict[str, Dict[str, Any]] = {} # 活跃任务字典 {task_id: task_info}
  25. self.completed_tasks: Dict[str, Dict[str, Any]] = {} # 已完成任务字典
  26. self.failed_tasks: Dict[str, Dict[str, Any]] = {} # 失败任务字典
  27. self.workers = [] # 工作协程列表
  28. self.is_running = False
  29. self._lock = asyncio.Lock() # 用于保护共享状态
  30. logger.info(f"队列管理器初始化完成,最大并发数: {max_concurrent}")
  31. async def start(self):
  32. """启动队列管理器的工作协程"""
  33. if self.is_running:
  34. logger.warning("队列管理器已经在运行中")
  35. return
  36. self.is_running = True
  37. # 启动多个工作协程
  38. for i in range(self.max_concurrent):
  39. worker = asyncio.create_task(self._worker(i))
  40. self.workers.append(worker)
  41. logger.info(f"队列管理器已启动,工作协程数: {self.max_concurrent}")
  42. async def stop(self):
  43. """停止队列管理器"""
  44. self.is_running = False
  45. # 等待所有工作协程完成
  46. if self.workers:
  47. await asyncio.gather(*self.workers, return_exceptions=True)
  48. self.workers.clear()
  49. logger.info("队列管理器已停止")
  50. async def _worker(self, worker_id: int):
  51. """
  52. 工作协程,从队列中取任务并执行
  53. 参数:
  54. worker_id: 工作协程ID
  55. """
  56. logger.info(f"工作协程 {worker_id} 已启动")
  57. while self.is_running:
  58. try:
  59. # 从队列中获取任务(带超时,避免永久阻塞)
  60. try:
  61. task_info = await asyncio.wait_for(
  62. self.task_queue.get(),
  63. timeout=1.0
  64. )
  65. except asyncio.TimeoutError:
  66. continue
  67. task_id = task_info["task_id"]
  68. task_func = task_info["task_func"]
  69. task_args = task_info["task_args"]
  70. task_kwargs = task_info["task_kwargs"]
  71. logger.info(f"工作协程 {worker_id} 开始处理任务: {task_id}")
  72. # 使用信号量控制并发
  73. async with self.semaphore:
  74. # 记录任务开始
  75. async with self._lock:
  76. self.active_tasks[task_id] = {
  77. "task_id": task_id,
  78. "worker_id": worker_id,
  79. "start_time": time.time(),
  80. "status": "running"
  81. }
  82. try:
  83. # 执行任务
  84. result = await task_func(*task_args, **task_kwargs)
  85. # 记录任务完成
  86. async with self._lock:
  87. if task_id in self.active_tasks:
  88. task_info = self.active_tasks.pop(task_id)
  89. task_info["end_time"] = time.time()
  90. task_info["duration"] = task_info["end_time"] - task_info["start_time"]
  91. task_info["status"] = "completed"
  92. task_info["result"] = result
  93. self.completed_tasks[task_id] = task_info
  94. logger.info(f"工作协程 {worker_id} 完成任务: {task_id}, 耗时: {task_info['duration']:.2f}秒")
  95. except Exception as e:
  96. # 记录任务失败
  97. logger.error(f"工作协程 {worker_id} 执行任务 {task_id} 失败: {e}", exc_info=True)
  98. async with self._lock:
  99. if task_id in self.active_tasks:
  100. task_info = self.active_tasks.pop(task_id)
  101. task_info["end_time"] = time.time()
  102. task_info["duration"] = task_info["end_time"] - task_info["start_time"]
  103. task_info["status"] = "failed"
  104. task_info["error"] = str(e)
  105. self.failed_tasks[task_id] = task_info
  106. # 标记任务完成
  107. self.task_queue.task_done()
  108. except Exception as e:
  109. logger.error(f"工作协程 {worker_id} 发生异常: {e}", exc_info=True)
  110. logger.info(f"工作协程 {worker_id} 已停止")
  111. async def submit_task(self, task_id: str, task_func: Callable, *args, **kwargs) -> None:
  112. """
  113. 提交任务到队列
  114. 参数:
  115. task_id: 任务ID(通常是 document_id)
  116. task_func: 任务函数(异步函数)
  117. *args: 任务函数的位置参数
  118. **kwargs: 任务函数的关键字参数
  119. """
  120. task_info = {
  121. "task_id": task_id,
  122. "task_func": task_func,
  123. "task_args": args,
  124. "task_kwargs": kwargs,
  125. "submit_time": time.time()
  126. }
  127. await self.task_queue.put(task_info)
  128. queue_size = self.task_queue.qsize()
  129. active_count = len(self.active_tasks)
  130. logger.info(f"任务已提交到队列: {task_id}, 队列长度: {queue_size}, 活跃任务数: {active_count}")
  131. def get_queue_status(self) -> Dict[str, Any]:
  132. """
  133. 获取队列状态
  134. 返回:
  135. 包含队列状态信息的字典
  136. """
  137. return {
  138. "is_running": self.is_running,
  139. "max_concurrent": self.max_concurrent,
  140. "queue_size": self.task_queue.qsize(),
  141. "active_tasks_count": len(self.active_tasks),
  142. "completed_tasks_count": len(self.completed_tasks),
  143. "failed_tasks_count": len(self.failed_tasks),
  144. "active_tasks": list(self.active_tasks.keys()),
  145. "worker_count": len(self.workers)
  146. }
  147. def get_task_status(self, task_id: str) -> Dict[str, Any]:
  148. """
  149. 获取指定任务的状态
  150. 参数:
  151. task_id: 任务ID
  152. 返回:
  153. 任务状态信息,如果任务不存在返回 None
  154. """
  155. if task_id in self.active_tasks:
  156. return {**self.active_tasks[task_id], "status": "running"}
  157. elif task_id in self.completed_tasks:
  158. return {**self.completed_tasks[task_id], "status": "completed"}
  159. elif task_id in self.failed_tasks:
  160. return {**self.failed_tasks[task_id], "status": "failed"}
  161. else:
  162. return {"status": "not_found"}
  163. async def wait_for_completion(self):
  164. """等待队列中的所有任务完成"""
  165. await self.task_queue.join()
  166. logger.info("队列中的所有任务已完成")
  167. async def remove_from_queue(self, task_id: str) -> bool:
  168. """
  169. 从等待队列中删除指定任务
  170. 参数:
  171. task_id: 任务ID
  172. 返回:
  173. bool: 是否成功删除
  174. """
  175. async with self._lock:
  176. # 获取队列中的所有任务
  177. temp_tasks = []
  178. found = False
  179. # 从队列中取出所有任务
  180. while not self.task_queue.empty():
  181. try:
  182. task_info = self.task_queue.get_nowait()
  183. if task_info["task_id"] == task_id:
  184. found = True
  185. logger.info(f"从队列中找到并删除任务: {task_id}")
  186. self.task_queue.task_done()
  187. else:
  188. temp_tasks.append(task_info)
  189. except:
  190. break
  191. # 将其他任务放回队列
  192. for task_info in temp_tasks:
  193. await self.task_queue.put(task_info)
  194. return found
  195. def clear_history(self):
  196. """清空已完成和失败任务的历史记录"""
  197. cleared_completed = len(self.completed_tasks)
  198. cleared_failed = len(self.failed_tasks)
  199. self.completed_tasks.clear()
  200. self.failed_tasks.clear()
  201. logger.info(f"已清空历史记录: 完成任务 {cleared_completed} 个, 失败任务 {cleared_failed} 个")
  202. # 全局队列管理器实例(单例模式)
  203. _global_queue_manager: UploadQueueManager = None
  204. def get_queue_manager(max_concurrent: int = 10) -> UploadQueueManager:
  205. """
  206. 获取全局队列管理器实例(单例模式)
  207. 参数:
  208. max_concurrent: 最大并发数,默认10
  209. 返回:
  210. UploadQueueManager 实例
  211. """
  212. global _global_queue_manager
  213. if _global_queue_manager is None:
  214. _global_queue_manager = UploadQueueManager(max_concurrent)
  215. logger.info("创建全局队列管理器实例")
  216. return _global_queue_manager