data_cleanup_script.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. import sys
  2. import os
  3. import logging
  4. from datetime import datetime
  5. from typing import List, Dict, Set, Tuple, Optional
  6. import mysql.connector
  7. from mysql.connector import Error
  8. from pymilvus import MilvusClient
  9. # 添加项目路径到sys.path
  10. sys.path.append(os.path.dirname(os.path.abspath(__file__)))
  11. from config import milvus_uri, mysql_config
  12. from utils.get_logger import setup_logger
  13. # 日志
  14. logger = setup_logger(__name__)
  15. class VectorDataCleaner:
  16. """向量数据清理工具类"""
  17. def __init__(self):
  18. """初始化数据清理工具"""
  19. self.mysql_config = mysql_config
  20. self.milvus_uri = milvus_uri
  21. self.milvus_client = None
  22. self.mysql_conn = None
  23. self._connect_databases()
  24. def _ensure_mysql_connection(self):
  25. if not self.mysql_conn or not self.mysql_conn.is_connected():
  26. logger.warning("MySQL连接失效,尝试重连")
  27. self.mysql_conn = mysql.connector.connect(**self.mysql_config)
  28. def _connect_databases(self):
  29. """连接数据库"""
  30. # 连接MySQL
  31. try:
  32. self.mysql_conn = mysql.connector.connect(**self.mysql_config)
  33. logger.info("MySQL数据库连接成功")
  34. except Error as e:
  35. logger.error(f"MySQL数据库连接失败: {e}")
  36. raise
  37. # 连接Milvus
  38. try:
  39. self.milvus_client = MilvusClient(uri=self.milvus_uri)
  40. logger.info("Milvus向量数据库连接成功")
  41. except Exception as e:
  42. logger.error(f"Milvus向量数据库连接失败: {e}")
  43. raise
  44. def close_connections(self):
  45. """关闭数据库连接"""
  46. if self.mysql_conn:
  47. self.mysql_conn.close()
  48. logger.info("MySQL连接已关闭")
  49. # Milvus客户端会自动管理连接
  50. def get_mysql_slices(self, knowledge_id: str, document_id: str = None, mode: str = "gc") -> Set[str]:
  51. """
  52. 从MySQL获取指定知识库/文件的所有切片ID
  53. Args:
  54. knowledge_id: 知识库ID
  55. document_id: 文档ID,可选。如果不提供则获取整个知识库的切片
  56. Returns:
  57. Set[str]: 切片ID集合
  58. """
  59. # if not self.mysql_conn:
  60. # logger.error("MySQL连接未建立")
  61. # return set()
  62. self._ensure_mysql_connection()
  63. cursor = None
  64. try:
  65. cursor = self.mysql_conn.cursor()
  66. if mode == "jt":
  67. if document_id:
  68. # 查询指定文档的切片
  69. query = """
  70. SELECT slice_id FROM slice_info
  71. WHERE knowledge_id = %s AND document_id = %s AND del_flag != "1"
  72. """
  73. cursor.execute(query, (knowledge_id, document_id))
  74. logger.info(f"查询知识库 {knowledge_id} 中文档 {document_id} 的切片")
  75. else:
  76. # 查询整个知识库的切片
  77. query = """
  78. SELECT slice_id FROM slice_info
  79. WHERE knowledge_id = %s AND del_flag != "1"
  80. """
  81. cursor.execute(query, (knowledge_id,))
  82. logger.info(f"查询知识库 {knowledge_id} 的所有切片")
  83. elif mode == "gc":
  84. if document_id:
  85. # 查询指定文档的切片
  86. query = """
  87. SELECT slice_id FROM slice_info
  88. WHERE knowledge_id = %s AND document_id = %s
  89. """
  90. cursor.execute(query, (knowledge_id, document_id))
  91. logger.info(f"查询知识库 {knowledge_id} 中文档 {document_id} 的切片")
  92. else:
  93. # 查询整个知识库的切片
  94. query = """
  95. SELECT slice_id FROM slice_info
  96. WHERE knowledge_id = %s
  97. """
  98. cursor.execute(query, (knowledge_id,))
  99. logger.info(f"查询知识库 {knowledge_id} 的所有切片")
  100. results = cursor.fetchall()
  101. slice_ids = {row[0] for row in results}
  102. logger.info(f"MySQL中找到 {len(slice_ids)} 个切片")
  103. return slice_ids, True
  104. except Error as e:
  105. logger.error(f"查询MySQL切片数据失败: {e}")
  106. return set(), False
  107. finally:
  108. if cursor:
  109. cursor.close()
  110. def get_milvus_slices(self, knowledge_id: str, document_id: str = None) -> Set[str]:
  111. """
  112. 从Milvus获取指定知识库/文件的所有切片ID
  113. Args:
  114. knowledge_id: 知识库ID(也是collection名称)
  115. document_id: 文档ID,可选。如果不提供则获取整个知识库的切片
  116. Returns:
  117. Set[str]: 切片ID集合
  118. """
  119. if not self.milvus_client:
  120. logger.error("Milvus连接未建立")
  121. return set()
  122. try:
  123. # 检查collection是否存在
  124. if not self.milvus_client.has_collection(collection_name=knowledge_id):
  125. logger.warning(f"Milvus中不存在知识库 {knowledge_id} 对应的collection")
  126. return set()
  127. # 构建查询条件
  128. if document_id:
  129. filter_expr = f"doc_id == '{document_id}'"
  130. logger.info(f"查询Milvus知识库 {knowledge_id} 中文档 {document_id} 的切片")
  131. else:
  132. filter_expr = None
  133. logger.info(f"查询Milvus知识库 {knowledge_id} 的所有切片")
  134. # 查询所有切片ID(分批查询以避免限制)
  135. all_results = []
  136. offset = 0
  137. limit = 10000 # 每次查询10000条
  138. while True:
  139. results = self.milvus_client.query(
  140. collection_name=knowledge_id,
  141. filter=filter_expr,
  142. output_fields=["chunk_id"],
  143. limit=limit,
  144. offset=offset
  145. )
  146. if not results:
  147. break
  148. all_results.extend(results)
  149. # 如果返回的结果少于limit,说明已经查询完毕
  150. if len(results) < limit:
  151. break
  152. offset += limit
  153. results = all_results
  154. slice_ids = {result["chunk_id"] for result in results if "chunk_id" in result}
  155. logger.info(f"Milvus中找到 {len(slice_ids)} 个切片")
  156. return slice_ids
  157. except Exception as e:
  158. logger.error(f"查询Milvus切片数据失败: {e}")
  159. return set()
  160. def get_milvus_slice_details(self, knowledge_id: str, chunk_ids: List[str]) -> List[Dict]:
  161. """
  162. 获取Milvus中指定切片的详细信息
  163. Args:
  164. knowledge_id: 知识库ID
  165. chunk_ids: 切片ID列表
  166. Returns:
  167. List[Dict]: 切片详细信息列表
  168. """
  169. if not self.milvus_client or not chunk_ids:
  170. return []
  171. try:
  172. # 构建查询条件
  173. chunk_ids_str = "', '".join(chunk_ids)
  174. filter_expr = f"chunk_id in ['{chunk_ids_str}']"
  175. results = self.milvus_client.query(
  176. collection_name=knowledge_id,
  177. filter=filter_expr,
  178. output_fields=["pk", "chunk_id", "doc_id", "content"],
  179. limit=len(chunk_ids)
  180. )
  181. return results
  182. except Exception as e:
  183. logger.error(f"获取Milvus切片详细信息失败: {e}")
  184. return []
  185. def delete_milvus_slices(self, knowledge_id: str, chunk_ids: List[str]) -> Tuple[bool, str]:
  186. """
  187. 删除Milvus中的指定切片
  188. Args:
  189. knowledge_id: 知识库ID
  190. chunk_ids: 要删除的切片ID列表
  191. Returns:
  192. Tuple[bool, str]: (是否成功, 消息)
  193. """
  194. if not self.milvus_client or not chunk_ids:
  195. return True, "没有需要删除的切片"
  196. try:
  197. # 分批删除,避免一次删除太多数据
  198. batch_size = 100
  199. total_deleted = 0
  200. for i in range(0, len(chunk_ids), batch_size):
  201. batch_chunk_ids = chunk_ids[i:i + batch_size]
  202. # 先查询获取主键
  203. chunk_ids_str = "', '".join(batch_chunk_ids)
  204. filter_expr = f"chunk_id in ['{chunk_ids_str}']"
  205. results = self.milvus_client.query(
  206. collection_name=knowledge_id,
  207. filter=filter_expr,
  208. output_fields=["pk"],
  209. limit=len(batch_chunk_ids)
  210. )
  211. if not results:
  212. continue
  213. # 提取主键
  214. primary_keys = [result["pk"] for result in results]
  215. # 执行删除
  216. delete_result = self.milvus_client.delete(
  217. collection_name=knowledge_id,
  218. ids=primary_keys
  219. )
  220. total_deleted += len(primary_keys)
  221. logger.info(f"删除了 {len(primary_keys)} 个切片,累计删除 {total_deleted} 个")
  222. # 执行flush和compact操作
  223. self.milvus_client.flush(collection_name=knowledge_id)
  224. logger.info(f"成功删除 {total_deleted} 个脏数据切片")
  225. return True, f"成功删除 {total_deleted} 个脏数据切片"
  226. except Exception as e:
  227. logger.error(f"删除Milvus切片失败: {e}")
  228. return False, f"删除失败: {str(e)}"
  229. def cleanup_data(self, knowledge_id: str, document_id: str = None, dry_run: bool = True) -> Dict:
  230. """
  231. 执行数据清理
  232. Args:
  233. knowledge_id: 知识库ID
  234. document_id: 文档ID,可选
  235. dry_run: 是否为试运行模式(不实际删除数据)
  236. Returns:
  237. Dict: 清理结果统计
  238. """
  239. logger.info(f"开始数据清理 - 知识库: {knowledge_id}, 文档: {document_id or '全部'}, 试运行: {dry_run}")
  240. # 检查数据库连接
  241. if not self.mysql_conn:
  242. return {"error": "MySQL连接失败"}
  243. if not self.milvus_client:
  244. return {"error": "Milvus连接失败"}
  245. try:
  246. # 获取MySQL中的切片ID
  247. mysql_slice_ids, success = self.get_mysql_slices(knowledge_id, document_id)
  248. if not success:
  249. return {"error": "MySQL查询失败"}
  250. # 获取Milvus中的切片ID
  251. milvus_slice_ids = self.get_milvus_slices(knowledge_id, document_id)
  252. # 找出需要删除的切片(在Milvus中但不在MySQL中)
  253. dirty_slice_ids = milvus_slice_ids - mysql_slice_ids
  254. # 找出缺失的切片(在MySQL中但不在Milvus中)
  255. missing_slice_ids = mysql_slice_ids - milvus_slice_ids
  256. result = {
  257. "knowledge_id": knowledge_id,
  258. "document_id": document_id,
  259. "mysql_slice_count": len(mysql_slice_ids),
  260. "milvus_slice_count": len(milvus_slice_ids),
  261. "dirty_slice_count": len(dirty_slice_ids),
  262. "missing_slice_count": len(missing_slice_ids),
  263. "dirty_slice_ids": list(dirty_slice_ids),
  264. "missing_slice_ids": list(missing_slice_ids),
  265. "dry_run": dry_run
  266. }
  267. logger.info(f"数据统计 - MySQL切片: {len(mysql_slice_ids)}, "
  268. f"Milvus切片: {len(milvus_slice_ids)}, "
  269. f"脏数据: {len(dirty_slice_ids)}, "
  270. f"缺失数据: {len(missing_slice_ids)}")
  271. if dirty_slice_ids:
  272. if dry_run:
  273. logger.info(f"试运行模式 - 发现 {len(dirty_slice_ids)} 个脏数据切片,不会实际删除")
  274. # 获取脏数据的详细信息用于日志记录
  275. dirty_details = self.get_milvus_slice_details(knowledge_id, list(dirty_slice_ids)[:10]) # 只获取前10个的详情
  276. result["dirty_slice_samples"] = dirty_details
  277. else:
  278. logger.info(f"开始删除 {len(dirty_slice_ids)} 个脏数据切片")
  279. success, message = self.delete_milvus_slices(knowledge_id, list(dirty_slice_ids))
  280. result["delete_success"] = success
  281. result["delete_message"] = message
  282. else:
  283. logger.info("没有发现脏数据,数据一致性良好")
  284. result["delete_success"] = True
  285. result["delete_message"] = "没有脏数据需要删除"
  286. if missing_slice_ids:
  287. logger.warning(f"发现 {len(missing_slice_ids)} 个切片在MySQL中存在但在Milvus中缺失")
  288. return result
  289. except Exception as e:
  290. logger.error(f"数据清理过程中发生错误: {e}")
  291. return {"error": str(e)}
  292. finally:
  293. # 注意:不在这里关闭连接,因为可能还需要使用
  294. pass
  295. def generate_report(self, result: Dict) -> str:
  296. """
  297. 生成清理报告
  298. Args:
  299. result: 清理结果
  300. Returns:
  301. str: 格式化的报告
  302. """
  303. if "error" in result:
  304. return f"清理失败: {result['error']}"
  305. report = f"""
  306. 数据清理报告
  307. {'='*50}
  308. 知识库ID: {result['knowledge_id']}
  309. 文档ID: {result.get('document_id', '全部文档')}
  310. 执行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
  311. 试运行模式: {'是' if result['dry_run'] else '否'}
  312. 数据统计:
  313. - MySQL中的切片数量: {result['mysql_slice_count']}
  314. - Milvus中的切片数量: {result['milvus_slice_count']}
  315. - 脏数据切片数量: {result['dirty_slice_count']}
  316. - 缺失切片数量: {result['missing_slice_count']}
  317. 清理结果:
  318. - 删除状态: {'成功' if result.get('delete_success', False) else '失败'}
  319. - 删除消息: {result.get('delete_message', 'N/A')}
  320. 建议:
  321. """
  322. if result['dirty_slice_count'] > 0:
  323. report += f"- 发现 {result['dirty_slice_count']} 个脏数据切片"
  324. if result['dry_run']:
  325. report += ",建议在非试运行模式下执行清理\n"
  326. else:
  327. report += ",已执行清理\n"
  328. else:
  329. report += "- 数据一致性良好,无需清理\n"
  330. if result['missing_slice_count'] > 0:
  331. report += f"- 发现 {result['missing_slice_count']} 个切片在Milvus中缺失,建议检查数据同步\n"
  332. return report
  333. def main():
  334. """主函数"""
  335. import argparse
  336. parser = argparse.ArgumentParser(description='数据清理脚本 - 清理向量库中的脏数据')
  337. parser.add_argument('knowledge_id', help='知识库ID')
  338. parser.add_argument('--document_id', help='文档ID(可选,不提供则清理整个知识库)')
  339. parser.add_argument('--dry-run', action='store_true', default=True,
  340. help='试运行模式(默认开启,不实际删除数据)')
  341. parser.add_argument('--execute', action='store_true',
  342. help='实际执行删除操作(关闭运行模式)')
  343. args = parser.parse_args()
  344. # 如果指定了--execute,则关闭运行模式
  345. dry_run = args.execute
  346. # 创建清理工具
  347. cleanup_tool = VectorDataCleaner()
  348. try:
  349. # 执行清理
  350. result = cleanup_tool.cleanup_data(
  351. knowledge_id=args.knowledge_id,
  352. document_id=args.document_id,
  353. dry_run=dry_run
  354. )
  355. # 生成并保存报告
  356. report = cleanup_tool.generate_report(result)
  357. print(report)
  358. finally:
  359. # 关闭连接
  360. cleanup_tool.close_connections()
  361. # 保存报告到文件
  362. timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  363. os.makedirs("./data_clean", exist_ok=True)
  364. report_filename = f"./data_clean/data_cleanup_report_{args.knowledge_id}_{timestamp}.txt"
  365. try:
  366. with open(report_filename, 'w', encoding='utf-8') as f:
  367. f.write(report)
  368. print(f"\n报告已保存到: {report_filename}")
  369. except Exception as e:
  370. logger.error(f"保存报告失败: {e}")
  371. if __name__ == "__main__":
  372. main()