task_registry.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import threading
  2. from dataclasses import dataclass, field
  3. from typing import Dict, Optional, Any
  4. from datetime import datetime
  5. from utils.get_logger import setup_logger
  6. logger = setup_logger(__name__)
  7. class TaskCancelledException(Exception):
  8. """任务被取消异常"""
  9. pass
  10. @dataclass
  11. class TaskContext:
  12. """任务上下文 - 存储单个任务的状态信息"""
  13. task_id: str # 任务ID
  14. user_id: str = "" # 用户ID
  15. knowledge_id: str = "" # 知识库ID
  16. is_cancelled: bool = False # 任务取消标志
  17. created_at: datetime = field(default_factory=datetime.now) # 任务创建时间
  18. reporter: Optional[Any] = None # 进度报告器
  19. # 记录已插入的数据,用于取消时清理
  20. inserted_doc_ids: list = field(default_factory=list)
  21. inserted_to_milvus: bool = False # 是否已插入到 Milvus
  22. inserted_to_mysql_slice: bool = False # 是否已插入到 MySQL slice
  23. inserted_to_mysql_image: bool = False # 是否已插入到 MySQL image
  24. class TaskRegistry:
  25. """
  26. 全局任务注册表(单例模式)
  27. 用于跟踪所有正在执行的任务,支持通过 task_id 查找和取消任务
  28. """
  29. _instance = None
  30. _lock = threading.Lock()
  31. def __new__(cls):
  32. if cls._instance is None:
  33. with cls._lock:
  34. if cls._instance is None:
  35. cls._instance = super().__new__(cls)
  36. cls._instance._tasks: Dict[str, TaskContext] = {}
  37. cls._instance._task_lock = threading.Lock()
  38. return cls._instance
  39. def register(self, task_id: str, user_id: str = "", knowledge_id: str = "") -> TaskContext:
  40. """
  41. 注册新任务
  42. 输入:
  43. task_id: 任务唯一标识(通常是 document_id)
  44. user_id: 用户ID
  45. knowledge_id: 知识库ID
  46. 输出:
  47. TaskContext: 任务上下文对象
  48. """
  49. with self._task_lock:
  50. # 如果已存在,先清理旧的
  51. if task_id in self._tasks:
  52. logger.warning(f"任务 {task_id} 已存在,将覆盖旧任务")
  53. ctx = TaskContext(
  54. task_id=task_id,
  55. user_id=user_id,
  56. knowledge_id=knowledge_id
  57. )
  58. self._tasks[task_id] = ctx
  59. logger.info(f"任务注册成功: task_id={task_id}, knowledge_id={knowledge_id}")
  60. return ctx
  61. def get(self, task_id: str) -> Optional[TaskContext]:
  62. """获取任务上下文"""
  63. with self._task_lock:
  64. return self._tasks.get(task_id)
  65. def cancel(self, task_id: str) -> tuple[bool, str]:
  66. """
  67. 取消任务
  68. 输入:
  69. task_id: 任务ID
  70. 输出:
  71. (success, message): 是否成功及消息
  72. """
  73. with self._task_lock:
  74. ctx = self._tasks.get(task_id)
  75. if not ctx:
  76. return False, "任务不存在或已完成"
  77. if ctx.is_cancelled:
  78. return False, "任务已被取消"
  79. # 设置取消标志
  80. ctx.is_cancelled = True
  81. # 停止进度上报
  82. if ctx.reporter:
  83. ctx.reporter.is_completed = True
  84. logger.info(f"任务取消标志已设置: task_id={task_id}")
  85. return True, "取消信号已发送"
  86. def unregister(self, task_id: str):
  87. """注销任务(任务完成或取消后调用)"""
  88. with self._task_lock:
  89. ctx = self._tasks.pop(task_id, None)
  90. if ctx:
  91. logger.info(f"任务注销成功: task_id={task_id}")
  92. def is_cancelled(self, task_id: str) -> bool:
  93. """检查任务是否已取消"""
  94. with self._task_lock:
  95. ctx = self._tasks.get(task_id)
  96. return ctx.is_cancelled if ctx else False
  97. def get_all_tasks(self) -> Dict[str, dict]:
  98. """获取所有活跃任务的信息(用于调试)"""
  99. with self._task_lock:
  100. return {
  101. task_id: {
  102. "user_id": ctx.user_id,
  103. "knowledge_id": ctx.knowledge_id,
  104. "is_cancelled": ctx.is_cancelled,
  105. "created_at": ctx.created_at.isoformat(),
  106. }
  107. for task_id, ctx in self._tasks.items()
  108. }
  109. # 全局单例实例
  110. task_registry = TaskRegistry()