from rag.vector_db.milvus_vector import HybridRetriever from response_info import generate_message, generate_response from utils.get_logger import setup_logger from datetime import datetime from uuid import uuid1 import mysql.connector from mysql.connector import pooling, Error, errors import threading from concurrent.futures import ThreadPoolExecutor, TimeoutError from config import milvus_uri, mysql_config logger = setup_logger(__name__) # uri = "http://localhost:19530" # if 'POOL' not in globals(): # try: # POOL = pooling.MySQLConnectionPool( # pool_name="mysql_pool", # pool_size=10, # **mysql_config # ) # logger.info("MySQL 连接池初始化成功") # except Error as e: # logger.info(f"初始化 MySQL 连接池失败: {e}") # POOL = None class MilvusOperate: def __init__(self, collection_name: str = "default", embedding_name:str = "e5"): self.collection = collection_name self.hybrid_retriever = HybridRetriever(uri=milvus_uri, embedding_name=embedding_name, collection_name=collection_name) self.mysql_client = MysqlOperate() def _has_collection(self): is_collection = self.hybrid_retriever.has_collection() return is_collection def _create_collection(self): if self._has_collection(): resp = {"code": 400, "message": "数据库已存在"} else: create_result = self.hybrid_retriever.build_collection() resp = generate_message(create_result) return resp def _delete_collection(self): delete_result = self.hybrid_retriever.delete_collection(self.collection) resp = generate_message(delete_result) return resp def _put_by_id(self, slice_json): slice_id = slice_json.get("slice_id", None) slice_text = slice_json.get("slice_text", None) update_result, chunk_len = self.hybrid_retriever.update_data(chunk_id=slice_id, chunk=slice_text) if update_result.endswith("success"): # 如果成功,更新mysql中知识库总长度和文档长度 update_json = {} update_json["knowledge_id"] = slice_json.get("knowledge_id") update_json["doc_id"] = slice_json.get("document_id") update_json["chunk_len"] = chunk_len update_json["operate"] = "update" update_json["chunk_id"] = slice_id update_json["chunk_text"] = slice_text update_flag, update_str = self.mysql_client.update_total_doc_len(update_json) else: update_flag = False if not update_flag: update_result = "update_error" resp = generate_message(update_result) return resp def _insert_slice(self, slice_json): slice_id = str(uuid1()) knowledge_id = slice_json.get("knowledge_id") doc_id = slice_json.get("document_id") slice_text = slice_json.get("slice_text", None) doc_name = slice_json.get("doc_name") chunk_len = len(slice_text) metadata = { "content": slice_text, "doc_id": doc_id, "chunk_id": slice_id, "metadata": {"source": doc_name, "chunk_len": chunk_len}, "Chapter": slice_json.get("Chapter",""), "Father_Chapter": slice_json.get("Father_Chapter",""), } insert_flag, insert_str = self.hybrid_retriever.insert_data(slice_text, metadata) if insert_flag: # 如果成功,更新mysql中知识库总长度和文档长度 update_json = {} update_json["knowledge_id"] = slice_json.get("knowledge_id") update_json["doc_id"] = slice_json.get("document_id") update_json["chunk_len"] = chunk_len update_json["operate"] = "insert" update_json["chunk_id"] = slice_id update_json["chunk_text"] = slice_text update_json["slice_index"] = slice_json.get("slice_index") update_flag, update_str = self.mysql_client.update_total_doc_len(update_json) else: logger.error(f"插入向量库出错:{insert_str}") update_flag = False update_str = "向量库写入出错" # pass if not update_flag: logger.error(f"新增切片中mysql数据库出错:{update_str}") # insert_result = "insert_error" success = "insert_error" else: # insert_result = "insert_success" success = "insert_success" # resp = generate_message(insert_result) resp = {"status": success, "slice_id":slice_id} return resp def _delete_by_chunk_id(self, chunk_id, knowledge_id, document_id): logger.info(f"删除的切片id:{chunk_id}") delete_result, delete_chunk_len = self.hybrid_retriever.delete_by_chunk_id(chunk_id=chunk_id) if delete_result.endswith("success"): chunk_len = delete_chunk_len[0] update_json = { "knowledge_id": knowledge_id, "doc_id": document_id, "chunk_len": -chunk_len, "operate": "delete", "chunk_id": chunk_id } update_flag, update_str = self.mysql_client.update_total_doc_len(update_json) else: logger.error("根据chunk id删除向量库失败") update_flag = False update_str = "根据chunk id删除失败" if not update_flag: logger.error(update_str) delete_result = "delete_error" resp = generate_message(delete_result) return resp def _delete_by_doc_id(self, doc_id: str = None): logger.info(f"删除数据的id:{doc_id}") delete_result = self.hybrid_retriever.delete_by_doc_id(doc_id=doc_id) resp = generate_message(delete_result) return resp def _search_by_chunk_id(self, chunk_id): if self._has_collection(): query_result = self.hybrid_retriever.query_chunk_id(chunk_id=chunk_id) else: query_result = [] logger.info(f"根据切片查询到的信息:{query_result}") resp = generate_response(query_result) return resp def _search_by_chunk_id_list(self, chunk_id_list): if self._has_collection(): query_result = self.hybrid_retriever.query_chunk_id_list(chunk_id_list) else: query_result = [] logger.info(f"召回的切片列表查询切片信息:{query_result}") chunk_content_list = [] for chunk_dict in query_result: chunk_content = chunk_dict.get("content") chunk_content_list.append(chunk_content) return chunk_content_list def _search_by_key_word(self, search_json): if self._has_collection(): doc_id = search_json.get("document_id", None) text = search_json.get("text", None) page_num = search_json.get("pageNum", 1) page_size = search_json.get("pageSize", 10) page_num = search_json.get("pageNum") # 根据传过来的id处理对应知识库 query_result = self.hybrid_retriever.query_filter(doc_id=doc_id, filter_field=text) else: query_result = [] resp = generate_response(query_result,page_num,page_size) return resp def _insert_data(self, docs): insert_flag = "" insert_info = "" for doc in docs: chunk = doc.get("content") insert_flag, insert_info = self.hybrid_retriever.insert_data(chunk, doc) if not insert_flag: break resp = insert_flag if insert_flag else "insert_error" return resp, insert_info def _batch_insert_data(self, docs, text_lists): insert_flag, insert_info = self.hybrid_retriever.batch_insert_data(text_lists, docs) resp = insert_flag return resp, insert_info def _search(self, query, k, mode): search_result = self.hybrid_retriever.search(query, k, mode) return search_result def _query_by_scalar_field(self, doc_id: str, field_name: str, field_value: str): """ 根据标量字段查询数据 参数: doc_id: 文档ID field_name: 字段名(如 Father_Chapter) field_value: 字段值 返回: 查询结果列表 """ return self.hybrid_retriever.query_by_scalar_field(doc_id, field_name, field_value) def _copy_docs_to_new_collection(self, new_collection_name, doc_ids, embedding_name="e5"): """ 将指定的文档数据复制到新集合或现有集合 使用雪花算法生成新的 doc_id 和 chunk_id 参数: new_collection_name: 目标集合名称(也是新的 knowledge_id) doc_ids: 要复制的文档ID列表 embedding_name: 向量模型名称 返回: 响应字典,包含 doc_id_mapping 映射关系 """ try: # 1. 从源集合查询数据 logger.info(f"从集合 {self.collection} 查询文档: {doc_ids}") query_results = self.hybrid_retriever.query_by_doc_ids(doc_ids) if not query_results: return {"code": 404, "message": "未找到匹配的文档数据", "doc_id_mapping": {}, "chunk_id_mapping": {}} logger.info(f"查询到 {len(query_results)} 条数据") # 2. 检查目标集合是否存在,不存在则创建 target_milvus_client = MilvusOperate(collection_name=new_collection_name, embedding_name=embedding_name) collection_exists = target_milvus_client._has_collection() if not collection_exists: logger.info(f"创建新集合: {new_collection_name}") create_result = target_milvus_client._create_collection() if create_result.get("code") != 200: create_result["doc_id_mapping"] = {} return create_result else: logger.info(f"集合 {new_collection_name} 已存在,直接插入数据") # 3. 为每个源 doc_id 生成新的 doc_id(使用雪花算法) doc_id_mapping = {} # {old_doc_id: new_doc_id} for old_doc_id in doc_ids: doc_id_mapping[old_doc_id] = generate_snowflake_id() # 4. 准备插入数据(移除pk字段,使用新的 doc_id 和 chunk_id) insert_data = [] chunk_id_mapping = {} # {old_chunk_id: new_chunk_id} for item in query_results: old_doc_id = item.get("doc_id") old_chunk_id = item.get("chunk_id") new_doc_id = doc_id_mapping.get(old_doc_id, generate_snowflake_id()) new_chunk_id = generate_snowflake_id() # 每个切片生成新的 chunk_id chunk_id_mapping[old_chunk_id] = new_chunk_id # 记录映射 new_item = { "content": item.get("content"), "dense_vector": item.get("dense_vector"), "doc_id": new_doc_id, "chunk_id": new_chunk_id, "Father_Chapter": item.get("Father_Chapter"), "Chapter": item.get("Chapter"), "metadata": item.get("metadata") } insert_data.append(new_item) # 5. 批量插入数据到目标集合 logger.info(f"开始向集合 {new_collection_name} 插入 {len(insert_data)} 条数据") try: target_milvus_client.hybrid_retriever.client.insert( collection_name=new_collection_name, data=insert_data ) logger.info(f"成功向集合插入数据") return { "code": 200, "message": "复制成功", "data": { "source_collection": self.collection, "target_collection": new_collection_name, "doc_ids": doc_ids, "total_records": len(insert_data), "collection_existed": collection_exists }, "doc_id_mapping": doc_id_mapping, "chunk_id_mapping": chunk_id_mapping } except Exception as e: logger.error(f"插入数据到集合失败: {e}") # 插入失败且是新创建的集合时删除 if not collection_exists: target_milvus_client._delete_collection() return {"code": 500, "message": f"插入数据失败: {str(e)}", "doc_id_mapping": {}, "chunk_id_mapping": {}} except Exception as e: logger.error(f"复制文档到新集合失败: {e}") return {"code": 500, "message": f"复制失败: {str(e)}", "doc_id_mapping": {}, "chunk_id_mapping": {}} # class MysqlOperate: # def get_connection(self): # """ # 从连接池中获取一个连接 # :return: 数据库连接对象 # """ # # try: # # with ThreadPoolExecutor() as executor: # # future = executor.submit(POOL.get_connection) # # connection = future.result(timeout=5.0) # 设置超时时间为5秒 # # logger.info("成功从连接池获取连接") # # return connection, "success" # # except TimeoutError: # # logger.error("获取mysql数据库连接池超时") # # return None, "mysql获取连接池超时" # # except errors.InterfaceError as e: # # logger.error(f"MySQL 接口异常:{e}") # # return None, "mysql接口异常" # # except errors.OperationalError as e: # # logger.error(f"MySQL 操作错误:{e}") # # return None, "mysql 操作错误" # # except Error as e: # # logger.error(f"无法从连接池获取连接: {e}") # # return None, str(e) # connection = None # event = threading.Event() # def target(): # nonlocal connection # try: # connection = POOL.get_connection() # finally: # event.set() # thread = threading.Thread(target=target) # thread.start() # event.wait(timeout=5) # if thread.is_alive(): # # 超时处理 # logger.error("获取连接超时") # return None, "获取连接超时" # else: # if connection: # return connection, "success" # else: # logger.error("获取连接失败") # return None, "获取连接失败" # def insert_to_slice(self, docs, knowledge_id, doc_id): # """ # 插入数据到切片信息表中 slice_info # """ # connection = None # cursor = None # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # values = [] # connection, cennction_info = self.get_connection() # if not connection: # return False, cennction_info # for chunk in docs: # slice_id = chunk.get("chunk_id") # slice_text = chunk.get("content") # chunk_index = chunk.get("metadata").get("chunk_index") # values.append((slice_id, knowledge_id, doc_id, slice_text, date_now, chunk_index)) # try: # cursor = connection.cursor() # # insert_sql = """ # # INSERT INTO slice_info ( # # slice_id, # # knowledge_id, # # document_id, # # slice_text, # # create_time, # # slice_index # # ) VALUES (%s, %s, %s, %s, %s,%s) # # """ # # 容错“for key 'UK_ID_TYPE_KEY'” # insert_sql = """ # INSERT INTO slice_info ( # slice_id, # knowledge_id, # document_id, # slice_text, # create_time, # slice_index # ) VALUES (%s, %s, %s, %s, %s, %s) # ON DUPLICATE KEY UPDATE # slice_text = VALUES(slice_text), # create_time = VALUES(create_time), # slice_index = VALUES(slice_index) # """ # cursor.executemany(insert_sql, values) # connection.commit() # logger.info(f"批量插入切片数据成功。") # return True, "success" # except Error as e: # logger.error(f"数据库操作出错:{e}") # connection.rollback() # return False, str(e) # finally: # # if cursor: # cursor.close() # # if connection and connection.is_connected(): # connection.close() # def delete_to_slice(self, doc_id): # """ # 删除 slice_info库中切片信息 # """ # connection = None # cursor = None # connection, connection_info = self.get_connection() # if not connection: # return False, connection_info # try: # cursor = connection.cursor() # delete_sql = f"DELETE FROM slice_info WHERE document_id = %s" # cursor.execute(delete_sql, (doc_id,)) # connection.commit() # logger.info(f"删除数据成功") # return True, "success" # except Error as e: # logger.error(f"根据{doc_id}删除数据失败:{e}") # connection.rollback() # return False, str(e) # finally: # # if cursor: # cursor.close() # # if connection and connection.is_connected(): # connection.close() # def insert_to_image_url(self, image_dict, knowledge_id, doc_id): # """ # 批量插入数据到指定表 # """ # connection = None # cursor = None # connection, connection_info = self.get_connection() # if not connection: # return False, connection_info # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # values = [] # for img_key, img_value in image_dict.items(): # origin_text = img_key # media_url = img_value # values.append((knowledge_id, doc_id, origin_text, "image", media_url, date_now)) # try: # cursor = connection.cursor() # # insert_sql = """ # # INSERT INTO bm_media_replacement ( # # knowledge_id, # # document_id, # # origin_text, # # media_type, # # media_url, # # create_time # # ) VALUES (%s, %s, %s, %s, %s, %s) # # """ # # 容错“for key 'UK_ID_TYPE_KEY'” # insert_sql = """ # INSERT INTO bm_media_replacement ( # knowledge_id, # document_id, # origin_text, # media_type, # media_url, # create_time # ) VALUES (%s, %s, %s, %s, %s, %s) # ON DUPLICATE KEY UPDATE # origin_text = VALUES(origin_text), # media_type = VALUES(media_type), # media_url = VALUES(media_url), # create_time = VALUES(create_time) # """ # cursor.executemany(insert_sql, values) # connection.commit() # logger.info(f"插入到bm_media_replacement表成功") # return True, "success" # except Error as e: # logger.error(f"数据库操作出错:{e}") # connection.rollback() # return False, str(e) # finally: # # if cursor: # cursor.close() # # if connection and connection.is_connected(): # connection.close() # def delete_image_url(self, doc_id): # """ # 根据doc id删除bm_media_replacement中的数据 # """ # connection = None # cursor = None # connection, connection_info = self.get_connection() # if not connection: # return False, connection_info # try: # cursor = connection.cursor() # delete_sql = f"DELETE FROM bm_media_replacement WHERE document_id = %s" # cursor.execute(delete_sql, (doc_id,)) # connection.commit() # logger.info(f"根据{doc_id} 删除bm_media_replacement表中数据成功") # return True, "success" # except Error as e: # logger.error(f"根据{doc_id}删除 bm_media_replacement 数据库操作出错:{e}") # connection.rollback() # return False, str(e) # finally: # # if cursor: # cursor.close() # # if connection and connection.is_connected(): # connection.close() # def update_total_doc_len(self, update_json): # """ # 更新长度表和文档长度表,删除slice info表, 插入slice info 切片信息 # """ # knowledge_id = update_json.get("knowledge_id") # doc_id = update_json.get("doc_id") # chunk_len = update_json.get("chunk_len") # operate = update_json.get("operate") # chunk_id = update_json.get("chunk_id") # chunk_text = update_json.get("chunk_text") # connection = None # cursor = None # connection, connection_info = self.get_connection() # if not connection: # return False, connection_info # try: # cursor = connection.cursor() # query_doc_word_num_sql = f"select word_num,slice_total from bm_document where document_id = %s" # query_knowledge_word_num_sql = f"select word_num from bm_knowledge where knowledge_id = %s" # cursor.execute(query_doc_word_num_sql, (doc_id,)) # doc_result = cursor.fetchone() # logger.info(f"查询到的文档长度信息:{doc_result}") # cursor.execute(query_knowledge_word_num_sql, (knowledge_id, )) # knowledge_result = cursor.fetchone() # logger.info(f"查询到的知识库总长度信息:{knowledge_result}") # if not doc_result: # new_word_num = 0 # slice_total = 0 # else: # old_word_num = doc_result[0] # slice_total = doc_result[1] # new_word_num = old_word_num + chunk_len # slice_total -= 1 if slice_total else 0 # if not knowledge_result: # new_knowledge_word_num = 0 # else: # old_knowledge_word_num = knowledge_result[0] # new_knowledge_word_num = old_knowledge_word_num + chunk_len # if operate == "update": # update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s" # cursor.execute(update_sql, (new_word_num, doc_id)) # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # update_slice_sql = f"UPDATE slice_info SET slice_text = %s, update_time = %s WHERE slice_id = %s" # cursor.execute(update_slice_sql, (chunk_text, date_now, chunk_id)) # elif operate == "insert": # query_slice_info_index_sql = f"select MAX(slice_index) from slice_info where document_id = %s" # cursor.execute(query_slice_info_index_sql, (doc_id,)) # chunk_index_result = cursor.fetchone()[0] # # logger.info(chunk_index_result) # if chunk_index_result: # chunk_max_index = int(chunk_index_result) # else: # chunk_max_index = 0 # update_sql = f"UPDATE bm_document SET word_num = %s WHERE document_id = %s" # cursor.execute(update_sql, (new_word_num, doc_id)) # date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # insert_slice_sql = "INSERT INTO slice_info (slice_id,knowledge_id,document_id,slice_text,create_time, slice_index) VALUES (%s, %s, %s, %s, %s, %s)" # cursor.execute(insert_slice_sql, (chunk_id, knowledge_id, doc_id, chunk_text, date_now, chunk_max_index+1)) # else: # update_sql = f"UPDATE bm_document SET word_num = %s, slice_total = %s WHERE document_id = %s" # cursor.execute(update_sql, (new_word_num, slice_total, doc_id)) # # 删除切片id对应的切片 # delete_slice_sql = f"DELETE FROM slice_info where slice_id = %s" # cursor.execute(delete_slice_sql, (chunk_id, )) # update_knowledge_sql = f"UPDATE bm_knowledge SET word_num = %s WHERE knowledge_id = %s" # cursor.execute(update_knowledge_sql, (new_knowledge_word_num, knowledge_id)) # connection.commit() # logger.info("bm_document和bm_knowledge数据更新成功") # return True, "success" # except Error as e: # logger.error(f"数据库操作出错:{e}") # connection.rollback() # return False, str(e) # finally: # # if cursor: # cursor.close() # # if connection and connection.is_connected(): # connection.close() TABLE_NAME = "bm_document" STATUS_FIELD = "update_by" TASK_ID_FIELD = "document_id" USER_ID_FIELD = "remark" import time # ========== 全局初始化连接池(自动检测 + 超时保护) ========== class SafeMySQLPool: def __init__(self, pool_size=10, conn_timeout=10, idle_timeout=60, **mysql_config): mysql_config.setdefault("connect_timeout", conn_timeout) mysql_config.setdefault("pool_reset_session", True) self._pool = pooling.MySQLConnectionPool( pool_name="safe_mysql_pool", pool_size=pool_size, **mysql_config ) # 使用 RLock 可重入锁,避免 _auto_reclaimer 调用 close 时死锁 self._lock = threading.RLock() self._active_conns = {} # {id(conn): (conn, last_used_time)} self._idle_timeout = idle_timeout self._stop_event = threading.Event() threading.Thread(target=self._auto_reclaimer, daemon=True).start() def get_connection(self, timeout=10): """安全获取连接(带超时检测与追踪)""" start = time.time() while True: try: conn = self._pool.get_connection() conn.ping(reconnect=True, attempts=3, delay=2) with self._lock: self._active_conns[id(conn)] = (conn, time.time()) return self._wrap_connection(conn) except errors.PoolError: if time.time() - start > timeout: raise TimeoutError(f"获取 MySQL 连接超时(超过 {timeout}s)") time.sleep(0.3) def _wrap_connection(self, conn): """包装连接对象以监控关闭事件""" pool = self orig_close = conn.close def safe_close(): try: orig_close() finally: with pool._lock: pool._active_conns.pop(id(conn), None) conn.close = safe_close return conn def _auto_reclaimer(self): """后台线程自动回收超时未关闭连接""" while not self._stop_event.is_set(): time.sleep(5) now = time.time() with self._lock: to_remove = [] for cid, (conn, last_used) in list(self._active_conns.items()): if now - last_used > self._idle_timeout: try: conn.close() logger.warning(f"[回收] 已回收超时未关闭连接 (idle={int(now - last_used)}s)") except Exception as e: logger.error(f"[回收] 回收连接失败: {e}") to_remove.append(cid) for cid in to_remove: self._active_conns.pop(cid, None) def close_all(self): """停止守护线程并关闭所有连接""" self._stop_event.set() with self._lock: for conn, _ in self._active_conns.values(): try: conn.close() except: pass self._active_conns.clear() # ========== 初始化连接池 ========== if "POOL" not in globals(): try: POOL = SafeMySQLPool(pool_size=10, idle_timeout=60, **mysql_config) logger.info("MySQL 连接池初始化成功") except Error as e: logger.error(f"MySQL 连接池初始化失败: {e}") POOL = None """ 雪花算法获取唯一id """ import time import threading class Snowflake: def __init__(self, datacenter_id=1, worker_id=1): self.worker_id_bits = 5 self.datacenter_id_bits = 5 self.sequence_bits = 12 self.max_worker_id = -1 ^ (-1 << self.worker_id_bits) self.max_datacenter_id = -1 ^ (-1 << self.datacenter_id_bits) self.worker_id = worker_id self.datacenter_id = datacenter_id self.sequence = 0 self.worker_id_shift = self.sequence_bits self.datacenter_id_shift = self.sequence_bits + self.worker_id_bits self.timestamp_left_shift = self.sequence_bits + self.worker_id_bits + self.datacenter_id_bits self.twepoch = 1288834974657 self.last_timestamp = -1 self.lock = threading.Lock() def _timestamp(self): return int(time.time() * 1000) def _wait_next_ms(self, last): ts = self._timestamp() while ts <= last: ts = self._timestamp() return ts def generate_id(self): with self.lock: timestamp = self._timestamp() if timestamp < self.last_timestamp: raise Exception("Clock moved backwards, refusing to generate id") if timestamp == self.last_timestamp: self.sequence = (self.sequence + 1) & ((1 << self.sequence_bits) - 1) if self.sequence == 0: timestamp = self._wait_next_ms(timestamp) else: self.sequence = 0 self.last_timestamp = timestamp snowflake_id = ( ((timestamp - self.twepoch) << self.timestamp_left_shift) | (self.datacenter_id << self.datacenter_id_shift) | (self.worker_id << self.worker_id_shift) | self.sequence ) # 转成固定 20 位数字 return str(snowflake_id).zfill(20) # ========== 全局雪花算法 ID 生成器 ========== snowflake_id_generator = Snowflake(datacenter_id=1, worker_id=1) def generate_snowflake_id(): """ 全局方法:使用雪花算法生成唯一ID 返回20位数字字符串 """ return snowflake_id_generator.generate_id() # ========== MysqlOperate 类 ========== class MysqlOperate: def get_connection(self): """安全获取连接""" if not POOL: return None, "连接池未初始化" try: connection = POOL.get_connection(timeout=5) return connection, "success" except TimeoutError as e: logger.error(str(e)) return None, "获取连接超时" except Error as e: logger.error(f"MySQL 获取连接失败: {e}") return None, str(e) def _execute_many(self, sql, values, success_msg, err_msg): """通用批量执行模板""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() cursor.executemany(sql, values) connection.commit() logger.info(success_msg) return True, "success" except Error as e: connection.rollback() logger.error(f"{err_msg}: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def insert_to_slice(self, docs, knowledge_id, doc_id): """批量插入切片信息(同时存储 slice_text 和 old_slice_text)""" date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') values = [ ( chunk.get("chunk_id"), knowledge_id, doc_id, chunk.get("content"), chunk.get("content"), # old_slice_text 存储原始文本 date_now, chunk.get("metadata", {}).get("chunk_index"), chunk.get("Chapter"), chunk.get("Father_Chapter", "") ) for chunk in docs ] sql = """ INSERT INTO slice_info ( slice_id, knowledge_id, document_id, slice_text, old_slice_text, create_time, slice_index, section, parent_section ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE slice_text = VALUES(slice_text), create_time = VALUES(create_time), slice_index = VALUES(slice_index), section = VALUES(section), parent_section = VALUES(parent_section) """ return self._execute_many(sql, values, "批量插入切片数据成功", "插入 slice_info 出错") def delete_to_slice(self, doc_id): """删除切片""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() cursor.execute("DELETE FROM slice_info WHERE document_id = %s", (doc_id,)) connection.commit() logger.info(f"删除 slice_info 数据成功") return True, "success" except Error as e: connection.rollback() logger.error(f"删除 slice_info 出错: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def insert_to_image_url(self, image_dict, knowledge_id, doc_id): """插入图片映射表""" date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') values = [ (knowledge_id, doc_id, k, "image", v, date_now) for k, v in image_dict.items() ] sql = """ INSERT INTO bm_media_replacement ( knowledge_id, document_id, origin_text, media_type, media_url, create_time ) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE origin_text = VALUES(origin_text), media_type = VALUES(media_type), media_url = VALUES(media_url), create_time = VALUES(create_time) """ return self._execute_many(sql, values, "插入 bm_media_replacement 成功", "插入 bm_media_replacement 出错") def delete_image_url(self, doc_id): """删除图片映射""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() cursor.execute("DELETE FROM bm_media_replacement WHERE document_id = %s", (doc_id,)) connection.commit() logger.info(f"删除 bm_media_replacement 成功") return True, "success" except Error as e: connection.rollback() logger.error(f"删除 bm_media_replacement 出错: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def update_task_status_start(self, task_id): """更新任务状态为开始(1)""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() sql = f"UPDATE {TABLE_NAME} SET {STATUS_FIELD} = %s WHERE {TASK_ID_FIELD} = %s" cursor.execute(sql, (1, task_id)) connection.commit() logger.info(f"任务 {task_id} 状态更新为开始(1)") return True, "success" except Error as e: connection.rollback() logger.error(f"更新任务状态为开始失败: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def update_task_status_complete(self, task_id, user_id): """更新任务状态为完成(2)并更新用户ID""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() sql = f"UPDATE {TABLE_NAME} SET {STATUS_FIELD} = %s, {USER_ID_FIELD} = %s WHERE {TASK_ID_FIELD} = %s" cursor.execute(sql, (2, user_id, task_id)) connection.commit() logger.info(f"任务 {task_id} 状态更新为完成(2),用户ID: {user_id}") return True, "success" except Error as e: connection.rollback() logger.error(f"更新任务状态为完成失败: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def update_task_status_error(self, task_id): """更新任务状态为错误(0)""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() sql = f"UPDATE {TABLE_NAME} SET {STATUS_FIELD} = %s WHERE {TASK_ID_FIELD} = %s" cursor.execute(sql, (0, task_id)) connection.commit() logger.info(f"任务 {task_id} 状态更新为错误(0)") return True, "success" except Error as e: connection.rollback() logger.error(f"更新任务状态为错误失败: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def delete_document(self, task_id): """删除 bm_document 表中的记录(取消任务前清理)""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() sql = f"DELETE FROM {TABLE_NAME} WHERE {TASK_ID_FIELD} = %s" cursor.execute(sql, (task_id,)) affected_rows = cursor.rowcount connection.commit() logger.info(f"删除 bm_document 记录成功: task_id={task_id}, 影响行数={affected_rows}") return True, affected_rows except Error as e: connection.rollback() logger.error(f"删除 bm_document 记录失败: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def query_parent_generation_enabled(self, doc_ids: list): """ 查询文档的 parent_generation_enabled 字段 返回 parent_generation_enabled=1 的 doc_id 集合 """ if not doc_ids: return set() connection, info = self.get_connection() if not connection: return set() cursor = None try: cursor = connection.cursor() placeholders = ", ".join(["%s"] * len(doc_ids)) sql = f"SELECT document_id FROM bm_document WHERE document_id IN ({placeholders}) AND parent_generation_enabled = 1" cursor.execute(sql, tuple(doc_ids)) results = cursor.fetchall() return {row[0] for row in results} except Error as e: logger.error(f"查询 parent_generation_enabled 失败: {e}") return set() finally: if cursor: cursor.close() if connection: connection.close() def insert_oss_record(self, file_name, url, tenant_id="000000", file_extension=""): """插入 OSS 记录到 sys_oss 表""" connection, info = self.get_connection() if not connection: return False, info cursor = None try: # 生成类似 "a2922173479520702464" 的 oss_id 20位 # import random # oss_id = f"a{int(time.time() * 1000)}{random.randint(1000, 9999)}" oss_id = Snowflake().generate_id() create_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') cursor = connection.cursor() sql = """ INSERT INTO sys_oss (oss_id, tenant_id, file_name, original_name, url, create_time, file_suffix) VALUES (%s, %s, %s, %s, %s, %s, %s) """ cursor.execute(sql, (oss_id, tenant_id, file_name, file_name, url, create_time, file_extension)) connection.commit() logger.info(f"OSS 记录插入成功: {oss_id}") return True, oss_id except Error as e: connection.rollback() logger.error(f"插入 OSS 记录失败: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def update_total_doc_len(self, update_json): """ 更新长度表和文档长度表,删除/更新/插入 slice_info 表 """ knowledge_id = update_json.get("knowledge_id") doc_id = update_json.get("doc_id") chunk_len = update_json.get("chunk_len") operate = update_json.get("operate") chunk_id = update_json.get("chunk_id") chunk_text = update_json.get("chunk_text", "") slice_index = update_json.get("slice_index") connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor() # 查询文档当前信息 cursor.execute( "SELECT word_num, slice_total FROM bm_document WHERE document_id = %s", (doc_id,) ) doc_result = cursor.fetchone() logger.info(f"查询到的文档长度信息:{doc_result}") # 查询知识库当前信息 cursor.execute( "SELECT word_num FROM bm_knowledge WHERE knowledge_id = %s", (knowledge_id,) ) knowledge_result = cursor.fetchone() logger.info(f"查询到的知识库总长度信息:{knowledge_result}") # 计算新的文档长度 if not doc_result: new_word_num = chunk_len if chunk_len > 0 else 0 slice_total = 0 else: old_word_num = doc_result[0] or 0 slice_total = doc_result[1] or 0 new_word_num = old_word_num + chunk_len if operate == "delete": slice_total = max(0, slice_total - 1) # 计算新的知识库长度 if not knowledge_result: new_knowledge_word_num = chunk_len if chunk_len > 0 else 0 else: old_knowledge_word_num = knowledge_result[0] or 0 new_knowledge_word_num = old_knowledge_word_num + chunk_len date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') # 根据操作类型执行不同的 SQL if operate == "update": # # 更新文档长度 # cursor.execute( # "UPDATE bm_document SET word_num = %s WHERE document_id = %s", # (new_word_num, doc_id) # ) # 更新切片内容 cursor.execute( "UPDATE slice_info SET slice_text = %s, update_time = %s WHERE slice_id = %s", (chunk_text, date_now, chunk_id) ) elif operate == "insert": # 获取最大切片索引 # cursor.execute( # "SELECT MAX(slice_index) FROM slice_info WHERE document_id = %s", # (doc_id,) # ) # chunk_index_result = cursor.fetchone()[0] # chunk_max_index = int(chunk_index_result) if chunk_index_result else 0 # # 更新文档长度 # cursor.execute( # "UPDATE bm_document SET word_num = %s WHERE document_id = %s", # (new_word_num, doc_id) # ) # 插入新切片 cursor.execute( """INSERT INTO slice_info (slice_id, knowledge_id, document_id, slice_text, create_time, slice_index, old_slice_text) VALUES (%s, %s, %s, %s, %s, %s, %s)""", (chunk_id, knowledge_id, doc_id, chunk_text, date_now, int(slice_index), chunk_text) ) elif operate == "delete": # # 更新文档长度和切片总数 # cursor.execute( # "UPDATE bm_document SET word_num = %s, slice_total = %s WHERE document_id = %s", # (new_word_num, slice_total, doc_id) # ) # 删除切片 # cursor.execute( # "DELETE FROM slice_info WHERE slice_id = %s", # (chunk_id,) # ) cursor.execute( "UPDATE slice_info SET del_flag = 1 WHERE slice_id = %s", (chunk_id,) ) # # 更新知识库总长度 # cursor.execute( # "UPDATE bm_knowledge SET word_num = %s WHERE knowledge_id = %s", # (new_knowledge_word_num, knowledge_id) # ) connection.commit() logger.info("bm_document 和 bm_knowledge 数据更新成功") return True, "success" except Error as e: connection.rollback() logger.error(f"update_total_doc_len 数据库操作出错:{e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def query_slice_info_by_doc_ids(self, knowledge_id: str, doc_ids: list): """ 根据 knowledge_id 和 doc_ids 列表查询 slice_info 表中的数据 参数: knowledge_id: 知识库ID doc_ids: 文档ID列表 返回: (success, data_or_error): 成功时返回数据列表,失败时返回错误信息 """ if not doc_ids: return True, [] connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor(dictionary=True) placeholders = ','.join(['%s'] * len(doc_ids)) sql = f""" SELECT slice_id, knowledge_id, document_id, slice_text, old_slice_text, create_time, update_time, slice_index FROM slice_info WHERE knowledge_id = %s AND document_id IN ({placeholders}) """ cursor.execute(sql, (knowledge_id, *doc_ids)) results = cursor.fetchall() logger.info(f"查询 slice_info 成功,共 {len(results)} 条记录") return True, results except Error as e: logger.error(f"查询 slice_info 出错: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def query_media_replacement_by_doc_ids(self, knowledge_id: str, doc_ids: list): """ 根据 knowledge_id 和 doc_ids 列表查询 bm_media_replacement 表中的数据 参数: knowledge_id: 知识库ID doc_ids: 文档ID列表 返回: (success, data_or_error): 成功时返回数据列表,失败时返回错误信息 """ if not doc_ids: return True, [] connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor(dictionary=True) placeholders = ','.join(['%s'] * len(doc_ids)) sql = f""" SELECT knowledge_id, document_id, origin_text, media_type, media_url, create_time FROM bm_media_replacement WHERE knowledge_id = %s AND document_id IN ({placeholders}) """ cursor.execute(sql, (knowledge_id, *doc_ids)) results = cursor.fetchall() logger.info(f"查询 bm_media_replacement 成功,共 {len(results)} 条记录") return True, results except Error as e: logger.error(f"查询 bm_media_replacement 出错: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def query_bm_document_by_doc_ids(self, knowledge_id: str, doc_ids: list): """ 根据 knowledge_id 和 doc_ids 列表查询 bm_document 表中的数据 参数: knowledge_id: 知识库ID doc_ids: 文档ID列表 返回: (success, data_or_error): 成功时返回数据列表,失败时返回错误信息 """ if not doc_ids: return True, [] connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor(dictionary=True) placeholders = ','.join(['%s'] * len(doc_ids)) sql = f""" SELECT * FROM bm_document WHERE knowledge_id = %s AND document_id IN ({placeholders}) """ cursor.execute(sql, (knowledge_id, *doc_ids)) results = cursor.fetchall() logger.info(f"查询 bm_document 成功,共 {len(results)} 条记录") return True, results except Error as e: logger.error(f"查询 bm_document 出错: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def copy_docs_metadata_to_new_knowledge(self, source_knowledge_id: str, source_doc_ids: list, new_knowledge_id: str, doc_id_mapping: dict = None, chunk_id_mapping: dict = None, tenant_id: str = "000000"): """ 复制文档元数据到新知识库(bm_document、slice_info 和 bm_media_replacement) 参数: source_knowledge_id: 源知识库ID source_doc_ids: 源文档ID列表 new_knowledge_id: 新的知识库ID doc_id_mapping: 旧文档ID到新文档ID的映射 {old_doc_id: new_doc_id} 如果不提供,将为每个文档生成新的雪花算法ID chunk_id_mapping: 旧切片ID到新切片ID的映射 {old_chunk_id: new_chunk_id} 用于 slice_id,保持与向量库一致 tenant_id: 租户ID(由前端传入) 返回: (success, result_dict): result_dict 包含复制结果详情 """ if not source_doc_ids: return True, {"message": "无需复制的文档", "doc_id_mapping": {}} # 如果没有提供映射,为每个源文档生成新的 doc_id if doc_id_mapping is None: doc_id_mapping = {} for old_doc_id in source_doc_ids: doc_id_mapping[old_doc_id] = generate_snowflake_id() # 1. 查询源 slice_info 数据 success, slice_data = self.query_slice_info_by_doc_ids(source_knowledge_id, source_doc_ids) if not success: return False, {"error": f"查询切片数据失败: {slice_data}"} # 2. 查询源 bm_media_replacement 数据 success, media_data = self.query_media_replacement_by_doc_ids(source_knowledge_id, source_doc_ids) if not success: return False, {"error": f"查询媒体映射失败: {media_data}"} # 3. 查询源 bm_document 数据 success, doc_data = self.query_bm_document_by_doc_ids(source_knowledge_id, source_doc_ids) if not success: return False, {"error": f"查询文档数据失败: {doc_data}"} connection, info = self.get_connection() if not connection: return False, {"error": info} cursor = None try: cursor = connection.cursor() date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') slice_count = 0 media_count = 0 doc_count = 0 # 4. 插入 bm_document if doc_data: # 定义所有字段及特殊处理 all_fields = [ 'document_id', 'knowledge_id', 'custom_separator', 'sentence_size', 'length', 'word_num', 'slice_total', 'name', 'url', 'parse_image', 'tenant_id', 'create_dept', 'create_by', 'create_time', 'update_by', 'update_time', 'remark', 'parsing_type', 'oss_id', 'status', 'mark_oss_id', 'mark_url', 'ref_document_id', 'qa_checked', 'related_questions_enabled', 'summary_generation_enabled', 'parent_generation_enabled', 'suffix', 'pdf_url' ] doc_values = [] for row in doc_data: old_doc_id = row['document_id'] new_doc_id = doc_id_mapping.get(old_doc_id, generate_snowflake_id()) # 特殊处理字段 special = { 'document_id': new_doc_id, 'knowledge_id': new_knowledge_id, 'tenant_id': tenant_id, 'create_time': date_now, 'ref_document_id': old_doc_id } doc_values.append(tuple(special.get(f, row.get(f)) for f in all_fields)) placeholders = ','.join(['%s'] * len(all_fields)) doc_sql = f"INSERT INTO bm_document ({','.join(all_fields)}) VALUES ({placeholders})" cursor.executemany(doc_sql, doc_values) doc_count = len(doc_values) logger.info(f"插入文档数据成功: {doc_count} 条") # 5. 插入 slice_info(使用 old_slice_text 存储原始文本) if slice_data: slice_values = [] for row in slice_data: old_doc_id = row['document_id'] old_slice_id = row['slice_id'] new_doc_id = doc_id_mapping.get(old_doc_id, generate_snowflake_id()) # 使用向量库中的 chunk_id 作为 slice_id(保持一致) new_slice_id = chunk_id_mapping.get(old_slice_id, generate_snowflake_id()) if chunk_id_mapping else generate_snowflake_id() # old_slice_text 存储原始文本 old_slice_text = row.get('old_slice_text') or row['slice_text'] slice_values.append(( new_slice_id, new_knowledge_id, new_doc_id, row['slice_text'], old_slice_text, date_now, row.get('slice_index') )) slice_sql = """ INSERT INTO slice_info ( slice_id, knowledge_id, document_id, slice_text, old_slice_text, create_time, slice_index ) VALUES (%s, %s, %s, %s, %s, %s, %s) """ cursor.executemany(slice_sql, slice_values) slice_count = len(slice_values) logger.info(f"插入切片数据成功: {slice_count} 条") # 6. 插入 bm_media_replacement if media_data: media_values = [] for row in media_data: old_doc_id = row['document_id'] new_doc_id = doc_id_mapping.get(old_doc_id, generate_snowflake_id()) media_values.append(( new_knowledge_id, new_doc_id, row['origin_text'], row['media_type'], row['media_url'], date_now )) media_sql = """ INSERT INTO bm_media_replacement ( knowledge_id, document_id, origin_text, media_type, media_url, create_time ) VALUES (%s, %s, %s, %s, %s, %s) """ cursor.executemany(media_sql, media_values) media_count = len(media_values) logger.info(f"插入媒体映射数据成功: {media_count} 条") connection.commit() return True, { "message": "复制元数据成功", "doc_id_mapping": doc_id_mapping, "doc_count": doc_count, "slice_count": slice_count, "media_count": media_count } except Error as e: connection.rollback() logger.error(f"复制文档元数据失败: {e}") return False, {"error": str(e)} finally: if cursor: cursor.close() if connection: connection.close() def query_slice_by_id(self, knowledge_id: str, slice_id: str): """ 根据 knowledge_id 和 slice_id 查询单个切片数据 """ connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor(dictionary=True) sql = """ SELECT slice_id, knowledge_id, document_id, slice_text, old_slice_text, slice_index FROM slice_info WHERE knowledge_id = %s AND slice_id = %s """ cursor.execute(sql, (knowledge_id, slice_id)) results = cursor.fetchall() return True, results except Error as e: logger.error(f"查询 slice_info 出错: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def query_slice_by_knowledge_and_doc(self, knowledge_id: str, document_id: str): """ 根据 knowledge_id 和 document_id 查询 slice_info 表中的切片数据 """ connection, info = self.get_connection() if not connection: return False, info cursor = None try: cursor = connection.cursor(dictionary=True) sql = """ SELECT slice_id, knowledge_id, document_id, slice_text, old_slice_text, slice_index FROM slice_info WHERE knowledge_id = %s AND document_id = %s """ cursor.execute(sql, (knowledge_id, document_id)) results = cursor.fetchall() logger.info(f"查询 slice_info 成功,共 {len(results)} 条记录") return True, results except Error as e: logger.error(f"查询 slice_info 出错: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close() def update_slice_llm_fields(self, knowledge_id: str, slice_id: str, qa: str = None, question: str = None, summary: str = None): """ 更新 slice_info 表中的 qa、question、summary 字段,值为 None 的字段不更新 基于 knowledge_id 和 slice_id 的包含关系进行更新 """ connection, info = self.get_connection() if not connection: return False, info cursor = None try: # 动态构建 SET 子句,只更新非 None 字段 set_parts = [] params = [] if qa is not None: set_parts.append("qa = %s") params.append(qa) if question is not None: set_parts.append("question = %s") params.append(question) if summary is not None: set_parts.append("summary = %s") params.append(summary) if not set_parts: return True, "no fields to update" date_now = datetime.now().strftime('%Y-%m-%d %H:%M:%S') set_parts.append("update_time = %s") params.append(date_now) params.append(knowledge_id) params.append(slice_id) cursor = connection.cursor() sql = f"UPDATE slice_info SET {', '.join(set_parts)} WHERE knowledge_id = %s AND slice_id = %s" cursor.execute(sql, tuple(params)) connection.commit() logger.info(f"更新切片 {slice_id} 的LLM字段成功") return True, "success" except Error as e: connection.rollback() logger.error(f"更新 slice_info LLM字段出错: {e}") return False, str(e) finally: if cursor: cursor.close() if connection: connection.close()