import threading from dataclasses import dataclass, field from typing import Dict, Optional, Any from datetime import datetime from utils.get_logger import setup_logger logger = setup_logger(__name__) class TaskCancelledException(Exception): """任务被取消异常""" pass @dataclass class TaskContext: """任务上下文 - 存储单个任务的状态信息""" task_id: str # 任务ID user_id: str = "" # 用户ID knowledge_id: str = "" # 知识库ID is_cancelled: bool = False # 任务取消标志 created_at: datetime = field(default_factory=datetime.now) # 任务创建时间 reporter: Optional[Any] = None # 进度报告器 # 记录已插入的数据,用于取消时清理 inserted_doc_ids: list = field(default_factory=list) inserted_to_milvus: bool = False # 是否已插入到 Milvus inserted_to_mysql_slice: bool = False # 是否已插入到 MySQL slice inserted_to_mysql_image: bool = False # 是否已插入到 MySQL image class TaskRegistry: """ 全局任务注册表(单例模式) 用于跟踪所有正在执行的任务,支持通过 task_id 查找和取消任务 """ _instance = None _lock = threading.Lock() def __new__(cls): if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._tasks: Dict[str, TaskContext] = {} cls._instance._task_lock = threading.Lock() return cls._instance def register(self, task_id: str, user_id: str = "", knowledge_id: str = "") -> TaskContext: """ 注册新任务 输入: task_id: 任务唯一标识(通常是 document_id) user_id: 用户ID knowledge_id: 知识库ID 输出: TaskContext: 任务上下文对象 """ with self._task_lock: # 如果已存在,先清理旧的 if task_id in self._tasks: logger.warning(f"任务 {task_id} 已存在,将覆盖旧任务") ctx = TaskContext( task_id=task_id, user_id=user_id, knowledge_id=knowledge_id ) self._tasks[task_id] = ctx logger.info(f"任务注册成功: task_id={task_id}, knowledge_id={knowledge_id}") return ctx def get(self, task_id: str) -> Optional[TaskContext]: """获取任务上下文""" with self._task_lock: return self._tasks.get(task_id) def cancel(self, task_id: str) -> tuple[bool, str]: """ 取消任务 输入: task_id: 任务ID 输出: (success, message): 是否成功及消息 """ with self._task_lock: ctx = self._tasks.get(task_id) if not ctx: return False, "任务不存在或已完成" if ctx.is_cancelled: return False, "任务已被取消" # 设置取消标志 ctx.is_cancelled = True # 停止进度上报 if ctx.reporter: ctx.reporter.is_completed = True logger.info(f"任务取消标志已设置: task_id={task_id}") return True, "取消信号已发送" def unregister(self, task_id: str): """注销任务(任务完成或取消后调用)""" with self._task_lock: ctx = self._tasks.pop(task_id, None) if ctx: logger.info(f"任务注销成功: task_id={task_id}") def is_cancelled(self, task_id: str) -> bool: """检查任务是否已取消""" with self._task_lock: ctx = self._tasks.get(task_id) return ctx.is_cancelled if ctx else False def get_all_tasks(self) -> Dict[str, dict]: """获取所有活跃任务的信息(用于调试)""" with self._task_lock: return { task_id: { "user_id": ctx.user_id, "knowledge_id": ctx.knowledge_id, "is_cancelled": ctx.is_cancelled, "created_at": ctx.created_at.isoformat(), } for task_id, ctx in self._tasks.items() } # 全局单例实例 task_registry = TaskRegistry()