| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- import threading
- from dataclasses import dataclass, field
- from typing import Dict, Optional, Any
- from datetime import datetime
- from utils.get_logger import setup_logger
- logger = setup_logger(__name__)
- class TaskCancelledException(Exception):
- """任务被取消异常"""
- pass
- @dataclass
- class TaskContext:
- """任务上下文 - 存储单个任务的状态信息"""
- task_id: str # 任务ID
- user_id: str = "" # 用户ID
- knowledge_id: str = "" # 知识库ID
- is_cancelled: bool = False # 任务取消标志
- created_at: datetime = field(default_factory=datetime.now) # 任务创建时间
- reporter: Optional[Any] = None # 进度报告器
- # 记录已插入的数据,用于取消时清理
- inserted_doc_ids: list = field(default_factory=list)
- inserted_to_milvus: bool = False # 是否已插入到 Milvus
- inserted_to_mysql_slice: bool = False # 是否已插入到 MySQL slice
- inserted_to_mysql_image: bool = False # 是否已插入到 MySQL image
- class TaskRegistry:
- """
- 全局任务注册表(单例模式)
-
- 用于跟踪所有正在执行的任务,支持通过 task_id 查找和取消任务
- """
- _instance = None
- _lock = threading.Lock()
-
- def __new__(cls):
- if cls._instance is None:
- with cls._lock:
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- cls._instance._tasks: Dict[str, TaskContext] = {}
- cls._instance._task_lock = threading.Lock()
- return cls._instance
-
- def register(self, task_id: str, user_id: str = "", knowledge_id: str = "") -> TaskContext:
- """
- 注册新任务
-
- 输入:
- task_id: 任务唯一标识(通常是 document_id)
- user_id: 用户ID
- knowledge_id: 知识库ID
-
- 输出:
- TaskContext: 任务上下文对象
- """
- with self._task_lock:
- # 如果已存在,先清理旧的
- if task_id in self._tasks:
- logger.warning(f"任务 {task_id} 已存在,将覆盖旧任务")
-
- ctx = TaskContext(
- task_id=task_id,
- user_id=user_id,
- knowledge_id=knowledge_id
- )
- self._tasks[task_id] = ctx
- logger.info(f"任务注册成功: task_id={task_id}, knowledge_id={knowledge_id}")
- return ctx
-
- def get(self, task_id: str) -> Optional[TaskContext]:
- """获取任务上下文"""
- with self._task_lock:
- return self._tasks.get(task_id)
-
- def cancel(self, task_id: str) -> tuple[bool, str]:
- """
- 取消任务
-
- 输入:
- task_id: 任务ID
-
- 输出:
- (success, message): 是否成功及消息
- """
- with self._task_lock:
- ctx = self._tasks.get(task_id)
-
- if not ctx:
- return False, "任务不存在或已完成"
-
- if ctx.is_cancelled:
- return False, "任务已被取消"
-
- # 设置取消标志
- ctx.is_cancelled = True
-
- # 停止进度上报
- if ctx.reporter:
- ctx.reporter.is_completed = True
-
- logger.info(f"任务取消标志已设置: task_id={task_id}")
- return True, "取消信号已发送"
-
- def unregister(self, task_id: str):
- """注销任务(任务完成或取消后调用)"""
- with self._task_lock:
- ctx = self._tasks.pop(task_id, None)
- if ctx:
- logger.info(f"任务注销成功: task_id={task_id}")
-
- def is_cancelled(self, task_id: str) -> bool:
- """检查任务是否已取消"""
- with self._task_lock:
- ctx = self._tasks.get(task_id)
- return ctx.is_cancelled if ctx else False
-
- def get_all_tasks(self) -> Dict[str, dict]:
- """获取所有活跃任务的信息(用于调试)"""
- with self._task_lock:
- return {
- task_id: {
- "user_id": ctx.user_id,
- "knowledge_id": ctx.knowledge_id,
- "is_cancelled": ctx.is_cancelled,
- "created_at": ctx.created_at.isoformat(),
- }
- for task_id, ctx in self._tasks.items()
- }
- # 全局单例实例
- task_registry = TaskRegistry()
|