import sys import os import logging from datetime import datetime from typing import List, Dict, Set, Tuple, Optional import mysql.connector from mysql.connector import Error from pymilvus import MilvusClient # 添加项目路径到sys.path sys.path.append(os.path.dirname(os.path.abspath(__file__))) from config import milvus_uri, mysql_config from utils.get_logger import setup_logger # 日志 logger = setup_logger(__name__) class VectorDataCleaner: """向量数据清理工具类""" def __init__(self): """初始化数据清理工具""" self.mysql_config = mysql_config self.milvus_uri = milvus_uri self.milvus_client = None self.mysql_conn = None self._connect_databases() def _ensure_mysql_connection(self): if not self.mysql_conn or not self.mysql_conn.is_connected(): logger.warning("MySQL连接失效,尝试重连") self.mysql_conn = mysql.connector.connect(**self.mysql_config) def _connect_databases(self): """连接数据库""" # 连接MySQL try: self.mysql_conn = mysql.connector.connect(**self.mysql_config) logger.info("MySQL数据库连接成功") except Error as e: logger.error(f"MySQL数据库连接失败: {e}") raise # 连接Milvus try: self.milvus_client = MilvusClient(uri=self.milvus_uri) logger.info("Milvus向量数据库连接成功") except Exception as e: logger.error(f"Milvus向量数据库连接失败: {e}") raise def close_connections(self): """关闭数据库连接""" if self.mysql_conn: self.mysql_conn.close() logger.info("MySQL连接已关闭") # Milvus客户端会自动管理连接 def get_mysql_slices(self, knowledge_id: str, document_id: str = None, mode: str = "gc") -> Set[str]: """ 从MySQL获取指定知识库/文件的所有切片ID Args: knowledge_id: 知识库ID document_id: 文档ID,可选。如果不提供则获取整个知识库的切片 Returns: Set[str]: 切片ID集合 """ # if not self.mysql_conn: # logger.error("MySQL连接未建立") # return set() self._ensure_mysql_connection() cursor = None try: cursor = self.mysql_conn.cursor() if mode == "jt": if document_id: # 查询指定文档的切片 query = """ SELECT slice_id FROM slice_info WHERE knowledge_id = %s AND document_id = %s AND del_flag != "1" """ cursor.execute(query, (knowledge_id, document_id)) logger.info(f"查询知识库 {knowledge_id} 中文档 {document_id} 的切片") else: # 查询整个知识库的切片 query = """ SELECT slice_id FROM slice_info WHERE knowledge_id = %s AND del_flag != "1" """ cursor.execute(query, (knowledge_id,)) logger.info(f"查询知识库 {knowledge_id} 的所有切片") elif mode == "gc": if document_id: # 查询指定文档的切片 query = """ SELECT slice_id FROM slice_info WHERE knowledge_id = %s AND document_id = %s """ cursor.execute(query, (knowledge_id, document_id)) logger.info(f"查询知识库 {knowledge_id} 中文档 {document_id} 的切片") else: # 查询整个知识库的切片 query = """ SELECT slice_id FROM slice_info WHERE knowledge_id = %s """ cursor.execute(query, (knowledge_id,)) logger.info(f"查询知识库 {knowledge_id} 的所有切片") results = cursor.fetchall() slice_ids = {row[0] for row in results} logger.info(f"MySQL中找到 {len(slice_ids)} 个切片") return slice_ids, True except Error as e: logger.error(f"查询MySQL切片数据失败: {e}") return set(), False finally: if cursor: cursor.close() def get_milvus_slices(self, knowledge_id: str, document_id: str = None) -> Set[str]: """ 从Milvus获取指定知识库/文件的所有切片ID Args: knowledge_id: 知识库ID(也是collection名称) document_id: 文档ID,可选。如果不提供则获取整个知识库的切片 Returns: Set[str]: 切片ID集合 """ if not self.milvus_client: logger.error("Milvus连接未建立") return set() try: # 检查collection是否存在 if not self.milvus_client.has_collection(collection_name=knowledge_id): logger.warning(f"Milvus中不存在知识库 {knowledge_id} 对应的collection") return set() # 构建查询条件 if document_id: filter_expr = f"doc_id == '{document_id}'" logger.info(f"查询Milvus知识库 {knowledge_id} 中文档 {document_id} 的切片") else: filter_expr = None logger.info(f"查询Milvus知识库 {knowledge_id} 的所有切片") # 查询所有切片ID(分批查询以避免限制) all_results = [] offset = 0 limit = 10000 # 每次查询10000条 while True: results = self.milvus_client.query( collection_name=knowledge_id, filter=filter_expr, output_fields=["chunk_id"], limit=limit, offset=offset ) if not results: break all_results.extend(results) # 如果返回的结果少于limit,说明已经查询完毕 if len(results) < limit: break offset += limit results = all_results slice_ids = {result["chunk_id"] for result in results if "chunk_id" in result} logger.info(f"Milvus中找到 {len(slice_ids)} 个切片") return slice_ids except Exception as e: logger.error(f"查询Milvus切片数据失败: {e}") return set() def get_milvus_slice_details(self, knowledge_id: str, chunk_ids: List[str]) -> List[Dict]: """ 获取Milvus中指定切片的详细信息 Args: knowledge_id: 知识库ID chunk_ids: 切片ID列表 Returns: List[Dict]: 切片详细信息列表 """ if not self.milvus_client or not chunk_ids: return [] try: # 构建查询条件 chunk_ids_str = "', '".join(chunk_ids) filter_expr = f"chunk_id in ['{chunk_ids_str}']" results = self.milvus_client.query( collection_name=knowledge_id, filter=filter_expr, output_fields=["pk", "chunk_id", "doc_id", "content"], limit=len(chunk_ids) ) return results except Exception as e: logger.error(f"获取Milvus切片详细信息失败: {e}") return [] def delete_milvus_slices(self, knowledge_id: str, chunk_ids: List[str]) -> Tuple[bool, str]: """ 删除Milvus中的指定切片 Args: knowledge_id: 知识库ID chunk_ids: 要删除的切片ID列表 Returns: Tuple[bool, str]: (是否成功, 消息) """ if not self.milvus_client or not chunk_ids: return True, "没有需要删除的切片" try: # 分批删除,避免一次删除太多数据 batch_size = 100 total_deleted = 0 for i in range(0, len(chunk_ids), batch_size): batch_chunk_ids = chunk_ids[i:i + batch_size] # 先查询获取主键 chunk_ids_str = "', '".join(batch_chunk_ids) filter_expr = f"chunk_id in ['{chunk_ids_str}']" results = self.milvus_client.query( collection_name=knowledge_id, filter=filter_expr, output_fields=["pk"], limit=len(batch_chunk_ids) ) if not results: continue # 提取主键 primary_keys = [result["pk"] for result in results] # 执行删除 delete_result = self.milvus_client.delete( collection_name=knowledge_id, ids=primary_keys ) total_deleted += len(primary_keys) logger.info(f"删除了 {len(primary_keys)} 个切片,累计删除 {total_deleted} 个") # 执行flush和compact操作 self.milvus_client.flush(collection_name=knowledge_id) logger.info(f"成功删除 {total_deleted} 个脏数据切片") return True, f"成功删除 {total_deleted} 个脏数据切片" except Exception as e: logger.error(f"删除Milvus切片失败: {e}") return False, f"删除失败: {str(e)}" def cleanup_data(self, knowledge_id: str, document_id: str = None, dry_run: bool = True) -> Dict: """ 执行数据清理 Args: knowledge_id: 知识库ID document_id: 文档ID,可选 dry_run: 是否为试运行模式(不实际删除数据) Returns: Dict: 清理结果统计 """ logger.info(f"开始数据清理 - 知识库: {knowledge_id}, 文档: {document_id or '全部'}, 试运行: {dry_run}") # 检查数据库连接 if not self.mysql_conn: return {"error": "MySQL连接失败"} if not self.milvus_client: return {"error": "Milvus连接失败"} try: # 获取MySQL中的切片ID mysql_slice_ids, success = self.get_mysql_slices(knowledge_id, document_id) if not success: return {"error": "MySQL查询失败"} # 获取Milvus中的切片ID milvus_slice_ids = self.get_milvus_slices(knowledge_id, document_id) # 找出需要删除的切片(在Milvus中但不在MySQL中) dirty_slice_ids = milvus_slice_ids - mysql_slice_ids # 找出缺失的切片(在MySQL中但不在Milvus中) missing_slice_ids = mysql_slice_ids - milvus_slice_ids result = { "knowledge_id": knowledge_id, "document_id": document_id, "mysql_slice_count": len(mysql_slice_ids), "milvus_slice_count": len(milvus_slice_ids), "dirty_slice_count": len(dirty_slice_ids), "missing_slice_count": len(missing_slice_ids), "dirty_slice_ids": list(dirty_slice_ids), "missing_slice_ids": list(missing_slice_ids), "dry_run": dry_run } logger.info(f"数据统计 - MySQL切片: {len(mysql_slice_ids)}, " f"Milvus切片: {len(milvus_slice_ids)}, " f"脏数据: {len(dirty_slice_ids)}, " f"缺失数据: {len(missing_slice_ids)}") if dirty_slice_ids: if dry_run: logger.info(f"试运行模式 - 发现 {len(dirty_slice_ids)} 个脏数据切片,不会实际删除") # 获取脏数据的详细信息用于日志记录 dirty_details = self.get_milvus_slice_details(knowledge_id, list(dirty_slice_ids)[:10]) # 只获取前10个的详情 result["dirty_slice_samples"] = dirty_details else: logger.info(f"开始删除 {len(dirty_slice_ids)} 个脏数据切片") success, message = self.delete_milvus_slices(knowledge_id, list(dirty_slice_ids)) result["delete_success"] = success result["delete_message"] = message else: logger.info("没有发现脏数据,数据一致性良好") result["delete_success"] = True result["delete_message"] = "没有脏数据需要删除" if missing_slice_ids: logger.warning(f"发现 {len(missing_slice_ids)} 个切片在MySQL中存在但在Milvus中缺失") return result except Exception as e: logger.error(f"数据清理过程中发生错误: {e}") return {"error": str(e)} finally: # 注意:不在这里关闭连接,因为可能还需要使用 pass def generate_report(self, result: Dict) -> str: """ 生成清理报告 Args: result: 清理结果 Returns: str: 格式化的报告 """ if "error" in result: return f"清理失败: {result['error']}" report = f""" 数据清理报告 {'='*50} 知识库ID: {result['knowledge_id']} 文档ID: {result.get('document_id', '全部文档')} 执行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} 试运行模式: {'是' if result['dry_run'] else '否'} 数据统计: - MySQL中的切片数量: {result['mysql_slice_count']} - Milvus中的切片数量: {result['milvus_slice_count']} - 脏数据切片数量: {result['dirty_slice_count']} - 缺失切片数量: {result['missing_slice_count']} 清理结果: - 删除状态: {'成功' if result.get('delete_success', False) else '失败'} - 删除消息: {result.get('delete_message', 'N/A')} 建议: """ if result['dirty_slice_count'] > 0: report += f"- 发现 {result['dirty_slice_count']} 个脏数据切片" if result['dry_run']: report += ",建议在非试运行模式下执行清理\n" else: report += ",已执行清理\n" else: report += "- 数据一致性良好,无需清理\n" if result['missing_slice_count'] > 0: report += f"- 发现 {result['missing_slice_count']} 个切片在Milvus中缺失,建议检查数据同步\n" return report def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description='数据清理脚本 - 清理向量库中的脏数据') parser.add_argument('knowledge_id', help='知识库ID') parser.add_argument('--document_id', help='文档ID(可选,不提供则清理整个知识库)') parser.add_argument('--dry-run', action='store_true', default=True, help='试运行模式(默认开启,不实际删除数据)') parser.add_argument('--execute', action='store_true', help='实际执行删除操作(关闭运行模式)') args = parser.parse_args() # 如果指定了--execute,则关闭运行模式 dry_run = args.execute # 创建清理工具 cleanup_tool = VectorDataCleaner() try: # 执行清理 result = cleanup_tool.cleanup_data( knowledge_id=args.knowledge_id, document_id=args.document_id, dry_run=dry_run ) # 生成并保存报告 report = cleanup_tool.generate_report(result) print(report) finally: # 关闭连接 cleanup_tool.close_connections() # 保存报告到文件 timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') os.makedirs("./data_clean", exist_ok=True) report_filename = f"./data_clean/data_cleanup_report_{args.knowledge_id}_{timestamp}.txt" try: with open(report_filename, 'w', encoding='utf-8') as f: f.write(report) print(f"\n报告已保存到: {report_filename}") except Exception as e: logger.error(f"保存报告失败: {e}") if __name__ == "__main__": main()