from rag.vector_db.milvus_vector import HybridRetriever from response_info import generate_message, generate_response from utils.get_logger import setup_logger from datetime import datetime from uuid import uuid1 import mysql.connector from mysql.connector import pooling, Error import threading from concurrent.futures import ThreadPoolExecutor, TimeoutError from config import milvus_uri, mysql_config logger = setup_logger(__name__) # uri = "http://localhost:19530" try: POOL = pooling.MySQLConnectionPool( pool_name="mysql_pool", pool_size=10, **mysql_config ) logger.info("MySQL 连接池初始化成功") except Error as e: logger.info(f"初始化 MySQL 连接池失败: {e}") POOL = None class MilvusOperate: def __init__(self, collection_name: str = "default", embedding_name:str = "e5"): self.collection = collection_name self.hybrid_retriever = HybridRetriever(uri=milvus_uri, embedding_name=embedding_name, collection_name=collection_name) self.mysql_client = MysqlOperate() def _has_collection(self): is_collection = self.hybrid_retriever.has_collection() return is_collection def _create_collection(self): if self._has_collection(): resp = {"code": 400, "message": "数据库已存在"} else: create_result = self.hybrid_retriever.build_collection() resp = generate_message(create_result) return resp def _delete_collection(self): delete_result = self.hybrid_retriever.delete_collection(self.collection) resp = generate_message(delete_result) return resp def _put_by_id(self, slice_json): slice_id = slice_json.get("slice_id", None) slice_text = slice_json.get("slice_text", None) update_result, chunk_len = self.hybrid_retriever.update_data(chunk_id=slice_id, chunk=slice_text) if update_result.endswith("success"): # 如果成功,更新mysql中知识库总长度和文档长度 update_json = {} update_json["knowledge_id"] = slice_json.get("knowledge_id") update_json["doc_id"] = slice_json.get("document_id") update_json["chunk_len"] = chunk_len update_json["operate"] = "update" update_json["chunk_id"] = slice_id update_json["chunk_text"] = slice_text update_flag, update_str = self.mysql_client.update_total_doc_len(update_json) else: update_flag = False if not update_flag: update_result = "update_error" resp = generate_message(update_result) return resp def _insert_slice(self, slice_json): slice_id = str(uuid1()) knowledge_id = slice_json.get("knowledge_id") doc_id = slice_json.get("document_id") slice_text = slice_json.get("slice_text", None) doc_name = slice_json.get("doc_name") chunk_len = len(slice_text) metadata = { "content": slice_text, "doc_id": doc_id, "chunk_id": slice_id, "metadata": {"source": doc_name, "chunk_len": chunk_len} } insert_flag, insert_str = self.hybrid_retriever.insert_data(slice_text, metadata) if insert_flag: # 如果成功,更新mysql中知识库总长度和文档长度 update_json = {} update_json["knowledge_id"] = slice_json.get("knowledge_id") update_json["doc_id"] = slice_json.get("document_id") update_json["chunk_len"] = chunk_len update_json["operate"] = "insert" update_json["chunk_id"] = slice_id update_json["chunk_text"] = slice_text update_flag, update_str = self.mysql_client.update_total_doc_len(update_json) else: logger.error(f"插入向量库出错:{insert_str}") update_flag = False update_str = "向量库写入出错" # pass if not update_flag: logger.error(f"新增切片中mysql数据库出错:{update_str}") insert_result = "insert_error" else: insert_result = "insert_success" resp = generate_message(insert_result) return resp def _delete_by_chunk_id(self, chunk_id, knowledge_id, document_id): logger.info(f"删除的切片id:{chunk_id}") delete_result, delete_chunk_len = self.hybrid_retriever.delete_by_chunk_id(chunk_id=chunk_id) if delete_result.endswith("success"): chunk_len = delete_chunk_len[0] update_json = { "knowledge_id": knowledge_id, "doc_id": document_id, "chunk_len": -chunk_len, "operate": "delete", "chunk_id": chunk_id } update_flag, update_str = self.mysql_client.update_total_doc_len(update_json) else: logger.error("根据chunk id删除向量库失败") update_flag = False update_str = "根据chunk id删除失败" if not update_flag: logger.error(update_str) delete_result = "delete_error" resp = generate_message(delete_result) return resp def _delete_by_doc_id(self, doc_id: str = None): logger.info(f"删除数据的id:{doc_id}") delete_result = self.hybrid_retriever.delete_by_doc_id(doc_id=doc_id) resp = generate_message(delete_result) return resp def _search_by_chunk_id(self, chunk_id): if self._has_collection(): query_result = self.hybrid_retriever.query_chunk_id(chunk_id=chunk_id) else: query_result = [] logger.info(f"根据切片查询到的信息:{query_result}") resp = generate_response(query_result) return resp def _search_by_chunk_id_list(self, chunk_id_list): if self._has_collection(): query_result = self.hybrid_retriever.query_chunk_id_list(chunk_id_list) else: query_result = [] logger.info(f"召回的切片列表查询切片信息:{query_result}") chunk_content_list = [] for chunk_dict in query_result: chunk_content = chunk_dict.get("content") chunk_content_list.append(chunk_content) return chunk_content_list def _search_by_key_word(self, search_json): if self._has_collection(): doc_id = search_json.get("document_id", None) text = search_json.get("text", None) page_num = search_json.get("pageNum", 1) page_size = search_json.get("pageSize", 10) page_num = search_json.get("pageNum") # 根据传过来的id处理对应知识库 query_result = self.hybrid_retriever.query_filter(doc_id=doc_id, filter_field=text) else: query_result = [] resp = generate_response(query_result,page_num,page_size) return resp def _insert_data(self, docs): for doc in docs: chunk = doc.get("content") insert_flag, insert_info = self.hybrid_retriever.insert_data(chunk, doc) if not insert_flag: break resp = insert_flag return resp, insert_info def _batch_insert_data(self, docs, text_lists): insert_flag, insert_info = self.hybrid_retriever.batch_insert_data(text_lists, docs) resp = insert_flag return resp, insert_info def _search(self, query, k, mode): search_result = self.hybrid_retriever.search(query, k, mode) return search_result class MysqlOperate: def get_connection(self): """ 从连接池中获取一个连接 :return: 数据库连接对象 """ # try: # with ThreadPoolExecutor() as executor: # future = executor.submit(POOL.get_connection) # connection = future.result(timeout=5.0) # 设置超时时间为5秒 # logger.info("成功从连接池获取连接") # return connection, "success" # except TimeoutError: # logger.error("获取mysql数据库连接池超时") # return None, "mysql获取连接池超时" # except errors.InterfaceError as e: # logger.error(f"MySQL 接口异常:{e}") # return None, "mysql接口异常" # except errors.OperationalError as e: # logger.error(f"MySQL 操作错误:{e}") # return None, "mysql 操作错误" # except Error as e: # logger.error(f"无法从连接池获取连接: {e}") # return None, str(e) connection = None event = threading.Event() def target(): nonlocal connection try: connection = POOL.get_connection() finally: event.set() thread = threading.Thread(target=target) thread.start() event.wait(timeout=5) if thread.is_alive(): # 超时处理 logger.error("获取连接超时") return None, "获取连接超时" else: if connection: return connection, "success" else: logger.error("获取连接失败") return None, "获取连接失败" def insert_to_slice(self, docs, knowledge_id, doc_id): """ 插入数据到切片信息表中 slice_info """ connection = None cursor = None date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') values = [] connection, cennction_info = self.get_connection() if not connection: return False, cennction_info for chunk in docs: slice_id = chunk.get("chunk_id") slice_text = chunk.get("content") chunk_index = chunk.get("metadata").get("chunk_index") values.append((slice_id, knowledge_id, doc_id, slice_text, date_now, chunk_index)) try: cursor = connection.cursor() insert_sql = """ INSERT INTO slice_info ( slice_id, knowledge_id, document_id, slice_text, create_time, slice_index ) VALUES (%s, %s, %s, %s, %s,%s) """ cursor.executemany(insert_sql, values) connection.commit() logger.info(f"批量插入切片数据成功。") return True, "success" except Error as e: logger.error(f"数据库操作出错:{e}") connection.rollback() return False, str(e) finally: # if cursor: cursor.close() # if connection and connection.is_connected(): connection.close() def delete_to_slice(self, doc_id): """ 删除 slice_info库中切片信息 """ connection = None cursor = None connection, connection_info = self.get_connection() if not connection: return False, connection_info try: cursor = connection.cursor() delete_sql = f"DELETE FROM slice_info WHERE document_id = %s" cursor.execute(delete_sql, (doc_id,)) connection.commit() logger.info(f"删除数据成功") return True, "success" except Error as e: logger.error(f"根据{doc_id}删除数据失败:{e}") connection.rollback() return False, str(e) finally: # if cursor: cursor.close() # if connection and connection.is_connected(): connection.close() def insert_to_image_url(self, image_dict, knowledge_id, doc_id): """ 批量插入数据到指定表 """ connection = None cursor = None connection, connection_info = self.get_connection() if not connection: return False, connection_info date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') values = [] for img_key, img_value in image_dict.items(): origin_text = img_key media_url = img_value values.append((knowledge_id, doc_id, origin_text, "image", media_url, date_now)) try: cursor = connection.cursor() insert_sql = """ INSERT INTO bm_media_replacement ( knowledge_id, document_id, origin_text, media_type, media_url, create_time ) VALUES (%s, %s, %s, %s, %s, %s) """ cursor.executemany(insert_sql, values) connection.commit() logger.info(f"插入到bm_media_replacement表成功") return True, "success" except Error as e: logger.error(f"数据库操作出错:{e}") connection.rollback() return False, str(e) finally: # if cursor: cursor.close() # if connection and connection.is_connected(): connection.close() def delete_image_url(self, doc_id): """ 根据doc id删除bm_media_replacement中的数据 """ connection = None cursor = None connection, connection_info = self.get_connection() if not connection: return False, connection_info try: cursor = connection.cursor() delete_sql = f"DELETE FROM bm_media_replacement WHERE document_id = %s" cursor.execute(delete_sql, (doc_id,)) connection.commit() logger.info(f"根据{doc_id} 删除bm_media_replacement表中数据成功") return True, "success" except Error as e: logger.error(f"根据{doc_id}删除 bm_media_replacement 数据库操作出错:{e}") connection.rollback() return False, str(e) finally: # if cursor: cursor.close() # if connection and connection.is_connected(): connection.close() def update_total_doc_len(self, update_json): """ 更新长度表和文档长度表,删除slice info表, 插入slice info 切片信息 """ knowledge_id = update_json.get("knowledge_id") doc_id = update_json.get("doc_id") chunk_len = update_json.get("chunk_len") operate = update_json.get("operate") chunk_id = update_json.get("chunk_id") chunk_text = update_json.get("chunk_text") connection = None cursor = None connection, connection_info = self.get_connection() if not connection: return False, connection_info try: cursor = connection.cursor() query_doc_word_num_sql = f"select word_num,slice_total from bm_document where document_id = %s" query_knowledge_word_num_sql = f"select word_num from bm_knowledge where knowledge_id = %s" cursor.execute(query_doc_word_num_sql, (doc_id,)) doc_result = cursor.fetchone() logger.info(f"查询到的文档长度信息:{doc_result}") cursor.execute(query_knowledge_word_num_sql, (knowledge_id, )) knowledge_result = cursor.fetchone() logger.info(f"查询到的知识库总长度信息:{knowledge_result}") if not doc_result: new_word_num = 0 slice_total = 0 else: old_word_num = doc_result[0] slice_total = doc_result[1] new_word_num = old_word_num + chunk_len slice_total -= 1 if slice_total else 0 if not knowledge_result: new_knowledge_word_num = 0 else: old_knowledge_word_num = knowledge_result[0] new_knowledge_word_num = old_knowledge_word_num + chunk_len if operate == "update": update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s" cursor.execute(update_sql, (new_word_num, doc_id)) date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') update_slice_sql = f"UPDATE slice_info SET slice_text = %s, update_time = %s WHERE slice_id = %s" cursor.execute(update_slice_sql, (chunk_text, date_now, chunk_id)) elif operate == "insert": query_slice_info_index_sql = f"select MAX(slice_index) from slice_info where document_id = %s" cursor.execute(query_slice_info_index_sql, (doc_id,)) chunk_index_result = cursor.fetchone()[0] # logger.info(chunk_index_result) if chunk_index_result: chunk_max_index = int(chunk_index_result) else: chunk_max_index = 0 update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s" cursor.execute(update_sql, (new_word_num, doc_id)) date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') insert_slice_sql = "INSERT INTO slice_info (slice_id,knowledge_id,document_id,slice_text,create_time, slice_index) VALUES (%s, %s, %s, %s, %s, %s)" cursor.execute(insert_slice_sql, (chunk_id, knowledge_id, doc_id, chunk_text, date_now, chunk_max_index+1)) else: update_sql = f"UPDATE bm_document SET word_num = %s, slice_total = %s WHERE document_id = %s" cursor.execute(update_sql, (new_word_num, slice_total, doc_id)) # 删除切片id对应的切片 delete_slice_sql = f"DELETE FROM slice_info where slice_id = %s" cursor.execute(delete_slice_sql, (chunk_id, )) update_knowledge_sql = f"UPDATE bm_knowledge SET word_num = %s WHERE knowledge_id = %s" cursor.execute(update_knowledge_sql, (new_knowledge_word_num, knowledge_id)) connection.commit() logger.info("bm_document和bm_knowledge数据更新成功") return True, "success" except Error as e: logger.error(f"数据库操作出错:{e}") connection.rollback() return False, str(e) finally: # if cursor: cursor.close() # if connection and connection.is_connected(): connection.close()