from rag.vector_db.milvus_vector import HybridRetriever from response_info import generate_message, generate_response from utils.get_logger import setup_logger from datetime import datetime from uuid import uuid1 import mysql.connector from mysql.connector import pooling, Error, errors import threading from concurrent.futures import ThreadPoolExecutor, TimeoutError from config import milvus_uri, mysql_config logger = setup_logger(__name__) # uri = "http://localhost:19530" # if 'POOL' not in globals(): # try: # POOL = pooling.MySQLConnectionPool( # pool_name="mysql_pool", # pool_size=10, # **mysql_config # ) # logger.info("MySQL 连接池初始化成功") # except Error as e: # logger.info(f"初始化 MySQL 连接池失败: {e}") # POOL = None class MilvusOperate: def __init__(self, collection_name: str = "default", embedding_name:str = "e5"): self.collection = collection_name self.hybrid_retriever = HybridRetriever(uri=milvus_uri, embedding_name=embedding_name, collection_name=collection_name) self.mysql_client = MysqlOperate() def _has_collection(self): is_collection = self.hybrid_retriever.has_collection() return is_collection def _create_collection(self): if self._has_collection(): resp = {"code": 400, "message": "数据库已存在"} else: create_result = self.hybrid_retriever.build_collection() resp = generate_message(create_result) return resp def _delete_collection(self): delete_result = self.hybrid_retriever.delete_collection(self.collection) resp = generate_message(delete_result) return resp def _put_by_id(self, slice_json): slice_id = slice_json.get("slice_id", None) slice_text = slice_json.get("slice_text", None) update_result, chunk_len = self.hybrid_retriever.update_data(chunk_id=slice_id, chunk=slice_text) if update_result.endswith("success"): # 如果成功,更新mysql中知识库总长度和文档长度 update_json = {} update_json["knowledge_id"] = slice_json.get("knowledge_id") update_json["doc_id"] = slice_json.get("document_id") update_json["chunk_len"] = chunk_len update_json["operate"] = "update" update_json["chunk_id"] = slice_id update_json["chunk_text"] = slice_text update_flag, update_str = self.mysql_client.update_total_doc_len(update_json) else: update_flag = False if not update_flag: update_result = "update_error" resp = generate_message(update_result) return resp def _insert_slice(self, slice_json): slice_id = str(uuid1()) knowledge_id = slice_json.get("knowledge_id") doc_id = slice_json.get("document_id") slice_text = slice_json.get("slice_text", None) doc_name = slice_json.get("doc_name") chunk_len = len(slice_text) metadata = { "content": slice_text, "doc_id": doc_id, "chunk_id": slice_id, "metadata": {"source": doc_name, "chunk_len": chunk_len} } insert_flag, insert_str = self.hybrid_retriever.insert_data(slice_text, metadata) if insert_flag: # 如果成功,更新mysql中知识库总长度和文档长度 update_json = {} update_json["knowledge_id"] = slice_json.get("knowledge_id") update_json["doc_id"] = slice_json.get("document_id") update_json["chunk_len"] = chunk_len update_json["operate"] = "insert" update_json["chunk_id"] = slice_id update_json["chunk_text"] = slice_text update_flag, update_str = self.mysql_client.update_total_doc_len(update_json) else: logger.error(f"插入向量库出错:{insert_str}") update_flag = False update_str = "向量库写入出错" # pass if not update_flag: logger.error(f"新增切片中mysql数据库出错:{update_str}") insert_result = "insert_error" else: insert_result = "insert_success" resp = generate_message(insert_result) return resp def _delete_by_chunk_id(self, chunk_id, knowledge_id, document_id): logger.info(f"删除的切片id:{chunk_id}") delete_result, delete_chunk_len = self.hybrid_retriever.delete_by_chunk_id(chunk_id=chunk_id) if delete_result.endswith("success"): chunk_len = delete_chunk_len[0] update_json = { "knowledge_id": knowledge_id, "doc_id": document_id, "chunk_len": -chunk_len, "operate": "delete", "chunk_id": chunk_id } update_flag, update_str = self.mysql_client.update_total_doc_len(update_json) else: logger.error("根据chunk id删除向量库失败") update_flag = False update_str = "根据chunk id删除失败" if not update_flag: logger.error(update_str) delete_result = "delete_error" resp = generate_message(delete_result) return resp def _delete_by_doc_id(self, doc_id: str = None): logger.info(f"删除数据的id:{doc_id}") delete_result = self.hybrid_retriever.delete_by_doc_id(doc_id=doc_id) resp = generate_message(delete_result) return resp def _search_by_chunk_id(self, chunk_id): if self._has_collection(): query_result = self.hybrid_retriever.query_chunk_id(chunk_id=chunk_id) else: query_result = [] logger.info(f"根据切片查询到的信息:{query_result}") resp = generate_response(query_result) return resp def _search_by_chunk_id_list(self, chunk_id_list): if self._has_collection(): query_result = self.hybrid_retriever.query_chunk_id_list(chunk_id_list) else: query_result = [] logger.info(f"召回的切片列表查询切片信息:{query_result}") chunk_content_list = [] for chunk_dict in query_result: chunk_content = chunk_dict.get("content") chunk_content_list.append(chunk_content) return chunk_content_list def _search_by_key_word(self, search_json): if self._has_collection(): doc_id = search_json.get("document_id", None) text = search_json.get("text", None) page_num = search_json.get("pageNum", 1) page_size = search_json.get("pageSize", 10) page_num = search_json.get("pageNum") # 根据传过来的id处理对应知识库 query_result = self.hybrid_retriever.query_filter(doc_id=doc_id, filter_field=text) else: query_result = [] resp = generate_response(query_result,page_num,page_size) return resp def _insert_data(self, docs): for doc in docs: chunk = doc.get("content") insert_flag, insert_info = self.hybrid_retriever.insert_data(chunk, doc) if not insert_flag: break resp = insert_flag return resp, insert_info def _batch_insert_data(self, docs, text_lists): insert_flag, insert_info = self.hybrid_retriever.batch_insert_data(text_lists, docs) resp = insert_flag return resp, insert_info def _search(self, query, k, mode): search_result = self.hybrid_retriever.search(query, k, mode) return search_result # class MysqlOperate: # def get_connection(self): # """ # 从连接池中获取一个连接 # :return: 数据库连接对象 # """ # # try: # # with ThreadPoolExecutor() as executor: # # future = executor.submit(POOL.get_connection) # # connection = future.result(timeout=5.0) # 设置超时时间为5秒 # # logger.info("成功从连接池获取连接") # # return connection, "success" # # except TimeoutError: # # logger.error("获取mysql数据库连接池超时") # # return None, "mysql获取连接池超时" # # except errors.InterfaceError as e: # # logger.error(f"MySQL 接口异常:{e}") # # return None, "mysql接口异常" # # except errors.OperationalError as e: # # logger.error(f"MySQL 操作错误:{e}") # # return None, "mysql 操作错误" # # except Error as e: # # logger.error(f"无法从连接池获取连接: {e}") # # return None, str(e) # connection = None # event = threading.Event() # def target(): # nonlocal connection # try: # connection = POOL.get_connection() # finally: # event.set() # thread = threading.Thread(target=target) # thread.start() # event.wait(timeout=5) # if thread.is_alive(): # # 超时处理 # logger.error("获取连接超时") # return None, "获取连接超时" # else: # if connection: # return connection, "success" # else: # logger.error("获取连接失败") # return None, "获取连接失败" # def insert_to_slice(self, docs, knowledge_id, doc_id): # """ # 插入数据到切片信息表中 slice_info # """ # connection = None # cursor = None # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # values = [] # connection, cennction_info = self.get_connection() # if not connection: # return False, cennction_info # for chunk in docs: # slice_id = chunk.get("chunk_id") # slice_text = chunk.get("content") # chunk_index = chunk.get("metadata").get("chunk_index") # values.append((slice_id, knowledge_id, doc_id, slice_text, date_now, chunk_index)) # try: # cursor = connection.cursor() # # insert_sql = """ # # INSERT INTO slice_info ( # # slice_id, # # knowledge_id, # # document_id, # # slice_text, # # create_time, # # slice_index # # ) VALUES (%s, %s, %s, %s, %s,%s) # # """ # # 容错“for key 'UK_ID_TYPE_KEY'” # insert_sql = """ # INSERT INTO slice_info ( # slice_id, # knowledge_id, # document_id, # slice_text, # create_time, # slice_index # ) VALUES (%s, %s, %s, %s, %s, %s) # ON DUPLICATE KEY UPDATE # slice_text = VALUES(slice_text), # create_time = VALUES(create_time), # slice_index = VALUES(slice_index) # """ # cursor.executemany(insert_sql, values) # connection.commit() # logger.info(f"批量插入切片数据成功。") # return True, "success" # except Error as e: # logger.error(f"数据库操作出错:{e}") # connection.rollback() # return False, str(e) # finally: # # if cursor: # cursor.close() # # if connection and connection.is_connected(): # connection.close() # def delete_to_slice(self, doc_id): # """ # 删除 slice_info库中切片信息 # """ # connection = None # cursor = None # connection, connection_info = self.get_connection() # if not connection: # return False, connection_info # try: # cursor = connection.cursor() # delete_sql = f"DELETE FROM slice_info WHERE document_id = %s" # cursor.execute(delete_sql, (doc_id,)) # connection.commit() # logger.info(f"删除数据成功") # return True, "success" # except Error as e: # logger.error(f"根据{doc_id}删除数据失败:{e}") # connection.rollback() # return False, str(e) # finally: # # if cursor: # cursor.close() # # if connection and connection.is_connected(): # connection.close() # def insert_to_image_url(self, image_dict, knowledge_id, doc_id): # """ # 批量插入数据到指定表 # """ # connection = None # cursor = None # connection, connection_info = self.get_connection() # if not connection: # return False, connection_info # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # values = [] # for img_key, img_value in image_dict.items(): # origin_text = img_key # media_url = img_value # values.append((knowledge_id, doc_id, origin_text, "image", media_url, date_now)) # try: # cursor = connection.cursor() # # insert_sql = """ # # INSERT INTO bm_media_replacement ( # # knowledge_id, # # document_id, # # origin_text, # # media_type, # # media_url, # # create_time # # ) VALUES (%s, %s, %s, %s, %s, %s) # # """ # # 容错“for key 'UK_ID_TYPE_KEY'” # insert_sql = """ # INSERT INTO bm_media_replacement ( # knowledge_id, # document_id, # origin_text, # media_type, # media_url, # create_time # ) VALUES (%s, %s, %s, %s, %s, %s) # ON DUPLICATE KEY UPDATE # origin_text = VALUES(origin_text), # media_type = VALUES(media_type), # media_url = VALUES(media_url), # create_time = VALUES(create_time) # """ # cursor.executemany(insert_sql, values) # connection.commit() # logger.info(f"插入到bm_media_replacement表成功") # return True, "success" # except Error as e: # logger.error(f"数据库操作出错:{e}") # connection.rollback() # return False, str(e) # finally: # # if cursor: # cursor.close() # # if connection and connection.is_connected(): # connection.close() # def delete_image_url(self, doc_id): # """ # 根据doc id删除bm_media_replacement中的数据 # """ # connection = None # cursor = None # connection, connection_info = self.get_connection() # if not connection: # return False, connection_info # try: # cursor = connection.cursor() # delete_sql = f"DELETE FROM bm_media_replacement WHERE document_id = %s" # cursor.execute(delete_sql, (doc_id,)) # connection.commit() # logger.info(f"根据{doc_id} 删除bm_media_replacement表中数据成功") # return True, "success" # except Error as e: # logger.error(f"根据{doc_id}删除 bm_media_replacement 数据库操作出错:{e}") # connection.rollback() # return False, str(e) # finally: # # if cursor: # cursor.close() # # if connection and connection.is_connected(): # connection.close() # def update_total_doc_len(self, update_json): # """ # 更新长度表和文档长度表,删除slice info表, 插入slice info 切片信息 # """ # knowledge_id = update_json.get("knowledge_id") # doc_id = update_json.get("doc_id") # chunk_len = update_json.get("chunk_len") # operate = update_json.get("operate") # chunk_id = update_json.get("chunk_id") # chunk_text = update_json.get("chunk_text") # connection = None # cursor = None # connection, connection_info = self.get_connection() # if not connection: # return False, connection_info # try: # cursor = connection.cursor() # query_doc_word_num_sql = f"select word_num,slice_total from bm_document where document_id = %s" # query_knowledge_word_num_sql = f"select word_num from bm_knowledge where knowledge_id = %s" # cursor.execute(query_doc_word_num_sql, (doc_id,)) # doc_result = cursor.fetchone() # logger.info(f"查询到的文档长度信息:{doc_result}") # cursor.execute(query_knowledge_word_num_sql, (knowledge_id, )) # knowledge_result = cursor.fetchone() # logger.info(f"查询到的知识库总长度信息:{knowledge_result}") # if not doc_result: # new_word_num = 0 # slice_total = 0 # else: # old_word_num = doc_result[0] # slice_total = doc_result[1] # new_word_num = old_word_num + chunk_len # slice_total -= 1 if slice_total else 0 # if not knowledge_result: # new_knowledge_word_num = 0 # else: # old_knowledge_word_num = knowledge_result[0] # new_knowledge_word_num = old_knowledge_word_num + chunk_len # if operate == "update": # update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s" # cursor.execute(update_sql, (new_word_num, doc_id)) # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # update_slice_sql = f"UPDATE slice_info SET slice_text = %s, update_time = %s WHERE slice_id = %s" # cursor.execute(update_slice_sql, (chunk_text, date_now, chunk_id)) # elif operate == "insert": # query_slice_info_index_sql = f"select MAX(slice_index) from slice_info where document_id = %s" # cursor.execute(query_slice_info_index_sql, (doc_id,)) # chunk_index_result = cursor.fetchone()[0] # # logger.info(chunk_index_result) # if chunk_index_result: # chunk_max_index = int(chunk_index_result) # else: # chunk_max_index = 0 # update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s" # cursor.execute(update_sql, (new_word_num, doc_id)) # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # insert_slice_sql = "INSERT INTO slice_info (slice_id,knowledge_id,document_id,slice_text,create_time, slice_index) VALUES (%s, %s, %s, %s, %s, %s)" # cursor.execute(insert_slice_sql, (chunk_id, knowledge_id, doc_id, chunk_text, date_now, chunk_max_index+1)) # else: # update_sql = f"UPDATE bm_document SET word_num = %s, slice_total = %s WHERE document_id = %s" # cursor.execute(update_sql, (new_word_num, slice_total, doc_id)) # # 删除切片id对应的切片 # delete_slice_sql = f"DELETE FROM slice_info where slice_id = %s" # cursor.execute(delete_slice_sql, (chunk_id, )) # update_knowledge_sql = f"UPDATE bm_knowledge SET word_num = %s WHERE knowledge_id = %s" # cursor.execute(update_knowledge_sql, (new_knowledge_word_num, knowledge_id)) # connection.commit() # logger.info("bm_document和bm_knowledge数据更新成功") # return True, "success" # except Error as e: # logger.error(f"数据库操作出错:{e}") # connection.rollback() # return False, str(e) # finally: # # if cursor: # cursor.close() # # if connection and connection.is_connected(): # connection.close() import time # ========== 全局初始化连接池(自动检测 + 超时保护) ========== class SafeMySQLPool: def __init__(self, pool_size=10, conn_timeout=10, idle_timeout=60, **mysql_config): mysql_config.setdefault("connect_timeout", conn_timeout) mysql_config.setdefault("pool_reset_session", True) self._pool = pooling.MySQLConnectionPool( pool_name="safe_mysql_pool", pool_size=pool_size, **mysql_config ) self._lock = threading.Lock() self._active_conns = {} # {id(conn): (conn, last_used_time)} self._idle_timeout = idle_timeout self._stop_event = threading.Event() threading.Thread(target=self._auto_reclaimer, daemon=True).start() def get_connection(self, timeout=10): """安全获取连接(带超时检测与追踪)""" start = time.time() while True: try: conn = self._pool.get_connection() with self._lock: self._active_conns[id(conn)] = (conn, time.time()) return self._wrap_connection(conn) except errors.PoolError: if time.time() - start > timeout: raise TimeoutError(f"获取 MySQL 连接超时(超过 {timeout}s)") time.sleep(0.3) def _wrap_connection(self, conn): """包装连接对象以监控关闭事件""" pool = self orig_close = conn.close def safe_close(): try: orig_close() finally: with pool._lock: pool._active_conns.pop(id(conn), None) conn.close = safe_close return conn def _auto_reclaimer(self): """后台线程自动回收超时未关闭连接""" while not self._stop_event.is_set(): time.sleep(5) now = time.time() with self._lock: to_remove = [] for cid, (conn, last_used) in list(self._active_conns.items()): if now - last_used > self._idle_timeout: try: conn.close() print(f"[AutoReclaim] 已回收超时未关闭连接 (idle={int(now - last_used)}s)") except Exception as e: print(f"[AutoReclaim] 回收连接失败: {e}") to_remove.append(cid) for cid in to_remove: self._active_conns.pop(cid, None) def close_all(self): """停止守护线程并关闭所有连接""" self._stop_event.set() with self._lock: for conn, _ in self._active_conns.values(): try: conn.close() except: pass self._active_conns.clear() # ========== 初始化连接池 ========== if "POOL" not in globals(): try: POOL = SafeMySQLPool(pool_size=10, idle_timeout=60, **mysql_config) logger.info("MySQL 连接池初始化成功") except Error as e: logger.error(f"MySQL 连接池初始化失败: {e}") POOL = None # ========== MysqlOperate 类 ========== class MysqlOperate: def get_connection(self): """安全获取连接""" if not POOL: return None, "连接池未初始化" try: connection = POOL.get_connection(timeout=5) return connection, "success" except TimeoutError as e: logger.error(str(e)) return None, "获取连接超时" except Error as e: logger.error(f"MySQL 获取连接失败: {e}") return None, str(e) def _execute_many(self, sql, values, success_msg, err_msg): """通用批量执行模板""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() cursor.executemany(sql, values) connection.commit() logger.info(success_msg) return True, "success" except Error as e: connection.rollback() logger.error(f"{err_msg}: {e}") return False, str(e) finally: if cursor: cursor.close() connection.close() def insert_to_slice(self, docs, knowledge_id, doc_id): """批量插入切片信息""" date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') values = [ ( chunk.get("chunk_id"), knowledge_id, doc_id, chunk.get("content"), date_now, chunk.get("metadata", {}).get("chunk_index") ) for chunk in docs ] sql = """ INSERT INTO slice_info ( slice_id, knowledge_id, document_id, slice_text, create_time, slice_index ) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE slice_text = VALUES(slice_text), create_time = VALUES(create_time), slice_index = VALUES(slice_index) """ return self._execute_many(sql, values, "批量插入切片数据成功", "插入 slice_info 出错") def delete_to_slice(self, doc_id): """删除切片""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() cursor.execute("DELETE FROM slice_info WHERE document_id = %s", (doc_id,)) connection.commit() logger.info(f"删除 slice_info 数据成功") return True, "success" except Error as e: connection.rollback() logger.error(f"删除 slice_info 出错: {e}") return False, str(e) finally: if cursor: cursor.close() connection.close() def insert_to_image_url(self, image_dict, knowledge_id, doc_id): """插入图片映射表""" date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') values = [ (knowledge_id, doc_id, k, "image", v, date_now) for k, v in image_dict.items() ] sql = """ INSERT INTO bm_media_replacement ( knowledge_id, document_id, origin_text, media_type, media_url, create_time ) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE origin_text = VALUES(origin_text), media_type = VALUES(media_type), media_url = VALUES(media_url), create_time = VALUES(create_time) """ return self._execute_many(sql, values, "插入 bm_media_replacement 成功", "插入 bm_media_replacement 出错") def delete_image_url(self, doc_id): """删除图片映射""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() cursor.execute("DELETE FROM bm_media_replacement WHERE document_id = %s", (doc_id,)) connection.commit() logger.info(f"删除 bm_media_replacement 成功") return True, "success" except Error as e: connection.rollback() logger.error(f"删除 bm_media_replacement 出错: {e}") return False, str(e) finally: if cursor: cursor.close() connection.close() def delete_document(self, task_id): """删除 bm_document 表中的记录(取消任务前清理)""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() sql = "DELETE FROM bm_document WHERE document_id = %s" cursor.execute(sql, (task_id,)) affected_rows = cursor.rowcount connection.commit() logger.info(f"删除 bm_document 记录成功: task_id={task_id}, 影响行数={affected_rows}") return True, affected_rows except Error as e: connection.rollback() logger.error(f"删除 bm_document 记录失败: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close()