| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466 |
- import sys
- import os
- import logging
- from datetime import datetime
- from typing import List, Dict, Set, Tuple, Optional
- import mysql.connector
- from mysql.connector import Error
- from pymilvus import MilvusClient
- # 添加项目路径到sys.path
- sys.path.append(os.path.dirname(os.path.abspath(__file__)))
- from config import milvus_uri, mysql_config
- from utils.get_logger import setup_logger
- # 日志
- logger = setup_logger(__name__)
- class VectorDataCleaner:
- """向量数据清理工具类"""
-
- def __init__(self):
- """初始化数据清理工具"""
- self.mysql_config = mysql_config
- self.milvus_uri = milvus_uri
- self.milvus_client = None
- self.mysql_conn = None
- self._connect_databases()
- def _ensure_mysql_connection(self):
- if not self.mysql_conn or not self.mysql_conn.is_connected():
- logger.warning("MySQL连接失效,尝试重连")
- self.mysql_conn = mysql.connector.connect(**self.mysql_config)
- def _connect_databases(self):
- """连接数据库"""
- # 连接MySQL
- try:
- self.mysql_conn = mysql.connector.connect(**self.mysql_config)
- logger.info("MySQL数据库连接成功")
- except Error as e:
- logger.error(f"MySQL数据库连接失败: {e}")
- raise
-
- # 连接Milvus
- try:
- self.milvus_client = MilvusClient(uri=self.milvus_uri)
- logger.info("Milvus向量数据库连接成功")
- except Exception as e:
- logger.error(f"Milvus向量数据库连接失败: {e}")
- raise
-
- def close_connections(self):
- """关闭数据库连接"""
- if self.mysql_conn:
- self.mysql_conn.close()
- logger.info("MySQL连接已关闭")
- # Milvus客户端会自动管理连接
-
- def get_mysql_slices(self, knowledge_id: str, document_id: str = None, mode: str = "gc") -> Set[str]:
- """
- 从MySQL获取指定知识库/文件的所有切片ID
-
- Args:
- knowledge_id: 知识库ID
- document_id: 文档ID,可选。如果不提供则获取整个知识库的切片
-
- Returns:
- Set[str]: 切片ID集合
- """
- # if not self.mysql_conn:
- # logger.error("MySQL连接未建立")
- # return set()
- self._ensure_mysql_connection()
-
- cursor = None
- try:
- cursor = self.mysql_conn.cursor()
- if mode == "jt":
-
- if document_id:
- # 查询指定文档的切片
- query = """
- SELECT slice_id FROM slice_info
- WHERE knowledge_id = %s AND document_id = %s AND del_flag != "1"
- """
- cursor.execute(query, (knowledge_id, document_id))
- logger.info(f"查询知识库 {knowledge_id} 中文档 {document_id} 的切片")
- else:
- # 查询整个知识库的切片
- query = """
- SELECT slice_id FROM slice_info
- WHERE knowledge_id = %s AND del_flag != "1"
- """
- cursor.execute(query, (knowledge_id,))
- logger.info(f"查询知识库 {knowledge_id} 的所有切片")
- elif mode == "gc":
-
- if document_id:
- # 查询指定文档的切片
- query = """
- SELECT slice_id FROM slice_info
- WHERE knowledge_id = %s AND document_id = %s
- """
- cursor.execute(query, (knowledge_id, document_id))
- logger.info(f"查询知识库 {knowledge_id} 中文档 {document_id} 的切片")
- else:
- # 查询整个知识库的切片
- query = """
- SELECT slice_id FROM slice_info
- WHERE knowledge_id = %s
- """
- cursor.execute(query, (knowledge_id,))
- logger.info(f"查询知识库 {knowledge_id} 的所有切片")
-
- results = cursor.fetchall()
- slice_ids = {row[0] for row in results}
-
- logger.info(f"MySQL中找到 {len(slice_ids)} 个切片")
- return slice_ids, True
-
- except Error as e:
- logger.error(f"查询MySQL切片数据失败: {e}")
- return set(), False
- finally:
- if cursor:
- cursor.close()
-
- def get_milvus_slices(self, knowledge_id: str, document_id: str = None) -> Set[str]:
- """
- 从Milvus获取指定知识库/文件的所有切片ID
-
- Args:
- knowledge_id: 知识库ID(也是collection名称)
- document_id: 文档ID,可选。如果不提供则获取整个知识库的切片
-
- Returns:
- Set[str]: 切片ID集合
- """
- if not self.milvus_client:
- logger.error("Milvus连接未建立")
- return set()
-
- try:
- # 检查collection是否存在
- if not self.milvus_client.has_collection(collection_name=knowledge_id):
- logger.warning(f"Milvus中不存在知识库 {knowledge_id} 对应的collection")
- return set()
-
- # 构建查询条件
- if document_id:
- filter_expr = f"doc_id == '{document_id}'"
- logger.info(f"查询Milvus知识库 {knowledge_id} 中文档 {document_id} 的切片")
- else:
- filter_expr = None
- logger.info(f"查询Milvus知识库 {knowledge_id} 的所有切片")
-
- # 查询所有切片ID(分批查询以避免限制)
- all_results = []
- offset = 0
- limit = 10000 # 每次查询10000条
-
- while True:
- results = self.milvus_client.query(
- collection_name=knowledge_id,
- filter=filter_expr,
- output_fields=["chunk_id"],
- limit=limit,
- offset=offset
- )
-
- if not results:
- break
-
- all_results.extend(results)
-
- # 如果返回的结果少于limit,说明已经查询完毕
- if len(results) < limit:
- break
-
- offset += limit
-
- results = all_results
-
- slice_ids = {result["chunk_id"] for result in results if "chunk_id" in result}
-
- logger.info(f"Milvus中找到 {len(slice_ids)} 个切片")
- return slice_ids
-
- except Exception as e:
- logger.error(f"查询Milvus切片数据失败: {e}")
- return set()
-
- def get_milvus_slice_details(self, knowledge_id: str, chunk_ids: List[str]) -> List[Dict]:
- """
- 获取Milvus中指定切片的详细信息
-
- Args:
- knowledge_id: 知识库ID
- chunk_ids: 切片ID列表
-
- Returns:
- List[Dict]: 切片详细信息列表
- """
- if not self.milvus_client or not chunk_ids:
- return []
-
- try:
- # 构建查询条件
- chunk_ids_str = "', '".join(chunk_ids)
- filter_expr = f"chunk_id in ['{chunk_ids_str}']"
-
- results = self.milvus_client.query(
- collection_name=knowledge_id,
- filter=filter_expr,
- output_fields=["pk", "chunk_id", "doc_id", "content"],
- limit=len(chunk_ids)
- )
-
- return results
-
- except Exception as e:
- logger.error(f"获取Milvus切片详细信息失败: {e}")
- return []
-
- def delete_milvus_slices(self, knowledge_id: str, chunk_ids: List[str]) -> Tuple[bool, str]:
- """
- 删除Milvus中的指定切片
-
- Args:
- knowledge_id: 知识库ID
- chunk_ids: 要删除的切片ID列表
-
- Returns:
- Tuple[bool, str]: (是否成功, 消息)
- """
- if not self.milvus_client or not chunk_ids:
- return True, "没有需要删除的切片"
-
- try:
- # 分批删除,避免一次删除太多数据
- batch_size = 100
- total_deleted = 0
-
- for i in range(0, len(chunk_ids), batch_size):
- batch_chunk_ids = chunk_ids[i:i + batch_size]
-
- # 先查询获取主键
- chunk_ids_str = "', '".join(batch_chunk_ids)
- filter_expr = f"chunk_id in ['{chunk_ids_str}']"
-
- results = self.milvus_client.query(
- collection_name=knowledge_id,
- filter=filter_expr,
- output_fields=["pk"],
- limit=len(batch_chunk_ids)
- )
-
- if not results:
- continue
-
- # 提取主键
- primary_keys = [result["pk"] for result in results]
-
- # 执行删除
- delete_result = self.milvus_client.delete(
- collection_name=knowledge_id,
- ids=primary_keys
- )
-
- total_deleted += len(primary_keys)
- logger.info(f"删除了 {len(primary_keys)} 个切片,累计删除 {total_deleted} 个")
-
- # 执行flush和compact操作
- self.milvus_client.flush(collection_name=knowledge_id)
- logger.info(f"成功删除 {total_deleted} 个脏数据切片")
-
- return True, f"成功删除 {total_deleted} 个脏数据切片"
-
- except Exception as e:
- logger.error(f"删除Milvus切片失败: {e}")
- return False, f"删除失败: {str(e)}"
-
- def cleanup_data(self, knowledge_id: str, document_id: str = None, dry_run: bool = True) -> Dict:
- """
- 执行数据清理
-
- Args:
- knowledge_id: 知识库ID
- document_id: 文档ID,可选
- dry_run: 是否为试运行模式(不实际删除数据)
-
- Returns:
- Dict: 清理结果统计
- """
- logger.info(f"开始数据清理 - 知识库: {knowledge_id}, 文档: {document_id or '全部'}, 试运行: {dry_run}")
-
- # 检查数据库连接
- if not self.mysql_conn:
- return {"error": "MySQL连接失败"}
-
- if not self.milvus_client:
- return {"error": "Milvus连接失败"}
-
- try:
- # 获取MySQL中的切片ID
- mysql_slice_ids, success = self.get_mysql_slices(knowledge_id, document_id)
- if not success:
- return {"error": "MySQL查询失败"}
-
- # 获取Milvus中的切片ID
- milvus_slice_ids = self.get_milvus_slices(knowledge_id, document_id)
-
- # 找出需要删除的切片(在Milvus中但不在MySQL中)
- dirty_slice_ids = milvus_slice_ids - mysql_slice_ids
-
- # 找出缺失的切片(在MySQL中但不在Milvus中)
- missing_slice_ids = mysql_slice_ids - milvus_slice_ids
-
- result = {
- "knowledge_id": knowledge_id,
- "document_id": document_id,
- "mysql_slice_count": len(mysql_slice_ids),
- "milvus_slice_count": len(milvus_slice_ids),
- "dirty_slice_count": len(dirty_slice_ids),
- "missing_slice_count": len(missing_slice_ids),
- "dirty_slice_ids": list(dirty_slice_ids),
- "missing_slice_ids": list(missing_slice_ids),
- "dry_run": dry_run
- }
-
- logger.info(f"数据统计 - MySQL切片: {len(mysql_slice_ids)}, "
- f"Milvus切片: {len(milvus_slice_ids)}, "
- f"脏数据: {len(dirty_slice_ids)}, "
- f"缺失数据: {len(missing_slice_ids)}")
-
- if dirty_slice_ids:
- if dry_run:
- logger.info(f"试运行模式 - 发现 {len(dirty_slice_ids)} 个脏数据切片,不会实际删除")
- # 获取脏数据的详细信息用于日志记录
- dirty_details = self.get_milvus_slice_details(knowledge_id, list(dirty_slice_ids)[:10]) # 只获取前10个的详情
- result["dirty_slice_samples"] = dirty_details
- else:
- logger.info(f"开始删除 {len(dirty_slice_ids)} 个脏数据切片")
- success, message = self.delete_milvus_slices(knowledge_id, list(dirty_slice_ids))
- result["delete_success"] = success
- result["delete_message"] = message
- else:
- logger.info("没有发现脏数据,数据一致性良好")
- result["delete_success"] = True
- result["delete_message"] = "没有脏数据需要删除"
-
- if missing_slice_ids:
- logger.warning(f"发现 {len(missing_slice_ids)} 个切片在MySQL中存在但在Milvus中缺失")
-
- return result
-
- except Exception as e:
- logger.error(f"数据清理过程中发生错误: {e}")
- return {"error": str(e)}
- finally:
- # 注意:不在这里关闭连接,因为可能还需要使用
- pass
-
- def generate_report(self, result: Dict) -> str:
- """
- 生成清理报告
-
- Args:
- result: 清理结果
-
- Returns:
- str: 格式化的报告
- """
- if "error" in result:
- return f"清理失败: {result['error']}"
-
- report = f"""
- 数据清理报告
- {'='*50}
- 知识库ID: {result['knowledge_id']}
- 文档ID: {result.get('document_id', '全部文档')}
- 执行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
- 试运行模式: {'是' if result['dry_run'] else '否'}
- 数据统计:
- - MySQL中的切片数量: {result['mysql_slice_count']}
- - Milvus中的切片数量: {result['milvus_slice_count']}
- - 脏数据切片数量: {result['dirty_slice_count']}
- - 缺失切片数量: {result['missing_slice_count']}
- 清理结果:
- - 删除状态: {'成功' if result.get('delete_success', False) else '失败'}
- - 删除消息: {result.get('delete_message', 'N/A')}
- 建议:
- """
-
- if result['dirty_slice_count'] > 0:
- report += f"- 发现 {result['dirty_slice_count']} 个脏数据切片"
- if result['dry_run']:
- report += ",建议在非试运行模式下执行清理\n"
- else:
- report += ",已执行清理\n"
- else:
- report += "- 数据一致性良好,无需清理\n"
-
- if result['missing_slice_count'] > 0:
- report += f"- 发现 {result['missing_slice_count']} 个切片在Milvus中缺失,建议检查数据同步\n"
-
- return report
- def main():
- """主函数"""
- import argparse
-
- parser = argparse.ArgumentParser(description='数据清理脚本 - 清理向量库中的脏数据')
- parser.add_argument('knowledge_id', help='知识库ID')
- parser.add_argument('--document_id', help='文档ID(可选,不提供则清理整个知识库)')
- parser.add_argument('--dry-run', action='store_true', default=True,
- help='试运行模式(默认开启,不实际删除数据)')
- parser.add_argument('--execute', action='store_true',
- help='实际执行删除操作(关闭运行模式)')
-
- args = parser.parse_args()
-
- # 如果指定了--execute,则关闭运行模式
- dry_run = args.execute
-
- # 创建清理工具
- cleanup_tool = VectorDataCleaner()
-
- try:
- # 执行清理
- result = cleanup_tool.cleanup_data(
- knowledge_id=args.knowledge_id,
- document_id=args.document_id,
- dry_run=dry_run
- )
-
- # 生成并保存报告
- report = cleanup_tool.generate_report(result)
- print(report)
- finally:
- # 关闭连接
- cleanup_tool.close_connections()
-
- # 保存报告到文件
- timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
- os.makedirs("./data_clean", exist_ok=True)
- report_filename = f"./data_clean/data_cleanup_report_{args.knowledge_id}_{timestamp}.txt"
-
- try:
- with open(report_filename, 'w', encoding='utf-8') as f:
- f.write(report)
- print(f"\n报告已保存到: {report_filename}")
- except Exception as e:
- logger.error(f"保存报告失败: {e}")
- if __name__ == "__main__":
- main()
|